In [None]:
!wget https://www.dropbox.com/s/l2ul3upj7dkv4ou/synthetic-data.zip
!unzip -qq synthetic-data.zip

# each image is located in folder synthetic-data
# example image name (annotation) is according what is whitten on image:
# on this image is written word "American" : /content/synthetic-data/American@3WPOqS.png

In [2]:
!pip install livelossplot --quiet
from livelossplot import PlotLosses

import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import os
from glob import glob
import cv2
import random
import numpy as np
from sklearn.model_selection import train_test_split

device = 'cuda' if torch.cuda.is_available() else 'cpu'

[?25l[K     |▏                               | 10 kB 30.3 MB/s eta 0:00:01[K     |▍                               | 20 kB 36.4 MB/s eta 0:00:01[K     |▋                               | 30 kB 43.8 MB/s eta 0:00:01[K     |▉                               | 40 kB 48.6 MB/s eta 0:00:01[K     |█                               | 51 kB 52.8 MB/s eta 0:00:01[K     |█▎                              | 61 kB 57.4 MB/s eta 0:00:01[K     |█▌                              | 71 kB 50.0 MB/s eta 0:00:01[K     |█▊                              | 81 kB 51.0 MB/s eta 0:00:01[K     |█▉                              | 92 kB 52.1 MB/s eta 0:00:01[K     |██                              | 102 kB 54.0 MB/s eta 0:00:01[K     |██▎                             | 112 kB 54.0 MB/s eta 0:00:01[K     |██▌                             | 122 kB 54.0 MB/s eta 0:00:01[K     |██▊                             | 133 kB 54.0 MB/s eta 0:00:01[K     |███                             | 143 kB 54.0 MB/s eta 0:

In [3]:
#@title Levenstein Distance


## got help from here: https://blog.paperspace.com/implementing-levenshtein-distance-word-autocomplete-autocorrect/#:~:text=The%20Levenshtein%20distance%20is%20a,transform%20one%20word%20into%20another.


# def printDistances(distances, token1Length, token2Length):
#     for t1 in range(token1Length + 1):
#         for t2 in range(token2Length + 1):
#             print(int(distances[t1][t2]), end=" ")
#         print()

def levenshteinDistanceDP(token1, token2):
    distances = np.zeros((len(token1) + 1, len(token2) + 1))

    for t1 in range(len(token1) + 1):
        distances[t1][0] = t1

    for t2 in range(len(token2) + 1):
        distances[0][t2] = t2
        
    a = 0
    b = 0
    c = 0
    
    for t1 in range(1, len(token1) + 1):
        for t2 in range(1, len(token2) + 1):
            if (token1[t1-1] == token2[t2-1]):
                distances[t1][t2] = distances[t1 - 1][t2 - 1]
            else:
                a = distances[t1][t2 - 1]
                b = distances[t1 - 1][t2]
                c = distances[t1 - 1][t2 - 1]
                
                if (a <= b and a <= c):
                    distances[t1][t2] = a + 1
                elif (b <= a and b <= c):
                    distances[t1][t2] = b + 1
                else:
                    distances[t1][t2] = c + 1

    #printDistances(distances, len(token1), len(token2))
    return distances[len(token1)][len(token2)]

## test
#levenshteinDistanceDP('ladooo', 'laaado') 

In [4]:
#@title data handling


images = glob('./synthetic-data/*.png')
imagelabels = lambda fname: fname.split("/")[-1].split('@')[0]

vocab = 'QWERTYUIOPASDFGHJKLZXCVBNMqwertyuiopasdfghjklzxcvbnm'
B,T,V = 64, 32, len(vocab) 
H,W = 32, 128 

class OCRDataset(Dataset):
    def __init__(self, items, vocab=vocab, preprocess_shape=(H,W), timesteps=T):
        super().__init__()
        self.items = items
        self.charList = {ix+1:ch for ix,ch in enumerate(vocab)}
        self.charList.update({0: '`'})
        self.invCharList = {v:k for k,v in self.charList.items()}
        self.ts = timesteps
    def __len__(self):
        return len(self.items)
    def sample(self):
        return self[random.randint(0, len(self))]
    def __getitem__(self, ix):
        item = self.items[ix]
        image = cv2.imread(item, 0)
        label = imagelabels(item)
        return image, label
    def collate_fn(self, batch):
        images, labels, label_lengths, label_vectors, input_lengths = [], [], [], [], []
        for image, label in batch:
            images.append(torch.Tensor(self.preprocess(image))[None,None])
            label_lengths.append(len(label))
            labels.append(label)
            label_vectors.append(self.str2vec(label))
            input_lengths.append(self.ts)
        images = torch.cat(images).float().to(device)
        label_lengths = torch.Tensor(label_lengths).long().to(device)
        label_vectors = torch.Tensor(label_vectors).long().to(device)
        input_lengths = torch.Tensor(input_lengths).long().to(device)
        return images, label_vectors, label_lengths, input_lengths, labels
    def str2vec(self, string, pad=True):
        string = ''.join([s for s in string if s in self.invCharList])
        val = list(map(lambda x: self.invCharList[x], string)) 
        if pad:
            while len(val) < self.ts:
                val.append(0)
        return val
    def preprocess(self, img, shape=(32,128)):
        target = np.ones(shape)*255
        try:
            H, W = shape
            h, w = img.shape
            fx = H/h
            fy = W/w
            f = min(fx, fy)
            _h = int(h*f)
            _w = int(w*f)
            _img = cv2.resize(img, (_w,_h))
            target[:_h,:_w] = _img
        except:
            ...
        return (255-target)/255
    def decoder_chars(self, pred):
        decoded = ""
        last = ""
        pred = pred.cpu().detach().numpy()
        for i in range(len(pred)):
            k = np.argmax(pred[i])
            if k > 0 and self.charList[k] != last:
                last = self.charList[k]
                decoded = decoded + last
            elif k > 0 and self.charList[k] == last:
                continue
            else:
                last = ""
        return decoded.replace(" "," ")
    def wer(self, preds, labels):
        c = 0
        for p, l in zip(preds, labels):
            c += p.lower().strip() != l.lower().strip()
        return round(c/len(preds), 4)
    def cer(self, preds, labels):
        c, d = [], []
        for p, l in zip(preds, labels):
            c.append(levenshteinDistanceDP(p, l) / len(l))   
        return round(np.mean(c), 4)
    def evaluate(self, model, ims, labels, lower=False):
        model.eval()
        preds = model(ims).permute(1,0,2) # B, T, V+1
        preds = [self.decoder_chars(pred) for pred in preds]
        return {'char-error-rate': self.cer(preds, labels),
                'word-error-rate': self.wer(preds, labels),
                'char-accuracy' : 1 - self.cer(preds, labels),
                'word-accuracy' : 1 - self.wer(preds, labels)}


trn_items, val_items = train_test_split(images, test_size=0.2, random_state=22)
trn_ds = OCRDataset(trn_items)
val_ds = OCRDataset(val_items)

trn_dl = DataLoader(trn_ds, batch_size=B, collate_fn=trn_ds.collate_fn, drop_last=True, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=B, collate_fn=val_ds.collate_fn, drop_last=True)

In [5]:
#@title model

class BasicBlock(nn.Module):
    def __init__(self, ni, no, ks=3, st=1, padding=1, pool=2, drop=0.2):
        super().__init__()
        self.ks = ks
        self.block = nn.Sequential(
            nn.Conv2d(ni, no, kernel_size=ks, stride=st, padding=padding),
            nn.BatchNorm2d(no, momentum=0.3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(pool),
            nn.Dropout2d(drop)
        )
    def forward(self, x):
        return self.block(x)


class Ocr(nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.model = nn.Sequential(
            BasicBlock( 1, 128),
            BasicBlock(128, 128),
            BasicBlock(128, 256, pool=(4,2)),
        )
        self.rnn = nn.Sequential(
            nn.LSTM(256, 256, num_layers=2, dropout=0.2, bidirectional=True),
        )
        self.classification = nn.Sequential(
            nn.Linear(512, vocab+1),
            nn.LogSoftmax(-1),
        )
    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 256, 32)    # x.reshape(-1, 256, 32)
        x = x.permute(2, 0, 1)       
        x, lstm_states = self.rnn(x)
        y = self.classification(x)
        return y

In [None]:
#@title training

def ctc(log_probs, target, input_lengths, target_lengths, blank=0):
    loss = nn.CTCLoss(blank=blank, zero_infinity=True)
    ctc_loss = loss(log_probs, target, input_lengths, target_lengths)
    return ctc_loss

def train_batch(data, model, optimizer, criterion):
    model.train()
    imgs, targets, label_lens, input_lens, labels = data
    optimizer.zero_grad()
    preds = model(imgs)
    loss = criterion(preds, targets, input_lens, label_lens)
    loss.backward()
    optimizer.step()
    results = trn_ds.evaluate(model, imgs.to(device), labels)
    return loss, results

@torch.no_grad()
def validate_batch(data, model):
    model.eval()
    imgs, targets, label_lens, input_lens, labels = data
    preds = model(imgs)
    loss = criterion(preds, targets, input_lens, label_lens)
    return loss, val_ds.evaluate(model, imgs.to(device), labels)    

###-------------------------------- define -----------------------------------
model = Ocr(len(vocab)).to(device)
criterion = ctc
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
n_epochs = 50

###----------------------------------------------------------------------------


def train_model(model, trn_dl, val_dl, optimizer, criterion, nepochs=n_epochs, eval_every=len(trn_dl), out_dir='./'):
  
  liveloss = PlotLosses()   # to plot training progress

  history = {'trn_loss': None, 'trn_char_acc': None, 'trn_word_acc': None,
            'val_loss': None, 'val_char_acc': None, 'val_word_acc': None}
  losses = {'loss': None,
          'val_loss': None}

  best = float('inf')    # set to infinite initially
  it = 0
  for epoch in range(nepochs):

    history['trn_loss'], history['trn_char_acc'], history['trn_word_acc']   = [], [], []       # initialize emtpy container for training losses
    history['val_loss'], history['val_char_acc'], history['val_word_acc']   = [], [], [] 

    losses['loss'], losses['val_loss'] = [], []

    for _, data in enumerate(trn_dl):
      it += 1

      loss, results = train_batch(data, model, optimizer, criterion)
      # scheduler.step()
      ca, wa = results['char-accuracy'], results['word-accuracy']
      history['trn_loss'].append(loss.detach().cpu())
      history['trn_char_acc'].append(ca)
      history['trn_word_acc'].append(wa)

      if (it == 1) or (it % eval_every == 0):
        
        with torch.no_grad():
          average_val_loss = []
          average_char_loss = []
          average_word_loss = []

          for _, data in enumerate(val_dl):
              loss, results = validate_batch(data, model)
              ca, wa = results['char-accuracy'], results['word-accuracy']
              history['val_loss'].append(loss), 
              history['val_char_acc'].append(ca), history['val_word_acc'].append(wa)
              
              average_val_loss.append(loss)    # use average_val_loss.extend(...) if loss is itself listlike etc
              average_char_loss.append(ca)
              average_word_loss.append(wa)

          average_val_loss = torch.stack(average_val_loss).mean().item()
          average_char_loss = np.mean(average_char_loss)
          average_word_loss = np.mean(average_word_loss)

          if (average_val_loss + average_char_loss + average_word_loss) < best:     # keep track of best model
            best = (average_val_loss + average_char_loss + average_word_loss)
            torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt'))  # save best model

    # update liveplot with latest values
    losses['val_loss'] = average_val_loss
    lower_bound = it-(len(trn_dl)*(epoch+1))
    part = history['trn_loss'][lower_bound : it]
    losses['loss'] = np.mean(part)    #sum(part)/len(part) # average over all training losses

    liveloss.update(losses)
    liveloss.send()

train_model(model, trn_dl, val_dl, optimizer, criterion, nepochs=10, eval_every=len(trn_dl), out_dir='./')


In [None]:
## torch.save(model.state_dict(), './saved_model.pth') --> you do not need this because you already saved model versions.
model = Ocr(len(vocab)).to(device)
model.load_state_dict(torch.load('./model_best.pt'))
model.eval()

In [32]:
## if you want to take generally any random image
# for jx in range(5):
#     img, label = val_ds.sample()
#     _img = torch.Tensor(val_ds.preprocess(img)[None,None]).to(device)
#     pred = model(_img)[:,0,:]
#     pred = trn_ds.decoder_chars(pred)
#     print(f'Pred: `{pred}` :: Truth: `{label}`')

def predict(item):
    img = cv2.imread(item, 0)
    label = imagelabels(item)
    _img = torch.Tensor(val_ds.preprocess(img)[None,None]).to(device)
    pred = model(_img)[:,0,:]
    pred = trn_ds.decoder_chars(pred)
    print(f'Pred: `{pred}` :: Truth: `{label}`')

predict('./synthetic-data/American@3WPOqS.png')

Pred: `american` :: Truth: `American`
