# 🌱🌿🌾Sorghum PyTorch TTA Inference Tutorial🚀

This notbook is ***the inference of*** the [Sorghum Speed UP](https://www.kaggle.com/code/leoooo333/sorghum-speed-up) tutorial.  

Thanks to the small jpegs Sorghum images from https://www.kaggle.com/datasets/mithilsalunkhe/small-jpegs-fgvc

If you have any ***question*** about my baseline, please *feel free* to ***make a comment***. I will reply as soon as possible! If you like it, please **upvote**👏👏👏


## Inference result:
>##### The first fold in 5-Fold model, without TTA, get LB **Acc@1: 82.9%**
>##### The 1,2 fold mix up(Avg), without TTA, get LB **Acc@1: 85.5%**
>##### All train image(instead of k-fold) model, without TTA, get LB **Acc@1: 85.6%**
>##### All train image(instead of k-fold) model, use flip and crop TTA, get LB **Acc@1: 86.3%**
>##### The 1,2,3 fold in 5-Fold and all-enrolled model mixup, use flip and crop TTA, get LB **Acc@1: 87.7%**

## main idea
+ *pre-process* the images(use CLAHE)
+ visualize and check the pre-process
+ use the pretrained model(import timm)
+ images augmentation
+ [train the model with DDP (and mix precise)](https://www.kaggle.com/code/leoooo333/sorghum-speed-up)
+ Inference

## tricks

#### To learn about training tricks
Click here [Sorghum Speed UP](https://www.kaggle.com/code/leoooo333/sorghum-speed-up) !

#### Inference
+ TTA : test time augmentation, make different trivial augmentations on test images, mix up all results, and get the average.

In [None]:
!pip install seaborn
!pip install timm

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, models
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image as Img
from tqdm import tqdm
import os
from sklearn.metrics import accuracy_score
import timm
from tqdm import tqdm  

from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

import torchvision
import matplotlib.pyplot as plt
import re
from sklearn.model_selection import train_test_split, StratifiedKFold
import pytorch_lightning as pl
import seaborn as sns
import cv2 as cv
import numpy as np
import torch.nn.functional as F

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import math

# Config 

In [None]:
MODEL_NAME = 'tf_efficientnet_b5_ns'
BATCH_SIZE = 64
IMAGE_SIZE = 900
NUM_WORKERS = 15
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
USE_AMP = True
INIT = False

root_in = '../input/small-jpegs-fgvc' #Folder with input (image, lable)
root_out = './' #Folder with output (csv, pth) 
have_index = False # If the breed label have been map to a index

'''ArcFace parameter'''
NUM_CLASSES = 100
EMBEDDING_SIZE = 1024
S, M = 30.0, 0.5 # S:consine scale in arcloss. M:arg penalty
EASY_MERGING, LS_EPS = False, 0.0

# Model

In [None]:
class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        s: norm of input feature
        m: margin
        cos(theta + m)
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        s: float,
        m: float,
        easy_margin: bool,
        ls_eps: float,
        rank
    ):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
        self.rank = rank

    def forward(self, input: torch.Tensor, label: torch.Tensor, device = 'cuda') -> torch.Tensor:
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        # Enable 16 bit precision
        cosine = cosine.to(torch.float32)

        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=self.rank)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

In [None]:
class SorghumModel(nn.Module):
    def __init__(self, model_name, embedding_size, map_location, k_fold, rank, pretrained=True):
        super(SorghumModel, self).__init__()       
        
        #model_effecient_b6 = timm.create_model(model_name, pretrained=pretrained, num_classes=NUM_CLASSES)
        #global param_name
        #param_name = [name for name,_ in model_effecient_b6.named_parameters()] # All parameters name
        #del model_effecient_b6
            
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=NUM_CLASSES)
        
        #freeze_pretrained_layers(self.model)
        #debarcle_layers(self.model, db_all=True) # Debarcle all layers()
        
        print('load Start!!!')
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.pooling = self.model.global_pool
        self.model.global_pool = nn.Identity()
        #self.pooling = GeM()
        self.rank = rank
        self.multiple_dropout = [nn.Dropout(0.25) for i in range(8)]
        self.embedding = nn.Linear(in_features * 2, embedding_size)
        self.fc = ArcMarginProduct(embedding_size, 
                                   NUM_CLASSES,
                                   S, 
                                   M, 
                                   EASY_MERGING, 
                                   LS_EPS,
                                  self.rank)

    def forward(self, images, labels):
        features = self.model(images)
        pooled_features_avg = self.pooling(features).flatten(1)
        pooled_features_max = nn.AdaptiveMaxPool2d((1,1))(features).flatten(1)
        pooled_features = torch.cat((pooled_features_avg, pooled_features_max), dim=1)
        pooled_features_dropout = torch.zeros((pooled_features.shape),device=self.rank)
        for i in range(8):
            pooled_features_dropout += self.multiple_dropout[i](pooled_features)
        pooled_features_dropout /= 8
        embedding = self.embedding(pooled_features_dropout)
        #pooled_features = nn.Dropout(0.5)(pooled_features)
        #embedding = self.embedding(pooled_features)
        output = self.fc(embedding, labels)
        return output
    
    def extract(self, images):
        features = self.model(images)
        pooled_features_avg = self.pooling(features).flatten(1)
        pooled_features_max = nn.AdaptiveMaxPool2d((1,1))(features).flatten(1)
        pooled_features = torch.cat((pooled_features_avg, pooled_features_max), dim=1)
        embedding = self.embedding(pooled_features)
        return embedding

# Dataset

In [None]:
class Sorghum_Train_Dataset(Dataset):
    '''Train Dataset'''
    def __init__(self, img_path_csv='', df=None, transform=None):
        if df is not None:
            self.df = df
        else:
            self.df = pd.read_csv(img_path_csv)
        self.transform = transform
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index):
        img = Img.open(os.path.join(root_in, 'train', self.df.iloc[index, 0]))
        label_index = self.df.iloc[index, 4]
        if self.transform is not None:
            img = self.transform(img)
        return img, label_index

In [None]:
class Sorghum_Test_Dataset(Sorghum_Train_Dataset):
    '''Test Dataset'''
    def __getitem__(self, index):
        img = Img.open(os.path.join(root_in, 'test', self.df.iloc[index, 0]))
        if self.transform:
            img = self.transform(img)
        return img

# data-pre-process

In [None]:
def data_pre_access(file, output):
    '''transfer train label into index'''
    labels = pd.read_csv(file, index_col='image')
    labels_map = dict()
    labels['label_index'] = torch.zeros((labels.shape[0])).type(torch.int32).numpy()
    for i, label in enumerate(labels.cultivar.unique()):
        labels_map[i] = label
        labels.loc[labels.cultivar == label, 'label_index'] = i
    labels.to_csv(output)
    
    return labels_map

# Check labels_map(map from index to cultivar)

##### Here, if you have labels_index.csv, please turn the have_index=True

In [None]:
if have_index:
    labels_map = {}
    train_df = pd.read_csv(os.path.join(root_out,'labels_index.csv'), index_col='image')
    def label_f(m):
        labels_map[int(m.label_index)] = m.cultivar
    train_df.apply(label_f,axis=1)
else:
    labels_map = data_pre_access(os.path.join(root_in,'train_cultivar_mapping.csv'), output=os.path.join(root_out,'labels_index.csv'))
    train_df = pd.read_csv(os.path.join(root_out,'labels_index.csv'), index_col='image')
num_classes = len(labels_map)

In [None]:
num_classes

In [None]:
check_sum = 0
for key, val in tqdm(labels_map.items()):
    train_df[train_df.label_index == key].cultivar.unique() == val
    check_sum += 1

In [None]:
check_sum == len(labels_map)

# Load model

##### Don't foget enable code below when you inference

In [None]:
#model_dict = torch.load(os.path.join(root_out,'tf_efficientnet_b5_ns_F_0/Sorghum17.params'))

In [None]:
'''
model_dict_noDist = {}
for key, value in model_dict.items():
    model_dict_noDist[key.split('module.')[-1]] = model_dict[key]
'''

In [None]:
'''
model = SorghumModel(MODEL_NAME, EMBEDDING_SIZE, map_location={'cuda:%d' % 0: 'cuda:%d' % 0}, k_fold=0, rank=0)
model.load_state_dict(model_dict_noDist)
'''

##### Here we predict all the values(acc@all), instead of max values index(acc@1)

In [None]:
def predict_test_raw(net, test_iter, device=None):
    '''Inference'''
    net.eval()
    if isinstance(net, nn.Module):
        net.eval()
        if not device:
            device = next(iter(net.parameters())).device
    y = []
    net.to(device)
    softmax = nn.Softmax(dim=1)
    with torch.no_grad():
        for X in tqdm(test_iter):
            if isinstance(X, list):
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            with torch.cuda.amp.autocast(enabled=True):
                embeddings = net.extract(X)
                y += softmax(S * F.linear(F.normalize(embeddings), F.normalize(net.fc.weight))).cpu()
    return np.array(list(Y.numpy() for Y in y))

## Save inference result

In [None]:
def CLAHE_Convert(origin_input):
    clahe = cv.createCLAHE(clipLimit=40, tileGridSize=(10,10))
    t = np.asarray(origin_input)
    t = cv.cvtColor(t, cv.COLOR_BGR2HSV)
    t[:,:,-1] = clahe.apply(t[:,:,-1])
    t = cv.cvtColor(t, cv.COLOR_HSV2BGR)
    t = Img.fromarray(t)
    return t

In [None]:
train_transform = transforms.Compose([
    CLAHE_Convert,
    transforms.Resize(IMAGE_SIZE),
    transforms.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.1),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomApply(transforms=
                  [transforms.RandomResizedCrop(size=IMAGE_SIZE, scale=(0.3,0.4), 
                                                ratio=(1/3,3),interpolation=
                                                transforms.InterpolationMode.BICUBIC)],p=0.2),
    transforms.ToTensor(),
    # Normalize to fit pretrained model
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

val_test_transform = transforms.Compose([
    CLAHE_Convert,
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    # Normalize to fit pretrained model
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

In [None]:
sorghum_test_dataset = Sorghum_Test_Dataset('../input/sorghum-jpeg-test-csv/test.csv', transform=val_test_transform)
sorghum_test_loader = DataLoader(sorghum_test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

In [None]:
'''Enable code below when you inference'''
#result_raw_original = predict_test_raw(model, sorghum_test_loader, device='cuda')
#np.save(os.path.join(root_out, 'test_result_raw_Original.npy'), result_raw_original)

# TTA

## TTA Transforms

In [None]:
tta_transform0 = transforms.Compose([
    CLAHE_Convert,
    transforms.Resize(IMAGE_SIZE),
    transforms.RandomAffine(degrees=(0, 45), translate=(0.05, 0.1), scale=(0.95, 1)),
    transforms.ToTensor(),
    # Normalize to fit pretrained model
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

tta_transform1 = transforms.Compose([
    CLAHE_Convert,
    transforms.Resize(IMAGE_SIZE),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.01, saturation=0.2),
    transforms.ToTensor(),
    # Normalize to fit pretrained model
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

tta_transform2 = transforms.Compose([
    CLAHE_Convert,
    transforms.Resize(IMAGE_SIZE),
    transforms.RandomApply(transforms=
              [transforms.RandomResizedCrop(size=IMAGE_SIZE, scale=(0.4,0.5), 
                                            ratio=(1/3,3),interpolation=
                                            transforms.InterpolationMode.BICUBIC)],p=0.2),
    transforms.ToTensor(),
    # Normalize to fit pretrained model
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

In [None]:
tta_transforms = [tta_transform0,
                  tta_transform1,
                  tta_transform2]

## Save TTA result

In [None]:
'''Enable the last line when you inference.
    The first line is just used to show result'''
result_raw_original = np.load('../input/sorghum-jpeg-test-csv/test_result_raw_Original.npy')
#result_raw_original = np.load(os.path.join(root_out, 'test_result_raw_Original.npy'))

In [None]:
result_raw_ttas = {'origin':result_raw_original, 'avg':result_raw_original}

In [None]:
#Enable code below when you inference
'''for i in range(len(tta_transforms)):
    torch.cuda.empty_cache()
    sorghum_test_dataset = Sorghum_Test_Dataset(os.path.join(root_in, 'test.csv'), transform=tta_transforms[i])
    sorghum_test_loader = DataLoader(sorghum_test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    result_raw_tta = predict_test_raw(model, sorghum_test_loader, device='cuda')
    np.save(os.path.join(root_out, 'test_result_raw_' + 'tta_' + str(i) + '_.npy'), result_raw_tta)
    result_raw_ttas['tta_' + str(i)] = result_raw_tta
    result_raw_ttas['avg'] += result_raw_tta'''

In [None]:
result_raw_ttas['avg'] /= len(result_raw_ttas.keys()) - 1

# Inspect the result

## Sort the result

In [None]:
result_ttas_sorted_val = {}
result_ttas_sorted_idx = {}

In [None]:
result_raw_ttas.keys()

In [None]:
for key, val in result_raw_ttas.items():
    torch.cuda.empty_cache()
    result_tta = torch.tensor(val, dtype=torch.float32, device='cuda')
    result_sorted_val, result_sorted_idx = result_tta.sort(dim=1,descending=True)
    result_ttas_sorted_val[key] = result_sorted_val.cpu().numpy()
    result_ttas_sorted_idx[key] = result_sorted_idx.cpu().numpy()
    del result_tta, result_sorted_val, result_sorted_idx

## Visualize

In [None]:
sns.displot(result_ttas_sorted_val.get('origin')[:,[0,1,2]], kind='ecdf')

In [None]:
sns.displot(result_ttas_sorted_val.get('avg')[:,[0,1,2]], kind='ecdf')

In [None]:
result_original_df = pd.DataFrame(result_ttas_sorted_val['origin'])

In [None]:
result_original_df.iloc[:,[0,1]].describe(percentiles=[0.05, 0.25, 0.35, 0.45, 0.65, 0.75, 0.95])

In [None]:
result_tta_avg_df = pd.DataFrame(result_ttas_sorted_val['avg'])

In [None]:
result_tta_avg_df.iloc[:,[0,1]].describe(percentiles=[0.05, 0.25, 0.35, 0.45, 0.65, 0.75, 0.95])

## Find the trust Threshold (no use in this competition)

In [None]:
Threshold = result_tta_avg_df.iloc[:,0].mean() - result_tta_avg_df.iloc[:,0].std() / 2 # trust interval [μ - σ/2, )

In [None]:
Threshold

## Predict and make submission

In [None]:
result_ttas_sorted_val['avg'].shape

In [None]:
result_sorted_val = result_ttas_sorted_val['avg']
result_sorted_idx = result_ttas_sorted_idx['avg']

In [None]:
len(result_sorted_val)

In [None]:
sub_file =  pd.read_csv(os.path.join(root_in, 'sample_submission.csv'))

In [None]:
'''check the order between test.csv.image and sample_submission.csv.filename'''
result = pd.read_csv('../input/sorghum-jpeg-test-csv/test.csv')
sum(result.image.map(lambda x:x.split('.jpeg')[0]) == sub_file.filename.map(lambda x:x.split('.png')[0]))

In [None]:
result = pd.read_csv('../input/small-jpegs-fgvc/sample_submission.csv')
result['cultivar'] = [labels_map.get(result_sorted_idx[i,0]) for i in range(result_sorted_idx.shape[0])]
result = result.set_index('filename')
result.to_csv(os.path.join(root_out, 'submission.csv'))