## Group Equivariant canonicalization for an Invariant Task (Image Classification with ViT-Base)
In this notebook, we test whether the group equivariant image canonicalizer can generate a canonical orientation properly for sample images which can be processed by the prediction network. We also visualize the ground truth and predicted class from a prediction network, which is Vision Transformer ([Dosovitskiy et. al, 2020](https://arxiv.org/abs/2010.11929)). Further we consider the group to be $C_4$ which is rotation of 4 discrete rotations.

In [None]:
import torch
from tqdm import tqdm

from equiadapt.images.canonicalization_networks.escnn_networks import ESCNNEquivariantNetwork
from equiadapt.images.canonicalization.discrete_group import GroupEquivariantImageCanonicalization
from equiadapt.common.basecanonicalization import IdentityCanonicalization

from examples.images.classification.prepare import STL10DataModule
from examples.images.classification.inference_utils import GroupInference
from examples.images.classification.model_utils import get_prediction_network


In [2]:
# get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class DatasetHyperparams:
    def __init__(self):
        self.dataset_name = "stl10" # Name of the dataset to use
        self.data_path = "/home/mila/s/siba-smarak.panigrahi/scratch/data/stl10" # Path to the dataset
        self.augment = 1 # Whether to use data augmentation (1) or not (0)
        self.num_workers = 4 # Number of workers for data loading
        self.batch_size = 64 # Number of samples per batch
        
dataset_hyperparams = DatasetHyperparams()
data = STL10DataModule(hyperparams=dataset_hyperparams)

data.setup()
train_loader = data.train_dataloader()

data.setup(stage="test")
test_loader = data.test_dataloader()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Test dataset size:  8000


In [4]:
loss_fn, image_shape, num_classes = torch.nn.CrossEntropyLoss(), (3, 224, 224), 10

# get the prediction network, which in this case is Vision Transformer
prediction_network = get_prediction_network(
    architecture = "vit", 
    dataset_name = "stl10",
    use_pretrained = True,
    freeze_encoder = False,
    input_shape = image_shape,
    num_classes = num_classes
).to(device)

In [5]:
# design canonicalization hyperparams class
class CanonicalizationHyperparams:
    def __init__(self):
        self.canonicalization_type="group_equivariant" # canonicalization type network
        self.network_type = "escnn" # group equivariant canonicalization
        self.resize_shape = 96 # resize shape for the canonicalization network
        self.network_hyperparams = {
            "kernel_size": 5, # Kernel size for the canonization network
            "out_channels": 32, # Number of output channels for the canonization network
            "num_layers": 3, # Number of layers in the canonization network
            "group_type": "rotation", # Type of group for the canonization network
            "num_rotations": 4, # Number of rotations for the canonization network 
        }
        self.beta = 1.0 
        self.input_crop_ratio = 0.8
        
canonicalization_hyperparams = CanonicalizationHyperparams()

canonicalization_network = ESCNNEquivariantNetwork(
                        image_shape,
                        **canonicalization_hyperparams.network_hyperparams,
                    ).to(device)

canonicalizer = GroupEquivariantImageCanonicalization(
            canonicalization_network,
            canonicalization_hyperparams,
            image_shape
        ).to(device)

In [6]:
class InferenceHyperparams:
    def __init__(self):
        self.method = "group"
        self.group_type = "rotation"
        self.num_rotations = 4
        
inference_method = GroupInference(
        canonicalizer,
        prediction_network,
        num_classes,
        InferenceHyperparams(),
        image_shape
    )

### Fine-tuning ViT on STL10 with a $C_4$ equivariant canonicalization network

In [7]:
optimizer = torch.optim.AdamW([
        {'params': prediction_network.parameters(), 'lr': 1e-5},
        {'params': canonicalizer.parameters(), 'lr': 1e-3},
    ])

In [8]:
epochs = 20

# finetuning the prediction network with the canonicalizer for STL10 dataset
for epoch in range(epochs):
    tqdm_bar = tqdm(enumerate(train_loader), desc=f"Epoch {epoch}", total=len(train_loader))
                 
    total_acc, total_loss, total_task_loss, total_prior_loss = 0.0, 0.0, 0.0, 0.0
    for batch_idx, batch in tqdm_bar:
        
        optimizer.zero_grad()
        
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        
        batch_size, num_channels, height, width = x.shape
        assert (num_channels, height, width) == image_shape

        training_metrics = {}
        loss = 0.0
        
        # canonicalize the input data
        x_canonicalized = canonicalizer(x)
        
        # Get the predictions from the prediction network
        logits = prediction_network(x_canonicalized)
            
        # Evaluate the task loss
        task_loss = loss_fn(logits, y)
        loss += task_loss
        
        # Add prior regularization loss
        prior_loss = canonicalizer.get_prior_regularization_loss()
        loss += prior_loss * 100
            
        # Get the predictions and calculate the accuracy 
        preds = logits.argmax(dim=-1)
        acc = (preds == y).float().mean()
            
            
        # Logging the training metrics
        total_acc += acc.item()
        total_loss += loss.item()
        total_task_loss += task_loss.item()
        total_prior_loss += prior_loss.item()   
        training_metrics.update({
                "acc": total_acc / (batch_idx + 1),
                "task_loss": total_task_loss / (batch_idx + 1),
                "prior_loss": total_prior_loss / (batch_idx + 1), 
                "loss": total_loss / (batch_idx + 1),
            })  
        
        # Usual training steps
        loss.backward()
        
        optimizer.step()
        
        # Log the training metrics
        tqdm_bar.set_postfix(training_metrics)

Epoch 0:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 0: 100%|██████████| 79/79 [00:42<00:00,  1.85it/s, acc=0.623, task_loss=1.42, prior_loss=1.21, loss=123]
Epoch 1: 100%|██████████| 79/79 [00:41<00:00,  1.90it/s, acc=0.928, task_loss=0.348, prior_loss=1.1, loss=111] 
Epoch 2: 100%|██████████| 79/79 [00:41<00:00,  1.90it/s, acc=0.962, task_loss=0.168, prior_loss=1.05, loss=105]
Epoch 3: 100%|██████████| 79/79 [00:41<00:00,  1.89it/s, acc=0.976, task_loss=0.109, prior_loss=1.01, loss=101]  
Epoch 4: 100%|██████████| 79/79 [00:41<00:00,  1.90it/s, acc=0.985, task_loss=0.0737, prior_loss=0.984, loss=98.5]
Epoch 5: 100%|██████████| 79/79 [00:41<00:00,  1.89it/s, acc=0.988, task_loss=0.0582, prior_loss=0.953, loss=95.4]
Epoch 6: 100%|██████████| 79/79 [00:41<00:00,  1.90it/s, acc=0.989, task_loss=0.0483, prior_loss=0.937, loss=93.7]
Epoch 7: 100%|██████████| 79/79 [00:41<00:00,  1.90it/s, acc=0.994, task_loss=0.0312, prior_loss=0.905, loss=90.5]
Epoch 8: 100%|██████████| 79/79 [00:41<00:00,  1.88it/s, acc=0.994, task_loss=0.0308, prior

In [9]:
test_tqdm_bar = tqdm(enumerate(test_loader), desc=f"Testing", total=len(test_loader))
total_acc, total_group_acc = 0, 0
for batch_idx, batch in test_tqdm_bar:
    x, y = batch
    x = x.to(device)
    y = y.to(device)
    
    batch_size, num_channels, height, width = x.shape
    assert (num_channels, height, width) == image_shape

    test_metrics = inference_method.get_inference_metrics(x, y)
    
    total_acc += test_metrics["test/acc"]
    total_group_acc += test_metrics["test/group_acc"]
    
    
print(f"Test Accuracy: {total_acc/len(test_loader):.3f}")
print(f"Test Group Accuracy: {total_group_acc/len(test_loader):.3f}")
    

Testing: 100%|██████████| 125/125 [01:23<00:00,  1.51it/s]

Test Accuracy: 0.971
Test Group Accuracy: 0.971





### Fine-tuning ViT on STL10 without canonicalization network

In [10]:
# redefine the prediction network (to reset the weights)
prediction_network = get_prediction_network(
    architecture = "vit",
    dataset_name = "stl10",
    use_pretrained = True,
    freeze_encoder = False,
    input_shape = image_shape,
    num_classes = num_classes
).to(device)


# Initialize tqdm progress bar
optimizer = torch.optim.AdamW([
        {'params': prediction_network.parameters(), 'lr': 1e-5},
    ])

epochs = 20

# finetuning the prediction network with the canonicalizer for STL10 dataset

for epoch in range(epochs):
    tqdm_bar = tqdm(enumerate(train_loader), desc=f"Epoch {epoch}", total=len(train_loader))
                 
    total_acc, total_loss, total_task_loss = 0.0, 0.0, 0.0
    for batch_idx, batch in tqdm_bar:
        
        optimizer.zero_grad()
        
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        
        batch_size, num_channels, height, width = x.shape
        assert (num_channels, height, width) == image_shape

        training_metrics = {}
        loss = 0.0
        
        # Get the predictions from the prediction network
        logits = prediction_network(x)
            
        # Evaluate the task loss
        task_loss = loss_fn(logits, y)
        loss += task_loss
            
        # Get the predictions and calculate the accuracy 
        preds = logits.argmax(dim=-1)
        acc = (preds == y).float().mean()
            
            
        # Logging the training metrics
        total_acc += acc.item()
        total_loss += loss.item()
        total_task_loss += task_loss.item()
        training_metrics.update({
                "acc": total_acc / (batch_idx + 1),
                "task_loss": total_task_loss / (batch_idx + 1),
                "loss": total_loss / (batch_idx + 1),
            })  
        
        # Usual training steps
        loss.backward()
        
        optimizer.step()
        
        # Log the training metrics
        tqdm_bar.set_postfix(training_metrics)

Epoch 0: 100%|██████████| 79/79 [00:34<00:00,  2.31it/s, acc=0.758, task_loss=1.12, loss=1.12]
Epoch 1: 100%|██████████| 79/79 [00:34<00:00,  2.31it/s, acc=0.974, task_loss=0.166, loss=0.166]
Epoch 2: 100%|██████████| 79/79 [00:34<00:00,  2.32it/s, acc=0.989, task_loss=0.0713, loss=0.0713]
Epoch 3: 100%|██████████| 79/79 [00:34<00:00,  2.32it/s, acc=0.993, task_loss=0.0438, loss=0.0438]
Epoch 4: 100%|██████████| 79/79 [00:33<00:00,  2.33it/s, acc=0.997, task_loss=0.0266, loss=0.0266]
Epoch 5: 100%|██████████| 79/79 [00:33<00:00,  2.32it/s, acc=0.999, task_loss=0.0144, loss=0.0144]
Epoch 6: 100%|██████████| 79/79 [00:34<00:00,  2.32it/s, acc=0.999, task_loss=0.0121, loss=0.0121]
Epoch 7: 100%|██████████| 79/79 [00:34<00:00,  2.32it/s, acc=0.999, task_loss=0.0119, loss=0.0119]  
Epoch 8: 100%|██████████| 79/79 [00:34<00:00,  2.31it/s, acc=1, task_loss=0.00626, loss=0.00626]
Epoch 9: 100%|██████████| 79/79 [00:34<00:00,  2.31it/s, acc=1, task_loss=0.00488, loss=0.00488]
Epoch 10: 100%|███

In [11]:
class InferenceHyperparams:
    def __init__(self):
        self.method = "group"
        self.group_type = "rotation"
        self.num_rotations = 4
        
inference_method = GroupInference(
        IdentityCanonicalization(),
        prediction_network,
        num_classes,
        InferenceHyperparams(),
        image_shape
    )

test_tqdm_bar = tqdm(enumerate(test_loader), desc=f"Testing", total=len(test_loader))
total_acc, total_group_acc = 0, 0
for batch_idx, batch in test_tqdm_bar:
    x, y = batch
    x = x.to(device)
    y = y.to(device)
    
    batch_size, num_channels, height, width = x.shape
    assert (num_channels, height, width) == image_shape

    test_metrics = inference_method.get_inference_metrics(x, y)
    
    total_acc += test_metrics["test/acc"]
    total_group_acc += test_metrics["test/group_acc"]
    
    
print(f"Test Accuracy: {total_acc/len(test_loader):.3f}")
print(f"Test Group Accuracy: {total_group_acc/len(test_loader):.3f}")
    

Testing:   0%|          | 0/125 [00:00<?, ?it/s]

Testing: 100%|██████████| 125/125 [01:12<00:00,  1.73it/s]

Test Accuracy: 0.982
Test Group Accuracy: 0.775



