In [1]:
%cd src/

from pathlib import Path
import numpy as np
import math
from itertools import groupby
import h5py
import numpy as np
import unicodedata
import cv2
from torchvision.models import resnet50, resnet101
from torch.autograd import Variable
import torchvision
from data import preproc as pp
from data import evaluation
from torch.utils.data import Dataset
import time
import wandb


import torch
from torch import nn
import pandas as pd
import numpy as np
from functools import reduce
from operator import __add__
import random


def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

set_random_seeds(random_seed=13)




#Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
def conv1d(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):
    "Create and initialize a `nn.Conv1d` layer with spectral normalization."
    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
    nn.init.kaiming_normal_(conv.weight)
    if bias: conv.bias.data.zero_()
    return nn.utils.spectral_norm(conv)



# Adapted from SelfAttention layer at https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
# Inspired by https://arxiv.org/pdf/1805.08318.pdf
class SimpleSelfAttention(nn.Module):
    
    def __init__(self, n_in:int, ks=1, sym=False):#, n_out:int):
        super().__init__()
           
        self.conv = conv1d(n_in, n_in, ks, padding=ks//2, bias=False)      
       
        self.gamma = nn.Parameter(torch.Tensor([0.]))
        
        self.sym = sym
        self.n_in = n_in
        
    def forward(self,x):
        
        
        if self.sym:
            # symmetry hack by https://github.com/mgrankin
            c = self.conv.weight.view(self.n_in,self.n_in)
            c = (c + c.t())/2
            self.conv.weight = c.view(self.n_in,self.n_in,1)
                
        size = x.size()  
        x = x.view(*size[:2],-1)   # (C,N)
        
        # changed the order of mutiplication to avoid O(N^2) complexity
        # (x*xT)*(W*x) instead of (x*(xT*(W*x)))
        
        convx = self.conv(x)   # (C,C) * (C,N) = (C,N)   => O(NC^2)
        xxT = torch.bmm(x,x.permute(0,2,1).contiguous())   # (C,N) * (N,C) = (C,C)   => O(NC^2)
        
        o = torch.bmm(xxT, convx)   # (C,C) * (C,N) = (C,N)   => O(NC^2)
          
        o = self.gamma * o + x
        
          
        return o.view(*size).contiguous()        
        


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=128):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class OCR(nn.Module):

    def __init__(self, vocab_len, hidden_dim, nheads,
                 num_encoder_layers, num_decoder_layers):
        super().__init__()
        
        self.conv1 = nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, stride=1, padding="same"
        )
#         self.sa1 = SimpleSelfAttention(64)        
        self.batch1 = nn.BatchNorm2d(16)
        self.act1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        ##CNN Layer 2
        self.conv2 = nn.Conv2d(
            in_channels=16, out_channels=32, kernel_size=3, stride=1, padding="same"
        )
#         self.sa2 = SimpleSelfAttention(128)        
        self.batch2 = nn.BatchNorm2d(32)
        self.act2 = nn.LeakyReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        ##CNN Layer 3
        self.drop1 = nn.Dropout(0.2)
        self.conv3 = nn.Conv2d(
            in_channels=32, out_channels=48, kernel_size=3, stride=1, padding="same"
        )
#         self.sa3 = SimpleSelfAttention(256)        
        self.batch3 = nn.BatchNorm2d(48)
        self.act3 = nn.LeakyReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        ##CNN Layer 4
        self.drop2 = nn.Dropout(0.2)
        self.conv4 = nn.Conv2d(
            in_channels=48, out_channels=64, kernel_size=3, stride=1, padding="same"
        )
#         self.sa4 = SimpleSelfAttention(512)        
        self.batch4 = nn.BatchNorm2d(64)
        self.act4 = nn.LeakyReLU()
        ##CNN Layer 5
        self.drop3 = nn.Dropout(0.2)
        self.conv5 = nn.Conv2d(
            in_channels=64, out_channels=80, kernel_size=3, stride=1, padding="same"
        )
#         self.sa5 = SimpleSelfAttention(512)        
        self.batch5 = nn.BatchNorm2d(80)
        self.act5 = nn.LeakyReLU()
        
        self.conv = nn.Conv2d(80, hidden_dim, 1)
        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads with length of vocab
        # DETR used basic 3 layer MLP for output
        self.vocab = nn.Linear(hidden_dim,vocab_len)

        # output positional encodings (object queries)
        self.decoder = nn.Embedding(vocab_len, hidden_dim)
        self.query_pos = PositionalEncoding(hidden_dim, .2)

        # spatial positional encodings, sine positional encoding can be used.
        # Detr baseline uses sine positional encoding.
        self.row_embed = nn.Parameter(torch.rand(128, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(16, hidden_dim // 2))
        self.trg_mask = None
  
    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask==1, float('-inf'))
        return mask

    def get_feature(self,x):
        x = self.conv1(x)
#         x = self.sa1(x)
        x = self.batch1(x)
        x = self.act1(x)
        x = self.pool1(x)
        ##CNN Layer 2
        x = self.conv2(x)
#         x = self.sa2(x)
        x = self.batch2(x)
        x = self.act2(x)
        x = self.pool2(x)
        ##CNN Layer 3
        x = self.drop1(x)
        x = self.conv3(x)
#         x = self.sa3(x)
        x = self.batch3(x)
        x = self.act3(x)
        x = self.pool3(x)
        ##CNN Layer 4
        x = self.drop2(x)
        x = self.conv4(x)
#         x = self.sa4(x)        
        x = self.batch4(x)
        x = self.act4(x)
        ##CNN Layer 5
        x = self.drop3(x)
        x = self.conv5(x)
#         x = self.sa5(x)        
        x = self.batch5(x)
        x = self.act5(x)
        return x


    def make_len_mask(self, inp):
        return (inp == 0).transpose(0, 1)


    def forward(self, inputs, trg):
        # propagate inputs through ResNet-101 up to avg-pool layer
        x = self.get_feature(inputs)
        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)
        # construct positional encodings
        bs,_,H, W = h.shape
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)

        # generating subsequent mask for target
        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_subsequent_mask(trg.shape[1]).to(trg.device)

        # Padding mask
        trg_pad_mask = self.make_len_mask(trg)

        # Getting postional encoding for target
        trg = self.decoder(trg)
        trg = self.query_pos(trg)
        
        output = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1), trg.permute(1,0,2), tgt_mask=self.trg_mask, 
                                  tgt_key_padding_mask=trg_pad_mask.permute(1,0))

        return self.vocab(output.transpose(0,1))


def make_model(vocab_len, hidden_dim=256, nheads=4,
                 num_encoder_layers=4, num_decoder_layers=4):
    
    return OCR(vocab_len, hidden_dim, nheads,
                 num_encoder_layers, num_decoder_layers)

/home/mhamdan/seq2seqAttenHTR/Transformer_ocr/src


In [2]:
# model = make_model(vocab_len=99,hidden_dim=256, nheads=4,
#                  num_encoder_layers=4, num_decoder_layers=4)

# img = torch.rand(1,1,1024,128)
# trg = torch.randint(1,5,(1,128))
# x = model(img,trg)

"""
Uses generator functions to supply train/test with data.
Image renderings and text are created on the fly each time.
"""

class DataGenerator(Dataset):
    """Generator class with data streaming"""

    def __init__(self, source, split, transform, tokenizer):
        self.tokenizer = tokenizer
        self.transform = transform
        
        self.split = split
        self.dataset = dict()

        with h5py.File(source, "r") as f:
            self.dataset[self.split] = dict()

            self.dataset[self.split]['dt'] = np.array(f[self.split]['dt'])
            self.dataset[self.split]['gt'] = np.array(f[self.split]['gt'])
          
            randomize = np.arange(len(self.dataset[self.split]['gt']))
            np.random.seed(42)
            np.random.shuffle(randomize)

            self.dataset[self.split]['dt'] = self.dataset[self.split]['dt'][randomize]
            self.dataset[self.split]['gt'] = self.dataset[self.split]['gt'][randomize]

            # decode sentences from byte
            self.dataset[self.split]['gt'] = [x.decode() for x in self.dataset[self.split]['gt']]
            
        self.size = len(self.dataset[self.split]['gt'])


    def __getitem__(self, i):
        img = self.dataset[self.split]['dt'][i]
        
        #making image compatible with resnet
#         img = cv2.transpose(img)
        img = np.repeat(img[..., np.newaxis],3, -1).astype("float32")   
#         img = pp.normalization(img).astype("float32")
#         img = img.astype("float32")
        if self.transform is not None:
            aug = self.transform(image=img)
            img = aug['image']
            
#             img = self.transform(img)
            
        y_train = self.tokenizer.encode(self.dataset[self.split]['gt'][i]) 
        
        #padding till max length
        y_train = np.pad(y_train, (0, self.tokenizer.maxlen - len(y_train)))

        gt = torch.Tensor(y_train)

        return img, gt          

    def __len__(self):
      return self.size

class Tokenizer():
    """Manager tokens functions and charset/dictionary properties"""

    def __init__(self, chars, max_text_length=128):
        self.PAD_TK, self.UNK_TK,self.SOS,self.EOS = "¶", "¤", "SOS", "EOS"
        self.chars = [self.PAD_TK] + [self.UNK_TK ]+ [self.SOS] + [self.EOS] +list(chars)
        self.PAD = self.chars.index(self.PAD_TK)
        self.UNK = self.chars.index(self.UNK_TK)

        self.vocab_size = len(self.chars)
        self.maxlen = max_text_length

    def encode(self, text):
        """Encode text to vector"""
        text = unicodedata.normalize("NFKD", text).encode("ASCII", "ignore").decode("ASCII")
        text = " ".join(text.split())

        groups = ["".join(group) for _, group in groupby(text)]
        text = "".join([self.UNK_TK.join(list(x)) if len(x) > 1 else x for x in groups])
        encoded = []

        text = ['SOS'] + list(text) + ['EOS']
        for item in text:
            index = self.chars.index(item)
            index = self.UNK if index == -1 else index
            encoded.append(index)

        return np.asarray(encoded)

    def decode(self, text):
        """Decode vector to text"""
        
        decoded = "".join([self.chars[int(x)] for x in text if x > -1])
        decoded = self.remove_tokens(decoded)
        decoded = pp.text_standardize(decoded)

        return decoded

    def remove_tokens(self, text):
        """Remove tokens (PAD) from text"""

        return text.replace(self.PAD_TK, "").replace(self.UNK_TK, "")

In [3]:
import os
import datetime
import string

batch_size = 8
epochs = 200

# define paths
#change paths accordingly
source = 'iam_only_illumintaion'
source_path = '../data/{}.hdf5'.format(source)

# define input size, number max of chars per line and list of valid chars
input_size = (1024, 128, 1)
max_text_length = 128
charset_base = string.printable[:95]
# charset_base = string.printable[:36].lower() + string.printable[36+26:95].lower() 

print("source:", source_path)
print("charset:", charset_base)


import torchvision.transforms as T
local_rank = 1
device = torch.device("cuda:{}".format(local_rank))

# transform = T.Compose([
#     T.ToTensor()])
tokenizer = Tokenizer(charset_base)
import albumentations
import albumentations.pytorch

if True:

    transform_train = albumentations.Compose([
        albumentations.OneOf(
            [
                albumentations.MotionBlur(p=1, blur_limit=8),
                albumentations.OpticalDistortion(p=1, distort_limit=0.05),
                albumentations.GaussNoise(p=1, var_limit=(10.0, 100.0)),
                albumentations.RandomBrightnessContrast(p=1, brightness_limit=0.2),
                albumentations.Downscale(p=1, scale_min=0.3, scale_max=0.5),
            ],
            p=.5,
        ),
#         albumentations.Resize(224,224),
        albumentations.Normalize(),
        albumentations.pytorch.ToTensorV2()

    ])

    transform_valid = albumentations.Compose(
        [
#         albumentations.Resize(224,224),            
            albumentations.Normalize(),
            albumentations.pytorch.ToTensorV2()
        ]
    )

else:
    transform_train = albumentations.Compose(
        [
#         albumentations.Resize(224,224),            
            albumentations.Normalize(),
            albumentations.pytorch.ToTensorV2()
        ]
    )
    
    transform_valid = albumentations.Compose(
        [
#         albumentations.Resize(224,224),
            
            albumentations.Normalize(),
            albumentations.pytorch.ToTensorV2()
        ]
    )
    

train_loader = torch.utils.data.DataLoader(DataGenerator(source_path,'train',transform_train, tokenizer), batch_size=batch_size,  shuffle=True,num_workers=6)
val_loader = torch.utils.data.DataLoader(DataGenerator(source_path,'valid',transform_valid, tokenizer), batch_size=batch_size, shuffle=False, num_workers=6)

source: ../data/iam_only_illumintaion.hdf5
charset: 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ 


In [4]:
num_encoder_layers = 3
num_decoder_layers = 3


model = make_model( vocab_len=tokenizer.vocab_size,hidden_dim=256, nheads=4,
                 num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)

            
# init_funcs = {
# 1: lambda x: torch.nn.init.normal_(x, mean=0., std=1.), # can be bias
# 2: lambda x: torch.nn.init.xavier_normal_(x, gain=1.), # can be weight
# 3: lambda x: torch.nn.init.xavier_uniform_(x, gain=1.), # can be conv1D filter
# 4: lambda x: torch.nn.init.xavier_uniform_(x, gain=1.), # can be conv2D filter
# "default": lambda x: torch.nn.init.constant(x, 1.), # everything else
# }
# for p in model.parameters():
#     init_func = init_funcs.get(len(p.shape), init_funcs["default"])
#     init_func(p)

class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    def __init__(self, size, padding_idx=0, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False)
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False))

smoothing = .4
criterion = LabelSmoothing(size=tokenizer.vocab_size, padding_idx=0, smoothing=smoothing)
criterion.to(device)
lr = .5e-05# learnig rte
backbone_lr = .003
# if not args.pretrained:
#     backbone_lr = backbone_lr*10

# param_dicts = [
#     {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
#     {
#         "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
#         "lr": backbone_lr,
#     },
# ]

scheduler_factor = .8



optimizer = torch.optim.AdamW(model.parameters(), lr=lr,weight_decay=.0004)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=scheduler_factor)



In [5]:
_ = model.to(device)

In [6]:
target_path = "../output_cnnt/only_illumination//{}_best.hdf5"

In [7]:
def train(model, criterion, optimiser,dataloader):
 
    model.train()
    total_loss = 0
    for batch, (imgs, labels_y,) in enumerate(dataloader):
          imgs = imgs.to(device)
          labels_y = labels_y.to(device)
    
          optimiser.zero_grad()
          output = model(imgs.float(),labels_y.long()[:,:-1])
 
          loss = criterion(output.log_softmax(-1).contiguous().view(-1, tokenizer.vocab_size), labels_y[:,1:].contiguous().view(-1).long()) 
 
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), 0.2)
          optimizer.step()
          total_loss += loss.item()
 
    return total_loss / len(dataloader)
 
def evaluate(model, criterion, dataloader,):
 
    model.eval()
    epoch_loss = 0
    cer = 0
    with torch.no_grad():
      for batch, (imgs, labels_y,) in enumerate(dataloader):
            imgs = imgs.to(device)
            labels_y = labels_y.to(device)
 
            output = model(imgs.float(),labels_y.long()[:,:-1])
            o = output.argmax(-1)
            predicts = list(map(lambda x : tokenizer.decode(x).replace('SOS','').replace('EOS',''),o))
            gt = list(map(lambda x : tokenizer.decode(x).replace('SOS','').replace('EOS',''),labels_y))
            cer += evaluation.ocr_metrics(predicts=predicts,
                                   ground_truth=gt)[0]
            
            loss = criterion(output.log_softmax(-1).contiguous().view(-1, tokenizer.vocab_size), labels_y[:,1:].contiguous().view(-1).long())
  
            epoch_loss += loss.item()
    
 
    return epoch_loss / len(dataloader), cer

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs
 
best_valid_loss = np.inf

In [8]:
c = 0
for epoch in range(200):

    print(f"Epoch: {epoch+1:02}", "learning rate{}".format(lr_scheduler.get_last_lr()))

    start_time = time.time()

    train_loss = train(model, criterion, optimizer, train_loader)
    valid_loss, cer = evaluate(model, criterion, val_loader)
    epoch_mins, epoch_secs = epoch_time(start_time, time.time())

    c += 1
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), target_path.format("loss"))
        c=0

    if c > 4:
        # decrease lr if loss does not deacrease after 5 steps
        lr_scheduler.step()
        c = 0
                   
    if epoch%10==0:
        torch.save(model.state_dict(), target_path.format("epoch"))


    print(f"Time: {epoch_mins}m {epoch_secs}s")
    print(f"Train Loss: {train_loss:.3f}")
    print(f"Val   Loss: {valid_loss:.3f}")
    print(f"cer: {cer:.3f}")

Epoch: 01 learning rate[5e-06]
Time: 2m 54s
Train Loss: 553.045
Val   Loss: 477.872
cer: 81.738
Epoch: 02 learning rate[5e-06]
Time: 2m 56s
Train Loss: 500.978
Val   Loss: 464.637
cer: 82.186
Epoch: 03 learning rate[5e-06]
Time: 2m 56s
Train Loss: 489.961
Val   Loss: 458.168
cer: 81.807
Epoch: 04 learning rate[5e-06]
Time: 2m 54s
Train Loss: 484.298
Val   Loss: 455.900
cer: 81.932
Epoch: 05 learning rate[5e-06]
Time: 2m 55s
Train Loss: 480.653
Val   Loss: 453.041
cer: 81.196
Epoch: 06 learning rate[5e-06]
Time: 2m 56s
Train Loss: 477.854
Val   Loss: 451.665
cer: 81.545
Epoch: 07 learning rate[5e-06]
Time: 2m 56s
Train Loss: 475.755
Val   Loss: 448.927
cer: 81.018
Epoch: 08 learning rate[5e-06]
Time: 2m 57s
Train Loss: 473.399
Val   Loss: 447.581
cer: 80.770
Epoch: 09 learning rate[5e-06]
Time: 2m 56s
Train Loss: 471.360
Val   Loss: 445.643
cer: 80.343
Epoch: 10 learning rate[5e-06]
Time: 2m 57s
Train Loss: 469.653
Val   Loss: 443.710
cer: 80.262
Epoch: 11 learning rate[5e-06]
Time: 2m 

Time: 3m 0s
Train Loss: 388.915
Val   Loss: 369.582
cer: 70.059
Epoch: 88 learning rate[5e-06]
Time: 3m 0s
Train Loss: 388.587
Val   Loss: 370.301
cer: 70.488
Epoch: 89 learning rate[5e-06]
Time: 2m 59s
Train Loss: 388.052
Val   Loss: 369.865
cer: 69.979
Epoch: 90 learning rate[5e-06]
Time: 2m 57s
Train Loss: 387.147
Val   Loss: 368.169
cer: 71.176
Epoch: 91 learning rate[5e-06]
Time: 2m 59s
Train Loss: 386.964
Val   Loss: 368.666
cer: 69.614
Epoch: 92 learning rate[5e-06]
Time: 3m 0s
Train Loss: 386.151
Val   Loss: 367.487
cer: 69.868
Epoch: 93 learning rate[5e-06]
Time: 3m 0s
Train Loss: 385.460
Val   Loss: 366.873
cer: 70.814
Epoch: 94 learning rate[5e-06]
Time: 3m 0s
Train Loss: 385.327
Val   Loss: 366.821
cer: 69.111
Epoch: 95 learning rate[5e-06]
Time: 3m 1s
Train Loss: 384.645
Val   Loss: 366.177
cer: 69.736
Epoch: 96 learning rate[5e-06]
Time: 3m 0s
Train Loss: 384.426
Val   Loss: 366.822
cer: 69.682
Epoch: 97 learning rate[5e-06]
Time: 2m 57s
Train Loss: 383.802
Val   Loss: 36

Time: 2m 49s
Train Loss: 356.777
Val   Loss: 350.487
cer: 66.886
Epoch: 171 learning rate[4.000000000000001e-06]
Time: 2m 49s
Train Loss: 356.464
Val   Loss: 349.770
cer: 66.932
Epoch: 172 learning rate[4.000000000000001e-06]
Time: 2m 50s
Train Loss: 356.250
Val   Loss: 350.051
cer: 67.000
Epoch: 173 learning rate[4.000000000000001e-06]
Time: 2m 50s
Train Loss: 355.799
Val   Loss: 349.448
cer: 66.842
Epoch: 174 learning rate[4.000000000000001e-06]
Time: 2m 49s
Train Loss: 355.589
Val   Loss: 349.000
cer: 66.658
Epoch: 175 learning rate[4.000000000000001e-06]
Time: 2m 50s
Train Loss: 355.520
Val   Loss: 349.027
cer: 66.647
Epoch: 176 learning rate[4.000000000000001e-06]
Time: 2m 49s
Train Loss: 355.395
Val   Loss: 349.430
cer: 66.597
Epoch: 177 learning rate[4.000000000000001e-06]
Time: 2m 49s
Train Loss: 354.642
Val   Loss: 348.950
cer: 66.837
Epoch: 178 learning rate[4.000000000000001e-06]
Time: 2m 49s
Train Loss: 354.903
Val   Loss: 349.371
cer: 66.759
Epoch: 179 learning rate[4.0000

In [None]:
c = 0
for epoch in range(300):

    print(f"Epoch: {epoch+1:02}", "learning rate{}".format(lr_scheduler.get_last_lr()))

    start_time = time.time()

    train_loss = train(model, criterion, optimizer, train_loader)
    valid_loss, cer = evaluate(model, criterion, val_loader)
    epoch_mins, epoch_secs = epoch_time(start_time, time.time())

    c += 1
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), target_path.format("loss"))
        c=0

    if c > 4:
        # decrease lr if loss does not deacrease after 5 steps
        lr_scheduler.step()
        c = 0
                   
    if epoch%10==0:
        torch.save(model.state_dict(), target_path.format("epoch"))


    print(f"Time: {epoch_mins}m {epoch_secs}s")
    print(f"Train Loss: {train_loss:.3f}")
    print(f"Val   Loss: {valid_loss:.3f}")
    print(f"cer: {cer:.3f}")

In [9]:
d1 = torch.load("../output_cnnt/only_illumination/loss_best.hdf5")
model.load_state_dict(d1)

<All keys matched successfully>

In [10]:
def get_memory(model,imgs):
    x = model.conv(model.get_feature(imgs))
    bs,_,H, W = x.shape
    pos = torch.cat([
            model.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            model.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)

    return model.transformer.encoder(pos +  0.1 * x.flatten(2).permute(2, 0, 1))
    

def test(model, test_loader, max_text_length):
    model.eval()
    predicts = []
    gt = []
    imgs = []
    c=0
    with torch.no_grad():
        for batch in test_loader:
            src, trg = batch
            imgs.append(src.flatten(0,1))
            src, trg = src.to(device), trg.to(device)            
            memory = get_memory(model,src.float())
            out_indexes = [tokenizer.chars.index('SOS'), ]
            for i in range(max_text_length):
                mask = model.generate_square_subsequent_mask(i+1).to(device)
                trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device)
                output = model.vocab(model.transformer.decoder(model.query_pos(model.decoder(trg_tensor)), memory,tgt_mask=mask))
                out_token = output.argmax(2)[-1].item()
                out_indexes.append(out_token)
                if out_token == tokenizer.chars.index('EOS'):
                    break
            predicts.append(tokenizer.decode(out_indexes))
            gt.append(tokenizer.decode(trg.flatten(0,1)))
#             if c==5:
#                 break
            c+=1
    return predicts, gt, imgs

test_loader = torch.utils.data.DataLoader(DataGenerator(source_path,'test',transform_valid, tokenizer), batch_size=1, shuffle=False)

predicts, gt, imgs = test(model,test_loader , max_text_length)

predicts = list(map(lambda x : x.replace('SOS','').replace('EOS',''),predicts))
gt = list(map(lambda x : x.replace('SOS','').replace('EOS',''),gt))

evaluate = evaluation.ocr_metrics(predicts=predicts,
                                  ground_truth=gt,)

In [12]:
evaluate

array([0.55101188, 0.75791297, 1.        ])