In [1]:
import os
import random
import functools
from functools import partial
import PIL

import numpy as np 
import pandas as pd

from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

import timm

from transformers import AutoTokenizer
from transformers import DistilBertModel

np.random.seed(1337)

In [2]:
df = pd.read_csv('data/train.csv')

In [3]:
# train val split
labels = np.random.permutation(df['label_group'].unique())

train_perc = 0.7
train_idx = int(train_perc * len(labels))

train_labels = labels[:train_idx]
val_labels = labels[train_idx:]

train_df = df[df['label_group'].isin(train_labels)]
val_df = df[df['label_group'].isin(val_labels)]

In [4]:

images_path = 'data/train_images/'
image_ids = [s.split('.')[0] for s in os.listdir(images_path)]
image_ids[:10]

['a756327777b2fe8e9383ab84468f3e5a',
 'c58ef3697a5ae269fc16a6d3d4b56877',
 '05f1bc3c03271e5d4b5af7a7b263facf',
 '2826339257579760a00d6e0a0065a89e',
 '451b108436f28c18b2dc8bf63b712c08',
 '2f5c66fc9bb86bc81dd461cc6ad574e1',
 'cbd375037243af186a4aa8d3aa08a489',
 '4edcfc239c17445eb949103c9f13eed3',
 '67e712eb6d5b08c0eb9e074cd7cc71c4',
 '712782879a3ef3acbb0c8eb945335528']

In [5]:
class TripletDataset(Dataset) :
    def __init__(self, images_path, df, img_tfms, testing, text_tokenizer=None):
        super(TripletDataset, self).__init__()
        
        self.images_path = images_path
        self.img_tfms = img_tfms
        self.testing = testing
              
        self.df = df.copy()
        self.df['label_group'] = self.df['label_group'].astype('category').cat.codes
        self.df['index'] = range(self.df.shape[0])
        self.labels = self.df['label_group'].unique()
        self.label_to_index_list = self.df.groupby('label_group')['index'].apply(list)
        
    def __getitem__(self, index) :
        index_meta = self.df.iloc[index]
        
        anchor_image, anchor_text = self._get_item(index)
        
        if self.testing: return anchor_image, anchor_text
        
        label = index_meta['label_group']
        
        # positive sample
        pos_index = random.choice(self.label_to_index_list[label])
        # we don't want the positive sample being the same as the anchor
        while pos_index == index :
            pos_index = random.choice(self.label_to_index_list[label])
        pos_image, pos_text = self._get_item(pos_index)
        
        #negative sample
        neg_label = random.choice(self.labels)
        # Negative sample has to be different label from anchor 
        while neg_label == index :
            neg_label = random.choice(self.labels)
        neg_index = random.choice(self.label_to_index_list[neg_label])
        neg_image, neg_text = self._get_item(neg_index)
        
        return anchor_image, anchor_text, pos_image, pos_text, neg_image, neg_text
        
    def _get_item(self, index) :
        image = PIL.Image.open(os.path.join(self.images_path, 
                                            self.df.iloc[index]['image']))
        image = self.img_tfms(image)
        text = self.df.iloc[index]['title']
        return image, text
    
    def __len__(self) :
        return self.df.shape[0]

In [6]:
def collate_fn(tokenizer, samples) :
    batch_size = len(samples)
    if len(samples[0]) == 2:
        images, texts = zip(*samples)
        images = torch.stack(images)
        texts = tokenizer(list(texts), padding=True, truncation=True, return_tensors="pt")
        return images, texts
    anchor_images, anchor_texts, pos_images, pos_texts, neg_images, neg_texts = zip(*samples)
    anchor_images = torch.stack(anchor_images)
    pos_images = torch.stack(pos_images)
    neg_images = torch.stack(neg_images)
    anchor_texts = tokenizer(list(anchor_texts), padding=True, truncation=True, return_tensors="pt")
    pos_texts = tokenizer(list(pos_texts), padding=True, truncation=True, return_tensors="pt")
    neg_texts = tokenizer(list(neg_texts), padding=True, truncation=True, return_tensors="pt")
    return anchor_images, anchor_texts, pos_images, pos_texts, neg_images, neg_texts

In [7]:
def create_dl(images_path, df_paths, img_tfms, pretrianed_tokenizer='distilbert-base-uncased', 
              batch_size=64, shuffle = True, testing = False) :
    dataset = TripletDataset(images_path, df_paths, img_tfms, testing)
    tokenizer = AutoTokenizer.from_pretrained(pretrianed_tokenizer)
    dl = DataLoader(dataset, batch_size=batch_size, collate_fn=partial(collate_fn, tokenizer), 
                    shuffle = shuffle, pin_memory = True)
    return dl

In [8]:
class EmbedorNN(nn.Module) :
    def __init__(self, pretrained_image_embedor='resnet18', pretrained_text_embedor='distilbert-base-uncased',
                output_dim=128) :
        super(EmbedorNN, self).__init__()
        self.image_embedor = timm.create_model(pretrained_image_embedor, pretrained=True)
        self.image_pool = nn.AdaptiveAvgPool2d((1,1))
        self.text_embedor = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.head = nn.Linear(512+768, output_dim)
    
    def forward(self, x) :
        images, texts = x
        out_images = self.image_embedor.forward_features(images)
        out_images = self.image_pool(out_images).squeeze()
        out_text = self.text_embedor(texts['input_ids'], 
                                     attention_mask=texts['attention_mask'])[0][:,0,:]
        out = torch.cat([out_images, out_text], dim=-1)
        return F.normalize(self.head(out), dim=-1)

In [9]:
device = torch.device('cuda')

In [10]:
cls = nn.Sequential(nn.Linear(256, 30), nn.ReLU(), nn.Dropout(0.1), nn.Linear(30,1)).to(device)

In [11]:
def initweights(module):
    """ Initialize the weights """
    if isinstance(module, (nn.Linear, nn.Embedding)):
        sz = module.weight.data.size(-1)
        module.weight.data.normal_(mean=0.0, std=1/np.sqrt(sz))
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero()
        module.weight.data.fill(1.0)
    elif isinstance(module, nn.Conv2d):
        n = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
        module.weight.data.normal_(0, math.sqrt(2. / n))
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()

In [12]:
for m in cls.parameters():
    initweights(m)

In [13]:
normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))

train_transforms = transforms.Compose([transforms.Resize((250, 250)),
                                       transforms.ColorJitter(.25,.25,.25),
                                       transforms.RandomRotation(5),
                                       transforms.RandomCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       normalize
                                       ])

val_transforms = transforms.Compose([transforms.Resize((224,224)),
                                     transforms.ToTensor(),
                                     normalize
                                     ])

In [14]:
model = EmbedorNN().to(device)
# weights: https://drive.google.com/drive/folders/19BGTC53p5YIAakWeCZfNIsceMX7cjHgp?usp=sharing
#model.load_state_dict(torch.load('3ep.pth'))

In [15]:
tr_dl = create_dl(images_path, train_df, train_transforms, shuffle = True)

In [16]:
val_dl = create_dl(images_path, val_df, val_transforms, shuffle = False)

In [17]:
testing_dl = create_dl(images_path, df, val_transforms, shuffle = False, testing = True)

In [17]:
n_epochs = 5
swa_start = int(0.75*n_epochs)

lf = nn.TripletMarginLoss()
lf2 = nn.BCEWithLogitsLoss()

lr = 1e-4
wd = 1e-3
no_decay = ["bias", "BatchNorm2d.weight", "BatchNorm2d.bias", "LayerNorm.weight", 'LayerNorm.bias']

optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)] + [p for n, p in cls.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": wd,
    },
    {
        "params": [p for n, p in  model.named_parameters() if any(nd in n for nd in no_decay)] + [p for n, p in  cls.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]


In [18]:
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)

# learning rate scheduler
sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr =lr, pct_start = 0.3, #anneal_strategy = 'linear',
                                            total_steps = int(n_epochs * len(tr_dl)))

scaler = torch.cuda.amp.GradScaler()

In [19]:
tr_triplet_losses = []
tr_bce_losses = []
val_triplet_losses = []
val_bce_losses = []
for ep in tqdm(range(n_epochs)):
    cls.train()
    model.train()
    tr_triplet_loss = []
    tr_bce_loss = []
    pbar = tqdm(tr_dl)
    for anchor_image, anchor_text, pos_image, pos_text, neg_image, neg_text in pbar:
        
        
        anchor = anchor_image.to(device), {'input_ids' : anchor_text['input_ids'].to(device),
                                           'attention_mask' : anchor_text['attention_mask'].to(device)}
        
        pos = pos_image.to(device), {'input_ids' : pos_text['input_ids'].to(device),
                                     'attention_mask' : pos_text['attention_mask'].to(device)}
        
        neg = neg_image.to(device), {'input_ids' : neg_text['input_ids'].to(device),
                                     'attention_mask' : neg_text['attention_mask'].to(device)}
        
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            anchor_emb = model(anchor)
            pos_emb = model(pos)
            neg_emb = model(neg)
            triplet_loss = lf(anchor_emb, pos_emb, neg_emb)
            
            pos_samples_1 = torch.cat([anchor_emb, pos_emb],-1)
            pos_samples_2 = torch.cat([pos_emb, anchor_emb],-1)

            neg_samples_11 = torch.cat([anchor_emb, neg_emb],-1)
            neg_samples_12 = torch.cat([neg_emb, anchor_emb],-1)

            neg_samples_21 = torch.cat([pos_emb, neg_emb],-1)
            neg_samples_22 = torch.cat([neg_emb, pos_emb],-1)

            x = torch.cat([pos_samples_1, pos_samples_2, 
                           neg_samples_11, neg_samples_12,
                           neg_samples_21, neg_samples_22],0)

            y = torch.cat([torch.ones(2*len(pos_samples_1)), 
                               torch.zeros(2*(len(neg_samples_11) 
                                              + len(neg_samples_21)))],0)[:,None].to(device)
            out = cls(x)
            bce_loss = lf2(out,y)
            
        scaler.scale(triplet_loss + bce_loss*0.2).backward()
        scaler.step(optimizer)
        scaler.update()
        sched.step()
        
        tr_triplet_loss.append(triplet_loss.item())
        tr_bce_loss.append(bce_loss.item())
        pbar.set_description(f"Tr triplet: {round(np.mean(tr_triplet_loss),3)}, bce: {round(np.mean(tr_bce_loss),3)}")
        
    model.eval()
    cls.eval()
    val_triplet_loss = []
    val_bce_loss = []
    with torch.no_grad():
        pbar = tqdm(val_dl)
        for anchor_image, anchor_text, pos_image, pos_text, neg_image, neg_text in pbar:

            anchor = anchor_image.to(device), {'input_ids' : anchor_text['input_ids'].to(device),
                                               'attention_mask' : anchor_text['attention_mask'].to(device)}

            pos = pos_image.to(device), {'input_ids' : pos_text['input_ids'].to(device),
                                         'attention_mask' : pos_text['attention_mask'].to(device)}

            neg = neg_image.to(device), {'input_ids' : neg_text['input_ids'].to(device),
                                         'attention_mask' : neg_text['attention_mask'].to(device)}

            with torch.cuda.amp.autocast():
                
                anchor_emb = model(anchor)
                pos_emb = model(pos)
                neg_emb = model(neg)
                triplet_loss = lf(anchor_emb, pos_emb, neg_emb)
                
                pos_samples_1 = torch.cat([anchor_emb, pos_emb],-1)
                pos_samples_2 = torch.cat([pos_emb, anchor_emb],-1)
                
                neg_samples_11 = torch.cat([anchor_emb, neg_emb],-1)
                neg_samples_12 = torch.cat([neg_emb, anchor_emb],-1)
                
                neg_samples_21 = torch.cat([pos_emb, neg_emb],-1)
                neg_samples_22 = torch.cat([neg_emb, pos_emb],-1)
                
                x = torch.cat([pos_samples_1, pos_samples_2, 
                               neg_samples_11, neg_samples_12,
                               neg_samples_21, neg_samples_22],0)
                
                y = torch.cat([torch.ones(2*len(pos_samples_1)), 
                               torch.zeros(2*(len(neg_samples_11) 
                                              + len(neg_samples_21)))],0)[:,None].to(device)
                out = cls(x)
                bce_loss = lf2(out,y)

            val_triplet_loss.append(triplet_loss.item())
            val_bce_loss.append(bce_loss.item())
            pbar.set_description(f"Val triplet: {round(np.mean(val_triplet_loss),3)}, bce: {round(np.mean(val_bce_loss),3)}")
            
    tr_triplet_loss = round(np.mean(tr_triplet_loss),3)
    val_triplet_loss = round(np.mean(val_triplet_loss),3)
    tr_triplet_losses.append(tr_triplet_loss)
    val_triplet_losses.append(val_triplet_loss)
    
    tr_bce_loss = round(np.mean(tr_bce_loss),3)
    val_bce_loss = round(np.mean(val_bce_loss),3)
    tr_bce_losses.append(tr_bce_loss)
    val_bce_losses.append(val_bce_loss)
    summary = f"Ep {ep}: Tr triplet {tr_triplet_loss}, bce {tr_bce_loss} - Val triplet {val_triplet_loss}, bce {val_bce_loss}"
    print(summary) 
    
    

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/375 [00:00<?, ?it/s]

  0%|          | 0/161 [00:00<?, ?it/s]

Ep 0: Tr triplet 0.076, bce 0.659 - Val triplet 0.022, bce 0.632


  0%|          | 0/375 [00:00<?, ?it/s]

  0%|          | 0/161 [00:00<?, ?it/s]

Ep 1: Tr triplet 0.021, bce 0.603 - Val triplet 0.024, bce 0.542


  0%|          | 0/375 [00:00<?, ?it/s]

  0%|          | 0/161 [00:00<?, ?it/s]

Ep 2: Tr triplet 0.018, bce 0.497 - Val triplet 0.017, bce 0.417


  0%|          | 0/375 [00:00<?, ?it/s]

  0%|          | 0/161 [00:00<?, ?it/s]

Ep 3: Tr triplet 0.011, bce 0.392 - Val triplet 0.013, bce 0.348


  0%|          | 0/375 [00:00<?, ?it/s]

  0%|          | 0/161 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [17]:
torch.save(model.state_dict(), '10ep.pth')

In [20]:
embs = []
model.eval()
with torch.no_grad():
    pbar = tqdm(testing_dl)
    for image, text in pbar:
        x = image.to(device), {'input_ids' : text['input_ids'].to(device),
                               'attention_mask' : text['attention_mask'].to(device)}
        y = model(x)
        embs.append(y.cpu())

  0%|          | 0/536 [00:00<?, ?it/s]

In [18]:
embs = torch.cat(embs,0)

In [19]:
embs_df = pd.DataFrame(embs.numpy())
emb_cols = [f'emb_{i}' for i in embs_df.columns]
embs_df.columns = emb_cols
embs_df['cls'] = df['label_group']
embs_df['cls'] = embs_df['cls'].astype('category').cat.codes

In [21]:
embs_df.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,34240,34241,34242,34243,34244,34245,34246,34247,34248,34249
emb_0,0.500960,0.237719,0.050669,0.718747,-0.581001,-0.097971,1.183702,1.779962,1.124665,-0.972785,...,-1.564198,0.842688,-0.393273,-0.365358,-0.153456,-0.877197,0.058753,-0.112191,0.031476,0.074278
emb_1,-1.301394,-1.292176,-0.944553,-1.888656,-0.108013,-0.747545,-1.263595,-1.211068,-0.682983,0.179870,...,-0.576654,-0.905915,-0.449244,-0.466996,-0.804786,-0.308678,0.259296,-0.835542,-0.930508,-1.531453
emb_2,-0.226438,0.306056,0.002641,0.311137,-0.544753,-0.013903,0.166399,0.306137,0.915265,-0.525376,...,0.455988,0.210848,-0.447949,-0.411966,0.376221,-0.085963,0.045716,0.809573,-0.265023,0.909191
emb_3,-0.323761,0.608568,-0.098363,-1.106508,0.221210,-1.320947,-0.539112,-1.217819,0.357320,-0.854563,...,0.597615,-0.977988,1.731781,1.720604,0.479843,0.006830,0.526009,-0.228569,0.531991,0.439010
emb_4,-1.258861,-0.940595,0.709853,-0.407680,0.241497,-1.016383,-0.178965,-0.405246,0.401834,-0.620351,...,-0.869939,-1.265399,-0.175623,-0.220973,-0.101543,-1.238433,-0.110854,-0.539160,0.548752,-1.179566
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
emb_124,1.327470,0.503480,0.680511,0.594233,-0.113412,-0.175654,0.622813,0.947124,0.934831,-0.190990,...,-0.169104,0.302227,1.240302,1.217515,-0.716665,-0.396859,0.584340,-0.777377,0.648708,-0.493271
emb_125,-1.127460,-0.561340,0.595688,-0.708277,-0.253344,-1.247321,-0.849796,-1.481531,0.011094,-1.126404,...,0.344415,-0.091109,1.479870,1.414771,-0.354530,0.216521,-0.839207,-1.230204,0.316971,-1.293721
emb_126,0.465477,-1.320146,0.288588,0.464027,0.089630,-0.426252,-0.943907,0.281021,0.681009,0.782818,...,1.783656,0.248377,0.240647,0.235544,0.519699,-0.012075,0.809881,1.609108,-0.025901,-0.320838
emb_127,1.034373,-0.168213,-0.304716,-0.008191,0.086427,1.383384,0.928859,0.351101,0.364785,0.460414,...,-0.302482,-0.573974,0.280889,0.266239,-0.264183,-0.374306,1.208320,-0.497938,-0.462315,0.550078


In [22]:
embs_df.to_csv('train_embs.csv')