In [1]:
import torch
import random
import pandas as pd
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision.models import resnet18, ResNet18_Weights
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
from torchvision import datasets
from torch.optim.lr_scheduler import ExponentialLR
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torchsummary import summary
from sklearn.metrics import f1_score, accuracy_score

In [None]:
# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Set other constants
BATCH_SIZE = 32
NUM_WORKERS = 8
NUM_CLASSES = 2_000
MAX_EPOCHS = 500

In [3]:
class FontDataset(Dataset):
    def __init__(self, csv_file, base_dir, transform=None):
        """
        Args:
            csv_file (str): Path to the CSV file with annotations.
            base_dir (str): Base directory to prepend to the image paths.
            transform (callable, optional): Transform to be applied on an image.
        """
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.base_dir = base_dir

        # Strip any extra spaces from column names
        self.data.columns = [col.strip() for col in self.data.columns]

        # Build a mapping from font text to a class index
        self.labels = sorted(self.data['text'].unique())
        self.label_to_index = {label: idx for idx, label in enumerate(self.labels)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get image path and label from the CSV
        img_path = self.data.iloc[idx]['img_path']
        label_text = self.data.iloc[idx]['text']
        
        # Prepend the base directory to the image path
        full_img_path = os.path.join(self.base_dir, img_path)
        
        # Load the image (ensure it exists)
        image = Image.open(full_img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)
            
        # Convert text label to an integer class index
        label = self.label_to_index[label_text]
        
        return image, label


In [4]:
# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Use dataset-specific or ImageNet values
                         std=[0.229, 0.224, 0.225])
])

Could be useful in dataloading:

https://lightning.ai/docs/pytorch/stable/advanced/speed.html

In [5]:
# Provide the correct CSV file and the base directory
csv_path = "./dataset/data_2K_2.csv"  # Update with your CSV file path
base_dir = "./dataset"  # Base directory for your images
# base_dir = "./NEW_dataset/dataset"  # Base directory for your images

In [6]:
# dataset = FontDataset(csv_file=csv_path, base_dir=base_dir, transform=transform)

In [7]:
data = pd.read_csv(csv_path)

In [8]:
# data.head()

In [9]:
len(data)

4000000

In [10]:
# # Specify the folder path
# folder_path = "./NEW_dataset/dataset/word_images/test"  # Replace with your folder path

# # Define the image extensions you want to count
# image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp')

# # Initialize a counter
# image_count = 0

# # Iterate through the files in the folder
# for filename in os.listdir(folder_path):
#     if filename.lower().endswith(image_extensions):
#         image_count += 1

# print(f"Number of images in {folder_path}: {image_count}")

In [11]:
# import matplotlib.pyplot as plt
# import torchvision.transforms as transforms

# # Set this to the desired sample index within each class (0 = first sample, 9 = tenth, etc.)
# sample_index = 222  # change this value as needed

# # Define the unnormalization transformation.
# unnormalize = transforms.Normalize(
#     mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
#     std=[1/0.229, 1/0.224, 1/0.225]
# )

# # For example, if you want to show images from 20 classes:
# samples_per_class = 1000
# num_images_to_show = 20

# # Create indices for 20 images, one per class.
# indices = [i * samples_per_class + sample_index for i in range(num_images_to_show)]

# # Create a figure with 4 rows and 5 columns (4*5 = 20 images).
# fig, axes = plt.subplots(4, 5, figsize=(15, 12))
# axes = axes.flatten()

# for ax, idx in zip(axes, indices):
#     image_tensor, label_idx = dataset[idx]
#     label_text = dataset.labels[label_idx]
    
#     # Unnormalize and convert to a PIL image.
#     image_unnorm = unnormalize(image_tensor)
#     image_pil = transforms.ToPILImage()(image_unnorm)
    
#     ax.imshow(image_pil)
#     ax.set_title(f"Label: {label_text}")
#     ax.axis("off")

# plt.tight_layout()
# plt.show()

In [12]:
# import matplotlib.pyplot as plt
# import torchvision.transforms as transforms

# # Set the desired font block index (0 = first font, 1 = second font, etc.)
# desired_font_index =   499

# # Define the unnormalization transformation.
# unnormalize = transforms.Normalize(
#     mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
#     std=[1/0.229, 1/0.224, 1/0.225]
# )

# # Define how many samples (words) per font (class) are in the dataset.
# samples_per_class = 1000  # Adjust this value as per your dataset

# # Calculate indices for the chosen font's block.
# start_index = desired_font_index * samples_per_class
# end_index = start_index + samples_per_class
# indices = list(range(start_index, end_index))

# # Create a figure with 2 rows and 5 columns.
# fig, axes = plt.subplots(2, 5, figsize=(15, 6))
# axes = axes.flatten()

# for ax, idx in zip(axes, indices):
#     image_tensor, label_idx = dataset[idx]
#     # Here, we assume that dataset.labels contains the font label.
#     label_text = dataset.labels[label_idx]
    
#     # Unnormalize and convert to a PIL image.
#     image_unnorm = unnormalize(image_tensor)
#     image_pil = transforms.ToPILImage()(image_unnorm)
    
#     ax.imshow(image_pil)
#     ax.set_title(f"Label: {label_text}")
#     ax.axis("off")

# plt.tight_layout()
# plt.show()

In [13]:
# -------------------------------------------------
# 1) Split dataset into 90% train and 10% temp
# -------------------------------------------------
train_df, temp_df = train_test_split(
    data,
    test_size=0.10,      # 10% of the entire dataset
    random_state=SEED,
    shuffle=True
)

# -------------------------------------------------
# 2) From that temp (10%), split again:
#    10% of temp => 1% of total is test
#    the remaining 90% of temp => 9% of total is val
# -------------------------------------------------
val_df, test_df = train_test_split(
    temp_df,
    test_size=0.10,      # 10% of the 10% => 1% of the total
    random_state=SEED,
    shuffle=True
)

print(f"Train samples: {len(train_df)}")
print(f"Val samples:   {len(val_df)}")
print(f"Test samples:  {len(test_df)}")

# Save to CSV if desired
train_csv = "train_split.csv"
val_csv   = "val_split.csv"
test_csv  = "test_split.csv"

train_df.to_csv(train_csv, index=False)
val_df.to_csv(val_csv, index=False)
test_df.to_csv(test_csv, index=False)

Train samples: 3600000
Val samples:   360000
Test samples:  40000


In [14]:
# Create training and validation datasets using the new CSV splits
train_dataset = FontDataset(csv_file=train_csv, base_dir=base_dir, transform=transform)
val_dataset = FontDataset(csv_file=val_csv, base_dir=base_dir, transform=transform)

# Create DataLoaders for each dataset
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Print dataset sizes for verification
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

Train dataset size: 3600000
Validation dataset size: 360000


In [15]:
# Define ResNet-18 for Arabic font classification
class ArabicResNet18(nn.Module):
    def __init__(self, num_classes):
        super(ArabicResNet18, self).__init__()
        # Load standard ResNet-18 (no pre-trained weights)
        self.resnet = models.resnet18(weights=None)
        # Modify the final fully connected layer to match the number of classes
        self.resnet.fc = nn.Linear(512, num_classes)  # Adjust based on your number of classes

    def forward(self, x):
        return self.resnet(x)

In [16]:
# Example: If you have 1000 classes (adjust as needed)
num_classes = NUM_CLASSES  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ArabicResNet18(num_classes).to(device)

In [17]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=2e-5)  # Learning rate: 2 × 10^-5

# Learning rate scheduler (Exponential Decay: 10^(-1/90000))
scheduler = ExponentialLR(optimizer, gamma=10**(-1/90_000))

In [18]:
# Print model summary; adjust the input shape as needed (here, (3, 224, 224) for an RGB image)
summary(model, (3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
├─ResNet: 1-1                            [-1, 2000]                --
|    └─Conv2d: 2-1                       [-1, 64, 112, 112]        9,408
|    └─BatchNorm2d: 2-2                  [-1, 64, 112, 112]        128
|    └─ReLU: 2-3                         [-1, 64, 112, 112]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 56, 56]          --
|    └─Sequential: 2-5                   [-1, 64, 56, 56]          --
|    |    └─BasicBlock: 3-1              [-1, 64, 56, 56]          73,984
|    |    └─BasicBlock: 3-2              [-1, 64, 56, 56]          73,984
|    └─Sequential: 2-6                   [-1, 128, 28, 28]         --
|    |    └─BasicBlock: 3-3              [-1, 128, 28, 28]         230,144
|    |    └─BasicBlock: 3-4              [-1, 128, 28, 28]         295,424
|    └─Sequential: 2-7                   [-1, 256, 14, 14]         --
|    |    └─BasicBlock: 3-5              [-1, 256, 14, 14]     

Layer (type:depth-idx)                   Output Shape              Param #
├─ResNet: 1-1                            [-1, 2000]                --
|    └─Conv2d: 2-1                       [-1, 64, 112, 112]        9,408
|    └─BatchNorm2d: 2-2                  [-1, 64, 112, 112]        128
|    └─ReLU: 2-3                         [-1, 64, 112, 112]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 56, 56]          --
|    └─Sequential: 2-5                   [-1, 64, 56, 56]          --
|    |    └─BasicBlock: 3-1              [-1, 64, 56, 56]          73,984
|    |    └─BasicBlock: 3-2              [-1, 64, 56, 56]          73,984
|    └─Sequential: 2-6                   [-1, 128, 28, 28]         --
|    |    └─BasicBlock: 3-3              [-1, 128, 28, 28]         230,144
|    |    └─BasicBlock: 3-4              [-1, 128, 28, 28]         295,424
|    └─Sequential: 2-7                   [-1, 256, 14, 14]         --
|    |    └─BasicBlock: 3-5              [-1, 256, 14, 14]     

In [19]:
print(model.resnet.conv1.weight.shape)

torch.Size([64, 3, 7, 7])


In [20]:
def evaluate(model, data_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.numpy())
    f1 = f1_score(all_labels, all_preds, average='macro')
    acc = accuracy_score(all_labels, all_preds)
    return f1, acc

In [21]:
# -------------------------------------------------
# 1) Create the ./models/ directory if not exists
# -------------------------------------------------
os.makedirs("./models_2k", exist_ok=True)

In [22]:
# -------------------------------------------------
# 5) Training Parameters
# -------------------------------------------------
num_iterations_per_epoch = 1000  # As per paper's definition of "epoch"
max_epochs = MAX_EPOCHS          # Upper bound on epochs == 500 # Approximately 5 days with 14.5 mins per epoch on RTX 3090
patience = 30                    # Early stopping patience (in “paper-defined epochs”)
checkpoint_interval = 5

best_val_f1 = float('-inf')     # Track best validation F1 so far
stagnant_epochs = 0             # Count epochs without improvement
start_epoch = 0                 # If resuming, we will update this

checkpoint_path = "./models_2k/checkpoint_last_2k.pth"
best_model_state_dict_path = "./models_2k/resnet18_pretrained_2k.pth"
best_model_full_path = "./models_2k/resnet18_pretrained_full_2k.pth"

In [23]:
# -------------------------------------------------
# 6) If Checkpoint Exists, Resume
# -------------------------------------------------
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    best_val_f1 = checkpoint["best_val_f1"]
    stagnant_epochs = checkpoint["stagnant_epochs"]
    start_epoch = checkpoint["epoch"]  # We stored epoch in the checkpoint
    print(f"Resuming training from epoch {start_epoch} with best_val_f1={best_val_f1:.4f}")

Resuming training from epoch 225 with best_val_f1=0.7782


In [None]:
# -------------------------------------------------
# 7) Training Loop
# -------------------------------------------------
for epoch in range(start_epoch, max_epochs):
    model.train()
    running_loss = 0.0
    num_iterations = 0

    # -------------------------
    # 7a) 1000 training iters
    # -------------------------
    for images, labels in train_loader:
        if num_iterations >= num_iterations_per_epoch:
            break  # Stop after 1000 iterations = "1 epoch"
        
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images) # Forward pass
        loss = criterion(outputs, labels) # Compute loss
        loss.backward() # Backpropagation
        optimizer.step() # Update weights

        running_loss += loss.item()
        num_iterations += 1

    # Compute average loss for this "epoch" (1000 iters)
    epoch_loss = running_loss / num_iterations_per_epoch

    # -------------------------
    # 7b) Evaluate on Train & Val
    # -------------------------
    train_f1, train_acc = evaluate(model, train_loader, device)
    val_f1, val_acc = evaluate(model, val_loader, device)

    # Current LR (assuming single LR group)
    # Print current learning rate (assuming all parameter groups have same LR)

    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch [{epoch+1}], Loss: {epoch_loss:.6f}, LR: {current_lr:.6e}, "
          f"Train F1: {train_f1:.4f}, Train Acc: {train_acc:.4f}, "
          f"Val F1: {val_f1:.4f}, Val Acc: {val_acc:.4f}")

    # -------------------------
    # 7c) Scheduler Step
    # -------------------------
    scheduler.step()  # Decay LR
    new_lr = optimizer.param_groups[0]['lr']
    print(f"LR updated from {current_lr:.6e} to {new_lr:.6e}")

    # -------------------------
    # 7d) Save a Full Checkpoint periodically
    # -------------------------
    if (epoch + 1) % checkpoint_interval == 0:
        checkpoint_data = {
            "epoch": epoch + 1,  # Next epoch to continue from
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "best_val_f1": best_val_f1,
            "stagnant_epochs": stagnant_epochs,
        }
        torch.save(checkpoint_data, checkpoint_path) # "./models_2k/checkpoint_last_2k.pth"
        print(f"Checkpoint saved to {checkpoint_path}")

    # -------------------------
    # 7e) Early Stopping & Best Model Saving
    #     (Based on best Val F1)
    # -------------------------
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        stagnant_epochs = 0

        # Save just the state_dict
        torch.save(model.state_dict(), best_model_state_dict_path) # "./models_2k/resnet18_pretrained_2k.pth"

        # Save the full model (less common, but can be convenient)
        torch.save(model, best_model_full_path) # "./models_2k/resnet18_pretrained_full_2k.pth"

        print(f"** New best model found! Val F1 improved to {best_val_f1:.4f}. Model saved.")
    else:
        stagnant_epochs += 1
        print(f"No improvement in Val F1 for {stagnant_epochs} epoch(s).")

        if stagnant_epochs >= patience:
            print("Early stopping triggered (no F1 improvement).")
            break
        
print("Training complete!")

Epoch [226], Loss: 0.417255, LR: 1.988520e-05, Train F1: 0.7852, Train Acc: 0.8059, Val F1: 0.7753, Val Acc: 0.7963
LR updated from 1.988520e-05 to 1.988469e-05
No improvement in Val F1 for 3 epoch(s).
Epoch [227], Loss: 0.415667, LR: 1.988469e-05, Train F1: 0.7885, Train Acc: 0.8099, Val F1: 0.7797, Val Acc: 0.8013
LR updated from 1.988469e-05 to 1.988418e-05
** New best model found! Val F1 improved to 0.7797. Model saved.
Epoch [228], Loss: 0.416571, LR: 1.988418e-05, Train F1: 0.7804, Train Acc: 0.8034, Val F1: 0.7713, Val Acc: 0.7947
LR updated from 1.988418e-05 to 1.988368e-05
No improvement in Val F1 for 1 epoch(s).
Epoch [229], Loss: 0.401082, LR: 1.988368e-05, Train F1: 0.7868, Train Acc: 0.8095, Val F1: 0.7782, Val Acc: 0.8010
LR updated from 1.988368e-05 to 1.988317e-05
No improvement in Val F1 for 2 epoch(s).
Epoch [230], Loss: 0.407652, LR: 1.988317e-05, Train F1: 0.7883, Train Acc: 0.8099, Val F1: 0.7794, Val Acc: 0.8011
LR updated from 1.988317e-05 to 1.988266e-05
Checkpo

Calculates F1 score, precision, recall, top 1/ 5/ 10 accuracy

In [25]:
# Create test dataset/loader
test_dataset = FontDataset(csv_file=test_csv, base_dir=base_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# -------------------------------------------------
# 8) Load best saved model for inference on val set
# -------------------------------------------------

# 8a) Option 1: If you saved only the state_dict
best_model = ArabicResNet18(num_classes).to(device)
best_model.load_state_dict(torch.load(best_model_state_dict_path))
best_model.eval()

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

def evaluate(model, data_loader, device):
    model.eval()

    all_preds = []
    all_labels = []
    
    # Variables to accumulate counts for top-K accuracy
    total_samples = 0
    correct_top5 = 0
    correct_top10 = 0
    
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)  # [batch_size, num_classes]
            
            # Top-1 predictions (for F1, Accuracy, Precision, Recall)
            preds = outputs.argmax(dim=1)  # [batch_size]
            
            # Collect predictions and labels (CPU) for F1 etc.
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # -------------------------------
            # Top-5 accuracy calculation
            # -------------------------------
            top5_vals, top5_indices = torch.topk(outputs, k=5, dim=1)  # shape: [batch_size, 5]
            # Check if the true label is in the top-5 predicted indices
            for i in range(labels.size(0)):
                if labels[i].item() in top5_indices[i]:
                    correct_top5 += 1
                    
            # -------------------------------
            # Top-10 accuracy calculation
            # -------------------------------
            top10_vals, top10_indices = torch.topk(outputs, k=10, dim=1)  # shape: [batch_size, 10]
            # Check if the true label is in the top-10 predicted indices
            for i in range(labels.size(0)):
                if labels[i].item() in top10_indices[i]:
                    correct_top10 += 1
            
            total_samples += labels.size(0)

    # Compute global metrics (precision, recall, F1, accuracy)
    f1 = f1_score(all_labels, all_preds, average='macro')
    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    
    # Compute top-5 and top-10 accuracy
    top5_acc = correct_top5 / total_samples
    top10_acc = correct_top10 / total_samples

    return f1, acc, precision, recall, top5_acc, top10_acc

test_f1, test_acc, test_precision, test_recall, test_top5, test_top10 = evaluate(best_model, test_loader, device)

print("\n[Test Evaluation Metrics]")
print(f"  F1 Score       : {test_f1:.4f}")
print(f"  Accuracy       : {test_acc:.4f}")
print(f"  Precision      : {test_precision:.4f}")
print(f"  Recall         : {test_recall:.4f}")
print(f"  Top-5 Accuracy : {test_top5:.4f}")
print(f"  Top-10 Accuracy: {test_top10:.4f}\n")


[Test Evaluation Metrics]
  F1 Score       : 0.7956
  Accuracy       : 0.8195
  Precision      : 0.8053
  Recall         : 0.8171
  Top-5 Accuracy : 0.9963
  Top-10 Accuracy: 0.9994



  _warn_prf(average, modifier, msg_start, len(result))
