# Train model Khmer Text Recognition (Use model : microsoft/trocr-base-printed)

## Step 1. Generate image dataset for Khmer text recognition

### 1. Import library for generate image

In [None]:
import pandas as pd
from PIL import Image, ImageDraw, ImageFont
import os
import random
import numpy as np
from sklearn.model_selection import train_test_split
import shutil
from IPython.display import FileLink

### 2. Load datasets

#### 2.1 Loading text word 

In [None]:
# 1.1. Loading words data
dataset_path = 'combined_cleaned.txt'

# Read all words from the text file
with open(dataset_path, 'r', encoding='utf-8') as f:
    words = [line.strip() for line in f if line.strip()]

print(f"\n✓ Loaded {len(words)} words from {dataset_path}")
print(f"Sample words: {words[:5]}")

# Create DataFrame
df = pd.DataFrame({'word': words})
print(f"\nDataFrame shape: {df.shape}")

### 3. Generate text to images

#### 3.1. Import function for generate text to image

In [None]:
def gen_khmer_text_image(index, content, data_type, bg, 
                        font_path, font_size, data_folder, padding=10):
    """
    Generate an image from Khmer text with specified styling parameters
    Image size adapts to text content
    
    Args:
        index: Index number for filename
        content: The text to render
        data_type: 'train', 'valid', or 'test'
        bg: Background color (R, G, B, A)
        font_path: Path to the font file
        font_size: Size of the font
        data_folder: Base folder for output
        padding: Padding around text (pixels)
    
    Returns:
        Filename of the generated image
    """
    # Load font first to measure text
    try:
        font = ImageFont.truetype(font_path, font_size)
    except:
        print(f"Warning: Could not load font {font_path}, using default")
        font = ImageFont.load_default()
    
    # Create temporary image to measure text
    temp_img = Image.new('RGBA', (1, 1))
    temp_draw = ImageDraw.Draw(temp_img)
    
    # Get text bounding box
    bbox = temp_draw.textbbox((0, 0), content, font=font)
    text_width = bbox[2] - bbox[0]
    text_height = bbox[3] - bbox[1]
    
    # Calculate image size based on text with padding
    img_width = text_width + (padding * 2)
    img_height = text_height + (padding * 2)
    
    # Create actual image with calculated size
    image = Image.new('RGBA', (img_width, img_height), bg)
    draw = ImageDraw.Draw(image)
    
    # Draw text with padding offset
    draw.text((padding, padding), content, font=font, fill=(0, 0, 0, 255))
    
    # Generate filename with 6-digit index
    filename = f"{index:06d}.png"
    
    # Create output directory if it doesn't exist
    output_dir = os.path.join(data_folder, data_type)
    os.makedirs(output_dir, exist_ok=True)
    
    # Save image
    output_path = os.path.join(output_dir, filename)
    image.save(output_path)
    
    return filename


#### 3.2. Define Variant values for Function Parameters

In [None]:
fonts_dir = "fonts"
fonts = []

if os.path.exists(fonts_dir):
    for filename in os.listdir(fonts_dir):
        if filename.endswith(('.ttf', '.otf', '.TTF', '.OTF')):
            font_path = os.path.join(fonts_dir, filename)
            fonts.append(font_path)
    fonts.sort()  # Sort alphabetically for consistency
else:
    print(f"Warning: '{fonts_dir}' folder not found!")
    fonts = []

if not fonts:
    print("ERROR: No font files found in 'fonts/' folder!")
    print("Please ensure .ttf or .otf font files are in the 'fonts/' directory")
    exit()

print(f"\nDiscovered fonts:")
for font in fonts:
    print(f"  • {font}")

# Font sizes
font_sizes = [9,10,11,12,13,14,15,16]

# Background colors
bg_colors = [
    (255, 255, 255, 255),
]

print(f"\n✓ {len(fonts)} fonts")
print(f"✓ {len(font_sizes)} font sizes")
print(f"✓ {len(bg_colors)} background colors")


#### 3.3 Splitting The Dataset: Train, Validation, Test

In [None]:
# 2.3 Splitting The Dataset: Train, Validation, Test
print("\n2.3. Splitting the dataset...")

# Split: 70% train, 15% validation, 15% test
train, temp = train_test_split(df, test_size=0.3, random_state=42)
valid, test = train_test_split(temp, test_size=0.5, random_state=42)

# Reset indices for proper numbering
train = train.reset_index(drop=True)
valid = valid.reset_index(drop=True)
test = test.reset_index(drop=True)

print(f"✓ Train: {len(train)} words")
print(f"✓ Validation: {len(valid)} words")
print(f"✓ Test: {len(test)} words")


#### 3.4 Generating Text to Images

In [None]:
# 2.4 Generating Text to Images
data_folder = "data_v1"
os.makedirs(data_folder, exist_ok=True)

# Lists to store labels
train_labels = []
valid_labels = []
test_labels = []

#### 3.5 Generating training data to image

In [None]:
i = 1
n = len(train)
for index, row in train.iterrows():
    font_size = random.choice(font_sizes)
    font = random.choice(fonts)
    bg = random.choice(bg_colors)
    
    try:
        filename = gen_khmer_text_image(
            index=index+1, 
            content=row["word"],
            data_type="train", 
            bg=bg,
            font_path=font, 
            font_size=font_size,
            data_folder=data_folder
        )
        
        train_labels.append(f"{filename}\t{row['word']}")
    except Exception as e:
        print(f"Error processing word '{row['word']}': {e}")
        continue
    
    if i % 100 == 0 or i == n:
        print(f"{i} of {n}: complete")
    i = i + 1

#### 3.6 Generating validation data to image

In [None]:
i = 1
n = len(valid)
for index, row in valid.iterrows():
    font_size = random.choice(font_sizes)
    font = random.choice(fonts)
    bg = random.choice(bg_colors)
    
    try:
        filename = gen_khmer_text_image(
            index=index+1, 
            content=row["word"],
            data_type="valid", 
            bg=bg,
            font_path=font, 
            font_size=font_size,
            data_folder=data_folder
        )
        
        valid_labels.append(f"{filename}\t{row['word']}")
    except Exception as e:
        print(f"Error processing word '{row['word']}': {e}")
        continue
    
    if i % 100 == 0 or i == n:
        print(f"{i} of {n}: complete")
    i = i + 1


#### 3.7 Generating testing data to image

In [None]:
i = 1
n = len(test)
for index, row in test.iterrows():
    font_size = random.choice(font_sizes)
    font = random.choice(fonts)
    bg = random.choice(bg_colors)
    
    try:
        filename = gen_khmer_text_image(
            index=index+1, 
            content=row["word"],
            data_type="test", 
            bg=bg,
            font_path=font, 
            font_size=font_size,
            data_folder=data_folder
        )
        
        test_labels.append(f"{filename}\t{row['word']}")
    except Exception as e:
        print(f"Error processing word '{row['word']}': {e}")
        continue
    
    if i % 100 == 0 or i == n:
        print(f"{i} of {n}: complete")
    i = i + 1

#### 3.8 Saving label files

In [None]:

# Save train.txt
with open(os.path.join(data_folder, 'train.txt'), 'w', encoding='utf-8') as f:
    f.write('\n'.join(train_labels))
print(f"✓ Saved train.txt ({len(train_labels)} entries)")

# Save valid.txt
with open(os.path.join(data_folder, 'valid.txt'), 'w', encoding='utf-8') as f:
    f.write('\n'.join(valid_labels))
print(f"✓ Saved valid.txt ({len(valid_labels)} entries)")

# Save test.txt
with open(os.path.join(data_folder, 'test.txt'), 'w', encoding='utf-8') as f:
    f.write('\n'.join(test_labels))
print(f"✓ Saved test.txt ({len(test_labels)} entries)")

## Step 2: TrOCR Training for Khmer Text Recognition

### 1: Install Required Packages

In [None]:
!pip install -q transformers datasets pillow evaluate jiwer accelerate

### 2: Import Libraries

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import (
    TrOCRProcessor, 
    VisionEncoderDecoderModel,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator
)
from datasets import load_metric
import numpy as np
from tqdm import tqdm
import pandas as pd

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


### 3: Custom Dataset Class

In [None]:
class KhmerTextRecognitionDataset(Dataset):
    """Custom dataset for Khmer text recognition"""
    
    def __init__(self, label_file, image_dir, processor, max_target_length=128):
        """
        Args:
            label_file: Path to the label file (train.txt, valid.txt, test.txt)
            image_dir: Directory containing the images
            processor: TrOCRProcessor instance
            max_target_length: Maximum length of target text
        """
        self.image_dir = image_dir
        self.processor = processor
        self.max_target_length = max_target_length
        
        # Load labels
        self.samples = []
        with open(label_file, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    parts = line.split('\t')
                    if len(parts) == 2:
                        filename, text = parts
                        self.samples.append({'filename': filename, 'text': text})
        
        print(f"Loaded {len(self.samples)} samples from {label_file}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load image
        image_path = os.path.join(self.image_dir, sample['filename'])
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a blank image if loading fails
            image = Image.new('RGB', (384, 64), color='white')
        
        # Process image
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        
        # Process text
        labels = self.processor.tokenizer(
            sample['text'],
            padding="max_length",
            max_length=self.max_target_length,
            truncation=True,
            return_tensors="pt"
        ).input_ids
        
        # Replace padding token id's with -100 so they are ignored by the loss
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        
        return {
            'pixel_values': pixel_values.squeeze(),
            'labels': labels.squeeze()
        }


### 4: Setup Model and Processor

In [None]:
print("\n" + "="*60)
print("Loading TrOCR Model and Processor")
print("="*60)

# Load processor and model
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")

# Move model to device
model.to(device)

# Set special tokens
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

# Set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

print(f"✓ Model loaded successfully")
print(f"✓ Model parameters: {sum(p.numel() for p in model.parameters()):,}")


### 5: Load Datasets

In [None]:
print("\n" + "="*60)
print("Loading Datasets")
print("="*60)

# Define paths (adjust these to your data location)
data_folder = "data_v1"
train_label_file = os.path.join(data_folder, "train.txt")
valid_label_file = os.path.join(data_folder, "valid.txt")
test_label_file = os.path.join(data_folder, "test.txt")

train_image_dir = os.path.join(data_folder, "train")
valid_image_dir = os.path.join(data_folder, "valid")
test_image_dir = os.path.join(data_folder, "test")

# Create datasets
train_dataset = KhmerTextRecognitionDataset(
    train_label_file, 
    train_image_dir, 
    processor
)

valid_dataset = KhmerTextRecognitionDataset(
    valid_label_file, 
    valid_image_dir, 
    processor
)

test_dataset = KhmerTextRecognitionDataset(
    test_label_file, 
    test_image_dir, 
    processor
)

print(f"✓ Train dataset: {len(train_dataset)} samples")
print(f"✓ Valid dataset: {len(valid_dataset)} samples")
print(f"✓ Test dataset: {len(test_dataset)} samples")


### 6: Compute Metrics Function

In [None]:
def compute_metrics(pred):
    """Compute CER (Character Error Rate) metric"""
    labels_ids = pred.label_ids
    pred_ids = pred.predictions
    
    # Decode predictions and labels
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
    
    # Compute CER
    cer = 0
    for pred, label in zip(pred_str, label_str):
        # Simple character error rate calculation
        pred = pred.replace(" ", "")
        label = label.replace(" ", "")
        
        if len(label) == 0:
            if len(pred) == 0:
                cer += 0
            else:
                cer += 1
        else:
            # Compute Levenshtein distance
            distance = levenshtein_distance(pred, label)
            cer += distance / len(label)
    
    return {"cer": cer / len(pred_str)}

def levenshtein_distance(s1, s2):
    """Calculate Levenshtein distance between two strings"""
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)
    
    if len(s2) == 0:
        return len(s1)
    
    previous_row = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row
    
    return previous_row[-1]


### 7: Custom Trainer with Checkpoint Saving

In [None]:
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    """Custom trainer that saves models at specific epochs"""
    
    def __init__(self, *args, checkpoint_epochs=[10, 20, 30, 40, 50], **kwargs):
        super().__init__(*args, **kwargs)
        self.checkpoint_epochs = checkpoint_epochs
        self.current_epoch = 0
    
    def _save_checkpoint(self, model, trial, metrics=None):
        """Override to save at specific epochs"""
        super()._save_checkpoint(model, trial, metrics)
        
        # Check if we should save a named checkpoint
        if self.current_epoch in self.checkpoint_epochs:
            checkpoint_name = f"khmer-text-recognition-{self.current_epoch // 10}"
            output_dir = os.path.join(self.args.output_dir, checkpoint_name)
            
            print(f"\n{'='*60}")
            print(f"Saving checkpoint at epoch {self.current_epoch}: {checkpoint_name}")
            print(f"{'='*60}")
            
            self.model.save_pretrained(output_dir)
            self.tokenizer.save_pretrained(output_dir)
            
            print(f"✓ Saved model to {output_dir}")
    
    def training_step(self, model, inputs):
        """Override to track epochs"""
        return super().training_step(model, inputs)
    
    def on_epoch_end(self, args, state, control, **kwargs):
        """Track epoch completion"""
        self.current_epoch = int(state.epoch)
        return super().on_epoch_end(args, state, control, **kwargs)


### 8: Training Arguments

In [None]:
print("\n" + "="*60)
print("Setting Up Training Configuration")
print("="*60)

training_args = Seq2SeqTrainingArguments(
    output_dir="./khmer-trocr-checkpoints",
    
    # Training parameters
    num_train_epochs=50,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    
    # Optimization
    learning_rate=5e-5,
    weight_decay=0.01,
    warmup_steps=500,
    
    # Evaluation
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=5,
    
    # Logging
    logging_dir="./logs",
    logging_steps=100,
    logging_strategy="steps",
    
    # Generation
    predict_with_generate=True,
    generation_max_length=128,
    generation_num_beams=4,
    
    # Performance
    fp16=torch.cuda.is_available(),
    dataloader_num_workers=2,
    
    # Other
    load_best_model_at_end=True,
    metric_for_best_model="cer",
    greater_is_better=False,
    push_to_hub=False,
    report_to="none",
)

print("✓ Training arguments configured")
print(f"  • Epochs: {training_args.num_train_epochs}")
print(f"  • Batch size: {training_args.per_device_train_batch_size}")
print(f"  • Learning rate: {training_args.learning_rate}")
print(f"  • FP16: {training_args.fp16}")


### 9: Initialize Trainer

In [None]:
print("\n" + "="*60)
print("Initializing Trainer")
print("="*60)

trainer = CustomSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=processor.tokenizer,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
    checkpoint_epochs=[10, 20, 30, 40, 50]
)

print("✓ Trainer initialized successfully")


### 10: Start Training

In [None]:
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)
print("This will take several hours depending on your dataset size and GPU")
print("Checkpoints will be saved at epochs: 10, 20, 30, 40, 50")
print("="*60 + "\n")

# Train the model
trainer.train()

print("\n" + "="*60)
print("TRAINING COMPLETED!")
print("="*60)


### 11: Save Final Best Model

In [None]:
print("\n" + "="*60)
print("Saving Final Best Model")
print("="*60)

best_model_dir = "./khmer-text-recognition-best"
model.save_pretrained(best_model_dir)
processor.save_pretrained(best_model_dir)

print(f"✓ Best model saved to {best_model_dir}")


### 12: Evaluate on Test Set

In [None]:
print("\n" + "="*60)
print("Evaluating on Test Set")
print("="*60)

test_results = trainer.evaluate(test_dataset)

print("\nTest Results:")
for key, value in test_results.items():
    print(f"  • {key}: {value:.4f}")


### 13: Test Predictions

In [None]:
print("\n" + "="*60)
print("Sample Predictions")
print("="*60)

# Test on a few samples
model.eval()
num_samples = 5

for idx in range(min(num_samples, len(test_dataset))):
    sample = test_dataset[idx]
    
    with torch.no_grad():
        pixel_values = sample['pixel_values'].unsqueeze(0).to(device)
        generated_ids = model.generate(pixel_values)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        # Get ground truth
        labels = sample['labels'].cpu().numpy()
        labels[labels == -100] = processor.tokenizer.pad_token_id
        ground_truth = processor.tokenizer.decode(labels, skip_special_tokens=True)
        
        print(f"\nSample {idx + 1}:")
        print(f"  Ground Truth: {ground_truth}")
        print(f"  Prediction:   {generated_text}")
        print(f"  Match: {'✓' if generated_text == ground_truth else '✗'}")


#### 13.5: Generate Training Graphs

In [None]:
import matplotlib.pyplot as plt
import json

print("\n" + "="*60)
print("Generating Training Graphs")
print("="*60)

# Read training logs
log_history = trainer.state.log_history

# Extract metrics
train_loss = []
eval_loss = []
eval_cer = []
epochs = []

for log in log_history:
    if 'loss' in log and 'epoch' in log:
        train_loss.append(log['loss'])
    if 'eval_loss' in log and 'epoch' in log:
        eval_loss.append(log['eval_loss'])
        eval_cer.append(log.get('eval_cer', 0))
        epochs.append(log['epoch'])

# Create figure with subplots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Khmer Text Recognition Training Metrics', fontsize=16, fontweight='bold')

# Plot 1: Training Loss
if train_loss:
    axes[0, 0].plot(train_loss, color='#2E86AB', linewidth=2)
    axes[0, 0].set_title('Training Loss', fontsize=12, fontweight='bold')
    axes[0, 0].set_xlabel('Steps')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].set_facecolor('#F8F9FA')

# Plot 2: Validation Loss
if eval_loss and epochs:
    axes[0, 1].plot(epochs, eval_loss, color='#A23B72', linewidth=2, marker='o')
    axes[0, 1].set_title('Validation Loss', fontsize=12, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].set_facecolor('#F8F9FA')

# Plot 3: Character Error Rate (CER)
if eval_cer and epochs:
    axes[1, 0].plot(epochs, eval_cer, color='#F18F01', linewidth=2, marker='s')
    axes[1, 0].set_title('Character Error Rate (CER)', fontsize=12, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('CER')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].set_facecolor('#F8F9FA')
    
    # Add best CER annotation
    if eval_cer:
        best_cer = min(eval_cer)
        best_epoch = epochs[eval_cer.index(best_cer)]
        axes[1, 0].axhline(y=best_cer, color='green', linestyle='--', alpha=0.5)
        axes[1, 0].text(0.02, 0.98, f'Best CER: {best_cer:.4f}\nEpoch: {best_epoch:.0f}', 
                       transform=axes[1, 0].transAxes, 
                       verticalalignment='top',
                       bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# Plot 4: Accuracy (1 - CER)
if eval_cer and epochs:
    accuracy = [1 - cer for cer in eval_cer]
    axes[1, 1].plot(epochs, accuracy, color='#06A77D', linewidth=2, marker='D')
    axes[1, 1].set_title('Recognition Accuracy (1 - CER)', fontsize=12, fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Accuracy')
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].set_facecolor('#F8F9FA')
    axes[1, 1].set_ylim([0, 1])
    
    # Add best accuracy annotation
    if accuracy:
        best_acc = max(accuracy)
        best_epoch_acc = epochs[accuracy.index(best_acc)]
        axes[1, 1].axhline(y=best_acc, color='green', linestyle='--', alpha=0.5)
        axes[1, 1].text(0.02, 0.02, f'Best Accuracy: {best_acc:.4f}\nEpoch: {best_epoch_acc:.0f}', 
                       transform=axes[1, 1].transAxes,
                       bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))

plt.tight_layout()

# Save the plot
plot_filename = 'training_metrics.png'
plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
print(f"✓ Saved training graphs to {plot_filename}")

# Display the plot
plt.show()

# Save metrics to JSON for later reference
metrics_data = {
    'epochs': epochs,
    'train_loss': train_loss,
    'eval_loss': eval_loss,
    'eval_cer': eval_cer,
    'accuracy': [1 - cer for cer in eval_cer] if eval_cer else []
}

with open('training_metrics.json', 'w') as f:
    json.dump(metrics_data, f, indent=2)
print(f"✓ Saved metrics data to training_metrics.json")

print("\n" + "="*60)


### 14: Save Training Summary

In [None]:
print("\n" + "="*60)
print("Training Summary")
print("="*60)

summary = {
    'total_epochs': 50,
    'train_samples': len(train_dataset),
    'valid_samples': len(valid_dataset),
    'test_samples': len(test_dataset),
    'final_test_cer': test_results.get('eval_cer', 'N/A'),
    'model_name': 'microsoft/trocr-base-printed',
    'best_model_location': best_model_dir
}

print("\nFinal Summary:")
for key, value in summary.items():
    print(f"  • {key}: {value}")


### 15: Zip and Automatically Download All Checkpoints

In [None]:
import shutil
from google.colab import files
import time

print("\n" + "="*60)
print("Automatic Model Download")
print("="*60)

# List of all models to zip and download
models_to_download = [
    ("./khmer-trocr-checkpoints/khmer-text-recognition-1", "khmer-text-recognition-1"),
    ("./khmer-trocr-checkpoints/khmer-text-recognition-2", "khmer-text-recognition-2"),
    ("./khmer-trocr-checkpoints/khmer-text-recognition-3", "khmer-text-recognition-3"),
    ("./khmer-trocr-checkpoints/khmer-text-recognition-4", "khmer-text-recognition-4"),
    ("./khmer-trocr-checkpoints/khmer-text-recognition-5", "khmer-text-recognition-5"),
    (best_model_dir, "khmer-text-recognition-best")
]

print("\nStarting automatic download process...")
print("Zipping and downloading all models...")
print("This may take several minutes depending on model size.\n")

successful_downloads = []
failed_downloads = []

for idx, (model_path, zip_name) in enumerate(models_to_download, 1):
    print(f"[{idx}/{len(models_to_download)}] Processing {zip_name}...")
    
    if os.path.exists(model_path):
        try:
            # Create zip file
            zip_filename = f"{zip_name}"
            print(f"  Creating {zip_filename}.zip...")
            shutil.make_archive(zip_filename, 'zip', model_path)
            
            # Get file size
            file_size = os.path.getsize(f"{zip_filename}.zip") / (1024*1024)
            print(f"  Size: {file_size:.2f} MB")
            
            # Download the file automatically
            print(f"  Downloading {zip_filename}.zip...")
            files.download(f"{zip_filename}.zip")
            
            successful_downloads.append(zip_name)
            print(f"  {zip_name} downloaded successfully!\n")
            
            # Small delay between downloads to prevent issues
            time.sleep(2)
            
        except Exception as e:
            failed_downloads.append((zip_name, str(e)))
            print(f"  Error downloading {zip_name}: {e}\n")
    else:
        failed_downloads.append((zip_name, "Path not found"))
        print(f"  Warning: {model_path} not found, skipping...\n")

# Summary
print("="*60)
print("DOWNLOAD SUMMARY")
print("="*60)
print(f"\nSuccessfully downloaded: {len(successful_downloads)}/{len(models_to_download)} models")
for model in successful_downloads:
    print(f"   • {model}.zip")

if failed_downloads:
    print(f"\nFailed downloads: {len(failed_downloads)}")
    for model, error in failed_downloads:
        print(f"   • {model}: {error}")
else:
    print("\nAll models downloaded successfully!")
print("\nCheck your browser's Downloads folder for the zip files.")
print("="*60)
