In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms, models

from dataset import get_data_loader, VNOnDB, RIMES
from utils import ScaleImageByHeight, StringTransform

from PIL import ImageOps

# Reproducible
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
device = f'cuda' if torch.cuda.is_available() else 'cpu'


class CTCModel(nn.Module):
    def __init__(self, vocab):
        super().__init__()

        resnet = models.resnet18(pretrained=True)
        self.cnn = nn.Sequential(*list(resnet.children())[:-2])
        self.pool = nn.AdaptiveAvgPool2d((1, None))
        encoder_layer = nn.TransformerEncoderLayer(d_model=resnet.fc.in_features, nhead=8)
        self.encoder = nn.TransformerEncoder(encoder_layer, 1)
        self.character_distribution = nn.Linear(resnet.fc.in_features, vocab.size)

    def forward(self, images) -> torch.Tensor:
        '''
        Shapes:
        -------
            images: (N,C,H,W)
        '''
        images = self.cnn(images) # [B,C,H',W']
        images = self.pool(images) # [B,C,1,W']
        images.squeeze_(-2) # [B,C,W']
        images = images.permute(2,0,1) # [S=W',B,C]
        images = self.encoder(images) # [S,B,C]
        images = images.transpose(0,1) # [B,S,C]
        images = self.character_distribution(images) # [B,S,V]
        return images

In [13]:
weight_path = 'runs/06-04-2020_11-43-53_ctc/weights/weights_epoch=6_loss=0.412.pt'

checkpoint = torch.load(weight_path, map_location=device)
root_config = checkpoint['config']

config = root_config['common']

image_transform = transforms.Compose([
    ImageOps.invert,
    ScaleImageByHeight(config['scale_height']),
    transforms.Grayscale(3),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
])

val_loader = get_data_loader(config['dataset'],
                                'test',
                                config['batch_size'],
                                1,
                                image_transform,
                                False,
                                flatten_type=config.get('flatten_type', None))

if config['dataset'] in ['vnondb', 'vnondb_line']:
    vocab = VNOnDB.vocab
elif config['dataset'] == 'rimes':
    vocab = RIMES.vocab
model = CTCModel(vocab)
model.to(device)
model.load_state_dict(checkpoint['model'])
model.eval()


CTCModel(
  (cnn): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  

In [15]:
iter_loader = iter(val_loader)

In [16]:
batch = next(iter_loader)

In [49]:
def step_val(batch):
    with torch.no_grad():
        imgs, targets = batch.images.to(device), batch.labels.to(device)
        targets = targets + 1 # Leave index 0 for '<blank>'

        outputs = model(imgs) # [B,S,V]
        outputs = F.log_softmax(outputs, -1) # [B,S,V]
        outputs_lengths = torch.tensor(outputs.size(1)).expand(outputs.size(0))
        return outputs.argmax(-1)

outputs = step_val(batch)

In [50]:
print(batch.labels[:, 1:-1])

tensor([[ 3,  7,  7, 11,  4,  9,  3, 12,  2,  8, 32],
        [14, 32, 13,  3, 12,  2,  8,  5,  6,  0,  0],
        [61, 61, 42, 38,  6,  0,  0,  0,  0,  0,  0],
        [73, 64,  6,  0,  0,  0,  0,  0,  0,  0,  0],
        [42, 61,  6,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 8,  5,  6,  0,  0,  0,  0,  0,  0,  0,  0],
        [55, 39,  6,  0,  0,  0,  0,  0,  0,  0,  0],
        [11,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0]])


In [66]:
def int2char_label(labels):
    results = []
    for label in labels:
        label = list(map(RIMES.vocab.int2char, label))
        results.append(label[:label.index(RIMES.vocab.EOS)])
    return results
labels = int2char_label(batch.labels[:,1:].tolist())

In [62]:
def int2char_output(outputs):
    results = []
    for output in outputs:
        output = [RIMES.vocab.int2char(i - 1) if i != 0 else '<b>' for i in output]
        results.append(output)
    return results
predicts = int2char_output(outputs.tolist())

In [69]:
for pred, tgt in zip(predicts, labels):
    print('-'*10)
    print(pred)
    print(tgt)

----------
['c', 'o', 'n', 'm', 'm', 'n', 'a', 't', 'i', 'c', 'a', 'l', 'l']
['i', 'm', 'm', 'a', 't', 'r', 'i', 'c', 'u', 'l', 'é']
----------
['v', 'é', 'h', 'i', 'c', 'u', 'u', 'l', 'e', 'e', '<b>', '<b>', '<b>']
['v', 'é', 'h', 'i', 'c', 'u', 'l', 'e']
----------
['7', '<b>', '7', '<b>', '5', '<b>', '2', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>']
['7', '7', '5', '2']
----------
['u', '<b>', 't', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>']
['W', 'Y']
----------
['5', '<b>', '<b>', '7', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>']
['5', '7']
----------
['l', 'e', 'e', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>']
['l', 'e']
----------
['1', '<b>', '<b>', '0', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>']
['1', '0']
----------
['a', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>', '<b>']
['a']
