# References

* Thanks to PyTorch Arcface Implementation by @tanulsingh077 from [here](https://www.kaggle.com/tanulsingh077/pytorch-metric-learning-pipeline-only-images)

**Task:** 

Training nfNet10, EfficientNet(b4-b6) models `model_name` in **pretrained_model_config**

- **nfNet10**         - Trained for 12 epochs - ~8 hours
- **EfficientNet b4** - Trained for 12 epochs - ~10 hours
- **EfficientNet b5** - Trained for 12 epochs - ~15 hours
- **EfficientNet b6** - Trained for 10 epochs - ~34 hours


In [None]:
!pip install timm
!pip install --upgrade --force-reinstall --no-deps albumentations
!pip install torchvision 
!pip install tqdm

import math
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F 

import torchvision

from tqdm import tqdm

Collecting albumentations
  Using cached https://files.pythonhosted.org/packages/03/58/63fb1d742dc42d9ba2800ea741de1f2bc6bb05548d8724aa84794042eaf2/albumentations-0.5.2-py3-none-any.whl
Installing collected packages: albumentations
  Found existing installation: albumentations 0.5.2
    Uninstalling albumentations-0.5.2:
      Successfully uninstalled albumentations-0.5.2
Successfully installed albumentations-0.5.2




In [None]:
from google.colab import drive

drive.mount('/content/gdrive')


Mounted at /content/gdrive


In [None]:
cd '/content/gdrive/MyDrive/256_TermProj'

/content/gdrive/MyDrive/256_TermProj


In [None]:
torch.cuda.empty_cache()

# Import Packages

In [None]:
def train_fn(model, data_loader, optimizer, scheduler, epoch, device):
    model.train()
    fin_loss = 0.0
    tk = tqdm(data_loader, desc = "Training epoch: " + str(epoch+1))

    for t,data in enumerate(tk):
        optimizer.zero_grad()
        for k,v in data.items():
            data[k] = v.to(device)

        _, loss = model(**data)
        loss.backward()
        optimizer.step() 
        fin_loss += loss.item() 

        tk.set_postfix({'loss' : '%.6f' %float(fin_loss/(t+1)), 'LR' : optimizer.param_groups[0]['lr']})

    scheduler.step()
    return fin_loss / len(data_loader)


def eval_fn(model, data_loader, epoch, device):
    model.eval()
    fin_loss = 0.0
    tk = tqdm(data_loader, desc = "Validation epoch: " + str(epoch+1))

    with torch.no_grad():
        for t,data in enumerate(tk):
            for k,v in data.items():
                data[k] = v.to(device)

            _, loss = model(**data)
            fin_loss += loss.item() 

            tk.set_postfix({'loss' : '%.6f' %float(fin_loss/(t+1))})
        return fin_loss / len(data_loader)

#dataset
import os
import cv2
import numpy as np 

import torch


class Product_Images(torch.utils.data.Dataset):

    def __init__(self, df, root_dir, transform=None):
        self.df = df 
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        row = self.df.iloc[idx]
        label = row.label_group

        img_path = os.path.join(self.root_dir, row.image)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        return {
            'image' : image,
            'label' : torch.tensor(label).long()
        }


#custom_scheduler.py
import torch 
from torch.optim.lr_scheduler import _LRScheduler


class ProductImageScheduler(_LRScheduler):
    def __init__(self, optimizer, lr_start=5e-6, lr_max=1e-5,
                 lr_min=1e-6, lr_ramp_ep=5, lr_sus_ep=0, lr_decay=0.4,
                 last_epoch=-1):
        self.lr_start = lr_start
        self.lr_max = lr_max
        self.lr_min = lr_min
        self.lr_ramp_ep = lr_ramp_ep
        self.lr_sus_ep = lr_sus_ep
        self.lr_decay = lr_decay
        super(ProductImageScheduler, self).__init__(optimizer, last_epoch)
        
    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)
        if self.last_epoch == 0:
            self.last_epoch += 1
            return [self.lr_start for _ in self.optimizer.param_groups]
        lr = self._compute_lr_from_epoch()
        self.last_epoch += 1
        return [lr for _ in self.optimizer.param_groups]
    
    def _get_closed_form_lr(self):
        return self.base_lrs
    
    def _compute_lr_from_epoch(self):
        if self.last_epoch < self.lr_ramp_ep:
            lr = ((self.lr_max - self.lr_start) / 
                  self.lr_ramp_ep * self.last_epoch + 
                  self.lr_start)
        elif self.last_epoch < self.lr_ramp_ep + self.lr_sus_ep:
            lr = self.lr_max
        else:
            lr = ((self.lr_max - self.lr_min) * self.lr_decay**
                  (self.last_epoch - self.lr_ramp_ep - self.lr_sus_ep) + 
                  self.lr_min)
        return lr

import albumentations
from albumentations.pytorch.transforms import ToTensorV2


def get_train_transforms(img_size=512):
    return albumentations.Compose([
        albumentations.Resize(img_size, img_size, always_apply=True),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.Rotate(limit=120, p=0.8),
        albumentations.RandomBrightness(limit=(0.09, 0.6), p=0.5),
        albumentations.Normalize(
            mean = [0.485, 0.456, 0.406],
            std = [0.229, 0.224, 0.225]
        ),
        ToTensorV2(p=1.0)
    ])

def get_valid_transforms(img_size=512):

    return albumentations.Compose([
        albumentations.Resize(img_size, img_size, always_apply=True),
        albumentations.Normalize(
            mean = [0.485, 0.456, 0.406],
            std = [0.229, 0.224, 0.225]
        ),
        ToTensorV2(p=1.0)
    ])


# Config and Directories

In [None]:
DATA_DIR = 'input/shopee-product-matching/train_images'
TRAIN_CSV = 'input/utils-shopee/folds.csv'
MODEL_PATH = 'output/'


class pretrained_model_config:
    seed = 54
    img_size = 512
    classes = 11014
    scale = 30
    margin = 0.5
    fc_dim = 512
    epochs = 10
    batch_size = 4
    num_workers = 4
    model_name = 'tf_efficientnet_b6'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    scheduler_params = {
        "lr_start": 1e-5,
        "lr_max": 1e-5 * batch_size,     # 1e-5 * 32 (if batch_size(=32) is different then)
        "lr_min": 1e-6,
        "lr_ramp_ep": 5,
        "lr_sus_ep": 0,
        "lr_decay": 0.8,
    }

# Create Model

In [None]:
class ImageRecog_MarginLoss_ArcFace(nn.Module):
    def __init__(self, in_features, out_features, scale=30.0, margin=0.50, easy_margin=False, ls_eps=0.0):
        super(ImageRecog_MarginLoss_ArcFace, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.scale = scale
        self.margin = margin
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.th = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin

    def forward(self, input, label):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
    
        one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.scale
        return output, nn.CrossEntropyLoss()(output,label)


class ProductImageModel(nn.Module):

    def __init__(
        self,
        n_classes = pretrained_model_config.classes,
        model_name = pretrained_model_config.model_name,
        fc_dim = pretrained_model_config.fc_dim,
        margin = pretrained_model_config.margin,
        scale = pretrained_model_config.scale,
        use_fc = True,
        pretrained = True):

        super(ProductImageModel,self).__init__()
        print('Building Model Backbone for {} model'.format(model_name))

        self.backbone = timm.create_model(model_name, pretrained=pretrained)
        in_features = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Identity()
        self.backbone.global_pool = nn.Identity()
        self.pooling =  nn.AdaptiveAvgPool2d(1)
        self.use_fc = use_fc

        if use_fc:
            self.dropout = nn.Dropout(p=0.1)
            self.classifier = nn.Linear(in_features, fc_dim)
            self.bn = nn.BatchNorm1d(fc_dim)
            self._init_params()
            in_features = fc_dim

        self.final = ImageRecog_MarginLoss_ArcFace(
            in_features,
            n_classes,
            scale = scale,
            margin = margin,
            easy_margin = False,
            ls_eps = 0.0
        )

    def _init_params(self):
        nn.init.xavier_normal_(self.classifier.weight)
        nn.init.constant_(self.classifier.bias, 0)
        nn.init.constant_(self.bn.weight, 1)
        nn.init.constant_(self.bn.bias, 0)

    def forward(self, image, label):
        features = self.extract_features(image)
        if self.training:
            logits = self.final(features, label)
            return logits
        else:
            return features

    def extract_features(self, x):
        batch_size = x.shape[0]
        x = self.backbone(x)
        x = self.pooling(x).view(batch_size, -1)

        if self.use_fc and self.training:
            x = self.dropout(x)
            x = self.classifier(x)
            x = self.bn(x)
        return x


# Training

In [None]:
def run_training():
    
    df = pd.read_csv('input/shopee-product-matching/train.csv')

    labelencoder= LabelEncoder()
    df['label_group'] = labelencoder.fit_transform(df['label_group'])

    trainset = Product_Images(df,
                             DATA_DIR,
                             transform = get_train_transforms(img_size = pretrained_model_config.img_size))

    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size = pretrained_model_config.batch_size,
        num_workers = pretrained_model_config.num_workers,
        pin_memory = True,
        shuffle = True,
        drop_last = True
    )

    model = ProductImageModel()
    model.to(pretrained_model_config.device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr = pretrained_model_config.scheduler_params['lr_start'])
    scheduler = ProductImageScheduler(optimizer, **pretrained_model_config.scheduler_params)

    for epoch in range(pretrained_model_config.epochs):
        torch.cuda.empty_cache()
        avg_loss_train = train_fn(model, trainloader, optimizer, scheduler, epoch, pretrained_model_config.device)
        torch.save(model.state_dict(), MODEL_PATH + 'arcface_512x512_{}.pt'.format(pretrained_model_config.model_name))
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
            },
            MODEL_PATH + 'arcface_512x512_{}_checkpoints.pt'.format(pretrained_model_config.model_name)
        )

In [None]:
run_training()

Building Model Backbone for tf_efficientnet_b6 model


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Training epoch: 10:  71%|███████   | 6062/8562 [24:35<09:55,  4.20it/s, loss=16.646075, LR=3.14e-6][A
Training epoch: 10:  71%|███████   | 6063/8562 [24:35<09:51,  4.22it/s, loss=16.646075, LR=3.14e-6][A
Training epoch: 10:  71%|███████   | 6063/8562 [24:36<09:51,  4.22it/s, loss=16.646473, LR=3.14e-6][A
Training epoch: 10:  71%|███████   | 6064/8562 [24:36<09:53,  4.21it/s, loss=16.646473, LR=3.14e-6][A
Training epoch: 10:  71%|███████   | 6064/8562 [24:36<09:53,  4.21it/s, loss=16.647086, LR=3.14e-6][A
Training epoch: 10:  71%|███████   | 6065/8562 [24:36<09:49,  4.23it/s, loss=16.647086, LR=3.14e-6][A
Training epoch: 10:  71%|███████   | 6065/8562 [24:36<09:49,  4.23it/s, loss=16.647514, LR=3.14e-6][A
Training epoch: 10:  71%|███████   | 6066/8562 [24:36<09:54,  4.20it/s, loss=16.647514, LR=3.14e-6][A
Training epoch: 10:  71%|███████   | 6066/8562 [24:36<09:54,  4.20it/s, loss=16.646045, LR=3.14e-6][A
Training