In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import datasets, models, transforms as T
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

In [2]:
from tqdm.notebook import tqdm
import pathlib
import os
from PIL import Image
import string
from typing import Tuple
import datetime
import copy
import time
import numpy as np
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(f"Running training on [{device}]")

Running training on [cuda]


In [3]:
BATCH_SIZE = 64
WORKERS = 1
EPOCHS = 20
MAX_WORD_LENGTH = 10
ALPHABET = string.ascii_letters + string.digits + "_" #blank char for CTC
OUTPUT_SEQUENCE_LENGTH = 10
OUTPUT_STEP_SIZE = 2

In [4]:
class CustomDataset(Dataset):
	def __init__(self, root_path, type="train"):
		self.root_path = root_path
		self.type = type
		self.images_paths = list(pathlib.Path(self.root_path + "./images").glob('*.png'))
		self.transforms = {
			'train' : T.Compose([
				# T.Resize((40,200)),
				T.RandomRotation(20),
				T.GaussianBlur(3),
				T.ToTensor()
			]),
			'valid' : T.Compose([
				T.ToTensor()
			])
		}
		global ALPHABET
		self.alphabet = ALPHABET
		self.alphabet_size = len(self.alphabet)
		print(f"Alphabet size: {self.alphabet_size}")

	def __getitem__(self, idx):
		image_path = self.images_paths[idx]
		sample_name = str(image_path).split(os.sep)[-1].split(".")[0]
		text_path = self.root_path + "/transcripts/" + str(int(sample_name) + 1) + ".txt"

		image = Image.open(image_path).convert("RGB")
		with open(text_path) as f:
			text = f.read()

		image = self.transforms[self.type](image)
		text_tensor = self.wordToTensor(text)
		return image, (text_tensor, len(text), text)
		

	def __len__(self):
		return len(self.images_paths)

	def letterToIndex(self, letter):
		return self.alphabet.find(letter)

	def letterToTensor(self, letter):
		tensor = torch.zeros(1, n_letters)
		tensor[0][letterToIndex(letter)] = 1
		return tensor

	def wordToTensor(self, word):
		tensor = torch.zeros(MAX_WORD_LENGTH, self.alphabet_size)
		# for li, letter in enumerate(word):
		# 	tensor[li][self.letterToIndex(letter)] = 1

		tensor = torch.full((MAX_WORD_LENGTH,), len(ALPHABET) - 1)
		# tensor = torch.zeros(MAX_WORD_LENGTH)
		for li, letter in enumerate(word):
			tensor[li] = self.letterToIndex(letter)
		return tensor


In [5]:
class DataHandler:
	def __init__(self, run_config):
		self._training_dataset = None
		self._validation_dataset = None
		self._run_config = run_config

		self._load_datasets()
		
	def _load_datasets(self):
		self._training_dataset = CustomDataset("dataset/training")
		self._validation_dataset = CustomDataset("dataset/validation")

	def get_data_loaders(self) -> Tuple[DataLoader]:
		return (
			DataLoader(self._training_dataset, batch_size=self._run_config["batch_size"], shuffle=True, pin_memory=True, drop_last=True), 
			DataLoader(self._validation_dataset, batch_size=self._run_config["batch_size"], shuffle=False, pin_memory=True, drop_last=True)
		)

	def get_datasets(self) -> Tuple[Dataset]:
		return self._training_dataset, self._validation_dataset

	def get_datasets_sizes(self) -> Tuple[int]:
		return len(self._training_dataset), len(self._validation_dataset)

In [6]:
data_handler = DataHandler(run_config = {
    "batch_size": BATCH_SIZE,
    "workers": WORKERS
})
train_loader, validation_loader = data_handler.get_data_loaders()
training_dataset_size, validation_dataset_size = data_handler.get_datasets_sizes()

Alphabet size: 63
Alphabet size: 63


In [7]:
class TranscribeModel(nn.Module):
    def __init__(self):
        super(TranscribeModel, self).__init__()
        self.conv_block1 = nn.Sequential(
			nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
			nn.MaxPool2d(kernel_size=(2, 2)),
			nn.BatchNorm2d(16),
			nn.LeakyReLU(0.2, inplace=True),

			nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
			nn.MaxPool2d(kernel_size=(2, 2)),
			nn.BatchNorm2d(32),
			nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
			nn.MaxPool2d(kernel_size=(2, 2)),
			nn.BatchNorm2d(64),
			nn.LeakyReLU(0.2, inplace=True)
		)
        
        self.linear_block1 = nn.Sequential(
            nn.Linear(1536, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 630),
        )

        # self.rnn_block1 = nn.Sequential(
        #     nn.LSTM(input_size=OUTPUT_STEP_SIZE * 64, hidden_size=len(ALPHABET), num_layers=2, batch_first=True, bidirectional=True)
        # )

        self.softmax = nn.LogSoftmax(dim=2)

    def forward(self, input):
        out = self.conv_block1(input)
        # print(out.shape) #[128, 32, 10, 50]
        # out = out.permute([0, 3, 2, 1])
        # (B, S, ) 
        # out = out.reshape(out.size(0), out.size(1), -1)

        # out, (_, _) = self.rnn_block1(out)

        out = self.linear_block1(out.view(BATCH_SIZE, -1))
        out = out.view(BATCH_SIZE, 10, len(ALPHABET))
        out = self.softmax(out)
        
        return out

In [8]:
model = TranscribeModel()
model.to(device)
optimizer = optim.AdamW(
    model.parameters(), 
    lr=0.0001, 
    betas=(0.9, 0.999), 
    eps=1e-08, 
    weight_decay=1e-4
)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
# loss_criterion = nn.CTCLoss(blank=len(ALPHABET)-1, zero_infinity=False, reduction="mean")
loss_criterion = nn.NLLLoss()
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=10, epochs=EPOCHS, anneal_strategy='linear')

log_dir = "logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(log_dir)

TranscribeModel(
  (conv_block1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (linear_block1

In [9]:
def tensorToWord(tensor):
	# tensor = tensor[..., :63]
	tensor = tensor.permute([1, 0, 2])
	indices = torch.argmax(tensor, dim=2).tolist()
	
	words = []
	for batch_idx in range(BATCH_SIZE):
		cur_word_indices = indices[batch_idx]
		cur_word = []
		last_letter = None
		for idx in range(OUTPUT_SEQUENCE_LENGTH):
			if ALPHABET[cur_word_indices[idx]] == ALPHABET[-1]:
				last_letter = None
				continue
			else:
				if last_letter == None or (last_letter is not None and last_letter != cur_word_indices[idx]):
					cur_word.append(ALPHABET[cur_word_indices[idx]])
					last_letter = cur_word_indices[idx]
					continue
		words.append("".join(cur_word))
	return words

def tensorToWordSync(tensor):
	indices = torch.argmax(tensor, dim=2).tolist()
	
	words = []
	for batch_idx in range(BATCH_SIZE):
		cur_word_indices = indices[batch_idx]
		cur_word = []
		last_letter = None
		for idx in range(OUTPUT_SEQUENCE_LENGTH):
			cur_word.append(ALPHABET[cur_word_indices[idx]])

		words.append("".join(cur_word))
	return words


In [10]:
def levenshteinDistance(string_one, string_two):
    dist = np.zeros((len(string_one) + 1, len(string_two) + 1))
    dist[1:, 0] = [i + 1 for i in range(len(string_one))]
    dist[0, 1:] = [i + 1 for i in range(len(string_two))]

    for row in range(1, len(string_one) + 1):
        for col in range(1, len(string_two) + 1):
            gain = 1
            if string_one[row - 1] == string_two[col -1]:
                dist[row, col] = dist[row - 1, col -1]
            else:
                dist[row, col] = gain + min(dist[row - 1, col], dist[row, col - 1], dist[row - 1, col - 1])
    
    return dist[-1,-1]

def batchLevenshteinDistance(arr_one, arr_two):
    cumulative_distance = 0
    for idx in range(len(arr_one)):
        cumulative_distance += levenshteinDistance(arr_one[idx], arr_two[idx])
    return cumulative_distance


In [15]:
image, (label_tensor, text_length, label_text) = iter(train_loader).next()
image, label_tensor = image.to(device), label_tensor.to(device)
output = model(image)
print(text_length)

tensor([ 2,  5,  8,  7,  8,  3,  3,  2,  3,  4, 10,  2,  1,  7,  4,  4,  1,  6,
         3,  1,  5,  8,  9,  7,  3,  1,  9,  2,  6,  9,  3,  8,  7,  3, 10,  1,
         5, 10,  6,  5,  7,  1,  4,  8, 10,  7,  7,  2,  5,  3,  4,  9,  8,  2,
         5,  7,  6,  4,  9,  7,  5,  6,  8,  5])


In [18]:
output.shape
out = output[..., :63]
# torch.argmax(out, dim=2)
print(">>>>>>> label_text <<<<<<<")
print(label_text)

# print(">>>>>>> label_tensor <<<<<<<")
# print(label_tensor)

# print(">>>>>>> output <<<<<<<<")
# print(output)

tensorToWordSync(out)

torch.Size([64, 10, 63])

>>>>>>> label_text <<<<<<<
['ik', '4rhXn', 'PB3zCbGT', 'xe2eW10', 'Ut4G0uua', 'xZf', 'SU6', 'HJ', 'nFS', 'iYcy', 'vAxt9ajHtU', 'cZ', 'P', 'dvW64HK', 'BfTg', 'egcK', 'r', 'yKNiUD', 'bOW', 'M', 'RNSbN', 'bfZoak1g', 'fqyDk9cC7', 'NQQWIZ2', 'm0j', 'f', 'bfl2mKpGJ', 'Gy', 'AxGQju', 'Fk2QkYZxu', 'UnG', 'MtSEqwCx', 'ILTDgd9', 'BPR', 'ukirCOsylM', 's', 'pGXPK', 'dRpznLKvSB', 'CIdaSu', 'aOkEr', '1yuNT4f', 'U', 'XBy9', 'CK5HJWil', 'JjlOcV9qMN', 'qocGOMS', 'AeEuIEQ', 'JO', 'I8cBd', 'tio', 'X2bn', '93GhZkCXr', 'B2gonagN', '2A', 'XYAZJ', 'opghYSM', 'z5r4QQ', 'zdAw', 'D71MLSeAd', 'oX51Vxh', 'NlMbz', 'uZ1p39', 'qU285CGw', 'BOk5U']


['hl________',
 'ajX0______',
 'bdgs1bAT__',
 'xcZgYVT___',
 'UMfR0wuS__',
 'xZ________',
 'VSk_______',
 'i1________',
 'wFX_______',
 'yVZ_______',
 'vAxwegjHt_',
 'cz________',
 'P_________',
 'qNWReOQ___',
 'BYTg______',
 'axX_______',
 'P_________',
 'kQOINJ____',
 'gOW_______',
 'Mt________',
 'Ufe5NV____',
 'bSgaA3____',
 'MdjDbxcC7_',
 'KQBWNQCT__',
 'wdj_______',
 'e_________',
 'bfXomKpGC_',
 'Gy________',
 'XXGXM_____',
 'FxE3gTT9u_',
 'yne_______',
 '8BEEgwCx__',
 'XKT2pB9___',
 'BFB_______',
 'uMcZ7wmyM_',
 'e_________',
 'hSZZA_____',
 'dDBmaLJJSB',
 'Cmmk9_____',
 'aOpE______',
 'Vr9TT4____',
 'U_________',
 'iBjq______',
 'ODVHAJiw__',
 '2yFFOnvWKN',
 '64K41MJ___',
 'AeEk9GX___',
 'JD________',
 'HpEg______',
 'Kp________',
 'XSgn______',
 '9TQkSbCO__',
 'ScEmmkg___',
 'ZA________',
 'lrXZ2_____',
 '8egY1Ul___',
 'cBoGQ_____',
 'ERzw______',
 'RY1ML5ew4_',
 'bXZY73____',
 'Amnbs_____',
 'nZJg3K____',
 '8UzS9SGQI_',
 'XBk5A_____']

In [14]:
since = time.time()
for epoch in range(EPOCHS):
	print('Epoch {}/{}'.format(epoch, EPOCHS))
	print('-' * 10)

	########### Training step ###########
	model = model.train()
	training_loss = []
	running_loss = 0.0
	running_corrects = 0
			
	for i, data in enumerate(tqdm(train_loader, desc=f"Epoch [{epoch + 1}] progress")):

		x_batch, (label_batch, label_length, label_text) = data
		x_batch, label_batch, label_length = x_batch.to(device), label_batch.to(device), label_length.to(device)

		optimizer.zero_grad()
		outputs = model(x_batch)
		outputs_permuted = outputs.permute((0, 2, 1))

		# loss = loss_criterion(outputs, label_batch, torch.full((BATCH_SIZE,), OUTPUT_SEQUENCE_LENGTH).to(device), label_length)
		loss = loss_criterion(outputs_permuted, label_batch)
		loss.backward()
		optimizer.step()

		# statistics
		running_loss += loss.item() * x_batch.size(0)
		predictions = tensorToWordSync(outputs)
		running_corrects += batchLevenshteinDistance(predictions, label_text)
		training_loss.append(loss.item())

	epoch_loss = running_loss / training_dataset_size
	epoch_acc = running_corrects / training_dataset_size

	# tensorboard logging
	writer.add_scalar("Loss/train", epoch_loss, epoch)

	print('Training step => Loss: {:.4f} | Dist: {:.4f}'.format(
		epoch_loss, epoch_acc
	))

	scheduler.step()


	########### Validation step ###########
	model = model.eval()
	validation_loss = []
	running_loss = 0.0
	running_corrects = 0

	for i, data in enumerate(validation_loader):
		with torch.no_grad():
			x_batch, (label_batch, label_length, label_text) = data
			x_batch, label_batch, label_length = x_batch.to(device), label_batch.to(device), label_length.to(device)

			outputs = model(x_batch)
			outputs_permuted = outputs.permute((0, 2, 1))
			# loss = loss_criterion(outputs, label_batch, torch.full((BATCH_SIZE,), OUTPUT_SEQUENCE_LENGTH).to(device), label_length)
			loss = loss_criterion(outputs_permuted, label_batch)

			running_loss += loss.item() * x_batch.size(0)
			
			predictions = tensorToWordSync(outputs)
			running_corrects += batchLevenshteinDistance(predictions, label_text)
			validation_loss.append(loss.item())
			
	epoch_loss = running_loss / validation_dataset_size
	epoch_acc = running_corrects / validation_dataset_size

	# tensorboard logging
	writer.add_scalar("Loss/validation", epoch_loss, epoch)

	print('Evaluation step => Loss: {:.4f} | Dist {:.4f}'.format(
		epoch_loss, epoch_acc
	))
	best_acc = 0
	#Save the best model based on accuracy
	if True:
		best_model_wts = copy.deepcopy(model.state_dict())

	#Checkpoint
	torch.save({
		"epoch": epoch,
		"model_state_dict": model.state_dict(),
		"optimizer_state_dict": optimizer.state_dict()
	}, "./checkpoints/ckp.pt")



time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
	time_elapsed // 60, time_elapsed % 60
))
print('Best (so far) validation Acc: {:4f}'.format(best_acc))

print('-' * 10)
print('### Final results ###\n')
print('Best validation Acc: {:4f}'.format(best_acc))

Epoch 0/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [1] progress', max=96.0, style=ProgressStyle(descri…


Training step => Loss: 2.4403 | Dist: 9.8790
Evaluation step => Loss: 2.2286 | Dist 9.2258
Epoch 1/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [2] progress', max=96.0, style=ProgressStyle(descri…


Training step => Loss: 2.3667 | Dist: 9.8240
Evaluation step => Loss: 2.1910 | Dist 9.1774
Epoch 2/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [3] progress', max=96.0, style=ProgressStyle(descri…


Training step => Loss: 2.3393 | Dist: 9.7768
Evaluation step => Loss: 2.1735 | Dist 9.1629
Epoch 3/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [4] progress', max=96.0, style=ProgressStyle(descri…


Training step => Loss: 2.3246 | Dist: 9.7544
Evaluation step => Loss: 2.1693 | Dist 9.1565
Epoch 4/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [5] progress', max=96.0, style=ProgressStyle(descri…


Training step => Loss: 2.3190 | Dist: 9.7455
Evaluation step => Loss: 2.1621 | Dist 9.1500
Epoch 5/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [6] progress', max=96.0, style=ProgressStyle(descri…


Training step => Loss: 2.3122 | Dist: 9.7337
Evaluation step => Loss: 2.1590 | Dist 9.1516
Epoch 6/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [7] progress', max=96.0, style=ProgressStyle(descri…


Training step => Loss: 2.3056 | Dist: 9.7248
Evaluation step => Loss: 2.1589 | Dist 9.1435
Epoch 7/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [8] progress', max=96.0, style=ProgressStyle(descri…


Training step => Loss: 2.3016 | Dist: 9.7194
Evaluation step => Loss: 2.1579 | Dist 9.1274
Epoch 8/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [9] progress', max=96.0, style=ProgressStyle(descri…


Training step => Loss: 2.2923 | Dist: 9.7006
Evaluation step => Loss: 2.1540 | Dist 9.1290
Epoch 9/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [10] progress', max=96.0, style=ProgressStyle(descr…


Training step => Loss: 2.2801 | Dist: 9.6784
Evaluation step => Loss: 2.1481 | Dist 9.1097
Epoch 10/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [11] progress', max=96.0, style=ProgressStyle(descr…


Training step => Loss: 2.2620 | Dist: 9.6468
Evaluation step => Loss: 2.1357 | Dist 9.0871
Epoch 11/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [12] progress', max=96.0, style=ProgressStyle(descr…


Training step => Loss: 2.2215 | Dist: 9.5889
Evaluation step => Loss: 2.1027 | Dist 9.0484
Epoch 12/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [13] progress', max=96.0, style=ProgressStyle(descr…


Training step => Loss: 2.1688 | Dist: 9.5135
Evaluation step => Loss: 2.0664 | Dist 8.9887
Epoch 13/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [14] progress', max=96.0, style=ProgressStyle(descr…


Training step => Loss: 2.0969 | Dist: 9.4008
Evaluation step => Loss: 2.0474 | Dist 8.9435
Epoch 14/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [15] progress', max=96.0, style=ProgressStyle(descr…


Training step => Loss: 2.0200 | Dist: 9.2844
Evaluation step => Loss: 1.9892 | Dist 8.8710
Epoch 15/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [16] progress', max=96.0, style=ProgressStyle(descr…


Training step => Loss: 1.9333 | Dist: 9.1565
Evaluation step => Loss: 1.9493 | Dist 8.7645
Epoch 16/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [17] progress', max=96.0, style=ProgressStyle(descr…


Training step => Loss: 1.8555 | Dist: 9.0213
Evaluation step => Loss: 1.9024 | Dist 8.7065
Epoch 17/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [18] progress', max=96.0, style=ProgressStyle(descr…


Training step => Loss: 1.7763 | Dist: 8.8635
Evaluation step => Loss: 1.8843 | Dist 8.6984
Epoch 18/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [19] progress', max=96.0, style=ProgressStyle(descr…


Training step => Loss: 1.6979 | Dist: 8.7240
Evaluation step => Loss: 1.8390 | Dist 8.5177
Epoch 19/20
----------


HBox(children=(FloatProgress(value=0.0, description='Epoch [20] progress', max=96.0, style=ProgressStyle(descr…


Training step => Loss: 1.6224 | Dist: 8.5813
Evaluation step => Loss: 1.8168 | Dist 8.4952
Training complete in 3m 3s
Best (so far) validation Acc: 0.000000
----------
### Final results ###

Best validation Acc: 0.000000
