In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from transformers import AutoImageProcessor, ViTMAEForPreTraining

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#read all files from the folder CUB_200_2011 and assign the subfolder as a class
#the subfolder name is the class name

path = './CUB_200_2011/images/'
classes = os.listdir(path)
classes.sort()
print(classes)
#read all files from the subfolders

data = []
for i in range(len(classes)):
    folder = os.path.join(path,classes[i])
    files = os.listdir(folder)
    for j in range(len(files)):
        data.append([classes[i],os.path.join(folder,files[j])])

#convert the list to a dataframe
df = pd.DataFrame(data,columns=['class','path'])
df.head()



['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani', '005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet', '009.Brewer_Blackbird', '010.Red_winged_Blackbird', '011.Rusty_Blackbird', '012.Yellow_headed_Blackbird', '013.Bobolink', '014.Indigo_Bunting', '015.Lazuli_Bunting', '016.Painted_Bunting', '017.Cardinal', '018.Spotted_Catbird', '019.Gray_Catbird', '020.Yellow_breasted_Chat', '021.Eastern_Towhee', '022.Chuck_will_Widow', '023.Brandt_Cormorant', '024.Red_faced_Cormorant', '025.Pelagic_Cormorant', '026.Bronzed_Cowbird', '027.Shiny_Cowbird', '028.Brown_Creeper', '029.American_Crow', '030.Fish_Crow', '031.Black_billed_Cuckoo', '032.Mangrove_Cuckoo', '033.Yellow_billed_Cuckoo', '034.Gray_crowned_Rosy_Finch', '035.Purple_Finch', '036.Northern_Flicker', '037.Acadian_Flycatcher', '038.Great_Crested_Flycatcher', '039.Least_Flycatcher', '040.Olive_sided_Flycatcher', '041.Scissor_tailed_Flycatcher', '042.Ver

Unnamed: 0,class,path
0,001.Black_footed_Albatross,./CUB_200_2011/images/001.Black_footed_Albatro...
1,001.Black_footed_Albatross,./CUB_200_2011/images/001.Black_footed_Albatro...
2,001.Black_footed_Albatross,./CUB_200_2011/images/001.Black_footed_Albatro...
3,001.Black_footed_Albatross,./CUB_200_2011/images/001.Black_footed_Albatro...
4,001.Black_footed_Albatross,./CUB_200_2011/images/001.Black_footed_Albatro...


In [3]:
#split the data into train and test and validation
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)


In [4]:
#import albumenations
import albumentations as A
from albumentations.pytorch import ToTensorV2

#transformations for train and validation
transform = A.Compose([
    A.Resize(512, 512)
    ,ToTensorV2()])

In [5]:
#create the Dataset class
image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
class BirdDataset(Dataset):
    def __init__(self,df):
        self.df = df
        self.images = self.df['path'].values
        self.classes = self.df['class'].values
        self.classes = np.array([classes.index(i) for i in self.classes])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        image = self.images[idx]
        image = plt.imread(image)
        #print(image.shape)
        if len(image.shape) == 2:
            image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
        image = transform(image=image)['image']
        if image.shape[1] == 1:
            image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
        image=image_processor(images=image, return_tensors="pt")
        #print(image['pixel_values'].shape)
        return image

In [6]:
#create the dataloaders
batch_size = 32
train_loader = DataLoader(BirdDataset(train_df),batch_size=batch_size,shuffle=True)
val_loader = DataLoader(BirdDataset(val_df),batch_size=batch_size,shuffle=False)
test_loader = DataLoader(BirdDataset(test_df),batch_size=batch_size,shuffle=False)


In [7]:
#try the dataloader
for index,i in enumerate(train_loader):
    #print pixel values and class
    #print(i.shape)
    #image=image_processor(images=i, return_tensors="pt")
    #reshape the pixel values to nx3x224x224
    i['pixel_values']=i['pixel_values'].to(torch.float32)
    s = i['pixel_values'].shape
    i['pixel_values']=i['pixel_values'].view(s[0],s[-3],s[-2],s[-1])
    print(i['pixel_values'].shape)
    #if index == 1000:
    #    break
    break 


torch.Size([32, 3, 224, 224])


In [8]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-large")
model.to(device)



ViTMAEForPreTraining(
  (vit): ViTMAEModel(
    (embeddings): ViTMAEEmbeddings(
      (patch_embeddings): ViTMAEPatchEmbeddings(
        (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      )
    )
    (encoder): ViTMAEEncoder(
      (layer): ModuleList(
        (0-23): 24 x ViTMAELayer(
          (attention): ViTMAEAttention(
            (attention): ViTMAESelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTMAESelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTMAEIntermediate(
            (dense): Linear(in_features=1024, out_feature

In [9]:
#train loop
from torchvision.utils import save_image
def train(model,train_loader,val_loader,epochs=10,lr=1e-2):



    optimizer = torch.optim.AdamW(model.parameters(),lr=lr)
    train_loss = []
    val_loss = []
    for epoch in range(epochs):
        model.train()
        train_loss_ = []
        val_loss_ = []
        for i,images in enumerate(train_loader):
            images = images.to(device)
            images['pixel_values']=images['pixel_values'].to(torch.float32)
            s = images['pixel_values'].shape
            images['pixel_values']=images['pixel_values'].view(s[0],s[-3],s[-2],s[-1])
            optimizer.zero_grad()
            #print(images['pixel_values'].shape)
            output = model(**images)
            loss = output.loss
            loss.backward()
            optimizer.step()
            train_loss_.append(loss.item())
            #every 100 batches, print the loss
            if i%64 == 0:

                print(f'Epoch: {epoch+1}, Batch: {i}, Loss: {train_loss_[-1]}')

                #reconstruct the original image shape
                masks = output.mask
                #print(masks.shape)
                #porzione immagine
                img = output.logits
                img = img.reshape(
                batch_size, 14, 14, 16, 16, 3
                )
                #batch
                img = torch.einsum("nhwpqc->nchpwq", img)
                img = img.reshape(batch_size, 3, 224, 224)
                #first img of the batch
                img = img[0]
                img = img.view(3,224,224)
                img_recon = img.cpu().detach().numpy().transpose(1,2,0)
                #img = np.maximum(img,0)
                img_recon = (img_recon-img_recon.min())/(img_recon.max()-img_recon.min())#np.minimum(img,1)
                #print(img.shape)
                
                img_originale = images['pixel_values'][0].cpu()
                #bring image to range 0-1
                img_originale = (img_originale-img_originale.min())/(img_originale.max()-img_originale.min())
                #print max and min values
                #print(img_originale.max(),img_originale.min())
                img_originale = img_originale.numpy().transpose(1,2,0)
                #cat immagine originale e immagine con maschera
                #porzione maschera sopra immagine originale
                mask = masks[0]
                mask_new = torch.ones(196,768).to(device)
                for j in range(196):
                    mask_new[j,:] =( 1-mask[j])*mask_new[j,:]
                img = mask_new
                img = img.reshape(
                 14, 14, 16, 16, 3
                )
                #batch
                img = torch.einsum("hwpqc->chpwq", img)
                img = img.reshape(3, 224, 224)
                #first img of the batch
                
                img = img.view(3,224,224)
                img_masked = img.cpu().detach().numpy().transpose(1,2,0)
                img_masked_org = img_masked*img_originale
                img_reconv2 = ((1-img_masked)*img_recon)+img_masked_org
                img_recon= ((1-img_masked)*img_recon)
                #img = np.maximum(img,0)
                #img_masked = (img_masked-img_masked.min())/(img_masked.max()-img_masked.min())#np.minimum(img,1)
                img_cat = np.concatenate((img_originale,img_masked_org,img_recon,img_reconv2),axis=1)
                plt.imsave(f'outputs_newL/epoch_{epoch+1}_batch_{i}.png',img_cat)
        train_loss.append(np.mean(train_loss_))
        
        model.eval()
        with torch.no_grad():
            for i,images in enumerate(val_loader):
                
                images = images.to(device)
                images['pixel_values']=images['pixel_values'].to(torch.float32)
                s = images['pixel_values'].shape
                images['pixel_values']=images['pixel_values'].view(s[0],s[-3],s[-2],s[-1])
                output = model(**images)
                loss = output.loss
                val_loss_.append(loss.item())
                
            val_loss.append(np.mean(val_loss_))
        print(f'Epoch: {epoch+1}, Train Loss: {train_loss[-1]}, Val Loss: {val_loss[-1]}')
        #save the last output image on the disk
        #save the model
        torch.save(model.state_dict(),f'last3l.pth')
        
    return train_loss,val_loss

In [10]:
train_loss,val_loss = train(model,train_loader,val_loader,epochs=500,lr=3e-4)

Epoch: 1, Batch: 0, Loss: 0.24588117003440857
Epoch: 1, Batch: 64, Loss: 0.2342035174369812
Epoch: 1, Batch: 128, Loss: 0.2543008625507355
Epoch: 1, Batch: 192, Loss: 0.21421873569488525
Epoch: 1, Train Loss: 0.22098694539676278, Val Loss: 0.22613456138109755
Epoch: 2, Batch: 0, Loss: 0.22318096458911896
Epoch: 2, Batch: 64, Loss: 0.21475069224834442
Epoch: 2, Batch: 128, Loss: 0.21557511389255524
Epoch: 2, Batch: 192, Loss: 0.1971559375524521
Epoch: 2, Train Loss: 0.22037122638548834, Val Loss: 0.22794569902500864
Epoch: 3, Batch: 0, Loss: 0.24961717426776886
Epoch: 3, Batch: 64, Loss: 0.23962678015232086
Epoch: 3, Batch: 128, Loss: 0.20404097437858582
Epoch: 3, Batch: 192, Loss: 0.22111573815345764
Epoch: 3, Train Loss: 0.21935206860051318, Val Loss: 0.23033750840162825
Epoch: 4, Batch: 0, Loss: 0.23012675344944
Epoch: 4, Batch: 64, Loss: 0.21092359721660614
Epoch: 4, Batch: 128, Loss: 0.2186511754989624
Epoch: 4, Batch: 192, Loss: 0.24471765756607056
Epoch: 4, Train Loss: 0.21961218

KeyboardInterrupt: 