In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from transformers import AutoImageProcessor, ViTMAEForPreTraining

In [12]:
#read all files from the folder CUB_200_2011 and assign the subfolder as a class
#the subfolder name is the class name

path = './vegetables/'
classes = os.listdir(path)
classes.sort()
print(classes)
#read all files from the subfolders

data = []
for i in range(len(classes)):
    folder = os.path.join(path,classes[i])
    files = os.listdir(folder)
    for j in range(len(files)):
        data.append([classes[i],os.path.join(folder,files[j])])

#convert the list to a dataframe
df = pd.DataFrame(data,columns=['class','path'])
df.head()



['Apple', 'Banana', 'Carambola', 'Guava', 'Kiwi', 'Mango', 'Orange', 'Peach', 'Pear', 'Persimmon', 'Pitaya', 'Plum', 'Pomegranate', 'Tomatoes', 'muskmelon']


Unnamed: 0,class,path
0,Apple,./vegetables/Apple\Apple 01.png
1,Apple,./vegetables/Apple\Apple 010.png
2,Apple,./vegetables/Apple\Apple 0100.png
3,Apple,./vegetables/Apple\Apple 01000.png
4,Apple,./vegetables/Apple\Apple 01001.png


In [13]:
#split the data into train and test and validation
# from sklearn.model_selection import train_test_split
# train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
# train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)
import datasets_utils
train_df, val_df, test_df = datasets_utils.get_train_valid_test_split(df, SPLITS=[0.6,0.2,0.2], SEED=42)


In [14]:
#import albumenations
import albumentations as A
from albumentations.pytorch import ToTensorV2

#transformations for train and validation
transform = A.Compose([
    A.Resize(224, 224)
    ,ToTensorV2()])

In [15]:
# #create the Dataset class
# image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
# class BirdDataset(Dataset):
#     def __init__(self,df):
#         self.df = df
#         self.images = self.df['path'].values
#         self.classes = self.df['class'].values
#         self.classes = np.array([classes.index(i) for i in self.classes])
        
#     def __len__(self):
#         return len(self.df)
    
#     def __getitem__(self,idx):
#         image = self.images[idx]
#         image = plt.imread(image)
#         #print(image.shape)
#         if len(image.shape) == 2:
#             image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
#         image = transform(image=image)['image']
#         if image.shape[1] == 1:
#             image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
#         image=image_processor(images=image, return_tensors="pt")
#         #print(image['pixel_values'].shape)
#         return image
#create the Dataset class
image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
class MelanzaDataset(Dataset):
    def __init__(self,df):
        self.df = df
        self.images = self.df['path'].values
        self.classes = self.df['class'].values
        self.classes = np.array([classes.index(i) for i in self.classes])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        image = self.images[idx]
        image = plt.imread(image)
        #print(image.shape)
        if len(image.shape) == 2:
            image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
        image = transform(image=image)['image']
        if image.shape[1] == 1:
            image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
        image=image_processor(images=image, return_tensors="pt")
        #print(image['pixel_values'].shape)
        return image

In [16]:
#create the dataloaders
batch_size = 64
train_loader = DataLoader(MelanzaDataset(train_df),batch_size=batch_size,shuffle=True)
val_loader = DataLoader(MelanzaDataset(val_df),batch_size=batch_size,shuffle=False)
test_loader = DataLoader(MelanzaDataset(test_df),batch_size=batch_size,shuffle=False)


In [17]:
#try the dataloader
for index,i in enumerate(train_loader):
    #print pixel values and class
    #print(i.shape)
    #image=image_processor(images=i, return_tensors="pt")
    #reshape the pixel values to nx3x224x224
    i['pixel_values']=i['pixel_values'].to(torch.float32)
    s = i['pixel_values'].shape
    i['pixel_values']=i['pixel_values'].view(s[0],s[-3],s[-2],s[-1])
    print(i['pixel_values'].shape)
    #if index == 1000:
    #    break
    break 


torch.Size([64, 3, 224, 224])


In [18]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
model.to(device)



ViTMAEForPreTraining(
  (vit): ViTMAEModel(
    (embeddings): ViTMAEEmbeddings(
      (patch_embeddings): ViTMAEPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
    )
    (encoder): ViTMAEEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTMAELayer(
          (attention): ViTMAEAttention(
            (attention): ViTMAESelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTMAESelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTMAEIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bi

In [19]:
#train loop
from torchvision.utils import save_image
output_path = './models/melanzana/'
def train(model,train_loader,val_loader,epochs=10,lr=1e-2):
    current_valid_loss , best_valid_loss = float('inf'), float('inf')

    optimizer = torch.optim.AdamW(model.parameters(),lr=lr)
    train_loss = []
    val_loss = []
    for epoch in range(epochs):
        model.train()
        train_loss_ = []
        val_loss_ = []
        for i,images in enumerate(train_loader):
            images = images.to(device)
            images['pixel_values']=images['pixel_values'].to(torch.float32)
            s = images['pixel_values'].shape
            images['pixel_values']=images['pixel_values'].view(s[0],s[-3],s[-2],s[-1])
            optimizer.zero_grad()
            #print(images['pixel_values'].shape)
            output = model(**images)
            loss = output.loss
            loss.backward()
            optimizer.step()
            train_loss_.append(loss.item())
            #every 100 batches, print the loss
            if i%128 == 0:

                print(f'Epoch: {epoch+1}, Batch: {i}, Loss: {train_loss_[-1]}')

                #reconstruct the original image shape
                masks = output.mask
                #print(masks.shape)
                #porzione immagine
                img = output.logits
                img = img.reshape(
                batch_size, 14, 14, 16, 16, 3
                )
                #batch
                img = torch.einsum("nhwpqc->nchpwq", img)
                img = img.reshape(batch_size, 3, 224, 224)
                #first img of the batch
                img = img[0]
                img = img.view(3,224,224)
                img_recon = img.cpu().detach().numpy().transpose(1,2,0)
                #img = np.maximum(img,0)
                img_recon = (img_recon-img_recon.min())/(img_recon.max()-img_recon.min())#np.minimum(img,1)
                #print(img.shape)
                
                img_originale = images['pixel_values'][0].cpu()
                #bring image to range 0-1
                img_originale = (img_originale-img_originale.min())/(img_originale.max()-img_originale.min())
                #print max and min values
                #print(img_originale.max(),img_originale.min())
                img_originale = img_originale.numpy().transpose(1,2,0)
                #cat immagine originale e immagine con maschera
                #porzione maschera sopra immagine originale
                mask = masks[0]
                mask_new = torch.ones(196,768).to(device)
                for j in range(196):
                    mask_new[j,:] =( 1-mask[j])*mask_new[j,:]
                img = mask_new
                img = img.reshape(
                 14, 14, 16, 16, 3
                )
                #batch
                img = torch.einsum("hwpqc->chpwq", img)
                img = img.reshape(3, 224, 224)
                #first img of the batch
                
                img = img.view(3,224,224)
                img_masked = img.cpu().detach().numpy().transpose(1,2,0)
                img_masked_org = img_masked*img_originale
                img_reconv2 = ((1-img_masked)*img_recon)+img_masked_org
                img_recon= ((1-img_masked)*img_recon)
                #img = np.maximum(img,0)
                #img_masked = (img_masked-img_masked.min())/(img_masked.max()-img_masked.min())#np.minimum(img,1)
                img_cat = np.concatenate((img_originale,img_masked_org,img_recon,img_reconv2),axis=1)
                plt.imsave(f'outputs_newB/epoch_{epoch+1}_batch_{i}.png',img_cat)
        train_loss.append(np.mean(train_loss_))
        
        model.eval()
        with torch.no_grad():
            for i,images in enumerate(val_loader):
                
                images = images.to(device)
                images['pixel_values']=images['pixel_values'].to(torch.float32)
                s = images['pixel_values'].shape
                images['pixel_values']=images['pixel_values'].view(s[0],s[-3],s[-2],s[-1])
                output = model(**images)
                loss = output.loss
                val_loss_.append(loss.item())

            mean_loss = np.mean(val_loss_)
            current_valid_loss = mean_loss
            if current_valid_loss < best_valid_loss:
                best_valid_loss = current_valid_loss
                print(f"\nBest validation loss: {best_valid_loss}")
                print(f"\nSaving best model for epoch: {epoch+1}\n")
                # torch.save(optimizer.state_dict(), 'jacoExperiments/best_distilled_model.pth')
            model.save_pretrained(output_path + '_(call_save_pretrained)_BESTmodel.pth')
            val_loss.append(mean_loss)
                
            val_loss.append(mean_loss)
        print(f'Epoch: {epoch+1}, Train Loss: {train_loss[-1]}, Val Loss: {val_loss[-1]}')
        #save the last output image on the disk
        #save the model
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        if epoch%101==0: 
            model.save_pretrained(output_path + '_(call_save_pretrained)_model_' + str(epoch) + '.pth')
        #torch.save(model.state_dict(),f'last3v2.pth')
        
    return train_loss,val_loss

In [20]:
train_loss,val_loss = train(model,train_loader,val_loader,epochs=500,lr=1e-4)

Epoch: 1, Batch: 0, Loss: 0.21643948554992676
Epoch: 1, Batch: 128, Loss: 0.17925329506397247
Epoch: 1, Batch: 256, Loss: 0.15531757473945618
Epoch: 1, Batch: 384, Loss: 0.14093269407749176

Best validation loss: 0.1566525823878546

Saving best model for epoch: 1



AttributeError: 'ViTMAEForPreTraining' object has no attribute 'save'