In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/handwriting-hdf5/data/washington.hdf5
/kaggle/input/handwriting-hdf5/data/bentham.hdf5
/kaggle/input/handwriting-hdf5/data/iam.hdf5
/kaggle/input/handwriting-hdf5/data/saintgall.hdf5
/kaggle/input/handwriting-hdf5/data/merged/merged.hdf5


In [2]:
if not os.path.exists('/kaggle/working/experimental-transformer-ocr'):
    ! git clone https://github.com/aritra-github26/experimental-transformer-ocr.git

In [3]:
os.chdir('/kaggle/working/experimental-transformer-ocr')
! pwd

/kaggle/working/experimental-transformer-ocr


In [4]:
!pip install -r requirements.txt

Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->-r requirements.txt (line 10))
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->-r requirements.txt (line 10))
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->-r requirements.txt (line 10))
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch->-r requirements.txt (line 10))
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch->-r requirements.txt (line 10))
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch->-r requirements.txt (line 10))
  Downloading nvid

In [5]:
os.chdir('/kaggle/working/experimental-transformer-ocr/src')
! pwd

/kaggle/working/experimental-transformer-ocr/src


In [6]:
from pathlib import Path
import numpy as np
import math
from itertools import groupby
import h5py
import numpy as np
import unicodedata
import cv2
import torch
from torch import nn
from torchvision.models import resnet50, resnet34
from torch.autograd import Variable
import torchvision
from data import preproc as pp
from data import evaluation
from torch.utils.data import Dataset
import time
import torch.nn.functional as F

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=128):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class OCR(nn.Module):

    def __init__(self, vocab_len, hidden_dim):
        super().__init__()

        # create ResNet-50 backbone
        self.backbone = resnet50(pretrained=True)  # Use pretrained weights
        del self.backbone.fc

        # Add batch normalization after CNN features
        self.batch_norm = nn.BatchNorm2d(2048)
        
        # create conversion layer with dropout
        self.dropout1 = nn.Dropout(0.2)
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        
        # Add layer normalization
        self.layer_norm = nn.LayerNorm(hidden_dim)

        # create a 2-layer BiLSTM with dropout
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2,
                           bidirectional=True, batch_first=True, dropout=0.2)
        
        # Add dropout before final projection
        self.dropout2 = nn.Dropout(0.2)
        
        # prediction heads with length of vocab
        self.vocab = nn.Linear(hidden_dim * 2, vocab_len)

        # spatial positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.query_pos = PositionalEncoding(hidden_dim, 0.1)  # Reduced dropout

    def get_feature(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)   
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)
        return x

    def forward(self, inputs):
        """
        Forward pass through the model with improved regularization.
        
        Args:
            inputs: Input images (batch of images).
        
        Returns:
            Output predictions for the sequences.
        """
        # Propagate inputs through ResNet-50
        x = self.get_feature(inputs)
        
        # Apply batch normalization
        x = self.batch_norm(x)
        
        # Apply dropout and conv
        x = self.dropout1(x)
        h = self.conv(x)  # shape: (batch, hidden_dim, H, W)

        # Add spatial positional encodings
        h_shape = h.shape
        H, W = h_shape[2], h_shape[3]
        row_emb = self.row_embed[:H].unsqueeze(1).repeat(1, W, 1)
        col_emb = self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1)
        pos_emb = torch.cat([row_emb, col_emb], dim=-1).permute(2, 0, 1).unsqueeze(0)
        pos_emb = pos_emb.to(h.device)
        h = h + pos_emb

        # Prepare input for LSTM
        h = h.flatten(2).permute(0, 2, 1)  # (batch, seq_len, feature)
        
        # Apply layer normalization
        h = self.layer_norm(h)

        # Add positional encoding
        h = h.permute(1, 0, 2)  # (seq_len, batch, feature)
        h = self.query_pos(h)
        h = h.permute(1, 0, 2)  # (batch, seq_len, feature)

        # Pass through BiLSTM
        h, _ = self.lstm(h)  # (batch, seq_len, hidden_dim*2)
        
        # Apply dropout before final projection
        h = self.dropout2(h)
        
        # Calculate output with temperature scaling
        temperature = 0.1
        h = h.permute(1, 0, 2)  # (seq_len, batch, hidden_dim*2)
        output = self.vocab(h) / temperature  # Scale logits

        return output


def make_model(vocab_len, hidden_dim=256):
    
    return OCR(vocab_len, hidden_dim)

In [None]:
def inspect_batch(dataloader):
    batch = next(iter(dataloader))
    imgs, labels = batch
    
    print("Batch shapes:")
    print(f"Images: {imgs.shape}")
    print(f"Labels: {labels.shape}")
    
    # Show first image and its label
    img = imgs[0].numpy().transpose(1, 2, 0)
    img = (img * 255).astype(np.uint8)
    plt.figure(figsize=(10, 3))
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    plt.title(f"Ground Truth: {tokenizer.decode(labels[0].tolist())}")
    plt.show()
    
    # Print first few labels
    print("\nFirst 3 decoded labels:")
    for i in range(min(3, len(labels))):
        print(f"{i+1}. {tokenizer.decode(labels[i].tolist())}")

print("Inspecting training data:")
inspect_batch(train_loader)


In [None]:
# Updated training configuration
import os
import datetime
import string
from torch.optim.lr_scheduler import ReduceLROnPlateau
import math

# Training parameters
batch_size = 32  # Increased for better batch statistics
epochs = 50      # Reduced but with better scheduling
warmup_epochs = 3
max_lr = 0.001
min_lr = 1e-6

# Gradient accumulation for effective larger batch size
gradient_accumulation_steps = 2

# define paths
source_path = '/kaggle/input/handwriting-hdf5/data/merged/merged.hdf5'
output_path = '/kaggle/working/output'
target_path = output_path + '/merged_training_weights_ctc_150.pt'

os.makedirs(output_path, exist_ok=True)

# define input size, number max of chars per line and list of valid chars
input_size = (1024, 128, 1)
max_text_length = 128
charset_base = string.printable[:95]

print("Source:", source_path)
print("Output:", output_path)
print("Target:", target_path)
print("Charset:", charset_base)
print(f"Training config: batch_size={batch_size}, epochs={epochs}, warmup_epochs={warmup_epochs}")
print(f"Learning rates: max_lr={max_lr}, min_lr={min_lr}")


In [8]:
"""
Uses generator functions to supply train/test with data.
Image renderings and text are created on the fly each time.
"""

class DataGenerator(Dataset):
    """Generator class with data streaming"""

    def __init__(self, source,charset, max_text_length, split, transform):
        self.tokenizer = Tokenizer(charset, max_text_length)
        self.transform = transform
        
        self.split = split
        self.dataset = dict()

        with h5py.File(source, "r") as f:
            self.dataset[self.split] = dict()

            self.dataset[self.split]['dt'] = np.array(f[self.split]['dt'])
            self.dataset[self.split]['gt'] = np.array(f[self.split]['gt'])
          
            randomize = np.arange(len(self.dataset[self.split]['gt']))
            np.random.seed(42)
            np.random.shuffle(randomize)

            self.dataset[self.split]['dt'] = self.dataset[self.split]['dt'][randomize]
            self.dataset[self.split]['gt'] = self.dataset[self.split]['gt'][randomize]

            # decode sentences from byte
            self.dataset[self.split]['gt'] = [x.decode() for x in self.dataset[self.split]['gt']]
            
        self.size = len(self.dataset[self.split]['gt'])


    def __getitem__(self, i):
        img = self.dataset[self.split]['dt'][i]
        
        #making image compatible with resnet
        img = np.repeat(img[..., np.newaxis],3, -1)    
        img = pp.normalization(img)
        
        if self.transform is not None:
            img = self.transform(img)

        y_train = self.tokenizer.encode(self.dataset[self.split]['gt'][i]) 
        
        #padding till max length
        y_train = np.pad(y_train, (0, self.tokenizer.maxlen - len(y_train)))

        gt = torch.Tensor(y_train)

        return img, gt          

    def __len__(self):
      return self.size



class Tokenizer():
    """Manager tokens functions and charset/dictionary properties"""

    def __init__(self, chars, max_text_length=128):
        self.PAD_TK, self.UNK_TK,self.SOS,self.EOS = "¶", "¤", "SOS", "EOS"
        self.chars = [self.PAD_TK] + [self.UNK_TK ]+ [self.SOS] + [self.EOS] +list(chars)
        self.PAD = self.chars.index(self.PAD_TK)
        self.UNK = self.chars.index(self.UNK_TK)

        self.vocab_size = len(self.chars)
        self.maxlen = max_text_length

    def encode(self, text):
        """Encode text to vector"""

        text = unicodedata.normalize("NFKD", text).encode("ASCII", "ignore").decode("ASCII")
        text = " ".join(text.split())

        groups = ["".join(group) for _, group in groupby(text)]
        text = "".join([self.UNK_TK.join(list(x)) if len(x) > 1 else x for x in groups])
        encoded = []

        text = ['SOS'] + list(text) + ['EOS']
        for item in text:
            index = self.chars.index(item)
            index = self.UNK if index == -1 else index
            encoded.append(index)

        return np.asarray(encoded)

    def decode(self, text):
        """Decode vector to text"""
        
        decoded = "".join([self.chars[int(x)] for x in text if x > -1])
        decoded = self.remove_tokens(decoded)
        decoded = pp.text_standardize(decoded)

        return decoded

    def remove_tokens(self, text):
        """Remove tokens (PAD) from text"""

        return text.replace(self.PAD_TK, "").replace(self.UNK_TK, "")



In [9]:
import os
import datetime
import string

batch_size = 16
epochs = 200

# define paths
#change paths accordingly

source_path = '/kaggle/input/handwriting-hdf5/data/merged/merged.hdf5'
output_path = '/kaggle/working/output'
target_path = output_path + '/merged_training_weights_ctc_150.pt'

os.makedirs(output_path, exist_ok=True)

# define input size, number max of chars per line and list of valid chars
input_size = (1024, 128, 1)
max_text_length = 128
charset_base = string.printable[:95]

print("source:", source_path)
print("output:", output_path)
print("target", target_path)
print("charset:", charset_base)

source: /kaggle/input/handwriting-hdf5/data/merged/merged.hdf5
output: /kaggle/working/output
target /kaggle/working/output/merged_training_weights_ctc_150.pt
charset: 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ 


In [10]:
import torchvision.transforms as T

device = torch.device("cuda")
transform = T.Compose([
    T.ToTensor()])
tokenizer = Tokenizer(charset_base)

In [10]:


train_loader = torch.utils.data.DataLoader(DataGenerator(source_path,charset_base,max_text_length,'train',transform), batch_size=batch_size, shuffle=False, num_workers=2)
val_loader = torch.utils.data.DataLoader(DataGenerator(source_path,charset_base,max_text_length,'valid',transform), batch_size=batch_size, shuffle=False, num_workers=2)


In [9]:
model = make_model(vocab_len=tokenizer.vocab_size)
_=model.to(device)

NameError: name 'tokenizer' is not defined

In [None]:
# Main training loop with improved monitoring
best_valid_loss = float('inf')
early_stopping_patience = 10
early_stopping_counter = 0
training_history = []

print("Starting training...")
print(f"Training on device: {device}")
print(f"Batch size: {batch_size} (effective batch size: {batch_size * gradient_accumulation_steps})")
print(f"Number of epochs: {epochs}")
print(f"Learning rate range: {min_lr} to {max_lr}")

for epoch in range(epochs):
    start_time = time.time()
    
    # Training phase
    train_loss = train(model, criterion, optimizer, scheduler, 
                      train_loader, tokenizer.vocab_size, device)
    
    # Validation phase
    valid_loss = evaluate(model, criterion, val_loader, tokenizer.vocab_size, device)
    
    # Update learning rate schedulers
    scheduler.step()  # Update the warmup/decay scheduler
    plateau_scheduler.step(valid_loss)  # Update the plateau scheduler
    
    # Calculate epoch time
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    # Save best model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_valid_loss': best_valid_loss,
            'tokenizer': tokenizer,
        }, target_path)
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
    
    # Store training history
    training_history.append({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'valid_loss': valid_loss,
        'learning_rate': optimizer.param_groups[0]['lr'],
        'time': epoch_mins * 60 + epoch_secs
    })
    
    # Print epoch results
    print(f'\nEpoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.4f}')
    print(f'\tValid Loss: {valid_loss:.4f}')
    print(f'\tLearning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
    
    # Early stopping check
    if early_stopping_counter >= early_stopping_patience:
        print(f'\nEarly stopping triggered after {epoch+1} epochs')
        break
    
    # Clear memory
    torch.cuda.empty_cache()
    gc.collect()

# Save training history
import pandas as pd
history_df = pd.DataFrame(training_history)
history_df.to_csv(f'{output_path}/training_history.csv', index=False)

print(f'\nTraining completed. Best validation loss: {best_valid_loss:.4f}')


In [12]:
# Loss function with blank token at index 0
criterion = nn.CTCLoss(blank=0, zero_infinity=True, reduction='mean')
criterion.to(device)

# Optimizer with improved weight decay and initial learning rate
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=min_lr,  # Start with min_lr for warmup
    weight_decay=0.01,  # Increased weight decay for better regularization
    betas=(0.9, 0.999),  # Default Adam betas
    eps=1e-8
)

# Learning rate scheduler with warmup and plateau detection
def get_lr(epoch):
    if epoch < warmup_epochs:
        # Linear warmup
        return min_lr + (max_lr - min_lr) * epoch / warmup_epochs
    return max_lr * (0.95 ** (epoch - warmup_epochs))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=get_lr)

# Additional plateau scheduler for fine-tuning
plateau_scheduler = ReduceLROnPlateau(
    optimizer, 
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True,
    min_lr=min_lr
)


In [11]:


def train(model, criterion, optimizer, scheduler, dataloader, vocab_length, device):
    """
    Train the model using the provided dataloader with gradient accumulation.
    
    Args:
        model: The OCR model
        criterion: Loss function (CTC)
        optimizer: Optimizer instance
        scheduler: Learning rate scheduler
        dataloader: Training data loader
        vocab_length: Size of vocabulary
        device: Device to train on
        
    Returns:
        float: Average loss for the epoch
    """
    model.train()
    total_loss = 0
    total_items = 0
    optimizer.zero_grad()  # Zero gradients at start of epoch
    
    for batch_idx, (imgs, labels_y) in enumerate(dataloader):
        imgs = imgs.to(device)
        labels_y = labels_y.to(device)
        batch_size = imgs.size(0)
        
        # Forward pass
        output = model(imgs.float())
        
        # Ensure output is in (batch, seq_len, vocab_size) format
        if output.dim() != 3:
            raise ValueError(f"Expected 3D output tensor, got shape: {output.shape}")
        
        # Apply temperature scaling and log_softmax
        temperature = 0.1
        log_probs = F.log_softmax(output / temperature, dim=2).permute(1, 0, 2)
        
        # Calculate input sequence lengths
        input_lengths = torch.full((batch_size,), 
                                 log_probs.size(0), 
                                 dtype=torch.long,
                                 device=device)
        
        # Process target sequences
        target_lengths = []
        labels_list = []
        valid_samples = []
        
        for i in range(batch_size):
            # Find non-zero elements (non-padding)
            non_zero = labels_y[i].nonzero().squeeze()
            if non_zero.dim() == 0:
                continue
            
            length = non_zero.shape[0]
            if length > log_probs.size(0):
                continue
                
            sequence = labels_y[i, :length]
            target_lengths.append(length)
            labels_list.append(sequence)
            valid_samples.append(i)
        
        if not valid_samples:
            continue
            
        # Keep only valid samples
        valid_samples = torch.tensor(valid_samples, device=device)
        log_probs = log_probs[:, valid_samples, :]
        input_lengths = input_lengths[valid_samples]
        target_lengths = torch.tensor(target_lengths, dtype=torch.long, device=device)
        labels_packed = torch.cat(labels_list)
        
        try:
            # Calculate CTC loss
            loss = criterion(log_probs,
                           labels_packed,
                           input_lengths,
                           target_lengths)
            
            # Scale loss for gradient accumulation
            loss = loss / gradient_accumulation_steps
            loss.backward()
            
            # Update weights if we've accumulated enough gradients
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
                optimizer.step()
                optimizer.zero_grad()
            
            # Track statistics
            total_loss += loss.item() * len(valid_samples) * gradient_accumulation_steps
            total_items += len(valid_samples)
            
            # Print progress every 100 batches
            if (batch_idx + 1) % 100 == 0:
                avg_loss = total_loss / total_items if total_items > 0 else float('inf')
                print(f'Batch {batch_idx + 1}/{len(dataloader)}, '
                      f'Loss: {avg_loss:.4f}, '
                      f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
            
        except RuntimeError as e:
            print(f"Error in batch {batch_idx}:")
            print(f"log_probs shape: {log_probs.shape}")
            print(f"labels_packed shape: {labels_packed.shape}")
            print(f"input_lengths shape: {input_lengths.shape}")
            print(f"target_lengths shape: {target_lengths.shape}")
            raise e
        
        # Clear memory
        del output, log_probs, loss
        if batch_idx % 10 == 0:  # Every 10 batches
            torch.cuda.empty_cache()
    
    # Handle any remaining accumulated gradients
    if (batch_idx + 1) % gradient_accumulation_steps != 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        optimizer.zero_grad()
    
    return total_loss / total_items if total_items > 0 else float('inf')

def evaluate(model, criterion, dataloader, vocab_length, device):
    """
    Evaluate the model using the provided dataloader.
    
    Args:
        model: The OCR model
        criterion: Loss function (CTC)
        dataloader: Validation data loader
        vocab_length: Size of vocabulary
        device: Device to evaluate on
        
    Returns:
        float: Average loss for the epoch
    """
    model.eval()
    total_loss = 0
    total_items = 0

    with torch.no_grad():
        for batch, (imgs, labels_y) in enumerate(dataloader):
            imgs = imgs.to(device)
            labels_y = labels_y.to(device)
            batch_size = imgs.size(0)

            # Forward pass
            output = model(imgs.float())
            
            # Ensure output is in (batch, seq_len, vocab_size) format
            if output.dim() != 3:
                raise ValueError(f"Expected 3D output tensor, got shape: {output.shape}")
            
            # Permute to (seq_len, batch, vocab_size) for CTC loss
            log_probs = F.log_softmax(output, dim=2).permute(1, 0, 2)
            
            # Calculate input sequence lengths (all are same length after CNN processing)
            input_lengths = torch.full((batch_size,), 
                                     log_probs.size(0), 
                                     dtype=torch.long,
                                     device=device)
            
            # Calculate target lengths (excluding padding)
            target_lengths = []
            labels_list = []
            valid_samples = []
            
            # Process each sequence in the batch
            for i in range(batch_size):
                # Find non-zero elements (non-padding)
                non_zero = labels_y[i].nonzero().squeeze()
                if non_zero.dim() == 0:  # Handle case of empty sequence
                    continue  # Skip this sample
                
                length = non_zero.shape[0]
                if length > log_probs.size(0):  # Skip if target is longer than output
                    continue
                    
                sequence = labels_y[i, :length]
                target_lengths.append(length)
                labels_list.append(sequence)
                valid_samples.append(i)
            
            # Skip batch if no valid samples
            if not valid_samples:
                continue
                
            # Keep only valid samples
            valid_samples = torch.tensor(valid_samples, device=device)
            log_probs = log_probs[:, valid_samples, :]
            input_lengths = input_lengths[valid_samples]
            
            # Convert target lengths to tensor
            target_lengths = torch.tensor(target_lengths, dtype=torch.long, device=device)
            
            # Concatenate all label sequences
            labels_packed = torch.cat(labels_list)
            
            try:
                # CTC loss calculation
                loss = criterion(log_probs,
                               labels_packed,
                               input_lengths,
                               target_lengths)
                
                total_loss += loss.item() * len(valid_samples)
                total_items += len(valid_samples)
                
            except RuntimeError as e:
                print(f"Error in batch {batch}:")
                print(f"log_probs shape: {log_probs.shape}")
                print(f"labels_packed shape: {labels_packed.shape}")
                print(f"input_lengths shape: {input_lengths.shape}")
                print(f"target_lengths shape: {target_lengths.shape}")
                raise e

    return total_loss / total_items if total_items > 0 else float('inf')

def get_memory(model, imgs):
    """
    Extract features and apply positional encoding for the BiLSTM model.
    
    Args:
        model: The OCR model
        imgs: Input images tensor
        
    Returns:
        Memory tensor with shape (seq_len, batch, hidden_dim*2)
    """
    with torch.no_grad():
        # Extract CNN features
        features = model.get_feature(imgs)
        
        # Apply conv layer
        conv_out = model.conv(features)
        
        # Get spatial dimensions
        bs, c, h, w = conv_out.size()
        
        # Add positional encodings
        row_emb = model.row_embed[:h].unsqueeze(1).repeat(1, w, 1)  # (H, W, hidden_dim//2)
        col_emb = model.col_embed[:w].unsqueeze(0).repeat(h, 1, 1)  # (H, W, hidden_dim//2)
        pos_emb = torch.cat([row_emb, col_emb], dim=-1).permute(2, 0, 1).unsqueeze(0)  # (1, hidden_dim, H, W)
        pos_emb = pos_emb.to(conv_out.device)
        conv_out = conv_out + pos_emb
        
        # Flatten spatial dimensions and permute for LSTM
        lstm_input = conv_out.flatten(2).permute(0, 2, 1)  # (batch, seq_len, feature)
        
        # Add positional encoding to LSTM input
        lstm_input = lstm_input.permute(1, 0, 2)  # (seq_len, batch, feature)
        lstm_input = model.query_pos(lstm_input)
        lstm_input = lstm_input.permute(1, 0, 2)  # (batch, seq_len, feature)
        
        # Apply BiLSTM
        lstm_out, _ = model.lstm(lstm_input)
        
        # Return in shape (seq_len, batch, hidden_dim*2)
        return lstm_out.permute(1, 0, 2)

def single_image_inference(model, img, tokenizer, transform, device):
    """
    Run inference on single image using greedy decoding.
    
    Args:
        model: The OCR model
        img: Input image
        tokenizer: Tokenizer for encoding/decoding text
        transform: Image transform pipeline
        device: Device to run inference on
        
    Returns:
        pred_text: Predicted text string
    """
    model.eval()
    
    # Preprocess image
    img = transform(img)
    imgs = img.unsqueeze(0).float().to(device)
    
    with torch.no_grad():
        # Forward pass
        output = model(imgs)
        
        # Ensure output is in (seq_len, batch, vocab_size) format
        if output.dim() == 3:
            output = output.permute(1, 0, 2)
        
        # Apply log softmax and get predictions
        output = F.log_softmax(output, dim=2)
        output = output.argmax(dim=2)
        output = output.squeeze(1)  # Remove batch dimension
        
        # Convert prediction to text (handle special tokens)
        out_indices = []
        for idx in output:
            token = idx.item()
            if token == tokenizer.chars.index('EOS'):
                break
            if token > tokenizer.chars.index('EOS'):  # Skip special tokens
                out_indices.append(token)
        
        # Decode the prediction
        pred_text = tokenizer.decode(out_indices)
    
    return pred_text

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


def run_epochs(model, criterion, optimizer, scheduler, train_loader, val_loader, epochs, tokenizer, target_path, device):
    """
    Run training for specified number of epochs.
    
    Args:
        model: The OCR model
        criterion: Loss function (CTC)
        optimizer: Optimizer instance
        scheduler: Learning rate scheduler
        train_loader: Training data loader
        val_loader: Validation data loader
        epochs: Number of epochs to train
        tokenizer: Tokenizer for encoding/decoding text
        target_path: Path to save model checkpoints
        device: Device to train on
    """
    best_valid_loss = float('inf')
    patience = 0
    max_patience = 4  # Number of epochs to wait before reducing learning rate
    
    for epoch in range(epochs):
        print(f'Epoch: {epoch + 1:02} | Learning rate: {scheduler.get_last_lr()[0]:.6f}')
        
        start_time = time.time()
        
        # Training phase
        train_loss = train(model, criterion, optimizer, scheduler, 
                          train_loader, tokenizer.vocab_size, device)
        
        # Validation phase
        valid_loss = evaluate(model, criterion, val_loader, tokenizer.vocab_size, device)
        
        epoch_mins, epoch_secs = epoch_time(start_time, time.time())
        
        # Save best model based on validation loss
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_valid_loss': best_valid_loss,
            }, target_path)
            patience = 0
        else:
            patience += 1
        
        # Reduce learning rate if validation loss hasn't improved
        if patience >= max_patience:
            scheduler.step()
            patience = 0
        
        print(f'Time: {epoch_mins}m {epoch_secs}s')
        print(f'Train Loss: {train_loss:.3f}')
        print(f'Val   Loss: {valid_loss:.3f}')
        print(f'Best Val Loss: {best_valid_loss:.3f}')
    
    print(f'Training completed. Best validation loss: {best_valid_loss:.3f}')








In [14]:
# train model
# This is the actual code :)



import joblib
training_results = [] # format: [training loss, validation loss, epoch time in seconds]
 
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


# def run_epochs(model, criterion, optimizer, scheduler, train_loader, val_loader, epochs, tokenizer, target_path, device):
'''
run one epoch for a model
'''
epochs = 3
c = 0
best_valid_loss = float('inf')

for epoch in range(epochs):     
    print(f'Epoch: {epoch + 1:02}', 'learning rate{}'.format(scheduler.get_last_lr()))
    
    start_time = time.time()
    
    # Training phase
    train_loss = train(model, criterion, optimizer, scheduler, 
                              train_loader, tokenizer.vocab_size, device)
    
    # Validation phase
    valid_loss = evaluate(model, criterion, val_loader, tokenizer.vocab_size, device)
    
    epoch_mins, epoch_secs = epoch_time(start_time, time.time())
    
    # Save best model based on validation loss
    c += 1
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        end_time = time.time()
        torch.save(model.state_dict(), target_path)
        training_results.append((train_loss, valid_loss, int(end_time - start_time)))
        c = 0
    
    if c > 4:
        scheduler.step()
        c = 0
    
    
    print(f'Time: {epoch_mins}m {epoch_secs}s')
    print(f'Train Loss: {train_loss:.3f}')
    print(f'Val   Loss: {valid_loss:.3f}')

print(best_valid_loss)


# training_results.append((train_loss, valid_loss, int(end_time - start_time)))
joblib.dump(training_results, output_path + '/training_results.joblib')


Epoch: 01 learning rate[0.0001]
Time: 3m 44s
Train Loss: 3.726
Val   Loss: 3.645
Epoch: 02 learning rate[0.0001]
Time: 3m 42s
Train Loss: 3.612
Val   Loss: 3.650
Epoch: 03 learning rate[0.0001]
Time: 3m 44s
Train Loss: 3.601
Val   Loss: 3.648
3.6454115870972754


['/kaggle/working/output/training_results.joblib']

In [12]:
model = make_model(vocab_len=tokenizer.vocab_size)
_=model.to(device)

# model.load_state_dict(torch.load(target_path))
model.load_state_dict(torch.load(target_path))

<All keys matched successfully>

In [13]:
# def get_memory(model, imgs):
#     # Refactored get_memory for resnet50-biLSTM model: simply extract features and apply conv and lstm
#     with torch.no_grad():
#         features = model.get_feature(imgs)
#         conv_out = model.conv(features)
#         bs, c, h, w = conv_out.size()
#         # Flatten spatial dimensions and permute for LSTM input: (batch, seq_len, feature)
#         lstm_input = conv_out.flatten(2).permute(0, 2, 1)
#         lstm_out, _ = model.lstm(lstm_input)
#     return lstm_out.permute(1, 0, 2)  # Return in shape (seq_len, batch, feature) for compatibility


In [14]:
# def test(model, test_loader, max_text_length, tokenizer):
#     """
#     Evaluate and predict model with the test dataloader.
    
#     Args:
#         model: The OCR model to evaluate.
#         test_loader: DataLoader for test dataset.
#         max_text_length: Maximum length of output sequence.
#         tokenizer: Tokenizer for decoding output tokens.
    
#     Returns:
#         predicts: List of predicted text sequences.
#         gt: List of ground truth text sequences.
#         imgs: List of input images.
#     """
#     model.eval()
#     predicts = []
#     gt = []
#     imgs = []
#     device = next(model.parameters()).device
    
#     with torch.no_grad():
#         for batch in test_loader:
#             src, trg = batch
#             imgs.append(src.flatten(0,1))
#             src = src.to(device)
#             trg = trg.to(device)
            
#             # Forward pass without teacher forcing
#             output = model(src.float())
            
#             # Ensure output is in (seq_len, batch, vocab_size) format
#             if output.dim() == 3:
#                 output = output.permute(1, 0, 2)
            
#             # Apply log softmax and get predictions
#             output = F.log_softmax(output, dim=2)
#             predictions = output.argmax(dim=2)
            
#             # Process each sequence in the batch
#             for pred, target in zip(predictions.transpose(0,1), trg):
#                 # Convert prediction to text (handle special tokens)
#                 pred_indices = []
#                 for idx in pred:
#                     token = idx.item()
#                     if token == tokenizer.chars.index('EOS'):
#                         break
#                     if token > tokenizer.chars.index('EOS'):  # Skip special tokens
#                         pred_indices.append(token)
                
#                 # Decode prediction
#                 pred_text = tokenizer.decode(pred_indices)
#                 pred_text = pred_text.replace('SOS', '').replace('EOS', '')
#                 predicts.append(pred_text)
                
#                 # Convert target to text (handle special tokens)
#                 target_indices = []
#                 for idx in target:
#                     token = idx.item()
#                     if token == tokenizer.chars.index('EOS'):
#                         break
#                     if token > tokenizer.chars.index('EOS'):  # Skip special tokens
#                         target_indices.append(token)
                
#                 # Decode target
#                 target_text = tokenizer.decode(target_indices)
#                 target_text = target_text.replace('SOS', '').replace('EOS', '')
#                 gt.append(target_text)
    
#     return predicts, gt, imgs


def calculate_cer(pred_text, target_text):
    """Calculate Character Error Rate using Levenshtein distance."""
    if len(target_text) == 0:
        return 0 if len(pred_text) == 0 else 1
        
    matrix = [[0 for _ in range(len(pred_text) + 1)] 
              for _ in range(len(target_text) + 1)]
    
    for i in range(len(target_text) + 1):
        matrix[i][0] = i
    for j in range(len(pred_text) + 1):
        matrix[0][j] = j
        
    for i in range(1, len(target_text) + 1):
        for j in range(1, len(pred_text) + 1):
            if target_text[i-1] == pred_text[j-1]:
                matrix[i][j] = matrix[i-1][j-1]
            else:
                matrix[i][j] = min(matrix[i-1][j-1] + 1,    # substitution
                                 matrix[i][j-1] + 1,         # insertion
                                 matrix[i-1][j] + 1)         # deletion
                
    return matrix[len(target_text)][len(pred_text)] / len(target_text)

def calculate_wer(pred_text, target_text):
    """Calculate Word Error Rate using word-level Levenshtein distance."""
    pred_words = pred_text.split()
    target_words = target_text.split()
    
    if len(target_words) == 0:
        return 0 if len(pred_words) == 0 else 1
        
    matrix = [[0 for _ in range(len(pred_words) + 1)] 
              for _ in range(len(target_words) + 1)]
    
    for i in range(len(target_words) + 1):
        matrix[i][0] = i
    for j in range(len(pred_words) + 1):
        matrix[0][j] = j
        
    for i in range(1, len(target_words) + 1):
        for j in range(1, len(pred_words) + 1):
            if target_words[i-1] == pred_words[j-1]:
                matrix[i][j] = matrix[i-1][j-1]
            else:
                matrix[i][j] = min(matrix[i-1][j-1] + 1,    # substitution
                                 matrix[i][j-1] + 1,         # insertion
                                 matrix[i-1][j] + 1)         # deletion
                
    return matrix[len(target_words)][len(pred_words)] / len(target_words)

In [15]:
def test(model, test_loader, max_text_length, tokenizer):
    """
    Evaluate and predict model with the test dataloader.
    Memory-efficient version that processes data in chunks.
    
    Args:
        model: The OCR model to evaluate.
        test_loader: DataLoader for test dataset.
        max_text_length: Maximum length of output sequence.
        tokenizer: Tokenizer for decoding output tokens.
    
    Returns:
        predicts: List of predicted text sequences.
        gt: List of ground truth text sequences.
        imgs: List of input images.
    """
    model.eval()
    predicts = []
    gt = []
    imgs = []
    device = next(model.parameters()).device
    
    # Clear memory before starting
    if device.type == 'cuda':
        torch.cuda.empty_cache()
    
    chunk_size = 10  # Process 10 samples at a time
    current_chunk = {'imgs': [], 'preds': [], 'gts': []}
    
    with torch.no_grad():
        try:
            for batch_idx, batch in enumerate(test_loader):
                src, trg = batch
                
                # Store CPU version of image
                current_chunk['imgs'].append(src.flatten(0,1).cpu())
                
                # Move tensors to device
                src = src.to(device)
                trg = trg.to(device)
                
                try:
                    # Forward pass without teacher forcing
                    output = model(src.float())
                    
                    # Free memory
                    del src
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                    
                    # Ensure output is in (seq_len, batch, vocab_size) format
                    if output.dim() == 3:
                        output = output.permute(1, 0, 2)
                    
                    # Apply log softmax and get predictions
                    output = F.log_softmax(output, dim=2)
                    predictions = output.argmax(dim=2)
                    
                    # Free memory
                    del output
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                    
                    # Process each sequence in the batch
                    for pred, target in zip(predictions.transpose(0,1), trg):
                        # Convert prediction to text (handle special tokens)
                        pred_indices = []
                        for idx in pred:
                            token = idx.item()
                            if token == tokenizer.chars.index('EOS'):
                                break
                            if token > tokenizer.chars.index('EOS'):  # Skip special tokens
                                pred_indices.append(token)
                        
                        # Decode prediction
                        pred_text = tokenizer.decode(pred_indices)
                        pred_text = pred_text.replace('SOS', '').replace('EOS', '')
                        current_chunk['preds'].append(pred_text)
                        
                        # Convert target to text (handle special tokens)
                        target_indices = []
                        for idx in target:
                            token = idx.item()
                            if token == tokenizer.chars.index('EOS'):
                                break
                            if token > tokenizer.chars.index('EOS'):  # Skip special tokens
                                target_indices.append(token)
                        
                        # Decode target
                        target_text = tokenizer.decode(target_indices)
                        target_text = target_text.replace('SOS', '').replace('EOS', '')
                        current_chunk['gts'].append(target_text)
                    
                    # Free memory
                    del predictions, trg
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                    
                except RuntimeError as e:
                    print(f"Error processing batch {batch_idx}: {e}")
                    continue
                
                # If chunk is full or this is the last batch, append to main lists and clear chunk
                if len(current_chunk['imgs']) >= chunk_size or batch_idx == len(test_loader) - 1:
                    imgs.extend(current_chunk['imgs'])
                    predicts.extend(current_chunk['preds'])
                    gt.extend(current_chunk['gts'])
                    
                    # Clear chunk
                    current_chunk = {'imgs': [], 'preds': [], 'gts': []}
                    
                    # Force garbage collection
                    import gc
                    gc.collect()
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                
        except Exception as e:
            print(f"Unexpected error during testing: {e}")
            # Save what we have so far
            if current_chunk['imgs']:
                imgs.extend(current_chunk['imgs'])
                predicts.extend(current_chunk['preds'])
                gt.extend(current_chunk['gts'])
    
    return predicts, gt, imgs

In [16]:
# Clear any existing cached memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
import gc
gc.collect()

# Configure DataLoader for minimal memory usage
test_loader = torch.utils.data.DataLoader(
    DataGenerator(source_path, charset_base, max_text_length, 'test', transform), 
    batch_size=1,  # Keep batch size at 1 for minimum memory usage
    shuffle=False, 
    num_workers=0,  # Single process loading
    pin_memory=False,  # Disable pin_memory to reduce memory usage
    persistent_workers=False,  # Disable persistent workers
    prefetch_factor=None  # Disable prefetching
)

In [17]:


predicts, gt, imgs = test(model, test_loader, max_text_length, tokenizer)


# the part below is causing error 
predicts = list(map(lambda x : x.replace('SOS','').replace('EOS',''),predicts))
gt = list(map(lambda x : x.replace('SOS','').replace('EOS',''),gt))

In [18]:
evaluate = evaluation.ocr_metrics(predicts=predicts,
                                  ground_truth=gt,)
 
print("Calculate Character Error Rate {}, Word Error Rate {} and Sequence Error Rate {}".format(evaluate[0],evaluate[1],evaluate[2]))

Calculate Character Error Rate 1.0, Word Error Rate 1.0 and Sequence Error Rate 1.0


In [19]:
import cv2
import numpy as np
from data import preproc as pp

def show_predictions(imgs, gt, predicts, max_display=10):
    """
    Display images with their ground truth and predicted text.

    Args:
        imgs: List of image tensors (C, H, W).
        gt: List of ground truth strings.
        predicts: List of predicted strings.
        max_display: Maximum number of images to display.
    """
    for i, item in enumerate(imgs[:max_display]):
        print("=" * 80)
        img = item.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
        # Convert to grayscale if image has 3 channels
        if img.shape[2] == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        img = pp.adjust_to_see(img)
        cv2.imshow('Line', img)
        print("Ground truth:", gt[i])
        print("Prediction :", predicts[i], "\n")
        cv2.waitKey(0)
    cv2.destroyAllWindows()

In [None]:
# from src.utils.display_results import show_predictions

# Assuming you have run the test function and obtained these:
# predicts, gt, imgs = test(model, test_loader, max_text_length, tokenizer)

# Call the display function to show images with predictions and ground truth
show_predictions(imgs, gt, predicts, max_display=10)




In [None]:
! pip install matplotlib seaborn

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

training_results = joblib.load(output_path + '/training_results.joblib')

# Set Seaborn style
sns.set(style="whitegrid")


# Extract data
train_losses = [r[0] for r in training_results]
valid_losses = [r[1] for r in training_results]
durations = [r[2] for r in training_results]
epoch_nums = list(range(1, len(training_results) + 1))

# Plotting
fig, axs = plt.subplots(1, 3, figsize=(20, 6))

# Training Loss
sns.lineplot(x=epoch_nums, y=train_losses, marker='o', ax=axs[0], color='blue')
axs[0].set_title('Training Loss per Epoch')
axs[0].set_xlabel('Epoch')
axs[0].set_ylabel('Training Loss')

# Validation Loss
sns.lineplot(x=epoch_nums, y=valid_losses, marker='o', ax=axs[1], color='green')
axs[1].set_title('Validation Loss per Epoch')
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('Validation Loss')

# Time per Epoch
sns.lineplot(x=epoch_nums, y=durations, marker='o', ax=axs[2], color='red')
axs[2].set_title('Time per Epoch')
axs[2].set_xlabel('Epoch')
axs[2].set_ylabel('Time (seconds)')

plt.tight_layout()
plt.show()
