In [None]:
!git clone https://github.com/WouterBant/GEVit-DL2-Project.git

In [None]:
%cd GEVit-DL2-Project/post_hoc_equivariance

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install einops
!pip install wandb

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tvtf
import torchvision.transforms.functional as TF

from tqdm import tqdm
import matplotlib.pyplot as plt

import sys
sys.path.append('..')
from datasets import MNIST_rot
from train_vit import VisionTransformer

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def set_seed(seed):
    """Set random seed for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using CUDA
    torch.backends.cudnn.deterministic = True  # if using CUDA
    torch.backends.cudnn.benchmark = False  # if using CUDA, may improve performance but can lead to non-reproducible results

def get_non_equivariant_vit():
    model = VisionTransformer(embed_dim=64,
                            hidden_dim=512,
                            num_heads=4,
                            num_layers=6,
                            patch_size=4,
                            num_channels=1,
                            num_patches=49,
                            num_classes=10,
                            dropout=0.1).to(device)
    model_path = "../saved/results/model.pt"
    print(model.load_state_dict(torch.load(model_path, map_location=device), strict=False))
    return model

model = get_non_equivariant_vit()

In [None]:
data_mean = (0.1307,)
data_stddev = (0.3081,)

transform_train = tvtf.Compose([
    tvtf.RandomRotation(degrees=(-180, 180)),  # random rotation
    tvtf.RandomHorizontalFlip(),  # random horizontal flip with a probability of 0.5
    tvtf.RandomVerticalFlip(),
    tvtf.ToTensor(),
    tvtf.Normalize(data_mean, data_stddev)
])
transform_test = tvtf.Compose(
    [
        tvtf.ToTensor(),
        tvtf.Normalize(data_mean, data_stddev),
    ]
)

train_set = MNIST_rot(root="../data", stage="train", download=True, transform=transform_train, data_fraction=1, only_3_and_8=False)
validation_set = MNIST_rot(root="../data", stage="validation", download=True, transform=transform_test, data_fraction=1, only_3_and_8=False)
test_set = MNIST_rot(root="../data", stage="test", download=True, transform=transform_test, data_fraction=1, only_3_and_8=False)

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=128,
    shuffle=True,
    num_workers=4,
)
val_loader = torch.utils.data.DataLoader(
    validation_set,
    batch_size=128,
    shuffle=True,
    num_workers=4,
)
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=128,
    shuffle=False,
    num_workers=4,
)
img_loader = torch.utils.data.DataLoader(  # single element for visualization purposes
    test_set,
    batch_size=1,
    shuffle=False,
    num_workers=4,
)

In [None]:
def train(model, n_epochs=5):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model.to(device)

    for epoch in tqdm(range(n_epochs)):
        epoch_losses = []
        for images, targets in train_loader:
            images = images.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            epoch_losses.append(loss.item())
        print(f"Epoch {epoch+1}: loss {sum(epoch_losses)/len(epoch_losses):.4f}")

def evaluate(model):
    model.eval()
    correct = total = 0
    with torch.no_grad():  # disable gradient calculation during inference
        for inputs, labels in tqdm(val_loader):
            inputs, labels = inputs.to(device), labels.to(device)  # move inputs and labels to device
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_acc = 100 * correct / total
    return test_acc

def test(model):
    model.eval()
    correct = total = 0
    with torch.no_grad():  # disable gradient calculation during inference
        for inputs, labels in tqdm(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)  # move inputs and labels to device
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_acc = 100 * correct / total
    return test_acc

In [None]:
data = iter(img_loader)
image, target = next(data)
plt.imshow(image.squeeze(), cmap="gray")
plt.show()

In [None]:
def get_transforms(image, n_rotations=4, flips=True):
    """
    Returns all transformations of a single input image
    """
    transforms = [image]

    # Rotations
    for i in range(1, n_rotations):
        angle = i * (360 / n_rotations)
        rotated_image = TF.rotate(image, angle)
        transforms.append(rotated_image)

    # Flips
    if flips:
        flips = []
        for transform in transforms:
            flipped_image_lr = TF.hflip(transform)
            flips.append(flipped_image_lr)
        # for transform in transforms:
        #     flipped_image_ud = TF.vflip(transform)
        #     flips.append(flipped_image_ud)
        transforms.extend(flips)

    res = torch.cat(transforms)
    return res

def visualize_transforms(transformed_images):
    num_images = len(transformed_images)
    num_rows = (num_images - 1) // 4 + 1
    fig, axes = plt.subplots(num_rows, 4, figsize=(16, 4*num_rows))

    for i, img in enumerate(transformed_images):
        row = i // 4
        col = i % 4
        axes[row, col].imshow(img.permute(1, 2, 0), cmap="gray")  # Permute dimensions for visualization
        axes[row, col].axis('off')

    # Hide empty subplots
    for i in range(num_images, num_rows * 4):
        row = i // 4
        col = i % 4
        axes[row, col].axis('off')

    # Labeling axes
    for ax, col in zip(axes[0], ['Original', '90°', '180°', '270°']):
        ax.set_title(col, size="larger")

    plt.tight_layout()
    plt.show()

transformed_images = get_transforms(image)
visualize_transforms(transformed_images)

In [None]:
model.forward(image.to(device), output_cls=True).shape, model.forward(get_transforms(image.to(device)), output_cls=True).shape

#### Some possibilities of equivariant combination of the latent representations:
- Mean pooling
- Max pooling
- Sum
- Most probable
- Highest probability among transformations
- Learn weights for weighted average

In [None]:
from post_hoc_equivariant import *
from sub_models import ScoringModel, Transformer

#### First keeping the original model frozen

In [None]:
# baseline
evaluate(model)

In [None]:
# mean pooling
eq_model_mean = PostHocEquivariantMean(model)
evaluate(eq_model_mean)

In [None]:
# max pooling
eq_model_max = PostHocEquivariantMax(model)
evaluate(eq_model_max)

In [None]:
# summing latent dimensions
eq_model_sum = PostHocEquivariantSum(model)
evaluate(eq_model_sum)

In [None]:
# product of class probabilities
eq_model_most_probable = PostHocEquivariantMostProbable(model)
evaluate(eq_model_most_probable)

In [None]:
# take transformation with highest certainty for class
eq_model_most_certain = PostHocMostCertain(model)
evaluate(eq_model_most_certain)

#### Learn weighs for weighted average

Also here there are a couple of options.
- a) nn takes as input entire latent dimension and outputs scalar representing weight for that dimensions
- b) nn takes as input the i'th entry of each latent dimension, the concatenation is a new latent dimension. Now we have to satisfy equivariance so this order of input should not matter. Use a transformer without PE in this case is a solution.


In [None]:
# a)
set_seed(42)
scoring_model = ScoringModel()
eq_model_learned_score_aggregation = PostHocLearnedScoreAggregation(model=model, scoring_model=scoring_model)
train(eq_model_learned_score_aggregation, n_epochs=10)
evaluate(eq_model_learned_score_aggregation)

In [None]:
# b)
set_seed(42)
aggregation_model = Transformer(embed_dim=64, hidden_dim=128, num_heads=4, num_layers=2, dropout=0.1)
eq_model_learned_aggregation = PostHocLearnedAggregation(model=model, aggregation_model=aggregation_model)
train(eq_model_learned_aggregation, n_epochs=25)
evaluate(eq_model_learned_aggregation)

#### Now all options but with finetuning the mlp_head

In [None]:
# mean pooling
set_seed(42)
model = get_non_equivariant_vit()
eq_model_mean = PostHocEquivariantMean(model, finetune_mlp_head=True)
train(eq_model_mean, n_epochs=25)
evaluate(eq_model_mean)

In [None]:
# max pooling
set_seed(42)
model = get_non_equivariant_vit()
eq_model_max = PostHocEquivariantMax(model, finetune_mlp_head=True)
train(eq_model_max, n_epochs=25)
evaluate(eq_model_max)

In [None]:
# summing latent dimensions
set_seed(42)
model = get_non_equivariant_vit()
eq_model_sum = PostHocEquivariantSum(model, finetune_mlp_head=True)
train(eq_model_sum, n_epochs=25)
evaluate(eq_model_sum)

In [None]:
# product of class probabilities
set_seed(42)
model = get_non_equivariant_vit()
eq_model_most_probable = PostHocEquivariantMostProbable(model, finetune_mlp_head=True)
train(eq_model_most_probable, n_epochs=25)
evaluate(eq_model_most_probable)

In [None]:
# take transformation with highest certainty for class
set_seed(42)
model = get_non_equivariant_vit()
eq_model_most_certain = PostHocMostCertain(model, finetune_mlp_head=True)
train(eq_model_most_certain, n_epochs=25)
evaluate(eq_model_most_certain)

In [None]:
# a)
set_seed(42)
model = get_non_equivariant_vit()
scoring_model = ScoringModel()
eq_model_learned_score_aggregation = PostHocLearnedScoreAggregation(model=model, scoring_model=scoring_model, finetune_mlp_head=True)
train(eq_model_learned_score_aggregation, n_epochs=25)
evaluate(eq_model_learned_score_aggregation)

In [None]:
# b)
set_seed(42)
aggregation_model = Transformer(embed_dim=64, hidden_dim=128, num_heads=4, num_layers=2, dropout=0.1)
eq_model_learned_aggregation = PostHocLearnedAggregation(model=model, aggregation_model=aggregation_model, finetune_mlp_head=True)
train(eq_model_learned_aggregation, n_epochs=25)
evaluate(eq_model_learned_aggregation)

#### Now all options but finetuning the entire base model

In [None]:
# mean pooling
set_seed(42)
model = get_non_equivariant_vit()
eq_model_mean = PostHocEquivariantMean(model, finetune_model=True)
train(eq_model_mean, n_epochs=25)
evaluate(eq_model_mean)

In [None]:
# max pooling
set_seed(42)
model = get_non_equivariant_vit()
eq_model_max = PostHocEquivariantMax(model, finetune_model=True)
train(eq_model_max, n_epochs=25)
evaluate(eq_model_max)

In [None]:
# summing latent dimensions
set_seed(42)
model = get_non_equivariant_vit()
eq_model_sum = PostHocEquivariantSum(model, finetune_model=True)
train(eq_model_sum, n_epochs=25)
evaluate(eq_model_sum)

In [None]:
# product of class probabilities
set_seed(42)
model = get_non_equivariant_vit()
eq_model_most_probable = PostHocEquivariantMostProbable(model, finetune_model=True)
train(eq_model_most_probable, n_epochs=25)
evaluate(eq_model_most_probable)

In [None]:
# take transformation with highest certainty for class
set_seed(42)
model = get_non_equivariant_vit()
eq_model_most_certain = PostHocMostCertain(model, finetune_model=True)
train(eq_model_most_certain, n_epochs=25)
evaluate(eq_model_most_certain)

In [None]:
# a)
set_seed(42)
model = get_non_equivariant_vit()
scoring_model = ScoringModel()
eq_model_learned_score_aggregation = PostHocLearnedScoreAggregation(model=model, scoring_model=scoring_model, finetune_model=True)
train(eq_model_learned_score_aggregation, n_epochs=25)
evaluate(eq_model_learned_score_aggregation)

In [None]:
# b)
set_seed(42)
aggregation_model = Transformer(embed_dim=64, hidden_dim=128, num_heads=4, num_layers=2, dropout=0.1)
eq_model_learned_aggregation = PostHocLearnedAggregation(model=model, aggregation_model=aggregation_model, finetune_model=True)
train(eq_model_learned_aggregation, n_epochs=25)
evaluate(eq_model_learned_aggregation)