In [1]:
import torch
# import sys
# sys.path.append("/home/sbalakri/PycharmProjects/PhD_1st_sem/e2cnn")
# from e2cnn import gspaces
# from e2cnn import nn
from matplotlib import pyplot as plt
import torch.nn as nn
hidden_channels = 1

In [2]:
class Conv_CNN(torch.nn.Module):
    
    def __init__(self, n_classes=10):
        
        super(Conv_CNN, self).__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        
        # convolution encoder 2
        self.block2 = nn.Sequential(
            nn.Conv2d(16, 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
        )
        
        # convolution encoder 3

        self.block3 = nn.Sequential(
            nn.Conv2d(8, 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True)
        )
        
        # convolution encoder 4
        self.block4 = nn.Sequential(
            nn.Conv2d(4, 1, kernel_size=3, padding=1),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=4)
        )
        
        
        
        
        # convolution decoder 1 block 1
        self.block3_dec = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=3, padding=2, stride=1),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True)
        )

        self.block4_dec = nn.Sequential(
            nn.Conv2d(4, 8, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True)
        )
        self.pool2_dec = nn.Sequential(
            nn.Upsample(scale_factor=5/2)
        )
        
        
       # convolution decoder 2 block 1
        self.block1_dec = nn.Sequential(
            nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )               
        # convolution decoder 2 block 2
        self.block2_dec = nn.Sequential(
            nn.Conv2d(16, 1, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )
        
        self.pool1_dec = nn.Sequential(
            nn.Upsample(scale_factor=29/15)
        )
        
    
    def encode(self, input: torch.Tensor):
        
        x = input
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool1(x)
        
        x = self.block3(x)         
        x = self.block4(x)
        x = self.pool2(x)
        
        return x 
    
    def decode(self, input: torch.Tensor):
        
        x = input
        x = self.block3_dec(x)
        x = self.block4_dec(x)
        x = self.pool2_dec(x)
        
        x = self.block1_dec(x)
        x = self.block2_dec(x)
        x = self.pool1_dec(x)
    
        return x 
    
    def forward(self, input: torch.Tensor):

        x = input
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool1(x)
        
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)
        
        x = self.block3_dec(x)
        x = self.block4_dec(x)
        x = self.pool2_dec(x)
        
        x = self.block1_dec(x)
        x = self.block2_dec(x)
        x = self.pool1_dec(x)

        return x

In [3]:
# download the dataset
!wget -nc http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip
# uncompress the zip file
!unzip -n mnist_rotation_new.zip -d mnist_rotation_new

File ‘mnist_rotation_new.zip’ already there; not retrieving.

Archive:  mnist_rotation_new.zip


In [4]:
from torch.utils.data import Dataset
from torchvision.transforms import RandomRotation
from torchvision.transforms import Pad
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose

import numpy as np

from PIL import Image

device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [5]:
class MnistRotDataset(Dataset):
    
    def __init__(self, mode, transform=None):
        assert mode in ['train', 'test']
            
        if mode == "train":
            file = "mnist_rotation_new/mnist_all_rotation_normalized_float_train_valid.amat"
        else:
            file = "mnist_rotation_new/mnist_all_rotation_normalized_float_test.amat"
        
        self.transform = transform

        data = np.loadtxt(file, delimiter=' ')
        if mode == "train":    
            self.images = data[:, :-1][:500].reshape(-1, 28, 28).astype(np.float32)
            self.labels = data[:, -1][:500].astype(np.int64)
        else:
            self.images = data[:, :-1][:30000].reshape(-1, 28, 28).astype(np.float32)
            self.labels = data[:, -1][:30000].astype(np.int64)
        self.num_samples = len(self.labels)
    
    def __getitem__(self, index):
        image, label = self.images[index], self.labels[index]
        image = Image.fromarray(image)
        if self.transform is not None:
            image = self.transform(image)
#         import pdb; pdb.set_trace()
        return image, label
    
    def __len__(self):
        return len(self.labels)

# images are padded to have shape 29x29.
# this allows to use odd-size filters with stride 2 when downsampling a feature map in the model
pad = Pad((0, 0, 1, 1), fill=0)
totensor = ToTensor()

In [6]:
model = Conv_CNN().to(device)

In [7]:
# from torchsummary import summary
# summary(model, (1, 29, 29))

In [8]:
def test_model(model: torch.nn.Module, x: Image):
    # evaluate the `model` on 8 rotated versions of the input image `x`
    model.eval()
    
    x = pad(x)
    
    print(x)
    print('##########################################################################################')
    header = 'angle |  ' + '  '.join(["{:6d}".format(d) for d in range(10)])
    print(header)
    with torch.no_grad():
        for r in range(8):
            x_transformed = totensor(x.rotate(r*45., Image.BILINEAR)).reshape(1, 1, 29, 29)
            x_transformed.to(device)

            y = model(x_transformed)
            y = y.to('cpu').numpy().squeeze()
            
            angle = r * 45
            print("{:5d} : {}".format(angle, y))
    print('##########################################################################################')
    print()

    
# build the test set    
mnist_test = MnistRotDataset(mode='test')

# retrieve the first image from the test set
x, y = next(iter(mnist_test))

# evaluate the model
# test_model(model, x)

In [9]:
root = './data'
import os
pad = Pad((0, 0, 1, 1), fill=0)
totensor = ToTensor()
if not os.path.exists(root):
    os.mkdir(root)
import torchvision.datasets as dset
import torchvision.transforms as transforms
trans = transforms.Compose([
    pad,
    totensor,
])

train_set = dset.MNIST(root=root, train=True, transform=trans, download=True)
train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=16,
                 shuffle=True)

test_transform = Compose([
    pad,
    totensor,
])
mnist_test = MnistRotDataset(mode='test', transform=test_transform)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=16)

loss_function = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.00001)

In [10]:
import time
t0 = time.time()
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.00001)
epoch = 0
for epoch in range(50):
    model.train()
    running_loss = 0
    num_samples = 0
    for i, (x, t) in enumerate(train_loader):
        
        if i>200:
            break
        optimizer.zero_grad()
        y = model(x)
        loss = loss_function(y, x)
        loss.backward()
        optimizer.step()   
        
        running_loss += loss.item()
        num_samples = num_samples + 1
    if epoch % 5 == 0:   # print every 2000 mini-batches
        print('[%d, %5d] loss: %.20f' %
              (epoch + 1, 0 , running_loss / num_samples))
              
    print(epoch, running_loss / num_samples)
    
print('{} seconds'.format(time.time() - t0))

[1,     0] loss: 0.08661569517791567474
0 0.08661569517791567
1 0.0535745929315019
2 0.048022759142354945
3 0.04413232930115799
4 0.04142974391208952
[6,     0] loss: 0.03959260869826843748
5 0.03959260869826844
6 0.03741387418698316
7 0.03680647712256482
8 0.036074053375652775
9 0.034603436537717115
[11,     0] loss: 0.03410158927241960930
10 0.03410158927241961
11 0.033982008315659874
12 0.033736320555358384
13 0.03317086393968086
14 0.03284355938731141
[16,     0] loss: 0.03272702271554304293
15 0.03272702271554304
16 0.03219154077716431
17 0.03215129692941459
18 0.03178376872185154
19 0.03185673639993763
[21,     0] loss: 0.03194062441439177863
20 0.03194062441439178
21 0.03183569898133847
22 0.03166331714065514
23 0.03146551709751881
24 0.031124192081503015
[26,     0] loss: 0.03086214082602837736
25 0.030862140826028377
26 0.030659768962074275
27 0.030762676742687747
28 0.03075851059856996
29 0.030566231389321498
[31,     0] loss: 0.03016493642767566902
30 0.03016493642767567
31 

In [11]:
enc_rep = torch.tensor([])
t_all = torch.tensor([])
for i, (x, t) in enumerate(train_loader):

    y = model.encode(x)
    enc_rep=torch.cat((enc_rep, y), 0)

    t = t.float()
    t_all = torch.cat((t_all, t), 0)
    
    if i > 200:
        break


In [12]:
# Export  
enc_rep_all = enc_rep.reshape(-1, hidden_channels*16).detach().numpy()
import pickle
features_path = "conv_autoe_features_conv_train.pickle"
labels_path = "conv_autoe_labels_conv_train.pickle"
orig_path = "conv_autoe_orig_features_conv_train.pickle"
pickle.dump(enc_rep_all, open(features_path, 'wb'))

t_all1 = t_all.detach().numpy().reshape(t_all.shape[0], 1)
pickle.dump(t_all1, open(labels_path, 'wb'))

In [13]:
enc_rep = torch.tensor([])
t_all = torch.tensor([])
for i, (x, t) in enumerate(test_loader):

    y = model.encode(x)
    enc_rep=torch.cat((enc_rep, y), 0)

    t = t.float()
    t_all = torch.cat((t_all, t), 0)


In [14]:
# Export  
enc_rep_all = enc_rep.reshape(-1, hidden_channels*16).detach().numpy()
import pickle
features_path = "conv_autoe_features_conv_test.pickle"
labels_path = "conv_autoe_labels_conv_test.pickle"
orig_path = "conv_autoe_orig_features_conv_test.pickle"
pickle.dump(enc_rep_all, open(features_path, 'wb'))

t_all1 = t_all.detach().numpy().reshape(t_all.shape[0], 1)
pickle.dump(t_all1, open(labels_path, 'wb'))

In [15]:
#gram pooling
enc_rep = enc_rep.reshape(-1, hidden_channels*8, 49)
encoded_rep_all_sum = torch.tensor([])
for s in range(enc_rep.shape[0]):   
    encoded_rep = enc_rep[s]
    all_encs =  torch.tensor([])
    sum_enc_rep = torch.empty(hidden_channels, hidden_channels, dtype=torch.float)
    for mm in range(hidden_channels):
        m = 8 * mm
        for i in range(hidden_channels - mm):
            sum_enc_rep[mm, i+mm] = torch.dot(encoded_rep[m], encoded_rep[8*i + m]) + torch.dot(encoded_rep[m+1],encoded_rep[8*i + m+1]) + torch.dot(encoded_rep[m+2], encoded_rep[8*i + m+2]) + torch.dot(encoded_rep[m+3], encoded_rep[8*i + m+3]) + torch.dot(encoded_rep[m+4], encoded_rep[8*i + m+4]) + torch.dot(encoded_rep[m+5],encoded_rep[8*i + m+5]) + torch.dot(encoded_rep[m+6],encoded_rep[8*i + m+6]) + torch.dot(encoded_rep[m+7],encoded_rep[8*i + m+7])
            sum_enc_rep[i+mm, mm] = sum_enc_rep[mm, i+mm]
    encoded_rep_all_sum = torch.cat((encoded_rep_all_sum, sum_enc_rep.reshape(1, hidden_channels*hidden_channels)), 0)


RuntimeError: shape '[-1, 8, 49]' is invalid for input of size 480000

In [None]:
# enc_rep_all = encoded_rep_all_sum.reshape(-1, hidden_channels*hidden_channels).detach().numpy()
# features_path = "conv_autoe_features_test2.pickle"
# labels_path = "conv_autoe_labels_test2.pickle"
# orig_path = "conv_autoe_orig_features_test2.pickle"
# pickle.dump(enc_rep_all, open(features_path, 'wb'))

# t_all1 = t_all.detach().numpy().reshape(t_all.shape[0], 1)
# pickle.dump(t_all1, open(labels_path, 'wb'))

In [None]:
# enc_rep_all = encoded_rep_all_sum.reshape(-1, hidden_channels*hidden_channels).detach().numpy()
# features_path = "conv_autoe_features_train2.pickle"
# labels_path = "conv_autoe_labels_train2.pickle"
# orig_path = "conv_autoe_orig_features_train2.pickle"
# pickle.dump(enc_rep_all, open(features_path, 'wb'))

# t_all1 = t_all.detach().numpy().reshape(t_all.shape[0], 1)
# pickle.dump(t_all1, open(labels_path, 'wb'))