## Group Equivariant canonicalization for an Invariant Task (Image Classification with ViT-Base)
In this notebook, we test whether the group equivariant image canonicalizer can generate a canonical orientation properly for sample images which can be processed by the prediction network. We also visualize the ground truth and predicted class from a prediction network, which is Vision Transformer ([Dosovitskiy et. al, 2020](https://arxiv.org/abs/2010.11929)). Further we consider the group to be $C_4$ which is rotation of 4 discrete rotations.

In [None]:
import copy
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms.functional as F

from equiadapt.images.canonicalization_networks.escnn_networks import ESCNNEquivariantNetwork
from equiadapt.images.canonicalization.discrete_group import GroupEquivariantImageCanonicalization
from equiadapt.common.basecanonicalization import IdentityCanonicalization
from examples.images.classification.inference_utils import GroupInference

from examples.images.classification.model_utils import get_prediction_network
from examples.images.classification.prepare import STL10DataModule 


In [2]:
# get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class DatasetHyperparams:
    def __init__(self):
        self.dataset_name = "stl10" # Name of the dataset to use
        self.data_path = "/home/mila/s/siba-smarak.panigrahi/scratch/data/stl10" # Path to the dataset
        self.augment = 1 # Whether to use data augmentation (1) or not (0)
        self.num_workers = 4 # Number of workers for data loading
        self.batch_size = 64 # Number of samples per batch
        
dataset_hyperparams = DatasetHyperparams()
data = STL10DataModule(hyperparams=dataset_hyperparams)

data.setup()
train_loader = data.train_dataloader()

data.setup(stage="test")
test_loader = data.test_dataloader()

In [4]:
# design canonicalization hyperparams class
class CanonicalizationHyperparams:
    def __init__(self):
        self.canonicalization_type="group_equivariant" # canonicalization type network
        self.network_type = "escnn" # group equivariant canonicalization
        self.resize_shape = 32 # resize shape for the canonicalization network
        self.network_hyperparams = {
            "kernel_size": 7, # Kernel size for the canonization network
            "out_channels": 64, # Number of output channels for the canonization network
            "num_layers": 5, # Number of layers in the canonization network
            "group_type": "rotation", # Type of group for the canonization network
            "num_rotations": 4, # Number of rotations for the canonization network 
        }
        self.beta = 1.0 
        self.input_crop_ratio = 0.9
        
canonicalization_hyperparams = CanonicalizationHyperparams()

# get the canonicalization network
canonicalization_network = ESCNNEquivariantNetwork(
    in_shape=(3, 32, 32),
    **canonicalization_hyperparams.network_hyperparams,
).to(device)

# get canonicalizer
canonicalizer = GroupEquivariantImageCanonicalization(
    canonicalization_network=canonicalization_network,
    canonicalization_hyperparams=CanonicalizationHyperparams(),
    in_shape=(3, 224, 224)
).to(device)

# get the prediction network, which in this case is Vision Transformer
prediction_network = get_prediction_network(
    architecture = "vit", 
    dataset_name = "stl10",
    use_pretrained = True,
    freeze_encoder = True,
    input_shape = (3, 224, 224),
    num_classes = 10
).to(device)

In [None]:
# Initialize tqdm progress bar
optimizer = torch.optim.Adam(prediction_network.parameters(), lr=1e-3)
epochs = 10

# finetuning the prediction network for STL10 dataset
for epoch in range(epochs):
    tqdm_bar = tqdm(enumerate(train_loader), desc=f"Epoch {epoch}", total=len(train_loader))
    for _, batch in tqdm_bar:
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        
        batch_size, num_channels, height, width = x.shape
        
        # get prediction network output
        logits = prediction_network(x)
            
        loss = torch.nn.functional.cross_entropy(logits, y)
            
        # Get the predictions and calculate the accuracy
        preds = logits.argmax(dim=-1)
        acc = (preds == y).float().mean()
            
        loss.backward()
        
        optimizer.step()

        # add loss and accuracy to tqdm bar
        tqdm_bar.set_postfix(loss=loss.item(), acc=acc.item())

In [None]:
class InferenceHyperparams:
    def __init__(self):
        self.group_type = "rotation"
        self.num_rotations = 4

inference_method = GroupInference(
            canonicalizer=IdentityCanonicalization(), 
            prediction_network=prediction_network, 
            num_classes=10, 
            inference_hyperparams=InferenceHyperparams(), 
            in_shape=(3, 224, 224)
        )

test_tqdm_bar = tqdm(enumerate(test_loader), desc=f"Testing", total=len(test_loader))
total_acc, total_group_acc = 0, 0
for _, batch in test_tqdm_bar:
    x, y = batch
    x = x.to(device)
    y = y.to(device)
    
    batch_size, num_channels, height, width = x.shape

    test_metrics = inference_method.get_inference_metrics(x, y)
    
    # add test_metrics acc and group_acc to tqdm bar
    test_tqdm_bar.set_postfix(acc=test_metrics["test/acc"].item(), group_acc=test_metrics["test/group_acc"].item())
    
    total_acc += test_metrics["test/acc"].item()
    total_group_acc += test_metrics["test/group_acc"].item()
    
print(f"Test Accuracy: {total_acc/len(test_loader):.3f}")
print(f"Test Group Accuracy: {total_group_acc/len(test_loader):.3f}")
    

In [7]:
# get the prediction network again
prediction_network = get_prediction_network(
    architecture = "vit", 
    dataset_name = "stl10",
    use_pretrained = True,
    freeze_encoder = True,
    input_shape = (3, 224, 224),
    num_classes = 10
).to(device)

In [None]:
# Initialize tqdm progress bar
optimizer = torch.optim.Adam([
        {'params': prediction_network.parameters(), 'lr': 1e-3},
        {'params':canonicalizer.parameters(), 'lr': 1e-3},
    ],)
epochs = 10

# finetuning the prediction network with the canonicalizer for STL10 dataset
for epoch in range(epochs):
    tqdm_bar = tqdm(enumerate(train_loader), desc=f"Epoch {epoch}", total=len(train_loader))
    for _, batch in tqdm_bar:
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        batch_size, num_channels, height, width = x.shape

        # canonicalize the input data
        # For the vanilla model, the canonicalization is the identity transformation
        x_canonicalized = canonicalizer(x)

        # get prediction network output
        logits = prediction_network(x_canonicalized)
            
        loss = torch.nn.functional.cross_entropy(logits, y)
        loss += 100 * canonicalizer.get_prior_regularization_loss()
            
        # Get the predictions and calculate the accuracy
        preds = logits.argmax(dim=-1)
        acc = (preds == y).float().mean()
            
        loss.backward()
        
        optimizer.step()

        # add loss and accuracy to tqdm bar
        tqdm_bar.set_postfix(loss=loss.item(), acc=acc.item())

In [None]:

class InferenceHyperparams:
    def __init__(self):
        self.group_type = "rotation"
        self.num_rotations = 4

inference_method = GroupInference(
            canonicalizer=canonicalizer, 
            prediction_network=prediction_network, 
            num_classes=10, 
            inference_hyperparams=InferenceHyperparams(), 
            in_shape=(3, 224, 224)
        )

test_tqdm_bar = tqdm(enumerate(test_loader), desc=f"Testing", total=len(test_loader))
total_acc, total_group_acc = 0, 0
for _, batch in test_tqdm_bar:
    x, y = batch
    x = x.to(device)
    y = y.to(device)
    
    batch_size, num_channels, height, width = x.shape
    print(x.shape)

    test_metrics = inference_method.get_inference_metrics(x, y)
    
    # add test_metrics acc and group_acc to tqdm bar
    test_tqdm_bar.set_postfix(acc=test_metrics["test/acc"].item(), group_acc=test_metrics["test/group_acc"].item())
    
    total_acc += test_metrics["test/acc"].item()
    total_group_acc += test_metrics["test/group_acc"].item()
    
print(f"Test Accuracy: {total_acc/len(test_loader)}")
print(f"Test Group Accuracy: {total_group_acc/len(test_loader)}")
    