# Image text recognition (from text-only images)

## Importing all libraries

In [None]:
import numpy as np
import torch
from torch import nn
import torchvision
import matplotlib.pyplot as plt
%matplotlib inline

import os
import glob
from torch.utils.data import Dataset
from scipy import signal
from scipy.io import wavfile
import cv2
from PIL import Image

from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import CTCLoss

from collections import OrderedDict



### Showing a few inputs

In [None]:
image_list = []
for filename in glob.glob('/mjsynth/90kDICT32px/1/1/*.jpg'):
    im=Image.open(filename)
    image_list.append(im)
for i in range(4):
    plt.subplot(2,2,(i+1))
    plt.imshow(image_list[i])

# Training data collection

In [None]:
class create_dataset(Dataset):
    ## finds and loads the data and its annotations, given the mjsynth directory 
    all_chars = '0123456789abcdefghijklmnopqrstuvwxyz'
    charToLabel = {char: i + 1 for i, char in enumerate(all_chars)}
    labelToChar = {label: char for char, label in charToLabel.items()}

    def __init__(self, dir_=None, mode=None, img_path=None, img_ht=32, img_wdt=100):
        if dir_ and mode and not img_path:
            img_path, texts = self.__load_files__(dir_, mode)
        elif not dir_ and not mode and img_path:
            texts = None

        self.img_path = img_path
        self.texts = texts
        self.img_ht = img_ht
        self.img_wdt = img_wdt

    def __load_files__(self, dir_, mode):
        mapping = {}
        with open(os.path.join(dir_, 'lexicon.txt'), 'r') as fr:
            for i, line in enumerate(fr.readlines()):
                mapping[i] = line.strip()

        annotation_file = None
        if mode == 'train':
            annotation_file = 'annotation_train.txt'
        elif mode == 'dev':
            annotation_file = 'annotation_val.txt'
        elif mode == 'test':
            annotation_file = 'annotation_test.txt'

        img_path = []
        texts = []
        with open(os.path.join(dir_, annotation_file), 'r') as fr:
            for line in fr.readlines():
                path, index_str = line.strip().split(' ')
                path = os.path.join(dir_, path)
                index = int(index_str)
                text = mapping[index]
                img_path.append(path)
                texts.append(text)
        return img_path, texts
    
    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, index):
        path = self.img_path[index]

        try:
            image = Image.open(path).convert('L')  # grey-scale
        except IOError:
            return self[index + 1]

        ## reshaping and scaling the input images
        
        image = image.resize((self.img_wdt, self.img_ht), resample=Image.BILINEAR)
        image = np.array(image)
        image = image.reshape((1, self.img_ht, self.img_wdt))
        image = (image / 127.5) - 1.0

        image = torch.FloatTensor(image)
        if self.texts:     # if in training/validation modes  
            text = self.texts[index]
            target = [self.charToLabel[c] for c in text]
            target_length = [len(target)]

            target = torch.LongTensor(target)
            target_length = torch.LongTensor(target_length)
            return image, target, target_length
        else:              # if in testing mode
            return image

### Collate function for dataloader

In [None]:
def collate_fun(batch): ## used to create the labels for the images
    images, targets, target_len = zip(*batch)
    images = torch.stack(images, 0)
    targets = torch.cat(targets, 0)
    target_len = torch.cat(target_len, 0)
    return images, targets, target_len

# model 

In [None]:
class my_model(nn.Module):
    def __init__(self,img_channels,img_ht,img_w,num_class,map_to_seq_hidden = 64,rnn_hidden = 256):
        super(my_model, self).__init__()
        self.cnn_stack = nn.Sequential(
                OrderedDict([
                  ('conv1', nn.Conv2d(img_channels,64,3,1,1)),
                  # ('bn1',nn.BatchNorm2d(64)),
                  ('relu1', nn.ReLU(inplace=True)),
                  ('pool1',nn.MaxPool2d(kernel_size = 2,stride = 2)),

                  ('conv2', nn.Conv2d(64,128,3,1,1)),
                  # ('bn2',nn.BatchNorm2d(128)),
                  ('relu2', nn.ReLU(inplace=True)),
                  ('pool2',nn.MaxPool2d(kernel_size = 2,stride = 2)),

                  ('conv3', nn.Conv2d(128,256,3,1,1)),
                  # ('bn3',nn.BatchNorm2d(128)),
                  ('relu3', nn.ReLU(inplace=True)),

                  ('conv4', nn.Conv2d(256,256,3,1,1)),
                  # ('bn4',nn.BatchNorm2d(256)),
                  ('relu4', nn.ReLU(inplace=True)),
                  ('pool4',nn.MaxPool2d(kernel_size = (2,1))),

                  ('conv5', nn.Conv2d(256,512,3,1,1)),
                  ('relu5', nn.ReLU(inplace=True)),
                  ('bn5',nn.BatchNorm2d(512)),

                  ('conv6', nn.Conv2d(512,512,3,1,1)),
                  ('relu6', nn.ReLU(inplace=True)),
                  ('bn6',nn.BatchNorm2d(512)),
                  ('pool6',nn.MaxPool2d(kernel_size = (2,1))),

                  ('conv7', nn.Conv2d(512,512,2,1,0)),
                  ('relu7', nn.ReLU(inplace=True)),
                ])
            )
        out_ht = img_ht // 16 - 1
        out_w = img_w // 4 -1
        self.map_to_seq = nn.Linear(512*out_ht, map_to_seq_hidden)
        self.lstm1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
        self.lstm2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)
        self.out_ = nn.Linear(2 * rnn_hidden, num_class)

    def forward(self, x):
        conv = self.cnn_stack(x)
        batch, channel, height, width = conv.size()
        conv = conv.view(batch, channel * height, width)
        conv = conv.permute(2, 0, 1)  # (width, batch, feature)
        seq = self.map_to_seq(conv)
        lstm_, _ = self.lstm1(seq)
        lstm_, _ = self.lstm2(lstm_)
        out = self.out_(lstm_)
        return out

# Character decoder 

In [None]:
def decode(log_probs,labelToChar = None,blank = 0):
    emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2)) # (batch, length, class)
    decoded_list = []
    for emission_log_prob in emission_log_probs:
        labels = np.argmax(emission_log_prob, axis=-1)
        new_labels = []
        # merging same labels
        previous = None
        for l in labels:
            if l != previous:
                new_labels.append(l)
                previous = l
        decoded = []
        ## removing blanks
        for i in new_labels:
            if i!=blank:
                decoded.append(i)
        ## for prediction, convert the number predictions to characters
        if labelToChar:
            decoded = [labelToChar[l] for l in decoded]
        decoded_list.append(decoded)
    return decoded_list

# evalute the model 

In [None]:
def evaluate(model_, dataloader, criterion):
    model_.eval()

    count = 0
    eval_loss = 0
    num_correct_preds = 0

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            device = 'cuda' if next(model_.parameters()).is_cuda else 'cpu'

            images, targets, target_len = [d.to(device) for d in data]

            logits = model_(images)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)

            batch_sz = images.size(0)
            input_lengths = torch.LongTensor([logits.size(0)] * batch_sz)

            loss = criterion(log_probs, targets, input_lengths, target_len)

            preds = decode(log_probs)
            reals = targets.cpu().numpy().tolist()
            target_len = target_len.cpu().numpy().tolist()

            count += batch_sz
            eval_loss += loss.item()
            target_length_counter = 0
            for pred, target_length in zip(preds, target_len):
                real = reals[target_length_counter:target_length_counter + target_length]
                target_length_counter += target_length
                if pred == real:
                    num_correct_preds += 1

    eval_ = {
        'loss': eval_loss / count,
        'acc': num_correct_preds / count,
    }
    return eval_




# Train function  

In [None]:
def train_batch(model_, data, optimizer, criterion, device):
    model_.train()
    images, targets, target_len = [x.to(device) for x in data]

    logits = model_(images)
    log_probs = torch.nn.functional.log_softmax(logits, dim=2)

    batch_sz = images.size(0)
    input_lengths = torch.LongTensor([logits.size(0)] * batch_sz)
    target_len = torch.flatten(target_len)

    loss = criterion(log_probs, targets, input_lengths, target_len)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()

## Training paramters

In [None]:
epochs = 10
train_batch_sz = 32
eval_batch_sz = 512
lr = 0.0005
show_train_loss = 20000
show_valid_loss = 50000
save_model = 50000
cpu_workers = 4
check_pt =  '/scratch/pm3140/checkpoints/check_pt1350000.pt'
# check_pt = None

img_wdt = 100
img_ht = 32
data_path = '/mjsynth/90kDICT32px/'
checkpts_path = '/scratch/pm3140/checkpoints/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

## Loading data into dataloaders

In [None]:
train_data = create_dataset(dir_=data_path, mode='train',
                                img_ht=img_ht, img_wdt=img_wdt)
val_data = create_dataset(dir_=data_path, mode='dev',
                                img_ht=img_ht, img_wdt=img_wdt)

tr_loader = DataLoader(dataset=train_data,batch_size=train_batch_sz,
    shuffle=True,num_workers=cpu_workers,collate_fn=collate_fun)
val_loader = DataLoader(dataset=val_data,batch_size=eval_batch_sz,
    shuffle=True,num_workers=cpu_workers,collate_fn=collate_fun)

## Model and Loss

In [None]:
num_class = len(create_dataset.labelToChar) + 1
model_ = my_model(1, img_ht, img_wdt, num_class,
            map_to_seq_hidden=64,rnn_hidden=256)
if check_pt:
    model_.load_state_dict(torch.load(check_pt, map_location=device))
model_.to(device)

optimizer = optim.RMSprop(model_.parameters(), lr=lr)
criterion = CTCLoss(reduction='sum')
criterion.to(device)

## Training

In [None]:
train_loss_hist = []
val_loss_hist = []
val_acc_hist = []
i = 0
for epoch in range(1, epochs + 1):
    print('epoch: ',epoch)
    tot_train_loss = 0.0
    tot_train_count = 0.0
    for train_data in tr_loader:
        loss = train_batch(model_, train_data, optimizer, criterion, device)
        train_size = train_data[0].size(0)

        tot_train_loss += loss
        tot_train_count += train_size
        
        if i % show_train_loss == 0:
            train_loss_hist.append(loss / train_size)
            print('train_batch_loss[%d]: %4f\n'%(i,loss / train_size))

        if i % show_valid_loss == 0:
            print('evaluating on the validation set ...')
            eval_ = evaluate(model_, val_loader, criterion)
            val_loss_hist.append(eval_['loss'])
            val_acc_hist.append(eval_['acc'])
            print('valid: loss=%4f, acc=%4f'%(eval_['loss'],eval_['acc']))

        if i % save_model == 0:
            loss = eval_['loss']
            print('saving model ...')
            torch.save(model_.state_dict(), 
                       os.path.join(checkpts_path,'check_pt'+str(i)+'.pt'))

        i += 1

    print('total train loss: ', tot_train_loss / tot_train_count)



## Loss/Accuracy Plots

In [None]:
fig, ax = plt.subplots()

ax.plot(train_loss_hist[0:80],label='Train')
ax.plot(val_loss_hist[0:80],label='Validation')
plt.title('Train and Validation loss')
plt.xlabel('iterations')
plt.ylabel('loss')
ax.legend()
# plt.legend([a,b],['train','validation'])
plt.show()

In [None]:
plt.figure(figsize=(12, 10), dpi=80)
fig, ax = plt.subplots()
# ax.plot(val_loss[0:80],label='Loss')
ax.plot(val_acc_hist,label='accuracy')
plt.title('Validation accuracy')
plt.xlabel('iterations')
plt.ylabel('accuracy')
ax.legend()
# plt.legend([a,b],['train','validation'])
plt.show()

In [None]:
plt.plot(val_acc_hist)
plt.show()

## Prediction function 

In [None]:
def predict(model_, dataloader, labelToChar):
    model_.eval()
    final_pred = []
    with torch.no_grad():
        for data in dataloader:
            device = 'cuda' if next(model_.parameters()).is_cuda else 'cpu'
            images = data.to(device)
            logits = model_(images)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)
            preds = decode(log_probs,labelToChar)
            final_pred+=preds
    return final_pred

### Setting the Prediction paramters

In [None]:
images = [os.path.join(pth, f) for pth, dir_s, files in os.walk('/scratch/pm3140/test_images') for f in files]
# images = ['/scratch/pm3140/test_images/100_Classmates_13991.jpg']
img_ht = 32
img_wdt = 100
check_pt='/scratch/pm3140/checkpoints/check_pt1350000.pt'
# check_pt = None

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')


pred_data = create_dataset(img_path = images,img_ht=img_ht, img_wdt=img_wdt)

pred_loader = DataLoader(dataset=pred_data,shuffle=False,num_workers=4)

num_class = len(create_dataset.labelToChar) + 1
model_ = my_model(1, img_ht, img_wdt, num_class,map_to_seq_hidden=64,rnn_hidden=256)
model_.load_state_dict(torch.load(check_pt, map_location=device))
model_.to(device)


### predictions

In [None]:
preds = predict(model_, pred_loader, create_dataset.labelToChar)
for pred in preds:
    final_pred = ''.join(pred)
    print(final_pred)