In [1]:
import pandas as pd
import random
from sklearn.model_selection import train_test_split
from cls_global import gen_khmer_text_image

# ============================================
# PART 1: Load data from text_process.text
# ============================================
print("Loading data from text_process.text...")
try:
    with open("datasets/all_cleaned_words.txt", 'r', encoding='utf-8') as file:
        content = file.read()
    
    # Split content into lines
    lines = content.strip().split('\n')
    data = pd.DataFrame(lines, columns=['word'])
    data = data[data['word'].str.strip() != '']  # Remove empty lines
    data['category'] = "Text Process"
    
    print(f"Loaded {len(data)} entries from text_process.text")
except FileNotFoundError:
    print("Error: text_process.text not found!")
    exit()
except Exception as e:
    print(f"Error loading text_process.text: {e}")
    exit()

data.reset_index(drop=True, inplace=True)

# Print statistics
print("\n=== Data Summary ===")
print(f"Total Data: {len(data)}")

# ============================================
# PART 2: Font and Background Variants
# ============================================
import os
import glob

# Automatically load all .ttf fonts from the fonts folder
fonts_folder = "fonts"
fonts = glob.glob(os.path.join(fonts_folder, "*.ttf"))

if len(fonts) == 0:
    print(f"Warning: No .ttf fonts found in '{fonts_folder}' folder!")
    exit()

print(f"\n=== Fonts Loaded ===")
print(f"Found {len(fonts)} fonts:")
for font in fonts:
    print(f"  - {os.path.basename(font)}")

font_sizes = [12, 16]

bg_colors = [
    (255, 255, 255, 255), 
]

noise_levels = ["low","none"]
blur_levels = [0]

# ============================================
# PART 3: Train/Valid/Test Split
# ============================================
if len(data) > 0:
    train_valid, test = train_test_split(
        data, 
        test_size=0.2, 
        stratify=data["category"], 
        random_state=42
    )
    train, valid = train_test_split(
        train_valid, 
        test_size=0.1, 
        stratify=train_valid["category"], 
        random_state=42
    )
    
    print("\n=== Split Summary ===")
    print(f"Train: {len(train)}")
    print(f"Valid: {len(valid)}")
    print(f"Test: {len(test)}")
    
    # ============================================
    # PART 4: Generate Images
    # ============================================
    data_folder = "data_v1"
    
    # Generate training images
    print("\n=== Generating Training Images ===")
    for i, (index, row) in enumerate(train.iterrows(), 1):
        font_size = random.choice(font_sizes)
        font = random.choice(fonts)
        bg = random.choice(bg_colors)
        noise_level = random.choice(noise_levels)
        blur_level = random.choice(blur_levels)
        
        gen_khmer_text_image(
            index=index+1, 
            content=row["word"],
            data_type="train", 
            bg=bg, 
            noise_level=noise_level, 
            blur_level=blur_level,
            font_path=font, 
            font_size=font_size,
            data_folder=data_folder
        )
        if i % 100 == 0 or i == len(train):
            print(f"{i} of {len(train)}: complete")
    
    # Generate validation images
    print("\n=== Generating Validation Images ===")
    for i, (index, row) in enumerate(valid.iterrows(), 1):
        font_size = random.choice(font_sizes)
        font = random.choice(fonts)
        bg = random.choice(bg_colors)
        noise_level = random.choice(noise_levels)
        blur_level = random.choice(blur_levels)
        
        gen_khmer_text_image(
            index=index+1, 
            content=row["word"],
            data_type="valid", 
            bg=bg, 
            noise_level=noise_level, 
            blur_level=blur_level,
            font_path=font, 
            font_size=font_size,
            data_folder=data_folder
        )
        if i % 100 == 0 or i == len(valid):
            print(f"{i} of {len(valid)}: complete")
    
    # Generate testing images
    print("\n=== Generating Testing Images ===")
    for i, (index, row) in enumerate(test.iterrows(), 1):
        font_size = random.choice(font_sizes)
        font = random.choice(fonts)
        bg = random.choice(bg_colors)
        noise_level = random.choice(noise_levels)
        blur_level = random.choice(blur_levels)
        
        gen_khmer_text_image(
            index=index+1, 
            content=row["word"],
            data_type="test", 
            bg=bg, 
            noise_level=noise_level, 
            blur_level=blur_level,
            font_path=font, 
            font_size=font_size,
            data_folder=data_folder
        )
        if i % 100 == 0 or i == len(test):
            print(f"{i} of {len(test)}: complete")
    
    print("\n=== Image Generation Complete ===")
    
    # ============================================
    # PART 5: Save Train/Valid/Test Labels
    # ============================================
    print("\n=== Saving Label Files ===")
    
    # Save train labels
    train_labels = []
    for index, row in train.iterrows():
        train_labels.append(f"train/{index+1}.png\t{row['word']}")
    with open(f"{data_folder}/train.txt", 'w', encoding='utf-8') as f:
        f.write('\n'.join(train_labels))
    print(f"Saved {len(train_labels)} training labels to {data_folder}/train.txt")
    
    # Save valid labels
    valid_labels = []
    for index, row in valid.iterrows():
        valid_labels.append(f"valid/{index+1}.png\t{row['word']}")
    with open(f"{data_folder}/valid.txt", 'w', encoding='utf-8') as f:
        f.write('\n'.join(valid_labels))
    print(f"Saved {len(valid_labels)} validation labels to {data_folder}/valid.txt")
    
    # Save test labels
    test_labels = []
    for index, row in test.iterrows():
        test_labels.append(f"test/{index+1}.png\t{row['word']}")
    with open(f"{data_folder}/test.txt", 'w', encoding='utf-8') as f:
        f.write('\n'.join(test_labels))
    print(f"Saved {len(test_labels)} test labels to {data_folder}/test.txt")
    
else:
    print("No data available for splitting and image generation.")

Loading data from text_process.text...
Loaded 127491 entries from text_process.text

=== Data Summary ===
Total Data: 127491

=== Fonts Loaded ===
Found 15 fonts:
  - KhmerDigital-Black.ttf
  - KhmerDigital-Bold.ttf
  - KhmerDigital-ExtraBold.ttf
  - KhmerDigital-ExtraLight.ttf
  - KhmerDigital-Light.ttf
  - KhmerDigital-Medium.ttf
  - KhmerDigital-Regular.ttf
  - KhmerDigital-SemiBold.ttf
  - KhmerDigital-Thin.ttf
  - KhmerDigitalMax.ttf
  - KhmerDigitalNumber.ttf
  - KhmerDigitalNumberMax.ttf
  - KhmerMPTC.ttf
  - KhmerOS_muollight.ttf
  - KhmerOS_siemreap.ttf

=== Split Summary ===
Train: 91792
Valid: 10200
Test: 25499

=== Generating Training Images ===
100 of 91792: complete
200 of 91792: complete
300 of 91792: complete
400 of 91792: complete
500 of 91792: complete
600 of 91792: complete
700 of 91792: complete
800 of 91792: complete
900 of 91792: complete
1000 of 91792: complete
1100 of 91792: complete
1200 of 91792: complete
1300 of 91792: complete
1400 of 91792: complete
1500 of

KeyboardInterrupt: 

In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from torch.optim import AdamW
from tqdm import tqdm
from jiwer import cer, wer
import matplotlib.pyplot as plt
from torchvision.transforms import Resize, ToTensor, Normalize, Compose, RandomRotation, ToPILImage
import shutil

# ============================================
# PART 1: Dataset Class
# ============================================
class KhmerTextDataset(Dataset):
    def __init__(self, dataframe, root_dir, processor, transform=None, max_target_length=128):
        self.dataframe = dataframe
        self.root_dir = root_dir
        self.processor = processor
        self.transform = transform
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.dataframe.iloc[idx, 0])
        image = Image.open(img_name).convert("RGB")
        label = self.dataframe.iloc[idx, 1]

        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)
        else:
            image = self.processor(images=image, return_tensors="pt").pixel_values.squeeze()

        # Tokenize label
        labels = self.processor.tokenizer(
            label,
            padding="max_length",
            max_length=self.max_target_length,
            truncation=True
        ).input_ids

        return {"pixel_values": image, "labels": torch.tensor(labels)}

# ============================================
# PART 2: Helper Functions
# ============================================
def load_dataset(file_path):
    """Load tab-separated data"""
    data = pd.read_csv(file_path, sep="\t", header=None, names=["image", "label"])
    return data

def create_dataloader(data, root_dir, processor, batch_size=16, shuffle=True, max_length=128, transform=None, data_type="train"):
    """Create DataLoader from dataset"""
    # Add data_type subdirectory to root_dir
    full_root_dir = os.path.join(root_dir, data_type)
    dataset = KhmerTextDataset(data, full_root_dir, processor, max_target_length=max_length, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def save_checkpoint(model, processor, optimizer, epoch, step, checkpoint_dir="checkpoint_latest"):
    """Save latest checkpoint (overwrites previous)"""
    # Remove old checkpoint if exists
    if os.path.exists(checkpoint_dir):
        shutil.rmtree(checkpoint_dir)
    
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Save model and processor
    model.save_pretrained(checkpoint_dir)
    processor.save_pretrained(checkpoint_dir)
    
    # Save optimizer state and training info
    torch.save({
        'epoch': epoch,
        'step': step,
        'optimizer_state_dict': optimizer.state_dict(),
    }, os.path.join(checkpoint_dir, 'training_state.pt'))
    
    print(f"\n✓ Checkpoint saved: Epoch {epoch}, Step {step}")
    
    # Auto-download if running in Google Colab
    try:
        from google.colab import files
        import zipfile
        
        # Create zip file
        zip_filename = f"checkpoint_epoch{epoch}_step{step}.zip"
        with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
            for root, dirs, files_list in os.walk(checkpoint_dir):
                for file in files_list:
                    file_path = os.path.join(root, file)
                    arcname = os.path.relpath(file_path, checkpoint_dir)
                    zipf.write(file_path, arcname)
        
        print(f"✓ Checkpoint zipped: {zip_filename}")
        print(f"⬇ Downloading checkpoint...")
        files.download(zip_filename)
        
        # Clean up zip file after download
        os.remove(zip_filename)
        print(f"✓ Download complete!")
        
    except ImportError:
        # Not in Colab, skip download
        print(f"✓ Checkpoint saved to: {checkpoint_dir}")

# ============================================
# PART 3: Configuration
# ============================================
batch_size = 16
max_length = 128
data_path = "data_v1"
epochs = 20
checkpoint_dir = "checkpoint_latest"  # Single checkpoint folder

# ============================================
# PART 4: Load Datasets
# ============================================
print("Loading datasets...")
train_data = load_dataset(f"{data_path}/train.txt")
valid_data = load_dataset(f"{data_path}/valid.txt")
test_data = load_dataset(f"{data_path}/test.txt")

print("\n=== Dataset Summary ===")
print(f"Train: {len(train_data)} samples")
print(f"Valid: {len(valid_data)} samples")
print(f"Test: {len(test_data)} samples")
print("\nSample train data:")
print(train_data.head())

# ============================================
# PART 5: Load Model and Processor
# ============================================
print("\n=== Loading TrOCR Model ===")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-printed")

model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.eos_token_id = processor.tokenizer.sep_token_id

print("Model loaded successfully!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# ============================================
# PART 6: Create DataLoaders
# ============================================
print("\n=== Creating DataLoaders ===")
transform = Compose([
    Resize((384, 384)),  # Resize to match ViT input size
    RandomRotation(degrees=5),  # Add slight rotation
    ToTensor(),  # Convert to PyTorch Tensor
    Normalize(mean=[0.5], std=[0.5])  # Normalize pixel values
])

train_loader = create_dataloader(
    train_data, f"{data_path}/", processor,
    batch_size=batch_size, max_length=max_length, transform=transform, data_type="train"
)
valid_loader = create_dataloader(
    valid_data, f"{data_path}/", processor,
    batch_size=batch_size, max_length=max_length, transform=transform, data_type="valid"
)
test_loader = create_dataloader(
    test_data, f"{data_path}/", processor,
    batch_size=batch_size, max_length=max_length, transform=transform, data_type="test"
)

print(f"Train batches: {len(train_loader)}")
print(f"Valid batches: {len(valid_loader)}")
print(f"Test batches: {len(test_loader)}")

# ============================================
# PART 7: Visualize Sample Batch
# ============================================
print("\n=== Visualizing Sample Batch ===")
reverse_transform = ToPILImage()

for i, batch in enumerate(train_loader):
    print(f"\nBatch {i + 1}:")
    print("Pixel Values Shape:", batch["pixel_values"].shape)
    print("Labels Shape:", batch["labels"].shape)

    # Show first image in batch
    label = batch["labels"][0]
    decoded_label = processor.tokenizer.decode(label.tolist(), skip_special_tokens=True)
    print(f"Decoded Label: {decoded_label}")

    pixel_values = batch["pixel_values"][0]
    image = reverse_transform(pixel_values)

    plt.figure(figsize=(8, 4))
    plt.imshow(image)
    plt.title(f"Label: {decoded_label}")
    plt.axis("off")
    plt.show()

    if i == 2:  # Show only first 3 batches
        break

# ============================================
# PART 8: Training Setup
# ============================================
print("\n=== Training Setup ===")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)

# Initialize lists to store metrics
training_losses = []
validation_losses = []
cer_scores = []
wer_scores = []

# Initialize global step counter
global_step = 0

# ============================================
# PART 9: Training Loop with Checkpointing
# ============================================
print("\n=== Starting Training ===")
print(f"Checkpoints will be saved and downloaded after each epoch\n")

for epoch in range(epochs):
    # Training phase
    model.train()
    total_loss = 0

    print(f"\nEpoch {epoch + 1}/{epochs} - Training")
    for batch_idx, batch in enumerate(tqdm(train_loader, desc="Training")):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Increment global step
        global_step += 1

    avg_train_loss = total_loss / len(train_loader)
    training_losses.append(avg_train_loss)
    print(f"Epoch {epoch + 1}/{epochs}, Training Loss: {avg_train_loss:.4f}, Global Step: {global_step}")

    # Validation phase
    model.eval()
    val_loss = 0
    all_predictions = []
    all_references = []

    with torch.no_grad():
        print(f"Epoch {epoch + 1}/{epochs} - Validation")
        for batch in tqdm(valid_loader, desc="Validation"):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"]

            # Forward pass
            outputs = model(pixel_values=pixel_values, labels=labels.to(device))
            val_loss += outputs.loss.item()

            # Decode predictions and references
            predicted_ids = torch.argmax(outputs.logits, dim=-1)
            predictions = processor.batch_decode(predicted_ids, skip_special_tokens=True)
            references = processor.batch_decode(labels, skip_special_tokens=True)

            all_predictions.extend(predictions)
            all_references.extend(references)

    avg_val_loss = val_loss / len(valid_loader)
    validation_losses.append(avg_val_loss)

    # Calculate CER and WER
    cer_score = cer(all_references, all_predictions)
    wer_score = wer(all_references, all_predictions)
    cer_scores.append(cer_score)
    wer_scores.append(wer_score)

    print(f"Epoch {epoch + 1}/{epochs}, Validation Loss: {avg_val_loss:.4f}")
    print(f"Epoch {epoch + 1}/{epochs}, CER: {cer_score:.4f}, WER: {wer_score:.4f}")
    
    # Save checkpoint at the end of each epoch
    print(f"\n{'='*60}")
    print(f"Saving checkpoint for Epoch {epoch + 1}")
    print(f"{'='*60}")
    save_checkpoint(model, processor, optimizer, epoch + 1, global_step, checkpoint_dir)

# ============================================
# PART 10: Plot Training Results
# ============================================
print("\n=== Plotting Results ===")
epochs_range = range(1, epochs + 1)

# Training and Validation Loss
plt.figure(figsize=(12, 6))
plt.plot(epochs_range, training_losses, label="Training Loss", marker='o')
plt.plot(epochs_range, validation_losses, label="Validation Loss", marker='s')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.grid(True)
plt.show()

# CER and WER
plt.figure(figsize=(12, 6))
plt.plot(epochs_range, cer_scores, label="CER", marker='o')
plt.plot(epochs_range, wer_scores, label="WER", marker='s')
plt.xlabel("Epochs")
plt.ylabel("Score")
plt.title("CER and WER over Epochs")
plt.legend()
plt.grid(True)
plt.show()

# ============================================
# PART 11: Save Final Model
# ============================================
print("\n=== Saving Final Model ===")
model.save_pretrained("khmer_text_recognition_model_v3")
processor.save_pretrained("khmer_text_recognition_processor_v3")
print("Model and processor saved successfully!")

# Auto-download final model
try:
    from google.colab import files
    import zipfile
    
    # Create zip file for final model
    final_zip = "khmer_text_recognition_model_final.zip"
    with zipfile.ZipFile(final_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files_list in os.walk("khmer_text_recognition_model_v3"):
            for file in files_list:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, "khmer_text_recognition_model_v3")
                zipf.write(file_path, os.path.join("model", arcname))
        
        for root, dirs, files_list in os.walk("khmer_text_recognition_processor_v3"):
            for file in files_list:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, "khmer_text_recognition_processor_v3")
                zipf.write(file_path, os.path.join("processor", arcname))
    
    print(f"✓ Final model zipped: {final_zip}")
    print(f"⬇ Downloading final model...")
    files.download(final_zip)
    print(f"✓ Final model download complete!")
    
except ImportError:
    print("Not running in Colab - skipping auto-download")

# ============================================
# PART 12: Test the Model
# ============================================
print("\n=== Testing Model ===")
# Load saved model and processor
processor = TrOCRProcessor.from_pretrained("khmer_text_recognition_processor_v3")
model = VisionEncoderDecoderModel.from_pretrained("khmer_text_recognition_model_v3")
model.to(device)

# Evaluate on test set
model.eval()
test_preds, test_refs = [], []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # Generate predictions
        outputs = model.generate(pixel_values, max_new_tokens=128)
        decoded_preds = processor.batch_decode(outputs, skip_special_tokens=True)
        decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)

        test_preds.extend(decoded_preds)
        test_refs.extend(decoded_labels)

# Calculate CER and WER
test_cer = cer(test_refs, test_preds)
test_wer = wer(test_refs, test_preds)

# Display sample results
print("\n=== Sample Predictions ===")
for i, (pred, ref) in enumerate(zip(test_preds[:10], test_refs[:10])):
    print(f"\n{i+1}.")
    print(f"Prediction: {pred}")
    print(f"Reference:  {ref}")
    print("-" * 60)

# Display overall metrics
print("\n=== Overall Test Metrics ===")
print(f"Character Error Rate (CER): {test_cer:.4f}")
print(f"Word Error Rate (WER): {test_wer:.4f}")
print(f"\nTotal training steps: {global_step}")
print(f"Total epochs completed: {epochs}")
print("\n=== Training Complete ===")