# Self-supervised learning with fastai

Inspired by [Ayush Thakur's work](https://www.kaggle.com/ayuraj/v2-self-supervised-pretraining-with-swav?scriptVersionId=59516445), I started exploring more about SwAV. The results in the [paper](https://arxiv.org/abs/2006.09882) were quite impressive. Hence, I decided to implement it.

Lucky for me, I found a pytorch implementation. Also, it used fastai. This was like a dream come true. Fastai is my comfort zone. You can find more about the [implementation here](https://keremturgutlu.github.io/self_supervised/). I will highly recommend checking out the documentation. Not just SwAV, the repository has fastai implementations of other state-of-the-art algorithms as well.

I followed [this tutorial](https://keremturgutlu.github.io/self_supervised/04%20-%20training_swav_iwang.html) for creating this notebook that you are reading now. So, if something looks off or doesn't make sense, then please refer the original tutorial.

Lets get started . . .

### Installation & imports

In [3]:
!pip install self-supervised -Uq

In [4]:
from fastai.vision.all import *

from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.vision.swav import *

from sklearn.model_selection import StratifiedKFold
import torchvision.models as models
from cuml.neighbors import NearestNeighbors

### Reading Data

In [5]:
df = pd.read_csv('../input/shopee-product-matching/train.csv')
df.head()

To train the model, we will also need a validation set. I will use simple `StratifiedKFold` technique to split my data into train & validation set.

In [6]:
sk_fold = StratifiedKFold(5)
df['is_valid'] = False
for i, (trn_idx, val_idx) in enumerate(sk_fold.split(df, df.label_group)):
    df.loc[val_idx, 'is_valid'] = True
    break
    
df.groupby('is_valid').label_group.value_counts()

As per the warning, there are some label groups with less than 5 posting. lets see how many such label group we have . . . 

In [7]:
sum(df.label_group.value_counts() < 5)

There are 9620 `label_group` with less than 5 postings. Quite a lot, huh! (lets handle it in version-2).

### Dataloaders

For now, lets create some helper functions to create dataloaders.

In [8]:
def get_x(x): return '../input/shopee-product-matching/train_images/' + x['image']

def get_dls(size, bs, workers=None):
    path = Path('../input/shopee-product-matching/train_images/')
    
    db = DataBlock(blocks = (ImageBlock(), CategoryBlock()),
              get_x = get_x, get_y=ColReader('label_group'),
              splitter=ColSplitter(),
              item_tfms=RandomResizedCrop(size, min_scale=1.))
    dls = db.dataloaders(df, bs=bs, num_workers=workers)
    return dls

Lets create our dataloaders . . . 

In [9]:
bs, resize, size = 24, 256, 224
dls = get_dls(resize, bs)

### Model & Callbacks

Finally, lets initialize our model & SwAV callbacks

In [10]:
## Model
arch = "resnet50"
encoder = create_encoder(arch, pretrained=True, n_in=3)
model = create_swav_model(encoder)

## SwAV callback
K = bs*2**4
aug_pipelines = get_swav_aug_pipelines(num_crops=[2, 6],
                                       crop_sizes=[size,int(3/4*size)], 
                                       min_scales=[0.25, 0.20],
                                       max_scales=[1.00, 0.35],
                                       rotate=True, rotate_deg=10, jitter=True, bw=True, blur=False)
cbs=[SWAV(aug_pipelines, crop_assgn_ids=[0,1], K=K, queue_start_pct=0.5, temp=0.1)]

Great, we have our data, model, & also the callbacks. Fastai has this amazing class called `Learner` which put everything together for training.

In [11]:
learn = Learner(dls, model, cbs=cbs)

Before we actually training the model, lets look at some of the samples. 

In [12]:
b = dls.one_batch()
learn._split(b)
learn('before_batch')
learn.swav.show(n=5);

### Training 

Time to train the model . . . 

In [13]:
lr, wd = 1e-2, 1e-2
epochs = 5 # try using 40 or 50
learn.unfreeze()
learn.fit_flat_cos(epochs, lr, wd=wd, pct_start=0.5)

In [14]:

learn.recorder.plot_loss()

In [16]:
save_name = f'swav_iwang_sz{size}_epc{epochs}'


In [17]:
learn.save(save_name)


In [23]:
!cd models

In [24]:
!ls


In [25]:
cd models

In [26]:
ls

In [30]:
ls /kaggle/working/models/swav_iwang_sz224_epc5_encoder.pth

In [29]:
cp swav_iwang_sz224_epc5_encoder.pth ../input/shopee-product-matching

In [20]:
torch.save(learn.model.encoder.state_dict(), learn.path/learn.model_dir/f'{save_name}_encoder.pth')

In [15]:
while true:
    pass

This notebook is only intended for learning. To get better score, try using different model architecture, bigger image size, etc.

I can find the inference notebook, [here](https://www.kaggle.com/ankursingh12/shopee-swav-inference)

Hope you enjoyed reading this notebook. If yes, then please consider **upvoting**!

In [31]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')

In [32]:
import gc
import os 
import cv2
import timm
import random

import numpy as np 
import pandas as pd 
from tqdm import tqdm
from pathlib import Path

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.vision.swav import *

import torchvision.models as models
from cuml.neighbors import NearestNeighbors

import albumentations as A 
from albumentations.pytorch.transforms import ToTensorV2

In [33]:
path = Path('../input/shopee-product-matching')


class CFG:
    img_size = 512
    batch_size = 12
    seed = 2020
    
    device = 'cuda'
    classes = 11014
    
    scale = 30 
    margin = 0.5

In [34]:
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_torch(CFG.seed)

In [44]:
def read_dataset():
    df = pd.read_csv('../input/shopee-product-matching/train.csv')
    tmp = df.groupby('label_group').posting_id.agg('unique').to_dict()
    df['target'] = df.label_group.map(tmp)
    
    image_paths = str(path) + '/train_images/' + df['image']
    return df, image_paths

def get_test_transforms():
    return A.Compose([A.Resize(CFG.img_size, CFG.img_size, always_apply=True),
                      A.Normalize(), ToTensorV2(p=1.0)])

class ShopeeDataset(Dataset):
    def __init__(self, image_paths, transforms=None):
        self.image_paths = image_paths
        self.augmentations = transforms

    def __len__(self):
        return self.image_paths.shape[0]

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.augmentations:
            augmented = self.augmentations(image=image)
            image = augmented['image']       
    
        return image,torch.tensor(1)

In [45]:
def get_image_embeddings(image_paths, model, model_path=None):
    embeds = []
    model.eval()
    
    if model_path:
        model.load_state_dict(torch.load(model_path))
        model = model.to(CFG.device)
    
    image_dataset = ShopeeDataset(image_paths,transforms = get_test_transforms())
    image_loader = DataLoader(image_dataset, batch_size=CFG.batch_size, pin_memory=True, 
                              drop_last=False,num_workers=4)
    
    with torch.no_grad():
        for img,label in tqdm(image_loader): 
            img = img.cuda()
            feat = model(img)
            image_embeddings = feat.detach().cpu().numpy()
            embeds.append(image_embeddings)
            
    image_embeddings = np.concatenate(embeds)
    print(f'Our image embeddings shape is {image_embeddings.shape}')
    
    del embeds, model
    gc.collect()
    return image_embeddings


def get_image_predictions(df, embeddings,threshold = 0.0):
    if len(df) > 3: KNN = 50
    else : KNN = 3
    
    model = NearestNeighbors(n_neighbors = KNN, metric = 'euclidean')
    model.fit(embeddings)
    distances, indices = model.kneighbors(embeddings)
    
    predictions = []
    for k in tqdm(range(embeddings.shape[0])):
        idx = np.where(distances[k,] < threshold)[0]
        ids = indices[k,idx]
        posting_ids = df['posting_id'].iloc[ids].values
        predictions.append(posting_ids)
        
    del model, distances, indices
    gc.collect()
    return predictions

In [46]:
def combine_predictions(row):
    x = np.concatenate([row['image_predictions']])
    return ' '.join( np.unique(x))

def getMetric(row, col):
        n = len(np.intersect1d(row.target,row[col]))
        return 2*n / (len(row.target)+len(row[col]))
    
def evaluate_model(models):
    img_embeddings = 0
    
    if isinstance(models, list):
        for m in models:
            img_embeddings += get_image_embeddings(image_paths.values, m)
        img_embeddings /= len(models)
    else:
        img_embeddings = get_image_embeddings(image_paths.values, models)
        
    img_embeddings = img_embeddings.squeeze()
    image_predictions = get_image_predictions(df, img_embeddings, threshold = 0.36)
    
    df['image_predictions'] = image_predictions
    f1_scores = df.apply(lambda r: getMetric(r, 'image_predictions'), axis=1)
    print(f'CV score for baseline = {f1_scores.mean()}')


In [47]:
df,image_paths = read_dataset()
df.head()


In [40]:
## Pretrained Resnet50 from pytorch
model = models.resnet50(pretrained=True)
model_enc = torch.nn.Sequential(*(list(model.children())[:-1])).cuda()
model_enc.eval()
evaluate_model(model_enc)
# euclidean -  CV score = 0.4930072045298856

In [41]:
## Resnet50 trained using SwAV, by faceboook
model = torch.hub.load('facebookresearch/swav', 'resnet50')
model_enc = torch.nn.Sequential(*(list(model.children())[:-1])).cuda()
model_enc.eval()
evaluate_model(model_enc)
# euclidean - CV score = 0.5134202771430936

In [51]:
## Fine-tuned Resnet50
arch = 'resnet50'
encoder_path = '/kaggle/working/models/swav_iwang_sz224_epc5_encoder.pth'
encoder = create_encoder(arch, pretrained=False, n_in=3)
encoder.load_state_dict(torch.load(encoder_path))
encoder.cuda()
encoder.eval()
evaluate_model(encoder)