In [None]:
import os
import glob
import torch
import pprint
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from PIL import Image
from tqdm import tqdm

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from sklearn import metrics
from sklearn import preprocessing
from sklearn import model_selection

In [None]:
'''
    OCR DATASET CLASS
    Dataset Used = BanglaWriting
    Dataset Manual = https://arxiv.org/pdf/2011.07499.pdf
    Dataset Download Link - https://data.mendeley.com/datasets/r43wkvdk4w/1
'''

class  OCRDataset(Dataset):
    
    def __init__(self, img_dir, targets):
        self.img_dir = img_dir
        self.targets = targets

    def __len__(self):
        return len(self.img_dir)
    
    def __getitem__(self, item):
        image = Image.open(self.img_dir[item])
        image = image.resize((128, 64), resample=Image.BILINEAR)

        targets = self.targets[item]

        image = np.array(image)
        image = np.stack((image,)*1, axis=-1)

        # Reshape to tensor format supported by Pytorch (C, H, W)
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)

        return {
            "images": torch.tensor(image, dtype=torch.float),
            "targets": torch.tensor(targets, dtype=torch.long),
        }



In [None]:
# NEW MODEL


import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.attention_hidden_size = hidden_size // 2
        
        self.query = nn.Linear(hidden_size, self.attention_hidden_size)
        self.key = nn.Linear(hidden_size, self.attention_hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.attention_hidden_size]))
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        
        Q = self.query(query)
        K = self.key(key)
        V = self.value(value)
        
        # Q: [batch_size, query_len, attention_hidden_size]
        # K: [batch_size, key_len, attention_hidden_size]
        # V: [batch_size, value_len, hidden_size]
        
        energy = torch.matmul(Q, K.permute(0, 2, 1)) / self.scale.to(query.device)
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
            
        attention = torch.softmax(energy, dim=-1)
        
        output = torch.matmul(attention, V)
        
        return output, attention

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)].detach()

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=kernel_size//2)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding=kernel_size//2)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
            
    def forward(self, x):
        residual = x
        
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        out += self.shortcut(residual)
        out = F.relu(out)
        
        return out

class SpatialTransformer(nn.Module):
    def __init__(self, in_channels):
        super(SpatialTransformer, self).__init__()
        self.localization = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=5, padding=2),
            nn.MaxPool2d(2),
            nn.ReLU(True),
            nn.Conv2d(64, 128, kernel_size=5, padding=2),
            nn.MaxPool2d(2),
            nn.ReLU(True)
        )
        
        self.fc_loc = nn.Sequential(
            nn.Linear(128 * 16 * 4, 256),
            nn.ReLU(True),
            nn.Linear(256, 6)
        )
        
        # Initialize transformation parameters
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
        
    def forward(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 128 * 16 * 4)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)
        
        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)
        
        return x

class FeaturePyramidNetwork(nn.Module):
    def __init__(self, in_channels_list, out_channels):
        super(FeaturePyramidNetwork, self).__init__()
        self.inner_blocks = nn.ModuleList()
        self.layer_blocks = nn.ModuleList()
        
        for in_channels in in_channels_list:
            self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, 1))
            self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, 3, padding=1))
            
    def forward(self, x):
        results = []
        
        for i, (inner_block, layer_block) in enumerate(zip(self.inner_blocks, self.layer_blocks)):
            if i == 0:
                out = inner_block(x)
            else:
                out = inner_block(x) + F.interpolate(results[-1], size=x.shape[-2:], mode='nearest')
            out = layer_block(out)
            results.append(out)
            
        return results

class ComplexOCRModel(nn.Module):
    def __init__(self, num_chars, input_channels=1, hidden_size=1024):
        super(ComplexOCRModel, self).__init__()
        
        self.hidden_size = hidden_size
        self.embedding_size = 512
        
        # Spatial Transformer for input preprocessing
        self.stn = SpatialTransformer(input_channels)
        
        # Convolutional Backbone - Much deeper and wider
        self.conv_backbone = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            # Residual blocks
            ResidualBlock(64, 128),
            ResidualBlock(128, 128),
            ResidualBlock(128, 256, stride=2),
            ResidualBlock(256, 256),
            ResidualBlock(256, 512, stride=2),
            ResidualBlock(512, 512),
            ResidualBlock(512, 1024, stride=2),
            ResidualBlock(1024, 1024),
        )
        
        # Feature Pyramid Network for multi-scale feature extraction
        self.fpn = FeaturePyramidNetwork([256, 512, 1024], 256)
        
        # Sequence encoder
        self.sequence_encoder = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 512, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(512),
        )
        
        # Positional encoding
        self.positional_encoding = PositionalEncoding(hidden_size)
        
        # Bidirectional GRU layers
        self.gru_layers = nn.ModuleList([
            nn.GRU(512, hidden_size//2, bidirectional=True, batch_first=True),
            nn.GRU(hidden_size, hidden_size//2, bidirectional=True, batch_first=True),
            nn.GRU(hidden_size, hidden_size//2, bidirectional=True, batch_first=True)
        ])
        
        # Self-attention mechanism
        self.self_attention = nn.ModuleList([
            Attention(hidden_size) for _ in range(4)
        ])
        
        # Fully connected layers for feature transformation
        self.fc_layers = nn.ModuleList([
            nn.Linear(hidden_size, hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.Linear(hidden_size, hidden_size)
        ])
        
        # Dropout layers
        self.dropout = nn.Dropout(0.5)
        
        # Character prediction head
        self.char_pred = nn.Sequential(
            nn.Linear(hidden_size, 2048),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, num_chars + 1)  # +1 for blank in CTC
        )
        
        # Additional hidden layers to increase model size
        self.additional_layers = nn.ModuleList([
            nn.Linear(hidden_size, hidden_size) for _ in range(30)
        ])
        
        # Large parameter tensors to increase model size
        self.large_param1 = nn.Parameter(torch.randn(2000, 2000))
        self.large_param2 = nn.Parameter(torch.randn(2000, 2000))
        self.large_param3 = nn.Parameter(torch.randn(2000, 2000))
        self.large_param4 = nn.Parameter(torch.randn(2000, 2000))
        self.large_param5 = nn.Parameter(torch.randn(2000, 2000))
        
    def forward(self, images, targets=None):
        bs, c, h, w = images.size()
        
        # Apply spatial transformer
        x = self.stn(images)
        
        # Extract features through convolutional backbone
        x = self.conv_backbone(x)
        
        # Apply FPN to get multi-scale features - use the last level for sequence modeling
        fpn_features = self.fpn(x)
        x = fpn_features[-1]
        
        # Prepare for sequence modeling
        x = self.sequence_encoder(x)
        
        # Collapse height dimension for sequence modeling
        x = x.mean(dim=2)  # Average pooling over height
        x = x.permute(0, 2, 1)  # [bs, seq_len, channels]
        
        # Apply positional encoding
        x = self.positional_encoding(x)
        
        # Apply GRU layers with residual connections
        gru_out = x
        for i, gru in enumerate(self.gru_layers):
            residual = gru_out
            gru_out, _ = gru(gru_out)
            gru_out = gru_out + residual
            
            # Apply self-attention after each GRU layer
            if i < len(self.self_attention):
                attn_out, _ = self.self_attention[i](gru_out, gru_out, gru_out)
                gru_out = gru_out + attn_out
                
            # Apply fully connected transformation
            if i < len(self.fc_layers):
                fc_out = self.fc_layers[i](gru_out)
                gru_out = gru_out + fc_out
                
            gru_out = self.dropout(gru_out)
        
        # Apply character prediction head
        x = self.char_pred(gru_out)
        
        # Prepare for CTC loss
        x = x.permute(1, 0, 2)  # [seq_len, bs, num_classes]
        
        if targets is not None:
            log_probs = F.log_softmax(x, 2)
            input_lengths = torch.full(
                size=(bs,), fill_value=log_probs.size(0), dtype=torch.int32
            )
            
            # Calculate actual lengths of targets (excluding padding zeros)
            target_lengths = []
            for i in range(bs):
                count = 0
                for j in range(targets.size(1)):
                    if targets[i, j] != 0:
                        count += 1
                target_lengths.append(max(1, count))  # Ensure at least length 1
            
            target_lengths = torch.tensor(target_lengths, dtype=torch.int32)
            
            # Use CTC loss
            loss = nn.CTCLoss(blank=0, reduction='mean')(
                log_probs, targets, input_lengths, target_lengths
            )
            return x, loss
        
        return x, None
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [None]:
# defining the model OLD MODEL

class OCRModel(nn.Module):
    def __init__(self, num_chars):
        super(OCRModel, self).__init__()
        self.conv_1 = nn.Conv2d(1, 128, kernel_size=(3, 6), padding=(1, 1))
        self.pool_1 = nn.MaxPool2d(kernel_size=(2, 2))
        self.conv_2 = nn.Conv2d(128, 64, kernel_size=(3, 6), padding=(1, 1))
        self.pool_2 = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.linear_1 = nn.Linear(1024, 64) # 1024 = 64*16
        self.drop_1 = nn.Dropout(0.2)
        self.gru = nn.GRU(64, 32, bidirectional=True, num_layers=2, dropout=0.25, batch_first=True)
        self.output = nn.Linear(64, num_chars + 1)

    def forward(self, images, targets=None):
        bs, c, h, w = images.size()
        # print("bs, c, h, w = ", bs, c, h, w)
        x = F.relu(self.conv_1(images))
        # print(x.size())
        x = self.pool_1(x)
        # print(x.size())
        x = F.relu(self.conv_2(x))
        # print(x.size())
        x = self.pool_2(x) # [8, 64, 16, 29] (bs, c, h, w)
        # print(x.size())
        
        x = x.permute(0, 3, 1, 2) # bs, w, c, h
        # print(x.size())           # 8, 29, 64, 16 
        x = x.view(bs, x.size(1), -1)
        # print(x.size())
        x = F.relu(self.linear_1(x))
        x = self.drop_1(x)
        # print(x.size())
        
        x, _ = self.gru(x)
        # print(x.size())
        x = self.output(x)
        # print(x.size())
        
        x = x.permute(1, 0, 2)

        if targets is not None:
            log_probs = F.log_softmax(x, 2).to(torch.float64)
            input_lengths = torch.full(
                size=(bs,), fill_value=log_probs.size(0), dtype=torch.int32
            )
            # print(input_lengths)
            target_lengths = torch.full(
                size=(bs,), fill_value=targets.size(1), dtype=torch.int32
            )
            # print(target_lengths)
            loss = nn.CTCLoss(blank=0)(
                log_probs, targets, input_lengths, target_lengths
            )
#             print(loss)
            return x, loss

        return x, None


#
if __name__ == "__main__":
    cm = OCRModel(115)
    img = torch.rand((32, 1, 64, 128))
    x, _ = cm(img, torch.rand((32, 15)))

In [None]:
def remove_duplicates(x):
    if len(x) < 2:
        return x
    fin = ""
    for j in x:
        if fin == "":
            fin = j
        else:
            if j == fin[-1]:
                continue
            else:
                fin = fin + j
    return fin


def decode_predictions(preds, encoder):
    preds = preds.permute(1, 0, 2)
    preds = torch.softmax(preds, 2)
    preds = torch.argmax(preds, 2)
    preds = preds.detach().cpu().numpy()
    word_preds = []
    for j in range(preds.shape[0]):
        temp = []
        for k in preds[j, :]:
            k = k - 1
            if k == -1:
                temp.append("°")
            else:
                p = encoder.inverse_transform([k])[0]
                temp.append(p)
        tp = "".join(temp)
        word_preds.append(remove_duplicates(tp))
    return word_preds


In [None]:
# define train and test functions

def train_fn(model, data_loader, optimizer):
    model.train()
    fin_loss = 0
    tk0 = tqdm(data_loader, total=len(data_loader))
    

    for data in tk0:
        for key, value in data.items():
            data[key] = value.to("cuda" if torch.cuda.is_available() else "cpu")
        optimizer.zero_grad()
        _, loss = model(**data)
        loss.backward()
        optimizer.step()
        fin_loss += loss.item()
    return fin_loss / len(data_loader)


def eval_fn(model, data_loader):
    model.eval()
    fin_loss = 0
    fin_preds = []
    with torch.no_grad():
        tk0 = tqdm(data_loader, total=len(data_loader))
        for data in tk0:
            for key, value in data.items():
                data[key] = value.to("cuda" if torch.cuda.is_available() else "cpu")
            batch_preds, loss = model(**data)
            fin_loss += loss.item()
            fin_preds.append(batch_preds)
        return fin_preds, fin_loss / len(data_loader)


In [None]:
filepath = './img/' 

In [44]:
def train():
    print('train function is running')
    image_files = glob.glob(os.path.join(filepath, '*jpg'))
    targets_orig = [x.split("/")[1].split(" ")[0] for x in image_files]
#     print(targets_orig)
    targets = [[c for c in x] for x in targets_orig]
    targets_flat = [c for clist in targets for c in clist]
    
    lbl_enc = preprocessing.LabelEncoder()
    lbl_enc.fit(targets_flat)
    targets_enc = [lbl_enc.transform(x) for x in targets]
    targets_enc = np.array(targets_enc) + 1
#     print(targets_enc)
    
    #############################################################################################
#     num = 3635
#     print(targets[num])  # target length (# 12650 = 9)
#     print("Target label length =", len(targets_enc[num]))
    #############################################################################################
    
    
    # add padding to labels to make the target length equal for every target/label
    maxlen = len(max(targets, key=len)) # to get the length of the largest label
    # print(maxlen)
    # print(max(targets, key=len))
    
    # iterating over every target and adding 0 at the last
    for item in range(len(targets_enc)):
        difference = maxlen - len(targets_enc[item]) 
        for i in range(difference):
            targets_enc[item] = np.append(targets_enc[item], 0)
#             np.pad(targets_enc[item], (0, difference), 'constant')

    
    print("Total unique classes/characters:", len(lbl_enc.classes_))
#     print(lbl_enc.classes_[114])
#     print(np.unique(targets_flat))
    
    # divide into train test 
    (
        train_imgs,
        test_imgs,
        train_targets,
        test_targets,
        train_orig_targets,
        test_orig_targets,
    ) = model_selection.train_test_split (
        image_files, targets_enc, targets_orig, test_size = 0.2, random_state = 42
    )
    
    # loading images and their corresponding labels to train and test dataset
    train_dataset = OCRDataset(img_dir = train_imgs, targets = train_targets)
    test_dataset = OCRDataset(img_dir = test_imgs, targets = test_targets)
    
    # defining the data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
    
    
    # model goes here
    model = OCRModel(len(lbl_enc.classes_))
    model.to("cuda" if torch.cuda.is_available() else "cpu")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.8, patience=5, verbose=True
    )

    # define number of epoch and start training
    num_epoch = 10
    for epoch in range(num_epoch):
        train_loss = train_fn(model, train_loader, optimizer)
        valid_preds, test_loss = eval_fn(model, test_loader)
        valid_word_preds = []
        
        for vp in valid_preds:
            current_preds = decode_predictions(vp, lbl_enc)
            valid_word_preds.extend(current_preds)
        combined = list(zip(test_orig_targets, valid_word_preds))
        print(combined[:10])
        test_dup_rem = [remove_duplicates(c) for c in test_orig_targets]
        accuracy = metrics.accuracy_score(test_dup_rem, valid_word_preds)
        pprint.pprint(list(zip(test_orig_targets, valid_word_preds))[6:11])
        print(
            f"Epoch={epoch}, Train Loss={train_loss}, Test Loss={test_loss} Accuracy={accuracy}"
        )
        scheduler.step(test_loss)

train()

 58%|█████▊    | 226/392 [00:06<00:04, 33.60it/s]


KeyboardInterrupt: 

In [None]:

# Visualize train data and its shape

# import matplotlib.pyplot as plt
# import numpy as np
# %matplotlib inline

# npimg = train_dataset[200]['images'].np()
# print(npimg.shape) # print current shape (torch style)

# # change the orientation of the image to display
# npimg = np.transpose(npimg, (1, 2, 0)).astype(np.float32)
# print(npimg.shape)

# plt.imshow(npimg)
