In [1]:
# ! pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# ! pip3 install einops
# ! pip3 install vit_pytorch
# ! pip3 install pandas
# ! pip3 install scikit-learn
# ! pip3 install albumentations
# ! pip3 install matplotlib

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from einops.layers.torch import Rearrange
from vit_pytorch.vit import Transformer

import os
import pandas as pd

from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torch.utils.data import Dataset, DataLoader

import numpy as np


In [3]:
import datasets_utils
# train_loader, val_loader, test_loader, classes = datasets_utils.get_CUB_loaders(BATCH_SIZE=16, SEED=42, SPLITS=[0.8,0.1,0.1])
train_loader, val_loader, test_loader, classes , img_size = datasets_utils.get_Chaoyang_loaders(BATCH_SIZE=16, SEED=42, SPLITS=[0.8,0.1,0.1])
N_CLASSES = len(classes)

In [4]:
train_dataloader = torch.utils.data.DataLoader(train_loader, batch_size=16, shuffle=True)

In [5]:
train_loader.dataset[0][0].shape

torch.Size([3, 224, 224])

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
from torchvision.models import resnet50

from vit_pytorch.distill import DistillableViT, DistillWrapper


teacher = resnet50(pretrained = True)
# replace the head of the teacher with a new fc with N_CLASSES as neurons
teacher.fc = nn.Linear(teacher.fc.in_features, N_CLASSES)
teacher = teacher.to(device)

student_vit = DistillableViT(
    image_size = img_size,
    patch_size = 32,
    num_classes = N_CLASSES, # TODO:This doesn't seem working well
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)
student_vit = student_vit.to(device)

distiller = DistillWrapper(
    student = student_vit,
    teacher = teacher,
    temperature = 3,           # temperature of distillation
    alpha = 0.5,               # trade between main loss and distillation loss
    hard = False               # whether to use soft or hard distillation
)
distiller = distiller.to(device)

# img = torch.randn(124, 3, 256, 256)
# img = img.to(device)
# labels = torch.randint(0, 16, (124,))
# labels = labels.to(device)

# loss = distiller(img, labels)
# loss.backward()



In [8]:
#train loop
def train(model,train_loader,val_loader,epochs=10,lr=1e-2):
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    train_loss = []
    val_loss = []
    for epoch in range(epochs):
        model.train()
        train_loss_ = []
        val_loss_ = []
        for i,(images,classes) in enumerate(train_loader):
            images = images.to(device)
            classes = classes.to(device)
            optimizer.zero_grad()
            
            # print("images shape: ", images.shape)
            # print("classes shape: ", classes.shape)
            loss = model(images, classes)
            loss.backward()
            optimizer.step()
            train_loss_.append(loss.item())
            # #every 100 batches, print the loss
            # if i%80 == 0:
            #     #transfer weights from model to maebellino model
            #     mae_bellino.load_state_dict(model.state_dict())
            #     #get the output image
            #     output = mae_bellino(images) 
            #     o = output[0].cpu().detach().numpy().transpose(1,2,0)
            #     #apply relu
            #     o = np.maximum(o,0)
            #     o = np.minimum(o,1)
            #     plt.imsave(f'outputs/epoch_{epoch+1}_batch_{i}.png',o)
        train_loss.append(np.mean(train_loss_))
        
        # model.eval()
        # with torch.no_grad():
        #     for i,(images,classes) in enumerate(val_loader):
                
        #         images = images.to(device)
        #         classes = classes.to(device)
        #         loss = model(images, classes)
        #         val_loss_.append(loss.item())
                
        #     val_loss.append(np.mean(val_loss_))
        # print(f'Epoch: {epoch+1}, Train Loss: {train_loss[-1]}, Val Loss: {val_loss[-1]}')
        print(f'Epoch: {epoch+1}, Train Loss: {train_loss[-1]}')
        #save the last output image on the disk
        #save the model
        torch.save(model.state_dict(),f'jacoExperiments/distil.pth')
        
    return train_loss,val_loss

In [9]:
train_loss,val_loss = train(distiller, train_loader,val_loader,epochs=3,lr=1e-4)

Epoch: 1, Train Loss: 0.7104098368369812


KeyboardInterrupt: 

In [None]:
# pred = student_vit(test_loader) # (2, 1000)
student_vit.eval()
with torch.no_grad():
    val_loss_=[]
    for i,(images,classes) in enumerate(val_loader):
        images = images.to(device)
        classes = classes.to(device)
        pred = student_vit(images)
        
        # Get the maximum predicted class
        pred = pred.argmax(dim=1, keepdim=True)
        print(pred)
        
    val_loss.append(np.mean(val_loss_))
print(f'Test Loss: {val_loss[-1]}')
        


In [None]:
# The DistillableViT class is identical to ViT except for how the forward pass is handled, 
# so you should be able to load the parameters back to ViT after you have completed distillation training.

# TODO: It might be uselful if we want to use a custom vit
student_vit = student_vit.to_vit()
type(student_vit) # <class 'vit_pytorch.vit_pytorch.ViT'>