In [None]:
# Import Statements
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm
from PIL import Image, ImageOps
import cv2
from mpl_toolkits.axes_grid1 import ImageGrid
import random

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from torch.autograd import Variable


import torchvision
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
import torchvision.models as models

import numpy as np
import seaborn as sns

import sys

import os

import gc

In [None]:
# Initialize the Diheadral Group of Order 8
class D8Group:
    def __init__(self):
        self.elements = self._generate_elements()

    def _generate_elements(self):
        elements = []
        for r in range(4):
            for f in range(2):
                elements.append((r, f))
        return elements

    def order(self):
        return len(self.elements)

    def apply(self, element, x):
        r, f = element
        if f == 1:
            x = torch.flip(x, [-1])
        x = torch.rot90(x, r, [-2, -1])
        return x

    def inverse(self, element):
        r, f = element
        return (4 - r) % 4, f

In [None]:
# Class Definitions

class EquivariantMaxPool2d(nn.Module):
    def __init__(self, kernel_size, stride=None, padding=0):
        super(EquivariantMaxPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.D8 = D8Group()

    def forward(self, x):
        out = []
        for element in self.D8.elements:
            transformed_input = self.D8.apply(element, x)
            pooled_output = F.max_pool2d(transformed_input, self.kernel_size, self.stride, self.padding)
            out.append(self.D8.apply(self.D8.inverse(element), pooled_output))
        out = torch.stack(out, dim=1)
        return out.mean(dim=1)


class EquivariantConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(EquivariantConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.D8 = D8Group()
        self.weights = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        torch.nn.init.kaiming_uniform_(self.weights, nonlinearity='relu')

    def forward(self, x):
        out = []
        for element in self.D8.elements:
            transformed_weight = self.D8.apply(element, self.weights)
            out.append(F.conv2d(x, transformed_weight, stride=self.stride, padding=self.padding))
        out = torch.stack(out, dim=1)
        return out.mean(dim=1)


class EquivariantVGG19(nn.Module):
    def __init__(self, num_classes=1000):
        super(EquivariantVGG19, self).__init__()
        self.features = self._make_layers()
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def _make_layers(self):
        layers = []
        in_channels = 3
        cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
        for x in cfg:
            if x == 'M':
                layers += [EquivariantMaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [EquivariantConv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.ReLU(inplace=True)]
                in_channels = x
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


def check_layer_equivariance(layer, input_tensor, tolerance=1e-3):
    d8 = D8Group()
    for element in d8.elements:
        transformed_input = d8.apply(element, input_tensor)
        original_output = layer(input_tensor)
        transformed_output = layer(transformed_input)
        expected_transformed_output = d8.apply(element, original_output)
        
        print(f"Element: {element}")
        
        if not torch.allclose(transformed_output, expected_transformed_output, atol=tolerance):
            print(f"Layer equivariance test failed for element: {element}")
            print(f"Transformed Input Shape: {transformed_input.shape}")
            print(f"Original Output Shape: {original_output.shape}")
            print(f"Transformed Output Shape: {transformed_output.shape}")
            print(f"Expected Transformed Output Shape: {expected_transformed_output.shape}")
            print("Difference:")
            print(torch.mean(torch.abs(transformed_output - expected_transformed_output)))
            return False
    return True