# Vision Transformer Test

Motivated by [NASA and IBM's Prithvi](https://arxiv.org/abs/2412.02732), we seek to construct a [Vision Transformer](https://arxiv.org/abs/2010.11929)-based spatio-temporal [Masked Autoencoder](https://arxiv.org/pdf/2111.06377), and to test the principle using a downsampled version of the [Moving MNIST dataset](https://www.cs.toronto.edu/~nitish/unsupervised_video/).

In [None]:
import datetime
import sys
import torch
from time import time

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
sys.path.insert(0, '..')

## Spatial Masked Autoencoder

We first examine the 2D masked autoencoder implemented by Facebook Research at https://github.com/facebookresearch/mae/tree/main and licenced under CC-BY-NC 4.0.

In [None]:
from src.FRAME_FM.models.mae_2d.models_mae import MaskedAutoencoderViT

accelerator = torch.accelerator.current_accelerator()
device = "cpu" if accelerator is None else accelerator.type
model = MaskedAutoencoderViT(
    img_size=28,  # Size of input image (height and width, in pixels)
    patch_size=4,  # Size of patches over which attention operates
    in_chans=1,  # Number of channels of input image
    embed_dim=8,  # Number of dimensions into which input is embedded
    depth=4,  # Number of attention layers for encoding
    num_heads=4,  # Number of attention heads in each layer
    decoder_embed_dim=8,  # Number of dimensions into which output is embedded
    decoder_num_heads=4,  # Number of attention heads for decoding
    mlp_ratio=4.,  # Ratio between dimensions of MLP layer and of embedding
    norm_layer=torch.nn.LayerNorm,  # Class of normalisation layer
    norm_pix_loss=False,  # Whether to normalise target pixels in loss calculation
    ).to(device)
optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.05, lr=1e-4, betas=(0.9, 0.95))

dataset_train = datasets.MNIST(
    r"C:\Users\matarran\OneDrive - NERC\Documents\Projects\FRAME-FM",
    train=True,
    download=True,
    transform=transforms.ToTensor(),
    )
dataloader_train = DataLoader(
    dataset_train,
    batch_size=64,
    shuffle=True,
    drop_last=True,
    )
dataset_test = datasets.MNIST(
    r"C:\Users\matarran\OneDrive - NERC\Documents\Projects\FRAME-FM",
    train=False,
    download=True,
    transform=transforms.ToTensor(),
    )
dataloader_test = DataLoader(
    dataset_test,
    batch_size=64,
    shuffle=False,
    drop_last=False,
    )
mask_ratio = 0.75

def train(dataloader, model, device, mask_ratio, optimizer, report_period=100):
    size = len(dataloader.dataset)
    model.train()
    for batch_id, (batch, _) in enumerate(dataloader):
        batch = batch.to(device, non_blocking=True)
        loss, _, _ = model(batch, mask_ratio=mask_ratio)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch_id % report_period == 0:
            start_id, end_id = batch_id * batch, (batch_id + 1) * len(batch) - 1
            loss_val = loss.item()
            print(f"{start_id:>5d} to {end_id:>5d} of {size:>5d}: train loss {loss_val:>7f}")

def test(dataloader, model, device, mask_ratio):
    size = len(dataloader.dataset)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch, _ in dataloader:
            batch = batch.to(device)
            loss, _, _ = model(batch, mask_ratio=mask_ratio)
            test_loss += loss.item()
    print(f"Average test loss: {test_loss / size:>7f}")

epochs = 100
print(f"Start training for {epochs} epochs")
start_time = time()
for epoch in range(epochs):
    train(dataloader_train, model, device, mask_ratio, optimizer)
    test(dataloader_train, model, device, mask_ratio)

total_time = time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))

Start training for 100 epochs
Train loss: 0.861305  [   64/60000]
Train loss: 0.295204  [ 6464/60000]
Train loss: 0.207280  [12864/60000]
Train loss: 0.169884  [19264/60000]
Train loss: 0.139803  [25664/60000]
Train loss: 0.109364  [32064/60000]
Train loss: 0.089447  [38464/60000]
Train loss: 0.093352  [44864/60000]
Train loss: 0.083689  [51264/60000]
Train loss: 0.080268  [57664/60000]
Average test loss: 0.001293
Train loss: 0.080804  [   64/60000]
Train loss: 0.080994  [ 6464/60000]
Train loss: 0.076160  [12864/60000]


KeyboardInterrupt: 