In [8]:
# ! pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# ! pip3 install einops
# ! pip3 install vit_pytorch
# ! pip3 install pandas
# ! pip3 install scikit-learn
# ! pip3 install albumentations
# ! pip3 install matplotlib
# ! pip3 install transformers

In [9]:
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from einops.layers.torch import Rearrange
from vit_pytorch.vit import Transformer

import os
import pandas as pd

from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torch.utils.data import Dataset, DataLoader

import numpy as np

import transformers


In [10]:
import datasets_utils
train_loader, val_loader, test_loader, classes , img_size = datasets_utils.get_Chaoyang_loaders(BATCH_SIZE=64, SEED=42)
# train_loader, val_loader, test_loader, classes , img_size= datasets_utils.get_CUB_loaders(BATCH_SIZE=256, SEED=42, SPLITS=[0.50,0.25,0.25])
# train_loader, val_loader, test_loader, classes , img_size = datasets_utils.get_vegetables_dataloader(BATCH_SIZE=32, SEED=42, SPLITS=[0.50,0.25,0.25])
N_CLASSES = len(classes)

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
from torchvision.models import resnet50

from vit_pytorch.distill import DistillableViT, DistillWrapper


teacher = resnet50(pretrained = True)
# replace the head of the teacher with a new fc with N_CLASSES as neurons
teacher.fc = nn.Linear(teacher.fc.in_features, N_CLASSES)
teacher = teacher.to(device)

student_vit = DistillableViT(
    image_size = img_size,
    patch_size = 32,
    num_classes = N_CLASSES,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)
student_vit = student_vit.to(device)

distiller = DistillWrapper(
    student = student_vit,
    teacher = teacher,
    temperature = 3,           # temperature of distillation
    alpha = 0.5,               # trade between main loss and distillation loss
    hard = False               # whether to use soft or hard distillation
)
distiller = distiller.to(device)

# img = torch.randn(124, 3, 256, 256)
# img = img.to(device)
# labels = torch.randint(0, 16, (124,))
# labels = labels.to(device)

# loss = distiller(img, labels)
# loss.backward()



In [13]:
#train loop
def train(model,train_loader,val_loader,epochs=10,lr=1e-2):
    current_valid_loss , best_valid_loss = float('inf'), float('inf')

    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    train_loss = []
    val_loss = []
    for epoch in range(epochs):
        model.train()
        train_loss_ = []
        val_loss_ = []
        for i,(images,classes) in enumerate(train_loader):

            if type(images)==transformers.image_processing_utils.BatchFeature:
                images = images['pixel_values']
                images=images.to(torch.float32)
                s = images.shape
                images=images.view(s[0],s[-3],s[-2],s[-1])
                images = F.interpolate(images, size=(256, 256), mode='bilinear', align_corners=False)

            images = images.to(device)
            classes = classes.to(device)
            optimizer.zero_grad()
            loss = model(images, classes)
            loss.backward()
            optimizer.step()
            train_loss_.append(loss.item())
            # #every 100 batches, print the loss
        train_loss.append(np.mean(train_loss_))
        
        model.eval()
        with torch.no_grad():
            for i,(images,classes) in enumerate(val_loader):
                if type(images)==transformers.image_processing_utils.BatchFeature:
                    images = images['pixel_values']
                    images=images.to(torch.float32)
                    s = images.shape
                    images=images.view(s[0],s[-3],s[-2],s[-1])
                    images = F.interpolate(images, size=(256, 256), mode='bilinear', align_corners=False)
                    
                images = images.to(device)
                classes = classes.to(device)
                loss = model(images, classes)

                current_valid_loss = loss.item()
                val_loss_.append(loss.item())
               
            mean_loss = np.mean(val_loss_)
            current_valid_loss = mean_loss
            if current_valid_loss < best_valid_loss:
                best_valid_loss = current_valid_loss
                print(f"\nBest validation loss: {best_valid_loss}")
                print(f"\nSaving best model for epoch: {epoch+1}\n")
                torch.save(optimizer.state_dict(), 'jacoExperiments/best_distilled_model.pth')
            val_loss.append(mean_loss)
        print(f'Epoch: {epoch+1}, Train Loss: {train_loss[-1]}, Val Loss: {val_loss[-1]}')

        # print(f'Epoch: {epoch+1}, Train Loss: {train_loss[-1]}')
        #save the last output image on the disk
        #save the model
        torch.save(model.state_dict(),f'jacoExperiments/last_distilled.pth')
        
    return train_loss,val_loss

In [14]:
train_loss,val_loss = train(distiller, train_loader,val_loader,epochs=500,lr=1e-4)
# images shape:  <class 'transformers.image_processing_utils.BatchFeature'>
# classes shape:  torch.Size([64])


Best validation loss: 0.6310295131471422

Saving best model for epoch: 1

Epoch: 1, Train Loss: 0.775751488118232, Val Loss: 0.6310295131471422

Best validation loss: 0.5922340154647827

Saving best model for epoch: 2

Epoch: 2, Train Loss: 0.6369126152388657, Val Loss: 0.5922340154647827

Best validation loss: 0.5671591560045878

Saving best model for epoch: 3

Epoch: 3, Train Loss: 0.6165906198417084, Val Loss: 0.5671591560045878

Best validation loss: 0.5616585844092898

Saving best model for epoch: 4

Epoch: 4, Train Loss: 0.5905899850628044, Val Loss: 0.5616585844092898

Best validation loss: 0.5013054410616556

Saving best model for epoch: 5

Epoch: 5, Train Loss: 0.5696962113621868, Val Loss: 0.5013054410616556

Best validation loss: 0.4945526553524865

Saving best model for epoch: 6

Epoch: 6, Train Loss: 0.5593898194500163, Val Loss: 0.4945526553524865

Best validation loss: 0.49395152264171177

Saving best model for epoch: 7

Epoch: 7, Train Loss: 0.5336030306695383, Val Los

In [15]:
# pred = student_vit(test_loader) # (2, 1000)
student_vit.eval()
with torch.no_grad():
    val_loss_=[]
    for i,(images,classes) in enumerate(val_loader):
        images = images.to(device)
        classes = classes.to(device)
        pred = student_vit(images)
        
        # Get the maximum predicted class
        pred = pred.argmax(dim=1, keepdim=True)
        print(pred)
        
    val_loss.append(np.mean(val_loss_))
print(f'Test Loss: {val_loss[-1]}')
        


AttributeError: 

In [None]:
# The DistillableViT class is identical to ViT except for how the forward pass is handled, 
# so you should be able to load the parameters back to ViT after you have completed distillation training.

# TODO: It might be uselful if we want to use a custom vit
student_vit = student_vit.to_vit()
type(student_vit) # <class 'vit_pytorch.vit_pytorch.ViT'>

In [None]:
# save student vit
torch.save(student_vit.state_dict(),f'jacoExperiments/distilled_student_vit.pth')