In [1]:
from google.colab import drive
drive.mount("/content/drive")
# avishka
# %cd /content/drive/MyDrive/Work/AHEAD-AWS/ocr
# chamod
%cd /content/drive/MyDrive/malli/ocr

Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/1A6FlH4cDIhWx-1NwKQRuwS6TBZNx_vb_/ocr


In [2]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from time import time as getTime
from datetime import datetime

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import one_hot

import torch
import torchvision
import torchvision.transforms as transforms
from tqdm.auto import tqdm

import cv2 as cv

In [3]:
!pip install torchinfo
from torchinfo import summary

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchinfo
  Downloading torchinfo-1.7.2-py3-none-any.whl (22 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.7.2


In [19]:
# @title Dataset
class NumberDataset(Dataset):
  def __init__(self, data_path):
    label_names = os.listdir(data_path)
    self.label_names = label_names
    self.data_path = data_path

    limit_map, img_lsts, n = [], [], 0
    for label in label_names:
      limit_map.append(n)
      img_lst = os.listdir(f"{data_path}/{label}")
      n += len(img_lst)
      img_lst = tuple(map(lambda img_name: f"{data_path}/{label}/{img_name}", img_lst))
      img_lsts.append(img_lst)

    self.limit_map = np.array(limit_map)
    self.img_lsts = img_lsts
    self.len = n

  def __len__(self):
    return self.len

  def __getitem__(self, idx):
    less_than_limits = self.limit_map <= idx
    lbl = less_than_limits.sum()-1
    img_id = idx - self.limit_map[lbl]
    img_path = self.img_lsts[lbl][img_id]

    img = cv.imread(img_path, cv.IMREAD_GRAYSCALE)
    _,img = cv.threshold(img,127,255,cv.THRESH_BINARY)
    img = (img/255).astype(int)[np.newaxis, :]
    img = torch.Tensor(img)

    return img, lbl

train_ds = NumberDataset("./data/alpha_num_char/train")
val_ds = NumberDataset("./data/alpha_num_char/val")
test_ds = NumberDataset("./data/alpha_num_char/test")
micro_ds = NumberDataset("./data/alpha_num_char/micro") # used for testing the code

batch_size = 16
train_dl = DataLoader(train_ds, batch_size, True)
val_dl = DataLoader(val_ds, batch_size, False)
test_dl = DataLoader(test_ds, batch_size, False)
# micro_dl = DataLoader(micro_ds, batch_size, True)

In [5]:
len(train_ds)/36, len(val_ds)/36, len(test_ds)/36

(700.0, 199.44444444444446, 100.0)

In [None]:
# @title load and cache the data into the colab VM
train_dl_cache = DataLoader(train_ds, 1024, True)
val_dl_cache = DataLoader(val_ds, 1024, False)
test_dl_cache = DataLoader(test_ds, 1024, False)
micro_dl_cache = DataLoader(micro_ds, 1024, True)

cache_dls = [train_dl_cache, test_dl_cache]
for dl in cache_dls:
  for batch in dl:
    print(batch[0].shape)

In [7]:
del train_dl_cache
del val_dl_cache
del test_dl_cache
del micro_dl_cache 
del cache_dls

In [8]:
for data in train_dl:
  print(type(data), len(data), type(data[0]), data[0].shape, data[1].shape, data[0].dtype, data[1].dtype)
  break

<class 'list'> 2 <class 'torch.Tensor'> torch.Size([4, 1, 28, 28]) torch.Size([4]) torch.float32 torch.int64


In [9]:
# @title 1C28S -> 36L Net
class NetLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, len(train_ds.label_names))

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device {device}")

netLP = NetLP().to(device)

Using device cuda


In [10]:
dataiter = iter(train_dl)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)
outputs = netLP(images)
print(images.shape, labels.shape, labels, outputs.shape)

torch.Size([4, 1, 28, 28]) torch.Size([4]) tensor([ 4, 13, 32, 32], device='cuda:0') torch.Size([4, 36])


In [12]:
# @title Training Functions
from typing import Callable, Tuple, Union, Any
from torch.utils.data import DataLoader, Dataset
from decimal import Decimal

def trainLoop(
        training_dataloader: DataLoader,
        model: torch.nn.Module,
        criterion: torch.nn,
        optimizer: Any,
        device: str,
        pbar
    ) -> float:
    train_loss = None
    for i, data in enumerate(training_dataloader, 0):
      # get the inputs; data is a list of [inputs, labels]
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()

      outputs = model(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      train_loss = loss.item()
      
      pbar.postfix = [f'{Decimal(str(train_loss)):.4e}', 'None']
      pbar.update(1)

    return train_loss

def valLoop(
        val_dataloader: DataLoader,
        model: torch.nn.Module,
        criterion: torch.nn,
        train_loss: float,
        device:str,
        pbar
    ) -> float:
    model.eval()
    running_val_loss = None
    with torch.no_grad():
        losses = []
        for i, data in enumerate(val_dataloader, 0):
          # get the inputs; data is a list of [inputs, labels]
          inputs, labels = data
          inputs, labels = inputs.to(device), labels.to(device)

          outputs = model(inputs)
          val_loss = criterion(outputs, labels).item()
          losses.append(val_loss)
          running_val_loss = sum(losses)/len(losses)
          
          pbar.postfix = [f'{Decimal(str(train_loss)):.4e}', f'{Decimal(str(running_val_loss)):.4e}']
          pbar.update(1)

    return running_val_loss

def testLoop(dataloader: DataLoader, model:torch.nn.Module, val_infer_fn:Callable, best_model_accuracy:list, device:str, pbar):
    prd_res = {"prd":[], "lbl":[]}
    accuracy = 0
    with torch.inference_mode():
        num_correct = num_samples = 0
        for batch in dataloader:
            prd, lbl = val_infer_fn(batch, model, device)
            prd_flat = prd.reshape(-1)
            lbl_flat = lbl.reshape(-1)
            prd_res["prd"].extend(prd_flat.tolist())
            prd_res["lbl"].extend(lbl_flat.tolist())
            num_correct += (prd.argmax(axis=1)==lbl.argmax(axis=1)).sum()
            num_samples += lbl.shape[0]
            
            accuracy = float(num_correct)/float(num_samples)
            pbar.update(1)
            pbar.postfix = [accuracy,"None"] if best_model_accuracy==None else [best_model_accuracy, accuracy]
    prd_res["prd"] = np.array(prd_res["prd"])
    prd_res["lbl"] = np.array(prd_res["lbl"])

    return accuracy, prd_res

In [13]:
weights_path = "./model-weights"
if not os.path.exists(weights_path): os.makedirs(weights_path)
MODEL_PATH = f'{weights_path}/netLP-weights.pth'
OPTIM_PATH = f'{weights_path}/netLP-optim-weights.pth'
HSTRY_PATH = f'{weights_path}/history.csv'

In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(netLP.parameters())
epochs = 50
best_model_weights = None
best_optim_weights = None
best_val_loss = np.inf

train_loss, val_loss, batch_rate, item_rate = None, None, None, None
tot_batches = len(train_dl)+len(val_dl)
avg_batch_size = (train_dl.batch_size+val_dl.batch_size)/2
# history of training
history = {"train_loss":[], "val_loss": [], "cumulative_time": [], "batch_processing_rate": [], "item_processing_rate": []}
start_time = getTime()
best_epoch = None

for epoch in range(epochs):  # loop over the dataset multiple times
    print(f"---- Epoch {epoch+1} ----")

    with tqdm(
      total=tot_batches,
      bar_format='{desc}: {percentage:3.0f}%|{bar}| Batch: {n_fmt}/{total_fmt} | Time: [{elapsed}<{remaining}, ' '{rate_fmt}] | Train loss: {postfix[0]} | Validation loss: {postfix[1]}',
      postfix=[train_loss, val_loss]
      ) as pbar:
      train_loss = trainLoop(
        train_dl,
        netLP,
        criterion,
        optimizer,
        device,
        pbar
      )
      val_loss = valLoop(
        test_dl,
        netLP,
        criterion,
        train_loss,
        device,
        pbar
      )

      if val_loss < best_val_loss:
        best_epoch = epoch
        best_val_loss = val_loss
        best_model_weights = netLP.state_dict()
        best_optim_weights = optimizer.state_dict()
        torch.save(best_model_weights, MODEL_PATH)
        torch.save(best_optim_weights, OPTIM_PATH)

      now = datetime.now().timestamp()
      batch_rate = tot_batches/(now - pbar.start_t)
      item_rate = tot_batches*avg_batch_size/(now - pbar.start_t)

    # update the history
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["cumulative_time"].append(getTime() - start_time)
    history["batch_processing_rate"].append(batch_rate)
    history["item_processing_rate"].append(item_rate)

print('Finished Training')

---- Epoch 1 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: None | Validation loss: None

---- Epoch 2 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.05483882874250412 | Validation loss: 0…

---- Epoch 3 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.04130900651216507 | Validation loss: 0…

---- Epoch 4 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.06287972629070282 | Validation loss: 0…

---- Epoch 5 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.009906220249831676 | Validation loss: …

---- Epoch 6 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 5.4623269534204155e-05 | Validation loss…

---- Epoch 7 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0012678737984970212 | Validation loss:…

---- Epoch 8 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.00010940073116216809 | Validation loss…

---- Epoch 9 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 2.345277607673779e-05 | Validation loss:…

---- Epoch 10 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 9.76023784460267e-07 | Validation loss: …

---- Epoch 11 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.00011366252874722704 | Validation loss…

---- Epoch 12 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.00015194002480711788 | Validation loss…

---- Epoch 13 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.06254260241985321 | Validation loss: 0…

---- Epoch 14 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0009554373100399971 | Validation loss:…

---- Epoch 15 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0010574148036539555 | Validation loss:…

---- Epoch 16 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 1.3060365745332092e-05 | Validation loss…

---- Epoch 17 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 2.4512034997314913e-06 | Validation loss…

---- Epoch 18 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 1.572059773025103e-06 | Validation loss:…

---- Epoch 19 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 2.354366188228596e-06 | Validation loss:…

---- Epoch 20 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 2.972747552121291e-06 | Validation loss:…

---- Epoch 21 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 3.747532446141122e-06 | Validation loss:…

---- Epoch 22 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 1.2947795767104253e-05 | Validation loss…

---- Epoch 23 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 1.7881382063933415e-07 | Validation loss…

---- Epoch 24 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 6.705519695060502e-08 | Validation loss:…

---- Epoch 25 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 7.450580152834618e-09 | Validation loss:…

---- Epoch 26 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 1.987613374052128…

---- Epoch 27 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 2.9802251333421736e-07 | Validation loss…

---- Epoch 28 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 5.960463056453591e-08 | Validation loss:…

---- Epoch 29 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 3.180127854187889…

---- Epoch 30 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 1.257611807048419…

---- Epoch 31 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 9.979997427380896…

---- Epoch 32 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 9.085932850341452…

---- Epoch 33 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 3.172259183238503…

---- Epoch 34 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 0.003150417439633…

---- Epoch 35 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 8.41891232994385e-06 | Validation loss: …

---- Epoch 36 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 9.015137720780331e-07 | Validation loss:…

---- Epoch 37 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 0.001560915197236…

---- Epoch 38 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 0.001425492412088…

---- Epoch 39 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 0.001104822616791…

---- Epoch 40 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 0.001327274161841…

---- Epoch 41 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 8.195635814445268e-08 | Validation loss:…

---- Epoch 42 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 0.000963162877337…

---- Epoch 43 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 0.000852033844720…

---- Epoch 44 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 0.000168361154598…

---- Epoch 45 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 7.450580152834618e-09 | Validation loss:…

---- Epoch 46 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.021662140265107155 | Validation loss: …

---- Epoch 47 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 2.2351738238057806e-08 | Validation loss…

---- Epoch 48 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 8.430576883669734…

---- Epoch 49 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 2.2351738238057806e-08 | Validation loss…

---- Epoch 50 ----


  0%|          | Batch: 0/2024 | Time: [00:00<?, ?it/s] | Train loss: 0.0 | Validation loss: 3.872235737524142…

Finished Training


In [None]:
summary(netLP)

In [None]:
history_df = pd.DataFrame(history)
history_df.to_csv(HSTRY_PATH, index=False)

In [None]:
FINAL_MODEL_PATH = f'{weights_path}/netLP-final-weights.pth'
FINAL_OPTIM_PATH = f'{weights_path}/netLP-final-optim-weights.pth'

torch.save(netLP.state_dict(), FINAL_MODEL_PATH)
torch.save(optimizer.state_dict(), FINAL_OPTIM_PATH)