<a href="https://colab.research.google.com/github/aycaKula/Artificial_Intelligence/blob/main/SuperpositionOfManyModelsIntoOne.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Neural Network Superposition**

**IMPORT LIBRARIES**

In [1]:
import numpy as np
import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair

import skimage.transform
from torchvision import datasets, transforms

In [2]:
import os
from datetime import datetime
import socket
import json
import time
import torchvision.utils as tvu 
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
import pprint

In [3]:
!pip install tensorboardcolab
import torch
from torch.utils.tensorboard import SummaryWriter

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorboardcolab
  Downloading tensorboardcolab-0.0.22.tar.gz (2.5 kB)
Building wheels for collected packages: tensorboardcolab
  Building wheel for tensorboardcolab (setup.py) ... [?25l[?25hdone
  Created wheel for tensorboardcolab: filename=tensorboardcolab-0.0.22-py3-none-any.whl size=3858 sha256=4a38e4b708640425f4e9bf329c0c3e1405f964ddee4a99c3170b82b26b3ccd5f
  Stored in directory: /root/.cache/pip/wheels/69/4e/4a/1c6c267395cb10edded1050df12af165d3254cfce324e80941
Successfully built tensorboardcolab
Installing collected packages: tensorboardcolab
Successfully installed tensorboardcolab-0.0.22


'''Define the arguments for each class
'''

In [4]:
class RotatingMNISTBinaryHash:
    def __init__(self):
        # activation functions
        self.activation = 'relu'
        # batch size
        self.batch_size = 128
        self.cheat_period = 1000
        # dataset
        self.dataset = 'rotating_mnist'
        self.desc = 'rot_mnist_binhash'
        self.key_pick = 'hash'
        self.learn_key = True
        self.lr = 1e-4
        self.momentum = 0.5
        self.n_layers = 3
        self.n_units = 256
        self.net = 'binaryhash'
        self.net_period = 50
        self.no_cuda = False
        self.optimizer = 'rmsprop'
        self.period = 1000
        self.rotate_continually = False
        self.seed = 1
        self.shuffle_test = False
        self.stationary = 0
        self.steps = 10000 # 50000
        self.test_batch_size = 1000
        self.test_steps = 10
        self.test_time = 0
        self.time_slow = 100.
        self.time_loss_coeff = 0.
        self.s_loss_coeff = 0.


class RotatingMNISTUnitaryHash(RotatingMNISTBinaryHash):
    def __init__(self):
        super(RotatingMNISTUnitaryHash, self).__init__()
        self.net = 'rotatehash'
        self.desc = 'rot_mnist_rothash'


class RotatingMNISTComplex(RotatingMNISTBinaryHash):
    def __init__(self):
        super(RotatingMNISTComplex, self).__init__()
        self.net = 'complex'
        self.desc = 'rot_mnist_complex'


class RotatingMNISTReal(RotatingMNISTBinaryHash):
    def __init__(self):
        super(RotatingMNISTReal, self).__init__()
        self.net = 'real'
        self.desc = 'rot_mnist_real'


class ICIFARResNet18:
    def __init__(self):
        self.activation = 'relu'
        self.batch_size = 128
        self.cheat_period = 100000
        self.dataset = 'incrementing_cifar'
        self.desc = 'icifar_resnet18'
        self.key_pick = 'hash'
        self.learn_key = True
        self.lr = 0.001
        self.momentum = 0.5
        self.n_layers = 6
        self.n_units = 64
        self.net = 'staticbnresnet18'
        self.net_period = 10
        self.no_cuda = False
        self.optimizer = 'rmsprop'
        self.period = 20000
        self.rotate_continually = False
        self.seed = 1
        self.shuffle_test = False
        self.stationary = 0
        self.steps = 10000 # 100000
        self.test_batch_size = 1000
        self.test_steps = 10
        self.test_time = 0
        self.time_slow = 20000.
        self.time_loss_coeff = 0.
        self.s_loss_coeff = 0.
        

class ICIFAR100ResNet18(ICIFARResNet18):
    def __init__(self):
        super(ICIFAR100ResNet18, self).__init__()
        self.dataset = 'incrementing_cifar100'
        self.desc = 'icifar100_resnet18'


class ICIFARHashResNet18(ICIFARResNet18):
    def __init__(self):
        super(ICIFARHashResNet18, self).__init__()
        self.net = 'hashresnet18'
        self.desc = 'icifar_hashresnet18'



**DATASETS**

In [5]:
'''
Rotation --> Rotate the image by angle
RotatingCIFAR
'''

class NonstationaryLoader(object):
    def __init__(self):
        # Start time at -1 which means no data has been drawn 
        self.current_time = -1

    def time(self):
        return self.current_time

    def set_time(self, new_time):
        self.current_time = new_time
        
    def get_data(self):
        raise NotImplementedError

    def get_dim(self):
        raise NotImplementedError


class Rotation(object):
    """Rotate the image by angle."""
    def __init__(self, resample=False, expand=False, center=None):
        self.resample = resample
        self.expand = expand
        self.center = center
        self.angle = 0.
        
    def __call__(self, img):
        print(self.angle)
        return transforms.functional.rotate(img,
                                            self.angle,
                                            self.resample,
                                            self.expand,
                                            self.center)


def rotate_image_batch(images, angle):
    # Normalize images for skimage.transform.rotate
    norm_images = np.copy(images)
    imgs_min = norm_images.min()
    norm_images -= imgs_min 
    norm_max = norm_images.max()
    norm_images /= norm_max

    norm_images = norm_images.transpose(2,3,0,1)
    h,w,n,c = norm_images.shape
    norm_images = norm_images.reshape((h,w,n*c))
    rot_images = skimage.transform.rotate(norm_images, angle)
    rot_images = rot_images.reshape((h,w,n,c))
    rot_images = rot_images.transpose(2,3,0,1)
    # Un-normalize images
    rot_images *= norm_max
    rot_images += imgs_min

    return rot_images


class RotatingData(NonstationaryLoader):
    def __init__(self, dataset, rotate_period, batch_size, draw_and_rotate, kwargs={}):
        super(RotatingData, self).__init__()
        self.data_loader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, shuffle=True, **kwargs)
        self.rotate_period = rotate_period
        self.draw_and_rotate = draw_and_rotate
        self.data_iter = iter(self.data_loader)
        self.sample_buffer = None
        self.draws = 0

    def get_data(self):
        self.current_time += 1
        if self.draw_and_rotate:
            try:
                self.sample_buffer = self.data_iter.next()
            except StopIteration:
                # Epoch finished, reset data loader
                self.data_iter = iter(self.data_loader)
                self.sample_buffer = self.data_iter.next()
            self.draws += 1
        else:
            # While loop handles arbitrary skips in time larger than a period
            while int(self.current_time/self.rotate_period) >= self.draws:
                try:
                    self.sample_buffer = self.data_iter.next()
                except StopIteration:
                    # Epoch finished, reset data loader
                    self.data_iter = iter(self.data_loader)
                    self.sample_buffer = self.data_iter.next()
                self.draws += 1

        images, labels = self.sample_buffer
        angle = 360.*(self.current_time/self.rotate_period)
        rot_images = rotate_image_batch(images.numpy(), angle)
        input_data = torch.from_numpy(rot_images).float()
        return input_data, labels


class RotatingMNIST(RotatingData):
    def __init__(self, rotate_period, batch_size, train=True, draw_and_rotate=True, kwargs={}):
        dataset = datasets.MNIST('data', train=train, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ]))
        super(RotatingMNIST, self).__init__(dataset, rotate_period, batch_size, draw_and_rotate, kwargs)

    def get_dim(self):
        return self.data_loader.dataset.data.shape[1:], 10

####################################################################################################################
###########################################################################
############################################## Incrementing CIFAR Dataset
##########################
class IncrementingCIFAR(NonstationaryLoader):
    def __init__(self, change_period, batch_size, n_class=10, use_cifar10=True, train=True, seed=1234, kwargs={}):
        super(IncrementingCIFAR, self).__init__()
        transform = transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                    ])
        self.dataset_0 = datasets.CIFAR10('data/cifar10', train=train, download=True,
                                          transform=transform)
        self.dataset = datasets.CIFAR100('data/cifar100', train=train, download=True,
                                         transform=transform)
        self.change_period = change_period
        self.transform = transform
        self.batch_size = batch_size
        self.train = train
        self.rand = np.random.RandomState(seed)
        self.n_class = n_class
        self.use_cifar10 = use_cifar10

    def get_data(self):
        self.current_time += 1
        task_offset = int(self.current_time/self.change_period) % int(100/self.n_class)

        # Train initially on the bigger CIFAR10 dataset to stabilize the classifier
        # (neural networks need a lot of data to perform consistently)
        if task_offset == 0 and self.use_cifar10:
            cur_dataset = self.dataset_0
        # Switch to the smaller CIFAR100 dataset to generate new tasks 
        else:
            cur_dataset = self.dataset

        data = cur_dataset.data
        labels = np.array(cur_dataset.targets)

        # Retrieve labels within a contiguous set of class labels
        label_range = range(self.n_class*task_offset, self.n_class*(task_offset+1))
        idx = np.isin(labels, label_range)
        # Indexing like the one below only works when current_data is a numpy variable
        # and not a torch variable.
        current_data = data[idx]
        current_labels = labels[idx]

        # Retrieve a random minibatch from the current task data 
        rand_idx = self.rand.permutation(current_data.shape[0])
        mb_data = current_data[rand_idx[:self.batch_size]]
        mb_labels = current_labels[rand_idx[:self.batch_size]]

        # Preprocess input to (-1,+1) range
        mb_data = torch.Tensor(mb_data.transpose(0,3,1,2))
        mb_data = 2.*(mb_data/255.) - 1.
        
        mb_labels = torch.Tensor(mb_labels) - self.n_class*task_offset
        mb_labels = mb_labels.long()
        return mb_data, mb_labels

    def get_dim(self):
        return self.dataset.data.shape[1:], (self.n_class,)
  


In [None]:
def from_polar(r, phi):
    a = r*torch.cos(phi)
    b = r*torch.sin(phi)
    return a, b


def to_polar(v):
    r = torch.norm(v, p=2, dim=0)
    phi = torch.atan2(v[1], v[0])
    return r, phi
    

class ComplexVar(object):
    def __init__(self, a, b, polar_init=False):
        self.s_r = torch.FloatTensor(*((2,) + a.shape))
        if polar_init:
            a, b = from_polar(a, b)

        self.s_r[0] = a
        self.s_r[1] = b

    def to_polar(self):
        return to_polar(self.s_r) 

**MODELS**

**torch.nn.init.xavier_uniform_(tensor, gain=1.0):** Fills the input Tensor with values according to the method described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution. The resulting tensor will have values sampled from U(−a,a) where

**nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')**


**Tensor.requires_grad** : Is True if gradients need to be computed for this Tensor, False otherwise.

In [6]:
class RouteLinear(nn.Module):
    def __init__(self, n_in, n_out, period, key_pick='hash', learn_key=True):
        super(RouteLinear, self).__init__()
        self.key_pick = key_pick
        w = nn.init.xavier_normal_(torch.empty(n_in, n_out))
        self.w = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(n_out))
        rots = []
        
        for key_i in range(period):
            row_idx = np.random.permutation(n_in)
            random_route = np.eye(n_in)[row_idx].astype('float32')
            rots.append(torch.from_numpy(random_route))

        rots = torch.stack(rots)
        self.rots = nn.Parameter(rots)
        
        if not learn_key:
            self.rots.requires_grad = False
    
    def forward(self, x, time):
        m = torch.mm(x, self.rots[int(time)])
        return torch.mm(m, self.w) + self.bias


class HashConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, period, 
                 stride=1, padding=0, bias=True,
                 key_pick='hash', learn_key=True):
        super(HashConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)

        w = torch.zeros(self.out_channels, self.in_channels, *self.kernel_size)
        nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
        self.w = nn.Parameter(w)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.register_parameter('bias', None)

        o_dim = self.in_channels*self.kernel_size[0]*self.kernel_size[1]
        # TODO(briancheung): The line below will cause problems when saving a model
        o = torch.from_numpy( np.random.binomial( p=.5, n=1, size = (o_dim, period) ).astype(np.float32) * 2 - 1 )
        self.o = nn.Parameter(o, requires_grad=False)

    def forward(self, x, time):
        net_time = time % self.o.shape[1]
        o = self.o[:, net_time].view(1,
                                     self.in_channels,
                                     self.kernel_size[0],
                                     self.kernel_size[1])
        return F.conv2d(x, self.w*o, self.bias, stride=self.stride, padding=self.padding)


class BinaryHashLinear(nn.Module):
    def __init__(self, n_in, n_out, period, key_pick='hash', learn_key=True):
        super(BinaryHashLinear, self).__init__()
        self.key_pick = key_pick
        w = nn.init.xavier_normal_(torch.empty(n_in, n_out))
        rand_01 = np.random.binomial(p=.5, n=1, size=(n_in, period)).astype(np.float32)
        o = torch.from_numpy(rand_01*2 - 1)

        self.w = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(n_out))
        self.o = nn.Parameter(o)
        if not learn_key:
            self.o.requires_grad = False # Is True if gradients need to be computed for this Tensor, False otherwise.

    def forward(self, x, time):
        o = self.o[:, int(time)]
        m = x*o
        r = torch.mm(m, self.w)
        return r


class RealLinear(nn.Module):
    '''Complex layer that throws away imaginary part'''
    def __init__(self, n_in, n_out):
        super(RealLinear, self).__init__()
        a = torch.empty(n_in, n_out)
        a = nn.init.xavier_normal_(a)
        b = torch.zeros(n_in, n_out)
        self.cv = ComplexVar(a, b, polar_init=True)
        self.w = nn.Parameter(self.cv.s_r)
        self.bias = nn.Parameter(torch.zeros(n_out))

    def forward(self, x_a, x_b):
        x_b = x_b*0
        w_a = self.w[0]
        w_b = self.w[1]
        r_a = torch.mm(x_a, w_a) - torch.mm(x_b, w_b)
        r_b = torch.mm(x_b, w_a) + torch.mm(x_a, w_b)
        return r_a + self.bias, r_b


class ComplexLinear(nn.Module):
    '''Complex layer'''
    def __init__(self, n_in, n_out):
        super(ComplexLinear, self).__init__()
        a = torch.empty(n_in, n_out)
        a = nn.init.xavier_normal_(a)
        b = torch.Tensor(n_in, n_out).uniform_(-np.pi, np.pi) 
        self.cv = ComplexVar(a, b, polar_init=True)
        self.w = nn.Parameter(self.cv.s_r)
        self.bias = nn.Parameter(torch.zeros(n_out))

    def forward(self, x_a, x_b):
        w_a = self.w[0]
        w_b = self.w[1]
        r_a = torch.mm(x_a, w_a) - torch.mm(x_b, w_b)
        r_b = torch.mm(x_b, w_a) + torch.mm(x_a, w_b)
        return r_a + self.bias, r_b

In [7]:
#################################################################################
############################################ BASELINE NET
#############################
class BaselineNet(nn.Module):
    def __init__(self, input_dim, n_units, activation, n_layers, layer_type):
        super(BaselineNet, self).__init__()
        self.n_units = n_units
        self.activation = activation
        self.n_layers = n_layers
        
        layer_units = [np.prod(input_dim)]
        layer_units += [n_units for i in range(n_layers)]
        layer_units += [10]

        layers = []
        for i in range(len(layer_units)-1):
            layers.append(layer_type(layer_units[i], layer_units[i+1]))
        self.layers = nn.ModuleList(layers)

    def forward(self, x, time):
        preactivations = []
        r = x
        for layer_i, layer in enumerate(self.layers):
            if layer_i > 0:
                r = self.activation(r)
            r = layer(r)
            preactivations.append(r)

        return r, None, preactivations


class SimpleNet(BaselineNet):
    def forward(self, x, time):
        preactivations = []
        r_a, r_b = x, torch.zeros_like(x)
        for layer_i, layer in enumerate(self.layers):
            if layer_i > 0:
                r_a = self.activation(r_a)
                r_b = self.activation(r_b)
            r_a, r_b = layer(r_a, r_b)
            preactivations.append(r_a)

        return r_a, r_b, preactivations

#################################################################################
############################################ CONTEXT NET
#############################
class HashNet(nn.Module):
    def __init__(self, input_dim, n_units,
                 activation, n_layers,
                 period, key_pick, 
                 learn_key, layer_type):
        super(HashNet, self).__init__()
        self.n_units = n_units
        self.activation = activation

        layer_units = [np.prod(input_dim)]
        layer_units += [n_units for i in range(n_layers)]
        layer_units += [10]

        layers = []
        for i in range(len(layer_units)-1):
            layers.append(layer_type(layer_units[i],
                                     layer_units[i+1],
                                     period, key_pick, learn_key))
        self.layers = nn.ModuleList(layers)


class RealHashNet(HashNet):
    def forward(self, x, time):
        preactivations = []
        r = x
        for layer_i, layer in enumerate(self.layers):
            if layer_i > 0:
                r = self.activation(r)
            r = layer(r, time)
            preactivations.append(r)

        return r, None, preactivations


#################################################################################
############################################ RESNET
#############################
''' https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py '''


class HashBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, period, stride=1):
        super(HashBasicBlock, self).__init__()
        self.conv1 = HashConv2d(in_planes, planes, 3, period, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes, affine=False, track_running_stats=False)
        self.conv2 = HashConv2d(planes, planes, 3, period, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes, affine=False, track_running_stats=False)

        self.shortcut = nn.ModuleList()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.ModuleList(
                [HashConv2d(in_planes, self.expansion*planes, 1, period, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes, affine=False, track_running_stats=False)]
            )

    def forward(self, x, time):
        out = F.relu(self.bn1(self.conv1(x, time)))
        out = self.bn2(self.conv2(out, time))
        if len(self.shortcut) > 0:
            sout = self.shortcut[0](x, time)
            sout = self.shortcut[1](sout)
        else:
            sout = x
        out += sout 
        out = F.relu(out)
        return out

class StaticBNBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(StaticBNBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes, affine=False, track_running_stats=False)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes, affine=False, track_running_stats=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes, affine=False, track_running_stats=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, bn1_affine=True, bn1_track_stats=True):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64, affine=bn1_affine, track_running_stats=bn1_track_stats)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, time):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out, None, []


class HashResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(HashResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64, affine=False, track_running_stats=False)
        self.layer1 = nn.ModuleList(self._make_layer(block, 64, num_blocks[0], stride=1, period=1000))
        self.layer2 = nn.ModuleList(self._make_layer(block, 128, num_blocks[1], stride=2, period=1000))
        self.layer3 = nn.ModuleList(self._make_layer(block, 256, num_blocks[2], stride=2, period=1000))
        self.layer4 = nn.ModuleList(self._make_layer(block, 512, num_blocks[3], stride=2, period=1000))
        self.linear = BinaryHashLinear(512*block.expansion,
					  num_classes,
				          1000)
        self.cheat_period = 1000000
        self.time_slow = 20000

    def _make_layer(self, block, planes, num_blocks, stride, period):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, period, stride))
            self.in_planes = planes * block.expansion
        return layers

    def forward(self, x, time):
        time = int(time)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1[0](out, time)
        out = self.layer1[1](out, time)
        out = self.layer2[0](out, time)
        out = self.layer2[1](out, time)
        out = self.layer3[0](out, time)
        out = self.layer3[1](out, time)
        out = self.layer4[0](out, time)
        out = self.layer4[1](out, time)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out, time)
        return out, None, []

def StaticBNResNet18():
      return ResNet(StaticBNBasicBlock, [2,2,2,2],bn1_affine=False, bn1_track_stats=False)

def HashResNet18(num_classes):
    return HashResNet(HashBasicBlock, [2,2,2,2], num_classes=num_classes)

      

**GET FULLY CONNECTED OR CONVOLUTIONAL NEURAL NETWORK MODEL**

The neural network model has been chosen with respect to each arguments self.net value.

**Extension to Convolutional Networks:** 
Usually the dimensionality of convolution
parameters is smaller than the input image, therefore, it makes more sense computationally to **apply context to the weights** rather than the input.

In [8]:
def get_fc_net(name, input_dim, output_dim, activation, args):
    mynet = None
    if name == 'real':
        mynet = SimpleNet(input_dim,
                          args.n_units,
                          activation,
                          args.n_layers,
                          RealLinear)
        
    elif name == 'complex':
        mynet = SimpleNet(input_dim,
                        args.n_units,
                        activation,
                        args.n_layers,
                        ComplexLinear)
        
    elif name == 'rotatehash':
      layer_type = RouteLinear
      mynet = RealHashNet(input_dim,
                          args.n_units,
                          activation,
                          args.n_layers,
                          args.net_period,
                          args.key_pick,
                          args.learn_key,
                          layer_type)
      
    return mynet

      
def get_conv_net(name, input_dim, output_dim, activation, args):
    mynet = None
    if name == 'hashresnet18':
      mynet = HashResNet18(np.prod(output_dim))

    elif name == 'staticbnresnet18':
      mynet = StaticBNResNet18()

    return mynet
    



**GET DATASET, ACTIVATION LAYER, OPTİMİZER**

The has been chosen with respect to each arguments value.

In [9]:
def get_dataset(name, period, batch_size, train, kwargs):
  '''
  get_dataset --> load_data
  '''
  if name == 'rotating_mnist':
        data_loader = RotatingMNIST(period,
                                    batch_size,
                                    train=train,
                                    draw_and_rotate=True,
                                    kwargs=kwargs)

  elif name == 'incrementing_cifar':
        data_loader = IncrementingCIFAR(period,
                                        batch_size,
                                        train=train,
                                        kwargs=kwargs)
        
  return data_loader


def get_activation(name): 
    if name == 'tanh':
        activation = torch.tanh
    elif name == 'sigmoid':
        activation = torch.sigmoid
    elif name == 'relu':
        activation = torch.relu
    elif name == 'none':
        activation = torch.identity

    return activation


def get_optimizer(name, params, args):
    if name == 'rmsprop':
        optimizer = optim.RMSprop(params, lr=args.lr)
    
    return optimizer


def get_model_params(model, clone=True):
    myparams = []
    for parameter in model.parameters():
        if clone:
            myparams.append(parameter.clone())
        else:
            myparams.append(parameter)

    return myparams
    

def get_uncoupled_norm(x_list, y_list):
    sqsum = 0.
    for x,y in zip(x_list, y_list):
       sqsum += ((x - y)**2).sum().item()
    norm = np.sqrt(sqsum)
    return norm

In [10]:
def get_preprocess(flatten=False):
        if flatten:
            return lambda x: x.view(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3])
        return lambda x: x

**TRAIN THE MODEL**

In [11]:
def train(model, optimizer, time, data, loss_coeffs):
    input_data, target_data = data
    time_loss_coeff, s_loss_coeff = loss_coeffs

    optimizer.zero_grad()
    out_a, out_b, preacts = model(input_data, time)
    out_class = out_a
    logsm_class = F.log_softmax(out_class, 1)
    sm_class = F.softmax(out_class, 1)

    # Calculate entropy of classes
    loss_entropy = (-sm_class*logsm_class).sum(1).mean()
    loss_class = F.nll_loss(logsm_class, target_data)

    loss = loss_class
    loss.backward()
    optimizer.step()
    return loss, loss_class, loss_entropy

**TEST THE ACCURACY AND THE LOSS**

In [12]:
def test_set(model, test_loader, device, time, period, preprocess, steps):
    test_loss_class = 0
    test_loss_time = 0
    test_loss_entropy = 0
    correct = 0
    num_seen = 0
    model.eval()
    with torch.no_grad():
        for batch_idx in range(steps):
            # Set time before getting data to get correct angle
            test_loader.set_time(time*period - 1)
            input_data, target = test_loader.get_data()
            input_data, target = input_data.to(device), target.to(device)
            pp_input = preprocess(input_data)
            out_a, out_b, preacts = model(pp_input, time)

            out_class = out_a
            logsm_class = F.log_softmax(out_class, 1)
            sm_class = F.softmax(out_class, 1)

            loss_entropy = (-sm_class*logsm_class).sum()
            loss_class = F.nll_loss(logsm_class, target, reduction='sum')

            test_loss_class += loss_class.item()
            test_loss_entropy += loss_entropy.item()
            pred = logsm_class.max(1, keepdim=True)[1]

            correct += pred.eq(target.view_as(pred)).sum().item()
            num_seen += input_data.shape[0]

    test_loss_class /= num_seen
    test_loss_entropy /= num_seen
    test_acc_class = 100. * correct / num_seen 
    print('\nTest set: Time: {:5f}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        time, test_loss_class, correct, num_seen,
        test_acc_class))

    return test_loss_class, test_acc_class, test_loss_entropy

**WHICH TASKS DID I CHOSE?**

In [13]:
# Rotating MNIST dataset experiments
rotmnist_exps = [RotatingMNISTUnitaryHash(), # Dataset: rotating_mnist ,  period: 1000, batch_size: 128, net: rotatehash
                RotatingMNISTComplex(),      # Dataset: rotating_mnist ,  period: 1000, batch_size: 128, net: complex
                RotatingMNISTReal()]         # Dataset: rotating_mnist ,  period: 1000, batch_size: 128, net: real

# ICIFAR dataset experiments
icifar_exps = [ICIFARResNet18(),             # Dataset: incrementing_cifar ,  period: 20000, batch_size: 128, net: staticbnresnet18
              ICIFARHashResNet18()]          # Dataset: incrementing_cifar ,  period: 20000, batch_size: 128, net: hashresnet18   
    
exps_to_run = (rotmnist_exps + icifar_exps)


**MAINCODE**

In [14]:

for args in exps_to_run:
  # Print the datasets in a readable way
  pprint.pprint(vars(args))

  # Obtain reproducible results
  # args.seed --> gives 1 for every dataset
  torch.manual_seed((args.seed)) 

  use_cuda = not args.no_cuda and torch.cuda.is_available()
  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  
  #
  train_loader = get_dataset(args.dataset, args.period, args.batch_size, True, kwargs)
  input_dim, output_dim = train_loader.get_dim() # --> (32, 32, 3) (10,)
  
  test_loader = get_dataset(args.dataset, args.period, args.test_batch_size, False, kwargs)
 
  # get_activation
  activation = get_activation(args.activation)
  
  mynet = get_fc_net(args.net, input_dim, output_dim, activation, args)

  if mynet:
    # Fully Connected Layer
    flat_input = True
  else:
    # Convolutional Layer
    mynet = get_conv_net(args.net, input_dim, output_dim, activation, args)
    flat_input = False 
  mynet = mynet.to(device)

  # Optimizer
  optimizer = get_optimizer(args.optimizer, mynet.parameters(), args)

  preprocess = get_preprocess(flat_input)

  current_time = datetime.now().strftime('%b%d_%H-%M-%S')

  # Save directory location
  # Example log_dir: runs/icifar_hashresnet18/Nov25_21-08-45_4d5526ef1f13
  log_dir = os.path.join('runs', args.desc, current_time + '_' + socket.gethostname())

  # create a summary writer using the specified (log_dir) folder name.
  writer = SummaryWriter(log_dir=log_dir)
  writer.add_text('Args', pprint.pformat(vars(args)), 0)

  with open(os.path.join(log_dir, 'args.json'), 'w') as fp:
    json.dump(vars(args), fp, sort_keys=True, indent=4)

    for batch_idx in range(args.steps):
        global_step = batch_idx
        if global_step < args.stationary:
            train_loader.set_time(np.random.randint(args.period))

        time_start = time.time()
        input_data, target = train_loader.get_data()
        input_data, target = input_data.to(device), target.to(device)
        pp_input = preprocess(input_data)

        mynet.train()
        if batch_idx % 100 == 0:
            params_tm1 = get_model_params(mynet, clone=True)
        net_time = train_loader.time() % args.cheat_period
        net_time /= args.time_slow
        losses = train(mynet,
                       optimizer,
                       net_time,
                       (pp_input, target),
                       (args.time_loss_coeff, args.s_loss_coeff))
        if batch_idx % 100 == 0:
            params_t = get_model_params(mynet, clone=False)
            norm_delta_params = get_uncoupled_norm(params_tm1, params_t)

        loss, loss_class, loss_entropy = losses
        time_stop = time.time()

        if batch_idx % 100 == 0:
            print(batch_idx, loss.item(), train_loader.current_time, time_stop-time_start)
            if args.shuffle_test:
                test_time = np.random.randint(args.period)
            else:
                test_time = args.test_time

            test_time = test_time % args.cheat_period
            test_time /= args.time_slow
            test_losses = test_set(mynet,
                                   test_loader,
                                   device,
                                   test_time,
                                   args.period,
                                   preprocess,
                                   args.test_steps)
            test_loss_class, test_acc_class, test_loss_entropy = test_losses

            writer.add_scalar('norm_delta_params', norm_delta_params, global_step)
            writer.add_scalar('local_loss_class', loss_class, global_step)
            writer.add_scalar('local_loss_entropy', loss_entropy, global_step)
            writer.add_scalar('local_train_loss', loss, global_step)
            writer.add_scalar('test_acc_class', test_acc_class, global_step)
            writer.add_scalar('test_loss_class', test_loss_class, global_step)
            writer.add_scalar('test_loss_entropy', test_loss_entropy, global_step)
            writer.add_scalar('system_loop_time', time_stop-time_start, global_step) 
            img = tvu.make_grid(input_data[:32], normalize=True)
            writer.add_image('train_image', img, global_step)

        if batch_idx % 1000 == 0:
            save_path = os.path.join(log_dir, 'mynet_%d.pth' % batch_idx)
            torch.save(mynet, save_path)

    # Test on all the tasks to get an average accuracy accross all tasks
    n_tasks = int(args.steps/args.period)
    print('args.steps:', args.steps)
    print('args.period:', args.period)
    total_acc = 0.
    for task_i in range(n_tasks):
        test_time = task_i 
        test_losses = test_set(mynet,
                               test_loader,
                               device,
                               test_time,
                               args.period,
                               preprocess,
                               args.test_steps)
        test_loss_class, test_acc_class, test_loss_entropy = test_losses
        writer.add_scalar('retro_acc', test_acc_class, task_i)
        print('test_time:', test_time, 'acc:', test_acc_class)
        total_acc += test_acc_class

    print('total_acc/n_tasks:', total_acc/n_tasks, 'n_tasks:', n_tasks)
    #writer.add_scalar('avg_acc', total_acc/n_tasks, global_step)
    writer.close()

 

{'activation': 'relu',
 'batch_size': 128,
 'cheat_period': 1000,
 'dataset': 'rotating_mnist',
 'desc': 'rot_mnist_rothash',
 'key_pick': 'hash',
 'learn_key': True,
 'lr': 0.0001,
 'momentum': 0.5,
 'n_layers': 3,
 'n_units': 256,
 'net': 'rotatehash',
 'net_period': 50,
 'no_cuda': False,
 'optimizer': 'rmsprop',
 'period': 1000,
 'rotate_continually': False,
 's_loss_coeff': 0.0,
 'seed': 1,
 'shuffle_test': False,
 'stationary': 0,
 'steps': 10000,
 'test_batch_size': 1000,
 'test_steps': 10,
 'test_time': 0,
 'time_loss_coeff': 0.0,
 'time_slow': 100.0}
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

0 2.455657482147217 0 6.666759967803955

Test set: Time: 0.000000, Average loss: 2.2040, Accuracy: 3033/10000 (30%)

100 2.441880941390991 100 0.04157900810241699

Test set: Time: 0.000000, Average loss: 0.6800, Accuracy: 7871/10000 (79%)

200 2.3349008560180664 200 0.059699058532714844

Test set: Time: 0.000000, Average loss: 0.6953, Accuracy: 7859/10000 (79%)

300 2.5598978996276855 300 0.04298043251037598

Test set: Time: 0.000000, Average loss: 0.7089, Accuracy: 7838/10000 (78%)

400 2.3328888416290283 400 0.043889760971069336

Test set: Time: 0.000000, Average loss: 0.7317, Accuracy: 7789/10000 (78%)

500 2.401526927947998 500 0.04452848434448242

Test set: Time: 0.000000, Average loss: 0.7481, Accuracy: 7765/10000 (78%)

600 2.5326545238494873 600 0.04862189292907715

Test set: Time: 0.000000, Average loss: 0.7690, Accuracy: 7724/10000 (77%)

700 2.4958410263061523 700 0.04379749298095703

Test set: Time: 0.00

  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar10/cifar-10-python.tar.gz to data/cifar10
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar100/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting data/cifar100/cifar-100-python.tar.gz to data/cifar100
Files already downloaded and verified
Files already downloaded and verified
0 2.4115757942199707 0 4.2054526805877686

Test set: Time: 0.000000, Average loss: 4.7507, Accuracy: 1042/10000 (10%)

100 1.548608422279358 100 0.21002483367919922

Test set: Time: 0.000000, Average loss: 1.6664, Accuracy: 3686/10000 (37%)

200 1.2843883037567139 200 0.20899462699890137

Test set: Time: 0.000000, Average loss: 1.4094, Accuracy: 4815/10000 (48%)

300 1.5061014890670776 300 0.2150418758392334

Test set: Time: 0.000000, Average loss: 1.3180, Accuracy: 5021/10000 (50%)

400 1.2290178537368774 400 0.21500205993652344

Test set: Time: 0.000000, Average loss: 1.1730, Accuracy: 5835/10000 (58%)

500 1.2036361694335938 500 0.22011208534240723

Test set: Time: 0.000000, Average loss: 1.1808, Accuracy: 5740/10000 (57%)

600 1.0075725317001343 600 0.21592044830322266

Test set: Time: 0.000000, Average loss: 1.0453, Accuracy: 6205/10000 (62%

ZeroDivisionError: ignored

In [None]:
mynet

In [None]:
x = torch.randn(3)
x_c = x.to('cpu')