In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _pair
from collections import OrderedDict, abc
from itertools import repeat

from typing import Type, Callable, List

# Inference and learning final project

## Introduction


## Imports

In [3]:
# out of memory,
#try to reimplement according to https://github.com/keras-team/keras/blob/v2.11.0/keras/layers/locally_connected/locally_connected2d.py#L34
class LocalLinear_custom(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int,padding : int, image_size: int):
        super(LocalLinear_custom, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        fold_num = (image_size+ 2* padding - self.kernel_size)//self.stride+1
        self.weight = nn.Parameter(torch.randn( in_channels, fold_num,fold_num, kernel_size, kernel_size, out_channels))
        self.bias = nn.Parameter(torch.randn(fold_num,fold_num,out_channels))

    def forward(self, x:torch.Tensor):
        #print("intial xshape",x.shape)
        #print("weight",self.weight.shape)
        #print("bias",self.bias.shape)
        x = F.pad(x, [self.padding]*4, value= 0)
        #print("pad",x.shape)
        x = x.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride)
        #print("unfold",x.shape)
        x = x.unsqueeze(1)
        x = torch.matmul(x,self.weight).squeeze(2).sum([-1,-2,-5])
        #print("matmul",x.shape)
        x = x + self.bias
        #print("final xshape",x.shape,end="\n\n")
        return x

class view_custom(nn.Module):
    def __init__(self,args):
        super().__init__()
        self.args = args
    def forward(self, x):
        return x.view(*self.args)


In [4]:
def funConv2d(size_in: int,size_out: int, kernel_size: int, stride: int, image_size: int ) -> tuple[int,int,List[nn.Module]]:
    padding = kernel_size//2
    
    modules = [ 
    nn.Conv2d(size_in,size_out,kernel_size,stride = stride,padding = padding),
    nn.BatchNorm2d(size_out),
    nn.ReLU()]
    im_out = (image_size+ 2 * padding - kernel_size)//stride+1
    return im_out,size_out,modules

def funLocalLinear(size_in: int,size_out: int, kernel_size: int, stride: int, image_size: int ) -> tuple[int,int,List[nn.Module]]:
    padding = kernel_size//2
    modules = [ 
    LocalLinear_custom(size_in,size_out,kernel_size,stride = stride,padding = padding,image_size = image_size),
    nn.BatchNorm2d(size_out),
    nn.ReLU()]
    im_out = (image_size+ 2* padding - kernel_size)//stride + 1
    return im_out,size_out,modules

def funFullyConnected(size_in: int,size_out: int, kernel_size: int, stride: int, image_size: int ) -> tuple[int,int,List[nn.Module]]:
    modules = []
    modules.extend([
    view_custom([-1,image_size*image_size*size_in]),
    nn.Linear(size_in * image_size**2,size_out * (image_size//stride)**2),
    view_custom([-1,size_out,image_size//stride,image_size//stride]),
    nn.BatchNorm2d(size_out),
    nn.ReLU()])
    im_out = image_size//stride
    return im_out,size_out,modules

In [5]:
class D_Conv(nn.Module):
    ### only works with square image, kernel size and stride
    def __init__(self, base_channel_size: int, loc : Callable[[int,int,int,int,int],tuple[int,int,List[nn.Module]]], image_size :int ) -> None:
        super().__init__()
        
        alpha = base_channel_size
        channel_size :int = 3
        im_size : int = image_size
        
        modules : List[nn.Module] = []
        size_conv = [
            (alpha,1),
            (2*alpha,2), #remove stride for now
            (2*alpha,1),
            (4*alpha,2), #here too
            (4*alpha,1),
            (8*alpha,2), # here too
            (8*alpha,1),
            (16*alpha,2)] # here too
        
        for i,val in enumerate(size_conv):
            size_out, stride = val
            im_size,channel_size,new_module = loc(channel_size,size_out,3,stride,im_size)
            modules.extend(new_module)
            
        self.conv = nn.Sequential(*modules)
        im_size,channel_size,modules = funFullyConnected(channel_size,64*alpha,0,1,im_size)
        self.fc = nn.Sequential(*modules)
        self.fc_final = nn.Linear(channel_size*im_size**2,10)

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        x = torch.flatten(x, 1)
        x = self.fc_final(x)
        return x

class S_Conv(nn.Module):
    def __init__(self, base_channel_size: int, loc : Callable[[int,int,int,int,int],tuple[int,int,List[nn.Module]]], image_size :int) -> None:
        super().__init__()
        
        alpha = base_channel_size
        channel_size :int = 3
        im_size : int = image_size
        
        im_size,channel_size, modules = loc(channel_size,alpha,9,2,im_size)
        self.conv =  nn.Sequential(*modules)
        
        im_size,channel_size, modules = funFullyConnected(channel_size,24*alpha,0,1,im_size)
        self.fc = nn.Sequential(*modules)
        
        self.fc_final = nn.Linear(channel_size*im_size**2,10)

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        x = torch.flatten(x, 1)
        x = self.fc_final(x)
        return x

In [6]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 512
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size,shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
networks = [(D_Conv,"D_Conv"),(S_Conv,"S_Conv")]
convolutions = [(funConv2d,"Conv2d"),(funFullyConnected,"FullyConnected")] #(funLocalLinear,"LocalLinear"),
nets = [(ni+"_"+nj,i,j) for i,ni in networks[1:] for j,nj in convolutions[1:] ]
torch.cuda.empty_cache()
torch.cuda.memory_allocated()

0

In [8]:
for name,net_type,call in nets:
    net = net_type(10,call,32).to(DEVICE)
    print(name)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    
    for epoch in range(400):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
        
        print('*', end='')
        if(epoch % 25 == 24):
            print(f" {epoch+1}")

    print('\nFinished Training')
    PATH = f"networks/{name}.pth"
    torch.save(net.state_dict(), PATH)
    del net

S_Conv_FullyConnected
************************* 25
************************* 50
************************* 75


In [15]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
        # calculate outputs by running images through the network
        outputs = net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

Accuracy of the network on the 10000 test images: 36 %
