In [1]:
import torch
import torch.nn
import torch.nn.functional as F

In [2]:
from model.encoder import Encoder
from model.decoder import Decoder

In [3]:
device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

In [4]:
import os
%matplotlib inline
import matplotlib.pyplot as plt
import skimage
from matplotlib import cm
import numpy as np
from dataset import get_dataset, collate_fn, vocab_size, int2char, char2int, SOS_CHAR, EOS_CHAR
from torch.utils.data import DataLoader
from torchvision import transforms
from utils import ScaleImageByHeight, PaddingWidth, AverageMeter, accuracy

In [5]:
config = {
    'batch_size': 32,
    'hidden_size': 256,
    'attn_size': 256,
    'max_length': 10,
    'n_epochs_decrease_lr': 15,
    'start_learning_rate': 1e-5,  # NOTE: paper start with 1e-8
    'end_learning_rate': 1e-11,
    'depth': 4,
    'n_blocks': 3,
    'growth_rate': 96,
}

In [6]:
MAX_LENGTH = config['max_length']
CKPT_DIR = './ckpt'

In [7]:
info = torch.load(os.path.join(CKPT_DIR, 'BEST_weights.pt'), map_location='cpu')

In [8]:
image_transform = transforms.Compose([
    transforms.Grayscale(3),
    ScaleImageByHeight(128),
    transforms.ToTensor(),
])

In [9]:
test_data = get_dataset('test', image_transform)
test_loader = DataLoader(test_data, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=2)

In [10]:
encoder = Encoder(config['depth'], config['n_blocks'], config['growth_rate'])
encoder.load_state_dict(info['encoder'])

<All keys matched successfully>

In [11]:
decoder = Decoder(encoder.n_features,
                  config['hidden_size'], vocab_size, config['attn_size'])
decoder.load_state_dict(info['decoder'])

<All keys matched successfully>

In [None]:
encoder = encoder.to(device)
decoder = decoder.to(device)

In [None]:
encoder.eval()
decoder.eval();

In [None]:
test_iter = iter(test_loader)

In [None]:
imgs, targets, targets_onehot, lengths = next(test_iter)

In [None]:
imgs.size()

In [None]:
plt.imshow(imgs[0].squeeze().permute(1,2,0))

In [None]:
targets.size()

In [None]:
''.join([int2char[x.item()] for x in targets[:,0].squeeze()])

In [None]:
start_input = torch.zeros(1, config['batch_size'], vocab_size)
start_input[0,0, char2int[SOS_CHAR]] = 1
start_input = start_input.to(device)

In [None]:
with torch.no_grad():
#     for i, (imgs, targets, targets_onehot, lengths) in enumerate(val_loader):
    imgs = imgs.to(device)
    img_features = encoder(imgs)
    outputs, weights = decoder.greedy(img_features, start_input)

In [None]:
outputs.size()

In [None]:
outputs

In [None]:
_, index = outputs.topk(1, -1)

In [None]:
index.size()

In [None]:
predicts = index.squeeze().transpose(0, 1) # [B, T]
predicts_str = []
for predict in predicts:
    s = [int2char[x.item()] for x in predict]
    try:
        eos_index = s.index(EOS_CHAR) + 1
    except ValueError:
        eos_index = len(s)
    predicts_str.append(s[:eos_index])

predicts_str

In [None]:
weights.size()

In [None]:
img_rows, img_cols = imgs.size(2), imgs.size(3)
print(img_rows, img_cols)

In [None]:
length = len(predicts)
length

# Visualize a sample

In [None]:
sample_index = 1
sample_image, sample_predict, sample_weigth = imgs[sample_index], predicts_str[sample_index], weights[:, [sample_index]]
fig, axeses = plt.subplots(len(sample_predict), figsize=(15,15), sharex=True, sharey=True)

for i, axes in enumerate(axeses.ravel()):
    weight = weights[i].reshape(-1, config['batch_size'], img_rows // 16, img_cols // 16) # 16 is factor that DenseNet reduce the original image size
    weight_numpy = weight.cpu().numpy()[:,sample_index,:].squeeze()
    weight_image = skimage.transform.resize(weight_numpy, (img_rows, img_cols))
    
    img = sample_image.squeeze().permute(1,2,0).cpu().numpy()[:,:,0]
    
    alpha = 0.5
    blend = img * alpha + weight_image * (1-alpha)
    
    axes.set_title(sample_predict[i])
    axes.imshow(blend, cmap='spring')
plt.plot();

# Calc CER, WER on the test set

## CER

In [None]:
import editdistance as ed

In [None]:
log_test = open('./log_test.txt', 'w+')

In [None]:
total_characters = 0
total_words = 0
CE = 0
WE = 0
log_interval = 10

# t = tqdm(test_loader)
t = test_loader
with torch.no_grad():
    for i, (imgs, targets, targets_onehot, lengths) in enumerate(t):
        print(f'[{i}]/[{len(t)}]', file=log_test)
        log_test.flush()
        batch_size = imgs.size(0)
        
        start_input = torch.zeros(1, batch_size, vocab_size)
        start_input[0,0, char2int[SOS_CHAR]] = 1
        start_input = start_input.to(device)
        imgs = imgs.to(device)
        
        img_features = encoder(imgs)
        outputs, weights = decoder.greedy(img_features, start_input)
        
        _, index = outputs.topk(1, -1)
        predicts = index.squeeze().transpose(0, 1) # [B, T]
        predicts_str = []
        for predict in predicts:
            s = [int2char[x.item()] for x in predict]
            try:
                eos_index = s.index(EOS_CHAR) + 1
            except ValueError:
                eos_index = len(s)
            predicts_str.append(s[:eos_index])

        targets_str = []
        for target in targets.transpose(0, 1).squeeze():
            s = [int2char[x.item()] for x in target]
            try:
                eos_index = s.index(EOS_CHAR) + 1
            except ValueError:
                eos_index = len(s)
            targets_str.append(s[:eos_index])
        
        assert len(predicts_str) == len(targets_str)
        for i in range(len(predicts_str)):
            CE += ed.distance(predicts_str[i], targets_str[i])
        total_characters += lengths.sum().item()
        
        for i in range(len(predicts_str)):
            WE += 1 if np.array_equal(np.array(predicts_str[i]), np.array(targets_str[i])) else 0
        total_words += len(predicts_str)
        
#         t.update()

In [None]:
CER = CE / total_characters
WER = WE / total_words
print('CER', CER, file=log_test)
print('WER', WER, file=log_test)
log_test.flush()

In [None]:
log_test.close()