In [1]:
%cd src/

from pathlib import Path
import numpy as np
import math
from itertools import groupby
import h5py
import numpy as np
import unicodedata
import cv2
import torch
from torch import nn
from torchvision.models import resnet50, resnet101
from torch.autograd import Variable
import torchvision
from data import preproc as pp
from data import evaluation
from torch.utils.data import Dataset
import time
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl 

/home/mhamdan/seq2seqAttenHTR/Transformer_ocr/src


In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=128):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [8]:
class OCR(nn.Module):

    def __init__(self, vocab_len, hidden_dim, nheads,
                 num_encoder_layers, num_decoder_layers):
        super().__init__()

        # create ResNet-101 backbone
        self.backbone = resnet101()
        del self.backbone.fc

        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads with length of vocab
        # DETR used basic 3 layer MLP for output
        self.vocab = nn.Linear(hidden_dim,vocab_len)

        # output positional encodings (object queries)
        self.decoder = nn.Embedding(vocab_len, hidden_dim)
        self.query_pos = PositionalEncoding(hidden_dim, .2)

        # spatial positional encodings, sine positional encoding can be used.
        # Detr baseline uses sine positional encoding.
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.trg_mask = None
  
    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask==1, float('-inf'))
        return mask

    def get_feature(self,x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)   
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)
        return x


    def make_len_mask(self, inp):
        return (inp == 0).transpose(0, 1)


    def forward(self, inputs, trg):
        # propagate inputs through ResNet-101 up to avg-pool layer
        x = self.get_feature(inputs)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)

        # construct positional encodings
        bs,_,H, W = h.shape
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)

        # generating subsequent mask for target
        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_subsequent_mask(trg.shape[1]).to(trg.device)

        # Padding mask
        trg_pad_mask = self.make_len_mask(trg)

        # Getting postional encoding for target
        trg = self.decoder(trg)
        trg = self.query_pos(trg)
        
        output = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1), trg.permute(1,0,2), tgt_mask=self.trg_mask, 
                                  tgt_key_padding_mask=trg_pad_mask.permute(1,0))

        return self.vocab(output.transpose(0,1))

def make_model(vocab_len, hidden_dim=256, nheads=4,
                 num_encoder_layers=4, num_decoder_layers=4):
    
    return OCR(vocab_len, hidden_dim, nheads,
                 num_encoder_layers, num_decoder_layers)

In [26]:
"""
Uses generator functions to supply train/test with data.
Image renderings and text are created on the fly each time.
"""

class DataGenerator(Dataset):
    """Generator class with data streaming"""

    def __init__(self, source, split, transform, tokenizer):
        self.tokenizer = tokenizer
        self.transform = transform
        
        self.split = split
        self.dataset = dict()

        with h5py.File(source, "r") as f:
            self.dataset[self.split] = dict()

            self.dataset[self.split]['dt'] = np.array(f[self.split]['dt'])
            self.dataset[self.split]['gt'] = np.array(f[self.split]['gt'])
          
            randomize = np.arange(len(self.dataset[self.split]['gt']))
            np.random.seed(42)
            np.random.shuffle(randomize)

            self.dataset[self.split]['dt'] = self.dataset[self.split]['dt'][randomize]
            self.dataset[self.split]['gt'] = self.dataset[self.split]['gt'][randomize]

            # decode sentences from byte
            self.dataset[self.split]['gt'] = [x.decode() for x in self.dataset[self.split]['gt']]
            
        self.size = len(self.dataset[self.split]['gt'])


    def __getitem__(self, i):
        img = self.dataset[self.split]['dt'][i]
        
        #making image compatible with resnet
        img = np.repeat(img[..., np.newaxis],3, -1)    
        img = pp.normalization(img).astype("float32")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)        
#         if self.transform is not None:
#             img = self.transform(img)

        if self.transform is not None:
            augmented = self.transform(image = img)
            img = augmented['image']
            
        y_train = self.tokenizer.encode(self.dataset[self.split]['gt'][i]) 
        
        #padding till max length
        y_train = np.pad(y_train, (0, self.tokenizer.maxlen - len(y_train)))

        gt = torch.Tensor(y_train)

        return img, gt          

    def __len__(self):
      return self.size

In [10]:
class Tokenizer():
    """Manager tokens functions and charset/dictionary properties"""

    def __init__(self, chars, max_text_length=128):
        self.PAD_TK, self.UNK_TK,self.SOS,self.EOS = "¶", "¤", "SOS", "EOS"
        self.chars = [self.PAD_TK] + [self.UNK_TK ]+ [self.SOS] + [self.EOS] +list(chars)
        self.PAD = self.chars.index(self.PAD_TK)
        self.UNK = self.chars.index(self.UNK_TK)

        self.vocab_size = len(self.chars)
        self.maxlen = max_text_length

    def encode(self, text):
        """Encode text to vector"""
        text = unicodedata.normalize("NFKD", text).encode("ASCII", "ignore").decode("ASCII").lower()
        text = " ".join(text.split())

        groups = ["".join(group) for _, group in groupby(text)]
        text = "".join([self.UNK_TK.join(list(x)) if len(x) > 1 else x for x in groups])
        encoded = []

        text = ['SOS'] + list(text) + ['EOS']
        for item in text:
            index = self.chars.index(item)
            index = self.UNK if index == -1 else index
            encoded.append(index)

        return np.asarray(encoded)

    def decode(self, text):
        """Decode vector to text"""
        
        decoded = "".join([self.chars[int(x)] for x in text if x > -1])
        decoded = self.remove_tokens(decoded)
        decoded = pp.text_standardize(decoded)

        return decoded

    def remove_tokens(self, text):
        """Remove tokens (PAD) from text"""

        return text.replace(self.PAD_TK, "").replace(self.UNK_TK, "")

import os
import datetime
import string

batch_size = 16
epochs = 200

# define paths
#change paths accordingly
source = 'iam'
source_path = '../data/{}.hdf5'.format(source)
output_path = os.path.join("..", "output", source)
target_path = os.path.join(output_path, "checkpoint_weights_iam.hdf5")
os.makedirs(output_path, exist_ok=True)

# define input size, number max of chars per line and list of valid chars
input_size = (1024, 128, 1)
max_text_length = 128
# charset_base = string.printable[:95]
charset_base = string.printable[:36].lower() + string.printable[36+26:95].lower() 

print("source:", source_path)
print("output", output_path)
print("target", target_path)
print("charset:", charset_base)

import torchvision.transforms as T
transform = T.Compose([
    T.ToTensor()])
tokenizer = Tokenizer(charset_base)

source: ../data/iam.hdf5
output ../output/iam
target ../output/iam/checkpoint_weights_iam.hdf5
charset: 0123456789abcdefghijklmnopqrstuvwxyz!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ 


In [1]:
import albumentations
import albumentations.pytorch


albumentations.Compose([
    albumentations.RandomContrast(),
  albumentations.MotionBlur(p=.2),
  albumentations.OpticalDistortion(p=.3),
  albumentations.GaussNoise(p=.2),
    albumentations.RandomBrightnessContrast(p=0.2),       
    albumentations.pytorch.transforms.ToTensorV2()
])



Compose([
  RandomContrast(always_apply=False, p=0.5, limit=(-0.2, 0.2)),
  MotionBlur(always_apply=False, p=0.2, blur_limit=(3, 7)),
  OpticalDistortion(always_apply=False, p=0.3, distort_limit=(-0.05, 0.05), shift_limit=(-0.05, 0.05), interpolation=1, border_mode=4, value=None, mask_value=None),
  GaussNoise(always_apply=False, p=0.2, var_limit=(10.0, 50.0), per_channel=True, mean=0),
  RandomBrightnessContrast(always_apply=False, p=0.2, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), brightness_by_max=True),
  ToTensorV2(always_apply=True, p=1.0, transpose_mask=False),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={})

In [27]:
testt = albumentations.Compose([
    albumentations.pytorch.transforms.ToTensorV2()
])

In [11]:
# data = DataModule(source_path, 16, tokenizer, transform)

model = make_model(tokenizer.vocab_size,hidden_dim=256, nheads=4,
                 num_encoder_layers=4, num_decoder_layers=4)

In [12]:
d = torch.load("../output/iam/checkpoint_weights_iam_small_4_enc_4_dec_aug.hdf5")

In [13]:
f = {}
for i in d:
    f[i.replace("module.","")] = d[i]

In [14]:
model.load_state_dict(f)

<All keys matched successfully>

In [15]:
def get_memory(model,imgs):
    x = model.conv(model.get_feature(imgs))
    bs,_,H, W = x.shape
    pos = torch.cat([
            model.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            model.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)

    return model.transformer.encoder(pos +  0.1 * x.flatten(2).permute(2, 0, 1))
    

def test(model, test_loader, max_text_length):
    model.eval()
    predicts = []
    gt = []
    imgs = []
    c=0
    with torch.no_grad():
        for batch in test_loader:
            src, trg = batch
            imgs.append(src.flatten(0,1))
            src, trg = src.cuda(), trg.cuda()            
            memory = get_memory(model,src.float())
            out_indexes = [tokenizer.chars.index('SOS'), ]
            for i in range(max_text_length):
                mask = model.generate_square_subsequent_mask(i+1).to(device)
                trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device)
                output = model.vocab(model.transformer.decoder(model.query_pos(model.decoder(trg_tensor)), memory,tgt_mask=mask))
                out_token = output.argmax(2)[-1].item()
                out_indexes.append(out_token)
                if out_token == tokenizer.chars.index('EOS'):
                    break
            predicts.append(tokenizer.decode(out_indexes))
            gt.append(tokenizer.decode(trg.flatten(0,1)))
            if c==100:
                break
            c+=1
    return predicts, gt, imgs

In [16]:
device ='cuda'
_ =  model.to("cuda")

In [28]:
test_loader = torch.utils.data.DataLoader(DataGenerator(source_path,'test',testt, tokenizer), batch_size=1, shuffle=False, num_workers=6)

In [29]:
len(test_loader)

1425

In [30]:
predicts, gt, imgs = test(model, test_loader, max_text_length)

In [31]:
predicts = list(map(lambda x : x.replace('SOS','').replace('EOS',''),predicts))
gt = list(map(lambda x : x.replace('SOS','').replace('EOS',''),gt))

evaluate = evaluation.ocr_metrics(predicts=predicts,
                                  ground_truth=gt,)

In [32]:
evaluate

array([0.76314335, 0.9618147 , 1.        ])

In [69]:
evaluate

array([0.1498361 , 0.38938587, 0.97052632])

In [64]:

for i in range(20):
    print(gt[i], "-",predicts[i])


quite unable to explain why he should feel - gofite unalle to explain whey hald feel
meet the deanes , and as soon as guy had - mect the decemes , and as scon as 63uy had
the horses and drank enough to cure our - the harses and denk ghough to use uw
you not killed ? ' ' because we know all things , ' the - you not lilled . ' becuse we know all things , ' the
with the possibility of faulty design . " he held - with the possibility of faulty design . " ste held
all due deference , miss deane - come off it !  - all due deference , miss deaue - coms oft ' . 
it would have been acceptable to all concerned - 3t would have been accepstable to all concermed
to make you understand just what happened - to wake you udustard just what hapered
course of action should be . first , to avoid the - course of action should be . first to avoid the
 ( stamp department ) while sally sulked at home .  - istams cepartment ) while sally sulted at homo . 
bill is good man , and bueno buck is raised on lake .  

In [20]:
import torch
xcit_small_12_p8 = torch.hub.load('facebookresearch/dino:main',"dino_resnet50" )

Using cache found in /home/mhamdan/.cache/torch/hub/facebookresearch_dino_main
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth" to /home/mhamdan/.cache/torch/hub/checkpoints/dino_resnet50_pretrain.pth


  0%|          | 0.00/90.0M [00:00<?, ?B/s]

In [35]:
a = torch.rand(1,3,1024,128).cuda()

In [22]:
_ = xcit_small_12_p8.cuda()

In [9]:
# import torch.distributed as dist
# dist.init_process_group('gloo', init_method='file:///tmp/somefile', rank=0, world_size=1)

KeyboardInterrupt: 

In [32]:
    def get_feature(backbone,x):
        x = backbone.conv1(x)
        x = backbone.bn1(x)   
        x = backbone.relu(x)
        x = backbone.maxpool(x)

        x = backbone.layer1(x)
        x = backbone.layer2(x)
        x = backbone.layer3(x)
        x = backbone.layer4(x)
        return x


In [36]:
b = get_feature(xcit_small_12_p8, a)

In [37]:
b.shape

torch.Size([1, 2048, 32, 4])

In [27]:
b =  xcit_small_12_p8.forward(a)
d =  xcit_small_12_p8(a)

In [28]:
d.shape, b.shape

(torch.Size([1, 2048]), torch.Size([1, 2048]))

In [10]:
class OCR(nn.Module):

    def __init__(self, vocab_len, hidden_dim, nheads,
                 num_encoder_layers, num_decoder_layers):
        super().__init__()
    
        self.hidden_dim = hidden_dim
        # create ResNet-101 backbone
        self.backbone = torch.hub.load('facebookresearch/dino:main',"dino_xcit_small_12_p8")

        # create conversion layer
        self.converter = nn.Linear(self.backbone.embed_dim, hidden_dim)

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads with length of vocab
        # DETR used basic 3 layer MLP for output
        self.vocab = nn.Linear(hidden_dim,vocab_len)

        # output positional encodings (object queries)
        self.decoder = nn.Embedding(vocab_len, hidden_dim)
        self.query_pos = PositionalEncoding(hidden_dim, .2)
#         self.cnn_pos = PositionalEncodingCNN(hidden_dim,.2)

        # spatial positional encodings, sine positional encoding can be used.
        # Detr baseline uses sine positional encoding.
#         self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
#         self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.trg_mask = None
  
    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask==1, float('-inf'))
        return mask

    def get_feature(self,x):
        x = self.backbone.forward_features(x)
        return x


    def make_len_mask(self, inp):
        return (inp == 0).transpose(0, 1)


    def forward(self, inputs, trg):
        # propagate inputs through ResNet-101 up to avg-pool layer
        h = self.get_feature(inputs)
        h = h.unsqueeze(2).repeat(1,1,self.hidden_dim)
        # convert from 2048 to 256 feature planes for the transformer
#         h = self.converter(x)
#         # construct positional encodings
#         bs,_,H, W = h.shape
#         pos = torch.cat([
#             self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
#             self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
#         ], dim=-1).flatten(0, 1).unsqueeze(1)

#         # generating subsequent mask for target
        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_subsequent_mask(trg.shape[1]).to(trg.device)

        # Padding mask
        trg_pad_mask = self.make_len_mask(trg)

        # Getting postional encoding for target
        trg = self.decoder(trg)
        trg = self.query_pos(trg)
        
        output = self.transformer(h.permute(1, 0, 2), trg.permute(1,0,2), tgt_mask=self.trg_mask, 
                                  tgt_key_padding_mask=trg_pad_mask.permute(1,0))

        return self.vocab(output.transpose(0,1))


In [11]:
def make_model(vocab_len, hidden_dim=256, nheads=4,
                 num_encoder_layers=4, num_decoder_layers=4):
    
    return OCR(vocab_len, hidden_dim, nheads,
                 num_encoder_layers, num_decoder_layers)


In [12]:
m = make_model(20)

Using cache found in /home/mhamdan/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /home/mhamdan/.cache/torch/hub/facebookresearch_xcit_master


In [24]:
a = torch.rand(1,3,225,225).cuda()
b = torch.rand(1,128).long().cuda()

In [45]:
_ = m.cuda()

In [46]:
m(a,b)

torch.Size([1, 384, 256])
torch.Size([1, 128, 256])


tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       grad_fn=<AddBackward0>)

In [34]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    

In [35]:
count_parameters(m)

+-------------------------------------------------------------+------------+
|                           Modules                           | Parameters |
+-------------------------------------------------------------+------------+
|                          row_embed                          |    6400    |
|                          col_embed                          |    6400    |
|                      backbone.cls_token                     |    384     |
|             backbone.patch_embed.proj.0.0.weight            |    2592    |
|             backbone.patch_embed.proj.0.1.weight            |     96     |
|              backbone.patch_embed.proj.0.1.bias             |     96     |
|             backbone.patch_embed.proj.2.0.weight            |   165888   |
|             backbone.patch_embed.proj.2.1.weight            |    192     |
|              backbone.patch_embed.proj.2.1.bias             |    192     |
|             backbone.patch_embed.proj.4.0.weight            |   663552   |

37525972