In [None]:
!pip install torchinfo
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.optim import Adam
from torch import nn as nn
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import top_k_accuracy_score
import os
import timm
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, AutoImageProcessor
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
import graphviz
from torchinfo import summary
graphviz.set_jupyter_format('png')

%matplotlib inline
plt.rcParams['figure.figsize'] = [11.7, 8.27]
sns.set_theme(style='white')
pd.set_option("display.precision", 3)

device = torch.device('cuda')

HF_MODEL = True

### FMix Implementation
Taken from https://github.com/ecs-vlc/FMix

In [None]:
import math
import random

import numpy as np
from scipy.stats import beta
import torch


def fftfreqnd(h, w=None, z=None):
    """ Get bin values for discrete fourier transform of size (h, w, z)

    :param h: Required, first dimension size
    :param w: Optional, second dimension size
    :param z: Optional, third dimension size
    """
    fz = fx = 0
    fy = np.fft.fftfreq(h)

    if w is not None:
        fy = np.expand_dims(fy, -1)

        if w % 2 == 1:
            fx = np.fft.fftfreq(w)[: w // 2 + 2]
        else:
            fx = np.fft.fftfreq(w)[: w // 2 + 1]

    if z is not None:
        fy = np.expand_dims(fy, -1)
        if z % 2 == 1:
            fz = np.fft.fftfreq(z)[:, None]
        else:
            fz = np.fft.fftfreq(z)[:, None]

    return np.sqrt(fx * fx + fy * fy + fz * fz)


def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
    """ Samples a fourier image with given size and frequencies decayed by decay power

    :param freqs: Bin values for the discrete fourier transform
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param ch: Number of channels for the resulting mask
    :param h: Required, first dimension size
    :param w: Optional, second dimension size
    :param z: Optional, third dimension size
    """
    scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power)

    param_size = [ch] + list(freqs.shape) + [2]
    param = np.random.randn(*param_size)

    scale = np.expand_dims(scale, -1)[None, :]

    return scale * param


def make_low_freq_image(decay, shape, ch=1):
    """ Sample a low frequency image from fourier space

    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param shape: Shape of desired mask, list up to 3 dims
    :param ch: Number of channels for desired mask
    """
    freqs = fftfreqnd(*shape)
    spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1))
    spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
    mask = np.real(np.fft.irfftn(spectrum, shape))

    if len(shape) == 1:
        mask = mask[:1, :shape[0]]
    if len(shape) == 2:
        mask = mask[:1, :shape[0], :shape[1]]
    if len(shape) == 3:
        mask = mask[:1, :shape[0], :shape[1], :shape[2]]

    mask = mask
    mask = (mask - mask.min())
    mask = mask / mask.max()
    return mask


def sample_lam(alpha, reformulate=False):
    """ Sample a lambda from symmetric beta distribution with given alpha

    :param alpha: Alpha value for beta distribution
    :param reformulate: If True, uses the reformulation of [1].
    """
    if reformulate:
        lam = beta.rvs(alpha+1, alpha)
    else:
        lam = beta.rvs(alpha, alpha)

    return lam


def binarise_mask(mask, lam, in_shape, max_soft=0.0):
    """ Binarises a given low frequency image such that it has mean lambda.

    :param mask: Low frequency image, usually the result of `make_low_freq_image`
    :param lam: Mean value of final mask
    :param in_shape: Shape of inputs
    :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
    :return:
    """
    idx = mask.reshape(-1).argsort()[::-1]
    mask = mask.reshape(-1)
    num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size)

    eff_soft = max_soft
    if max_soft > lam or max_soft > (1-lam):
        eff_soft = min(lam, 1-lam)

    soft = int(mask.size * eff_soft)
    num_low = num - soft
    num_high = num + soft

    mask[idx[:num_high]] = 1
    mask[idx[num_low:]] = 0
    mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))

    mask = mask.reshape((1, *in_shape))
    return mask


def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False):
    """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
    it based on this lambda

    :param alpha: Alpha value for beta distribution from which to sample mean of mask
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param shape: Shape of desired mask, list up to 3 dims
    :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
    :param reformulate: If True, uses the reformulation of [1].
    """
    if isinstance(shape, int):
        shape = (shape,)

    # Choose lambda
    lam = sample_lam(alpha, reformulate)

    # Make mask, get mean / std
    mask = make_low_freq_image(decay_power, shape)
    mask = binarise_mask(mask, lam, shape, max_soft)

    return lam, mask


def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
	"""

	:param x: Image batch on which to apply fmix of shape [b, c, shape*]
	:param alpha: Alpha value for beta distribution from which to sample mean of mask
	:param decay_power: Decay power for frequency decay prop 1/f**d
	:param shape: Shape of desired mask, list up to 3 dims
	:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
	:param reformulate: If True, uses the reformulation of [1].
	:return: mixed input, permutation indices, lambda value of mix,
	"""
	lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
	index = np.random.permutation(x.shape[0])

	x1, x2 = x * mask, x[index] * (1-mask)
	return x1+x2, index, lam


class FMix:
	r""" FMix augmentation

		Args:
			decay_power (float): Decay power for frequency decay prop 1/f**d
			alpha (float): Alpha value for beta distribution from which to sample mean of mask
			size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
			max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
			reformulate (bool): If True, uses the reformulation of [1].
	"""

	def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
		super().__init__()
		self.decay_power = decay_power
		self.reformulate = reformulate
		self.size = size
		self.alpha = alpha
		self.max_soft = max_soft
		self.index = None
		self.lam = None
  
	def __call__(self, x):
		x = x.cpu().numpy()
		x, index, lam = sample_and_apply(x, self.alpha, self.decay_power, self.size, self.max_soft, self.reformulate)
		x = torch.Tensor(x)
		return x, index, lam

	def loss(self, *args, **kwargs):
		raise NotImplementedError

### Hyperparams

In [None]:
timm.list_models(filter = "*vit_small_patch16*", pretrained = True)
# model_name = 'tiny_vit_5m_224.dist_in22k_ft_in1k'
# model_name = 'vit_small_patch16_224.augreg_in21k'
model_name = 'google/mobilenet_v2_0.75_160'
num_classes = 7
batch_size = 32
num_epochs = 25
learning_rate = 0.0002
online_augmentation_function = None # Possible values: None, 'fmix', 'cutmix, 'mixup', 'cutout'

### Train function

In [None]:
num_to_class = {}

def calculate_loss(model, criterion, test):
    with torch.no_grad():
        losses = []
        for features, labels in test:
            features = features.to(device)
            labels = labels.to(device)
            
            if not HF_MODEL:
                y = model(features)
            else:
                y = model(features).logits
            loss = criterion(y, labels)
            losses.append(loss.item())
        return np.mean(losses)

def predict_multiple(model, test):
    with torch.no_grad():
        outputs_raw = []
        outputs = []
        golden = []
        for features, labels in test:
            features = features.to(device)
            
            if not HF_MODEL:
                y = model(features)
            else:
                y = model(features).logits
            
            outputs_raw.append(y.cpu().numpy())
            outputs.append(torch.argmax(y, dim = 1).cpu().numpy())
            golden.append(labels.numpy())
        return np.concatenate(golden), np.concatenate(outputs), np.concatenate(outputs_raw)

def train_network(model : nn.Module, optimizer : torch.optim.Optimizer, criterion, train, test, epochs = 50, use_scheduler = True):
    if use_scheduler:
       scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30, 45, 60, 75, 90], gamma=0.25)
    
    train_losses = []
    test_losses = []
    
    table = {'Accuracy': [], 'Accuracy-2': []}
    for label in range(num_classes):
        table['Precision ' + num_to_class[label]] = []
        table['Recall ' + num_to_class[label]] = []
        table['F1 ' + num_to_class[label]] = []
    
    for epoch in range(1, epochs + 1):
        epoch_train_losses = []
        
        for features, labels in train:
            features = features.to(device)
            labels = labels.to(device)
            
            if not HF_MODEL:
                y = model(features)
            else:
                y = model(features).logits
            loss = criterion(y, labels)
            model.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_train_losses.append(loss.item())
        
        train_losses.append(np.mean(epoch_train_losses))
        test_losses.append(calculate_loss(model, criterion, test))
        
        golden, outputs, outputs_raw = predict_multiple(model, test)
        
        prfs = precision_recall_fscore_support(golden, outputs, zero_division = 0)
        
        table['Accuracy'].append(accuracy_score(golden, outputs))
        table['Accuracy-2'].append(top_k_accuracy_score(golden, outputs_raw, k = 2))
        for i, label in enumerate(range(num_classes)):
            table['Precision ' + num_to_class[label]].append(prfs[0][i])
            table['Recall ' + num_to_class[label]].append(prfs[1][i])
            table['F1 ' + num_to_class[label]].append(prfs[2][i])
        
        last_acc = table['Accuracy'][-1]
        last_acc_2 = table['Accuracy-2'][-1]
        print(f'Epoch {epoch}: Train loss {train_losses[-1]}, Test loss {test_losses[-1]}, Test acc {last_acc}, Test acc-2 {last_acc_2}')
            
        if use_scheduler:
            scheduler.step()
    
    return train_losses, test_losses, table
        

In [None]:
if not HF_MODEL:
    model = timm.create_model(model_name, pretrained = True, num_classes = num_classes, drop_rate = 0.1).to(device)
    timm_data_config = timm.data.resolve_data_config({}, model=model)
    print(timm_data_config)
else:
    model = AutoModelForImageClassification.from_pretrained(model_name, num_labels = num_classes, ignore_mismatched_sizes = True).to(device)

###

In [None]:
if not HF_MODEL:
    processor = transforms.Normalize(mean = timm_data_config['mean'], std = timm_data_config['std'])
else:
#     processor = transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
#     processor = AutoImageProcessor.from_pretrained(model_name)
      processor = transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))

def transform_images(x):
    x = x.resize((160, 160))
    x = transforms.ToTensor()(x)
    # x = transforms.RandomRotation(180)(x)
    x = transforms.RandomHorizontalFlip(0.5)(x)
    # x = transforms.RandomHorizontalFlip(0.25)(x)
    # x = transforms.RandomVerticalFlip(0.25)(x)
    x = processor(x)
    return x

processor

### Mixup

In [None]:
def mixup(x, y):
    lam = np.random.beta(0.8, 0.8)
    indices = torch.randperm(x.shape[0])
        
    shuffled_x = x[indices]
    shuffled_y = y[indices]
    
    new_x = x * lam + shuffled_x * (1 - lam)
    new_y = y * lam + shuffled_y * (1 - lam)
    
    return new_x, new_y

### Cutmix

In [None]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W // 4, W - W // 4)
    cy = np.random.randint(H // 4, H - H // 4)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def cutmix(x, y):
    lam = np.random.beta(0.8, 0.8)
    indices = torch.randperm(x.shape[0])
    
    shuffled_x = x[indices]
    shuffled_y = y[indices]
  
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.shape, lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = shuffled_x[:, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.shape[-1] * x.shape[-2]))
    new_y = y * lam + shuffled_y * (1 - lam)
 
    #plt.imshow(x[0].permute(1, 2, 0) / 2 + 0.5)
    #plt.show()

    return x, new_y

### Cutout

In [None]:
def cutout(x, y):
    lam = np.random.beta(0.8, 0.8)
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.shape, lam)
    
    x[:, :, bbx1:bbx2, bby1:bby2] = 0
    
    return x, y

### FMix

In [None]:
def fmix(x, y):
    new_x, index, lam = FMix(size = x.shape[2 :])(x)
    new_y = y * lam + y[index] * (1 - lam)
    
    return new_x, new_y

In [None]:
augmentation_functions = {'mixup': mixup, 'cutmix': cutmix, 'cutout': cutout, 'fmix': fmix}

In [None]:
augment_prob = 0.5

def collate_fn(batch):
    x = torch.stack([sample[0] for sample in batch])
    y = torch.stack([nn.functional.one_hot(torch.tensor(sample[1]), num_classes = num_classes).float() for sample in batch])
    
    if online_augmentation_function is not None:
        p = np.random.rand()
        if p < augment_prob:
            new_x, new_y = augmentation_functions[online_augmentation_function](x, y)
        else:
            new_x, new_y = x, y
    else:
        new_x, new_y = x, y

    return new_x, new_y

In [None]:
# DATASET_DIR = '/kaggle/input/face-expression-recognition-dataset/images'
DATASET_DIR = '/kaggle/input/stargan-fer-augmented/images'

train = torchvision.datasets.ImageFolder(os.path.join(DATASET_DIR, 'train'), transform_images)
val = torchvision.datasets.ImageFolder(os.path.join(DATASET_DIR, 'validation'), transform_images)

test = val

#hflip_aug = FlipAugment(train, horizontal = True)
#vflip_aug = FlipAugment(train, horizontal = False)

print(train.classes)
print(test.classes)
print(val.classes)

assert train.classes == test.classes and test.classes == val.classes

num_to_class = {i : c for i, c in enumerate(train.classes)}

#final_train_set = torch.utils.data.ConcatDataset([train, hflip_aug, vflip_aug])
final_train_set = train
train_data = torch.utils.data.DataLoader(final_train_set, batch_size = batch_size, shuffle = True, collate_fn = collate_fn)
test_data = torch.utils.data.DataLoader(test, batch_size = batch_size)
val_data = torch.utils.data.DataLoader(val, batch_size = batch_size)

print(next(iter(train_data))[0][0].min())

optim = Adam(model.parameters(), lr = learning_rate)
criterion = torch.nn.CrossEntropyLoss()

train_losses, test_losses, scores = train_network(model, optim, criterion, train_data, val_data, num_epochs)

df = pd.DataFrame({'train loss': train_losses, 'dev loss': test_losses})
df.to_csv('losses.csv', sep = ',', index = False)
sns.lineplot(df, dashes = False)
plt.show()

df = pd.DataFrame(scores)
df.to_csv('dev_scores.csv', sep = ',', index = False)
avg_df = pd.DataFrame()
avg_df['Accuracy'] = df['Accuracy']
avg_df['Accuracy-2'] = df['Accuracy-2']
metrics = ['Precision', 'Recall', 'F1']
for m in metrics:
    s = None
    for label in range(num_classes):
        if s is None:
            s = df[m + ' ' + num_to_class[label]]
        else:
            s = s + df[m + ' ' + num_to_class[label]]
    s = s / num_classes
    avg_df[m] = s
    
sns.lineplot(avg_df, dashes = False)
plt.show()

In [None]:
model.eval()

table = {'Accuracy': [], 'Accuracy-2': []}
for label in range(num_classes):
    table['Precision ' + num_to_class[label]] = []
    table['Recall ' + num_to_class[label]] = []
    table['F1 ' + num_to_class[label]] = []

golden, outputs, outputs_raw = predict_multiple(model, test_data)
        
prfs = precision_recall_fscore_support(golden, outputs, zero_division = 0)

table['Accuracy'].append(accuracy_score(golden, outputs))
table['Accuracy-2'].append(top_k_accuracy_score(golden, outputs_raw))
for i, label in enumerate(range(num_classes)):
    table['Precision ' + num_to_class[label]].append(prfs[0][i])
    table['Recall ' + num_to_class[label]].append(prfs[1][i])
    table['F1 ' + num_to_class[label]].append(prfs[2][i])

df = pd.DataFrame(table)
df.to_csv('test_scores.csv', sep = ',', index = False)
df

In [None]:
if not HF_MODEL:
    torch.save(model.state_dict(), model_name + '_full.pth')
    model.reset_classifier(0)
    torch.save(model.state_dict(), model_name + '_feats.pth')
else:
    torch.save(model.state_dict(), model_name.split('/')[1] + '_full.pth')
    model.classifier = torch.nn.Identity()
    torch.save(model.state_dict(), model_name.split('/')[1] + '_feats.pth')

### OOD

In [None]:
OOD_PARAMS_DIR = 'ood_params'

if not os.path.exists(OOD_PARAMS_DIR):
    os.makedirs(OOD_PARAMS_DIR)
    
ood_data = torch.utils.data.DataLoader(torch.utils.data.ConcatDataset([final_train_set, val]), batch_size = 32)

if not HF_MODEL:
    latent_feats_dim = model.forward(torch.randn(1, 3, 224, 224).to(device)).shape[-1]
else:
    latent_feats_dim = model.forward(torch.randn(1, 3, 224, 224).to(device)).logits.shape[-1]

# prepare mahalanobis distance params
mean_feature_maps = torch.zeros((num_classes, latent_feats_dim)).to(device)
mean_feature_map_0 = torch.zeros(latent_feats_dim).to(device)
Nk = torch.zeros(num_classes).to(device)
covar = torch.zeros((latent_feats_dim, latent_feats_dim)).to(device)
covar_0 = torch.zeros((latent_feats_dim, latent_feats_dim)).to(device)

def save_tensor(x, file_path):
    if len(x.shape) == 1:
        x = x[None, :]
    
    with open(file_path, "w") as f:
        f.write(f"{x.shape[0]}\n{x.shape[1]}\n")
        for i in range(x.shape[0]):
            for j in range(x.shape[1]):
                f.write(f"{x[i, j]}\n")

with torch.no_grad():
    # Calculate mean feature maps
    print("Calculating mean feature maps...")
    for batch, label in ood_data:
        batch = batch.to(device)
        
        if not HF_MODEL:
            z = model(batch)
        else:
            z = model(batch).logits
        # print(z.shape)
        for feature_map, label_num in zip(z, label):
            mean_feature_maps[label_num] += feature_map
            Nk[label_num] += 1
    
    mean_feature_map_0 = torch.sum(mean_feature_maps, dim = 0) / torch.sum(Nk)
    mean_feature_maps = (mean_feature_maps.T / Nk).T
    torch.save(mean_feature_map_0, "ood_params/mean_feature_map_0.pt")
    save_tensor(mean_feature_map_0, "ood_params/mean_feature_map_0.matrix")
    torch.save(mean_feature_maps, "ood_params/mean_feature_maps.pt")
    save_tensor(mean_feature_maps, "ood_params/mean_feature_maps.matrix")
    
    # Calculate covariance matrices
    print("Calculating covariance matrices")
    for batch, label in ood_data:
        batch = batch.to(device)
        
        if not HF_MODEL:
            z = model(batch)
        else:
            z = model(batch).logits
        for feature_map, label_num in zip(z, label):
            feature_map_adjusted_0 = (feature_map - mean_feature_map_0)[None, :]
            feature_map_adjusted = (feature_map - mean_feature_maps[label_num])[None, :]
            covar_0 += feature_map_adjusted_0.T @ feature_map_adjusted_0
            covar += feature_map_adjusted.T @ feature_map_adjusted
    
    covar_0 = covar_0 / torch.sum(Nk)
    covar = covar / torch.sum(Nk)
    torch.save(torch.linalg.pinv(covar_0.cpu()), "ood_params/covar_0_inverse.pt")
    save_tensor(torch.linalg.pinv(covar_0.cpu()), "ood_params/covar_0_inverse.matrix")
    torch.save(torch.linalg.pinv(covar.cpu()), "ood_params/covar_inverse.pt")
    save_tensor(torch.linalg.pinv(covar.cpu()), "ood_params/covar_inverse.matrix")

In [None]:
summary(model, input_size=(1, 3, 224, 224))