In [11]:
%cd /lustre07/scratch/hamdan/Training Process/
import argparse
import string


/lustre07/scratch/hamdan/Training Process


In [23]:
import os
current_path = os.getcwd()
current_path

'/lustre07/scratch/hamdan/Training Process'

In [4]:
from pathlib import Path
import numpy as np
import math
from itertools import groupby
import h5py
import numpy as np
import unicodedata
import cv2
import torch
from torch import nn
from torchvision.models import resnet50, resnet101
from torch.autograd import Variable
import torchvision
from data import preproc as pp
from data import evaluation
from torch.utils.data import Dataset
import time
import timm


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=46):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class OCR(nn.Module):
    def __init__(self, vocab_len, maxlen, hidden_dim, nheads,num_encoder_layers, num_decoder_layers):
        super().__init__()
    
#         self.backbone = timm.create_model('ecaresnet101d', pretrained=True,)
#         del self.backbone.fc
        self.backbone = resnet101(pretrained=True)
        del self.backbone.fc
        
        
#         del self.backbone.classifier, self.backbone.conv_head, self.backbone.bn2,self.backbone.act2,self.backbone.global_pool
        _ = self.backbone.to("cpu")
#         for name,p in self.backbone.named_parameters():
#             if "bn" not in name or "attnpool" in name:
#                 p.requires_grad =  False

        # create a default PyTorch transformer
        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads with length of vocab
        # DETR used basic 3 layer MLP for output
        self.vocab = nn.Linear(hidden_dim,vocab_len)

        # output positional encodings (object queries)
        self.decoder = nn.Embedding(vocab_len, hidden_dim)
        self.query_pos = PositionalEncoding(hidden_dim, .2)

        # spatial positional encodings, sine positional encoding can be used.
        # Detr baseline uses sine positional encoding.
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.trg_mask = None
  
    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask==1, float('-inf'))
        return mask
    

    def get_feature(self,x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)   
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)
        return x


    def make_len_mask(self, inp):
        return (inp == 0).transpose(0, 1)


    def forward(self, inputs, trg):
        # propagate inputs through ResNet-101 up to avg-pool layer
        x = self.get_feature(inputs)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)

        # construct positional encodings
        bs,_,H, W = h.shape
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)

        # generating subsequent mask for target
        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_subsequent_mask(trg.shape[1]).to(trg.device)

        # Padding mask
        trg_pad_mask = self.make_len_mask(trg)

        # Getting postional encoding for target
        trg = self.decoder(trg)
        trg = self.query_pos(trg)
        
        output = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1), trg.permute(1,0,2), tgt_mask=self.trg_mask, 
                                  tgt_key_padding_mask=trg_pad_mask.permute(1,0))

        return self.vocab(output.transpose(0,1))


def make_model(vocab_len, maxlen, hidden_dim=256, nheads=6,
                 num_encoder_layers=2, num_decoder_layers=6):
    
    return OCR(vocab_len, maxlen, hidden_dim, nheads,
                 num_encoder_layers, num_decoder_layers)


In [5]:
"""
Uses generator functions to supply train/test with data.
Image renderings and text are created on the fly each time.
"""

class DataGenerator(Dataset):
    """Generator class with data streaming"""

    def __init__(self, source, split, transform, tokenizer):
        self.tokenizer = tokenizer
        self.transform = transform
        
        self.split = split
        self.dataset = dict()

        with h5py.File(source, "r") as f:
            self.dataset[self.split] = dict()

            self.dataset[self.split]['dt'] = np.array(f[self.split]['dt'])
            self.dataset[self.split]['gt'] = np.array(f[self.split]['gt'])
          
            randomize = np.arange(len(self.dataset[self.split]['gt']))
            np.random.seed(42)
            np.random.shuffle(randomize)

            self.dataset[self.split]['dt'] = self.dataset[self.split]['dt'][randomize]
            self.dataset[self.split]['gt'] = self.dataset[self.split]['gt'][randomize]

            # decode sentences from byte
            self.dataset[self.split]['gt'] = [x.decode() for x in self.dataset[self.split]['gt']]
            
        self.size = len(self.dataset[self.split]['gt'])


    def __getitem__(self, i):
        img = self.dataset[self.split]['dt'][i]
        
        #making image compatible with resnet
        img = np.repeat(img[..., np.newaxis],3, -1).astype("float32")   
#         img = pp.normalization(img).astype("float32")

        if self.transform is not None:
            aug = self.transform(image=img)
            img = aug['image']
            
#             img = self.transform(img)
            
        y_train = self.tokenizer.encode(self.dataset[self.split]['gt'][i]) 
        
        #padding till max length
        y_train = np.pad(y_train, (0, self.tokenizer.maxlen - len(y_train)))

        gt = torch.Tensor(y_train)

        return img, gt          

    def __len__(self):
      return self.size

In [6]:
class Tokenizer():
    """Manager tokens functions and charset/dictionary properties"""

    def __init__(self, chars, max_text_length=128):
        self.PAD_TK, self.UNK_TK,self.SOS,self.EOS = "¶", "¤", "SOS", "EOS"
        self.chars = [self.PAD_TK] + [self.UNK_TK ]+ [self.SOS] + [self.EOS] +list(chars)
        self.PAD = self.chars.index(self.PAD_TK)
        self.UNK = self.chars.index(self.UNK_TK)

        self.vocab_size = len(self.chars)
        self.maxlen = max_text_length

    def encode(self, text):
        """Encode text to vector"""
        text = unicodedata.normalize("NFKD", text).encode("ASCII", "ignore").decode("ASCII")
        text = " ".join(text.split())

        groups = ["".join(group) for _, group in groupby(text)]
        text = "".join([self.UNK_TK.join(list(x)) if len(x) > 1 else x for x in groups])
        encoded = []

        text = ['SOS'] + list(text) + ['EOS']
        for item in text:
            index = self.chars.index(item)
            index = self.UNK if index == -1 else index
            encoded.append(index)

        return np.asarray(encoded)

    def decode(self, text):
        """Decode vector to text"""
        
        decoded = "".join([self.chars[int(x)] for x in text if x > -1])
        decoded = self.remove_tokens(decoded)
        decoded = pp.text_standardize(decoded)

        return decoded

    def remove_tokens(self, text):
        """Remove tokens (PAD) from text"""

        return text.replace(self.PAD_TK, "").replace(self.UNK_TK, "")


In [7]:
import os
import datetime
import string
import albumentations
import albumentations.pytorch
import torchvision.transforms as T

batch_size = 16
epochs = 200

# define paths
#change paths accordingly
source = 'iam_paragraph_finetune'
source_path = './datahdf5/iam/{}.hdf5'.format(source)
output_path = os.path.join(".", "output", source)
target_path = os.path.join(output_path, "checkpoint_weights_iam_{}.hdf5".format("dsa"))
os.makedirs(output_path, exist_ok=True)

# define input size, number max of chars per line and list of valid chars
input_size = (384, 384, 3)
max_text_length = 633
charset_base = string.printable[:95]
# charset_base = string.printable[:36].lower() + string.printable[36+26:95].lower() 

print("source:", source_path)
print("output", output_path)
print("target", target_path)
print("charset:", charset_base)


source: ./datahdf5/iam/iam_paragraph_finetune.hdf5
output ./output/iam_paragraph_finetune
target ./output/iam_paragraph_finetune/checkpoint_weights_iam_dsa.hdf5
charset: 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ 


In [8]:
local_rank = 1
device = torch.device("cuda:{}".format(local_rank))

# transform = T.Compose([
#     T.ToTensor()])
tokenizer = Tokenizer(charset_base)

transform_valid = albumentations.Compose(
    [
        albumentations.Normalize(),
        albumentations.pytorch.ToTensorV2()
    ]
)


In [9]:
charset_base

'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '

In [10]:
tokenizer = Tokenizer(charset_base)


In [11]:
model = make_model( vocab_len= tokenizer.vocab_size, maxlen= max_text_length ,hidden_dim=256, nheads=4,
                 num_encoder_layers=4, num_decoder_layers=4)



In [12]:
print(device, '  is used')
_ = model.to(device)

cuda:1   is used


RuntimeError: No CUDA GPUs are available

In [2]:
""""
Uses generator functions to supply train/test with data.
Image renderings and text are created on the fly each time.
"""

class DataGenerator(Dataset):
    """Generator class with data streaming"""

    def __init__(self, source, split, transform, tokenizer):
        self.tokenizer = tokenizer
        self.transform = transform
        
        self.split = split
        self.dataset = dict()

        with h5py.File(source, "r") as f:
            self.dataset[self.split] = dict()

            self.dataset[self.split]['dt'] = np.array(f[self.split]['dt'])
            self.dataset[self.split]['gt'] = np.array(f[self.split]['gt'])
          
#             randomize = np.arange(len(self.dataset[self.split]['gt']))
#             np.random.seed(42)
#             np.random.shuffle(randomize)

#             self.dataset[self.split]['dt'] = self.dataset[self.split]['dt'][randomize]
#             self.dataset[self.split]['gt'] = self.dataset[self.split]['gt'][randomize]

            # decode sentences from byte
            self.dataset[self.split]['gt'] = [x.decode() for x in self.dataset[self.split]['gt']]
            
        self.size = len(self.dataset[self.split]['gt'])


    def __getitem__(self, i):
        img = self.dataset[self.split]['dt'][i]
        
        #making image compatible with resnet
#         img = cv2.transpose(img)
#         img = np.repeat(img[..., np.newaxis],3, -1).astype("float32")   
#         img = pp.normalization(img).astype("float32")

        if self.transform is not None:
            aug = self.transform(image=img)
            img = aug['image']
            
#             img = self.transform(img)
        
#         print(self.dataset[self.split]['gt'][i])
        y_train = self.tokenizer.encode(self.dataset[self.split]['gt'][i]) 
        
        #padding till max length
        y_train = np.pad(y_train, (0, self.tokenizer.maxlen - len(y_train)))

        gt = torch.Tensor(y_train)

        return img, gt          

    def __len__(self):
      return self.size

class Tokenizer():
    """Manager tokens functions and charset/dictionary properties"""

    def __init__(self, chars, max_text_length=630):
        self.PAD_TK, self.UNK_TK,self.SOS,self.EOS = "¶", "¤", "SOS", "EOS"
        self.chars = [self.PAD_TK] + [self.UNK_TK ]+ [self.SOS] + [self.EOS] +list(chars)
        self.PAD = self.chars.index(self.PAD_TK)
        self.UNK = self.chars.index(self.UNK_TK)

        self.vocab_size = len(self.chars)
        self.maxlen = max_text_length

    def encode(self, text):
        """Encode text to vector"""
#         text = unicodedata.normalize("NFKD", text).encode("ASCII", "ignore").decode("ASCII")
#         text = " ".join(text.split())

#         groups = ["".join(group) for _, group in groupby(text)]
#         text = "".join([self.UNK_TK.join(list(x)) if len(x) > 1 else x for x in groups])
        text = str(text)
        encoded = []
        
        text = ['SOS'] + list(text.strip()) + ['EOS']
        for item in text:
            index = self.chars.index(item)
            index = self.UNK if index == -1 else index
            encoded.append(index)

        return np.asarray(encoded)

    def decode(self, text):
        """Decode vector to text"""
        
        decoded = "".join([self.chars[int(x)] for x in text if x > -1])
        decoded = self.remove_tokens(decoded)
#         decoded = pp.text_standardize(decoded)

        return decoded

    def remove_tokens(self, text):
        """Remove tokens (PAD) from text"""

        return text.replace(self.PAD_TK, "").replace(self.UNK_TK, "")

import os
import datetime
import string

batch_size = 32
epochs = 200


max_text_length = 770
charset_base = 'kCE3̄cw¬NaorūxbgOÖö5tDsSRā>L(G\nB61ÿ8e,<¾Q0ōäizußY2ZA)y-F9PmTKfdVUüvMn/jph+4ēIJWl7q:ȳH—.̈ '
# charset_base = string.printable[:36].lower() + string.printable[36+26:95].lower()
print("charset:", charset_base)

device = torch.device("cuda:{}".format(0))
tokenizer = Tokenizer(charset_base, max_text_length)


transform_valid = albumentations.Compose(
    [
#         albumentations.Resize(224,224),            
        albumentations.Normalize(),
        albumentations.pytorch.ToTensorV2()
    ]
)

charset: kCE3̄cw¬NaorūxbgOÖö5tDsSRā>L(G
B61ÿ8e,<¾Q0ōäizußY2ZA)y-F9PmTKfdVUüvMn/jph+4ēIJWl7q:ȳH—.̈ 


In [5]:
charset.chars.index(' ')

NameError: name 'charset' is not defined

In [3]:
num_encoder_layers = 2
num_decoder_layers = 6

ddp_model = make_model(vocab_len=tokenizer.vocab_size, maxlen=tokenizer.maxlen, hidden_dim=384, nheads=6,
                 num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)

In [4]:
file_path = "reads_paragraph_finetune.hdf5"
# file_path = "iam_paragraph_finetune.hdf5"
test_loader = torch.utils.data.DataLoader(DataGenerator("{}".format(file_path),'test',transform_valid, tokenizer), batch_size=1, num_workers=1)
# test_loader = torch.utils.data.DataLoader(DataGenerator("{}".format(file_path),'valid',transform_valid, tokenizer), batch_size=1, num_workers=1)
# train_loader = torch.utils.data.DataLoader(DataGenerator("{}".format(file_path),'train',transform_train, tokenizer), batch_size=batch_size, num_workers=1,shuffle=True)

In [5]:
checkpoint = torch.load("swinreads_paragraph_finetunebest_loss.pt", map_location="cpu")

In [6]:
checkpoint = torch.load("teacher_model.pt", map_location="cpu")

In [7]:
d = {}
for i in checkpoint:
    d[i.replace("teacher_model.","")] = checkpoint[i]

In [16]:
_ = ddp_model.load_state_dict(checkpoint['model_state_dict'])
_ = ddp_model.to(device)
_ = ddp_model.eval()

KeyError: 'model_state_dict'

In [18]:
def get_memory(model,imgs):
    
    with torch.no_grad():    
        x = model.backbone.features(imgs)
        h = model.conv(x.permute(0,3,1,2))
        
        bs,_,H, W = h.shape
        pos = torch.cat([
                model.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
                model.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
            ], dim=-1).flatten(0, 1).unsqueeze(1)

        h = pos +  0.1 * h.flatten(2).permute(2, 0, 1)
    return h
    
    
predicts = []
gt = []

def test(model, test_loader, max_text_length):
    imgs = []
    c=0
    with torch.no_grad():
        for batch in test_loader:
            src, trg = batch
            imgs.append(src.flatten(0,1))
            src, trg = src.to(device), trg.to(device)            
            memory = get_memory(model,src.float())
            out_indexes = [tokenizer.chars.index('SOS'), ]
#             print(memory.shape)
            for i in range(max_text_length):
                mask = model.generate_square_subsequent_mask(i+1).to(device)
                trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device)
                trg_tensor = model.decoder(trg_tensor)
                trg_tensor = model.query_pos(trg_tensor)
#                 trg_tensor = trg_tensor.permute(1,0,2)
#                 print(trg_tensor.shape)
                output = model.vocab(model.transformer_decoder(trg_tensor, memory,tgt_mask=mask))
                out_token = output.argmax(2)[-1].item()
                out_indexes.append(out_token)
                if out_token == tokenizer.chars.index('EOS'):
                    break
            predicts.append(tokenizer.decode(out_indexes))
            gt.append(tokenizer.decode(trg.flatten(0,1)))
#             if c==2:
#                 break
            c+=1
    return predicts, gt, imgs

In [10]:
predicts, gt, imgs = test1(ddp_model,test_loader , max_text_length)

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

In [38]:
predicts = list(map(lambda x : x.replace('SOS','').replace('EOS',''),predicts))
gt = list(map(lambda x : x.replace('SOS','').replace('EOS',''),gt))

In [39]:
gt,predicts

(['226',
  'dieweiln man in disem\nLanndt Allerlai gfar.\nvnd dūrchzüg Zū besorg.\nvnd dise vmbligennde\nGericht. Ir Traid sonst\naūfm KornPlaz. vnnd\nden Peckhen hieher fiern\ndem Zeffer APottegg\nZū Brixen. bei Straff\nAūfzūladen. sich solicher\nAūf khaūffūng. so allain\nZū sein Aignen Nūz. vnd\ngemeltem KornPlaz.\nals ainem Frl: Lehen.\nZū schaden vnd Für khaūff\nangsehen.',
  'Aūf das: von Ir Gn: hn\nLanndthaūbtmann an\nder Etsch etc. Abganng'],
 ['62',
  'Ainem, mann in disem\nLanndt Aller Zū gfar\nvnd die vmbligende\nvnd die sich Zūbemelt.\nGericht. Ir Traid sonst\naūfm Prackhen hieher vnd\ndem Peckhen hier fiern,\ndem Zachen Parteg\nZū Brixen. so soliche\nAūfzūladen. so solle\nZū khaūffen. soll alle\nAūfzūladen. so alle\nZū sein Aignen Waiz. vnd\ngemeltem KornPlaz.\nals ainem Frl: Lehen¬\naūsgeschen. vnd fir khaūff\nZūschehen.',
  'Aūf das: von Ir Gn: hln\nLanndthaūbtmann an\nder Etsch etc. Abganng\ner Etsch etc. Abganng'])

In [19]:
gt, predicts

(['226'], ['62'])

In [14]:
gt, predicts

(['Become a success with a disc and hey presto ! You \'re a star.... Rolly sings with\nassuredness " Bella Bella Marie " ( Parlophone ), a lively song that changes tempo mid-way.\nI don\'t think he will storm the charts with this one, but it \'s a good start.\nCHRIS CHARLES, 39, who lives in Stockton-on-Tees, is an accountant.'],
 ['Become a success with a disc and hey presto ! You \'re a star... Peolly sings with\nassuredness " Bella Bella Marie " ( Parlophone ), a lively song that changes tempo mid-way\nI don\'t think he will storm the charts with this one, but it \'s a good start.\nCherr chiefs, 3 , who has in Station. over. There, is an accountant.'])

In [10]:
len(gt)

100

In [11]:
len(predicts)

100

In [12]:
import string
import unicodedata
import editdistance
import numpy as np


def ocr_metrics(predicts, ground_truth, norm_accentuation=False, norm_punctuation=False):
    """Calculate Character Error Rate (CER), Word Error Rate (WER) and Sequence Error Rate (SER)"""

    if len(predicts) == 0 or len(ground_truth) == 0:
        return (1, 1, 1)

    cer, wer, ser = [], [], []

    for (pd, gt) in zip(predicts, ground_truth):
#         pd, gt = pd.lower(), gt.lower()

        if norm_accentuation:
            pd = unicodedata.normalize("NFKD", pd).encode("ASCII", "ignore").decode("ASCII")
            gt = unicodedata.normalize("NFKD", gt).encode("ASCII", "ignore").decode("ASCII")

        if norm_punctuation:
            pd = pd.translate(str.maketrans("", "", string.punctuation))
            gt = gt.translate(str.maketrans("", "", string.punctuation))

        pd_cer, gt_cer = list(pd), list(gt)
        dist = editdistance.eval(pd_cer, gt_cer)
        cer.append(dist / (max(len(pd_cer), len(gt_cer))))

        pd_wer, gt_wer = pd.split(), gt.split()
        dist = editdistance.eval(pd_wer, gt_wer)
        wer.append(dist / (max(len(pd_wer), len(gt_wer))))

        pd_ser, gt_ser = [pd], [gt]
        dist = editdistance.eval(pd_ser, gt_ser)
        ser.append(dist / (max(len(pd_ser), len(gt_ser))))

    metrics = [cer, wer, ser]
    metrics = np.mean(metrics, axis=1)

    return metrics

In [14]:
ocr_metrics(predicts, gt)

array([0.21861598, 0.34708263, 0.99404762])

In [30]:
ocr_metrics(predicts, gt)

array([0.01626205, 0.03218362, 0.54166667])

In [21]:
ocr_metrics(predicts, gt) #all pretrain finetune iam

array([0.03757567, 0.06686977, 0.83630952])

In [13]:
ocr_metrics(predicts, gt) #all pretrain finetune rimes

array([0.33396599, 0.47621471, 1.        ])