In [1]:
import os
import random
import functools
from functools import partial
import PIL

import numpy as np 
import pandas as pd

from tqdm.notebook import tqdm
import math

import torch
from torch.nn.parameter import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

import timm

In [2]:
class SoftMaxDS(Dataset):
    def __init__(self, data, images_path, return_triplet = True):
        super().__init__()
        self.imgs = data['image'].tolist()
        self.unique_labels = data['label_group'].unique().tolist()
        self.labels = data['label_group'].astype('category')
        self.label_codes = self.labels.cat.codes
        
        self.images_path = images_path
        
    def __getitem__(self, idx):
        
        img = self._get_item(idx)
        label = self.label_codes.iloc[idx]
        return img, label
    def __len__(self):
        return len(self.imgs)
    
    def _get_item(self, idx):
        im = PIL.Image.open(os.path.join(self.images_path, self.imgs[idx]))
        im = torch.tensor(np.array(im) / 255.0, dtype = torch.float).permute(2,0,1)
        return im

In [3]:
# load in data

df = pd.read_csv('data/train.csv')
small_images_dir = 'data/small_train_images/'
n_classes = df['label_group'].nunique()
np.random.seed(1337)

# train val split

train_perc = 0.7
n_train_examples = int(train_perc * len(df))

train_df = df.iloc[:n_train_examples]
val_df = df.iloc[n_train_examples:]

In [4]:
# creating dataloaders

vision_model = 'resnet50'

bs = 64
tr_ds = SoftMaxDS(df, small_images_dir)
tr_dl = DataLoader(tr_ds, batch_size = bs, shuffle = True, pin_memory = True)

device = torch.device('cuda')

In [5]:
class EMBRes(nn.Module) :
    def __init__(self, pretrained_image_embedor='resnet50',
                output_dim=512) :
        super(EMBRes, self).__init__()
        self.image_embedor = timm.create_model(pretrained_image_embedor, pretrained=True)
        self.image_pool = nn.AdaptiveAvgPool2d((1,1))
        self.head = nn.Sequential(nn.Linear(2048, output_dim), 
                                  #nn.ReLU(),
                                  )
        
        for m in self.head.modules():
            if isinstance(m, nn.Linear):
                sz = m.weight.data.size(-1)
                m.weight.data.normal_(mean=0.0, std=1/np.sqrt(sz))
            elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
                m.bias.data.zero_()
                m.weight.data.fill_(1.0)
                m.bias.data.zero_()
            if isinstance(m, nn.Linear) and m.bias is not None:
                m.bias.data.zero_()
    
    def _get_embs(self, x) :
        images = x
        out_images = self.image_embedor.forward_features(images)
        out_images = self.image_pool(out_images).squeeze()
        #return F.normalize(out_images, dim=-1)
        return out_images
    
    def forward(self, x) :
        out_images = self._get_embs(x)
        
        return self.head(out_images)

In [6]:
# https://github.com/ronghuaiyang/arcface-pytorch

class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s
        # print(output)

        return output


***Embeddings normalization is not done in the model but in the arcface metric***

In [7]:
model = EMBRes(vision_model).to(device)
metric_fc = ArcMarginProduct(512, df['label_group'].nunique(), s=30, m=0.5, easy_margin=False).to(device)

In [13]:

normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))

train_transforms = transforms.Compose([transforms.ColorJitter(.3,.3,.3),
                                       transforms.RandomRotation(5),
                                       transforms.RandomCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       normalize
                                       ])

val_transforms = transforms.Compose([normalize
                                     ])

n_epochs = 30

lf = nn.CrossEntropyLoss()

lr = 1e-2
wd = 0
no_decay = ["bias", "BatchNorm2d.weight", "BatchNorm2d.bias", "LayerNorm.weight", 'LayerNorm.bias',
            "BatchNorm1d.weight", "BatchNorm1d.bias"]

params = list(model.named_parameters()) + list(metric_fc.named_parameters())
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in params if not any(nd in n for nd in no_decay)],
        "weight_decay": wd,
    },
    {
        "params": [p for n, p in  model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)

# learning rate scheduler
sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr =lr, pct_start = 0.3, #anneal_strategy = 'linear',
                                            total_steps = int(n_epochs * len(tr_dl)))

In [14]:
tr_losses = []
val_losses = []
for ep in tqdm(range(n_epochs)):
    model.train()
    tr_loss = []
    pbar = tqdm(tr_dl)
    for imgs, labels in pbar:
        
        imgs = train_transforms(imgs.to(device))
        
        optimizer.zero_grad()
        feature = model(imgs)
        labels = labels.long().to(device)
        out = metric_fc(feature, labels)
        loss = lf(out, labels)
            
        loss.backward()
        optimizer.step()
        sched.step()
        
        tr_loss.append(loss.item())
        pbar.set_description(f"Train loss: {round(np.mean(tr_loss),3)}")
    
    if ep%2==0 :
        torch.save(model.state_dict(), 'data/tests_model_image/model_class_ep_{}.pth'.format(ep))
    model.eval()
    tr_losses.append(tr_loss)
    summary = f"Ep {ep}: Train loss {np.asarray(tr_loss).mean()}"
    print(summary) 
    

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 0: Train loss 22.94508082831084


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 1: Train loss 19.762793202898397


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 2: Train loss 16.567221851491215


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 3: Train loss 13.847081554469778


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 4: Train loss 11.133125767778994


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 5: Train loss 9.719569367259297


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 6: Train loss 8.603367827721495


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 7: Train loss 7.66605324531669


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 8: Train loss 6.922346911323604


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 9: Train loss 6.200776666402817


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 10: Train loss 5.60735265932866


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 11: Train loss 5.128292756294137


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 12: Train loss 4.684561143170542


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 13: Train loss 4.259772528463335


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 14: Train loss 3.827615266860421


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 15: Train loss 3.4297792128662565


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 16: Train loss 3.036038393182541


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 17: Train loss 2.6709789286798507


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 18: Train loss 2.2973569568191


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 19: Train loss 1.9213387159936464


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 20: Train loss 1.6221876772704409


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 21: Train loss 1.2947686180027562


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 22: Train loss 1.0781045713420234


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 23: Train loss 0.8719725500681063


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 24: Train loss 0.7260605340906933


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 25: Train loss 0.5906273582255217


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 26: Train loss 0.5194121982707683


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 27: Train loss 0.46470308513033076


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 28: Train loss 0.44491417134708877


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 29: Train loss 0.42045613468638554



In [15]:
torch.save(model.state_dict(), 'data/tests_model_image/model_class_ep_30.pth')

In [17]:
testing_ds = SoftMaxDS(df, small_images_dir)
testing_dl = DataLoader(testing_ds, batch_size = bs, shuffle = False, pin_memory = True)

In [18]:
normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))

train_transforms = transforms.Compose([transforms.ColorJitter(.3,.3,.3),
                                       transforms.RandomRotation(5),
                                       transforms.RandomCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       normalize
                                       ])

val_transforms = transforms.Compose([normalize
                                     ])

In [19]:
embs = []
model.eval()
with torch.no_grad():
    pbar = tqdm(testing_dl)
    for image, labels in pbar:
        x = val_transforms(image.to(device))
        y = model._get_embs(x)
        embs.append(y.cpu())

HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))




In [20]:
embs = F.normalize(torch.cat(embs,0))

In [21]:
embs_df = pd.DataFrame(embs.numpy())
emb_cols = [f'emb_{i}' for i in embs_df.columns]
embs_df.columns = emb_cols
embs_df.to_csv('data/tests_model_image/train_embs_arcf_30ep.csv')
