In [210]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import datasets, models, transforms as T
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

In [211]:
from tqdm.notebook import tqdm
import pathlib
import os
from PIL import Image
import string
from typing import Tuple
import datetime
import copy
import time
import numpy as np
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running training on [{device}]")

Running training on [cuda]


In [212]:
BATCH_SIZE = 128
WORKERS = 1
EPOCHS = 20
MAX_WORD_LENGTH = 10
ALPHABET = string.ascii_letters + string.digits + "_" #blank char for CTC
OUTPUT_SEQUENCE_LENGTH = 25
OUTPUT_STEP_SIZE = 5

In [213]:
class CustomDataset(Dataset):
	def __init__(self, root_path, type="train"):
		self.root_path = root_path
		self.type = type
		self.images_paths = list(pathlib.Path(self.root_path + "./images").glob('*.png'))
		self.transforms = {
			'train' : T.Compose([
				# T.Resize((40,200)),
				T.RandomRotation(20),
				T.GaussianBlur(3),
				T.ToTensor()
			]),
			'valid' : T.Compose([
				T.ToTensor()
			])
		}
		global ALPHABET
		self.alphabet = ALPHABET
		self.alphabet_size = len(self.alphabet)
		print(f"Alphabet size: {self.alphabet_size}")

	def __getitem__(self, idx):
		image_path = self.images_paths[idx]
		sample_name = str(image_path).split(os.sep)[-1].split(".")[0]
		text_path = self.root_path + "/transcripts/" + str(int(sample_name) + 1) + ".txt"

		image = Image.open(image_path).convert("RGB")
		with open(text_path) as f:
			text = f.read()

		image = self.transforms[self.type](image)
		text_tensor = self.wordToTensor(text)
		return image, (text_tensor, len(text), text)
		

	def __len__(self):
		return len(self.images_paths)

	def letterToIndex(self, letter):
		return self.alphabet.find(letter)

	def letterToTensor(self, letter):
		tensor = torch.zeros(1, n_letters)
		tensor[0][letterToIndex(letter)] = 1
		return tensor

	def wordToTensor(self, word):
		# tensor = torch.zeros(MAX_WORD_LENGTH, self.alphabet_size)
		# for li, letter in enumerate(word):
		# 	tensor[li][self.letterToIndex(letter)] = 1

		tensor = torch.zeros(MAX_WORD_LENGTH)
		for li, letter in enumerate(word):
			tensor[li] = self.letterToIndex(letter)
		return tensor


In [214]:
class DataHandler:
	def __init__(self, run_config):
		self._training_dataset = None
		self._validation_dataset = None
		self._run_config = run_config

		self._load_datasets()
		
	def _load_datasets(self):
		self._training_dataset = CustomDataset("dataset/training")
		self._validation_dataset = CustomDataset("dataset/validation")

	def get_data_loaders(self) -> Tuple[DataLoader]:
		return (
			DataLoader(self._training_dataset, batch_size=self._run_config["batch_size"], shuffle=True, pin_memory=True, drop_last=True), 
			DataLoader(self._validation_dataset, batch_size=self._run_config["batch_size"], shuffle=True, pin_memory=True, drop_last=True)
		)

	def get_datasets(self) -> Tuple[Dataset]:
		return self._training_dataset, self._validation_dataset

	def get_datasets_sizes(self) -> Tuple[int]:
		return len(self._training_dataset), len(self._validation_dataset)

In [215]:
data_handler = DataHandler(run_config = {
    "batch_size": BATCH_SIZE,
    "workers": WORKERS
})
train_loader, validation_loader = data_handler.get_data_loaders()
training_dataset_size, validation_dataset_size = data_handler.get_datasets_sizes()

Alphabet size: 63
Alphabet size: 63


In [216]:
class TranscribeModel(nn.Module):
    def __init__(self):
        super(TranscribeModel, self).__init__()
        self.conv_block1 = nn.Sequential(
			nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
			nn.MaxPool2d(kernel_size=(2, 2)),
			nn.BatchNorm2d(16),
			nn.LeakyReLU(0.2, inplace=True),

			nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
			nn.MaxPool2d(kernel_size=(2, 2)),
			nn.BatchNorm2d(32),
			nn.LeakyReLU(0.2, inplace=True)
		)

        self.rnn_block1 = nn.Sequential(
            nn.LSTM(input_size=OUTPUT_STEP_SIZE * 32, hidden_size=len(ALPHABET), num_layers=2, batch_first=True, bidirectional=True)
        )

        self.softmax = nn.LogSoftmax(dim=2)

    def forward(self, input):
        out = self.conv_block1(input)
        # print(out.shape) #[128, 32, 10, 50]
        out = out.permute([0, 3, 2, 1])
        # (B, S, ) 
        out = out.reshape(out.size(0), out.size(1), -1)

        out, (_, _) = self.rnn_block1(out)
        out = self.softmax(out)
        return out

In [217]:
model = TranscribeModel()
model.to(device)
optimizer = optim.AdamW(
    model.parameters(), 
    lr=0.0001, 
    betas=(0.9, 0.999), 
    eps=1e-08, 
    weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
loss_criterion = nn.CTCLoss(blank=len(ALPHABET)-1, zero_infinity=True, reduction="mean")
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=10, epochs=EPOCHS, anneal_strategy='linear')

log_dir = "logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(log_dir)

TranscribeModel(
  (conv_block1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (rnn_block1): Sequential(
    (0): LSTM(160, 63, num_layers=2, batch_first=True, bidirectional=True)
  )
  (softmax): LogSoftmax(dim=2)
)

In [218]:
def tensorToWord(tensor):
	tensor = tensor[..., :63]
	indices = torch.argmax(tensor, dim=2).tolist()
	
	words = []
	for batch_idx in range(BATCH_SIZE):
		cur_word_indices = indices[batch_idx]
		cur_word = []
		last_letter = None
		for idx in range(OUTPUT_SEQUENCE_LENGTH):
			if ALPHABET[cur_word_indices[idx]] == ALPHABET[-1]:
				last_letter = None
				continue
			else:
				if last_letter == None or (last_letter is not None and last_letter != cur_word_indices[idx]):
					cur_word.append(ALPHABET[cur_word_indices[idx]])
					last_letter = cur_word_indices[idx]
					continue
		words.append("".join(cur_word))
	return words

	

In [219]:
def levenshteinDistance(string_one, string_two):
    dist = np.zeros((len(string_one) + 1, len(string_two) + 1))
    dist[1:, 0] = [i + 1 for i in range(len(string_one))]
    dist[0, 1:] = [i + 1 for i in range(len(string_two))]

    for row in range(1, len(string_one) + 1):
        for col in range(1, len(string_two) + 1):
            gain = 1
            if string_one[row - 1] == string_two[col -1]:
                dist[row, col] = dist[row - 1, col -1]
            else:
                dist[row, col] = gain + min(dist[row - 1, col], dist[row, col - 1], dist[row - 1, col - 1])
    
    return dist[-1,-1]

def batchLevenshteinDistance(arr_one, arr_two):
    cumulative_distance = 0
    for idx in range(len(arr_one)):
        cumulative_distance += levenshteinDistance(arr_one[idx], arr_two[idx])
    return cumulative_distance


In [223]:
image, (label_tensor, text_length, label_text) = iter(train_loader).next()
image, label_tensor = image.to(device), label_tensor.to(device)
output = model(image)
print(text_length)

tensor([ 6,  4,  4,  4,  4,  5,  5,  3,  1,  8, 10,  2,  7,  1,  8,  7,  9,  6,
        10,  3,  2,  9,  8,  5,  8,  9,  8,  3,  9,  4, 10,  8,  3, 10,  8,  7,
         3,  2,  3,  4,  9,  5,  4,  1,  7,  2, 10,  5,  6,  4,  8,  5,  7,  5,
         1,  7,  9,  1,  5,  9,  1,  8,  9,  1, 10,  5,  2,  4, 10,  5,  7,  8,
         2,  4,  4,  1,  2,  6,  9,  2,  8,  1,  4,  7, 10,  3,  9,  6,  2,  1,
         1,  4,  4,  1,  6,  7,  6,  8,  5, 10,  8, 10,  4,  4,  7,  3,  9,  7,
         2,  5,  9, 10,  5,  6,  1,  8,  4,  1,  9,  2,  1,  6,  3, 10,  5,  7,
         7,  5])


In [224]:
output.shape
out = output[..., :63]
# torch.argmax(out, dim=2)
print(">>>>>>> label_text <<<<<<<")
print(label_text)

print(">>>>>>> label_tensor <<<<<<<")
print(label_tensor)

print(">>>>>>> output <<<<<<<<")
print(output)

tensorToWord(out)

torch.Size([128, 25, 126])

>>>>>>> label_text <<<<<<<
['wlA1iH', '3G4S', 'asOy', 'qPa6', '3vBR', 'OkDkN', '0w2Sf', 'cbF', 'e', 'mEjW3Fjz', 'Wq8p6S4Gse', 'XO', 'ZzmdIgr', 'X', 'rTs8amUO', 'zyEAFWB', '0Agek7n3p', 'xc8mtY', 'nQwEJBVtCj', 'TvW', 'Qh', 'UAdyNGSg6', 'gH0X9sbs', 'Ibznk', 'Af9lbBuv', 'MtWYaEHmU', 'a8pGe1ks', 'hMr', 'cYV4rLkPa', 'f0ah', 'EsQQIbeYCm', 'ZbVd45EJ', '9XC', 'bPFBqvP5Xe', 'M9aG2LCE', 'CPSy6kV', 'fp9', 'f4', 'mjK', 'CFUh', 'GnEgTLW0c', '5jWOu', 'gnts', 'E', 'T9VZVH2', '6d', '8Y1hiftC0X', 'dWBbL', '4CJFHJ', 'ios5', '5PNxH1hu', 'auj1A', 'Lz0qE6M', 'zGYZs', 'g', 'MCGlYSh', 'eIE4SgIl9', 'b', 'oWgNY', '0Ptmnpnz9', 'u', 'kzLwuKuC', 'vbw8sEZma', '9', 'ifZaRNIenI', 'Kzbqf', '7c', 'VdZ5', 'oVcRtCSQC7', 'KWAXi', 'jcCwJWV', 'XfB8Hwyp', '3E', 'nImH', 'xaos', '4', 'KR', 'vA1ihq', 'zgsLaXlPq', 'BS', '38gdY1Ay', 'y', '7Lxt', 'Urm90Ag', 'LdWIY04yaQ', 'uhm', 'kI3xnzVA6', '4Zo68U', '8o', 's', 'A', '08wS', 'MZlC', 'm', 'yX8DIE', 'VAfwNlj', '8fvLMj', '9IzF2jZY', 'eRqGR', 'V3chJ3RJkP', 'cW8DRwsI', 'bKZOOBiprd', 'yt

['',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '']

In [222]:
since = time.time()
for epoch in range(EPOCHS):
	print('Epoch {}/{}'.format(epoch, EPOCHS))
	print('-' * 10)

	########### Training step ###########
	model = model.train()
	training_loss = []
	running_loss = 0.0
	running_corrects = 0
			
	for i, data in enumerate(tqdm(train_loader, desc=f"Epoch [{epoch + 1}] progress")):

		x_batch, (label_batch, label_length, label_text) = data
		x_batch, label_batch, label_length = x_batch.to(device), label_batch.to(device), label_length.to(device)

		optimizer.zero_grad()
		outputs = model(x_batch)
		outputs_permuted = outputs.permute((1, 0, 2))

		loss = loss_criterion(outputs_permuted, label_batch, torch.full((BATCH_SIZE,), OUTPUT_SEQUENCE_LENGTH).to(device), label_length)

		loss.backward()
		optimizer.step()

		# statistics
		running_loss += loss.item() * x_batch.size(0)
		predictions = tensorToWord(outputs)
		running_corrects += batchLevenshteinDistance(predictions, label_text)
		training_loss.append(loss.item())

	epoch_loss = running_loss / training_dataset_size
	epoch_acc = running_corrects / training_dataset_size

	# tensorboard logging
	writer.add_scalar("Loss/train", epoch_loss, epoch)

	print('Training step => Loss: {:.4f} | Dist: {:.4f}'.format(
		epoch_loss, epoch_acc
	))

	scheduler.step()


	########### Validation step ###########
	model = model.eval()
	validation_loss = []
	running_loss = 0.0
	running_corrects = 0

	for i, data in enumerate(validation_loader):
		with torch.no_grad():
			x_batch, (label_batch, label_length, label_text) = data
			x_batch, label_batch, label_length = x_batch.to(device), label_batch.to(device), label_length.to(device)

			outputs = model(x_batch)
			outputs_permuted = outputs.permute((1, 0, 2))
			loss = loss_criterion(outputs_permuted, label_batch, torch.full((BATCH_SIZE,), OUTPUT_SEQUENCE_LENGTH).to(device), label_length)

			running_loss += loss.item() * x_batch.size(0)
			
			predictions = tensorToWord(outputs)
			running_corrects += batchLevenshteinDistance(predictions, label_text)
			validation_loss.append(loss.item())
			
	epoch_loss = running_loss / validation_dataset_size
	epoch_acc = running_corrects / validation_dataset_size

	# tensorboard logging
	writer.add_scalar("Loss/validation", epoch_loss, epoch)

	print('Evaluation step => Loss: {:.4f} | Dist {:.4f}'.format(
		epoch_loss, epoch_acc
	))
	best_acc = 0
	#Save the best model based on accuracy
	if True:
		best_model_wts = copy.deepcopy(model.state_dict())

	#Checkpoint
	torch.save({
		"epoch": epoch,
		"model_state_dict": model.state_dict(),
		"optimizer_state_dict": optimizer.state_dict()
	}, "./checkpoints/ckp.pt")



time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
	time_elapsed // 60, time_elapsed % 60
))
print('Best (so far) validation Acc: {:4f}'.format(best_acc))

print('-' * 10)
print('### Final results ###\n')
print('Best validation Acc: {:4f}'.format(best_acc))

Epoch 0/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [1] progress', max=48.0, style=ProgressStyle(descri…


Training step => Loss: 31.3369 | Dist: 6.5242
Evaluation step => Loss: 26.2891 | Dist 4.3871
Epoch 1/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [2] progress', max=48.0, style=ProgressStyle(descri…


Training step => Loss: 30.1149 | Dist: 5.3958
Evaluation step => Loss: 24.4581 | Dist 4.3839
Epoch 2/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [3] progress', max=48.0, style=ProgressStyle(descri…


Training step => Loss: 27.5865 | Dist: 5.4139
Evaluation step => Loss: 20.7836 | Dist 4.5516
Epoch 3/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [4] progress', max=48.0, style=ProgressStyle(descri…


Training step => Loss: 25.1225 | Dist: 5.4079
Evaluation step => Loss: 20.3415 | Dist 4.4129
Epoch 4/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [5] progress', max=48.0, style=ProgressStyle(descri…


Training step => Loss: 23.5602 | Dist: 5.4106
Evaluation step => Loss: 17.9795 | Dist 4.5258
Epoch 5/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [6] progress', max=48.0, style=ProgressStyle(descri…


Training step => Loss: 22.4290 | Dist: 5.4079
Evaluation step => Loss: 17.9040 | Dist 4.4871
Epoch 6/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [7] progress', max=48.0, style=ProgressStyle(descri…


Training step => Loss: 21.7630 | Dist: 5.4058
Evaluation step => Loss: 17.5179 | Dist 4.5016
Epoch 7/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [8] progress', max=48.0, style=ProgressStyle(descri…


Training step => Loss: 21.4845 | Dist: 5.4044
Evaluation step => Loss: 17.2680 | Dist 4.4419
Epoch 8/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [9] progress', max=48.0, style=ProgressStyle(descri…


Training step => Loss: 21.3541 | Dist: 5.4027
Evaluation step => Loss: 17.3930 | Dist 4.4645
Epoch 9/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [10] progress', max=48.0, style=ProgressStyle(descr…


Training step => Loss: 21.3081 | Dist: 5.4097
Evaluation step => Loss: 17.2609 | Dist 4.4645
Epoch 10/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [11] progress', max=48.0, style=ProgressStyle(descr…


Training step => Loss: 21.3000 | Dist: 5.4055
Evaluation step => Loss: 17.7164 | Dist 4.4097
Epoch 11/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [12] progress', max=48.0, style=ProgressStyle(descr…


Training step => Loss: 21.2810 | Dist: 5.4015
Evaluation step => Loss: 17.5225 | Dist 4.4226
Epoch 12/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [13] progress', max=48.0, style=ProgressStyle(descr…


Training step => Loss: 21.2057 | Dist: 5.4110
Evaluation step => Loss: 17.2110 | Dist 4.4694
Epoch 13/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [14] progress', max=48.0, style=ProgressStyle(descr…


Training step => Loss: 21.2365 | Dist: 5.4085
Evaluation step => Loss: 17.1271 | Dist 4.4839
Epoch 14/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [15] progress', max=48.0, style=ProgressStyle(descr…


Training step => Loss: 21.2247 | Dist: 5.4105
Evaluation step => Loss: 17.2019 | Dist 4.4355
Epoch 15/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [16] progress', max=48.0, style=ProgressStyle(descr…


Training step => Loss: 21.2528 | Dist: 5.4011
Evaluation step => Loss: 17.6210 | Dist 4.4548
Epoch 16/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [17] progress', max=48.0, style=ProgressStyle(descr…


Training step => Loss: 21.2059 | Dist: 5.4074
Evaluation step => Loss: 17.0845 | Dist 4.4806
Epoch 17/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [18] progress', max=48.0, style=ProgressStyle(descr…


Training step => Loss: 21.2040 | Dist: 5.4058
Evaluation step => Loss: 17.4289 | Dist 4.4645
Epoch 18/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [19] progress', max=48.0, style=ProgressStyle(descr…


Training step => Loss: 21.2160 | Dist: 5.4103
Evaluation step => Loss: 17.7020 | Dist 4.4129
Epoch 19/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [20] progress', max=48.0, style=ProgressStyle(descr…


Training step => Loss: 21.2078 | Dist: 5.4034
Evaluation step => Loss: 17.4039 | Dist 4.4419
Training complete in 2m 50s
Best (so far) validation Acc: 0.000000
----------
### Final results ###

Best validation Acc: 0.000000


<All keys matched successfully>