In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install pytorch_lightning



In [0]:
maindir = "/content/drive/My Drive/bengali/bengali_data/"

In [0]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline


import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torch.nn.functional import cross_entropy
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.optim import SGD, Adam
import torchvision.transforms as transforms

import pytorch_lightning as pl
from pytorch_lightning import Trainer

import torchvision.models as models

import albumentations as albu
import cv2

In [0]:
cl_consonant_diacritic = 7   #3
cl_vowel_diacritic     = 11  #2
cl_grapheme_root       = 168 #1

In [0]:
HEIGHT = 137
WIDTH = 236
SIZE = 128


def bbox(img):
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    return rmin, rmax, cmin, cmax

def crop_resize(img0, size=SIZE, pad=16):
    #crop a box around pixels large than the threshold 
    #some images contain line at the sides
    ymin,ymax,xmin,xmax = bbox(img0[5:-5,5:-5] > 80)
    #cropping may cut too much, so we need to add it back
    xmin = xmin - 13 if (xmin > 13) else 0
    ymin = ymin - 10 if (ymin > 10) else 0
    xmax = xmax + 13 if (xmax < WIDTH - 13) else WIDTH
    ymax = ymax + 10 if (ymax < HEIGHT - 10) else HEIGHT
    img = img0[ymin:ymax,xmin:xmax]
    #remove lo intensity pixels as noise
    img[img < 28] = 0
    lx, ly = xmax-xmin,ymax-ymin
    l = max(lx,ly) + pad
    #make sure that the aspect ratio is kept in rescaling
    img = np.pad(img, [((l-ly)//2,), ((l-lx)//2,)], mode='constant')
    return cv2.resize(img,(size,size))

In [0]:
class GraphemeDataset(Dataset):
    
    def __init__(self, tablefile, labelfile, transform = None):
        
        tablefile = pd.read_parquet(tablefile)     
        labelfile = pd.read_csv(labelfile)
        
        datafile = labelfile.merge(tablefile, left_on='image_id', right_on='image_id', copy = False)

        self.X_data = 255 - datafile.iloc[:, 5:].values.reshape(-1, HEIGHT, WIDTH).astype(np.uint8)
        
        self.y_data   = datafile.iloc[:, 1:4].to_numpy(dtype = 'uint8')
        
        self.transform = transform

    def __getitem__(self, idx):
        
        X         = self.X_data[idx]
        X = (X*(255.0/X.max())).astype(np.uint8)
        X = crop_resize(X)


        if self.transform:
            augmented = self.transform(image = X)
            X = augmented['image']

        X         = (X / 255).astype("float32")
        y         = self.y_data[idx].astype("long")
        
        return X.reshape(1, SIZE, SIZE), y

    def __len__(self):
        return len(self.y_data)

In [0]:
transform = albu.Compose([
    albu.Cutout()
    albu.ShiftScaleRotate(shift_limit = 0.0625, scale_limit = 0, rotate_limit = 0, border_mode = cv2.BORDER_CONSTANT, p = 1)
])

In [0]:
train0 = GraphemeDataset(tablefile = maindir + "train_image_data_0.parquet",
                        labelfile  = maindir + "train.csv",
                        transform  = transform)
train1 = GraphemeDataset(tablefile = maindir + "train_image_data_1.parquet",
                        labelfile  = maindir + "train.csv",
                        transform  = transform)
train2 = GraphemeDataset(tablefile = maindir + "train_image_data_2.parquet",
                        labelfile  = maindir + "train.csv",
                        transform  = transform)
train3 = GraphemeDataset(tablefile = maindir + "train_image_data_3.parquet",
                        labelfile  = maindir + "train.csv",
                        transform  = transform)

train = ConcatDataset([train0, train1, train2, train3])

In [0]:
epoch = 20
batch_size = 200
learning_rate = 0.01

class GraphemeModel(pl.LightningModule):

    def __init__(self):   
        super(GraphemeModel, self).__init__()
        
        self.model  = nn.Sequential(
            nn.Conv2d(1, 3, kernel_size = 1),
            models.resnet18(num_classes = 1000)
            )
        
        self.output1 = nn.Linear(1000, cl_grapheme_root      )
        self.output2 = nn.Linear(1000, cl_vowel_diacritic    )
        self.output3 = nn.Linear(1000, cl_consonant_diacritic)
        
    def forward(self, X):
        out = self.model(X)
        
        out1 = self.output1(out)
        out2 = self.output2(out)
        out3 = self.output3(out)
        
        return out1, out2, out3

    def training_step(self, batch, batch_nb):
        X, y    = batch
        pred1, pred2, pred3 = self.forward(X)
        
        loss1 = cross_entropy(pred1, y[:, 0])
        loss2 = cross_entropy(pred2, y[:, 1])
        loss3 = cross_entropy(pred3, y[:, 2])
        
        loss  = loss1 + loss2 + loss3
        
        return {
            'loss'         : loss,
            'log'          : { 'loss' : loss}
        }

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr = learning_rate)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4, 8, 12, 16], gamma = 0.1)
        return [optimizer], [scheduler]
    
    @pl.data_loader
    def train_dataloader(self):
        return DataLoader(train, batch_size = batch_size, shuffle = True)

In [0]:
model = GraphemeModel()
trainer = Trainer(early_stop_callback=False, max_nb_epochs = epoch, gpus = -1)
trainer.fit(model)

INFO:root:gpu available: True, used: True
INFO:root:VISIBLE GPUS: 0


INFO:root:
               Name               Type Params
0             model         Sequential   11 M
1           model.0             Conv2d    6  
2           model.1             ResNet   11 M
3     model.1.conv1             Conv2d    9 K
4       model.1.bn1        BatchNorm2d  128  
..              ...                ...    ...
68  model.1.avgpool  AdaptiveAvgPool2d    0  
69       model.1.fc             Linear  513 K
70          output1             Linear  168 K
71          output2             Linear   11 K
72          output3             Linear    7 K

[73 rows x 3 columns]


Epoch 10:  85%|████████▌ | 855/1005 [06:24<01:08,  2.20batch/s, batch_idx=853, gpu=0, loss=0.119, v_num=3]

In [0]:
torch.save(model.model.state_dict(), "model.pt")

torch.save(model.output1.state_dict(), "cl_grapheme_root.pt")
torch.save(model.output2.state_dict(), "cl_vowel_diacritic.pt")
torch.save(model.output3.state_dict(), "cl_consonant_diacritic.pt")