In [1]:
import torch
import torch.nn
import torch.nn.functional as F

In [2]:
from model.feature_extractor import DenseNetFE
from model.transformer import Transformer

In [3]:
device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

In [4]:
# import os
# %matplotlib inline
# import matplotlib.pyplot as plt
# import skimage
# from matplotlib import cm
# import numpy as np
from data import get_data_loader, get_vocab, SOS_CHAR, EOS_CHAR, PAD_CHAR
from torchvision import transforms
from utils import ScaleImageByHeight

from inference_tf import inference

In [5]:
CKPT = './runs/09-02-2020_00-10-39_tf_encoder_8_1_decoder_10_1_both_8/weights/BEST_weights.pt'

In [6]:
print('Device = {}'.format(device))
print('Load weight from {}'.format(CKPT))
checkpoint = torch.load(CKPT, map_location=device)
config = checkpoint['config']

cnn = DenseNetFE(config['depth'],
                 config['n_blocks'],
                 config['growth_rate'])

vocab = get_vocab(config['dataset'])
model = Transformer(cnn, vocab.vocab_size, config['attn_size'],
                    config['encoder_nhead'], config['decoder_nhead'], config['both_nhead'],
                    config['encoder_nlayers'], config['decoder_nlayers'])
model.to(device)

model.load_state_dict(checkpoint['model'])

test_transform = transforms.Compose([
    transforms.Grayscale(3),
    ScaleImageByHeight(config['scale_height']),
    transforms.ToTensor(),
])

test_loader = get_data_loader(config['dataset'], 'test', config['batch_size'],
                              test_transform, vocab)

model.eval()
batch = next(iter(test_loader))
with torch.no_grad():
    inference(model, batch, vocab, device)

Device = cuda:3
Load weight from ./runs/09-02-2020_00-10-39_tf_encoder_8_1_decoder_10_1_both_8/weights/BEST_weights.pt
> /home/aioz-interns-1/working-dir/loi/model/transformer.py(87)greedy()
-> for t in range(max_length):


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(88)greedy()
-> output, weight = self.decoder.forward(predicts, image_features)


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(89)greedy()
-> output = self.character_distribution(output)


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(90)greedy()
-> output = F.softmax(output, -1)


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(91)greedy()
-> index = output.topk(1, -1)[1]


(Pdb)  output.shape


torch.Size([1, 32, 150])


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(92)greedy()
-> output = torch.zeros_like(output)


(Pdb)  index.shape


torch.Size([1, 32, 1])


(Pdb)  index


tensor([[[ 57],
         [ 25],
         [ 32],
         [ 51],
         [ 57],
         [ 32],
         [ 24],
         [ 32],
         [ 49],
         [ 40],
         [ 66],
         [ 44],
         [ 51],
         [ 49],
         [ 26],
         [ 32],
         [ 44],
         [ 80],
         [ 34],
         [ 32],
         [ 51],
         [ 14],
         [ 32],
         [ 56],
         [ 24],
         [ 16],
         [ 16],
         [106],
         [ 60],
         [ 49],
         [136],
         [  1]]], device='cuda:3')


(Pdb)  index.squeeze()


tensor([ 57,  25,  32,  51,  57,  32,  24,  32,  49,  40,  66,  44,  51,  49,
         26,  32,  44,  80,  34,  32,  51,  14,  32,  56,  24,  16,  16, 106,
         60,  49, 136,   1], device='cuda:3')


(Pdb)  index.squeeze(-1)


tensor([[ 57,  25,  32,  51,  57,  32,  24,  32,  49,  40,  66,  44,  51,  49,
          26,  32,  44,  80,  34,  32,  51,  14,  32,  56,  24,  16,  16, 106,
          60,  49, 136,   1]], device='cuda:3')


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(93)greedy()
-> output.scatter_(-1, index, 1)


(Pdb)  output


tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:3')


(Pdb)  output.max()


tensor(0., device='cuda:3')


(Pdb)  output.min()


tensor(0., device='cuda:3')


(Pdb)  output.shape


torch.Size([1, 32, 150])


(Pdb)  index.shape


torch.Size([1, 32, 1])


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(94)greedy()
-> predicts = torch.cat([predicts, output], dim=0)


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(95)greedy()
-> weights.append(weight)


(Pdb)  predicts


tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 0.]]], device='cuda:3')


(Pdb)  predicts.transpose(0,1).topk(1, -1)[1]


tensor([[[ 12],
         [ 57]],

        [[ 12],
         [ 25]],

        [[ 12],
         [ 32]],

        [[ 12],
         [ 51]],

        [[ 12],
         [ 57]],

        [[ 12],
         [ 32]],

        [[ 12],
         [ 24]],

        [[ 12],
         [ 32]],

        [[ 12],
         [ 49]],

        [[ 12],
         [ 40]],

        [[ 12],
         [ 66]],

        [[ 12],
         [ 44]],

        [[ 12],
         [ 51]],

        [[ 12],
         [ 49]],

        [[ 12],
         [ 26]],

        [[ 12],
         [ 32]],

        [[ 12],
         [ 44]],

        [[ 12],
         [ 80]],

        [[ 12],
         [ 34]],

        [[ 12],
         [ 32]],

        [[ 12],
         [ 51]],

        [[ 12],
         [ 14]],

        [[ 12],
         [ 32]],

        [[ 12],
         [ 56]],

        [[ 12],
         [ 24]],

        [[ 12],
         [ 16]],

        [[ 12],
         [ 16]],

        [[ 12],
         [106]],

        [[ 12],
         [ 60]],

        [[ 12]

(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(87)greedy()
-> for t in range(max_length):


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(88)greedy()
-> output, weight = self.decoder.forward(predicts, image_features)


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(89)greedy()
-> output = self.character_distribution(output)


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(90)greedy()
-> output = F.softmax(output, -1)


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(91)greedy()
-> index = output.topk(1, -1)[1]


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(92)greedy()
-> output = torch.zeros_like(output)


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(93)greedy()
-> output.scatter_(-1, index, 1)


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(94)greedy()
-> predicts = torch.cat([predicts, output], dim=0)


(Pdb)  n


> /home/aioz-interns-1/working-dir/loi/model/transformer.py(95)greedy()
-> weights.append(weight)


(Pdb)  predicts.transpose(0,1).topk(1,-1)[1]


tensor([[[ 12],
         [ 57],
         [ 57]],

        [[ 12],
         [ 25],
         [ 25]],

        [[ 12],
         [ 32],
         [ 32]],

        [[ 12],
         [ 51],
         [ 51]],

        [[ 12],
         [ 57],
         [ 57]],

        [[ 12],
         [ 32],
         [ 32]],

        [[ 12],
         [ 24],
         [ 24]],

        [[ 12],
         [ 32],
         [ 32]],

        [[ 12],
         [ 49],
         [ 49]],

        [[ 12],
         [ 40],
         [ 40]],

        [[ 12],
         [ 66],
         [ 66]],

        [[ 12],
         [ 44],
         [ 44]],

        [[ 12],
         [ 51],
         [ 51]],

        [[ 12],
         [ 49],
         [ 49]],

        [[ 12],
         [ 26],
         [ 26]],

        [[ 12],
         [ 32],
         [ 32]],

        [[ 12],
         [ 44],
         [ 44]],

        [[ 12],
         [ 80],
         [ 80]],

        [[ 12],
         [ 34],
         [ 34]],

        [[ 12],
         [ 32],
         [ 32]],



(Pdb)  q


BdbQuit: 

# Visualize a sample

In [None]:
sample_index = 1
sample_image, sample_predict, sample_weigth = imgs[sample_index], predicts_str[sample_index], weights[:, [sample_index]]
fig, axeses = plt.subplots(len(sample_predict), figsize=(15,15), sharex=True, sharey=True)

for i, axes in enumerate(axeses.ravel()):
    weight = weights[i].reshape(-1, config['batch_size'], img_rows // 16, img_cols // 16) # 16 is factor that DenseNet reduce the original image size
    weight_numpy = weight.cpu().numpy()[:,sample_index,:].squeeze()
    weight_image = skimage.transform.resize(weight_numpy, (img_rows, img_cols))
    
    img = sample_image.squeeze().permute(1,2,0).cpu().numpy()[:,:,0]
    
    alpha = 0.5
    blend = img * alpha + weight_image * (1-alpha)
    
    axes.set_title(sample_predict[i])
    axes.imshow(blend, cmap='spring')
plt.plot();

# Calc CER, WER on the test set

## CER

In [29]:
import editdistance as ed

In [30]:
log_test = open('./log_test.txt', 'w+')

In [31]:
total_characters = 0
total_words = 0
CE = 0
WE = 0
log_interval = 10

# t = tqdm(test_loader)
t = test_loader
with torch.no_grad():
    for i, (imgs, targets, targets_onehot, lengths) in enumerate(t):
        print(f'[{i}]/[{len(t)}]', file=log_test)
        log_test.flush()
        batch_size = imgs.size(0)
        
        imgs = imgs.to(device)
        img_features = encoder(imgs)
        targets_onehot = targets_onehot[1:].to(device)
        targets = targets[1:].to(device)
        lengths = lengths - 1
        outputs = decoder.forward(img_features, targets_onehot, targets, lengths, char2int[PAD_CHAR])
        
        _, index = outputs.topk(1, -1)
        predicts = index.squeeze().transpose(0, 1) # [B, T]
        predicts_str = []
        for predict in predicts:
            s = [int2char[x.item()] for x in predict]
            try:
                eos_index = s.index(EOS_CHAR) + 1
            except ValueError:
                eos_index = len(s)
            predicts_str.append(s[:eos_index])

        targets_str = []
        for target in targets.transpose(0, 1).squeeze():
            s = [int2char[x.item()] for x in target]
            try:
                eos_index = s.index(EOS_CHAR) + 1
            except ValueError:
                eos_index = len(s)
            targets_str.append(s[:eos_index])
        
        assert len(predicts_str) == len(targets_str)
        for j in range(len(predicts_str)):
            CE += ed.distance(predicts_str[j], targets_str[j])
        total_characters += lengths.sum().item()
        
        for j in range(len(predicts_str)):
            if not np.array_equal(np.array(predicts_str[j]), np.array(targets_str[j])):
                WE += 1
                print(f'Batch {i} - sample {j}: "{predicts_str[j]}"/"{targets_str[j]}"')
        total_words += len(predicts_str)
        
#         t.update()

Batch 0 - sample 25: "['u', 'n', '<end>']"/"['O', 'n', '<end>']"
Batch 0 - sample 26: "['u', 'n', '<end>']"/"['O', 'n', '<end>']"
Batch 2 - sample 31: "['u', 'n', '<end>']"/"['O', 'n', '<end>']"
Batch 8 - sample 31: "['u', 'n', '<end>']"/"['O', 'n', '<end>']"
Batch 9 - sample 13: "['s', 's', 'n', '<end>']"/"['s', 'ẵ', 'n', '<end>']"
Batch 15 - sample 27: "['u', 'n', '<end>']"/"['O', 'n', '<end>']"
Batch 16 - sample 26: "['u', 'n', '<end>']"/"['O', 'n', '<end>']"
Batch 17 - sample 30: "['u', 'n', '<end>']"/"['O', 'n', '<end>']"
Batch 24 - sample 30: "['u', 'n', '<end>']"/"['O', 'n', '<end>']"
Batch 28 - sample 21: "['s', 's', 'n', '<end>']"/"['s', 'ẵ', 'n', '<end>']"
Batch 29 - sample 25: "['u', 'n', '<end>']"/"['O', 'n', '<end>']"
Batch 31 - sample 27: "['M', 'M', '<end>']"/"['M', 'ỹ', '<end>']"
Batch 39 - sample 22: "['u', 'n', '<end>']"/"['O', 'n', '<end>']"
Batch 39 - sample 28: "['k', 'k', '<end>']"/"['k', 'ỳ', '<end>']"
Batch 39 - sample 31: "['y', 'n', '<end>']"/"['Â', 'n', '<end

In [32]:
CER = CE / total_characters
WER = WE / total_words
print('CER', CER, file=log_test)
print('WER', WER, file=log_test)
log_test.flush()

In [33]:
log_test.close()

In [34]:
print(CE, total_characters, CER)

164 108614 0.0015099342626180786


In [35]:
print(WE, total_words, WER)

156 25115 0.0062114274338044995
