# Load dataset from Google Drive

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pdb
import random
from tqdm.notebook import tqdm
from PIL import Image

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter
import torchvision
from torch.utils.data import Dataset, DataLoader

# Load Data

In [3]:
SOS_CHAR = '<start>' # start of sequence character
EOS_CHAR = '<end>' # end of sequence character
PAD_CHAR = '<pad>' # padding character

In [4]:
class VNOnDB(torch.utils.data.Dataset):
    def __init__(self, image_folder, csv, image_transform=None):
        self.df = pd.read_csv(csv, sep='\t', keep_default_na=False, index_col=0)
        self.image_folder = image_folder
        self.image_transform = image_transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.image_folder, self.df['id'][idx]+'.png')
        image = Image.open(image_path)
        
        if self.image_transform:
            image = self.image_transform(image)
        
        label = self.df['label'][idx]
        label = [SOS_CHAR] + list(label) + [EOS_CHAR]
            
        return image, label

In [5]:
class ScaleImageByHeight(object):
    def __init__(self, target_height):
        self.target_height = target_height

    def __call__(self, image):
        width, height = image.size
        factor = self.target_height / height
        new_width = int(width * factor)
        new_height = int(height * factor)
        image = image.resize((new_width, new_height))
        return image

In [6]:
image_transform = torchvision.transforms.Compose([
    torchvision.transforms.Grayscale(3),
    ScaleImageByHeight(32),
    torchvision.transforms.ToTensor(),
])

In [7]:
all_data_csv = './data/VNOnDB/all_word.csv'
train_data_csv = './data/VNOnDB/train_word.csv'
val_data_csv = './data/VNOnDB/validation_word.csv'
test_data_csv = './data/VNOnDB/test_word.csv'

train_image_folder = './data/VNOnDB/word_train'
val_image_folder = './data/VNOnDB/word_val'
test_image_folder = './data/VNOnDB/word_test'

train_data = VNOnDB(train_image_folder, train_data_csv, image_transform)
validation_data = VNOnDB(val_image_folder, val_data_csv, image_transform)

In [8]:
all_data_df = pd.read_csv(all_data_csv, sep='\t', keep_default_na=False, index_col=0)
alphabets = sorted(list(set.union(*all_data_df.label.apply(set))) + [SOS_CHAR, EOS_CHAR, PAD_CHAR])

char2int = dict((c, i) for i, c in enumerate(alphabets))
int2char = dict((i, c) for i, c in enumerate(alphabets))
vocab_size = len(alphabets)    

In [9]:
def collate_fn(samples):
    '''
    :param samples: list of tuples:
        - image: tensor of [C, H, W]
        - label: list of characters including '<start>' and '<end>' at both ends
    :returns:
        - images: tensor of [B, C, H, W]
        - labels: tensor of [max_T, B, 1]
        - lengths: tensor of [B, 1]
    '''
    batch_size = len(samples)
    samples.sort(key=lambda sample: len(sample[1]), reverse=True)
    image_samples, label_samples = list(zip(*samples))

    # images: [B, 3, H, W]
    max_image_row = max([image.size(1) for image in image_samples])
    max_image_col = max([image.size(2) for image in image_samples])
    images = torch.ones(batch_size, 3, max_image_row, max_image_col)
    for i, image in enumerate(image_samples):
        image_row = image.shape[1]
        image_col = image.shape[2]
        images[i, :, :image_row, :image_col] = image

    label_lengths = [len(label) for label in label_samples]
    max_length = max(label_lengths)
    label_samples = [label + [PAD_CHAR] * (max_length - len(label)) for label in label_samples]
    
    labels = torch.zeros(max(label_lengths), batch_size, 1, dtype=torch.long) # [max_T, B, 1]
    for i, label in enumerate(label_samples):
        label_int = torch.tensor([char2int[char] for char in label]).view(-1, 1) # [T, 1]
        labels[:, i] = label_int
        
    labels_onehot = torch.zeros(max(label_lengths), batch_size, vocab_size, dtype=torch.long) # [max_T, B, vocab_size]
    for label_i, label in enumerate(label_samples):
        for char_i, char in enumerate(label):
            char_int = char2int[char]
            onehot = torch.zeros(vocab_size, dtype=torch.long)
            onehot[char_int] = 1
            labels_onehot[char_i, label_i] = onehot

    return images, labels, labels_onehot, torch.tensor(label_lengths).view(-1, 1)

# Define model

In [10]:
class Encoder(nn.Module):
    def __init__(self, depth, n_blocks, growth_rate):
        super(Encoder, self).__init__()

        self.cnn = torchvision.models.DenseNet(
            growth_rate=growth_rate,
            block_config=[depth]*n_blocks
        ).features

        # TODO: fix me
        self.n_features = self.cnn.norm5.num_features
    
    def forward(self, inputs):
        '''
        :param inputs: [B, C, H, W]
        :returms: [num_pixels, B, C']
        '''
        batch_size = inputs.size(0)
        outputs = self.cnn(inputs) # [B, C', H', W']
        outputs = outputs.view(batch_size, self.n_features, -1) # [B, C', H' x W'] == [B, C', num_pixels]
        outputs = outputs.permute(2, 0, 1) # [num_pixels, B, C']
        return outputs

In [11]:
class Attention(nn.Module):
    def __init__(self, feature_size, hidden_size, attn_size):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.Wa = nn.Linear(feature_size, attn_size)
        self.Ua = nn.Linear(hidden_size, attn_size)
        self.va = nn.Linear(attn_size, 1)

    def forward(self, last_hidden, encoder_outputs):
        '''
        Input:
        :param last_hidden: [1, B, H]
        :param encoder_outputs: [num_pixels, B, C]
        Output:
        weights: [num_pixels, B, 1]
        '''
        attn1 = self.Wa(encoder_outputs) # [num_pixels, B, A]
        attn2 = self.Ua(last_hidden) # [1, B, A]
        attn = self.va(torch.tanh(attn1 + attn2)) # [num_pixels, B, 1]
        
        weights = F.softmax(attn.squeeze(2), 1).unsqueeze(2) # [num_pixels, B, 1]
        context = (weights * encoder_outputs).sum(0, keepdim=True) # [1, B, C]
        
        return context, weights

In [12]:
class Decoder(nn.Module):
    def __init__(self, feature_size, hidden_size, vocab_size, attn_size):
        super(Decoder, self).__init__()

        self.feature_size = feature_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.attn_size = attn_size

        self.rnn = nn.GRU(
            input_size=self.vocab_size+self.feature_size,
            hidden_size=self.hidden_size,
        )

        self.attention = Attention(
            self.feature_size,
            self.hidden_size,
            self.attn_size)

        self.character_distribution = nn.Linear(self.hidden_size, self.vocab_size)

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size)

    def forward(self, img_features, targets, teacher_forcing_ratio=0.5):
        '''
        :param img_features: tensor of [num_pixels, B, C]
        :param targets: tensor of [T, B, V], each target has <start> and <end> at begin and end of the word
        :return:
            outputs: tensor of [T, B, V]
            weights: tensor of [T, B, num_pixels]
        '''

        num_pixels = img_features.size(0)
        batch_size = img_features.size(1)
        max_length = targets.size(0)

        targets = targets.float()
        rnn_input = targets[[0]].float() # [1, B, V]
        hidden = self.init_hidden(batch_size).to(img_features.device)

        outputs = torch.zeros(max_length, batch_size, self.vocab_size, device=img_features.device)
        weights = torch.zeros(max_length, batch_size, num_pixels, device=img_features.device) 

        # pdb.set_trace()
        for t in range(max_length - 1):
            context, weight = self.attention(hidden, img_features) # [1, B, C], [num_pixels, B, 1]

            teacher_force = random.random() < teacher_forcing_ratio
            if self.training and teacher_force:
                rnn_input = torch.cat((targets[[t]], context), -1)
            else:
                rnn_input = torch.cat((rnn_input, context), -1)

            output, hidden = self.rnn(rnn_input, hidden)
            output = self.character_distribution(output)

            outputs[[t]] = output
            weights[[t]] = weight.transpose(0, 2)
            
            rnn_input = output
            
        return outputs, weights

# Training

In [13]:
def accuracy(outputs, targets):
    batch_size = outputs.size(0)
    _, ind = outputs.topk(1, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() / batch_size

In [14]:
class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [15]:
def train_one_epoch(epoch, train_loader, encoder, decoder, optimizer, criterion, writer, log_interval=100):
    global train_step
    encoder.train()
    decoder.train()
    
    losses = AverageMeter()
    accs = AverageMeter()
    
    for i, (imgs, targets, targets_onehot, lengths) in enumerate(train_loader):

        optimizer.zero_grad()

        imgs = imgs.to(device)
        targets = targets.to(device)
        targets_onehot = targets_onehot.to(device)

        img_features = encoder(imgs)
        outputs, weights = decoder(img_features, targets_onehot)

        packed_outputs = torch.nn.utils.rnn.pack_padded_sequence(outputs, lengths.squeeze())[0]
        packed_targets = torch.nn.utils.rnn.pack_padded_sequence(targets.squeeze(), lengths.squeeze())[0]
        
        loss = criterion(packed_outputs, packed_targets)
        acc = accuracy(packed_outputs, packed_targets)
        
        total_characters = lengths.sum().item()
        losses.update(loss, total_characters)
        accs.update(acc, total_characters)
        
        loss.backward()
        optimizer.step()

        train_step += 1
        writer.add_scalar('Train/Loss', loss.item(), train_step)
        writer.add_scalar('Train/Accuracy', acc, train_step)
        
        if (i+1) % log_interval == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Accuracy {accs.val:.3f} ({accs.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                    loss=losses,
                                                                    accs=accs))
    return losses.avg, accs.avg

In [16]:
def validate(epoch, val_loader, encoder, decoder, criterion, writer, log_interval=100):
    global val_step
    
    losses = AverageMeter()
    accs = AverageMeter()
    
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        for i, (imgs, targets, targets_onehot, lengths) in enumerate(val_loader):

            imgs = imgs.to(device)
            targets = targets.to(device)
            targets_onehot = targets_onehot.to(device)

            img_features = encoder(imgs)
            outputs, weights = decoder(img_features, targets_onehot)
            

            packed_outputs = torch.nn.utils.rnn.pack_padded_sequence(outputs, lengths.squeeze())[0]
            packed_targets = torch.nn.utils.rnn.pack_padded_sequence(targets.squeeze(), lengths.squeeze())[0]
            loss = criterion(packed_outputs, packed_targets)
            acc = accuracy(packed_outputs, packed_targets)

            total_characters = lengths.sum().item()
            losses.update(loss, total_characters)
            accs.update(acc, total_characters)

            val_step += 1
            writer.add_scalar('Validate/Loss', loss.item(), val_step)
            writer.add_scalar('Validate/Accuracy', acc, val_step)

            if (i+1) % log_interval == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Accuracy {accs.val:.3f} ({accs.avg:.3f})'.format(epoch, i, len(val_loader),
                                                                        loss=losses,
                                                                        accs=accs))
    return losses.avg, accs.avg

In [17]:
def save_checkpoint(epoch, train_step, val_step, encoder, decoder, optimizer, lr, best_val_acc, is_best=False):
    info = {
        'epoch': epoch,
        'train_step': train_step,
        'val_step': val_step,
        'encoder_state': encoder.state_dict(),
        'decoder_state': decoder.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'lr': lr,
        'best_val_acc': best_val_acc,
    }
    torch.save('./ckpt/weights.pt')
    if is_best:
        torch.save('./ckpt/best_weights.pt')

In [18]:
config = {
  'batch_size': 64,
  'hidden_size': 256,
  'attn_size': 256,
  'max_length': 10,
  'n_epochs_decrease_lr': 15,
  'start_learning_rate': 1e-5, # NOTE: paper start with 1e-8
  'end_learning_rate': 1e-11,
  'depth': 4,
  'n_blocks': 3,
  'growth_rate': 96,
}

In [19]:
train_loader = DataLoader(train_data, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=8)
val_loader = DataLoader(validation_data, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=8)

In [20]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [21]:
encoder = Encoder(config['depth'], config['n_blocks'], config['growth_rate']).to(device)
decoder = Decoder(encoder.n_features, config['hidden_size'], vocab_size, config['attn_size']).to(device)

In [22]:
params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = optim.Adam(params, lr=config['start_learning_rate'])
criterion = nn.CrossEntropyLoss().to(device)

In [23]:
writer = SummaryWriter()
train_step = 0
val_step = 0

# Train from scratch

In [None]:
log_interval = 200
epoch = 0
best_val_acc = 0
count_decrease_lr = 0
lr = config['start_learning_rate']
while True:
    epoch += 1 
    train_loss, train_acc = train_one_epoch(epoch, train_loader, encoder, decoder, optimizer, criterion, writer, log_interval)
    val_loss, val_acc = validate(epoch, val_loader, encoder, decoder, criterion, writer, log_interval)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        count_decrease_lr = 0
        save_checkpoint(epoch, train_step, val_step, encoder, decoder, optimizer, lr, best_val_acc, True)
    else:
        count_decrease_lr += 1
        if count_decrease_lr == config['n_epochs_decrease_lr']:
            lr = lr * 0.1
            print('Decrease learning rate to', lr)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            count_decrease_lr = 0

    save_checkpoint(epoch, train_step, val_step, encoder, decoder, optimizer, lr, best_val_acc, False)
    if lr <= config['end_learning_rate']:
        print('Training done')

Epoch: [1][199/1047]	Loss 4.2969 (4.7703)	Accuracy 0.181 (0.133)
Epoch: [1][399/1047]	Loss 3.5187 (4.3190)	Accuracy 0.370 (0.204)
Epoch: [1][599/1047]	Loss 3.1056 (3.9839)	Accuracy 0.372 (0.257)
Epoch: [1][799/1047]	Loss 2.8779 (3.7405)	Accuracy 0.367 (0.284)
