# Import e path

In [None]:
results_save_path = "../Results/"
dataset_path = "../Datasets/"

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch import optim, nn
from torch.nn import functional
import os
import time
import csv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shutil
from torchvision.utils import make_grid
from random import randint
from PIL import Image
import random

from einops import rearrange
from einops.layers.torch import Rearrange

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


import timm

from TRAM import TRAM
from ViT import ViT
from TopK import TopKViT
from PatchMergerViT import PatchMergerViT
from ATSViT import ATSViT

<h3> Validi per ogni modello

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

In [None]:
# Verifica se la GPU è disponibile
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# Funzioni per train e validation del modello

<h3> Funzione per il Training

In [None]:
def train_iter(model, optimz, data_load, loss_val, device):
    samples = len(data_load.dataset)
    model.train()

    for i, (data, target) in enumerate(data_load):
        data = data.to(device)
        target = target.to(device)

        optimz.zero_grad()
        out = functional.log_softmax(model(data), dim=1)
        loss = functional.nll_loss(out, target)
        loss.backward()
        optimz.step()

        if i % 100 == 0:
            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_load)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
    loss_val.append(loss.item())

<h3> Funzione per la validation

In [None]:
def evaluate(model, optimizer, data_load, loss_val, device):
    model.eval()

    samples = len(data_load.dataset)
    # predizioni corrette
    csamp = 0
    tloss = 0

    with torch.no_grad():
        for data, target in data_load:

            data = data.to(device)
            target = target.to(device)

            output = functional.log_softmax(model(data), dim=1)
            loss = functional.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)

            tloss += loss.item()
            csamp += pred.eq(target).sum()

    aloss = tloss / samples
    loss_val.append(aloss)
    acc = (100.0 * csamp / samples).cpu()

    print('\nAverage test loss: ' + '{:.4f}'.format(aloss) +
          '  Accuracy:' + '{:5}'.format(csamp) + '/' +
          '{:5}'.format(samples) + ' (' +
          '{:4.2f}'.format(acc) + '%)\n')

    return acc

<h3> Funzione per allenare e validare il modello

In [None]:
def train_validation(model, optimizer, train_loader, validation_loader, nome_file, epoche, device):
  tr_loss, ts_loss, ts_acc, epoch_time_list = [], [], [], []

  for epoch in range(1, epoche + 1):

      start_time = time.time()

      print(f'Epoch: {epoch}/{epoche}')
      print("INIZIO TRAINING")
      train_iter(model, optimizer, train_loader, tr_loss, device)
      print("INIZIO VALIDATION")
      acc = evaluate(model, optimizer, validation_loader, ts_loss, device)          

      ts_acc.append(acc)


      epoch_time = time.time() - start_time
      epoch_time_list.append(epoch_time)

      print('Execution time:', '{:5.2f}'.format(epoch_time), 'seconds')
      print("#"*40)

  return tr_loss, ts_loss, ts_acc, epoch_time_list

In [None]:
def create_patch_list(total_patches, cut):
    # Calcola l'importo scontato per ogni blocco di 3 elementi
    discounted_patches = total_patches
    patch_list = []
    for i in range(12):
        if i % 3 == 0 and i != 0:
            discounted_patches = int(discounted_patches * (cut / 100))
        patch_list.append(discounted_patches)
    return patch_list



def performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size):

    set_seed(42)

    patch = create_patch_list(total_patches, cut)

    if modello == 'PatchMergerViT':
        patch = [(2,patch[3]),(5,patch[6]),(8,patch[9])]

    model = get_weights(modello, patch, num_classes, model_timm, dim, heads, img_size, patch_size)
    print(f'{results_save_path}{nome_file}/{modello}_{size}_{cut}%.csv')

    # Sposta il modello sulla GPU (se disponibile)
    model.to(device)

    # definiamo l'ottimizzatore
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    train_loss, validation_loss, validation_acc, epoch_time = train_validation(model, optimizer, train_loader, validation_loader, nome_file, epoche, device = device)

    df = pd.DataFrame({'train_loss': train_loss,
                       'validation_loss': validation_loss,
                       'validation_acc': [tensor.item() for tensor in validation_acc],
                       'epoch_time': epoch_time
                       })

    return df.to_csv(f'{results_save_path}{nome_file}/{modello}_{size}_{cut}%.csv', index=False)

# Funzioni per importare i pesi

In [None]:
def get_weights(modello, n_patch, num_classes, model_timm, dim, heads, image_size, patch_size):

    # Definizione dei parametri del modello ViT
    depth = 12
    mlp_dim = dim * 4

    # Creazione del modello ViT utilizzando vit_pytorch
    if modello == 'TRAM':

        model_vit = TRAM(
            image_size=image_size,
            patch_size=patch_size,
            num_classes=num_classes,
            dim=dim,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim,
            n_patch = n_patch
        )

    elif modello == 'ATSViT':

        model_vit = ATSViT(
            image_size=image_size,
            patch_size=patch_size,
            num_classes=num_classes,
            dim=dim,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim,
            max_tokens_per_depth = n_patch
        )

    elif modello == 'PatchMergerViT':

        model_vit = PatchMergerViT(
            image_size=image_size,
            patch_size=patch_size,
            num_classes=num_classes,
            dim=dim,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim,
            patch_merge_layers = n_patch
        )

    elif modello == 'TopKViT':

        model_vit = TopKViT(
            image_size=image_size,
            patch_size=patch_size,
            num_classes=num_classes,
            dim=dim,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim,
            n_patch = n_patch
        )

    elif modello == 'ViT':

        model_vit = ViT(
            image_size=image_size,
            patch_size=patch_size,
            num_classes=num_classes,
            dim=dim,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim,
        )

    # Ottieni i pesi del modello
    model_timm_weights = model_timm.state_dict()

    # Ottieni i pesi del modello
    model_vit_weights = model_vit.state_dict()


    i = 0

    for elem_timm, elem_vit in zip(model_timm.blocks, model_vit.transformer.layers):
        model_timm_weights[f'transformer.layers.{i}.0.norm.weight'] = model_timm_weights.pop(f'blocks.{i}.norm1.weight')
        model_timm_weights[f'transformer.layers.{i}.0.norm.bias'] = model_timm_weights.pop(f'blocks.{i}.norm1.bias')

        model_timm_weights[f'transformer.layers.{i}.0.to_qkv.weight'] = model_timm_weights.pop(f'blocks.{i}.attn.qkv.weight')
        model_timm_weights.pop(f'blocks.{i}.attn.qkv.bias')

        model_timm_weights[f'transformer.layers.{i}.0.to_out.0.weight'] = model_timm_weights.pop(f'blocks.{i}.attn.proj.weight')
        model_timm_weights[f'transformer.layers.{i}.0.to_out.0.bias'] = model_timm_weights.pop(f'blocks.{i}.attn.proj.bias')


        model_timm_weights[f'transformer.layers.{i}.1.net.0.weight'] = model_timm_weights.pop(f'blocks.{i}.norm2.weight')
        model_timm_weights[f'transformer.layers.{i}.1.net.0.bias'] = model_timm_weights.pop(f'blocks.{i}.norm2.bias')

        model_timm_weights[f'transformer.layers.{i}.1.net.1.weight'] = model_timm_weights.pop(f'blocks.{i}.mlp.fc1.weight')
        model_timm_weights[f'transformer.layers.{i}.1.net.1.bias'] = model_timm_weights.pop(f'blocks.{i}.mlp.fc1.bias')

        model_timm_weights[f'transformer.layers.{i}.1.net.4.weight'] = model_timm_weights.pop(f'blocks.{i}.mlp.fc2.weight')
        model_timm_weights[f'transformer.layers.{i}.1.net.4.bias'] = model_timm_weights.pop(f'blocks.{i}.mlp.fc2.bias')

        i += 1

    model_timm_weights[f'mlp_head.weight'] = model_timm_weights.pop(f'head.weight')
    model_timm_weights[f'mlp_head.bias'] = model_timm_weights.pop(f'head.bias')

    model_timm_weights[f'transformer.norm.weight'] = model_timm_weights.pop(f'norm.weight')
    model_timm_weights[f'transformer.norm.bias'] = model_timm_weights.pop(f'norm.bias')

    model_timm_weights[f'pos_embedding'] = model_timm_weights.pop(f'pos_embed')

    if modello == 'PatchMergerViT':
        for elem in ["cls_token", "mlp_head.weight", "mlp_head.bias"]:
            model_timm_weights.pop(elem)
        for elem in ["transformer.patch_mergers.2.queries", "transformer.patch_mergers.2.norm.weight", "transformer.patch_mergers.2.norm.bias", "transformer.patch_mergers.5.queries", "transformer.patch_mergers.5.norm.weight", "transformer.patch_mergers.5.norm.bias", "transformer.patch_mergers.8.queries", "transformer.patch_mergers.8.norm.weight", "transformer.patch_mergers.8.norm.bias", "mlp_head.1.weight", "mlp_head.1.bias"]:
            model_timm_weights[elem] = model_vit_weights[elem]

    if modello == 'ATSViT':
        model_timm_weights["mlp_head.0.weight"] = model_timm_weights.pop('transformer.norm.weight')
        model_timm_weights["mlp_head.0.bias"] = model_timm_weights.pop("transformer.norm.bias")
        model_timm_weights["mlp_head.1.weight"] = model_timm_weights.pop("mlp_head.weight")
        model_timm_weights["mlp_head.1.bias"] = model_timm_weights.pop("mlp_head.bias")

    model_vit.load_state_dict(model_timm_weights)


    return model_vit

# Imagenette

## Dataset e Modello

<h3> Parametri del dataset

In [None]:
batch_size = 64
img_size = 224
patch_size = 16
nome_file = "Imagenette_pretrained"

## Divisione tra train e test

In [None]:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

transform_train = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_validation = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])


# Prepare dataset
trainset = torchvision.datasets.Imagenette(root=dataset_path, split='train', transform=transform_train) #download=True,
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

validationset = torchvision.datasets.Imagenette(root=dataset_path, split='val', transform=transform_validation) #download=True

validation_loader = torch.utils.data.DataLoader(validationset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

classes = trainset.classes
print(classes)

## Inizializzazione

In [None]:
epoche = 5
learning_rate = 0.0001
num_classes = 10
cut_list = [75, 50, 25]
total_patches = int(img_size/patch_size)**2

## Base

In [None]:
# Definizione dei parametri del modello ViT
model_name = 'vit_base_patch16_224.augreg2_in21k_ft_in1k'
pretrained = True
dim = 768
heads = 12

size = 'Base'

# Caricamento del modello ViT utilizzando timm
model_timm = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)

In [None]:
for cut in cut_list:

    modello = 'TRAM'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'PatchMergerViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'ATSViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'TopKViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)


modello = 'ViT'
performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

## Small

In [None]:
# Definizione dei parametri del modello ViT
model_name = 'vit_small_patch16_224.augreg_in21k_ft_in1k'
pretrained = True
dim = 384
heads = 6

size = 'Small'

# Caricamento del modello ViT utilizzando timm
model_timm = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)

In [None]:
for cut in cut_list:

    modello = 'TRAM'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'PatchMergerViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'ATSViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'TopKViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)


modello = 'ViT'
performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

# CIFAR10

## Dataset e Modello

<h3> Parametri del dataset

In [None]:
batch_size = 64
img_size = 224
patch_size = 16
nome_file = "CIFAR10_pretrained"

## Divisione tra train e test

In [None]:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

transform_train = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_validation = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])


# Prepare dataset
trainset = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=True, transform=transform_train) #download=True,
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

validationset = torchvision.datasets.CIFAR10(root=dataset_path, train=False, transform=transform_validation) #download=True

validation_loader = torch.utils.data.DataLoader(validationset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

classes = trainset.classes
print(classes)

## Inizializzazione

In [None]:
epoche = 5
learning_rate = 0.0001
num_classes = 10
cut_list = [
            75,
            50,
           ]
total_patches = int(img_size/patch_size)**2

## Base

In [None]:
# Definizione dei parametri del modello ViT
model_name = 'vit_base_patch16_224.augreg2_in21k_ft_in1k'
pretrained = True
dim = 768
heads = 12

size = 'Base'

# Caricamento del modello ViT utilizzando timm
model_timm = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)

In [None]:
for cut in cut_list:

    modello = 'TRAM'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'PatchMergerViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'ATSViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'TopKViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)


modello = 'ViT'
performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

## Small

In [None]:
# Definizione dei parametri del modello ViT
model_name = 'vit_small_patch16_224.augreg_in21k_ft_in1k'
pretrained = True
dim = 384
heads = 6

cut_list = [
            75,
            50
            ]

size = 'Small'

# Caricamento del modello ViT utilizzando timm
model_timm = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)

In [None]:
for cut in cut_list:

    modello = 'TRAM'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'PatchMergerViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'ATSViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'TopKViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)


modello = 'ViT'
performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

# FMNIST

## Dataset e Modello

<h3> Parametri del dataset

In [None]:
batch_size = 64
img_size = 224
patch_size = 16
nome_file = "FMNIST_pretrained"

## Divisione tra train e test

In [None]:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

trans = transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)

transform_train = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    trans,
    transforms.Normalize(mean, std),
])


transform_validation = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    trans,
    transforms.Normalize(mean, std),
])


# Prepare dataset
trainset = torchvision.datasets.FashionMNIST(root=dataset_path, train=True, transform=transform_train) #download=True,
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

validationset = torchvision.datasets.FashionMNIST(root=dataset_path, train=False, transform=transform_validation) #download=True

validation_loader = torch.utils.data.DataLoader(validationset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

classes = trainset.classes
print(classes)

## Inizializzazione

In [None]:
epoche = 5
learning_rate = 0.0001
num_classes = 10
cut_list = [
            # 75,
            50,
            # 25
           ]
total_patches = int(img_size/patch_size)**2

## Base

In [None]:
# Definizione dei parametri del modello ViT
model_name = 'vit_base_patch16_224.augreg2_in21k_ft_in1k'
pretrained = True
dim = 768
heads = 12

size = 'Base'

# Caricamento del modello ViT utilizzando timm
model_timm = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)

In [None]:
for cut in cut_list:

    modello = 'TRAM'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'PatchMergerViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'ATSViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'TopKViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)


modello = 'ViT'
performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

## Small

In [None]:
cut_list = [
            75,
            50
           ]

# Definizione dei parametri del modello ViT
model_name = 'vit_small_patch16_224.augreg_in21k_ft_in1k'
pretrained = True
dim = 384
heads = 6

size = 'Small'

# Caricamento del modello ViT utilizzando timm
model_timm = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)

In [None]:
for cut in cut_list:

    modello = 'PatchMergerViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'ATSViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)

    modello = 'TopKViT'
    performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)


modello = 'ViT'
performance(modello, total_patches, learning_rate, num_classes, epoche, model_timm, dim, heads, cut, size, img_size, patch_size)