In [None]:
import torch
import numpy as np
import torch.nn as nn
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from PIL import Image,ImageOps
from torchvision import transforms
import os
from torch.utils.data import Dataset,DataLoader,ConcatDataset,SubsetRandomSampler
import pandas as pd
import glob
import uuid
import random
import cv2 
import albumentations as A
from albumentations.augmentations.geometric.rotate import Rotate
from albumentations.augmentations.geometric.transforms import ElasticTransform
from itertools import groupby

import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn.init as init
from torchmetrics import Recall, Precision

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
transform = transforms.Compose([transforms.Resize((224,224))])
aug_transform = A.Compose([
   ElasticTransform(p=1.0,border_mode=cv2.BORDER_REPLICATE,approximate=True,same_dxdy=True),
   Rotate(limit=20,p=0.6)
])

In [None]:
df = pd.read_csv("../input/csv-file/filtered.csv")

In [None]:
img = Image.open("../input/english-handwritten-characters-dataset/Img/img002-032.png")
img2 = Image.open("../input/english-handwritten-characters-dataset/Img/img003-010.png")

img=transform(img)
img2=transform(img2)

img = np.asarray(img)
img2=np.asarray(img2)
a = np.concatenate([img[:,:170],img2[:,45:]],axis=1)
a=Image.fromarray(a)
a=transform(a)
a=np.asarray(a)
aug = aug_transform(image=a)
aug_image = aug["image"]

In [None]:
plt.imshow(aug_image)

In [None]:
df.head(5)

In [None]:
a.size

In [None]:
d = dict()


d["0"]= df[df["label"]==0].iloc[:,1].tolist()
d["1"]=df[df["label"]==1].iloc[:,1].tolist()
d["2"]=df[df["label"]==2].iloc[:,1].tolist()
d["3"]=df[df["label"]==3].iloc[:,1].tolist()
d["4"]=df[df["label"]==4].iloc[:,1].tolist()
d["5"]=df[df["label"]==5].iloc[:,1].tolist()
d["6"]=df[df["label"]==6].iloc[:,1].tolist()
d["7"]=df[df["label"]==7].iloc[:,1].tolist()
d["8"]=df[df["label"]==8].iloc[:,1].tolist()
d["9"]=df[df["label"]==9].iloc[:,1].tolist()

In [None]:
root_path="../input/english-handwritten-characters-dataset/"

In [None]:
def number_list(root,df,transform):
    num_list = []
    aug_num_list = []
    for i in df.iterrows():
        image = Image.open(root+i[1][1])
        label = i[1][2]
        image = transform(image)
        image_copy = np.asarray(image)
        
        for j in range(5):
            aug = aug_transform(image=image_copy)
            aug_image = aug["image"]
            aug_image = Image.fromarray(aug_image)
            aug_num_list.append((aug_image,label))
        
        num_list.append((image,label))

    return num_list,aug_num_list
    
    

In [None]:
num_list_1,aug_num_1=number_list(root_path,df,transform)
print(len(num_list_1),len(aug_num_1))

In [None]:
def create_numbers(root,d,transform):
    num_list=[]
    aug_num_list = []
    for i in range(10,101):
        
        val=str(i)
        l1=d[val[0]]
        l2=d[val[1]]
        for j in range(55):
            
            num1 = np.random.choice(l1)
            num2 = np.random.choice(l2)
            
            img = Image.open(root+num1)
            img2 = Image.open(root+num2)
            
            img=transform(img)
            img2=transform(img2)

            img = np.asarray(img)
            img2=np.asarray(img2)
            
            if i==100:
                a = np.concatenate([img[:,:170],img2[:,45:170],img2[:,45:]],axis=1)
            else:    
                a = np.concatenate([img[:,:170],img2[:,45:]],axis=1)
                
            a=Image.fromarray(a)
            a=transform(a)
            
            image_copy = np.asarray(a)
        
            for j in range(5):
                aug = aug_transform(image=image_copy)
                aug_image = aug["image"]
                aug_image = Image.fromarray(aug_image)
                aug_num_list.append((aug_image,i))


            
            num_list.append((a,i))
            
            
    return num_list,aug_num_list
            

In [None]:
num_list_2,aug_num_2 = create_numbers(root_path,d,transform)
print(len(num_list_2),len(aug_num_2))

In [None]:
num_list = num_list_1+ num_list_2
aug_list = aug_num_1 + aug_num_2

print(len(num_list),len(aug_list))

In [None]:
val_test_list = [num_list.pop(random.randrange(len(num_list))) for _ in range(3333)]

In [None]:
num_list = num_list+aug_list

In [None]:
class Numbers(Dataset):
    def __init__(self,image_list,max_length,transform):
        self.image_list=image_list
        self.transform=transform
        self.max_length = max_length
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self,index):
        image = self.image_list[index][0]
        image=ImageOps.grayscale(image)
        image = np.asarray(image)
        
        blur = cv2.GaussianBlur(image,(5,5),0)
    
        ret3,th3 = cv2.threshold(blur,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)

        image = Image.fromarray(th3)
        
        label = self.image_list[index][1]
        label_len = len(str(label))
        mod_label=[]
        string_label = str(label)
        
        for s in range(self.max_length):
            if s>len(string_label)-1:
                mod_label.append(-1)
                continue
            mod_label.append(int(string_label[s]))
        
        y_label=torch.tensor(mod_label)
            
        if self.transform:
            image=self.transform(image)

            
        return (image,y_label,label_len)

In [None]:
transform2 = transforms.Compose([transforms.ToTensor()])
train_dataset = Numbers(num_list,3,transform2)
val_test_dataset = Numbers(val_test_list,3,transform2)

In [None]:
val_set,test_set=torch.utils.data.random_split(val_test_dataset,[2222,1111])

train_loader=DataLoader(dataset=train_dataset,batch_size=64,shuffle=True)
test_loader=DataLoader(dataset=test_set,batch_size=64,shuffle=True)
val_loader=DataLoader(dataset=val_set,batch_size=64,shuffle=True)

In [None]:
VGG_types = {
    "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG16": [
        64,
        64,
        "M",
        128,
        128,
        "M",
        256,
        256,
        256,
        "M",
        512,
        512,
        512,
        "M",
        512,
        512,
        512,
        "M",
    ],
    "VGG19": [
        64,
        64,
        "M",
        128,
        128,
        "M",
        256,
        256,
        256,
        256,
        "M",
        512,
        512,
        512,
        512,
        "M",
        512,
        512,
        512,
        512,
        "M",
    ],
}

In [None]:
# torch.Size([32, 512, 7, 7])
class CRNN(nn.Module):

    def __init__(self,device,cnn_type):
        super(CRNN, self).__init__()
        self.in_channels=1
        self.device=device
        self.num_classes = 10 + 1
        self.image_H = 28

        self.vgg = self.create_conv_layers(cnn_type)
#         self.get_dims = self.vgg(torch.randn(1,1,224,224))
        self.postconv_height = 7
        self.postconv_width = 7
        self.gru_input_size = self.postconv_height * 64
        self.gru_hidden_size = 128 
        self.gru_num_layers = 2
        self.gru_h = None
        self.gru_cell = None

        self.gru = nn.GRU(self.gru_input_size, self.gru_hidden_size, self.gru_num_layers, batch_first = True, bidirectional = True)

        self.fc = nn.Linear(self.gru_hidden_size * 2, self.num_classes)


    def create_conv_layers(self,architecture):
        layers=[]
        in_channels=self.in_channels
        
        for x in architecture:
            
            if type(x)==int:
                out_channels=x
                layers+=[
                    nn.Conv2d
                    (
                        in_channels,out_channels,
                        kernel_size=(3, 3),
                        stride=(1, 1),
                        padding=(1, 1),
                    ),
                    nn.InstanceNorm2d(x),
                    nn.LeakyReLU(0.01),
                ]
                in_channels=x
                
            elif x=="M":
                layers+=[nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))]
                
        return nn.Sequential(*layers)

    def forward(self, x):
        batch_size = x.shape[0]
        out = self.vgg(x)
        
        out = out.permute(0, 3, 2, 1) 
        out = out.reshape(batch_size, -1, self.gru_input_size)

        out, gru_h = self.gru(out, self.gru_h)
       
        self.gru_h = gru_h.detach()
        out = torch.stack([F.log_softmax(self.fc(out[i])) for i in range(out.shape[0])])

        return out

    def reset_hidden(self,batch_size):
        h = torch.zeros(self.gru_num_layers * 2,batch_size , self.gru_hidden_size,device=self.device)
        self.gru_h = Variable(h)

crnn = CRNN(device,VGG_types["VGG19"]).to(device)
criterion = nn.CTCLoss(blank=10, reduction='mean', zero_infinity=True)
optimizer = torch.optim.Adam(crnn.parameters(), lr=3e-4) 
PATH = "best_model_vgg.pth"

In [None]:
# a = torch.randn(1,1,224,224)
# out = crnn(a)
# out.shape

In [None]:
torch.cuda.empty_cache()
best_val = -1
for e in range(20):
    
    BLANK_LABEL=10

    num_batches =0
    total_loss = 0
    
    crnn.train()
    for (inputs,labels,label_size) in train_loader:
        
        correct = 0
        total = 0
        
        inputs,labels = inputs.to(device),labels.to(device)
        batch_size = len(inputs)
        crnn.reset_hidden(batch_size)

        optimizer.zero_grad()  
        
        y_pred = crnn(inputs)
        y_pred = y_pred.permute(1, 0, 2)

        input_lengths = torch.IntTensor(batch_size).fill_(crnn.postconv_width)
        target_lengths = torch.IntTensor([t for t in label_size])

        loss = criterion(y_pred, labels, input_lengths, target_lengths)
        
        total_loss+=loss.item()
        
        loss.backward()
        optimizer.step()
        
        
        num_batches += 1
        

    
    crnn.eval()
    with torch.no_grad():
        val_correct = 0
        val_total = 0
        val_batches = 0
        running_precision = 0
        running_recall = 0
        batch_precision = 0
        batch_recall = 0
        prec = Precision()
        rec = Recall()
        for (inputs,labels,label_size) in val_loader:
            
            inputs,labels = inputs.to(device),labels.to(device)
            batch_size = len(inputs)
            crnn.reset_hidden(batch_size)
            
            y_pred = crnn(inputs)
            y_pred = y_pred.permute(1, 0, 2)
            
            _, max_index = torch.max(y_pred, dim=2)
            
            for i in range(batch_size):
                raw_prediction = list(max_index[:, i].detach().cpu().numpy())

                prediction = torch.IntTensor([c for c, _ in groupby(raw_prediction) if c != BLANK_LABEL])
                ground_truth = labels[i].detach().cpu()
                ground_truth = ground_truth[ground_truth!=-1]

                if len(prediction) == len(ground_truth) and torch.all(prediction.eq(ground_truth)):
                    val_correct += 1
                val_total += 1
                
                preds = ""
                g_truth = ""
                for k in prediction:
                    preds+=str(k.item())
                preds = int(preds)
                preds = torch.tensor([preds])
                
                for k in ground_truth:
                    g_truth+=str(k.item())
                    
                g_truth = int(g_truth)
                g_truth = torch.tensor([g_truth])
    
                running_precision+=prec(preds,g_truth).item()
                running_recall+=rec(preds,g_truth).item()
            
            running_precision = running_precision/batch_size
            running_recall = running_recall/batch_size
            
            batch_precision+=running_precision
            batch_recall+=running_recall
            
            val_batches+=1
                
        if val_correct>best_val:
            best_val = val_correct
            torch.save(crnn.state_dict(), PATH)
            print("Best val correct:",best_val)
            print("SAVING MODEL")
            
    print("Epoch:{e} val_total:{total} val_correct:{correct}".format(e=e,total=val_total,correct=val_correct))
    print("Epoch:{e} precision_val:{prec_val} recall_val:{rec_val}".format(e=e,prec_val=batch_precision/val_batches,rec_val=batch_recall/val_batches))
    print("Epoch:{e} train_loss:{Loss}".format(e=e,Loss=total_loss / num_batches))
    

# **Testing**

In [None]:
crnn.eval()
with torch.no_grad():
    val_correct = 0
    val_total = 0
    val_batches = 0
    running_precision = 0
    running_recall = 0
    batch_precision = 0
    batch_recall = 0
    prec = Precision()
    rec = Recall()
    for (inputs,labels,label_size) in test_loader:

        inputs,labels = inputs.to(device),labels.to(device)
        batch_size = len(inputs)
        crnn.reset_hidden(batch_size)

        y_pred = crnn(inputs)
        y_pred = y_pred.permute(1, 0, 2)

        _, max_index = torch.max(y_pred, dim=2)

        for i in range(batch_size):
            raw_prediction = list(max_index[:, i].detach().cpu().numpy())

            prediction = torch.IntTensor([c for c, _ in groupby(raw_prediction) if c != BLANK_LABEL])
            ground_truth = labels[i].detach().cpu()
            ground_truth = ground_truth[ground_truth!=-1]

            if len(prediction) == len(ground_truth) and torch.all(prediction.eq(ground_truth)):
                val_correct += 1
            val_total += 1

            preds = ""
            g_truth = ""
            for k in prediction:
                preds+=str(k.item())
            preds = int(preds)
            preds = torch.tensor([preds])

            for k in ground_truth:
                g_truth+=str(k.item())

            g_truth = int(g_truth)
            g_truth = torch.tensor([g_truth])

            running_precision+=prec(preds,g_truth).item()
            running_recall+=rec(preds,g_truth).item()

        running_precision = running_precision/batch_size
        running_recall = running_recall/batch_size

        batch_precision+=running_precision
        batch_recall+=running_recall

        val_batches+=1
        

    print("Epoch:{e} val_total:{total} val_correct:{correct}".format(e=e,total=val_total,correct=val_correct))
    print("Epoch:{e} precision_val:{prec_val} recall_val:{rec_val}".format(e=e,prec_val=batch_precision/val_batches,rec_val=batch_recall/val_batches))