## Configure Single GPU Use

In [None]:
import os
import sys
import time
import random
import string
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from collections import OrderedDict

from utils import AttnLabelConverter, Averager
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
from model import Model
from evaluation import validation

# --- For reproducibility ---
random.seed(1111)
np.random.seed(1111)
torch.manual_seed(1111)
torch.cuda.manual_seed(1111)

cudnn.benchmark = True
cudnn.deterministic = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device.type}")

## Model Config

In [None]:
class ModelConfig:
    def __init__(self):
        # --- Paths & Naming ---
        self.exp_name = "TPS-ResNet-BiLSTM-Attn-FineTune-Mediseen"
        self.train_data = "data/train"
        self.valid_data = "data/val"
        self.saved_model = "saved_models/TPS-ResNet-BiLSTM-Attn.pth"

        # --- Training Specs ---
        self.manualSeed = 1111
        self.workers = 4
        self.batch_size = 64
        self.num_iter = 25000
        self.valInterval = 500
        self.FT = True
        self.adam = True
        self.lr = 0.0001
        self.beta1 = 0.9
        self.rho = 0.95
        self.eps = 1e-8
        self.grad_clip = 5

        # --- Data Processing ---
        self.select_data = "train"
        self.batch_ratio = "1.0"
        self.total_data_usage_ratio = "1.0"
        self.batch_max_length = 25
        self.imgH = 32
        self.imgW = 100
        self.rgb = False  # GRAY ONLY
        self.character = "0123456789abcdefghijklmnopqrstuvwxyz"
        self.sensitive = False
        self.PAD = True
        self.data_filtering_off = False

        # --- Model Architecture ---
        self.Transformation = "TPS"
        self.FeatureExtraction = "ResNet"
        self.SequenceModeling = "BiLSTM"
        self.Prediction = "Attn"
        self.num_fiducial = 20
        self.input_channel = 3 if self.rgb else 1
        self.output_channel = 512
        self.hidden_size = 256

        # --- GPU ---
        self.num_gpu = torch.cuda.device_count()


opt = ModelConfig()

# Create directory for saving the fine-tuned models
os.makedirs(f"./saved_models/{opt.exp_name}", exist_ok=True)

# Update character set for case-sensitive models
if opt.sensitive:
    opt.character = string.printable[:-6]

# Define the converter to calculate num_class
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)  # This creates the necessary attribute

# Handle multi-GPU settings
if opt.num_gpu > 1:
    print(f"Using {opt.num_gpu} GPUs")
    opt.workers = opt.workers * opt.num_gpu
    opt.batch_size = opt.batch_size * opt.num_gpu

# ✨ FIX: Initialize the model here AFTER all opt attributes are set
model = Model(opt)
print("Model configured successfully.")
print(
    "Model input parameters:",
    opt.imgH,
    opt.imgW,
    opt.num_fiducial,
    opt.input_channel,
    opt.output_channel,
    opt.hidden_size,
    opt.num_class,
    opt.batch_max_length,
    opt.Transformation,
    opt.FeatureExtraction,
    opt.SequenceModeling,
    opt.Prediction,
)
for name, param in model.named_parameters():
    if "FeatureExtraction" in name:
        param.requires_grad = False
    else:
        # Ensure all other layers are trainable
        param.requires_grad = True


## Load Data to GPU Memory

In [None]:
print(f"Loading pretrained model from {opt.saved_model}")
try:
    # Load the state dictionary from the file
    state_dict = torch.load(opt.saved_model, map_location=device)

    # Create a new dictionary without the 'module.' prefix
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace("module.", "")
        new_state_dict[name] = v

    # Load the cleaned state dictionary
    model.load_state_dict(new_state_dict)
    print("✅ Successfully loaded and cleaned pre-trained model weights.")

except Exception as e:
    print(f"❌ Error loading model: {e}")
    print("Training from scratch instead.")
# =====================================================================


# DataParallel for multi-GPU
model = torch.nn.DataParallel(model).to(device)

# --- Loss, Averager, and Optimizer ---
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
    device
)  # ignore [GO] token = 0
loss_avg = Averager()

# Filter that only require gradient descent
filtered_parameters = []
params_num = []
for p in filter(lambda p: p.requires_grad, model.parameters()):
    filtered_parameters.append(p)
    params_num.append(np.prod(p.size()))
print("Trainable params num : ", sum(params_num))

# Setup optimizer
if opt.adam:
    # ✨ ADD weight_decay HERE ✨
    optimizer = optim.Adam(
        filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=0.001
    )
else:
    optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)

scheduler = ReduceLROnPlateau(optimizer, "min", patience=5, factor=0.5)

print("Optimizer:")
print(optimizer)

In [None]:
print("--- Initializing Dataloaders ---")

# --- Dataloaders ---
# Validation loader
AlignCollate_valid = AlignCollate(
    imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD
)
valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=int(opt.workers),
    collate_fn=AlignCollate_valid,
    pin_memory=True,
)
print(valid_dataset_log)

# Split the configuration strings into lists
opt.select_data = opt.select_data.split("-")
opt.batch_ratio = opt.batch_ratio.split("-")

# Training loader
train_dataset = Batch_Balanced_Dataset(opt)
print("Dataloaders initialized successfully.")

In [26]:
start_time = time.time()
best_accuracy = -1
best_norm_ED = -1

# To store metrics for plotting
history = {
    "iterations": [],
    "train_loss": [],
    "valid_loss": [],
    "accuracy": [],
    "norm_ED": [],
}


print("--- Starting Fine-Tuning ---")
for iteration in range(opt.num_iter):
    model.train()
    # --- Train one iteration ---
    image_tensors, labels = train_dataset.get_batch()
    image = image_tensors.to(device)
    text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
    batch_size = image.size(0)

    preds = model(image, text[:, :-1])  # Exclude the [s] token for input
    target = text[:, 1:]  # Exclude the [GO] token for target

    cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

    # l1_lambda = 1e-6  # Regularization strength
    # l1_penalty = sum(p.abs().sum() for p in model.parameters())
    # cost = cost + l1_lambda * l1_penalty

    model.zero_grad()
    cost.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
    optimizer.step()

    loss_avg.add(cost)

    # --- Validation part ---
    if (iteration + 1) % opt.valInterval == 0 or iteration == 0:
        elapsed_time = time.time() - start_time
        print(f"\n--- Validation at Iteration {iteration + 1}/{opt.num_iter} ---")

        # Switch to evaluation mode
        model.eval()
        with torch.no_grad():
            (
                valid_loss,
                current_accuracy,
                current_norm_ED,
                preds,
                confidence_score,
                labels,
                infer_time,
                length_of_data,
            ) = validation(model, criterion, valid_loader, converter, opt)

        # Store metrics
        history["iterations"].append(iteration + 1)
        history["train_loss"].append(loss_avg.val())
        history["valid_loss"].append(valid_loss)
        history["accuracy"].append(current_accuracy)
        history["norm_ED"].append(current_norm_ED)

        # Logging
        loss_log = f"Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}s"
        loss_avg.reset()

        current_model_log = (
            f"Accuracy: {current_accuracy:0.3f} | Norm_ED: {current_norm_ED:0.2f}"
        )
        print(loss_log)
        print(current_model_log)

        # Save the best models
        if current_accuracy > best_accuracy:
            best_accuracy = current_accuracy
            torch.save(
                model.state_dict(), f"./saved_models/{opt.exp_name}/best_accuracy.pth"
            )
            print(f"✨ New best accuracy! Saved to best_accuracy.pth")
        if current_norm_ED > best_norm_ED:
            best_norm_ED = current_norm_ED
            torch.save(
                model.state_dict(), f"./saved_models/{opt.exp_name}/best_norm_ED.pth"
            )
            print(f"✨ New best Norm ED! Saved to best_norm_ED.pth")

        best_model_log = (
            f"Best Accuracy: {best_accuracy:0.3f} | Best Norm_ED: {best_norm_ED:0.2f}"
        )
        print(best_model_log)

        # Show some predicted results
        dashed_line = "-" * 80
        head = f"{'Ground Truth':25s} | {'Prediction':25s} | Confidence Score & T/F"
        predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n"
        for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
            gt = gt[: gt.find("[s]")]
            pred = pred[: pred.find("[s]")]
            predicted_result_log += (
                f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n"
            )
        predicted_result_log += f"{dashed_line}"
        print(predicted_result_log)


print("\n--- End of Training ---")

KeyboardInterrupt: 

In [None]:
history_df = pd.DataFrame(history)

fig, axes = plt.subplots(1, 3, figsize=(24, 6))
fig.suptitle("Model Training and Validation Metrics", fontsize=16)

# --- Plot 1: Loss ---
axes[0].plot(
    history_df["iterations"],
    history_df["train_loss"],
    label="Training Loss",
    color="blue",
    marker="o",
)
axes[0].plot(
    history_df["iterations"],
    history_df["valid_loss"],
    label="Validation Loss",
    color="orange",
    marker="o",
)
axes[0].set_title("Training vs. Validation Loss")
axes[0].set_xlabel("Iterations")
axes[0].set_ylabel("Loss")
axes[0].legend()
axes[0].grid(True)

# --- Plot 2: Accuracy ---
axes[1].plot(
    history_df["iterations"],
    history_df["accuracy"],
    label="Validation Accuracy",
    color="green",
    marker="o",
)
axes[1].set_title("Validation Accuracy")
axes[1].set_xlabel("Iterations")
axes[1].set_ylabel("Accuracy (%)")
axes[1].legend()
axes[1].grid(True)

# --- Plot 3: Normalized Edit Distance ---
axes[2].plot(
    history_df["iterations"],
    history_df["norm_ED"],
    label="Validation Norm ED",
    color="red",
    marker="o",
)
axes[2].set_title("Validation Normalized Edit Distance")
axes[2].set_xlabel("Iterations")
axes[2].set_ylabel("Norm ED")
axes[2].legend()
axes[2].grid(True)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

In [None]:
print("\n--- Running Final Evaluation on the Best Model ---")

# Create a fresh model instance and load the best saved weights
final_model = Model(opt)
final_model = torch.nn.DataParallel(final_model).to(device)

best_model_path = f"./saved_models/{opt.exp_name}/best_accuracy.pth"
print(f"Loading weights from: {best_model_path}")
final_model.load_state_dict(torch.load(best_model_path, map_location=device))

# Run validation
final_model.eval()
with torch.no_grad():
    valid_loss, accuracy, norm_ED, _, _, _, _, _ = validation(
        final_model, criterion, valid_loader, converter, opt
    )

print("\n--- Final Performance on Validation Set ---")
print(f"  -> Final Accuracy: {accuracy:.3f}%")
print(f"  -> Final Normalized Edit Distance: {norm_ED:.3f}")
print(f"  -> Final Validation Loss: {valid_loss:.5f}")

In [None]:
from PIL import Image
import torchvision.transforms as transforms
from dataset import ResizeNormalize, NormalizePAD


def run_inference(model, image_path, opt, converter):
    """
    Loads an image, preprocesses it, and runs inference to predict the text.
    """
    if not os.path.exists(image_path):
        print(f"Error: Image not found at {image_path}")
        return None

    # --- Image Preprocessing ---
    image = Image.open(image_path).convert("RGB" if opt.rgb else "L")

    if opt.PAD:
        # Replicate the logic from AlignCollate for padded resize
        resized_max_w = opt.imgW
        input_channel = 3 if opt.rgb else 1
        transform = NormalizePAD((input_channel, opt.imgH, resized_max_w))

        w, h = image.size
        ratio = w / float(h)
        if math.ceil(opt.imgH * ratio) > opt.imgW:
            resized_w = opt.imgW
        else:
            resized_w = math.ceil(opt.imgH * ratio)

        resized_image = image.resize((resized_w, opt.imgH), Image.BICUBIC)
        image_tensor = transform(resized_image)
    else:
        transform = ResizeNormalize((opt.imgW, opt.imgH))
        image_tensor = transform(image)

    image_tensor = image_tensor.unsqueeze(0).to(device)

    # --- Model Inference ---
    model.eval()
    with torch.no_grad():
        batch_size = image_tensor.size(0)
        # Dummy inputs for attention decoder
        text_for_pred = (
            torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
        )

        preds = model(image_tensor, text_for_pred, is_train=False)

        # Select max probability (greedy decoding)
        _, preds_index = preds.max(2)

        # Decode index to character string
        preds_str = converter.decode(
            preds_index, torch.IntTensor([opt.batch_max_length] * batch_size)
        )

    # --- Post-processing ---
    pred_text = preds_str[0]
    pred_EOS = pred_text.find("[s]")
    if pred_EOS != -1:
        pred_text = pred_text[:pred_EOS]

    return pred_text


```
# --- Example Usage ---
# 1. Create a folder named `test_images` in your project root.
# 2. Add some sample images to it.
# 3. Update the `sample_image_path` variable below.

# Make a dummy test image if it doesn't exist
os.makedirs('test_images', exist_ok=True)
sample_image_path = 'test_images/sample1.png'
if not os.path.exists(sample_image_path):
    try:
        from PIL import Image, ImageDraw, ImageFont
        img = Image.new('RGB', (200, 60), color = (255, 255, 255))
        d = ImageDraw.Draw(img)
        d.text((10,10), "hello world", fill=(0,0,0))
        img.save(sample_image_path)
        print(f"Created a dummy test image at: {sample_image_path}")
    except:
        print(f"Please place a test image at {sample_image_path}")

if os.path.exists(sample_image_path):
    predicted_text = run_inference(final_model, sample_image_path, opt, converter)
    
    # Display the image and the prediction
    plt.figure(figsize=(10, 5))
    img = Image.open(sample_image_path)
    plt.imshow(img)
    plt.title(f'Prediction: "{predicted_text}"', fontsize=16)
    plt.axis('off')
    plt.show()
```