In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import numpy as np
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, CIFAR10
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from tqdm import tqdm
import json

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU is not available")

# add reproducibility stuff
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


GPU is available


In [2]:
def fastfood_vars(DD, device=0):
    """
    Returns parameters for fast food transform
    :param DD: desired dimension
    :return:
    """
    ll = int(np.ceil(np.log(DD) / np.log(2)))
    LL = 2**ll

    # Binary scaling matrix where $B_{i,i} \in \{\pm 1 \}$ drawn iid
    BB = torch.FloatTensor(LL).uniform_(0, 2).type(torch.LongTensor)
    BB = (BB * 2 - 1).type(torch.FloatTensor).to(device)
    BB.requires_grad = False

    # Random permutation matrix
    Pi = torch.LongTensor(np.random.permutation(LL)).to(device)
    Pi.requires_grad = False

    # Gaussian scaling matrix, whose elements $G_{i,i} \sim \mathcal{N}(0, 1)$
    GG = (
        torch.FloatTensor(
            LL,
        )
        .normal_()
        .to(device)
    )
    GG.requires_grad = False

    divisor = torch.sqrt(LL * torch.sum(torch.pow(GG, 2)))

    return [BB, Pi, GG, divisor, LL]


In [3]:
def fast_walsh_hadamard_torched(x, axis=0, normalize=False):
    """
    Performs fast Walsh Hadamard transform
    :param x:
    :param axis:
    :param normalize:
    :return:
    """
    orig_shape = x.size()
    assert axis >= 0 and axis < len(
        orig_shape
    ), "For a vector of shape %s, axis must be in [0, %d] but it is %d" % (
        orig_shape,
        len(orig_shape) - 1,
        axis,
    )
    h_dim = orig_shape[axis]
    h_dim_exp = int(round(np.log(h_dim) / np.log(2)))
    assert h_dim == 2**h_dim_exp, (
        "hadamard can only be computed over axis with size that is a power of two, but"
        " chosen axis %d has size %d" % (axis, h_dim)
    )

    working_shape_pre = [int(np.prod(orig_shape[:axis]))]  # prod of empty array is 1 :)
    working_shape_post = [
        int(np.prod(orig_shape[axis + 1 :]))
    ]  # prod of empty array is 1 :)
    working_shape_mid = [2] * h_dim_exp
    working_shape = working_shape_pre + working_shape_mid + working_shape_post

    ret = x.view(working_shape)

    for ii in range(h_dim_exp):
        dim = ii + 1
        arrs = torch.chunk(ret, 2, dim=dim)
        assert len(arrs) == 2
        ret = torch.cat((arrs[0] + arrs[1], arrs[0] - arrs[1]), axis=dim)

    if normalize:
        ret = ret / torch.sqrt(float(h_dim))

    ret = ret.view(orig_shape)

    return ret


In [4]:
def fastfood_torched(x, DD, param_list=None, device=0):
    """
    Fastfood transform
    :param x: array of dd dimension
    :param DD: desired dimension
    :return:
    """
    dd = x.size(0)

    if not param_list:

        BB, Pi, GG, divisor, LL = fastfood_vars(DD, device=device)

    else:

        BB, Pi, GG, divisor, LL = param_list

    # Padd x if needed
    dd_pad = F.pad(x, pad=(0, LL - dd), value=0, mode="constant")

    # From left to right HGPiH(BX), where H is Walsh-Hadamard matrix
    mul_1 = torch.mul(BB, dd_pad)
    # HGPi(HBX)
    mul_2 = fast_walsh_hadamard_torched(mul_1, 0, normalize=False)

    # HG(PiHBX)
    mul_3 = mul_2[Pi]

    # H(GPiHBX)
    mul_4 = torch.mul(mul_3, GG)

    # (HGPiHBX)
    mul_5 = fast_walsh_hadamard_torched(mul_4, 0, normalize=False)

    ret = torch.div(mul_5[:DD], divisor * np.sqrt(float(DD) / LL))

    return ret


In [5]:
class FastfoodWrapper(nn.Module):
    def __init__(self, module, intrinsic_dimension, device):
        """
        Wrapper to estimate the intrinsic dimensionality of the
        objective landscape for a specific task given a specific model using FastFood transform
        :param module: pytorch nn.Module
        :param intrinsic_dimension: dimensionality within which we search for solution
        :param device: cuda device id
        """
        super(FastfoodWrapper, self).__init__()

        # Hide this from inspection by get_parameters()
        self.m = [module]

        self.name_base_localname = []

        # Stores the initial value: \theta_{0}^{D}
        self.initial_value = dict()

        # Fastfood parameters
        self.fastfood_params = {}

        # Parameter vector that is updated
        # Initialised with zeros as per text: \theta^{d}
        V = nn.Parameter(torch.zeros((intrinsic_dimension)).to(device))
        self.register_parameter("V", V)
        v_size = (intrinsic_dimension,)

        # Iterate over layers in the module
        for name, param in module.named_parameters():
            # If param requires grad update
            if param.requires_grad:

                # Saves the initial values of the initialised parameters from param.data and sets them to no grad.
                # (initial values are the 'origin' of the search)
                self.initial_value[name] = v0 = (
                    param.clone().detach().requires_grad_(False).to(device)
                )

                # Generate fastfood parameters
                DD = np.prod(v0.size())
                self.fastfood_params[name] = fastfood_vars(DD, device)

                base, localname = module, name
                while "." in localname:
                    prefix, localname = localname.split(".", 1)
                    base = base.__getattr__(prefix)
                self.name_base_localname.append((name, base, localname))

        for name, base, localname in self.name_base_localname:
            delattr(base, localname)

    def forward(self, x):
        # Iterate over layers
        for name, base, localname in self.name_base_localname:

            init_shape = self.initial_value[name].size()
            DD = np.prod(init_shape)

            # Fastfood transform te replace dence P
            ray = fastfood_torched(self.V, DD, self.fastfood_params[name]).view(
                init_shape
            )

            param = self.initial_value[name] + ray

            setattr(base, localname, param)

        # Pass through the model, by getting hte module from a list self.m
        module = self.m[0]
        x = module(x)
        return x


In [6]:
class DenseWrap(nn.Module):
    def __init__(self, module, intrinsic_dimension, device):
        """
        Wrapper to estimate the intrinsic dimensionality of the
        objective landscape for a specific task given a specific model
        :param module: pytorch nn.Module
        :param intrinsic_dimension: dimensionality within which we search for solution
        :param device: cuda device id
        """
        super(DenseWrap, self).__init__()

        # Hide this from inspection by get_parameters()
        self.m = [module]

        self.name_base_localname = []

        # Stores the initial value: \theta_{0}^{D}
        self.initial_value = dict()

        # Stores the randomly generated projection matrix P
        self.random_matrix = dict()

        # Parameter vector that is updated, initialised with zeros as per text: \theta^{d}
        V = nn.Parameter(torch.zeros((intrinsic_dimension, 1)).to(device))
        self.register_parameter("V", V)
        v_size = (intrinsic_dimension,)

        # Iterates over layers in the Neural Network
        for name, param in module.named_parameters():
            # If the parameter requires gradient update
            if param.requires_grad:

                # Saves the initial values of the initialised parameters from param.data and sets them to no grad.
                # (initial values are the 'origin' of the search)
                self.initial_value[name] = v0 = (
                    param.clone().detach().requires_grad_(False).to(device)
                )

                # If v0.size() is [4, 3], then below operation makes it [4, 3, v_size]
                matrix_size = v0.size() + v_size

                # Generates random projection matrices P, sets them to no grad
                self.random_matrix[name] = (
                    torch.randn(matrix_size, requires_grad=False).to(device)
                    / intrinsic_dimension**0.5
                )

                # NOTE!: lines below are not clear!
                base, localname = module, name
                while "." in localname:
                    prefix, localname = localname.split(".", 1)
                    base = base.__getattr__(prefix)
                self.name_base_localname.append((name, base, localname))

        for name, base, localname in self.name_base_localname:
            delattr(base, localname)

    def forward(self, x):
        # Iterate over the layers
        for name, base, localname in self.name_base_localname:

            # Product between matrix P and \theta^{d}
            ray = torch.matmul(self.random_matrix[name], self.V)

            # Add the \theta_{0}^{D} to P \dot \theta^{d}
            param = self.initial_value[name] + torch.squeeze(ray, -1)

            setattr(base, localname, param)

        # Pass through the model, by getting the module from a list self.m
        module = self.m[0]
        x = module(x)
        return x


In [7]:
# implementig the code from the paper https://arxiv.org/abs/1804.08838 in pytorch
"""class Classifier(nn.Module):
    def __init__(self, input_dim, n_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_dim, n_classes)
        self.maxpool = nn.AdaptiveMaxPool2d(1)

    def forward(self, x):
        x = self.maxpool(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x


def get_resnet(encoder_name, num_classes, pretrained=False):
    assert encoder_name in [
        "resnet18",
        "resnet50",
    ], "{} is a wrong encoder name!".format(encoder_name)
    if encoder_name == "resnet18":
        model = models.resnet18(pretrained=pretrained)
        latent_dim = 512
    else:
        model = models.resnet50(pretrained=pretrained)
        latent_dim = 2048
    children = list(model.children())[:-2] + [Classifier(latent_dim, num_classes)]
    model = torch.nn.Sequential(*children)
    return model


# Get model and wrap it in fastfood
model = get_resnet("resnet18", num_classes=YOUR_NUMBER_OF_CLASSES).cuda()
model = FastfoodWrapper(model, intrinsic_dimension=100, device=device)"""


'class Classifier(nn.Module):\n    def __init__(self, input_dim, n_classes):\n        super(Classifier, self).__init__()\n        self.fc = nn.Linear(input_dim, n_classes)\n        self.maxpool = nn.AdaptiveMaxPool2d(1)\n\n    def forward(self, x):\n        x = self.maxpool(x)\n        x = x.reshape(x.size(0), -1)\n        x = self.fc(x)\n        return x\n\n\ndef get_resnet(encoder_name, num_classes, pretrained=False):\n    assert encoder_name in [\n        "resnet18",\n        "resnet50",\n    ], "{} is a wrong encoder name!".format(encoder_name)\n    if encoder_name == "resnet18":\n        model = models.resnet18(pretrained=pretrained)\n        latent_dim = 512\n    else:\n        model = models.resnet50(pretrained=pretrained)\n        latent_dim = 2048\n    children = list(model.children())[:-2] + [Classifier(latent_dim, num_classes)]\n    model = torch.nn.Sequential(*children)\n    return model\n\n\n# Get model and wrap it in fastfood\nmodel = get_resnet("resnet18", num_classes=YO

In [8]:
BATCH_SIZE = 128
DATASET_NAME = "CIFAR10"

img_transform = transforms.Compose(
    [transforms.ToTensor()]
)

train_dataset = None
test_dataset = None
if DATASET_NAME == "MNIST":
    train_dataset = MNIST(
        root="./data/MNIST", download=True, train=True, transform=img_transform
    )
    test_dataset = MNIST(
    root="./data/MNIST", download=True, train=False, transform=img_transform
    )
elif DATASET_NAME == "CIFAR10":
    train_dataset = CIFAR10(
    root="./data/CIFAR10", download=True, train=True, transform=img_transform
    )
    test_dataset = CIFAR10(
    root="./data/CIFAR10", download=True, train=False, transform=img_transform
    )
else:
    raise Exception("Name of dataset not in: [MNIST, CIFAR10]")



Files already downloaded and verified
Files already downloaded and verified


In [9]:
if DATASET_NAME == "MNIST":
    channel_in = 1
    input_dim = 28*28*channel_in
    output_dim = 10
    idx = 0 # @param {type:"slider", min:0, max:59999, step:1}
else:
    channel_in = 3
    input_dim = 32*32*channel_in
    output_dim = 10
    idx = 0 # @param {type:"slider", min:0, max:49999, step:1}

px.imshow(train_dataset.data[idx])

In [10]:
# Class for a Fully Connected Network
class FullyConnectedNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(FullyConnectedNetwork, self).__init__()
        self.num_layers = num_layers
        self.fc_in = nn.Linear(input_dim, hidden_dim)
        if num_layers > 0:
            self.fcs = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.fc_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc_in(x))
        if self.num_layers > 0:
            for fc in self.fcs: 
                x = F.relu(fc(x))
        x = self.fc_out(x)
        return x


In [15]:
# Class for Standard LeNet Network, reference http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf with some modification to follow the same number of parameters as the main paper for the task does
class LeNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LeNet, self).__init__()
        # 6 kernels 5x5
        self.conv1 = nn.Conv2d(input_dim, 6, 5, padding='valid',)
        # max-pooling over 2x2
        self.pool1 = nn.MaxPool2d(2, stride=2)
        # 16 kernels 5x5
        self.conv2 = nn.Conv2d(6, 16, 5, padding='valid')
        # max-pooling over 2x2
        self.pool2 = nn.MaxPool2d(2, stride=2)
        # 120 kernels 4x4 to match the dimensionality of the fully connected network
        self.conv3 = nn.Conv2d(16, 120, 5,)
        # 120 fully connected neurons, too many parameter in this case w.r.t. the paper
        #self.fc1 = nn.Linear(16 * 5 * 5, 120,)
        self.flat = nn.Flatten(start_dim=1)
        # 84 fully connected neurons
        self.fc2 = nn.Linear(120, 84)
        # 10 fully connected neurons
        self.fc3 = nn.Linear(84, output_dim,)

    def forward(self, x):
        #x = x.view(-1, 1, 28, 28)
        print(x.shape)
        x = self.conv1(x)
        print(x.shape)
        x = self.pool1(F.relu(x))
        print(x.shape)
        x = self.pool2(F.relu(self.conv2(x)))
        print(x.shape)
        x = F.relu(self.conv3(x))
        x = self.flat(x)
        #x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [7]:
from torch.nn.modules.utils import _pair

# from https://discuss.pytorch.org/t/locally-connected-layers/26979
class LocallyConnected2d(nn.Module):
    def __init__(self, in_channels, out_channels, output_size, kernel_size, stride=1, bias=True):
        super(LocallyConnected2d, self).__init__()
        output_size = _pair(output_size)
        self.weight = nn.Parameter(
            nn.init.kaiming_normal_(torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size**2), nonlinearity='relu')
        )
        if bias:
            self.bias = nn.Parameter(nn.init.kaiming_normal_(
                torch.randn(1, out_channels, output_size[0], output_size[1]), nonlinearity='relu')
            )
        else:
            self.register_parameter('bias', None)
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        
    def forward(self, x):
        _, c, h, w = x.size()
        kh, kw = self.kernel_size
        dh, dw = self.stride
        x = x.unfold(2, kh, dh).unfold(3, kw, dw)
        x = x.contiguous().view(*x.size()[:-2], -1)
        # Sum in in_channel and kernel_size dims
        out = (x.unsqueeze(1) * self.weight).sum([2, -1])
        if self.bias is not None:
            out += self.bias
        return out

# Class for Untied LeNet Network
class Untied_LeNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Untied_LeNet, self).__init__()
        # 6 kernels 5x5
        self.conv1 = LocallyConnected2d(input_dim, 6, (24,24), 5)
        # max-pooling over 2x2
        self.pool1 = nn.MaxPool2d(2, stride=2)
        # 16 kernels 5x5
        self.conv2 = LocallyConnected2d(6, 16, (8,8), 5)
        # max-pooling over 2x2
        self.pool2 = nn.MaxPool2d(2, stride=2)
        # 120 kernels 4x4 to match the dimensionality of the fully connected network
        self.conv3 = LocallyConnected2d(16, 120, (1,1), 4)
        # 120 fully connected neurons, too many parameter in this case w.r.t. the paper
        #self.fc1 = nn.Linear(16 * 5 * 5, 120,)
        self.flat = nn.Flatten(start_dim=1)
        # 84 fully connected neurons
        self.fc2 = nn.Linear(120, 84)
        # 10 fully connected neurons
        self.fc3 = nn.Linear(84, output_dim,)

    def forward(self, x):
        #x = x.view(-1, 1, 28, 28)
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.flat(x)
        #x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [10]:
# Class for FC-LeNet Network
class FcLeNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FcLeNet, self).__init__()
        # 6 kernels 5x5
        self.fcconv1 = nn.Linear(input_dim, 3456)
        # max-pooling over 2x2
        self.pool1 = nn.MaxPool2d(2, stride=2)
        # 16 kernels 5x5
        self.fcconv2 = nn.Linear(864, 1024)
        # max-pooling over 2x2
        self.pool2 = nn.MaxPool2d(2, stride=2)
        # 120 kernels 4x4 to match the dimensionality of the fully connected network
        self.fcconv3 = nn.Linear(256, 120)
        # 120 fully connected neurons, too many parameter in this case w.r.t. the paper
        #self.fc1 = nn.Linear(16 * 5 * 5, 120,)
        self.flat = nn.Flatten(start_dim=1)
        # 84 fully connected neurons
        self.fc2 = nn.Linear(120, 84)
        # 10 fully connected neurons
        self.fc3 = nn.Linear(84, output_dim,)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.pool1(F.relu(self.fcconv1(x)).view(-1,6,24,24))
        x = self.flat(x)
        x = self.pool2(F.relu(self.fcconv2(x)).view(-1,16,8,8))
        x = self.flat(x)
        x = F.relu(self.fcconv3(x))
        #x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [22]:
# Class for FCTied-LeNet
class FCTied_LeNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FCTied_LeNet, self).__init__()
        # 6 kernels 5x5
        self.conv1 = nn.Conv2d(input_dim, 6, 55, padding='same',)
        # max-pooling over 2x2
        self.pool1 = nn.MaxPool2d(2, stride=2)
        # 16 kernels 5x5
        self.conv2 = nn.Conv2d(6, 16, 27, padding='same')
        # max-pooling over 2x2
        self.pool2 = nn.MaxPool2d(2, stride=2)
        # 120 kernels 4x4 to match the dimensionality of the fully connected network
        self.conv3 = nn.Conv2d(16, 120, 7,)
        # 120 fully connected neurons, too many parameter in this case w.r.t. the paper
        #self.fc1 = nn.Linear(16 * 5 * 5, 120,)
        self.flat = nn.Flatten(start_dim=1)
        # 84 fully connected neurons
        self.fc2 = nn.Linear(120, 84)
        # 10 fully connected neurons
        self.fc3 = nn.Linear(84, output_dim,)

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.flat(x)
        #x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [11]:
hidden_dim = 200
num_layers = 1
model = FullyConnectedNetwork(input_dim, hidden_dim, output_dim, num_layers)
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: ", num_params)

#save information to file
with open('results.txt', 'a') as f:
    f.write("\n##############################################")
    f.write(f"\nNumber_of_parameters: {num_params}")
    f.write(f"\nhidden_dim: {hidden_dim}")
    f.write(f"\nnum_layers: {num_layers}")
f.close()


Number of parameters:  199210


In [12]:
modules = [module for module in model.modules()]
# Print Model Summary
print(modules[0])

FullyConnectedNetwork(
  (fc_in): Linear(in_features=784, out_features=200, bias=True)
  (fcs): ModuleList(
    (0): Linear(in_features=200, out_features=200, bias=True)
  )
  (fc_out): Linear(in_features=200, out_features=10, bias=True)
)


In [16]:
model = LeNet(channel_in, output_dim)
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: %d" % num_params)

with open('results.txt', 'a') as f:
    f.write("\n##############################################")
    f.write(f"\nNumber_of_parameters: {num_params}")
    f.write("\ntype: LeNet")
f.close()

Number of parameters: 62006


In [12]:
modules = [module for module in model.modules()]
# Print Model Summary
print(modules[0])

LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1), padding=valid)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1), padding=valid)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(16, 120, kernel_size=(4, 4), stride=(1, 1))
  (flat): Flatten(start_dim=1, end_dim=-1)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [9]:
model = Untied_LeNet(channel_in, output_dim)
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: %d" % num_params)

Number of parameters: 286334


In [11]:
model = FcLeNet(input_dim, output_dim)
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: %d" % num_params)


Number of parameters: 3640574


In [12]:
modules = [module for module in model.modules()]
# Print Model Summary
print(modules[0])

FcLeNet(
  (fcconv1): Linear(in_features=784, out_features=3456, bias=True)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fcconv2): Linear(in_features=864, out_features=1024, bias=True)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fcconv3): Linear(in_features=256, out_features=120, bias=True)
  (flat): Flatten(start_dim=1, end_dim=-1)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [12]:
model = FCTied_LeNet(channel_in, output_dim)
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: %d" % num_params)

NameError: name 'FCTied_LeNet' is not defined

In [17]:
intrinsic_dim = 150
model_intrinsic = DenseWrap(model, intrinsic_dimension=intrinsic_dim, device=device)
num_params_intrinsic = sum(p.numel() for p in model_intrinsic.parameters() if p.requires_grad)
print("Number of parameters: %d" % num_params_intrinsic)
with open('results.txt', 'a') as f:
    f.write(f"\nintrinsic_dim: {intrinsic_dim}")
f.close()

Number of parameters: 150


In [18]:
#torch.autograd.set_detect_anomaly(True) #this line can have huge performance impact
# train the model
from logging import raiseExceptions

# training step

def train(model, train_loader, optimizer, epoch):
    model.train()
    #train_loss_averager = make_averager()  # mantain a running average of the loss

    # TRAIN
    tqdm_iterator = tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc="",
        leave=True,)

    len_tr_dl_ds = len(train_loader.dataset)

    for batch_idx, (data, target) in tqdm_iterator:
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        output = model(data)
        loss = F.cross_entropy(output, target)
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        #train_loss_averager(loss.item())
        tqdm_iterator.set_description(
            f"Train Epoch: {epoch} [ {batch_idx * len(data)}/{len_tr_dl_ds} \tLoss: {loss.item():.6f}]"
        )
        tqdm_iterator.refresh()  # to show immediately the update
    tqdm_iterator.close()

    if np.isnan(loss.item()):
        print("Loss is nan")
        raise Exception("Loss is nan")
        exit()


# testing step
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    tqdm_iterator = tqdm(
        enumerate(test_loader),
        total=len(test_loader),
        desc="",
        leave=True,)

    len_ts_dl_ds = len(test_loader.dataset)

    with torch.no_grad():
        for batch_idx, (data, target) in tqdm_iterator:
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
            output = model(data)
            test_loss += F.cross_entropy(output, target).item()  # sum up batch loss
            # get the index of the max probability
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).cpu().sum().item()
            tqdm_iterator.set_description(
            f"Test Epoch: {epoch} [{batch_idx * len(data)}/{len_ts_dl_ds} \tLoss: {test_loss:.6f}, Accuracy: {correct}/{len_ts_dl_ds} ({100.0 * correct / len_ts_dl_ds}%)"
        )
    tqdm_iterator.refresh()  # to show immediately the update
    test_loss /= len(test_loader.dataset)
    print(f"Validation Average loss: {test_loss:.6f}")
    
    tqdm_iterator.close()
    
    # show an histogram of the weights of the model
    """start = -1
    stop = 1
    bins = 30
    for param in model.parameters():
        if param.requires_grad:
            
            hist = torch.histc(param.data, bins = bins, min = start, max = stop)
            x = np.arange(start, stop, (stop-start)/bins)
            plt.bar(x, hist.cpu(), align='center')
            plt.ylabel('Frequency')
            plt.show() """

    return correct / len(test_loader.dataset)

if __name__ == "__main__":
    learning_rate = 0.01
    optimizer = optim.SGD(model_intrinsic.parameters(), lr=learning_rate)
    # download and load MNIST Dataset
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = 4, pin_memory=True, persistent_workers=True)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = 2, pin_memory=True, persistent_workers=True)
    # train the model
    best_acc, best_epoch = 0, 0
    for epoch in range(1, 101):
        train(model_intrinsic, train_dataloader, optimizer, epoch)
        accuracy = test(model_intrinsic, test_dataloader)
        print("Validation Accuracy: {}".format(accuracy))
        #save information to file
        with open('results.txt', 'a') as f:
            f.write(f"\nEpoch: {epoch}")
            f.write(f"\nValidation Accuracy: {accuracy}")
        f.close()
        if accuracy > best_acc:
            best_acc = accuracy
            best_epoch = epoch
            if best_acc >= 0.90:
                torch.save(model_intrinsic.state_dict(), f"lenet_mnist_{intrinsic_dim}.pt")
                #torch.save(model_intrinsic.state_dict(), f"model_best_h{hidden_dim}_id{intrinsic_dim}_lay{num_layers}.pt")
                """ j = None
                with open('results.json', 'r') as f:
                    j = json.load(f)
                f.close()
                with open('results.json', 'w') as f:
                    j[f"fcmodel_h{hidden_dim}_id{intrinsic_dim}_lay{num_layers}"] = {"number_parameter": num_params, "hidden_dimension": hidden_dim, "number_layers": num_layers, "intrinsic_dimension": intrinsic_dim, "epoch": epoch, "validation_accuracy": accuracy}
                    json.dump(j, f, indent=4, separators=(',', ': ')) 
                break """

    j = None
    with open('results_lenet.json', 'r') as f:
        j = json.load(f)
    f.close()
    with open('results_lenet.json', 'w') as f:
        j[f"lenet_model_id{intrinsic_dim}_lr{learning_rate}"] = {"number_parameter": num_params, "intrinsic_dimension": intrinsic_dim, "epoch": epoch, 
        "validation_accuracy": accuracy, "learning_rate": learning_rate, "best_epoch": best_epoch, "best_accuracy": best_acc}
        json.dump(j, f, indent=4, separators=(',', ': '))
    f.close()

    """ j = None
    with open('results_lenet.json', 'r') as f:
        j = json.load(f)
    f.close()
    with open('results_lenet.json', 'w') as f:
        j[f"fcmodel_h{hidden_dim}_id{intrinsic_dim}_lay{num_layers}_lr{learning_rate}"] = {"number_parameter": num_params, 
        "hidden_dimension": hidden_dim, "number_layers": num_layers, "intrinsic_dimension": intrinsic_dim, "epoch": epoch, 
        "validation_accuracy": accuracy, "learning_rate": learning_rate, "best_epoch": best_epoch, "best_accuracy": best_acc}
        json.dump(j, f, indent=4, separators=(',', ': '))
    f.close() """
"""     j = None
    with open('results.json', 'r') as f:
        j = json.load(f)
    f.close()
    with open('results.json', 'w') as f:
        j[f"fcmodel_h{hidden_dim}_id{intrinsic_dim}_lay{num_layers}"] = {"number_parameter": num_params, "hidden_dimension": hidden_dim, "number_layers": num_layers, "intrinsic_dimension": intrinsic_dim, "epoch": epoch, "validation_accuracy": accuracy}
        json.dump(j, f, indent=4, separators=(',', ': ')) """


Train Epoch: 1 [ 128/50000 	Loss: 2.303157]:   0%|          | 1/391 [00:00<01:13,  5.32it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])


Train Epoch: 1 [ 2560/50000 	Loss: 2.300147]:   5%|▌         | 20/391 [00:00<00:05, 66.09it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 5632/50000 	Loss: 2.306211]:   8%|▊         | 33/391 [00:00<00:04, 86.89it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 8832/50000 	Loss: 2.305226]:  16%|█▌        | 61/391 [00:00<00:02, 112.81it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 12288/50000 	Loss: 2.297648]:  23%|██▎       | 90/391 [00:00<00:02, 127.74it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 15616/50000 	Loss: 2.297792]:  30%|███       | 118/391 [00:01<00:02, 133.31it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 18944/50000 	Loss: 2.306632]:  38%|███▊      | 147/391 [00:01<00:01, 136.24it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 22272/50000 	Loss: 2.299867]:  41%|████      | 161/391 [00:01<00:01, 135.75it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 25600/50000 	Loss: 2.308383]:  49%|████▊     | 190/391 [00:01<00:01, 137.55it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 28928/50000 	Loss: 2.306080]:  56%|█████▌    | 219/391 [00:01<00:01, 138.72it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 32128/50000 	Loss: 2.305287]:  63%|██████▎   | 247/391 [00:02<00:01, 137.86it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 35456/50000 	Loss: 2.295831]:  70%|███████   | 275/391 [00:02<00:00, 136.75it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 38656/50000 	Loss: 2.302664]:  74%|███████▍  | 289/391 [00:02<00:00, 136.10it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 41984/50000 	Loss: 2.306463]:  81%|████████  | 317/391 [00:02<00:00, 135.78it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 45440/50000 	Loss: 2.299818]:  89%|████████▊ | 347/391 [00:02<00:00, 139.02it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 48768/50000 	Loss: 2.307637]:  96%|█████████▋| 377/391 [00:03<00:00, 139.66it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Train Epoch: 1 [ 31200/50000 	Loss: 2.300783]: 100%|██████████| 391/391 [00:03<00:00, 125.44it/s]


torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([80, 3, 32, 32])
torch.Size([80, 6, 28, 28])
torch.Size([80, 6, 14, 1

Test Epoch: 1 [2432/10000 	Loss: 46.032489, Accuracy: 269/10000 (2.69%):  11%|█▏        | 9/79 [00:00<00:00, 89.11it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14

Test Epoch: 1 [4736/10000 	Loss: 87.485102, Accuracy: 511/10000 (5.11%):  47%|████▋     | 37/79 [00:00<00:00, 101.37it/s]

torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14, 14])
torch.Size([128, 16, 5, 5])
torch.Size([128, 3, 32, 32])
torch.Size([128, 6, 28, 28])
torch.Size([128, 6, 14




In [60]:
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=[1, 3.2, 5.4, 7.6, 9.8],
    y=[1, 3.2, 5.4, 7.6, 9.8],
    mode='markers',
    marker=dict(
        color=[1, 2, 3, 4, 5],
        size=30,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 50'
))

fig.add_trace(go.Scatter(
    x=[1, 3.2, 5.4, 7.6, 9.8],
    y=[1, 3.2, 5.4, 7.6, 9.8],
    mode='markers',
    marker=dict(
        color=[1, 2, 3, 4, 5],
        size=55,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 100'
))

fig.add_trace(go.Scatter(
    x=[1, 3.2, 5.4, 7.6, 9.8],
    y=[1, 3.2, 5.4, 7.6, 9.8],
    mode='markers',
    marker=dict(
        color=[1, 2, 3, 4, 5],
        size=70,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 200'
))

fig.add_trace(go.Scatter(
    x=[1, 3.2, 5.4, 7.6, 9.8],
    y=[1, 3.2, 5.4, 7.6, 9.8],
    mode='markers',
    marker=dict(
        color=[1, 2, 3, 4, 5],
        size=90,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 400'
))

fig.update_layout(legend=dict(
    yanchor="top",
    y=0.99,
    xanchor="left",
    x=0.01
))

fig.show()