In [None]:
import gc
import os
import cv2
import sys
import json
import time
import timm
import torch
import random
import sklearn.metrics
import matplotlib.pyplot as plt

from PIL import Image
from pathlib import Path
from functools import partial
from contextlib import contextmanager

import numpy as np
import scipy as sp
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch import linalg
from pytorch_metric_learning.losses import ArcFaceLoss

from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch import ToTensorV2

from tqdm.notebook import tqdm

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
train_metadata = pd.read_csv("/DATA/asaginbaev/DF20M/DF20M-train_metadata_PROD.csv")
print(len(train_metadata))

test_metadata = pd.read_csv("/DATA/asaginbaev/DF20M/DF20M-public_test_metadata_PROD.csv")
print(len(test_metadata))

In [None]:
train_metadata['image_path'] = train_metadata.apply(lambda x: '/DATA/asaginbaev/DF20M/DF20M/' + x['image_path'].split('.')[0] + '.JPG', axis=1)
test_metadata['image_path'] = test_metadata.apply(lambda x: '/DATA/asaginbaev/DF20M/DF20M/' + x['image_path'].split('.')[0] + '.JPG', axis=1)

train_metadata.head()

### Обычный датасет

In [None]:
N_CLASSES = len(train_metadata['class_id'].unique())

class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_path = self.df['image_path'].values[idx]
        label = self.df['class_id'].values[idx]
        image = cv2.imread(file_path)
        
        try:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        except:
            print(file_path)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        return image, label

### Датасет для triplet-loss обучения

С вероятностью 0.25 в качестве негативного выбирается простой пример, т.е. из произвольного другого класса, с вероятностью 0.75- тяжелый, т.е. из того же рода, но другого вида

In [None]:
N_CLASSES = len(train_metadata['class_id'].unique())

class TLTrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_path = self.df['image_path'].values[idx]
        label = self.df['class_id'].values[idx]
        genus = self.df['genus'].values[idx]
        is_hard_negative = np.random.choice([False, True], p=[0.25, 0.75])
        if is_hard_negative:
            negative_idx = np.random.choice(np.where(np.logical_and(self.df['genus'] == genus, 
                                                                self.df['class_id'] != label))[0])
        else:
            negative_idx = np.random.choice(np.where(self.df['class_id'] != label)[0])
        positive_indices = np.where(self.df['class_id'] == label)[0]
        positive_idx = np.random.choice(positive_indices[positive_indices != idx])
        
        
        image_anchor = cv2.imread(file_path)
        image_positive = cv2.imread(self.df['image_path'].values[positive_idx])
        image_negative = cv2.imread(self.df['image_path'].values[negative_idx])
        
        try:
            image_anchor = cv2.cvtColor(image_anchor, cv2.COLOR_BGR2RGB)
            image_positive = cv2.cvtColor(image_positive, cv2.COLOR_BGR2RGB)
            image_negative = cv2.cvtColor(image_negative, cv2.COLOR_BGR2RGB)      
        except:
            print(file_path)

        if self.transform:
            anchor_augmented = self.transform(image=image_anchor)
            positive_augmented = self.transform(image=image_positive)
            negative_augmented = self.transform(image=image_negative)
            
            image_anchor = anchor_augmented['image']
            image_positive = positive_augmented['image']
            image_negative = negative_augmented['image']
            
        
        return image_anchor, image_positive, image_negative, label

Additive-margin-softmax лосс, реализация из https://github.com/Leethony/Additive-Margin-Softmax-Loss-Pytorch

In [None]:
class AdMSoftmaxLoss(nn.Module):

    def __init__(self, in_features, out_features, s=30.0, m=0.4):
        '''
        AM Softmax Loss
        '''
        super(AdMSoftmaxLoss, self).__init__()
        self.s = s
        self.m = m
        self.in_features = in_features
        self.out_features = out_features
        self.fc = nn.Linear(in_features, out_features, bias=False)

    def forward(self, x, labels):
        '''
        input shape (N, in_features)
        '''
        
        x = torch.flatten(x, start_dim=1)
        assert len(x) == len(labels)
        assert torch.min(labels) >= 0
        assert torch.max(labels) < self.out_features
        
        for W in self.fc.parameters():
            W = F.normalize(W, dim=1)

        x = F.normalize(x, dim=1)

        wf = self.fc(x)
        numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) - self.m)
        excl = torch.cat([torch.cat((wf[i, :y], wf[i, y+1:])).unsqueeze(0) for i, y in enumerate(labels)], dim=0)
        denominator = torch.exp(numerator) + torch.sum(torch.exp(self.s * excl), dim=1)
        L = numerator - torch.log(denominator)
        return -torch.mean(L)

In [None]:
def getModel(architecture_name, target_size, pretrained = False):
    net = timm.create_model(architecture_name, pretrained=pretrained)
    net_cfg = net.default_cfg
    last_layer = net_cfg['classifier']
    num_ftrs = getattr(net, last_layer).in_features
    setattr(net, last_layer, nn.Linear(num_ftrs, target_size))
    return net

In [None]:
# %%
MODEL_NAME = 'vit_base_patch16_224'
model = getModel(MODEL_NAME, N_CLASSES, pretrained=True)
model_mean = list(model.default_cfg['mean'])
model_std = list(model.default_cfg['std'])

In [None]:
model.load_state_dict(torch.load('DF20M-ViT_base_patch16_224_best_accuracy.pth'))

In [None]:
WIDTH, HEIGHT = 224, 224

from albumentations import RandomCrop, HorizontalFlip, VerticalFlip, RandomBrightnessContrast, CenterCrop, PadIfNeeded, RandomResizedCrop

def get_transforms(*, data):
    assert data in ('train', 'valid')

    if data == 'train':
        return Compose([
            RandomResizedCrop(WIDTH, HEIGHT, scale=(0.8, 1.0)),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            RandomBrightnessContrast(p=0.2),
            Normalize(mean=model_mean, std=model_std),
            ToTensorV2(),
        ])

    elif data == 'valid':
        return Compose([
            Resize(WIDTH, HEIGHT),
            Normalize(mean=model_mean, std=model_std),
            ToTensorV2(),
        ])

In [None]:
train_dataset = TLTrainDataset(train_metadata, transform=get_transforms(data='train'))
valid_dataset = TLTrainDataset(test_metadata, transform=get_transforms(data='valid'))

In [None]:
train_dataset = TrainDataset(train_metadata, transform=get_transforms(data='train'))
valid_dataset = TrainDataset(test_metadata, transform=get_transforms(data='valid'))

In [None]:
BATCH_SIZE = 32
EPOCHS = 100
WORKERS = 4

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=WORKERS)

In [None]:
import wandb

### Обучение модели с ArcFace лоссом

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score

wandb.init(project="danish-fungi")

n_epochs = EPOCHS
lr = 1e-3
    
model.to(device)
    
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=1, verbose=True, eps=1e-6)
    
criterion_classification = nn.CrossEntropyLoss()
criterion_metrics = ArcFaceLoss(182, 197 * 768, margin=5.7, scale=8).to(device)

optimizer_criterion = SGD(criterion_metrics.parameters(), lr=lr, momentum=0.9)
scheduler_criterion = ReduceLROnPlateau(optimizer_criterion, 'min', factor=0.9, patience=1, verbose=True, eps=1e-6)
    
for epoch in range(n_epochs):
        
    start_time = time.time()

    model.train()
    avg_loss = 0.

    optimizer.zero_grad()
    optimizer_criterion.zero_grad()
    

    for i, (images, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):

        images = images.to(device)
        labels = labels.to(device)

        y_features = model.forward_features(images)
        loss_metrics = criterion_metrics(torch.flatten(y_features, start_dim=1), labels)
        
        loss = loss_metrics

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        optimizer_criterion.step()
        optimizer_criterion.zero_grad()
        wandb.log({
                   "train/arcface loss": loss_metrics.item(), 
                   "train/loss total": loss.item(),
                   "epoch": epoch,})

        avg_loss += loss.item() / len(train_loader)
    wandb.log({"train/avg loss": avg_loss,
                "epoch": epoch,})
    model.eval()
    avg_val_loss = 0.
    preds = np.zeros((len(valid_dataset)))
    preds_raw = []

    for i, (images, labels) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
            
        images = images.to(device)
        labels = labels.to(device)
            
        with torch.no_grad():
            y_preds = model(images)
            y_features = model.forward_features(images)
            
        preds[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = y_preds.argmax(1).to('cpu').numpy()
        preds_raw.extend(y_preds.to('cpu').numpy())

        loss_metrics = criterion_metrics(torch.flatten(y_features, start_dim=1), labels)
        
        loss = loss_metrics
        wandb.log({
                   "val/arcface loss": loss_metrics.item(), 
                   "val/loss total": loss.item(),
                   "epoch": epoch,})
        avg_val_loss += loss.item() / len(valid_loader)
    wandb.log({"val/avg loss": avg_val_loss,
                "epoch": epoch,})     
    scheduler.step(avg_val_loss)
    scheduler_criterion.step(avg_val_loss)
    
            
    score = f1_score(test_metadata['class_id'], preds, average='macro')
    accuracy = accuracy_score(test_metadata['class_id'], preds)
    recall_3 = top_k_accuracy_score(test_metadata['class_id'], preds_raw, k=3)
    
    elapsed = time.time() - start_time
    
    wandb.log({"val/f1-score": score, 
               "val/accuracy": accuracy, 
               "val/top-3 accuracy": recall_3,
               "epoch": epoch,
               "time elapsed": elapsed})
    
    if (epoch - 9) % 10 == 0:
        torch.save(model.state_dict(), f'checkpoints/VIT_arcface_epoch_{epoch}.pth')
        torch.save(criterion_metrics.state_dict(), f'checkpoints/arcface_epoch_{epoch}.pth')        

### Обучение с additive margin softmax лоссом

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score

wandb.init(project="danish-fungi")

n_epochs = EPOCHS
lr = 1e-4
    
model.to(device)
    
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=1, verbose=True, eps=1e-6)
    
criterion_classification = nn.CrossEntropyLoss()
criterion_metrics = AdMSoftmaxLoss(768 * 197, 182, s=8, m=0.1).to(device)

optimizer_criterion = SGD(criterion_metrics.parameters(), lr=lr, momentum=0.9)
scheduler_criterion = ReduceLROnPlateau(optimizer_criterion, 'min', factor=0.9, patience=1, verbose=True, eps=1e-6)
    
for epoch in range(n_epochs):
        
    start_time = time.time()

    model.train()
    avg_loss = 0.

    optimizer.zero_grad()
    optimizer_criterion.zero_grad()
    

    for i, (images, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):

        images = images.to(device)
        labels = labels.to(device)

        y_features = model.forward_features(images)
        loss_metrics = criterion_metrics(y_features, labels)
        
        loss = loss_metrics

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        optimizer_criterion.step()
        optimizer_criterion.zero_grad()
        wandb.log({
                   "train/am softmax loss": loss_metrics.item(), 
                   "train/loss total": loss.item(),
                   "epoch": epoch,})

        avg_loss += loss.item() / len(train_loader)
    wandb.log({"train/avg loss": avg_loss,
                "epoch": epoch,})
    model.eval()
    avg_val_loss = 0.
    preds = np.zeros((len(valid_dataset)))
    preds_raw = []

    for i, (images, labels) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
            
        images = images.to(device)
        labels = labels.to(device)
            
        with torch.no_grad():
            y_preds = model(images)
            y_features = model.forward_features(images)
            
        preds[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = y_preds.argmax(1).to('cpu').numpy()
        preds_raw.extend(y_preds.to('cpu').numpy())

        loss_metrics = criterion_metrics(y_features, labels)
        
        loss = loss_metrics
        wandb.log({
                   "val/am softmax loss": loss_metrics.item(), 
                   "val/loss total": loss.item(),
                   "epoch": epoch,})
        avg_val_loss += loss.item() / len(valid_loader)
    wandb.log({"val/avg loss": avg_val_loss,
                "epoch": epoch,})     
    scheduler.step(avg_val_loss)
    scheduler_criterion.step(avg_val_loss)
    
            
    score = f1_score(test_metadata['class_id'], preds, average='macro')
    accuracy = accuracy_score(test_metadata['class_id'], preds)
    recall_3 = top_k_accuracy_score(test_metadata['class_id'], preds_raw, k=3)
    
    elapsed = time.time() - start_time
    
    wandb.log({"val/f1-score": score, 
               "val/accuracy": accuracy, 
               "val/top-3 accuracy": recall_3,
               "epoch": epoch,
               "time elapsed": elapsed})
    
    if (epoch - 9) % 10 == 0:
        torch.save(model.state_dict(), f'checkpoints/VIT_am_softmax_epoch_{epoch}.pth')

### Обучение с triplet лоссом

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score

wandb.init(project="danish-fungi")

n_epochs = EPOCHS
lr = 1e-4
    
model.to(device)
    
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=1, verbose=True, eps=1e-6)
    
criterion_classification = nn.CrossEntropyLoss()
criterion_metrics = nn.TripletMarginLoss()
    
for epoch in range(n_epochs):
        
    start_time = time.time()

    model.train()
    avg_loss = 0.

    optimizer.zero_grad()

    for i, (images, images_positive, images_negative, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):

        images = images.to(device)
        images_positive = images_positive.to(device)
        images_negative = images_negative.to(device)
        labels = labels.to(device)

        y_preds = model(images)
        y_features = model.forward_features(images)
        y_features_positive = model.forward_features(images_positive)
        y_features_negative = model.forward_features(images_negative)
        loss_metrics = criterion_metrics(y_features, y_features_positive, y_features_negative)
        
        loss = loss_metrics

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        wandb.log({"train/triplet loss": loss_metrics.item(), 
                   "train/loss total": loss.item(),
                   "epoch": epoch,})

        avg_loss += loss.item() / len(train_loader)
    wandb.log({"train/avg loss": avg_loss,
                "epoch": epoch,})
    model.eval()
    avg_val_loss = 0.
    preds = np.zeros((len(valid_dataset)))
    preds_raw = []

    for i, (images, images_positive, images_negative, labels) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
            
        images = images.to(device)
        images_positive = images_positive.to(device)
        images_negative = images_negative.to(device)
        labels = labels.to(device)
            
        with torch.no_grad():
            y_preds = model(images)
            y_features = model.forward_features(images)
            y_features_positive = model.forward_features(images_positive)
            y_features_negative = model.forward_features(images_negative)
            
        preds[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = y_preds.argmax(1).to('cpu').numpy()
        preds_raw.extend(y_preds.to('cpu').numpy())

        loss_metrics = criterion_metrics(y_features, y_features_positive, y_features_negative)
        
        loss = loss_metrics
        wandb.log({"val/triplet loss": loss_metrics.item(), 
                   "val/loss total": loss.item(),
                   "epoch": epoch,})
        avg_val_loss += loss.item() / len(valid_loader)
    wandb.log({"val/avg loss": avg_val_loss,
                "epoch": epoch,})     
    scheduler.step(avg_val_loss)
            
    score = f1_score(test_metadata['class_id'], preds, average='macro')
    accuracy = accuracy_score(test_metadata['class_id'], preds)
    recall_3 = top_k_accuracy_score(test_metadata['class_id'], preds_raw, k=3)
    
    elapsed = time.time() - start_time
    
    wandb.log({"val/f1-score": score, 
               "val/accuracy": accuracy, 
               "val/top-3 accuracy": recall_3,
               "epoch": epoch,
               "time elapsed": elapsed})
    
    torch.save(model.state_dict(), f'checkpoints/VIT_triplet_epoch_{epoch}.pth')

### Замораживаем слои трансформера, оставляя обучаться только классификатор

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
for param in model.head.parameters():
    param.requires_grad = True

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score

wandb.init(project="danish-fungi")

EPOCHS = 50
n_epochs = EPOCHS
lr = 1e-3
    
model.to(device)
    
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=1, verbose=True, eps=1e-6)
    
criterion_classification = nn.CrossEntropyLoss()
    
for epoch in range(n_epochs):
        
    start_time = time.time()

    model.train()
    avg_loss = 0.

    optimizer.zero_grad()

    for i, (images, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):

        images = images.to(device)
        labels = labels.to(device)

        y_preds = model(images)
        loss_classification = criterion_classification(y_preds, labels)
        
        loss = loss_classification

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        wandb.log({"train/cross-entropy": loss_classification.item(), 
                   "train/loss total": loss.item(),
                   "epoch": epoch,})

        avg_loss += loss.item() / len(train_loader)
    wandb.log({"train/avg loss": avg_loss,
                "epoch": epoch,})
    model.eval()
    avg_val_loss = 0.
    preds = np.zeros((len(valid_dataset)))
    preds_raw = []

    for i, (images, labels) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
            
        images = images.to(device)
        labels = labels.to(device)
            
        with torch.no_grad():
            y_preds = model(images)

        preds[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = y_preds.argmax(1).to('cpu').numpy()
        preds_raw.extend(y_preds.to('cpu').numpy())

        loss_classification = criterion_classification(y_preds, labels)
        
        loss = loss_classification
        wandb.log({"val/cross-entropy": loss_classification.item(), 
                   "val/loss total": loss.item(),
                   "epoch": epoch,})
        avg_val_loss += loss.item() / len(valid_loader)
    wandb.log({"val/avg loss": avg_val_loss,
                "epoch": epoch,})    
    scheduler.step(avg_val_loss)
            
    score = f1_score(test_metadata['class_id'], preds, average='macro')
    accuracy = accuracy_score(test_metadata['class_id'], preds)
    recall_3 = top_k_accuracy_score(test_metadata['class_id'], preds_raw, k=3)
    
    elapsed = time.time() - start_time
    
    wandb.log({"val/f1-score": score, 
               "val/accuracy": accuracy, 
               "val/top-3 accuracy": recall_3,
               "epoch": epoch,
               "time elapsed": elapsed})
    
    if (epoch - 9) % 10 == 0:
        torch.save(model.state_dict(), f'checkpoints/VIT_classifier_epoch_{epoch}.pth')

In [None]:
wandb.finish()

### Подсчет финальных метрик

In [None]:
preds = np.zeros((len(valid_dataset)))
preds_raw = []

for i, (images, labels) in enumerate(valid_loader):
            
    images = images.to(device)
    labels = labels.to(device)
            
    with torch.no_grad():
        y_preds = model(images)
            
    preds[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = y_preds.argmax(1).to('cpu').numpy()
    preds_raw.extend(y_preds.to('cpu').numpy())

In [None]:
score = sklearn.metrics.f1_score(test_metadata['class_id'], preds, average='macro')
accuracy = sklearn.metrics.accuracy_score(test_metadata['class_id'], preds)
recall_3 = sklearn.metrics.top_k_accuracy_score(test_metadata['class_id'], preds_raw, k=3)

### Попытка обучить на эмбеддингах классификатор на основе k-средних 

In [None]:
from sklearn.neighbors import (NeighborhoodComponentsAnalysis, KNeighborsClassifier)
from sklearn.pipeline import Pipeline

In [None]:
embeddings = []
classes = []
with torch.no_grad():
    for i, (images, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
        images = images.to(device)
        y_features = model.forward_features(images)
        embeddings.append(torch.flatten(y_features, start_dim=1).cpu().detach().numpy())
        classes.append(labels.cpu().detach().numpy())

In [None]:
embeddings_concat = np.concatenate(embeddings, axis=0)
classes_concat =  np.concatenate(classes)

In [None]:
nca = NeighborhoodComponentsAnalysis(verbose=1, n_components=128)
knn = KNeighborsClassifier(n_neighbors=7)
nca_pipe = Pipeline([('nca', nca), ('knn', knn)], verbose=True)
nca_pipe.fit(embeddings_concat[:, -768 * 2:], classes_concat)

In [None]:
embeddings_test = []
classes_test = []
with torch.no_grad():
    for i, (images, labels) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
        images = images.to(device)
        y_features = model.forward_features(images)
        embeddings_test.append(torch.flatten(y_features, start_dim=1).cpu().detach().numpy())
        classes_test.append(labels.cpu().detach().numpy())

In [None]:
embeddings_test_concat = np.concatenate(embeddings_test, axis=0)
classes_test_concat =  np.concatenate(classes_test)

In [None]:
nca_pipe.score(embeddings_test_concat[:, -768 * 2:], classes_test_concat)