In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from vit_pytorch import ViT
from torch.optim import Adam
from torch import nn as nn
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import top_k_accuracy_score
import os
import timm
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
from torchview import draw_graph
import graphviz
from torchsummary import summary
from einops import rearrange
from PIL import Image
graphviz.set_jupyter_format('png')

%matplotlib inline
plt.rcParams['figure.figsize'] = [11.7, 8.27]
sns.set_theme(style='white')
pd.set_option("display.precision", 3)

device = torch.device('cuda')
HF_MODEL = False

In [None]:
import math
import random

import numpy as np
from scipy.stats import beta
import torch


def fftfreqnd(h, w=None, z=None):
    """ Get bin values for discrete fourier transform of size (h, w, z)

    :param h: Required, first dimension size
    :param w: Optional, second dimension size
    :param z: Optional, third dimension size
    """
    fz = fx = 0
    fy = np.fft.fftfreq(h)

    if w is not None:
        fy = np.expand_dims(fy, -1)

        if w % 2 == 1:
            fx = np.fft.fftfreq(w)[: w // 2 + 2]
        else:
            fx = np.fft.fftfreq(w)[: w // 2 + 1]

    if z is not None:
        fy = np.expand_dims(fy, -1)
        if z % 2 == 1:
            fz = np.fft.fftfreq(z)[:, None]
        else:
            fz = np.fft.fftfreq(z)[:, None]

    return np.sqrt(fx * fx + fy * fy + fz * fz)


def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
    """ Samples a fourier image with given size and frequencies decayed by decay power

    :param freqs: Bin values for the discrete fourier transform
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param ch: Number of channels for the resulting mask
    :param h: Required, first dimension size
    :param w: Optional, second dimension size
    :param z: Optional, third dimension size
    """
    scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power)

    param_size = [ch] + list(freqs.shape) + [2]
    param = np.random.randn(*param_size)

    scale = np.expand_dims(scale, -1)[None, :]

    return scale * param


def make_low_freq_image(decay, shape, ch=1):
    """ Sample a low frequency image from fourier space

    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param shape: Shape of desired mask, list up to 3 dims
    :param ch: Number of channels for desired mask
    """
    freqs = fftfreqnd(*shape)
    spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1))
    spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
    mask = np.real(np.fft.irfftn(spectrum, shape))

    if len(shape) == 1:
        mask = mask[:1, :shape[0]]
    if len(shape) == 2:
        mask = mask[:1, :shape[0], :shape[1]]
    if len(shape) == 3:
        mask = mask[:1, :shape[0], :shape[1], :shape[2]]

    mask = mask
    mask = (mask - mask.min())
    mask = mask / mask.max()
    return mask


def sample_lam(alpha, reformulate=False):
    """ Sample a lambda from symmetric beta distribution with given alpha

    :param alpha: Alpha value for beta distribution
    :param reformulate: If True, uses the reformulation of [1].
    """
    if reformulate:
        lam = beta.rvs(alpha+1, alpha)
    else:
        lam = beta.rvs(alpha, alpha)

    return lam


def binarise_mask(mask, lam, in_shape, max_soft=0.0):
    """ Binarises a given low frequency image such that it has mean lambda.

    :param mask: Low frequency image, usually the result of `make_low_freq_image`
    :param lam: Mean value of final mask
    :param in_shape: Shape of inputs
    :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
    :return:
    """
    idx = mask.reshape(-1).argsort()[::-1]
    mask = mask.reshape(-1)
    num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size)

    eff_soft = max_soft
    if max_soft > lam or max_soft > (1-lam):
        eff_soft = min(lam, 1-lam)

    soft = int(mask.size * eff_soft)
    num_low = num - soft
    num_high = num + soft

    mask[idx[:num_high]] = 1
    mask[idx[num_low:]] = 0
    mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))

    mask = mask.reshape((1, *in_shape))
    return mask


def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False):
    """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
    it based on this lambda

    :param alpha: Alpha value for beta distribution from which to sample mean of mask
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param shape: Shape of desired mask, list up to 3 dims
    :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
    :param reformulate: If True, uses the reformulation of [1].
    """
    if isinstance(shape, int):
        shape = (shape,)

    # Choose lambda
    lam = sample_lam(alpha, reformulate)

    # Make mask, get mean / std
    mask = make_low_freq_image(decay_power, shape)
    mask = binarise_mask(mask, lam, shape, max_soft)

    return lam, mask


def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
	"""

	:param x: Image batch on which to apply fmix of shape [b, c, shape*]
	:param alpha: Alpha value for beta distribution from which to sample mean of mask
	:param decay_power: Decay power for frequency decay prop 1/f**d
	:param shape: Shape of desired mask, list up to 3 dims
	:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
	:param reformulate: If True, uses the reformulation of [1].
	:return: mixed input, permutation indices, lambda value of mix,
	"""
	lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
	index = np.random.permutation(x.shape[0])

	x1, x2 = x * mask, x[index] * (1-mask)
	return x1+x2, index, lam


class FMix:
	r""" FMix augmentation

		Args:
			decay_power (float): Decay power for frequency decay prop 1/f**d
			alpha (float): Alpha value for beta distribution from which to sample mean of mask
			size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
			max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
			reformulate (bool): If True, uses the reformulation of [1].
	"""

	def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
		super().__init__()
		self.decay_power = decay_power
		self.reformulate = reformulate
		self.size = size
		self.alpha = alpha
		self.max_soft = max_soft
		self.index = None
		self.lam = None
  
	def __call__(self, x):
		x = x.cpu().numpy()
		x, index, lam = sample_and_apply(x, self.alpha, self.decay_power, self.size, self.max_soft, self.reformulate)
		x = torch.Tensor(x)
		return x, index, lam

	def loss(self, *args, **kwargs):
		raise NotImplementedError

In [None]:
timm.list_models(filter = "*vit_small_patch16*", pretrained = True)
model_name = 'vit_small_patch16_224.augreg_in21k'
organistation = 'google' # For HF models
image1_name = '../datasets/augmented/images/train/happy/25.jpg'
image2_name = '../datasets/augmented/images/train/disgust/synthetic_0.jpg'
num_classes = 7

In [None]:
if not HF_MODEL:
    model = timm.create_model(model_name, pretrained = True, num_classes = num_classes, drop_rate = 0.1).to(device)
    timm_data_config = timm.data.resolve_data_config({}, model = model)
    processor = transforms.Normalize(mean = timm_data_config['mean'], std = timm_data_config['std'])
else:
    # processor = transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
    # processor = AutoImageProcessor.from_pretrained(model_name)
    processor = transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))

def transform_images(x):
    x = x.resize((224, 224))
    x = transforms.ToTensor()(x)
    
    # for some weird reason here it needs to be extended to 3 channels
    if x.shape[0] == 1:
        x = x.repeat(3, 1, 1)
    x = processor(x)
    return x

In [None]:
def mixup(x, y):
    lam = np.random.beta(0.8, 0.8)
    indices = torch.Tensor([1, 0]).long()
        
    shuffled_x = x[indices]
    shuffled_y = y[indices]
    
    new_x = x * lam + shuffled_x * (1 - lam)
    new_y = y * lam + shuffled_y * (1 - lam)
    
    return new_x, new_y

In [None]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W // 4, W - W // 4)
    cy = np.random.randint(H // 4, H - H // 4)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def cutmix(x, y):
    lam = np.random.beta(0.8, 0.8)
    indices = torch.randperm(x.shape[0])
    
    shuffled_x = x[indices]
    shuffled_y = y[indices]
  
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.shape, lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = shuffled_x[:, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.shape[-1] * x.shape[-2]))
    new_y = y * lam + shuffled_y * (1 - lam)
 
    #plt.imshow(x[0].permute(1, 2, 0) / 2 + 0.5)
    #plt.show()

    return x, new_y

In [None]:
def cutout(x, y):
    lam = np.random.beta(0.8, 0.8)
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.shape, lam)
    
    x[:, :, bbx1:bbx2, bby1:bby2] = 0
    
    return x, y

In [None]:
def fmix(x, y):
    new_x, index, lam = FMix(size = x.shape[2 :])(x)
    new_y = y * lam + y[index] * (1 - lam)
    
    return new_x.to(device), new_y

In [None]:
augmentation_functions = {'mixup': mixup, 'cutmix': cutmix, 'cutout': cutout, 'fmix': fmix}
# Change function here
used_aug_function = 'cutout'

In [None]:
if not HF_MODEL:
    model = timm.create_model(model_name, pretrained = True, num_classes = num_classes, drop_rate = 0.1).to(device)
else:
    model = AutoModelForImageClassification.from_pretrained(model_name, num_labels = num_classes, ignore_mismatched_sizes = True).to(device)

model.load_state_dict(torch.load(f'../fer/saved_models/{model_name}_full.pth'))
model.eval()
classifier = model.get_classifier()
print(classifier.weight.shape)

#print(type(model))
#timm.models.vision_transformer.VisionTransformer

with torch.no_grad():
    image1 = Image.open(image1_name)
    image1 = transform_images(image1)[None, :]
    image1 = image1.to(device)
    
    image2 = Image.open(image2_name)
    image2 = transform_images(image2)[None, :]
    image2 = image2.to(device)
    
    x = torch.concat([image1, image2])
    
    x, y = augmentation_functions[used_aug_function](x, torch.zeros(2))
    image = x[0:1]
    
    if not HF_MODEL:
        feats = model.forward_features(image)
        out = model.forward_head(feats)
        print(out)
        class_out = torch.argmax(out, 1).item()
        weights = classifier.weight[class_out]
    
    feats = feats[:, 1:, :]
    feats = rearrange(feats, 'b (h w) c -> b c h w', h = 14, w = 14)
    
    cam = torch.zeros((feats.shape[-1], feats.shape[-1])).to(device)
    for feat, weight in zip(feats[0], weights):
        cam += feat * weight
    
    cam = transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC)(cam[None, :])[0]
    cam = ((cam - cam.min()) / (cam.max() - cam.min())).cpu()
    cam = plt.cm.jet(cam)[:, :, :3]
    image = (image[0] * 0.5 + 0.5).permute(1, 2, 0).cpu()
    
    #plt.imshow(image * 0.5 + cam * 0.5)
    _, ax = plt.subplots(1, 2)
    ax[0].axis('off')
    ax[0].imshow(image)
    ax[1].axis('off')
    ax[1].imshow(image * 0.5 + cam * 0.5)
    
    # for module_name in model._modules:
    #     module = model._modules[module_name]
    #     print(x.shape)
    #     x = module(x)
    #     print(f"After going through {module_name}: {x.shape}")
if not os.path.exists('saved_images'):
    os.makedirs('saved_images')
plt.savefig(f'saved_images/cam-{used_aug_function}2.png')