# ESEGUIMI SNARCINO

In [1]:
# ! pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# ! pip3 install einops
# ! pip3 install vit_pytorch
# ! pip3 install pandas
# ! pip3 install scikit-learn
# ! pip3 install albumentations
# ! pip3 install matplotlib
# ! pip3 install transformers

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from einops.layers.torch import Rearrange
from vit_pytorch.vit import Transformer

import os
import pandas as pd

from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torch.utils.data import Dataset, DataLoader

import numpy as np

import transformers


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import datasets_utils
train_loader, val_loader, test_loader, classes , img_size = datasets_utils.get_Chaoyang_loaders(BATCH_SIZE=64, SEED=42)
# train_loader, val_loader, test_loader, classes , img_size= datasets_utils.get_CUB_loaders(BATCH_SIZE=256, SEED=42, SPLITS=[0.50,0.25,0.25])
# train_loader, val_loader, test_loader, classes , img_size = datasets_utils.get_vegetables_dataloader(BATCH_SIZE=32, SEED=42, SPLITS=[0.50,0.25,0.25])
N_CLASSES = len(classes)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# loadedmodel = torch.load('jacoExperiments/best_distilled_model.pth')
# distiller.load_state_dict(loadedmodel['model_state_dict'])


In [7]:
from torchvision.models import resnet50

from vit_pytorch.distill import DistillableViT, DistillWrapper

teacher = resnet50(pretrained = True)
teacher.fc = nn.Linear(teacher.fc.in_features, N_CLASSES)
resnet50_loaded = torch.load('jacoExperiments/resnet50_finetuned_sulle_cellule.pth')
teacher.load_state_dict(resnet50_loaded)
teacher = teacher.to(device)

student_vit = DistillableViT(
    image_size = img_size,
    patch_size = 32,
    num_classes = N_CLASSES,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)
student_vit = student_vit.to(device)

distiller = DistillWrapper(
    student = student_vit,
    teacher = teacher,
    temperature = 3,           # temperature of distillation
    alpha = 0.5,               # trade between main loss and distillation loss
    hard = False               # whether to use soft or hard distillation
)
distiller = distiller.to(device)



# SNARCI SALTA AL PROSSIMO MARKDOWN

In [None]:
#train loop
def train(model,train_loader,val_loader,epochs=10,lr=1e-2):
    current_valid_loss , best_valid_loss = float('inf'), float('inf')

    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    train_loss = []
    val_loss = []
    for epoch in range(epochs):
        model.train()
        train_loss_ = []
        val_loss_ = []
        for i,(images,classes) in enumerate(train_loader):
            images = datasets_utils.convert_images_dict_to_tensor(images, resize=256)

            images = images.to(device)
            classes = classes.to(device)
            optimizer.zero_grad()
            loss = model(images, classes)
            loss.backward()
            optimizer.step()
            train_loss_.append(loss.item())
            # #every 100 batches, print the loss
        train_loss.append(np.mean(train_loss_))
        
        model.eval()
        with torch.no_grad():
            for i,(images,classes) in enumerate(val_loader):
                images = datasets_utils.convert_images_dict_to_tensor(images, resize=256)

                    
                images = images.to(device)
                classes = classes.to(device)
                loss = model(images, classes)

                current_valid_loss = loss.item()
                val_loss_.append(loss.item())
               
            mean_loss = np.mean(val_loss_)
            current_valid_loss = mean_loss
            if current_valid_loss < best_valid_loss:
                best_valid_loss = current_valid_loss
                print(f"\nBest validation loss: {best_valid_loss}")
                print(f"\nSaving best model for epoch: {epoch+1}\n")
                torch.save(model.state_dict(), 'jacoExperiments/BESt_distilled_model_with_resnet50_already_finetuned.pth')
            val_loss.append(mean_loss)
        print(f'Epoch: {epoch+1}, Train Loss: {train_loss[-1]}, Val Loss: {val_loss[-1]}')

        #save the model
        torch.save(model.state_dict(),f'jacoExperiments/last_distilled_model_with_resnet50_already_finetuned.pth')
        
    return train_loss,val_loss

In [None]:
train_loss,val_loss = train(distiller, train_loader,val_loader,epochs=500,lr=1e-4)
# images shape:  <class 'transformers.image_processing_utils.BatchFeature'>
# classes shape:  torch.Size([64])

# ESEGUI QUESTO SNARCINO

In [8]:
loadedmodel = torch.load('jacoExperiments/last_distilled_model_with_resnet50_already_finetuned.pth')
# print(loadedmodel)

#TODO PER SNARCINOOOOO: SE NON L'HAI ANCORA FATTO, DEVI PRIMA CREARE IL DISTILLER, CIOE' QUESTO
##########################################################################################################
# from torchvision.models import resnet50

# from vit_pytorch.distill import DistillableViT, DistillWrapper

# teacher = resnet50(pretrained = True)
# teacher.fc = nn.Linear(teacher.fc.in_features, N_CLASSES)
# resnet50_loaded = torch.load('jacoExperiments/resnet50_finetuned_sulle_cellule.pth')
# teacher.load_state_dict(resnet50_loaded)
# teacher = teacher.to(device)

# student_vit = DistillableViT(
#     image_size = img_size,
#     patch_size = 32,
#     num_classes = N_CLASSES,
#     dim = 1024,
#     depth = 6,
#     heads = 8,
#     mlp_dim = 2048,
#     dropout = 0.1,
#     emb_dropout = 0.1
# )
# student_vit = student_vit.to(device)

# distiller = DistillWrapper(
#     student = student_vit,
#     teacher = teacher,
#     temperature = 3,           # temperature of distillation
#     alpha = 0.5,               # trade between main loss and distillation loss
#     hard = False               # whether to use soft or hard distillation
# )
# distiller = distiller.to(device)
##########################################################################################################

distiller.load_state_dict(loadedmodel)
student_vit = distiller.student
student_vit.to(device)

DistillableViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
    (1): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=3072, out_features=1024, bias=True)
    (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (dropout): Dropout(p=0.1, inplace=False)
            (to_qkv): Linear(in_features=1024, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=1024, bias=True)
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((1024,), eps=1e-

In [9]:
def test(model,loader):
    model.eval()
    with torch.no_grad():
        y_predicted=[]
        y_real=[]
        for i,(images,classes) in enumerate(loader):
            images = datasets_utils.convert_images_dict_to_tensor(images, resize=256)
            images = images.to(device)
            classes = classes.to(device)
            y_real.extend(classes.cpu().numpy().tolist())

            pred = model(images)
            # from pred of shape 64,4 get the index of the max value for each row
            pred = torch.argmax(pred, dim=1)
            y_predicted.extend(pred.cpu().numpy().tolist())


    acc = sum(1 for x,y in zip(y_real,y_predicted) if x == y) / float(len(y_real))
    print(acc)
test(student_vit,test_loader)          


0.7363253856942497


In [None]:
# The DistillableViT class is identical to ViT except for how the forward pass is handled, 
# so you should be able to load the parameters back to ViT after you have completed distillation training.

# TODO: It might be uselful if we want to use a custom vit
student_vit = student_vit.to_vit()
type(student_vit) # <class 'vit_pytorch.vit_pytorch.ViT'>