# Evaluate

In [1]:
import torch
from torch.utils.data import DataLoader
from torch.nn import CTCLoss
from tqdm import tqdm

from dataset import Synth90kDataset, synth90k_collate_fn
from model import CRNN
from ctc_decoder import ctc_decode
from config import evaluate_config as config

torch.backends.cudnn.enabled = False

In [2]:
def evaluate(crnn, dataloader, criterion,
             max_iter=None, decode_method='beam_search', beam_size=10):
    crnn.eval()

    tot_count = 0
    tot_loss = 0
    tot_correct = 0
    wrong_cases = []

    pbar_total = max_iter if max_iter else len(dataloader)
    pbar = tqdm(total=pbar_total, desc="Evaluate")

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            if max_iter and i >= max_iter:
                break
            device = 'cuda' if next(crnn.parameters()).is_cuda else 'cpu'

            images, targets, target_lengths = [d.to(device) for d in data]

            logits = crnn(images)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)

            batch_size = images.size(0)
            input_lengths = torch.LongTensor([logits.size(0)] * batch_size)

            loss = criterion(log_probs, targets, input_lengths, target_lengths)

            preds = ctc_decode(log_probs, method=decode_method, beam_size=beam_size)
            reals = targets.cpu().numpy().tolist()
            target_lengths = target_lengths.cpu().numpy().tolist()

            tot_count += batch_size
            tot_loss += loss.item()
            target_length_counter = 0
            for pred, target_length in zip(preds, target_lengths):
                real = reals[target_length_counter:target_length_counter + target_length]
                target_length_counter += target_length
                if pred == real:
                    tot_correct += 1
                else:
                    wrong_cases.append((real, pred))

            pbar.update(1)
        pbar.close()

    evaluation = {
        'loss': tot_loss / tot_count,
        'acc': tot_correct / tot_count,
        'wrong_cases': wrong_cases
    }
    return evaluation

In [3]:
def main():
    eval_batch_size = config['eval_batch_size']
    cpu_workers = config['cpu_workers']
    reload_checkpoint = '../checkpoints/crnn_synth90k.pt'

    img_height = config['img_height']
    img_width = config['img_width']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    test_dataset = Synth90kDataset(root_dir=config['data_dir'], mode='test',
                                   img_height=img_height, img_width=img_width)

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=cpu_workers,
        collate_fn=synth90k_collate_fn)

    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(1, img_height, img_width, num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    criterion = CTCLoss(reduction='sum')
    criterion.to(device)

    evaluation = evaluate(crnn, test_loader, criterion,
                          decode_method=config['decode_method'],
                          beam_size=config['beam_size'])
    print('test_evaluation: loss={loss}, acc={acc}'.format(**evaluation))


In [4]:
# main()

In [5]:
eval_batch_size = config['eval_batch_size']
cpu_workers = config['cpu_workers']
cpu_workers  =0 
reload_checkpoint = '../checkpoints/crnn_synth90k.pt'

img_height = config['img_height']
img_width = config['img_width']

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

test_dataset = Synth90kDataset(root_dir=config['data_dir'], mode='test',
                               img_height=img_height, img_width=img_width)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=eval_batch_size,
    shuffle=False,
    num_workers=cpu_workers,
    collate_fn=synth90k_collate_fn)

num_class = len(Synth90kDataset.LABEL2CHAR) + 1
crnn = CRNN(1, img_height, img_width, num_class,
            map_to_seq_hidden=config['map_to_seq_hidden'],
            rnn_hidden=config['rnn_hidden'],
            leaky_relu=config['leaky_relu'])
crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
crnn.to(device)

criterion = CTCLoss(reduction='sum')
criterion.to(device)

device: cuda


<All keys matched successfully>

CRNN(
  (cnn): Sequential(
    (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu0): ReLU(inplace=True)
    (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU(inplace=True)
    (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3): ReLU(inplace=True)
    (pooling2): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
    (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu4): ReLU(inplace=True)
    (conv5): Conv2d(512, 512, 

CTCLoss()

In [6]:
## testing
i = 0
for test_data in test_loader:
    data = test_data
    i+=1
    if i==1:
        break

In [9]:
decode_method='beam_search'
beam_size = 10

In [11]:
with torch.no_grad():
    device = 'cuda' if next(crnn.parameters()).is_cuda else 'cpu'

    images, targets, target_lengths = [d.to(device) for d in data]

    logits = crnn(images)
    log_probs = torch.nn.functional.log_softmax(logits, dim=2)

    batch_size = images.size(0)
    input_lengths = torch.LongTensor([logits.size(0)] * batch_size)

    loss = criterion(log_probs, targets, input_lengths, target_lengths)

    preds = ctc_decode(log_probs, method=decode_method, beam_size=beam_size)
    reals = targets.cpu().numpy().tolist()
    target_lengths = target_lengths.cpu().numpy().tolist()

In [16]:
log_probs.shape

torch.Size([24, 234, 37])

In [20]:
tot_count = 0
tot_loss = 0
tot_correct = 0
wrong_cases = []

In [21]:
tot_count += batch_size
tot_loss += loss.item()
target_length_counter = 0
for pred, target_length in zip(preds, target_lengths):
    real = reals[target_length_counter:target_length_counter + target_length]
    target_length_counter += target_length
    if pred == real:
        tot_correct += 1
    else:
        wrong_cases.append((real, pred))

In [26]:
pred  ### only once evey elements of sequences is correct, this same can be deemed as correct.  

[32, 15, 15, 20, 11, 35, 29]

In [27]:
real

[32, 15, 15, 20, 11, 35, 29]