In [1]:
from pathlib import Path
import numpy as np
import math
import h5py
import numpy as np
import unicodedata
import cv2
import torch
from torch import nn
from torch.utils.data import Dataset
import time
import timm
import random
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
import itertools
import string
from torch.autograd import Variable
from tqdm.autonotebook import tqdm

def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

set_random_seeds(random_seed=13)


class CFG:
    debug = False
    batch_size = 200
    num_workers = 6
    head_lr = 0.0006
    image_encoder_lr = 0.0001
    text_encoder_lr = 0.0001
    weight_decay = 1e-3
    patience = 5
    factor = 0.8
    epochs = 200
    device = torch.device("cuda:1")

    image_embedding = 2048
    text_embedding = 300
    max_length = 30

    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    temperature = 1.0

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    projection_dim = 256 
    dropout = 0.1
    
class Tokenizer():
    """Manager tokens functions and charset/dictionary properties"""

    def __init__(self, chars, max_text_length=CFG.max_length):
        self.PAD_TK, self.UNK_TK,self.SOS,self.EOS = "¶", "¤", "SOS", "EOS"
        self.chars = [self.PAD_TK] + [self.UNK_TK ]+ [self.SOS] + [self.EOS] +list(chars)
        self.PAD = self.chars.index(self.PAD_TK)
        self.UNK = self.chars.index(self.UNK_TK)

        self.vocab_size = len(self.chars)
        self.maxlen = max_text_length

    def encode(self, text):
        """Encode text to vector"""
        text = text.decode("utf-8") 
        text = unicodedata.normalize("NFKD", text).encode("ASCII", "ignore").decode("ASCII")
#         text = " ".join(text.split())

#         groups = ["".join(group) for _, group in groupby(text)]
#         text = "".join([self.UNK_TK.join(list(x)) if len(x) > 1 else x for x in groups])
        encoded = []

        text = ['SOS'] + list(text) + ['EOS']
        for item in text:
            index = self.chars.index(item)
            index = self.UNK if index == -1 else index
            encoded.append(index)

        return np.asarray(encoded)

    def decode(self, text):
        """Decode vector to text"""
        
        decoded = "".join([self.chars[int(x)] for x in text if x > -1])
        decoded = self.remove_tokens(decoded)
        decoded = pp.text_standardize(decoded)

        return decoded

    def remove_tokens(self, text):
        """Remove tokens (PAD) from text"""

        return text.replace(self.PAD_TK, "").replace(self.UNK_TK, "")

class DataGenerator(Dataset):
    """Generator class with data streaming"""

    def __init__(self, source, split, transform, tokenizer):
        self.tokenizer = tokenizer
        self.transform = transform
        
        self.split = split
        self.dataset = dict()

#         self.dataset = h5py.File(source, "r")
        with h5py.File(source, "r") as f:
            self.dataset[self.split] = dict()

            self.dataset[self.split]['dt'] = np.array(f[self.split]['dt'])
            self.dataset[self.split]['gt'] = np.array(f[self.split]['gt'])
#             self.dataset[self.split]['label'] = np.array(f[self.split]['label'])            
          
#             randomize = np.arange(len(self.dataset[self.split]['gt']))
#             np.random.seed(42)
#             np.random.shuffle(randomize)

#             self.dataset[self.split]['dt'] = self.dataset[self.split]['dt'][randomize]
#             self.dataset[self.split]['gt'] = self.dataset[self.split]['gt'][randomize]
#         print(self.dataset[self.split]['gt'].shape)
    
        self.size = len(self.dataset[self.split]['gt'])


    def __getitem__(self, i):
        img = self.dataset[self.split]['dt'][i]
        #making image compatible with resnet
#         img = cv2.transpose(img)
#         img = np.repeat(img[..., np.newaxis],3, -1).astype("float32")   
#         img = pp.normalization(img).astype("float32")

        if self.transform is not None:
            aug = self.transform(image=img)
            img = aug['image']
            
            
#             img = self.transform(img)
        y_train = self.tokenizer.encode(self.dataset[self.split]['gt'][i].lower()) 
#         print(self.dataset[self.split]['gt'][i])
#         print(len(self.dataset[self.split]['gt'][i]))
#         if len(y_train)==0:
#             asdas
#         print(y_train)
#         print()
        #padding till max length
        y_train = np.pad(y_train, (0, self.tokenizer.maxlen - len(y_train)))
#         if all(y_train==0):
#             print(self.dataset[self.split]['gt'][i])
#             print("afdas")
#             ssa
        gt = torch.Tensor(y_train)
#         label = self.dataset[self.split]['label'][i]
        label = 1        
        if label==0:
            label = -1
            
        return img, gt,label         

    def __len__(self):
      return self.size

charset_base = string.printable[:95]
tokenizer = Tokenizer(charset_base)
class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    def __init__(self, size, padding_idx=0, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False)
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False))

criterion = LabelSmoothing(size=tokenizer.vocab_size, padding_idx=0, smoothing=.1)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=CFG.max_length):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x
    
cos_loss = nn.CosineEmbeddingLoss(reduction="mean", margin=.5)
class Clip(nn.Module):

    def __init__(self, 
                 tokenizer,
                 temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
):
        super().__init__()
    
#         self.backbone = resnet101(pretrained=args.pretrained)
        self.backbone = timm.create_model(
                    "resnest26d", True, num_classes=0, global_pool="avg"
                )
#         for p in self.backbone.parameters():
#             p.requires_grad = True
            
        self.tokenizer = tokenizer
        self.embeding = nn.Embedding(self.tokenizer.vocab_size,CFG.text_embedding)
        self.conv1 = nn.Conv1d(CFG.text_embedding,32,8)        
        self.gelu = nn.GELU()
        
#         self.pos_encoding = PositionalEncoding(CFG.text_embedding, .2)
            
#         encoder_layer = nn.TransformerEncoderLayer(d_model=CFG.text_embedding, nhead=4, dropout=.2)
#         self.text_encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
        
        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0
        self.temperature = temperature
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=736)
        
    def make_len_mask(self, inp):
        return (inp == 0)
            
    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask==1, float('-inf'))
        return mask
        
    def forward(self, img, input_ids, check):        
        
        image_features = self.backbone.forward(img)        
#         with open("das.sa","a") as f:
#             f.write(str(pad_mask.sum(0)))
                
        input_ids = self.embeding(input_ids)
        text_features = self.gelu(self.conv1(input_ids.permute(0,2,1)))
        text_features = text_features.flatten(1)
        
        
        
#         input_ids = self.pos_encoding(input_ids.permute(1,0,2))        
#         input_ids = input_ids.permute(1,0,2)
#         last_hidden_state = self.text_encoder(input_ids)
#         print()
#         print(last_hidden_state[:, self.target_token_idx, :])        
#         text_features = last_hidden_state[:, self.target_token_idx, :]
        
        # Getting Image and Text Embeddings (with same dimension)
        
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features) 
            
# #         print(image_features,text_features)
#         loss = cos_loss(image_embeddings, text_embeddings, check.to(CFG.device))
#         return loss
#         loss = 1-F.cosine_similarity(image_embeddings, text_embeddings)
#         Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()




In [12]:
source = "iam_aug"
source_path = '../data/{}.hdf5'.format(source)

In [2]:
import albumentations
import albumentations.pytorch

transform_train = albumentations.Compose([
    albumentations.OneOf(
        [
            albumentations.MotionBlur(p=1, blur_limit=8),
            albumentations.OpticalDistortion(p=1, distort_limit=0.05),
            albumentations.GaussNoise(p=1, var_limit=(10.0, 100.0)),
            albumentations.RandomBrightnessContrast(p=1, brightness_limit=0.2),
            albumentations.Downscale(p=1, scale_min=0.8, scale_max=.9),
        ],
        p=.5,
    ),
#         albumentations.Resize(224,224),
    albumentations.Normalize(),
    albumentations.pytorch.ToTensorV2()

])

transform_valid = albumentations.Compose(
    [
#         albumentations.Resize(224,224),            
        albumentations.Normalize(),
        albumentations.pytorch.ToTensorV2()
    ]
)

In [4]:
model = Clip(tokenizer).to(CFG.device)

In [6]:
# d = torch.load("../output/clip_entropy",map_location="cuda:0")

d = torch.load("../output/clip_entropy_26iam_aug_epoch10.pt",map_location="cpu")
_ = model.load_state_dict(d)

In [7]:
source = "../data/iam_aug.hdf5"
split = "test"
dataset = {}
with h5py.File(source, "r") as f:
    dataset[split] = dict()
    dataset[split]['dt'] = np.array(f[split]['dt'])
    dataset[split]['gt'] = np.array(f[split]['gt'])


In [8]:
img = dataset[split]['dt'][49]
# img = np.repeat(img[..., np.newaxis],3, -1).astype("float32")   
aug = transform_valid(image=img)
img = aug['image']

In [8]:
from collections import Counter

c = Counter(dataset[split]['gt'])

mapping = {}
for i in c.keys():
    mapping[i] = np.where(dataset[split]['gt']==i)[0]

In [9]:
from PIL import Image

In [13]:
valid_loader = torch.utils.data.DataLoader(DataGenerator(source_path,'test',transform_valid, tokenizer), batch_size=300, shuffle=False, num_workers=1)

In [11]:
_=model.eval()

In [14]:
CFG.device = "cpu"

In [15]:
valid_image_embeddings = []
for img, input_ids,check in valid_loader:
    with torch.no_grad():
        image_features = model.backbone(img.to(CFG.device))
        image_embeddings = model.image_projection(image_features)
        image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
        valid_image_embeddings.append(image_embeddings)        
        

In [16]:
valid_image_embeddings = [i.cpu().numpy() for i in valid_image_embeddings]

In [17]:
valid_image_embeddings = np.concatenate(valid_image_embeddings)

In [18]:
scores = []
ct=0
for i in range(len(valid_image_embeddings)):    
    label = dataset[split]['gt'][i]    
    if len(mapping[label])==1:
        continue
    c = valid_image_embeddings[i] @ valid_image_embeddings.T
    score = len(set(np.argsort(c)[::-1][:len(mapping[label])]).intersection(mapping[label]))/len(mapping[label])
    scores.append(score)
#     if score<.6:
#         print(label, i)
    ct+=1

In [19]:
sum(scores)/len(scores)

0.8406309541642519

In [42]:
sum(scores)/len(scores)

0.9806200302718646

In [51]:
sum(scores)/len(scores)

0.9806059417962377

In [39]:
sum(scores)/len(scores)

0.9854132457580715

In [15]:
sum(scores)/len(scores)

0.9849278977727236

In [20]:
t_embeddings = []
for key in mapping.keys():
    encoded_query = tokenizer.encode(key)
    encoded_query = np.pad(encoded_query, (0, 30 - len(encoded_query)))
    # encoded_query = tokenizer(["dsad dsa dasdas"],padding='max_length', truncation=True, max_length=CFG.max_length)
    with torch.no_grad():
        a = model.embeding(torch.Tensor(encoded_query).unsqueeze(0).long().to(CFG.device))
        text_features2 = model.gelu((model.conv1(a.permute(0,2,1))))
        text_features2 = text_features2.flatten(1)    
        text_embeddings = model.text_projection(text_features2)  
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
        text_embeddings = text_embeddings[0].cpu().numpy()
        t_embeddings.append(text_embeddings)

In [16]:
t_embeddings = []
for key in mapping.keys():
    encoded_query = tokenizer.encode(key)
    encoded_query = np.pad(encoded_query, (0, 30 - len(encoded_query)))
    # encoded_query = tokenizer(["dsad dsa dasdas"],padding='max_length', truncation=True, max_length=CFG.max_length)
    with torch.no_grad():
        input_ids = model.embeding(torch.Tensor(encoded_query).unsqueeze(0).long().to(CFG.device))
        input_ids = model.pos_encoding(input_ids.permute(1,0,2))        
        input_ids = input_ids.permute(1,0,2)
        last_hidden_state = model.text_encoder(input_ids)
        text_features = last_hidden_state[torch.arange(last_hidden_state.shape[0]), encoded_query.argmax()]
        
        # Getting Image and Text Embeddings (with same dimension)
        
        text_embeddings = model.text_projection(text_features)         
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
        text_embeddings = text_embeddings[0].cpu().numpy()
        t_embeddings.append(text_embeddings)

AttributeError: 'Clip' object has no attribute 'pos_encoding'

In [21]:
scores = []
ct=0
for i,label in enumerate(mapping.keys()):    
    if len(mapping[label])==1:
        continue
    c = t_embeddings[i] @ valid_image_embeddings.T
    score = len(set(np.argsort(c)[::-1][:len(mapping[label])]).intersection(mapping[label]))/len(mapping[label])
    scores.append(score)
    if score<.6:
        print(label, i)
    ct+=1

b'writer' 25
b'wrote' 32
b'likes' 40
b'mother' 58
b'chair' 77
b'normally' 93
b'seen' 98
b'courtenay' 103
b'regarded' 112
b't v' 114
b'potter' 127
b'wide' 148
b'arrived' 187
b'fight' 201
b'telephone' 202
b'textual' 249
b'explanations' 250
b'fell' 257
b'book' 340
b'13' 352
b'use' 373
b'writing' 380
b'cut' 387
b'reservoirs' 416
b'corner' 469
b'rock' 474
b'security' 488
b'impression' 498
b'created' 512
b'diligently' 517
b'1ye' 574
b'greatest' 577
b'fruit' 584
b'cure' 618
b'gives' 640
b'meek' 646
b'thank' 671
b'though' 698
b'rome' 701
b'king' 725
b'came' 746
b'battle' 765
b'necessarily' 775
b'uses' 786
b'reverence' 793
b'fear' 842
b'cup' 868
b'bed' 899
b'night' 907
b'danger' 914
b'suppose' 923
b'tired' 924
b'journey' 926
b'seated' 949
b'illness' 966
b'feared' 975
b'comes' 980
b'lived' 983
b'months' 999
b'desperately' 1022
b'conscious' 1027
b'un' 1043
b'save' 1047
b'jobs' 1053
b'next' 1066
b'drift' 1069
b'reached' 1070
b'becomes' 1080
b'interested' 1083
b'toes' 1111
b'words' 1171
b'continue'

In [22]:
sum(scores)/len(scores)

0.8785374417963444

In [54]:
sum(scores)/len(scores)

0.9371581622247703

In [21]:
sum(scores)/len(scores)

0.8335253261688096

In [42]:
sum(scores)/len(scores)

0.9918401742903649

In [35]:
sum(scores)/len(scores)

0.9854395718807263

In [18]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

+---------------------------------------+------------+
|                Modules                | Parameters |
+---------------------------------------+------------+
|        backbone.conv1.0.weight        |    864     |
|        backbone.conv1.1.weight        |     32     |
|         backbone.conv1.1.bias         |     32     |
|        backbone.conv1.3.weight        |    9216    |
|        backbone.conv1.4.weight        |     32     |
|         backbone.conv1.4.bias         |     32     |
|        backbone.conv1.6.weight        |   18432    |
|          backbone.bn1.weight          |     64     |
|           backbone.bn1.bias           |     64     |
|     backbone.layer1.0.conv1.weight    |    4096    |
|      backbone.layer1.0.bn1.weight     |     64     |
|       backbone.layer1.0.bn1.bias      |     64     |
|  backbone.layer1.0.conv2.conv.weight  |   36864    |
|   backbone.layer1.0.conv2.bn0.weight  |    128     |
|    backbone.layer1.0.conv2.bn0.bias   |    128     |
|   backbo

15972804

In [22]:
15972804/1e6

15.972804