In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#read all files from the folder CUB_200_2011 and assign the subfolder as a class
#the subfolder name is the class name

path = './CUB_200_2011/images/'
classes = os.listdir(path)
classes.sort()
print(classes)
#read all files from the subfolders

data = []
for i in range(len(classes)):
    folder = os.path.join(path,classes[i])
    files = os.listdir(folder)
    for j in range(len(files)):
        data.append([classes[i],os.path.join(folder,files[j])])

#convert the list to a dataframe
df = pd.DataFrame(data,columns=['class','path'])
df.head()



['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani', '005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet', '009.Brewer_Blackbird', '010.Red_winged_Blackbird', '011.Rusty_Blackbird', '012.Yellow_headed_Blackbird', '013.Bobolink', '014.Indigo_Bunting', '015.Lazuli_Bunting', '016.Painted_Bunting', '017.Cardinal', '018.Spotted_Catbird', '019.Gray_Catbird', '020.Yellow_breasted_Chat', '021.Eastern_Towhee', '022.Chuck_will_Widow', '023.Brandt_Cormorant', '024.Red_faced_Cormorant', '025.Pelagic_Cormorant', '026.Bronzed_Cowbird', '027.Shiny_Cowbird', '028.Brown_Creeper', '029.American_Crow', '030.Fish_Crow', '031.Black_billed_Cuckoo', '032.Mangrove_Cuckoo', '033.Yellow_billed_Cuckoo', '034.Gray_crowned_Rosy_Finch', '035.Purple_Finch', '036.Northern_Flicker', '037.Acadian_Flycatcher', '038.Great_Crested_Flycatcher', '039.Least_Flycatcher', '040.Olive_sided_Flycatcher', '041.Scissor_tailed_Flycatcher', '042.Ver

Unnamed: 0,class,path
0,001.Black_footed_Albatross,./CUB_200_2011/images/001.Black_footed_Albatro...
1,001.Black_footed_Albatross,./CUB_200_2011/images/001.Black_footed_Albatro...
2,001.Black_footed_Albatross,./CUB_200_2011/images/001.Black_footed_Albatro...
3,001.Black_footed_Albatross,./CUB_200_2011/images/001.Black_footed_Albatro...
4,001.Black_footed_Albatross,./CUB_200_2011/images/001.Black_footed_Albatro...


In [3]:
#split the data into train and test and validation
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)


In [4]:
#import albumenations
import albumentations as A
from albumentations.pytorch import ToTensorV2

#transformations for train and validation
transform = A.Compose([
    A.Resize(256,256),
    ToTensorV2()])

In [5]:
#create the Dataset class
class BirdDataset(Dataset):
    def __init__(self,df):
        self.df = df
        self.images = self.df['path'].values
        self.classes = self.df['class'].values
        self.classes = np.array([classes.index(i) for i in self.classes])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        image = self.images[idx]
        image = plt.imread(image)
        #image = np.transpose(image,(2,0,1))
        #if is grayscale, convert to rgb
        if image.shape[0] == 1:
            image = np.repeat(image,3,0)
        #cast to float and normalize
        image = image.astype(np.float32)
        image = image/255.0 
        image =  transform(image=image)['image']
        class_ = self.classes[idx]
        class_ = torch.tensor(class_,dtype=torch.long)
        image = image.type(torch.FloatTensor)
        if image.shape[0] == 1:
            image = torch.repeat_interleave(image,3,0)
        return image,class_

In [6]:
#create the dataloaders
batch_size = 16
train_loader = DataLoader(BirdDataset(train_df),batch_size=batch_size,shuffle=True)
val_loader = DataLoader(BirdDataset(val_df),batch_size=batch_size,shuffle=False)
test_loader = DataLoader(BirdDataset(test_df),batch_size=batch_size,shuffle=False)


In [7]:
from vit_pytorch import ViT, MAE
from mae_bellino import MAE as MAEBELLINO
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

v = v.to(device)
#import pretrained weights from the paper

mae_bellino = MAEBELLINO(
    encoder = v,
    masking_ratio = 0.75,   # the paper recommended 75% masked patches
    decoder_dim = 512,      # paper showed good results with just 512
    decoder_depth = 6       # anywhere from 1 to 8
)

mae_bellino = mae_bellino.to(device)
mae = MAE(
    encoder = v,
    masking_ratio = 0.75,   # the paper recommended 75% masked patches
    decoder_dim = 512,      # paper showed good results with just 512
    decoder_depth = 6       # anywhere from 1 to 8
)



mae = mae.to(device)



In [8]:
#train loop
def train(model,train_loader,val_loader,epochs=10,lr=1e-2):

    mae_bellino = MAEBELLINO(
        encoder = v,
        masking_ratio = 0.75,   # the paper recommended 75% masked patches
        decoder_dim = 512,      # paper showed good results with just 512
        decoder_depth = 6       # anywhere from 1 to 8
    )
    mae_bellino = mae_bellino.to(device)

    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    train_loss = []
    val_loss = []
    for epoch in range(epochs):
        model.train()
        train_loss_ = []
        val_loss_ = []
        for i,(images,classes) in enumerate(train_loader):
            images = images.to(device)
            classes = classes.to(device)
            optimizer.zero_grad()
            loss = model(images)
            loss.backward()
            optimizer.step()
            train_loss_.append(loss.item())
            #every 100 batches, print the loss
            if i%80 == 0:
                #transfer weights from model to maebellino model
                mae_bellino.load_state_dict(model.state_dict())
                #get the output image
                output = mae_bellino(images) 
                o = output[0].cpu().detach().numpy().transpose(1,2,0)
                #apply relu
                o = np.maximum(o,0)
                o = np.minimum(o,1)
                plt.imsave(f'outputs/epoch_{epoch+1}_batch_{i}.png',o)
        train_loss.append(np.mean(train_loss_))
        
        model.eval()
        with torch.no_grad():
            for i,(images,classes) in enumerate(val_loader):
                
                images = images.to(device)
                classes = classes.to(device)
                loss = model(images)
                val_loss_.append(loss.item())
                
            val_loss.append(np.mean(val_loss_))
        print(f'Epoch: {epoch+1}, Train Loss: {train_loss[-1]}, Val Loss: {val_loss[-1]}')
        #save the last output image on the disk
        #save the model
        torch.save(model.state_dict(),f'last.pth')
        
    return train_loss,val_loss

In [9]:
train_loss,val_loss = train(mae,train_loader,val_loader,epochs=500,lr=1e-4)

Epoch: 1, Train Loss: 0.0767658321676239, Val Loss: 0.0518376429454755
Epoch: 2, Train Loss: 0.05049062111430754, Val Loss: 0.049360179421255146
Epoch: 3, Train Loss: 0.048815903995754355, Val Loss: 0.04760433742934364
Epoch: 4, Train Loss: 0.04752320779282284, Val Loss: 0.04716150566809258
Epoch: 5, Train Loss: 0.046860323989984846, Val Loss: 0.04711972052326142
Epoch: 6, Train Loss: 0.04668159530324451, Val Loss: 0.047814717390022035
Epoch: 7, Train Loss: 0.046062051551416516, Val Loss: 0.04687287080717289
Epoch: 8, Train Loss: 0.04562902234699893, Val Loss: 0.0456416071787224
Epoch: 9, Train Loss: 0.04537453663030292, Val Loss: 0.04624925236517595
Epoch: 10, Train Loss: 0.045330212120029884, Val Loss: 0.04715892221084086
Epoch: 11, Train Loss: 0.04504454727047833, Val Loss: 0.046525340337874525
Epoch: 12, Train Loss: 0.044771281678734696, Val Loss: 0.04506942544574455
Epoch: 13, Train Loss: 0.04418137640755434, Val Loss: 0.04506376256250729
Epoch: 14, Train Loss: 0.0439798513514195,