In [1]:
import torch
from e2cnn import gspaces
from e2cnn import nn
import sys
from torch.utils.data import Dataset
from torchvision.transforms import RandomRotation
from torchvision.transforms import Pad
from torchvision.transforms import Resize
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose

import numpy as np

from PIL import Image

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [2]:
class C4SteerableCNN(torch.nn.Module):
    
    def __init__(self, n_classes=10):
        
        super(C4SteerableCNN, self).__init__()
        
        # the model is equivariant under rotations by 90 degrees, modelled by C4
        self.r2_act = gspaces.Rot2dOnR2(N=4)
        
        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        
        # to store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type
        
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C4
        out_type = nn.FieldType(self.r2_act, 24*[self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            nn.MaskModule(in_type, 29, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        #convolution 2
        in_type = self.block1.out_type
        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool1 = nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        
        #convolution 3
        in_type = self.block2.out_type
        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        #convolution 4
        in_type = self.block3.out_type
        out_type = nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool2 = nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        self.gpool = nn.GroupPooling(out_type)
        
        # number of output channels
        c = self.gpool.out_type.size
        
        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c*7*7, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
    
    def forward(self, input: torch.Tensor):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        x = nn.GeometricTensor(input, self.input_type)
        
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool1(x)
        
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)
        
        # pool over the group
        x = self.gpool(x)
        # unwrap the output GeometricTensor
        # (take the Pytorch tensor and discard the associated representation)
        x = x.tensor
        
        x = self.fully_net(x.reshape(x.shape[0], -1))
        
        return x


In [4]:
class MnistRotDataset(Dataset):   
    def __init__(self, mode, transform=None):
        assert mode in ['train', 'test']
            
        if mode == "train":
            file = "mnist_rotation_new/mnist_all_rotation_normalized_float_train_valid.amat"
        else:
            file = "mnist_rotation_new/mnist_all_rotation_normalized_float_test.amat"
        
        self.transform = transform

        data = np.loadtxt(file, delimiter=' ')
            
        self.images = data[:, :-1].reshape(-1, 28, 28).astype(np.float32)
        self.labels = data[:, -1].astype(np.int64)
        self.num_samples = len(self.labels)
    
    def __getitem__(self, index):
        image, label = self.images[index], self.labels[index]
        image = Image.fromarray(image)
        if self.transform is not None:
            image = self.transform(image)
        return image, label
    
    def __len__(self):
        return len(self.labels)

# images are padded to have shape 29x29.
# this allows to use odd-size filters with stride 2 when downsampling a feature map in the model
pad = Pad((0, 0, 1, 1), fill=0)

# to reduce interpolation artifacts (e.g. when testing the model on rotated images),
# we upsample an image by a factor of 3, rotate it and finally downsample it again
resize1 = Resize(87)
resize2 = Resize(29)

# Defining the dataloaders
train_transform = Compose([pad, resize1,
    RandomRotation(180, resample=Image.BILINEAR, expand=False),
    resize2,ToTensor(),])

mnist_train = MnistRotDataset(mode='train', transform=train_transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64)

test_transform = Compose([pad,ToTensor(),])
mnist_test = MnistRotDataset(mode='test', transform=test_transform)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64)


In [5]:
model = C4SteerableCNN().to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-5)



In [6]:
for epoch in range(31):
    model.train()
    for i, (x, t) in enumerate(train_loader):
        optimizer.zero_grad()

        x = x.to(device)
        t = t.to(device)

        y = model(x)
        #sys.exit()

        loss = loss_function(y, t)
        if i%5==0:
            print(f"Epoch numer {epoch} | Batch number {i}| Loss:",loss)

        loss.backward()

        optimizer.step()
    
    if epoch % 5 == 1:
        total = 0
        correct = 0
        with torch.no_grad():
            model.eval()
            for i, (x, t) in enumerate(test_loader):

                x = x.to(device)
                t = t.to(device)
                
                y = model(x)

                _, prediction = torch.max(y.data, 1)
                total += t.shape[0]
                correct += (prediction == t).sum().item()
        print(f"After epoch {epoch} | Test accuracy: {correct/total*100.}")

Epoch numer 0 | Batch number 0| Loss: tensor(2.5043, device='cuda:0', grad_fn=<NllLossBackward>)
Epoch numer 0 | Batch number 5| Loss: tensor(2.1237, device='cuda:0', grad_fn=<NllLossBackward>)
Epoch numer 0 | Batch number 10| Loss: tensor(1.9712, device='cuda:0', grad_fn=<NllLossBackward>)
Epoch numer 0 | Batch number 15| Loss: tensor(1.7131, device='cuda:0', grad_fn=<NllLossBackward>)
Epoch numer 0 | Batch number 20| Loss: tensor(1.7183, device='cuda:0', grad_fn=<NllLossBackward>)
Epoch numer 0 | Batch number 25| Loss: tensor(1.6067, device='cuda:0', grad_fn=<NllLossBackward>)
Epoch numer 0 | Batch number 30| Loss: tensor(1.6867, device='cuda:0', grad_fn=<NllLossBackward>)
Epoch numer 0 | Batch number 35| Loss: tensor(1.5074, device='cuda:0', grad_fn=<NllLossBackward>)
Epoch numer 0 | Batch number 40| Loss: tensor(1.4283, device='cuda:0', grad_fn=<NllLossBackward>)
Epoch numer 0 | Batch number 45| Loss: tensor(1.2501, device='cuda:0', grad_fn=<NllLossBackward>)
Epoch numer 0 | Batch 

KeyboardInterrupt: 