# Efficient Group Equivairent Network
This notebook contains the code for creating and comparing an efficient group equivarient CNN.
The architecture of this network is adapted from the NeurIPS-2021 publication of Lingshen He, Yuxuan Chen, Zhengyang Shen, Yiming Dong, Yisen Wang, and Zhouchen Lin. Adaptions and improvments will be noted.

In [None]:
# import all necessary tools
# import libraries
import torch
import torchvision
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import math
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Experiment
This experiment will compare the Cross Entropy Loss and final classification accuracy of 4 neural networks. Each network will trained on the special Orthogonal groups (SO2) of the CIFAR and MNIST datasets.

The four neural networks are
1.  **Traditional CNN.**
  
    A traditional CNN can achieve group equivarience by rotating the input across the whole group and maintaining the same training label.

2.  **Group Equivarient CNN.**

    The group equivarient CNN (G-CNN) will be created by using the approach proposed by He et. al. in the NeurIPS-2021 submission "Efficient Equivarient Network".
3.  **Group Equivarient CNN with Equivarient MaxPooling.**

    Equivarient max pooling layers will be created and added using the approch proposed by Xu et.al in the NeurIPS-2021 submission "Group Equivarient Subsampling".

4.  **Group Equivarient CNN with Equivarient Maxpooling and non-equivarient attention**.

    Adding non-equivarent attention encoders allows the network to learn the non-equivarient features of the input as well. This is especially important as the output may depend on the relative rotation of the detected features.



## SO2 Group
Create a class for applying the SO2 tansformations on input images to test and train our models
Our S02 group will consist of all 90 degree rotations.

In [None]:
def s02_r4(x):
  b, c, h, w = x.shape
  elements = torch.zeros((b, 4, c, h, w))
  # first element is the original all others are rotated by 90 degress along the height/width dim (2,3)
  elements[::, 0, ...] = x
  elements[::, 1, ...] = torch.rot90(x, 1, [2,3])
  elements[::, 2, ...] = torch.rot90(x, 2, [2,3])
  elements[::, 3, ...] = torch.rot90(x, 3, [2,3])
  return elements

In [None]:
# create the cifar10 datsets, transforms, group elements, and plot
class cifar():
  def __init__(self):
    self.mean = np.array([0.49139968, 0.48215827 ,0.44653124])
    self.std = np.array([0.24703233, 0.24348505, 0.26158768])
    self.t = T.Compose( [T.ToTensor(),    # convert images to tensor form (pushes channel dimentions to beginning)
             T.Normalize( self.mean,      # normalize with the known mean
                          self.std) ])    # normalize with the known standard deviation
    self.inv_t = T.Normalize(mean=(-self.mean/self.std), std=(1/self.std))

    # set the batch size for training
    self.batch_size = 4

    # load the training set and create a data loader for the training set
    self.train_data = torchvision.datasets.CIFAR10(root='./data', download=True,train=True, transform=self.t)
    self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.batch_size,
                                              shuffle=True, num_workers=2)

    self.test_data = torchvision.datasets.CIFAR10(root='./data', download=True, transform=self.t, train=False)
    self.test_loader = torch.utils.data.DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, num_workers=2)

    self.classes = ( 'plane', 'car', 'bird', 'cat', 'deer',
            'dog', 'frog', 'horse', 'ship', 'truck' )

  def transform_image(self, img):
    img = self.inv_t(img)    # unormalize
    img = img.numpy()
    img = np.transpose(img, (1,2,0))
    return img

  def plot_rotations(self, images):
    fig, axes = plt.subplots(self.batch_size, 4, figsize=(8,8))
    for i  in range(self.batch_size):
      for j in range(4):
        image = images[i, j]
        image = self.transform_image(image)
        ax = axes[i, j]
        ax.imshow(image)
        ax.axis('off')



Ciphar = cifar()
print(f"Training size is {len(Ciphar.train_loader)}")
print(f"Testing size is {len(Ciphar.test_loader)}")

In [None]:
dataiter = iter(Ciphar.train_loader)
images, labels = next(dataiter)

print(f"Images before rotation: {images.shape}")
images = s02_r4(images)
print(f"Images after rotation: {images.shape}")

Ciphar.plot_rotations(images)

In [None]:
class mnist():
  def __init__(self):
    self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                      torchvision.transforms.Normalize((0.1307,),(0.3081,))])

    self.train_dataset = torchvision.datasets.MNIST('data', train=True, download=True, transform=self.transform)
    self.test_dataset = torchvision.datasets.MNIST('data', train=False, download=True, transform=self.transform)

    self.batch_size = 4
    self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
    self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

  def transform_image(self, img):
    img = img[0]
    return img

  def plot_rotations(self, images):
    fig, axes = plt.subplots(self.batch_size, 4, figsize=(8,8))
    for i  in range(self.batch_size):
      for j in range(4):
        image = images[i, j]
        image = self.transform_image(image)
        ax = axes[i, j]
        ax.imshow(image)
        ax.axis('off')


Mnist = mnist()
print(f"Training size is {len(Mnist.train_loader)}")
print(f"Testing size is {len(Mnist.test_loader)}")

In [None]:
dataiter = iter(Mnist.train_loader)
images, labels = next(dataiter)

print(f"Images before rotation: {images.shape}")
images = s02_r4(images)
print(f"Images after rotation: {images.shape}")

Mnist.plot_rotations(images)

In [None]:
# test expansion of the rotations into the batch size
dataiter = iter(Mnist.train_loader)
images, labels = next(dataiter)

print(f"Images before rotation: {images.shape}")
images = s02_r4(images)
print(f"Images after rotation: {images.shape}")

b, r, c, h, w = images.shape
images = images.reshape(-1, c, h, w)
labels = np.repeat(labels,r)

print(f"Images after expansion into batch dimention: {images.shape}")
print(f"Labels after expansion into for rotations: {labels.shape}")


fig, axes = plt.subplots(images.shape[0], figsize=(15,15))
for i in range(images.shape[0]):
  plt.imshow(images[i][0])
  ax = axes[i]
  ax.imshow(images[i][0])
  ax.set_title(f"Label: {labels[i]}")
  ax.axis('off')

plt.subplots_adjust(wspace=0.4, hspace=2)


In [None]:
# train the model and print the error periodically
class Trainer():
  def __init__(self, model, criterion, optimizer, train_loader, test_loader):
    self.model = model
    self.criterion = criterion
    self.optimizer = optimizer
    self.train_loader = train_loader
    self.test_loader = test_loader

  def train(self, epochs):
    self.model.train() # we need to set the mode for our model
    total_loss = 0

    len_loader = len(self.train_loader.dataset)
    for e in range(epochs):
      for batch_idx, (images, targets) in enumerate(self.train_loader):
        images, targets = images.to(device), targets.to(device)
        self.optimizer.zero_grad()
        output = self.model(images)
        loss = self.criterion(output, targets) # Here is a typical loss function (negative log likelihood)
        loss.backward()
        self.optimizer.step()

        total_loss += loss.item()
        if batch_idx % 200 == 0: # We visulize our output every 2000 batches
          print(f'Epoch {e}: [{batch_idx*len(images)}/{len_loader}] Loss: {total_loss/200}')
          total_loss = 0.0
    print("Finished training")

  def train_c4(self, epochs):
    self.model.to(device)
    self.model.train() # we need to set the mode for our model
    total_loss = 0

    len_loader = len(self.train_loader.dataset)
    for e in range(epochs):
      for batch_idx, (images, targets) in enumerate(self.train_loader):
        images, targets = images.to(device), targets.to(device)
        b, c, h, w = images.shape
        images = s02_r4(images).reshape(-1, c,h, w).to(device)
        targets = targets.repeat(4).to(device)
        self.optimizer.zero_grad()
        output = self.model(images)
        loss = self.criterion(output, targets) # Here is a typical loss function (negative log likelihood)
        loss.backward()
        self.optimizer.step()

        total_loss += loss.item()
        if batch_idx % 20 == 0: # We visulize our output every 10 batches
          print(f'Epoch {e}: [{batch_idx*len(images)}/{len_loader}] Loss: {total_loss/20}')
          total_loss = 0.0
    print("Finished training")

  def test(self):
    self.model.eval() # we need to set the mode for our model
    test_loss = 0
    correct = 0
    with torch.no_grad():
      for images, targets in self.test_loader:
        images, targets = images.to(device), targets.to(device)
        output = self.model(images)
        test_loss += self.criterion(output, targets).item()
        pred = output.data.max(1, keepdim=True)[1] # we get the estimate of our result by look at the largest class value
        correct += pred.eq(targets.data.view_as(pred)).sum() # sum up the corrected samples

    test_loss /= len(self.test_loader.dataset)
    print(f'Test: Avg loss is {test_loss}, Accuracy: {100.*correct/len(self.test_loader.dataset)}%')

  def test_c4(self):
    self.model.eval() # we need to set the mode for our model
    test_loss = 0
    correct = 0
    with torch.no_grad():
      for images, targets in self.test_loader:
        images, targets = images.to(device), targets.to(device)
        b, c, h, w = images.shape
        images = s02_r4(images).reshape(-1, c,h, w).to(device)
        targets = targets.repeat(4).to(device)
        output = self.model(images)
        test_loss += self.criterion(output, targets).item()
        pred = output.data.max(1, keepdim=True)[1] # we get the estimate of our result by look at the largest class value
        correct += pred.eq(targets.data.view_as(pred)).sum() # sum up the corrected samples

    test_loss /= len(self.test_loader.dataset)
    print(f'Test C4: Avg loss is {test_loss}, Accuracy: {100.*correct/(len(self.test_loader.dataset)*4)}%')

## Netork 1. Traditional CNN

In [None]:
class MyCNN(nn.Module):
  def __init__(self, in_channels, in_h, in_w):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels, 6, 5)
    size_after_conv1 = in_h - 5 + 1

    self.pool = nn.MaxPool2d(2, 2)
    size_after_max_pool1 = size_after_conv1 // 2 + (size_after_conv1 % 2)

    self.conv2 = nn.Conv2d(6, 16, 5)
    size_after_conv2 = size_after_max_pool1 - 5 + 1

    size_after_max_pool2 = size_after_conv2 // 2 + (size_after_conv2 % 2)
    self.fc1 = nn.Linear(16 *size_after_max_pool2  * size_after_max_pool2, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, 1) # flatten all dimensions except batch
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

Train on a Baseline CNN on the MNIST and CIFAR Dataset

In [None]:
# loss function
import torch.optim as optim
BaselineCNN = MyCNN(1, 28, 28).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(BaselineCNN.parameters(), lr=0.001, momentum=0.9)

# count the parameter
par = sum(p.numel() for p in BaselineCNN.parameters() if p.requires_grad)
print(f"Total parameters of BaselineCNN: {par}")

Total parameters of BaselineCNN: 44426


In [None]:
BaselineCNN_mnist = Trainer(BaselineCNN, criterion, optimizer, Mnist.train_loader, Mnist.test_loader)
print("Baseline CNN no equivarience")
BaselineCNN_mnist.train(2)
BaselineCNN_mnist.test()
BaselineCNN_mnist.test_c4()

In [None]:
BaselineCNN = MyCNN(1, 28, 28).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(BaselineCNN.parameters(), lr=0.001, momentum=0.9)

In [None]:
BaselineCNN_mnist = Trainer(BaselineCNN, criterion, optimizer, Mnist.train_loader, Mnist.test_loader)
print("Baseline CNN with equivarience training")
BaselineCNN_mnist.train_c4(1)
BaselineCNN_mnist.test()
BaselineCNN_mnist.test_c4()

Train the Baseline CNN on CIFAR dataset

In [None]:
# loss function
import torch.optim as optim
BaselineCNN = MyCNN(3, 32, 32).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(BaselineCNN.parameters(), lr=0.001, momentum=0.9)

# count the parameter
par = sum(p.numel() for p in BaselineCNN.parameters() if p.requires_grad)
print(f"Total parameters of BaselineCNN: {par}")

Total parameters of BaselineCNN: 62006


In [None]:
BaselineCNN_cifar = Trainer(BaselineCNN, criterion, optimizer, Ciphar.train_loader, Ciphar.test_loader)
print("Baseline CNN no equivarience")
BaselineCNN_cifar.train(1)
BaselineCNN_cifar.test()
BaselineCNN_cifar.test_c4()

In [None]:
BaselineCNN = MyCNN(3, 32, 32).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(BaselineCNN.parameters(), lr=0.001, momentum=0.9)

In [None]:
BaselineCNN_cifar = Trainer(BaselineCNN, criterion, optimizer, Ciphar.train_loader, Ciphar.test_loader)
print("Baseline CNN with equivarience training")
BaselineCNN_cifar.train_c4(1)
BaselineCNN_cifar.test()
BaselineCNN_cifar.test_c4()

## Network 2. Efficient Group Equivarient CNN
An efficient G-CNN can be created by the following steps


1.   Lifting Convolution
2.   G to G' mapping that reuses the same parameters





In [None]:
# encoder module

##################################################
class C_4_1x1(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(C_4_1x1, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        weight = torch.randn(out_channels, in_channels, 4) / math.sqrt(4 * in_channels / 2)
        self.weight = torch.nn.Parameter(weight)

    def forward(self, x):
        weight = torch.zeros(self.out_channels, 4, self.in_channels, 4).to(x.device)
        weight[::, 0, ...] = self.weight
        weight[::, 1, ...] = self.weight[..., [3, 0, 1, 2]]
        weight[::, 2, ...] = self.weight[..., [2, 3, 0, 1]]
        weight[::, 3, ...] = self.weight[..., [1, 2, 3, 0]]
        x = torch.nn.functional.conv2d(x, weight.reshape(self.out_channels * 4, self.in_channels * 4, 1, 1), stride=1,
                                       padding=0)
        return x

##################################################
class C_4_1x1_(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(C_4_1x1_, self).__init__()
        self.net = nn.Conv3d(in_channels, out_channels, 1, bias=True)
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, x):
        b, c, h, w = x.shape
        x = self.net(x.view(b, c // 4, 4, h, w)).reshape(b, self.out_channels * 4, h, w)
        return x

##################################################
class C_4_3x3(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(C_4_3x3, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        weight = torch.randn(out_channels, in_channels, 4, 3, 3) / math.sqrt(4 * in_channels * 9 / 2)
        self.weight = torch.nn.Parameter(weight)

    def forward(self, x):
        weight = torch.zeros(self.out_channels, 4, self.in_channels, 4, 3, 3).to(x.device)
        weight[::, 0, ...] = self.weight
        weight[::, 1, ...] = torch.rot90(self.weight[..., [3, 0, 1, 2], ::, ::], 1, [3, 4])
        weight[::, 2, ...] = torch.rot90(self.weight[..., [2, 3, 0, 1], ::, ::], 2, [3, 4])
        weight[::, 3, ...] = torch.rot90(self.weight[..., [1, 2, 3, 0], ::, ::], 3, [3, 4])
        x = torch.nn.functional.conv2d(x, weight.reshape(self.out_channels * 4, self.in_channels * 4, 3, 3))
        return x

##################################################
class C_4_BN(nn.Module):
    def __init__(self, in_channels):
        super(C_4_BN, self).__init__()
        self.bn = nn.BatchNorm3d(in_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        return self.bn(x.reshape(b, c // 4, 4, h, w)).reshape(x.size())

##################################################
class C_4_Pool(nn.Module):
    def __init__(self):
        super(C_4_Pool, self).__init__()
        self.pool = nn.MaxPool3d((4, 1, 1), (4, 1, 1))

    def forward(self, x):
        b, c, h, w = x.shape
        return self.pool(x.reshape(b, c // 4, 4, h, w)).squeeze(2)

##################################################
class E4_C4(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 reduction_ratio=2,
                 groups=1
                 ):

        super(E4_C4, self).__init__()
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.reduction_ratio = reduction_ratio
        self.group_channels = groups
        self.groups = self.out_channels // self.group_channels
        self.dim_g = 4

        self.v = nn.Sequential(C_4_1x1(in_channels, out_channels))
        self.conv1 = nn.Sequential(C_4_1x1(in_channels, int(in_channels // reduction_ratio)),
                                    nn.GroupNorm(int(in_channels // reduction_ratio),int(in_channels // reduction_ratio)*4), nn.ReLU())
        self.conv2 = nn.Sequential(C_4_1x1_(int(in_channels // reduction_ratio), kernel_size ** 2 * self.groups))

        self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride=1)

    def forward(self, x):
        weight = self.conv2(self.conv1(x))
        b, c, h, w = weight.shape
        weight = weight.view(b, self.groups, self.kernel_size, self.kernel_size, 4, h, w)
        weight[::, ::, ::, ::, 1, ::, ::] = torch.rot90(weight[::, ::, ::, ::, 1, ::, ::], 1, [2, 3])
        weight[::, ::, ::, ::, 2, ::, ::] = torch.rot90(weight[::, ::, ::, ::, 2, ::, ::], 2, [2, 3])
        weight[::, ::, ::, ::, 3, ::, ::] = torch.rot90(weight[::, ::, ::, ::, 3, ::, ::], 3, [2, 3])
        weight = weight.reshape(b, self.groups, self.kernel_size ** 2, 4, h, w).unsqueeze(2).transpose(3, 4)
        x = self.v(x)
        out = self.unfold(x).view(b, self.groups, self.group_channels, 4, self.kernel_size ** 2, h, w)
        out = (weight * out).sum(dim=4).view(b, self.out_channels * 4, h, w)
        return out

##################################################
class C_4_Conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(C_4_Conv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        weight = torch.randn(out_channels, in_channels, 3, 3) / math.sqrt(9 * in_channels / 2)
        self.weight = torch.nn.Parameter(weight)

    def forward(self, input):
        weight = torch.zeros(self.out_channels, 4, self.in_channels, 3, 3).to(input.device)
        weight[::, 0] = self.weight
        weight[::, 1] = torch.rot90(self.weight[::], 1, [2, 3])
        weight[::, 2] = torch.rot90(self.weight[::], 2, [2, 3])
        weight[::, 3] = torch.rot90(self.weight[::], 3, [2, 3])
        out = nn.functional.conv2d(input, weight.reshape(self.out_channels * 4, self.in_channels, 3, 3), padding=1)
        return out

##################################################
class E4_net(nn.Module):
    def __init__(self, in_channels=1, kernel_size=5, groups=8, reduction_ratio=1, drop=0.2):
        super(E4_net, self).__init__()

        # lifting convolution
        self.conv1=C_4_Conv(in_channels, 16)

        # pooling through conv for dimention reductions
        self.conv2=E4_C4(16, 16, kernel_size, reduction_ratio=reduction_ratio, groups=groups)
        self.conv3=E4_C4(16, 16, kernel_size, reduction_ratio=reduction_ratio, groups=groups)
        self.conv4=E4_C4(16, 16, kernel_size, reduction_ratio=reduction_ratio, groups=groups)
        self.conv5=E4_C4(16, 16, kernel_size, reduction_ratio=reduction_ratio, groups=groups)
        self.conv6=E4_C4(16, 16, kernel_size, reduction_ratio=reduction_ratio, groups=groups)
        self.conv7=E4_C4(16, 16, kernel_size, reduction_ratio=reduction_ratio, groups=groups)
        self.pool=nn.MaxPool2d(2,2)

        self.bn1=C_4_BN(16)
        self.bn2=C_4_BN(16)
        self.bn3=C_4_BN(16)
        self.bn4=C_4_BN(16)
        self.bn5=C_4_BN(16)
        self.bn6=C_4_BN(16)
        self.bn7=C_4_BN(16)

        self.drop=nn.Dropout(drop)


        self.group_pool=C_4_Pool()
        self.global_pool=nn.AdaptiveMaxPool2d(1)

        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(16, 10),
        )

    def forward(self, x):
        x=torch.relu(self.conv1(x))
        x=torch.relu(self.bn2(self.conv2(x)))
        x=self.pool(x)
        x=self.drop(torch.relu(self.conv3(x)))
        x=self.drop(torch.relu(self.conv4(x)))
        x=self.pool(x)
        x=self.drop(torch.relu(self.bn5(self.conv5(x))))
        x=self.drop(torch.relu(self.conv6(x)))
        x=self.drop(torch.relu(self.conv7(x)))
        x=self.group_pool(x)
        x=self.global_pool(x).reshape(x.size(0),-1)
        x=self.fully_net(x)
        return x

##################################################
class C4Basic(nn.Module):
    def __init__(self, in_h, in_w, in_channels, kernel, reduction, groups):
        super().__init__()

        self.conv1=C_4_Conv(in_channels, 16)

        self.forward_function1 = nn.Sequential(
            C_4_3x3(16, 16),
            C_4_BN(16),
            nn.ReLU(inplace=True),
            C_4_3x3(16, 16),
            C_4_BN(16)
        )

        self.forward_function2 = nn.Sequential(
            C_4_3x3(16, 16),
            C_4_BN(16),
            nn.ReLU(inplace=True),
            C_4_3x3(16, 4),
            C_4_BN(4)
        )

        self.pool = nn.MaxPool2d(2, 2)
        h = (in_h - 8) // 2 + ( (in_h - 8) % 2 )
        w = (in_w - 8) // 2 + ( (in_w - 8) % 2 )

        in_nodes = 4 * 4 * h * w
        self.fc1 = nn.Linear(in_nodes, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU(inplace=True)(self.forward_function1(x) )
        x = nn.ReLU(inplace=True)(self.forward_function2(x) )
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
#tests
dataiter = iter(Ciphar.train_loader)
images, labels = next(dataiter)
images_c4 = images.repeat(1, 4, 1, 1)
print(f"Inputs: images={images.shape}, label={labels.shape}")

TestC4_1x1 = C_4_1x1(3,1)
out_a = TestC4_1x1(images_c4)
print(f"C_4_1x1 outputs: {out_a.shape}")

TestC4_1x1_ = C_4_1x1_(3,2)
out_b = TestC4_1x1_(images_c4)
print(f"C_4_1x1_ outputs: {out_b.shape}")

TestC4_3x3 = C_4_3x3(3,3)
out_c = TestC4_3x3(images_c4)
print(f"TestC4_3x3 outputs: {out_c.shape}")

TestEC4 = E4_C4(3,4,5, reduction_ratio=2, groups=1)
out_d = TestEC4(images_c4)
print(f"TestEC4 outputs: {out_d.shape}")

TestC4Basic = C4Basic(
                      in_h=32,
                      in_w=32,
                      in_channels=3,
                      kernel=5,
                      reduction=1,
                      groups=1)
out_e = TestC4Basic(images)
print(f"TestC4Basic outputs: {out_e.shape}")

TestE4_net = E4_net(in_channels=3,kernel_size=5, groups=8, reduction_ratio=1, drop=0.2)
out_f = TestE4_net(images)
print(f"TestE4_net outputs: {out_f.shape}")


Basic Equivarient Network Trained without equivariance

In [None]:
# loss function
import torch.optim as optim
C4BasicNet = C4Basic(
                      in_h=28,
                      in_w=28,
                      in_channels=1,
                      kernel=5,
                      reduction=2,
                      groups=1).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(C4BasicNet.parameters(), lr=0.001, momentum=0.9)

# print total number of parameters
par = sum(p.numel() for p in C4BasicNet.parameters() if p.requires_grad)
print(f"Total parameters of BaselineCNN: {par}")

In [None]:
C4BasicNet_mnist = Trainer(C4BasicNet, criterion, optimizer, Mnist.train_loader, Mnist.test_loader)
print("E4_net no equivarience")
C4BasicNet_mnist.train(1)
C4BasicNet_mnist.test()
C4BasicNet_mnist.test_c4()

Basic Equivarient Network trained with equivariant datasets

In [None]:
# loss function
import torch.optim as optim
C4BasicNet = C4Basic(
                      in_h=28,
                      in_w=28,
                      in_channels=1,
                      kernel=5,
                      reduction=2,
                      groups=1).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(C4BasicNet.parameters(), lr=0.001, momentum=0.9)

# print total number of parameters
par = sum(p.numel() for p in C4BasicNet.parameters() if p.requires_grad)
print(f"Total parameters of BaselineCNN: {par}")

In [None]:
C4BasicNet_mnist = Trainer(C4BasicNet, criterion, optimizer, Mnist.train_loader, Mnist.test_loader)
print("E4_net no equivarience")
C4BasicNet_mnist.train_c4(1)
C4BasicNet_mnist.test()
C4BasicNet_mnist.test_c4()

In [None]:
import torch.optim as optim
TestE4_net = E4_net(kernel_size=5, groups=8, reduction_ratio=1, drop=0.2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(TestE4_net.parameters(), lr=0.001, momentum=0.9)

# print total number of parameters
par = sum(p.numel() for p in TestE4_net.parameters() if p.requires_grad)
print(f"Total parameters of E4_net: {par}")

In [None]:
E4_net_mnist = Trainer(TestE4_net, criterion, optimizer, Mnist.train_loader, Mnist.test_loader)
print("E4_net no equivarience")
E4_net_mnist.train(1)
E4_net_mnist.test()
E4_net_mnist.test_c4()

In [None]:
E4_net_mnist = Trainer(TestE4_net, criterion, optimizer, Mnist.train_loader, Mnist.test_loader)
print("E4_net no equivarience")
E4_net_mnist.train_c4(1)
E4_net_mnist.test()
E4_net_mnist.test_c4()

## Network 3. Efficient Group Equivarient CNN with Equivarient Maxpooling

In [None]:
import torch
class EquivariantSubsample(torch.nn.Module):
  def __init__(self, reduction=(2,2)):
    super().__init__()
    self.width_reduction = reduction[0]
    self.height_reduction = reduction[1]

  def max(self, sample):
    max_values, max_w_indices = torch.max(sample, dim=3)
    max_values, h = torch.max(max_values, dim=2)
    w = torch.gather(max_w_indices, dim=2, index=h.unsqueeze(-1)).squeeze(-1)
    max_indices = torch.stack((h, w), dim=-1)
    return max_values, max_indices

  def get_p(self, image):
    b, c, h, w = image.shape
    max_values, max_index = self.max(image)
    max_h_idx = max_index[::, ::, 0] % self.height_reduction
    max_w_idx = max_index[::, ::, 1] % self.width_reduction
    return max_h_idx, max_w_idx

  def block_pool(self, img, y_offsets, x_offsets):
    batch_size, channels, height, width = img.shape
    h = self.height_reduction
    w = self.width_reduction

    # Clamp to ensure offsets are within bounds for the pool_size
    y_offsets = torch.clamp(y_offsets, 0, height - h)
    x_offsets = torch.clamp(x_offsets, 0, width - w)

    # extend to batch and channel
    y_offsets = y_offsets.unsqueeze(-1).unsqueeze(-1).repeat(1,1,h,w).to(device)
    x_offsets = x_offsets.unsqueeze(-1).unsqueeze(-1).repeat(1,1,h,w).to(device)

    # Generate grid for h x w window
    y_relative = torch.arange(h).view(1, 1, h, 1).repeat(batch_size, channels, 1, w).to(device)
    x_relative = torch.arange(w).view(1, 1, 1, w).repeat(batch_size, channels, h,1).to(device)

    # Calculate absolute indices for y and x within the h x w block
    y_indices = y_offsets + y_relative
    x_indices = x_offsets + x_relative

    block = img[torch.arange(batch_size).view(-1, 1, 1, 1),
                torch.arange(channels).view(1, -1, 1, 1),
                y_indices, x_indices]
    block, indexes = self.max(block)
    return block.to(device)

  def forward(self, images, p):
    b, c, h, w = images.shape
    p_w = p[0].unsqueeze(-1).to(device)
    p_h = p[1].unsqueeze(-1).to(device)

    w_sample_indices = p_w + torch.arange(0, w, self.width_reduction).to(device)
    h_sample_indices = p_h + torch.arange(0, h, self.height_reduction).to(device)

    out_h = h//self.height_reduction
    out_w = w//self.width_reduction

    output = torch.zeros(b,c,out_h, out_w).to(device)
    for i in range(out_h):
      for j in range(out_w):
        y_offsets = h_sample_indices[::,::,i]
        x_offsets = w_sample_indices[::,::,j]
        output[::,::, i,j] = self.block_pool(images, y_offsets, x_offsets)
    return output

In [None]:
#test equivariant function
sub = EquivariantSubsample(reduction=(2,2))
test_image = torch.tensor([[
                [15,2,3,4,5],
                [6,7,8,9,10],
                [11,12,13,14,15],
                [16,17,18,19,20],
                [21,22,23,24,25]
              ],

                [[1,2,3,4,5],
                [6,7,8,9,10],
                [11,12,13,14,15],
                [16,17,2,25,20],
                [21,22,22,24,23]],

              [[1,2,3,4,5],
                [6,7,8,9,10],
                [11,12,13,14,15],
                [16,17,18,19,20],
                [21,22,23,25,24]
              ]
              ]).repeat(4,1,1,1).to(device)
print(test_image)
print(test_image.shape)
m, indices = sub.max(test_image)
p = sub.get_p(test_image)
print(f"max height: {p[0]}, max width: {p[1]}")
print(m)
print(indices)

o = sub.forward(test_image, p)
print(f"Outut: {o}")


Equivarient Network With subsampling

In [None]:
class E4_Pooling(nn.Module):
    def __init__(self, in_w, in_channels=1, kernel_size=5, groups=8, reduction_ratio=1, drop=0.2):
        super(E4_Pooling, self).__init__()

        # lifting convolution
        self.liftingConv=C_4_Conv(in_channels, 16)

        # conv 3x3 + equivariant pooling
        self.conv3x3_1 = C_4_3x3(16,16)
        self.pool1 = EquivariantSubsample(reduction=(2,2))
        self.conv3x3_2 = C_4_3x3(16,16)
        self.pool2 = EquivariantSubsample(reduction=(2,2))

        # E4 network
        self.conv1=E4_C4(16, 16, kernel_size, reduction_ratio=reduction_ratio, groups=groups)
        self.conv2=E4_C4(16, 16, kernel_size, reduction_ratio=reduction_ratio, groups=groups)
        self.conv3=E4_C4(16, 16, kernel_size, reduction_ratio=1, groups=groups)

        self.ReductionConv1=E4_C4(4, 2, kernel_size, reduction_ratio=1, groups=1)
        self.ReductionConv2=E4_C4(2, 1, kernel_size, reduction_ratio=1, groups=1)

        # norm functions
        self.bn1=C_4_BN(16)
        self.bn2=C_4_BN(16)
        self.bn3=C_4_BN(16)

        self.group_pool=C_4_Pool()
        #self.global_pool=nn.AdaptiveMaxPool2d()
        final_size = ((((in_w-2)//2 ) - 2) // 2)**2
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(final_size*4, 36),
            torch.nn.Linear(36, 10)
        )


    def forward(self, x):
        # 3x3 conv 1
        x=torch.relu(self.liftingConv(x))
        x = self.conv3x3_1(x)
        p1 = self.pool1.get_p(x)
        x = self.pool1(x, p1)

        # 3x3 conv 2
        x = self.conv3x3_2(x)
        p2 = self.pool2.get_p(x)
        x = self.pool1(x,p2)


        # x=torch.relu(self.bn2(self.conv2(x)))
        # x=self.pool(x)

        x=torch.relu(self.bn1(self.conv1(x)))
        x=torch.relu(self.bn2(self.conv2(x)))
        x=torch.relu(self.bn3(self.conv3(x)))

        #x = self.ReductionConv2(x)

        x=self.group_pool(x)
        x = self.ReductionConv1(x)
        x = self.ReductionConv2(x)

        x = torch.flatten(x, 1)
        x=self.fully_net(x)
        return x

In [None]:
# loss function
import torch.optim as optim
E4_Pooling_net = E4_Pooling(in_w=32, in_channels=3, kernel_size=5, groups=8, reduction_ratio=1, drop=0.2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(E4_Pooling_net.parameters(), lr=0.001, momentum=0.9)

# print total number of parameters
par = sum(p.numel() for p in E4_Pooling_net.parameters() if p.requires_grad)
print(f"Total parameters of ES4_Net: {par}")

In [None]:
dataiter = iter(Ciphar.train_loader)
images, labels = next(dataiter)
images, lables = images.to(device), labels.to(device)
images_c4 = images.repeat(1, 4, 1, 1)
print(f"Inputs: images={images.shape}, label={labels.shape}")
out_a =E4_Pooling_net(images)
print(f"Final output: {out_a.shape}")
print(out_a)

In [None]:
import torch.optim as optim
E4_Pooling_net = E4_Pooling(in_w=28,in_channels=1, kernel_size=5, groups=8, reduction_ratio=1, drop=0.2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(E4_Pooling_net.parameters(), lr=0.001, momentum=0.9)

# print total number of parameters
par = sum(p.numel() for p in E4_Pooling_net.parameters() if p.requires_grad)
print(f"Total parameters of E4_Pooling_net: {par}")

E4_Pooling_net_mnist = Trainer(E4_Pooling_net, criterion, optimizer, Mnist.train_loader, Mnist.test_loader)
print("E4_net no equivarience")
E4_Pooling_net_mnist.train(2)
E4_Pooling_net_mnist.test()
E4_Pooling_net_mnist.test_c4()

In [None]:
E4_Pooling_net = E4_Pooling(in_w=28,in_channels=1, kernel_size=5, groups=8, reduction_ratio=1, drop=0.2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(E4_Pooling_net.parameters(), lr=0.001, momentum=0.9)

# print total number of parameters
par = sum(p.numel() for p in E4_Pooling_net.parameters() if p.requires_grad)
print(f"Total parameters of E4_Pooling_net: {par}")

E4_Pooling_net_mnist = Trainer(E4_Pooling_net, criterion, optimizer, Mnist.train_loader, Mnist.test_loader)
print("E4_net no equivarience")
E4_Pooling_net_mnist.train_c4(2)
E4_Pooling_net_mnist.test()
E4_Pooling_net_mnist.test_c4()

In [None]:
E4_Pooling_net_mnist.test()
E4_Pooling_net_mnist.test_c4()

## Network 4. Efficient Group Equivarient CNN with Equivarient Maxpooling and Non-equivarient Attention Layers