# Global Imports

In [1]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from transformers.modeling_outputs import Seq2SeqLMOutput
from torch.optim import AdamW
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


# Local Imports

In [2]:
from models.trocr_apl import TrocrApl
from dataset.trocr_dataset import HandwrittenTextDataset

# Constants

In [3]:
MAX_LINE_LENGTH: int = 128
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
CSV_SEPERATOR: str = "⯑"
MAX_STRING_LENGTH: int = 128
BATCH_SIZE: int = 1

# Paths

In [4]:

file_dirpath: str = os.path.abspath(".")
root_dirpath: str = os.path.join(
    file_dirpath,
    os.pardir
)
dataset_dirpath: str = os.path.join(
    root_dirpath,
    "dataset"
)
apl_dataset_dirpath: str = os.path.join(
    dataset_dirpath,
    "apl_dataset"
)
metadata_csv_filepath: str = os.path.join(
    dataset_dirpath,
    "metadata_apl_fix.csv"
)
checkpoint_dirpath: str = os.path.join(
    root_dirpath,
    "model_checkpoints"
)
os.makedirs(checkpoint_dirpath, exist_ok=True)

# Load Model

In [5]:
trocr_model: TrocrApl = TrocrApl(
    max_target_length=128
    
)


Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.46.2"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder

# Load Dataset

In [6]:
metadata_df: pd.DataFrame = pd.read_csv(
    metadata_csv_filepath,
    delimiter=CSV_SEPERATOR,
    encoding="utf-8",
    engine="python"
)
filenames: list[str] = metadata_df["filename"].to_list()
labels: list[str] = metadata_df["label"].to_list()

In [7]:

train_filenames: list[str] = []
val_filenames: list[str] = []
train_labels: list[str] = []
val_labels: list[str] = []

train_filenames, val_filenames, train_labels, val_labels = train_test_split(
    filenames[:4], 
    labels[:4],
    train_size=0.75
)


In [8]:

train_dataset: HandwrittenTextDataset = HandwrittenTextDataset(
    dataset_dirpath=apl_dataset_dirpath,
    filenames=train_filenames,
    label_strings=train_labels,
    processor=trocr_model.processor,
    max_target_length=MAX_STRING_LENGTH
)
val_dataset: HandwrittenTextDataset = HandwrittenTextDataset(
    dataset_dirpath=apl_dataset_dirpath,
    filenames=val_filenames,
    label_strings=val_labels,
    processor=trocr_model.processor,
    max_target_length=MAX_STRING_LENGTH
)

train_dataloader: DataLoader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

val_dataloader: DataLoader = DataLoader(
    dataset=val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False
)

# Train Model

In [9]:
optimiser: AdamW = AdamW(
    params=trocr_model.model.parameters(),
    lr=5e-5
)

In [10]:


epoch: int
for epoch in range(10000):

    train_loss: float = 0.0
    
    image: torch.Tensor
    encoded_label: torch.Tensor
    for image, encoded_label in tqdm(
        iterable=train_dataloader,
        desc="Training model...",
        total=len(train_dataset)//BATCH_SIZE
    ):
        
        image = image.to(DEVICE)
        encoded_label = encoded_label.to(DEVICE)
        
        trocr_model.model.to(device=DEVICE)
        trocr_model.model = trocr_model.model.train()
        
        trocr_output: Seq2SeqLMOutput = trocr_model.forward(
            pixels=image,
            encoded_labels=encoded_label
        )
        
        predicted_encoded_string: torch.Tensor = trocr_model.model.generate(
            image
        )
        
        predicted_string: str = trocr_model.decode_model_output(
            predicted_encoded_string
        )
        y_string: str = trocr_model.decode_model_output(
            encoded_label
        )
        print(f"Training: \ny:{y_string}\ny_hat{predicted_string}\n==================")
        loss: torch.Tensor = trocr_output.loss
        loss.backward()
        optimiser.step()
        optimiser.zero_grad()

        train_loss += loss.item()
        
    print(f"Loss after epoch {epoch}:", train_loss/len(train_dataloader))
    
    trocr_model.model = trocr_model.model.eval()
    
    valid_cer: float = 0.0
    with torch.no_grad():
        
        image: torch.Tensor
        encoded_label: torch.Tensor        
        for image, encoded_label in tqdm(
            train_dataloader
        ):  
            
            image = image.to(device=DEVICE)
            encoded_label = encoded_label.to(device=DEVICE)
            
            encoded_string_prediction: torch.Tensor = trocr_model.model.generate(
                image
            )
            
        
            predicted_output: list[str] = trocr_model.decode_model_output(
                encoded_string_prediction
            )
            correct_output: list[str] = trocr_model.decode_model_output(
                encoded_label
            )
            
            #plt.imshow(X.detach().cpu()[0, :, :, :].permute((1, 2, 0)))
            #plt.show()S
            for y_, y_hat in zip(correct_output, predicted_output):
                print(f"y: {y_}\ny_hat: {y_hat}\n=====================================")
            
            cer: float = trocr_model.compute_character_error_rate(
                pred_ids=encoded_string_prediction, 
                label_ids=encoded_label
            )
            
            valid_cer += cer 

    print("Validation CER:", valid_cer / len(val_dataloader))

    trocr_model.model.save_pretrained(checkpoint_dirpath)



Training: 
y:['gck,←(B 1)(B 2)(B 3)(B 4)']
y_hat[' (3)']


Training model...:  33%|███▎      | 1/3 [02:29<04:58, 149.03s/it]

Training: 
y:['GC←{']
y_hat['�']


Training model...:  67%|██████▋   | 2/3 [08:55<04:27, 267.73s/it]


KeyboardInterrupt: 