In [1]:
# ! pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# ! pip3 install einops
# ! pip3 install vit_pytorch
# ! pip3 install pandas
# ! pip3 install scikit-learn
# ! pip3 install albumentations
# ! pip3 install matplotlib

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from einops.layers.torch import Rearrange
from vit_pytorch.vit import Transformer

import os
import pandas as pd

from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torch.utils.data import Dataset, DataLoader


In [3]:
#read all files from the folder CUB_200_2011 and assign the subfolder as a class
#the subfolder name is the class name

path = './CUB_200_2011/images/'
classes = os.listdir(path)
classes.sort()
print(classes)
#read all files from the subfolders

data = []
for i in range(len(classes)):
    folder = os.path.join(path,classes[i])
    files = os.listdir(folder)
    for j in range(len(files)):
        data.append([classes[i],os.path.join(folder,files[j])])

#convert the list to a dataframe
df = pd.DataFrame(data,columns=['class','path'])
df.head()

N_CLASSES = len(classes)

['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani', '005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet', '009.Brewer_Blackbird', '010.Red_winged_Blackbird', '011.Rusty_Blackbird', '012.Yellow_headed_Blackbird', '013.Bobolink', '014.Indigo_Bunting', '015.Lazuli_Bunting', '016.Painted_Bunting', '017.Cardinal', '018.Spotted_Catbird', '019.Gray_Catbird', '020.Yellow_breasted_Chat', '021.Eastern_Towhee', '022.Chuck_will_Widow', '023.Brandt_Cormorant', '024.Red_faced_Cormorant', '025.Pelagic_Cormorant', '026.Bronzed_Cowbird', '027.Shiny_Cowbird', '028.Brown_Creeper', '029.American_Crow', '030.Fish_Crow', '031.Black_billed_Cuckoo', '032.Mangrove_Cuckoo', '033.Yellow_billed_Cuckoo', '034.Gray_crowned_Rosy_Finch', '035.Purple_Finch', '036.Northern_Flicker', '037.Acadian_Flycatcher', '038.Great_Crested_Flycatcher', '039.Least_Flycatcher', '040.Olive_sided_Flycatcher', '041.Scissor_tailed_Flycatcher', '042.Ver

In [4]:
#split the data into train and test and validation
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)

In [5]:
#transformations for train and validation
transform = A.Compose([
    A.Resize(256,256),
    ToTensorV2()])

In [6]:
import numpy as np
import matplotlib.pyplot as plt
#create the Dataset class
class BirdDataset(Dataset):
    def __init__(self,df):
        self.df = df
        self.images = self.df['path'].values
        self.classes = self.df['class'].values
        self.classes = np.array([classes.index(i) for i in self.classes])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        image = self.images[idx]
        image = plt.imread(image)
        #image = np.transpose(image,(2,0,1))
        #if is grayscale, convert to rgb
        if image.shape[0] == 1:
            image = np.repeat(image,3,0)
        #cast to float and normalize
        image = image.astype(np.float32)
        image = image/255.0 
        image =  transform(image=image)['image']
        class_ = self.classes[idx]
        class_ = torch.tensor(class_,dtype=torch.long)
        image = image.type(torch.FloatTensor)
        if image.shape[0] == 1:
            image = torch.repeat_interleave(image,3,0)
        return image,class_

In [7]:
#create the dataloaders
BATCH_SIZE = 16
train_loader = DataLoader(BirdDataset(train_df),batch_size=BATCH_SIZE,shuffle=True)
val_loader = DataLoader(BirdDataset(val_df),batch_size=BATCH_SIZE,shuffle=False)
test_loader = DataLoader(BirdDataset(test_df),batch_size=BATCH_SIZE,shuffle=False)


In [8]:
train_loader.dataset[0][0].shape

torch.Size([3, 256, 256])

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
from torchvision.models import resnet50

from vit_pytorch.distill import DistillableViT, DistillWrapper


teacher = resnet50(pretrained = True)
teacher = teacher.to(device)

student_vit = DistillableViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000, # TODO:This doesn't seem working well
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)
student_vit = student_vit.to(device)

distiller = DistillWrapper(
    student = student_vit,
    teacher = teacher,
    temperature = 3,           # temperature of distillation
    alpha = 0.5,               # trade between main loss and distillation loss
    hard = False               # whether to use soft or hard distillation
)
distiller = distiller.to(device)

# img = torch.randn(124, 3, 256, 256)
# img = img.to(device)
# labels = torch.randint(0, 16, (124,))
# labels = labels.to(device)

# loss = distiller(img, labels)
# loss.backward()



In [11]:
#train loop
def train(model,train_loader,val_loader,epochs=10,lr=1e-2):
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    train_loss = []
    val_loss = []
    for epoch in range(epochs):
        model.train()
        train_loss_ = []
        val_loss_ = []
        for i,(images,classes) in enumerate(train_loader):
            images = images.to(device)
            classes = classes.to(device)
            optimizer.zero_grad()
            
            # print("images shape: ", images.shape)
            # print("classes shape: ", classes.shape)
            loss = model(images, classes)
            loss.backward()
            optimizer.step()
            train_loss_.append(loss.item())
            # #every 100 batches, print the loss
            # if i%80 == 0:
            #     #transfer weights from model to maebellino model
            #     mae_bellino.load_state_dict(model.state_dict())
            #     #get the output image
            #     output = mae_bellino(images) 
            #     o = output[0].cpu().detach().numpy().transpose(1,2,0)
            #     #apply relu
            #     o = np.maximum(o,0)
            #     o = np.minimum(o,1)
            #     plt.imsave(f'outputs/epoch_{epoch+1}_batch_{i}.png',o)
        train_loss.append(np.mean(train_loss_))
        
        # model.eval()
        # with torch.no_grad():
        #     for i,(images,classes) in enumerate(val_loader):
                
        #         images = images.to(device)
        #         classes = classes.to(device)
        #         loss = model(images, classes)
        #         val_loss_.append(loss.item())
                
        #     val_loss.append(np.mean(val_loss_))
        # print(f'Epoch: {epoch+1}, Train Loss: {train_loss[-1]}, Val Loss: {val_loss[-1]}')
        print(f'Epoch: {epoch+1}, Train Loss: {train_loss[-1]}')
        #save the last output image on the disk
        #save the model
        torch.save(model.state_dict(),f'jacoExperiments/distil.pth')
        
    return train_loss,val_loss

In [12]:
train_loss,val_loss = train(distiller, train_loader,val_loader,epochs=500,lr=1e-4)


# pred = student_vit(img) # (2, 1000)

Epoch: 1, Train Loss: 3.2733336924496346


RuntimeError: Parent directory jacoExperiments does not exist.

In [None]:
# The DistillableViT class is identical to ViT except for how the forward pass is handled, 
# so you should be able to load the parameters back to ViT after you have completed distillation training.

# TODO: It might be uselful if we want to use a custom vit
student_vit = student_vit.to_vit()
type(student_vit) # <class 'vit_pytorch.vit_pytorch.ViT'>