In [3]:
import os
import numpy as np 
import matplotlib.pyplot as plt 
from PIL import Image, ImageOps
import pandas as pd
import csv

import torch
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
from itertools import groupby

from src.crnn_dataset import TRDataset, get_split
from src.crnn_model import CRNN
from src.crnn_decoder import ctc_decode
from src.crnn_train import train_batch
from src.crnn_evaluate import evaluate

from tqdm import tqdm

torch.manual_seed(42)

<torch._C.Generator at 0x1a6859bdb10>

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [5]:
reload_checkpoint = 'checkpoints/crnn_synth_100k_config.pt'

In [6]:
if reload_checkpoint:
    config = torch.load(reload_checkpoint, map_location=device)
else :
    config = {
    'state_dict' : None,
    'img_height' : 32,
    'img_width' : 100,
    'batch_size' : 64,
    'root_dir' : "datasets/TR_100k",
    'labels' : "labels.csv",
    'splits' : [0.98,0.01,0.01],
    'map_to_seq' : 64,
    'rnn_hidden' : 256
    }

In [7]:
train_loader, val_loader, test_loader = get_split(root_dir=config['root_dir'],
                                                  labels=config['labels'],
                                                  img_width=config['img_width'],
                                                  img_height=config['img_height'],
                                                  batch_size=config['batch_size'],
                                                  splits=config['splits'])

In [8]:
num_class = len(TRDataset.LABEL2CHAR) + 1

In [9]:
crnn = CRNN(1, config['img_height'], config['img_width'], num_class,
                map_to_seq=config['map_to_seq'],
                rnn_hidden=config['rnn_hidden'])

if config['state_dict']:
    crnn.load_state_dict(config['state_dict'])

crnn.to(device)

CRNN(
  (cnn): Sequential(
    (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu0): ReLU(inplace=True)
    (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU(inplace=True)
    (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3): ReLU(inplace=True)
    (pooling2): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
    (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu4): ReLU(inplace=True)
    (conv5): Conv2d(512, 512, 

In [None]:
cp_path = os.path.join('checkpoints','crnn_synth_100k_config.pt')
torch.save(config,cp_path)

In [10]:
criterion = nn.CTCLoss(reduction='sum', zero_infinity=True)
optimizer = torch.optim.Adam(crnn.parameters(), lr=0.0005)

In [None]:
def predict(crnn, dataloader, label2char):
    crnn.eval()
    pbar = tqdm(total=len(dataloader), desc="Predict")

    all_preds = []
    with torch.no_grad():
        for data in dataloader:
            device = 'cuda' if next(crnn.parameters()).is_cuda else 'cpu'

            images, targets, target_lengths = [d.to(device) for d in data]

            logits = crnn(images)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)

            preds = ctc_decode(log_probs, label2char=label2char)
            all_preds += preds

            pbar.update(1)
        pbar.close()

    return all_preds


def show_result(preds):
    print('\n===== result =====')
    for pred in preds:
        text = ''.join(pred)
        print(text)

In [11]:
evaluation = evaluate(crnn, val_loader, criterion)
print('valid_evaluation: loss={loss}, acc={acc}'.format(**evaluation))

Evaluate: 100%|██████████| 16/16 [00:02<00:00,  5.52it/s]

valid_evaluation: loss=2.2308524475097657, acc=0.779





In [None]:
fig, axes = plt.subplots(8, 8, figsize=(24, 24))

crnn.eval()

label2char = TRDataset.LABEL2CHAR
with torch.no_grad():
    for i, data in enumerate(val_loader):
        device = 'cuda' if next(crnn.parameters()).is_cuda else 'cpu'

        images, targets, target_lengths = [d.to(device) for d in data]

        logits = crnn(images)
        log_probs = torch.nn.functional.log_softmax(logits, dim=2)

        batch_size = images.size(0)
        input_lengths = torch.LongTensor([logits.size(0)] * batch_size)

        loss = criterion(log_probs, targets, input_lengths, target_lengths)

        preds = ctc_decode(log_probs, label2char=label2char)
        reals = targets.cpu().numpy().tolist()
        target_lengths = target_lengths.cpu().numpy().tolist()

        all_preds = []
        all_reals = []
        target_length_counter = 0
        for pred, target_length in zip(preds, target_lengths):
            real = reals[target_length_counter:target_length_counter + target_length]
            real = [label2char[l] for l in real]
            pred_text, real_text = ''.join(pred), ''.join(real)
            all_preds.append(pred_text)
            all_reals.append(real_text)

            target_length_counter += target_length

        for i, img in enumerate(images.cpu().numpy()):
            axes.flatten()[i].imshow(img[0], cmap='gray')
            axes.flatten()[i].set_title(f'{all_reals[i]} > {all_preds[i]}')

        plt.show()
        break

In [None]:
epochs = 20

train_losses = []
val_losses = []
val_accs = []
for epoch in range(1,epochs+1):
    print(f'EPOCH [{epoch}/{epochs}]')
    run_train_loss = 0.
    run_train_count = 0
    
    step = 1
    for train_data in train_loader:
        loss = train_batch(crnn, train_data, optimizer, criterion, device)
        train_size = train_data[0].size(0)
        run_train_loss += loss
        run_train_count += train_size
        if step%100 == 0:
            print(f'Running Train Loss [{step}/{len(train_loader)}] : {run_train_loss/run_train_count :.2f}')
        
        step += 1
    
    train_loss = loss / train_size
    eval = evaluate(crnn, val_loader, criterion)
    
    print(f"EPOCH [{epoch}/{epochs}] => Train Loss : {train_loss:.4f} | Val Loss : {eval['loss']:.4f}, Val Acc: { eval['acc']}")
    train_losses.append(train_loss)
    val_losses.append(evaluation['loss'])
    val_accs.append(evaluation['acc'])

print('[Evaluation]')
final_eval = evaluate(crnn, val_loader, criterion)
print(f"Val Loss : {final_eval['loss']:.4f}, Val Acc: {final_eval['acc']}")

In [None]:
cp_path = os.path.join('checkpoints','crnn_100k.pt')
torch.save(crnn.state_dict(),cp_path)