# Import e path

In [None]:
dataset_path = "../Datasets/"

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch import optim, nn
from torch.nn import functional
import os
import time
import csv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shutil
from torchvision.utils import make_grid
from random import randint
from PIL import Image
import random

from einops import rearrange
from einops.layers.torch import Rearrange

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import timm
from TRAM import TRAM

<h3> Validi per ogni modello

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

In [None]:
# Verifica se la GPU è disponibile
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# Funzioni per train e validation del modello

<h3> Funzione per il Training

In [None]:
def train_iter(model, optimz, data_load, loss_val, device):
    samples = len(data_load.dataset)
    model.train()

    for i, (data, target) in enumerate(data_load):
        data = data.to(device)
        target = target.to(device)

        optimz.zero_grad()
        out = functional.log_softmax(model(data), dim=1)
        loss = functional.nll_loss(out, target)
        loss.backward()
        optimz.step()

        if i % 100 == 0:
            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_load)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
    loss_val.append(loss.item())

<h3> Funzione per la validation

In [None]:
def evaluate(model, optimizer, data_load, loss_val, device):
    model.eval()

    samples = len(data_load.dataset)
    # predizioni corrette
    csamp = 0
    tloss = 0

    with torch.no_grad():
        for data, target in data_load:

            data = data.to(device)
            target = target.to(device)

            output = functional.log_softmax(model(data), dim=1)
            loss = functional.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)

            tloss += loss.item()
            csamp += pred.eq(target).sum()

    aloss = tloss / samples
    loss_val.append(aloss)
    acc = (100.0 * csamp / samples).cpu()

    print('\nAverage test loss: ' + '{:.4f}'.format(aloss) +
          '  Accuracy:' + '{:5}'.format(csamp) + '/' +
          '{:5}'.format(samples) + ' (' +
          '{:4.2f}'.format(acc) + '%)\n')

    return acc

<h3> Funzione per allenare e validare il modello

In [None]:
def train_validation(model, optimizer, train_loader, validation_loader, nome_file, epoche, device):
  tr_loss, ts_loss, ts_acc, epoch_time_list = [], [], [], []

  for epoch in range(1, epoche + 1):

      start_time = time.time()

      print(f'Epoch: {epoch}/{epoche}')
      print("INIZIO TRAINING")
      train_iter(model, optimizer, train_loader, tr_loss, device)
      print("INIZIO VALIDATION")
      acc = evaluate(model, optimizer, validation_loader, ts_loss, device)

      if (not ts_acc or acc >= max(ts_acc)):
        checkpoint = {'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'train_loss_state_dict': tr_loss[-1],
                      'val_loss_state_dict': ts_loss[-1],
                      'val_acc_state_dict': acc
                      }

      ts_acc.append(acc)


      epoch_time = time.time() - start_time
      epoch_time_list.append(epoch_time)

      print('Execution time:', '{:5.2f}'.format(epoch_time), 'seconds')
      print("#"*40)

  return tr_loss, ts_loss, ts_acc, epoch_time_list

# Funzioni per importare i pesi

In [None]:
def get_weights(n_patch, num_classes, model_timm, dim, heads, image_size, patch_size):


    # Definizione dei parametri del modello ViT
    depth = 12
    mlp_dim = dim * 4

    # Creazione del modello ViT utilizzando vit_pytorch
    model_vit = TRAM(
            image_size=image_size,
            patch_size=patch_size,
            num_classes=num_classes,
            dim=dim,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim,
            n_patch = n_patch
    )


    # Ottieni i pesi del modello
    model_timm_weights = model_timm.state_dict()

    # Ottieni i pesi del modello
    model_vit_weights = model_vit.state_dict()
    
    i = 0

    for elem_timm, elem_vit in zip(model_timm.blocks, model_vit.transformer.layers):
        model_timm_weights[f'transformer.layers.{i}.0.norm.weight'] = model_timm_weights.pop(f'blocks.{i}.norm1.weight')
        model_timm_weights[f'transformer.layers.{i}.0.norm.bias'] = model_timm_weights.pop(f'blocks.{i}.norm1.bias')

        model_timm_weights[f'transformer.layers.{i}.0.to_qkv.weight'] = model_timm_weights.pop(f'blocks.{i}.attn.qkv.weight')
        model_timm_weights.pop(f'blocks.{i}.attn.qkv.bias')

        model_timm_weights[f'transformer.layers.{i}.0.to_out.0.weight'] = model_timm_weights.pop(f'blocks.{i}.attn.proj.weight')
        model_timm_weights[f'transformer.layers.{i}.0.to_out.0.bias'] = model_timm_weights.pop(f'blocks.{i}.attn.proj.bias')


        model_timm_weights[f'transformer.layers.{i}.1.net.0.weight'] = model_timm_weights.pop(f'blocks.{i}.norm2.weight')
        model_timm_weights[f'transformer.layers.{i}.1.net.0.bias'] = model_timm_weights.pop(f'blocks.{i}.norm2.bias')

        model_timm_weights[f'transformer.layers.{i}.1.net.1.weight'] = model_timm_weights.pop(f'blocks.{i}.mlp.fc1.weight')
        model_timm_weights[f'transformer.layers.{i}.1.net.1.bias'] = model_timm_weights.pop(f'blocks.{i}.mlp.fc1.bias')

        model_timm_weights[f'transformer.layers.{i}.1.net.4.weight'] = model_timm_weights.pop(f'blocks.{i}.mlp.fc2.weight')
        model_timm_weights[f'transformer.layers.{i}.1.net.4.bias'] = model_timm_weights.pop(f'blocks.{i}.mlp.fc2.bias')

        i += 1


    model_timm_weights[f'mlp_head.weight'] = model_timm_weights.pop(f'head.weight')
    model_timm_weights[f'mlp_head.bias'] = model_timm_weights.pop(f'head.bias')

    model_timm_weights[f'transformer.norm.weight'] = model_timm_weights.pop(f'norm.weight')
    model_timm_weights[f'transformer.norm.bias'] = model_timm_weights.pop(f'norm.bias')

    model_timm_weights[f'pos_embedding'] = model_timm_weights.pop(f'pos_embed')

    model_vit.load_state_dict(model_timm_weights)


    return model_vit

# Funzioni per il dataloader

<h3> Funzione per creare il dataloader in base al tipo di split

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, root_dir, split="train", transform=None):

        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.class_folders = [f.path for f in os.scandir(self.root_dir) if f.is_dir()]
        self.image_paths = []
        self.labels = []

        class_to_label = {class_name: i for i, class_name in enumerate(self.class_folders)}

        for class_folder, class_label in class_to_label.items():

            if self.split == "train":
              image_folder_path = os.path.join(class_folder, "images")
              image_filenames = os.listdir(image_folder_path)
              image_paths = [os.path.join(image_folder_path, f) for f in image_filenames]

            elif self.split == "val":
              image_filenames = os.listdir(class_folder)
              image_paths = [os.path.join(class_folder, f) for f in image_filenames]

            self.image_paths.extend(image_paths)
            self.labels.extend([class_label] * len(image_filenames))



    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


<h3> Funzioni per stampare il batch di immagini passando il dataloader

In [None]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def show_batch(dataloader):
    dataiter = iter(dataloader)
    images, labels = next(dataiter)
    imshow(make_grid(images))

# Dataset

In [None]:
batch_size = 64
image_size = 224
patch_size = 16
nome_file = "Imagenette_pretrained"

In [None]:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

transform_train = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_validation = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])


# Prepare dataset
trainset = torchvision.datasets.Imagenette(root=dataset_path, split='train', transform=transform_train) #download=True,
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

validationset = torchvision.datasets.Imagenette(root=dataset_path, split='val', transform=transform_validation) #download=True

validation_loader = torch.utils.data.DataLoader(validationset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

classes = trainset.classes
print(classes)

## Functions 

In [None]:
def create_patch_list(total_patches, cut):
    # Calcola l'importo scontato per ogni blocco di 3 elementi
    discounted_patches = total_patches
    patch_list = []
    for i in range(12):
        if i % 3 == 0 and i != 0:
            discounted_patches = int(discounted_patches * (cut / 100))
        patch_list.append(discounted_patches)
    return patch_list

In [None]:
def train_model(model, cut, epoche, learning_rate, image_size, patch_size):

    total_patches = int(image_size/patch_size)**2
    
    # Definizione dei parametri del modello ViT
    if model == 'Base':
        model_name = 'vit_base_patch16_224.augreg_in21k_ft_in1k'
        pretrained = True
        dim = 768
        heads = 12

    elif model == 'Small':
        model_name = 'vit_small_patch16_224.augreg_in21k_ft_in1k'
        pretrained = True
        dim = 384
        heads = 6

    elif model == 'Tiny':
        model_name = 'vit_tiny_patch16_224.augreg2_in21k_ft_in1k'
        pretrained = True
        dim = 192
        heads = 3

    else:
        print('Model not found')
        return
    
    patch = create_patch_list(total_patches, cut)
    
    # Caricamento del modello ViT utilizzando timm
    model_timm = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
    
    model = get_weights(patch, num_classes, model_timm, dim, heads, image_size, patch_size)
    
    # Sposta il modello sulla GPU (se disponibile)
    model.to(device)
    
    # definiamo l'ottimizzatore
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    train_loss, validation_loss, validation_acc, epoch_time = train_validation(model, optimizer, train_loader, validation_loader, nome_file, epoche, device = device)

    return model

In [None]:
def create_mask(patch_list, mask_size, patch_size):
    mask = torch.zeros((mask_size, mask_size), dtype=torch.float32)
    for patch_index in patch_list:
        x = (patch_index // (mask_size // patch_size)) * patch_size
        y = (patch_index % (mask_size // patch_size)) * patch_size
        mask[x:x+patch_size, y:y+patch_size] = 1
    return mask.to(device)


def seleziona_indici_iterativamente(lista_originale, indici_selezione):
    risultati = []
    lista_attuale = lista_originale[:]

    for selezione in indici_selezione:
        indici_selezionati = [lista_attuale[i] for i in selezione]
        risultati.append(indici_selezionati)

        lista_attuale = indici_selezionati[:]

    return risultati

## Train 50% and 75%

In [None]:
epoche = 1
learning_rate = 0.0001
num_classes = 10

In [None]:
cut = 75
model = 'Base'
model_75 = train_model(model, cut, epoche, learning_rate, image_size, patch_size)

In [None]:
cut = 50
model = 'Base'
model_50 = train_model(model, cut, epoche, learning_rate, image_size, patch_size)

## Visualization and image saving

In [None]:
torch.manual_seed(42)

# we want to visualize one image at a time
test_loader = torch.utils.data.DataLoader(validationset, batch_size=1, shuffle=True, num_workers=1, pin_memory=False)

In [None]:
mean_tensor = torch.tensor([elem for elem in mean])
std_tensor = torch.tensor([elem for elem in std])
plot = False

for i, (data, target) in enumerate(test_loader):

    if i == 50:
        break

    data = data.to(device)
    target = target.to(device)

    # Utilizzare il modello model_50
    results_50, tokens_50 = model_50(data, return_tokens=True)
    results_50 = functional.softmax(results_50, dim=1)
    predicted_prob_50, predicted_idx_50 = torch.max(results_50, dim=1)
    predicted_label_50 = trainset.classes[predicted_idx_50[0]]
    real_label_50 = trainset.classes[target[0]]

    # Utilizzare il modello model_75
    results_75, tokens_75 = model_75(data, return_tokens=True)
    results_75 = functional.softmax(results_75, dim=1)
    predicted_prob_75, predicted_idx_75 = torch.max(results_75, dim=1)
    predicted_label_75 = trainset.classes[predicted_idx_75[0]]
    real_label_75 = trainset.classes[target[0]]

    total = [i for i in range(196)]

    final_50 = [elem[0].tolist() for elem in tokens_50]
    final_75 = [elem[0].tolist() for elem in tokens_75]

    indici_50 = seleziona_indici_iterativamente(total, final_50)
    indici_75 = seleziona_indici_iterativamente(total, final_75)

    result_50 = []
    result_75 = []

    for indice in indici_50:
        mask = create_mask(indice, data.shape[2], patch_size)
        data_modified = data[0] * std_tensor.unsqueeze(0).unsqueeze(2).unsqueeze(3).to(device) + mean_tensor.unsqueeze(0).unsqueeze(2).unsqueeze(3).to(device)
        result_50.append(data_modified[0] * mask.unsqueeze(0).unsqueeze(1))

    for indice in indici_75:
        mask = create_mask(indice, data.shape[2], patch_size)
        data_modified = data[0] * std_tensor.unsqueeze(0).unsqueeze(2).unsqueeze(3).to(device) + mean_tensor.unsqueeze(0).unsqueeze(2).unsqueeze(3).to(device)
        result_75.append(data_modified[0] * mask.unsqueeze(0).unsqueeze(1))

    fig, axes = plt.subplots(1, 4, figsize=(10, 5))

    for j in range(0, 4):
        image_75 = result_75[j*3][0].cpu()

        axes[j].imshow(image_75.permute(1, 2, 0), cmap=plt.cm.binary)

        if j == 0:
            axes[j].set_title(f'{real_label_75[0]}')


        axes[j].set_xticks([])
        axes[j].set_yticks([])

        if j == 1:
            axes[j].set_title(f'Layer 3')
        if j == 2:
            axes[j].set_title(f'Layer 7')
        if j == 3:
            axes[j].set_title(f'Layer 10')

    plt.tight_layout()  
    plt.savefig(f"Imgs/Single/{i}_75_{real_label_75[0]}.pdf", dpi=100, bbox_inches='tight')
    if plot == True: 
        plt.show()
    else:
        plt.close()

    
    fig, axes = plt.subplots(1, 4, figsize=(10, 5))

    for j in range(0, 4):
        image_50 = result_50[j*3][0].cpu()

        axes[j].imshow(image_50.permute(1, 2, 0), cmap=plt.cm.binary)
        
        axes[j].set_xticks([])
        axes[j].set_yticks([])

    plt.tight_layout()
    plt.savefig(f"Imgs/Single/{i}_50_{real_label_50[0]}.pdf", dpi=100, bbox_inches='tight')
    if plot == True: 
        plt.show()
    else:
        plt.close()


    fig, axes = plt.subplots(2, 4, figsize=(10, 5))

    for j in range(0, 4):

        image_75 = result_75[j*3][0].cpu()
        axes[0, j].imshow(image_75.permute(1, 2, 0), cmap=plt.cm.binary)

        image_50 = result_50[j*3][0].cpu()
        axes[1, j].imshow(image_50.permute(1, 2, 0), cmap=plt.cm.binary)

        
        if j == 0:
            axes[0, j].set_title(f'{real_label_75[0]}')


        axes[0, j].set_xticks([])
        axes[0, j].set_yticks([])

        axes[1, j].set_xticks([])
        axes[1, j].set_yticks([])

        if j == 1:
            axes[0, j].set_title(f'Layer 3')
        if j == 2:
            axes[0, j].set_title(f'Layer 7')
        if j == 3:
            axes[0, j].set_title(f'Layer 10')

    plt.tight_layout()
    plt.savefig(f"Imgs/United/{i}_{real_label_50[0]}.png", dpi=100, bbox_inches='tight')
    if plot == True: 
        plt.show()
    else:
        plt.close()
    
    print("#############################")