In [1]:
import VisionTransformer as vit

import datetime
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F

from datetime import datetime
from gc import collect
from os import cpu_count
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from tqdm import tqdm

In [2]:
random_seed = 1
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
cudnn.benchmark = False

In [3]:
collect()
torch.cuda.empty_cache()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Running on device: {device}")

Running on device: cpu


## Preparing dataset

In [4]:
dspl = h5py.File('displacements_25.h5')["data"]
trac = h5py.File('tractions_25.h5')["data"]

dspl = np.moveaxis(np.array(dspl),3 ,1)
trac = np.moveaxis(np.array(trac),3 ,1)

X_train = torch.from_numpy(dspl).double()
Y_train = torch.from_numpy(trac).double()

In [5]:
X_train.size()

torch.Size([25, 2, 104, 104])

In [6]:
train_set = TensorDataset(X_train, Y_train)
# val_set = TensorDataset(X_val, y_val)

batch_size = 8

if device == 'cpu':
    num_workers = os.cpu_count()
else:
    num_workers = 4 * torch.cuda.device_count()

dataloaders = {}
dataloaders['train'] = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
# dataloaders['val'] = DataLoader(val_set, batch_size=10*batch_size, num_workers=num_workers, pin_memory=True)

In [7]:
vit_model = vit.VisionTransformer(embed_dim=128).double()
n_params = sum(p.numel() for p in vit_model.parameters() if p.requires_grad)

In [8]:
loss = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.AdamW(vit_model.parameters(), lr=0.001, weight_decay=0.0005)  # to use with ViTs

# fp16_scaler = torch.cuda.amp.GradScaler()

In [9]:
def run_epoch(model, loss_fn, dataloader, device, epoch, optimizer, train):
    # Set model to training mode
    if train:
        model.train()
    else:
        model.eval()

    epoch_loss = 0.0

    with tqdm(dataloader, unit="batch") as tepoch:
        # Iterate over data
        for xb, yb in tepoch:
            tepoch.set_description(f"Epoch {epoch}")

            xb, yb = xb.to(device), yb.to(device)

            # zero the parameters
            if train:
                optimizer.zero_grad(set_to_none=True)

            # forward
            with torch.set_grad_enabled(train):
                pred = model(xb)
                loss = loss_fn(pred, yb)

                # backward + optimize if in training phase
                if train:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2)
                    optimizer.step()

            # statistics
            epoch_loss += loss.item()

        epoch_loss /= len(dataloader.dataset)
        epoch_rmse = np.sqrt(2 * epoch_loss)
        tepoch.set_postfix(loss=epoch_loss)
        sleep(0.01)
    return epoch_loss, epoch_rmse

In [10]:
def fit(model, loss_fn, dataloaders, optimizer, device, writer, NAME, max_epochs, patience):
    brdt_train_rmse = np.inf
    best_epoch = -1
    best_model_weights = {}

    for epoch in range(1, max_epochs + 1):
        train_loss, train_rmse = run_epoch(model, loss_fn, dataloaders['train'], device, epoch, optimizer, train=True)
        # val_loss, val_rmse = run_epoch(model, loss_fn, dataloaders['val'], device, epoch, optimizer=None, train=False)
        print(
            f"Epoch {epoch}/{max_epochs}, train_loss: {train_loss:.3f}, train_rmse: {train_rmse:.3f}")

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_rmse', train_rmse, epoch)
        # writer.add_scalar('val_loss', val_loss, epoch)
        # writer.add_scalar('val_rmse', val_rmse, epoch)

        # Save best weights
        if train_rmse < best_train_rmse:
            best_epoch = epoch
            best_model_weights = copy.deepcopy(model.state_dict())

        # Early stopping
        print(
            f"best train_rmse: {best_train_rmse:.3f}, epoch: {epoch}, best_epoch: {best_epoch}, current_patience: {patience - (epoch - best_epoch)}")
        if epoch - best_epoch >= patience:
            break

    torch.save(best_model_weights, f'{NAME}_best_train_rmse_{np.round(best_train_rmse, 3)}.pth')

In [11]:
NAME = "ViT-{:%Y-%b-%d %H:%M:%S}".format(datetime.now())
writer = SummaryWriter(log_dir='{}'.format(NAME))
vit_model.to(device)
fit(vit_model, loss, dataloaders, optimizer, device, writer, NAME, 100, 5)

Epoch 1:   0%|                                                            | 0/4 [00:00<?, ?batch/s]


Forward call of PatchEmbed: x of shape torch.Size([8, 2, 104, 104])
Forward call of PatchEmbed: After proj, x of shape torch.Size([8, 128, 13, 13])
Forward call of PatchEmbed: After proj and flatten, x of shape torch.Size([8, 128, 169])
Forward call of PatchEmbed: After proj and flatten and transpose, x of shape torch.Size([8, 169, 128])
Forward call of Attention: x of shape torch.Size([8, 169, 128])
Forward call of Attention: x of shape torch.Size([8, 169, 384])


RuntimeError: shape '[8, 169, 2, 12, 10]' is invalid for input of size 519168