# Import e path

In [1]:
models_save_path = "Models/"
results_save_path = "Results/"
dataset_path = "../Datasets/"

In [2]:
import torchvision
import torch
import os
import time
import csv
import matplotlib.pyplot as plt
import numpy as np
import shutil
from random import randint
import pandas as pd
from PIL import Image
import random
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as functional
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torchvision.datasets import ImageFolder
from einops import rearrange
from einops.layers.torch import Rearrange
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import torch.optim.lr_scheduler as lr_scheduler
from vit_pytorch.ats_vit import ViT as ATS

from vit_pytorch import ViT
from vit_pytorch.vit_with_patch_merger import ViT as PatchMerger
from vit_pytorch import SimpleViT as SV

<h3> Validi per ogni modello

In [3]:
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

In [4]:
# Verifica se la GPU Ã¨ disponibile
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

# Funzioni per train e validation del modello SAMPLING


In [5]:
def train_iter(model, optimz, data_load, loss_val, device, scheduler):
    samples = len(data_load.dataset)
    model.train()

    for i, (data, target) in enumerate(data_load):
        data = data.to(device)
        target = target.to(device)

        optimz.zero_grad()
        out = functional.log_softmax(model(data), dim=1)
        loss = functional.nll_loss(out, target)
        loss.backward()
        optimz.step()
    
        if i % 100 == 0:
            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_load)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
    scheduler.step()
    print(scheduler.get_last_lr())
    loss_val.append(loss.item())

def evaluate(model, optimizer, data_load, loss_val, device):
    model.eval()

    samples = len(data_load.dataset)
    # predizioni corrette
    csamp = 0
    tloss = 0

    with torch.no_grad():
        for data, target in data_load:

            data = data.to(device)
            target = target.to(device)

            output = functional.log_softmax(model(data), dim=1)
            loss = functional.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)

            tloss += loss.item()
            csamp += pred.eq(target).sum()

    aloss = tloss / samples
    loss_val.append(aloss)
    acc = (100.0 * csamp / samples).cpu()

    print('\nAverage test loss: ' + '{:.4f}'.format(aloss) +
          '  Accuracy:' + '{:5}'.format(csamp) + '/' +
          '{:5}'.format(samples) + ' (' +
          '{:4.2f}'.format(acc) + '%)\n')

    return acc

def train_validation(model, optimizer, train_loader, validation_loader, models_save_path, nome_dataset, epoche,scheduler, device):
  tr_loss, ts_loss, ts_acc, epoch_time_list = [], [], [], []

  for epoch in range(1, epoche + 1):

      start_time = time.time()

      print(f'Epoch: {epoch}/{epoche}')
      print("INIZIO TRAINING")
      train_iter(model, optimizer, train_loader, tr_loss, device, scheduler= scheduler)
      print("INIZIO VALIDATION")
      acc = evaluate(model, optimizer, validation_loader, ts_loss, device)

      if (not ts_acc or acc >= max(ts_acc)):
        checkpoint = {'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'train_loss_state_dict': tr_loss[-1],
                      'val_loss_state_dict': ts_loss[-1],
                      'val_acc_state_dict': acc
                      }
        print(f'Saving Best Accuracy Model')
        torch.save(checkpoint, f'{models_save_path}{nome_dataset}')
        print(f'End of Saving \n')

      ts_acc.append(acc)


      epoch_time = time.time() - start_time
      epoch_time_list.append(epoch_time)

      print('Execution time:', '{:5.2f}'.format(epoch_time), 'seconds')
      print("#"*40)

  return tr_loss, ts_loss, ts_acc, epoch_time_list

# Params

In [6]:
batch_size = 64
img_size = 160
patch_size = 16

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

trans = transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)


transform_train = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    # transforms.RandomRotation(10),  # Random rotation by 10 degrees
    # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Color Jitter
    # # transforms.RandomPerspective(distortion_scale=0.5, p=0.5),  # Random perspective
    # transforms.RandomVerticalFlip(p=0.5),  # Vertical flip with 50% probability
    # transforms.RandomResizedCrop(img_size),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    trans,
    transforms.Normalize(mean, std),
])

transform_validation = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    trans,
    transforms.Normalize(mean, std),
])

trainset = torchvision.datasets.FashionMNIST(root=dataset_path, train=True, download=True, transform=transform_train) #download=True,
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=6, pin_memory=True)

validationset = torchvision.datasets.FashionMNIST(root=dataset_path, train = False, download=True, transform=transform_validation) #download=True

val_loader = torch.utils.data.DataLoader(validationset, batch_size=batch_size, shuffle=False, num_workers=6, pin_memory=True)

# Modello

In [7]:
from DWTViT_gini import DWTViT_gini
from DWTViT_pruning import DWTViT_pruning
from DWTViT_quantile import DWTViT_quantile

In [8]:
def ViT_DWT(nome_dataset, pruning_locations, dim, heads, wavelet, epoche, model_type = 'pruning', strategy='cA'):
    learning_rate = 0.0001

    # Instanzia il modello in base al tipo specificato
    if model_type == 'gini':
        model = DWTViT_gini(
            image_size=img_size,
            patch_size=patch_size,
            num_classes=10,
            dim=dim,
            depth=12,
            heads=heads,
            mlp_dim=dim*4,
            dropout=0,
            emb_dropout=0,
            wavelet=wavelet,
            pruning_locations=pruning_locations,
        )
    elif model_type == 'pruning':
        if strategy is None:
            raise ValueError("Per DWTViT_pruning devi specificare il parametro 'strategy'")
        model = DWTViT_pruning(
            image_size=img_size,
            patch_size=patch_size,
            num_classes=10,
            dim=dim,
            depth=12,
            heads=heads,
            mlp_dim=dim*4,
            dropout=0,
            emb_dropout=0,
            wavelet=wavelet,
            pruning_locations=pruning_locations,
            strategy=strategy
        )
    elif model_type == 'quantile':
        model = DWTViT_quantile(
            image_size=img_size,
            patch_size=patch_size,
            num_classes=10,
            dim=dim,
            depth=12,
            heads=heads,
            mlp_dim=dim*4,
            dropout=0,
            emb_dropout=0,
            wavelet=wavelet,
            pruning_locations=pruning_locations,
        )
    else:
        raise ValueError(f"model_type '{model_type}' non riconosciuto. Usa 'gini', 'pruning' o 'quantile'.")

    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    initial = time.time()
    _, _, validation_acc, epoch_time = train_validation(
        model, optimizer, train_loader, val_loader, models_save_path, nome_dataset, epoche, device=device, scheduler=scheduler
    )
    print(f'Total Time: {time.time() - initial}')

    df = pd.DataFrame({
        'validation_acc': [tensor.item() for tensor in validation_acc],
        'epoch_time': epoch_time
    })

    df.to_csv(f'{results_save_path}{nome_dataset}.csv', index=False)


# RUN

In [9]:
# wavelet = 'haar'
# pruning_locations = [4, 6, 8]
# epoche = 10

# ViT_DWT(nome_dataset = f"FashionMNIST_gini_DWTSmall75%",pruning_locations = pruning_locations,dim=384,heads=6, wavelet=wavelet, epoche = epoche, model_type = gini)

# ViT_DWT(nome_dataset = f"FashionMNIST_gini_DWTBase75%",pruning_locations = pruning_locations,dim=768,heads=12,wavelet=wavelet, epoche = epoche, model_type = gini)

In [10]:
# for wavelet in ['haar', 'db2', 'db4', 'sym2']:

#     pruning_locations = [4, 6, 8]
#     epoche = 10
    
#     ViT_DWT(nome_dataset = f"FashionMNIST_{wavelet}_DWTSmall75%",pruning_locations = pruning_locations,dim=384,heads=6, wavelet=wavelet, epoche = epoche)
    
#     ViT_DWT(nome_dataset = f"FashionMNIST_{wavelet}_DWTBase75%",pruning_locations = pruning_locations,dim=768,heads=12,wavelet=wavelet, epoche = epoche)

In [None]:
# wavelet = 'haar'

# pruning_locations = [4, 6, 8]
# epoche = 10

# for strategy in ['cA', 'cD']:
#     ViT_DWT(nome_dataset = f"FashionMNIST_{strategy}_DWTSmall75%",pruning_locations = pruning_locations,dim=384,heads=6, wavelet=wavelet, epoche = epoche, strategy= strategy, model_type = 'pruning')
    
#     ViT_DWT(nome_dataset = f"FashionMNIST_{strategy}_DWTBase75%",pruning_locations = pruning_locations,dim=768,heads=12,wavelet=wavelet, epoche = epoche, strategy=strategy, model_type = 'pruning')