In [4]:
import os
import random
import functools
from functools import partial
import PIL

import numpy as np 
import pandas as pd

from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

import timm

In [5]:
class SoftMaxDS(Dataset):
    def __init__(self, data, images_path, return_triplet = True):
        super().__init__()
        self.imgs = data['image'].tolist()
        self.unique_labels = data['label_group'].unique().tolist()
        self.labels = data['label_group'].astype('category')
        self.label_codes = self.labels.cat.codes
        
        self.images_path = images_path
        
    def __getitem__(self, idx):
        
        img = self._get_item(idx)
        label = self.label_codes.iloc[idx]
        return img, label
    def __len__(self):
        return len(self.imgs)
    
    def _get_item(self, idx):
        im = PIL.Image.open(os.path.join(self.images_path, self.imgs[idx]))
        im = torch.tensor(np.array(im) / 255.0, dtype = torch.float).permute(2,0,1)
        return im

In [6]:
# load in data

df = pd.read_csv('data/train.csv')
small_images_dir = 'data/small_train_images/'
n_classes = df['label_group'].nunique()
np.random.seed(1337)

# train val split

train_perc = 0.7
n_train_examples = int(train_perc * len(df))

train_df = df.iloc[:n_train_examples]
val_df = df.iloc[n_train_examples:]

In [7]:
# creating dataloaders

vision_model = 'resnet50'

bs = 64
tr_ds = SoftMaxDS(df, small_images_dir)
tr_dl = DataLoader(tr_ds, batch_size = bs, shuffle = True, pin_memory = True)

device = torch.device('cuda')

In [8]:
class EMBCLass(nn.Module) :
    def __init__(self, pretrained_image_embedor='resnet50',
                output_dim=512) :
        super(EMBCLass, self).__init__()
        self.image_embedor = timm.create_model(pretrained_image_embedor, pretrained=True)
        self.image_pool = nn.AdaptiveAvgPool2d((1,1))
        self.head = nn.Sequential(nn.Linear(2048, output_dim), 
                                  #nn.ReLU(),
                                  )
        
        for m in self.head.modules():
            if isinstance(m, nn.Linear):
                sz = m.weight.data.size(-1)
                m.weight.data.normal_(mean=0.0, std=1/np.sqrt(sz))
            elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
                m.bias.data.zero_()
                m.weight.data.fill_(1.0)
                m.bias.data.zero_()
            if isinstance(m, nn.Linear) and m.bias is not None:
                m.bias.data.zero_()
    
    def _get_embs(self, x) :
        images = x
        out_images = self.image_embedor.forward_features(images)
        out_images = self.image_pool(out_images).squeeze()
        #return F.normalize(out_images, dim=-1)
        return out_images
    
    def forward(self, x) :
        out_images = self._get_embs(x)
        
        return self.head(out_images)

In [7]:
model = EMBCLass(vision_model, output_dim=n_classes).to(device)

In [8]:

normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))

train_transforms = transforms.Compose([transforms.ColorJitter(.3,.3,.3),
                                       transforms.RandomRotation(5),
                                       transforms.RandomCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       normalize
                                       ])

val_transforms = transforms.Compose([
        normalize
                                     ])

n_epochs = 30

lf = nn.CrossEntropyLoss()

lr = 1e-2
wd = 0
no_decay = ["bias", "BatchNorm2d.weight", "BatchNorm2d.bias", "LayerNorm.weight", 'LayerNorm.bias',
            "BatchNorm1d.weight", "BatchNorm1d.bias"]

optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": wd,
    },
    {
        "params": [p for n, p in  model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)

# learning rate scheduler
sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr =lr, pct_start = 0.3, #anneal_strategy = 'linear',
                                            total_steps = int(n_epochs * len(tr_dl)))

In [9]:
tr_losses = []
val_losses = []
for ep in tqdm(range(n_epochs)):
    model.train()
    tr_loss = []
    pbar = tqdm(tr_dl)
    for imgs, labels in pbar:
        
        imgs = train_transforms(imgs.to(device))
        
        optimizer.zero_grad()
        out = model(imgs)
        loss = lf(out, labels.long().to(device))
            
        loss.backward()
        optimizer.step()
        sched.step()
        
        tr_loss.append(loss.item())
        pbar.set_description(f"Train loss: {round(np.mean(tr_loss),3)}")
    
    if ep%2==0 :
        torch.save(model.state_dict(), 'data/tests_model_image/model_class_ep_{}.pth'.format(ep))
    model.eval()
    tr_losses.append(tr_loss)
    summary = f"Ep {ep}: Train loss {np.asarray(tr_loss).mean()}"
    print(summary) 
    

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 0: Train loss 8.821764301897874


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 1: Train loss 5.132217973915498


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 2: Train loss 2.0855755416759805


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 3: Train loss 1.2709581281489402


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 4: Train loss 0.9436204743140669


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 5: Train loss 0.7508404035566013


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 6: Train loss 0.6013587592586652


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 7: Train loss 0.4956280533502351


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 8: Train loss 0.3906500585405017


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 9: Train loss 0.3101756614076891


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 10: Train loss 0.2681865827041442


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 11: Train loss 0.22927061314416577


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 12: Train loss 0.20623355696790976


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 13: Train loss 0.17473196958143042


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 14: Train loss 0.1524906879110234


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 15: Train loss 0.1396005724204023


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 16: Train loss 0.11291867159411255


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 17: Train loss 0.10100022109734492


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 18: Train loss 0.08138219460344581


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 19: Train loss 0.06742419608140629


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 20: Train loss 0.059215225652468374


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 21: Train loss 0.047547162775755616


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 22: Train loss 0.03982319059010044


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 23: Train loss 0.034133574037261846


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 24: Train loss 0.02762287608390186


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 25: Train loss 0.024498500820340975


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 26: Train loss 0.02100787410058525


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 27: Train loss 0.01742024289632721


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 28: Train loss 0.0163770934244107


HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))


Ep 29: Train loss 0.015568373700723408



In [11]:
testing_ds = SoftMaxDS(df, small_images_dir)
testing_dl = DataLoader(testing_ds, batch_size = bs, shuffle = False, pin_memory = True)

In [14]:
embs = []
model.eval()
with torch.no_grad():
    pbar = tqdm(testing_dl)
    for image, labels in pbar:
        x = val_transforms(image.to(device))
        y = model._get_embs(x)
        embs.append(y.cpu())

HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))




In [15]:
embs = torch.cat(embs,0)

In [16]:
embs_df = pd.DataFrame(embs.numpy())
emb_cols = [f'emb_{i}' for i in embs_df.columns]
embs_df.columns = emb_cols
embs_df.to_csv('data/tests_model_image/train_embs_class_30ep.csv')


## No train

In [9]:
model = EMBCLass(vision_model, output_dim=n_classes).to(device)

In [10]:
testing_ds = SoftMaxDS(df, small_images_dir)
testing_dl = DataLoader(testing_ds, batch_size = bs, shuffle = False, pin_memory = True)

In [12]:
normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))

train_transforms = transforms.Compose([transforms.ColorJitter(.3,.3,.3),
                                       transforms.RandomRotation(5),
                                       transforms.RandomCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       normalize
                                       ])

val_transforms = transforms.Compose([normalize
                                     ])

In [13]:
embs = []
model.eval()
with torch.no_grad():
    pbar = tqdm(testing_dl)
    for image, labels in pbar:
        x = val_transforms(image.to(device))
        y = model._get_embs(x)
        embs.append(y.cpu())

HBox(children=(FloatProgress(value=0.0, max=536.0), HTML(value='')))




In [14]:
embs = torch.cat(embs,0)

In [15]:
embs_df = pd.DataFrame(embs.numpy())
emb_cols = [f'emb_{i}' for i in embs_df.columns]
embs_df.columns = emb_cols
embs_df.to_csv('data/tests_model_image/train_embs_class_notrain.csv')
