# Imports

In [1]:
import os
from PIL import Image
import numpy as np
from tqdm import tqdm

import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from transformers import ViTMAEForPreTraining, AutoImageProcessor

from src.vitmae.dataset import BubblesDataset

# Device

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Image Processor

In [3]:
image_processor_checkpoint = r"facebook/vit-mae-base"
image_processor = AutoImageProcessor.from_pretrained(image_processor_checkpoint)

# Dataset

In [4]:
train_images_dir = r"C:\Internship\ITMO_ML\data\ViTMAE\train"
val_images_dir = r"C:\Internship\ITMO_ML\data\ViTMAE\val"

In [5]:
train_dataset = BubblesDataset(images_dir=train_images_dir, image_processor=image_processor)
val_dataset = BubblesDataset(images_dir=val_images_dir, image_processor=image_processor)

In [6]:
batch_size_train = 96
batch_size_val = 96
pin_memory = True
num_workers = 4

In [7]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size_train,
    shuffle=True,
    pin_memory=pin_memory,
    num_workers=num_workers
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size_val,
    shuffle=False,
    pin_memory=pin_memory,
    num_workers=num_workers
)

# Model

In [8]:
def save_model(model, path):
    torch.save(model.state_dict(), path)

In [9]:
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
model = model.to(device)

In [10]:
optimizer = torch.optim.Adam(model.parameters())

In [11]:
save_dir = r"C:\Internship\ITMO_ML\CTCI\checkpoints\vit\vitmae_on_bubbles"

In [12]:
def train(model, optimizer, train_dataloder, val_dataloader, num_epochs=5):
    history = {"train": [], "val": []}
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}:")
        epoch_history = {"train": [], "val": []}
        
        model.train()
        for inputs in tqdm(train_dataloder):
            optimizer.zero_grad()
            
            inputs = inputs.to(device)
            batch_size, _, num_channels, height, width = inputs.data["pixel_values"].shape
            inputs.data["pixel_values"] = torch.reshape(inputs.data["pixel_values"], (batch_size, num_channels, height, width))
            outputs = model(**inputs)
            
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            history["train"].append(loss.item())
            epoch_history["train"].append(loss.item())
            
        epoch_train_loss = sum(epoch_history["train"])/len(epoch_history["train"])
        print(f"Epoch train loss: {epoch_train_loss}")
        
        model.eval()    
        for inputs in tqdm(val_dataloader):
            inputs = inputs.to(device)
            batch_size, _, num_channels, height, width = inputs.data["pixel_values"].shape
            inputs.data["pixel_values"] = torch.reshape(inputs.data["pixel_values"], (batch_size, num_channels, height, width))
            outputs = model(**inputs)
            
            loss = outputs.loss
            
            history["val"].append(loss.item())
            epoch_history["val"].append(loss.item())
    
        epoch_val_loss = sum(epoch_history["val"])/len(epoch_history["val"])
        print(f"Epoch val loss: {epoch_val_loss}\n")
        
        save_model(model.to("cpu"), path=os.path.join(save_dir, f"epoch_{epoch+1}"))
        model = model.to(device)
        
    return history


In [None]:
history = train(
    model=model,
    optimizer=optimizer,
    train_dataloder=train_dataloader,
    val_dataloader=val_dataloader,
    num_epochs=20
)

Epoch 1:


100%|██████████| 195/195 [02:58<00:00,  1.09it/s]


Epoch train loss: 0.1625838197194613


 50%|█████     | 11/22 [02:28<02:21, 12.82s/it]

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(nrows=1, ncols=2)
fig.set_size_inches(18, 6)

ax[0].plot(range(len(history["train"])), history["train"])
ax[0].set_title("Train loss")
ax[1].plot(range(len(history["val"])), history["val"])
ax[1].set_title("Val loss")

plt.show()