In [None]:
%cd GEVit-DL2-Project/src/modern_eq_vit/

In [None]:
import sys
sys.path.append("..")
import models
from g_selfatt import utils
import g_selfatt.groups as groups
from datasets import MNIST_rot, PCam

import torch
import torch.nn as nn
import torchvision.transforms as tvtf
import torchvision.transforms.functional as TF
from torch.optim.lr_scheduler import StepLR,LambdaLR

import os
import copy
import math
import wandb
import random
import argparse
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class CustomRotation(object):
    def __init__(self, angles):
        self.angles = angles

    def __call__(self, img):
        angle = random.choice(self.angles)
        return tvtf.functional.rotate(img, angle)

In [None]:
data_mean = (0.701, 0.538, 0.692)
data_stddev = (0.235, 0.277, 0.213)
transform_train = tvtf.Compose([
    CustomRotation([0, 90, 180, 270]),
    tvtf.RandomHorizontalFlip(),  # Random horizontal flip with a probability of 0.5
    tvtf.RandomVerticalFlip(),
    tvtf.ToTensor(),
    tvtf.Normalize(data_mean, data_stddev)
])

transform_test = tvtf.Compose(
    [
        tvtf.ToTensor(),
        tvtf.Normalize(data_mean, data_stddev),
    ]
)

train_set = PCam(root="../data", train=True, download=True, transform=transform_train, data_fraction=1)
validation_set = PCam(root="../data", train=False, valid=True, download=True, transform=transform_test, data_fraction=1)
test_set = PCam(root="../data", train=False, download=True, transform=transform_test)

batch_size = 256 #64 #if (args.modern_vit) else 16
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    drop_last=True,
)
val_loader = torch.utils.data.DataLoader(
    validation_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
)
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8,
)

In [None]:
kernel_size = 5
hidden_channels = 8
num_hidden = 8  

gcnn = models.get_gcnn(order=4,
    in_channels=3,
    out_channels=hidden_channels,
    kernel_size=kernel_size,
    num_hidden=num_hidden,
    hidden_channels=hidden_channels) 

group_transformer = models.GroupTransformer(
    group=groups.SE2(num_elements=4),
    in_channels=gcnn.out_channels,
    num_channels=20,
    block_sizes=[2, 3],
    expansion_per_block=0,
    crop_per_layer=[1, 0, 0, 0, 0], 
    image_size=gcnn.output_dimensionality,
    num_classes=2,
    dropout_rate_after_maxpooling=0.0,
    maxpool_after_last_block=True,
    normalize_between_layers=True,
    patch_size=5,
    num_heads=9,
    norm_type="LayerNorm",
    activation_function="Swish",
    attention_dropout_rate=0.1,
    value_dropout_rate=0.1,
    whitening_scale=1.41421356,
)

model = models.Hybrid(gcnn, group_transformer).to(device)

In [None]:
model.load_state_dict(torch.load("saved/modern_eq_vit.pt"))

In [None]:
class ModifiedHybrid(nn.Module):
    def __init__(self, gcnn, output_dim=2):
        super(ModifiedHybrid, self).__init__()
        self.gcnn = gcnn
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(288, output_dim)  # 36*8 = 288
        for param in self.gcnn.parameters():
            param.requires_grad = False
        
    def forward(self, x):
        x = self.gcnn(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [None]:
adj_model = ModifiedHybrid(model.gcnn).to(device)

In [None]:
optimizer = torch.optim.Adam(adj_model.parameters(), 0.0001)  # 0.001 works well here for floris model

In [None]:
scaler = GradScaler()
criterion = torch.nn.CrossEntropyLoss()

In [None]:
adj_model.train()
losses = []
for inputs, labels in tqdm(train_loader):
    inputs, labels = inputs.to(device), labels.to(device)  # Move inputs and labels to device
    # smoothed_labels = labels * smoothing[1] + (1 - labels) * smoothing[0]

    optimizer.zero_grad()
    with torch.set_grad_enabled(True):
        with autocast():  # Sets autocast in the main thread. It handles mixed precision in the forward pass.
            outputs = adj_model(inputs)
            loss = criterion(outputs, labels)

        if loss.item() != loss.item():
            continue
        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        scaler.scale(loss).backward()
        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        scaler.step(optimizer)
        # Updates the scale for next iteration.
        scaler.update()
        print(loss.item())


In [None]:
adj_model.eval()  # Set the model to evaluation mode
correct = 0
total = 0
with torch.no_grad():  # Disable gradient calculation during inference
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)  # Move inputs and labels to device
        outputs = adj_model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
test_acc = 100 * correct / total

In [None]:
test_acc