# Imports

In [1]:
import matplotlib.pyplot as plt
import torch 

from transformers import ViTMAEForPreTraining, AutoImageProcessor

from src.features.vitmae.dataset import init_pretext_datasets, init_dataloaders
from src.models.vitmae.training import pretext_task_train

# Device

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = True

# Image Processor

In [4]:
image_processor_checkpoint = r"facebook/vit-mae-base"
image_processor = AutoImageProcessor.from_pretrained(image_processor_checkpoint)

# Dataset

In [5]:
train_images_dir = r""
val_images_dir = r""

In [6]:
batch_size_train = 52
batch_size_val = 32
pin_memory = True
num_workers = 0

In [7]:
train_dataset, val_dataset = init_pretext_datasets(
    train_images_dir=train_images_dir,
    val_images_dir=val_images_dir,
    image_processor=image_processor
)

In [8]:
train_dataloader, val_dataloader = init_dataloaders(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size_train=batch_size_train,
    batch_size_val=batch_size_val,
    pin_memory=pin_memory,
    num_workers=num_workers
)

# Model

In [9]:
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
model = model.to(device)

In [10]:
optimizer = torch.optim.Adam(model.parameters())

In [11]:
save_dir = r""

In [None]:
history = pretext_task_train(
    model=model,
    optimizer=optimizer,
    train_dataloder=train_dataloader,
    val_dataloader=val_dataloader,
    device=device,
    save_dir=save_dir,
    num_epochs=10
)

In [None]:
model.save_pretrained(save_dir)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=2)
fig.set_size_inches(18, 6)

ax[0].plot(range(len(history["train_epoch"])), history["train_epoch"])
ax[0].set_title("Train loss")
ax[1].plot(range(len(history["val_epoch"])), history["val_epoch"])
ax[1].set_title("Val loss")

plt.show()