<br>
<h2 style = "font-size:60px; font-family:Garamond ; font-weight : normal; background-color: #f6f5f5 ; color : #fe346e; text-align: center; border-radius: 100px 100px;">FAISS Pytorch Inference</h2>
<br>

<h3>📌 Inference Pipeline is taken from this Notebook:</h3> <h4><a href='https://www.kaggle.com/ks2019/happywhale-arcface-baseline-tpu'>https://www.kaggle.com/ks2019/happywhale-arcface-baseline-tpu</a></h4>

<h3>📌 Train Notebook:</h3> <h4><a href='https://www.kaggle.com/debarshichanda/pytorch-arcface-gem-pooling-starter'>https://www.kaggle.com/debarshichanda/pytorch-arcface-gem-pooling-starter</a></h4>

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Install Required Libraries</h1></span>

In [None]:
!pip install timm
!pip install faiss-gpu

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Import Required Libraries 📚</h1></span>

In [None]:
import os
import gc
import cv2
import math
import copy
import time
import random

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict

# Sklearn Imports
from sklearn.preprocessing import LabelEncoder, normalize
from sklearn.model_selection import StratifiedKFold

# For Image Models
import timm

# For Similarity Search
import faiss

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
y_ = Fore.YELLOW
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Configuration ⚙️</h1></span>

In [None]:
CONFIG = {"seed": 2022,
          "img_size": 448,
          "model_name": "tf_efficientnet_b0_ns",
          "num_classes": 15587,
          "embedding_size": 512,
          "train_batch_size": 64,
          "valid_batch_size": 64,
          "n_fold": 5,
          "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
          # ArcFace Hyperparameters
          "s": 30.0, 
          "m": 0.30,
          "ls_eps": 0.0,
          "easy_margin": False
          }

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Set Seed for Reproducibility</h1></span>

In [None]:
def set_seed(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CONFIG['seed'])

In [None]:
ROOT_DIR = '../input/happy-whale-and-dolphin'
TRAIN_DIR = '../input/happy-whale-and-dolphin/train_images'
TEST_DIR = '../input/happy-whale-and-dolphin/test_images'

In [None]:
def get_train_file_path(id):
    return f"{TRAIN_DIR}/{id}"

# <h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Read the Data 📖</h1>

In [None]:
df = pd.read_csv(f"{ROOT_DIR}/train.csv")
df['file_path'] = df['image'].apply(get_train_file_path)
df.head()

In [None]:
encoder = LabelEncoder()

with open("../input/arcface-gap-embed/le.pkl", "rb") as fp:
    encoder = joblib.load(fp)
    
df['individual_id'] = encoder.transform(df['individual_id'])

In [None]:
skf = StratifiedKFold(n_splits=CONFIG['n_fold'])

for fold, ( _, val_) in enumerate(skf.split(X=df, y=df.individual_id)):
      df.loc[val_ , "kfold"] = fold

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Dataset Class</h1></span>

In [None]:
class HappyWhaleDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.ids = df['image'].values
        self.file_names = df['file_path'].values
        self.labels = df['individual_id'].values
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        idx = self.ids[index]
        img_path = self.file_names[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        label = self.labels[index]
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            'image': img,
            'label': torch.tensor(label, dtype=torch.long),
            'id': idx
        }

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Augmentations</h1></span>

In [None]:
data_transforms = {
    "train": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.)
}

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">GeM Pooling</h1></span>

In [None]:
# NOT USED

class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">ArcFace</h1></span>

In [None]:
class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, 
                 m=0.50, easy_margin=False, ls_eps=0.0):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=CONFIG['device'])
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Model</h1></span>

In [None]:
class HappyWhaleModel(nn.Module):
    def __init__(self, model_name, embedding_size, pretrained=True):
        super(HappyWhaleModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.embedding = nn.Linear(in_features, embedding_size)
        self.fc = ArcMarginProduct(embedding_size, 
                                   CONFIG["num_classes"],
                                   s=CONFIG["s"], 
                                   m=CONFIG["m"], 
                                   easy_margin=CONFIG["ls_eps"], 
                                   ls_eps=CONFIG["ls_eps"])
    
    def forward(self, images, labels):
        features = self.model(images)
        embedding = self.embedding(features)
        output = self.fc(embedding, labels)
        return output
    
    def extract(self, images):
        features = self.model(images)
        embedding = self.embedding(features)
        return embedding
    

model = HappyWhaleModel(CONFIG['model_name'], CONFIG['embedding_size'])
model.load_state_dict(torch.load("../input/arcface-gap-embed/Loss14.0082_epoch10.bin"))
model.to(CONFIG['device']);

In [None]:
@torch.inference_mode()
def get_embeddings(model, dataloader, device):
    model.eval()
    
    LABELS = []
    EMBEDS = []
    IDS = []
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:        
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)
        ids = data['id']

        outputs = model.extract(images)
        
        LABELS.append(labels.cpu().numpy())
        EMBEDS.append(outputs.cpu().numpy())
        IDS.append(ids)
    
    EMBEDS = np.vstack(EMBEDS)
    LABELS = np.concatenate(LABELS)
    IDS = np.concatenate(IDS)
    
    return EMBEDS, LABELS, IDS

In [None]:
def prepare_loaders(df, fold):
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)
    
    train_dataset = HappyWhaleDataset(df_train, transforms=data_transforms["train"])
    valid_dataset = HappyWhaleDataset(df_valid, transforms=data_transforms["valid"])

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    
    return train_loader, valid_loader

In [None]:
train_loader, valid_loader = prepare_loaders(df, fold=0)

In [None]:
train_embeds, train_labels, train_ids = get_embeddings(model, train_loader, CONFIG['device'])
valid_embeds, valid_labels, valid_ids = get_embeddings(model, valid_loader, CONFIG['device'])

In [None]:
train_embeds = normalize(train_embeds, axis=1, norm='l2')
valid_embeds = normalize(valid_embeds, axis=1, norm='l2')

In [None]:
train_labels = encoder.inverse_transform(train_labels)
valid_labels = encoder.inverse_transform(valid_labels)

In [None]:
index = faiss.IndexFlatIP(CONFIG['embedding_size'])
index.add(train_embeds)

In [None]:
D, I = index.search(valid_embeds, k=50)

In [None]:
allowed_targets = np.unique(train_labels)

In [None]:
val_targets_df = pd.DataFrame(np.stack([valid_ids, valid_labels], axis=1), columns=['image','target'])
val_targets_df.loc[~val_targets_df.target.isin(allowed_targets), 'target'] = 'new_individual'
val_targets_df.target.value_counts()

In [None]:
valid_df = []
for i, val_id in tqdm(enumerate(valid_ids)):
    targets = train_labels[I[i]]
    distances = D[i]
    subset_preds = pd.DataFrame(np.stack([targets,distances],axis=1),columns=['target','distances'])
    subset_preds['image'] = val_id
    valid_df.append(subset_preds)

In [None]:
valid_df = pd.concat(valid_df).reset_index(drop=True)
valid_df = valid_df.groupby(['image','target']).distances.max().reset_index()
valid_df.head()

In [None]:
valid_df = valid_df.sort_values('distances', ascending=False).reset_index(drop=True)
valid_df.to_csv('val_neighbors.csv')

In [None]:
sample_list = ['938b7e931166', '5bf17305f073', '7593d2aee842', '7362d7a01d00','956562ff2888']

In [None]:
def get_predictions(test_df, threshold=0.2):
    predictions = {}
    for i, row in tqdm(test_df.iterrows()):
        if row.image in predictions:
            if len(predictions[row.image]) == 5:
                continue
            predictions[row.image].append(row.target)
        elif row.distances > threshold:
            predictions[row.image] = [row.target, 'new_individual']
        else:
            predictions[row.image] = ['new_individual', row.target]

    for x in tqdm(predictions):
        if len(predictions[x]) < 5:
            remaining = [y for y in sample_list if y not in predictions]
            predictions[x] = predictions[x] + remaining
            predictions[x] = predictions[x][:5]
        
    return predictions

In [None]:
def map_per_image(label, predictions):
    """Computes the precision score of one image.

    Parameters
    ----------
    label : string
            The true label of the image
    predictions : list
            A list of predicted elements (order does matter, 5 predictions allowed per image)

    Returns
    -------
    score : double
    """    
    try:
        return 1 / (predictions[:5].index(label) + 1)
    except ValueError:
        return 0.0

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Compute CV</h1></span>

In [None]:
best_th = 0
best_cv = 0
for th in [0.1*x for x in range(11)]:
    all_preds = get_predictions(valid_df, threshold=th)
    cv = 0
    for i,row in val_targets_df.iterrows():
        target = row.target
        preds = all_preds[row.image]
        val_targets_df.loc[i,th] = map_per_image(target, preds)
    cv = val_targets_df[th].mean()
    print(f"CV at threshold {th}: {cv}")
    if cv > best_cv:
        best_th = th
        best_cv = cv

In [None]:
print("Best threshold", best_th)
print("Best cv", best_cv)
val_targets_df.describe()

In [None]:
## Adjustment: Since Public lb has nearly 10% 'new_individual' (Be Careful for private LB)
val_targets_df['is_new_individual'] = val_targets_df.target=='new_individual'
print(val_targets_df.is_new_individual.value_counts().to_dict())
val_scores = val_targets_df.groupby('is_new_individual').mean().T
val_scores['adjusted_cv'] = val_scores[True]*0.1+val_scores[False]*0.9
best_threshold_adjusted = val_scores['adjusted_cv'].idxmax()
print("best_threshold",best_threshold_adjusted)
val_scores

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Inference</h1></span>

In [None]:
train_embeds = np.concatenate([train_embeds, valid_embeds])
train_labels = np.concatenate([train_labels, valid_labels])
print(train_embeds.shape,train_labels.shape)

In [None]:
index = faiss.IndexFlatIP(CONFIG['embedding_size'])
index.add(train_embeds)

In [None]:
test = pd.DataFrame()
test["image"] = os.listdir("../input/happy-whale-and-dolphin/test_images")
test["file_path"] = test["image"].apply(lambda x: f"{TEST_DIR}/{x}")
test["individual_id"] = -1  #dummy value
test.head()

In [None]:
test_dataset = HappyWhaleDataset(test, transforms=data_transforms["valid"])
test_loader = DataLoader(test_dataset, batch_size=CONFIG['valid_batch_size'], 
                         num_workers=2, shuffle=False, pin_memory=True)

In [None]:
test_embeds, _, test_ids = get_embeddings(model, test_loader, CONFIG['device'])
test_embeds = normalize(test_embeds, axis=1, norm='l2')

In [None]:
D, I = index.search(test_embeds, k=50)

In [None]:
test_df = []
for i, test_id in tqdm(enumerate(test_ids)):
    targets = train_labels[I[i]]
    distances = D[i]
    subset_preds = pd.DataFrame(np.stack([targets, distances], axis=1), columns=['target','distances'])
    subset_preds['image'] = test_id
    test_df.append(subset_preds)
    
test_df = pd.concat(test_df).reset_index(drop=True)
test_df = test_df.groupby(['image','target']).distances.max().reset_index()
test_df = test_df.sort_values('distances', ascending=False).reset_index(drop=True)
test_df.to_csv('test_neighbors.csv')

In [None]:
predictions = get_predictions(test_df, best_threshold_adjusted)

predictions = pd.Series(predictions).reset_index()
predictions.columns = ['image','predictions']
predictions['predictions'] = predictions['predictions'].apply(lambda x: ' '.join(x))
predictions.to_csv('submission.csv',index=False)
predictions.head()