In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import numpy as np
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, CIFAR10
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from tqdm import tqdm
import json

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU is not available")

# add reproducibility stuff
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


GPU is available


In [2]:
def fastfood_vars(DD, device=0):
    """
    Returns parameters for fast food transform
    :param DD: desired dimension
    :return:
    """
    ll = int(np.ceil(np.log(DD) / np.log(2)))
    LL = 2**ll

    # Binary scaling matrix where $B_{i,i} \in \{\pm 1 \}$ drawn iid
    BB = torch.FloatTensor(LL).uniform_(0, 2).type(torch.LongTensor)
    BB = (BB * 2 - 1).type(torch.FloatTensor).to(device)
    BB.requires_grad = False

    # Random permutation matrix
    Pi = torch.LongTensor(np.random.permutation(LL)).to(device)
    Pi.requires_grad = False

    # Gaussian scaling matrix, whose elements $G_{i,i} \sim \mathcal{N}(0, 1)$
    GG = (
        torch.FloatTensor(
            LL,
        )
        .normal_()
        .to(device)
    )
    GG.requires_grad = False

    divisor = torch.sqrt(LL * torch.sum(torch.pow(GG, 2)))

    return [BB, Pi, GG, divisor, LL]


In [3]:
def fast_walsh_hadamard_torched(x, axis=0, normalize=False):
    """
    Performs fast Walsh Hadamard transform
    :param x:
    :param axis:
    :param normalize:
    :return:
    """
    orig_shape = x.size()
    assert axis >= 0 and axis < len(
        orig_shape
    ), "For a vector of shape %s, axis must be in [0, %d] but it is %d" % (
        orig_shape,
        len(orig_shape) - 1,
        axis,
    )
    h_dim = orig_shape[axis]
    h_dim_exp = int(round(np.log(h_dim) / np.log(2)))
    assert h_dim == 2**h_dim_exp, (
        "hadamard can only be computed over axis with size that is a power of two, but"
        " chosen axis %d has size %d" % (axis, h_dim)
    )

    working_shape_pre = [int(np.prod(orig_shape[:axis]))]  # prod of empty array is 1 :)
    working_shape_post = [
        int(np.prod(orig_shape[axis + 1 :]))
    ]  # prod of empty array is 1 :)
    working_shape_mid = [2] * h_dim_exp
    working_shape = working_shape_pre + working_shape_mid + working_shape_post

    ret = x.view(working_shape)

    for ii in range(h_dim_exp):
        dim = ii + 1
        arrs = torch.chunk(ret, 2, dim=dim)
        assert len(arrs) == 2
        ret = torch.cat((arrs[0] + arrs[1], arrs[0] - arrs[1]), axis=dim)

    if normalize:
        ret = ret / torch.sqrt(float(h_dim))

    ret = ret.view(orig_shape)

    return ret


In [4]:
def fastfood_torched(x, DD, param_list=None, device=0):
    """
    Fastfood transform
    :param x: array of dd dimension
    :param DD: desired dimension
    :return:
    """
    dd = x.size(0)

    if not param_list:

        BB, Pi, GG, divisor, LL = fastfood_vars(DD, device=device)

    else:

        BB, Pi, GG, divisor, LL = param_list

    # Padd x if needed
    dd_pad = F.pad(x, pad=(0, LL - dd), value=0, mode="constant")

    # From left to right HGPiH(BX), where H is Walsh-Hadamard matrix
    mul_1 = torch.mul(BB, dd_pad)
    # HGPi(HBX)
    mul_2 = fast_walsh_hadamard_torched(mul_1, 0, normalize=False)

    # HG(PiHBX)
    mul_3 = mul_2[Pi]

    # H(GPiHBX)
    mul_4 = torch.mul(mul_3, GG)

    # (HGPiHBX)
    mul_5 = fast_walsh_hadamard_torched(mul_4, 0, normalize=False)

    ret = torch.div(mul_5[:DD], divisor * np.sqrt(float(DD) / LL))

    return ret


In [5]:
class FastfoodWrapper(nn.Module):
    def __init__(self, module, intrinsic_dimension, device):
        """
        Wrapper to estimate the intrinsic dimensionality of the
        objective landscape for a specific task given a specific model using FastFood transform
        :param module: pytorch nn.Module
        :param intrinsic_dimension: dimensionality within which we search for solution
        :param device: cuda device id
        """
        super(FastfoodWrapper, self).__init__()

        # Hide this from inspection by get_parameters()
        self.m = [module]

        self.name_base_localname = []

        # Stores the initial value: \theta_{0}^{D}
        self.initial_value = dict()

        # Fastfood parameters
        self.fastfood_params = {}

        # Parameter vector that is updated
        # Initialised with zeros as per text: \theta^{d}
        V = nn.Parameter(torch.zeros((intrinsic_dimension)).to(device))
        self.register_parameter("V", V)
        v_size = (intrinsic_dimension,)

        # Iterate over layers in the module
        for name, param in module.named_parameters():
            # If param requires grad update
            if param.requires_grad:

                # Saves the initial values of the initialised parameters from param.data and sets them to no grad.
                # (initial values are the 'origin' of the search)
                self.initial_value[name] = v0 = (
                    param.clone().detach().requires_grad_(False).to(device)
                )

                # Generate fastfood parameters
                DD = np.prod(v0.size())
                self.fastfood_params[name] = fastfood_vars(DD, device)

                base, localname = module, name
                while "." in localname:
                    prefix, localname = localname.split(".", 1)
                    base = base.__getattr__(prefix)
                self.name_base_localname.append((name, base, localname))

        for name, base, localname in self.name_base_localname:
            delattr(base, localname)

    def forward(self, x):
        # Iterate over layers
        for name, base, localname in self.name_base_localname:

            init_shape = self.initial_value[name].size()
            DD = np.prod(init_shape)

            # Fastfood transform te replace dence P
            ray = fastfood_torched(self.V, DD, self.fastfood_params[name]).view(
                init_shape
            )

            param = self.initial_value[name] + ray

            setattr(base, localname, param)

        # Pass through the model, by getting hte module from a list self.m
        module = self.m[0]
        x = module(x)
        return x


In [6]:
class DenseWrap(nn.Module):
    def __init__(self, module, intrinsic_dimension, device):
        """
        Wrapper to estimate the intrinsic dimensionality of the
        objective landscape for a specific task given a specific model
        :param module: pytorch nn.Module
        :param intrinsic_dimension: dimensionality within which we search for solution
        :param device: cuda device id
        """
        super(DenseWrap, self).__init__()

        # Hide this from inspection by get_parameters()
        self.m = [module]

        self.name_base_localname = []

        # Stores the initial value: \theta_{0}^{D}
        self.initial_value = dict()

        # Stores the randomly generated projection matrix P
        self.random_matrix = dict()

        # Parameter vector that is updated, initialised with zeros as per text: \theta^{d}
        V = nn.Parameter(torch.zeros((intrinsic_dimension, 1)).to(device))
        self.register_parameter("V", V)
        v_size = (intrinsic_dimension,)

        # Iterates over layers in the Neural Network
        for name, param in module.named_parameters():
            # If the parameter requires gradient update
            if param.requires_grad:

                # Saves the initial values of the initialised parameters from param.data and sets them to no grad.
                # (initial values are the 'origin' of the search)
                self.initial_value[name] = v0 = (
                    param.clone().detach().requires_grad_(False).to(device)
                )

                # If v0.size() is [4, 3], then below operation makes it [4, 3, v_size]
                matrix_size = v0.size() + v_size

                # Generates random projection matrices P, sets them to no grad
                self.random_matrix[name] = (
                    torch.randn(matrix_size, requires_grad=False).to(device)
                    / intrinsic_dimension**0.5
                )

                # NOTE!: lines below are not clear!
                base, localname = module, name
                while "." in localname:
                    prefix, localname = localname.split(".", 1)
                    base = base.__getattr__(prefix)
                self.name_base_localname.append((name, base, localname))

        for name, base, localname in self.name_base_localname:
            delattr(base, localname)

    def forward(self, x):
        # Iterate over the layers
        for name, base, localname in self.name_base_localname:

            # Product between matrix P and \theta^{d}
            ray = torch.matmul(self.random_matrix[name], self.V)

            # Add the \theta_{0}^{D} to P \dot \theta^{d}
            param = self.initial_value[name] + torch.squeeze(ray, -1)

            setattr(base, localname, param)

        # Pass through the model, by getting the module from a list self.m
        module = self.m[0]
        x = module(x)
        return x


In [7]:
# implementig the code from the paper https://arxiv.org/abs/1804.08838 in pytorch
"""class Classifier(nn.Module):
    def __init__(self, input_dim, n_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_dim, n_classes)
        self.maxpool = nn.AdaptiveMaxPool2d(1)

    def forward(self, x):
        x = self.maxpool(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x


def get_resnet(encoder_name, num_classes, pretrained=False):
    assert encoder_name in [
        "resnet18",
        "resnet50",
    ], "{} is a wrong encoder name!".format(encoder_name)
    if encoder_name == "resnet18":
        model = models.resnet18(pretrained=pretrained)
        latent_dim = 512
    else:
        model = models.resnet50(pretrained=pretrained)
        latent_dim = 2048
    children = list(model.children())[:-2] + [Classifier(latent_dim, num_classes)]
    model = torch.nn.Sequential(*children)
    return model


# Get model and wrap it in fastfood
model = get_resnet("resnet18", num_classes=YOUR_NUMBER_OF_CLASSES).cuda()
model = FastfoodWrapper(model, intrinsic_dimension=100, device=device)"""


'class Classifier(nn.Module):\n    def __init__(self, input_dim, n_classes):\n        super(Classifier, self).__init__()\n        self.fc = nn.Linear(input_dim, n_classes)\n        self.maxpool = nn.AdaptiveMaxPool2d(1)\n\n    def forward(self, x):\n        x = self.maxpool(x)\n        x = x.reshape(x.size(0), -1)\n        x = self.fc(x)\n        return x\n\n\ndef get_resnet(encoder_name, num_classes, pretrained=False):\n    assert encoder_name in [\n        "resnet18",\n        "resnet50",\n    ], "{} is a wrong encoder name!".format(encoder_name)\n    if encoder_name == "resnet18":\n        model = models.resnet18(pretrained=pretrained)\n        latent_dim = 512\n    else:\n        model = models.resnet50(pretrained=pretrained)\n        latent_dim = 2048\n    children = list(model.children())[:-2] + [Classifier(latent_dim, num_classes)]\n    model = torch.nn.Sequential(*children)\n    return model\n\n\n# Get model and wrap it in fastfood\nmodel = get_resnet("resnet18", num_classes=YO

In [8]:
BATCH_SIZE = 128
DATASET_NAME = "MNIST"

img_transform = transforms.Compose(
    [transforms.ToTensor()]
)

train_dataset = None
test_dataset = None
if DATASET_NAME == "MNIST":
    train_dataset = MNIST(
        root="./data/MNIST", download=True, train=True, transform=img_transform
    )
    test_dataset = MNIST(
    root="./data/MNIST", download=True, train=False, transform=img_transform
    )
elif DATASET_NAME == "CIFAR10":
    train_dataset = CIFAR10(
    root="./data/CIFAR10", download=True, train=True, transform=img_transform
    )
    test_dataset = CIFAR10(
    root="./data/CIFAR10", download=True, train=False, transform=img_transform
    )
else:
    raise Exception("Name of dataset not in: [MNIST, CIFAR10]")



In [9]:
if DATASET_NAME == "MNIST":
    channel_in = 1
    input_dim = 28*28*channel_in
    output_dim = 10
    idx = 0 # @param {type:"slider", min:0, max:59999, step:1}
else:
    channel_in = 3
    input_dim = 32*32*channel_in
    output_dim = 10
    idx = 0 # @param {type:"slider", min:0, max:49999, step:1}

px.imshow(train_dataset.data[idx])

In [10]:
# Class for a Fully Connected Network
class FullyConnectedNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(FullyConnectedNetwork, self).__init__()
        self.fc_in = nn.Linear(input_dim, hidden_dim)
        self.fcs = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.fc_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc_in(x))
        for fc in self.fcs: 
            x = F.relu(fc(x))
        x = self.fc_out(x)
        return x


In [5]:
# Class for Standard LeNet Network, reference http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf with some modification to follow the same number of parameters as the main paper for the task does
class LeNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LeNet, self).__init__()
        # 6 kernels 5x5
        self.conv1 = nn.Conv2d(input_dim, 6, 5, padding='valid',)
        # max-pooling over 2x2
        self.pool1 = nn.MaxPool2d(2, stride=2)
        # 16 kernels 5x5
        self.conv2 = nn.Conv2d(6, 16, 5, padding='valid')
        # max-pooling over 2x2
        self.pool2 = nn.MaxPool2d(2, stride=2)
        # 120 kernels 4x4 to match the dimensionality of the fully connected network
        self.conv3 = nn.Conv2d(16, 120, 4,)
        # 120 fully connected neurons, too many parameter in this case w.r.t. the paper
        #self.fc1 = nn.Linear(16 * 5 * 5, 120,)
        self.flat = nn.Flatten(start_dim=1)
        # 84 fully connected neurons
        self.fc2 = nn.Linear(120, 84)
        # 10 fully connected neurons
        self.fc3 = nn.Linear(84, output_dim,)

    def forward(self, x):
        #x = x.view(-1, 1, 28, 28)
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.flat(x)
        #x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [7]:
from torch.nn.modules.utils import _pair

# from https://discuss.pytorch.org/t/locally-connected-layers/26979
class LocallyConnected2d(nn.Module):
    def __init__(self, in_channels, out_channels, output_size, kernel_size, stride=1, bias=True):
        super(LocallyConnected2d, self).__init__()
        output_size = _pair(output_size)
        self.weight = nn.Parameter(
            nn.init.kaiming_normal_(torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size**2), nonlinearity='relu')
        )
        if bias:
            self.bias = nn.Parameter(nn.init.kaiming_normal_(
                torch.randn(1, out_channels, output_size[0], output_size[1]), nonlinearity='relu')
            )
        else:
            self.register_parameter('bias', None)
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        
    def forward(self, x):
        _, c, h, w = x.size()
        kh, kw = self.kernel_size
        dh, dw = self.stride
        x = x.unfold(2, kh, dh).unfold(3, kw, dw)
        x = x.contiguous().view(*x.size()[:-2], -1)
        # Sum in in_channel and kernel_size dims
        out = (x.unsqueeze(1) * self.weight).sum([2, -1])
        if self.bias is not None:
            out += self.bias
        return out

# Class for Untied LeNet Network
class Untied_LeNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Untied_LeNet, self).__init__()
        # 6 kernels 5x5
        self.conv1 = LocallyConnected2d(input_dim, 6, (24,24), 5)
        # max-pooling over 2x2
        self.pool1 = nn.MaxPool2d(2, stride=2)
        # 16 kernels 5x5
        self.conv2 = LocallyConnected2d(6, 16, (8,8), 5)
        # max-pooling over 2x2
        self.pool2 = nn.MaxPool2d(2, stride=2)
        # 120 kernels 4x4 to match the dimensionality of the fully connected network
        self.conv3 = LocallyConnected2d(16, 120, (1,1), 4)
        # 120 fully connected neurons, too many parameter in this case w.r.t. the paper
        #self.fc1 = nn.Linear(16 * 5 * 5, 120,)
        self.flat = nn.Flatten(start_dim=1)
        # 84 fully connected neurons
        self.fc2 = nn.Linear(120, 84)
        # 10 fully connected neurons
        self.fc3 = nn.Linear(84, output_dim,)

    def forward(self, x):
        #x = x.view(-1, 1, 28, 28)
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.flat(x)
        #x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [30]:
# Class for FC-LeNet Network
class FcLeNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FcLeNet, self).__init__()
        # 6 kernels 5x5
        self.fcconv1 = nn.Linear(input_dim, 3456)
        # max-pooling over 2x2
        self.pool1 = nn.MaxPool2d(2, stride=2)
        # 16 kernels 5x5
        self.fcconv2 = nn.Linear(864, 1024)
        # max-pooling over 2x2
        self.pool2 = nn.MaxPool2d(2, stride=2)
        # 120 kernels 4x4 to match the dimensionality of the fully connected network
        self.fcconv3 = nn.Linear(256, 120)
        # 120 fully connected neurons, too many parameter in this case w.r.t. the paper
        #self.fc1 = nn.Linear(16 * 5 * 5, 120,)
        self.flat = nn.Flatten(start_dim=1)
        # 84 fully connected neurons
        self.fc2 = nn.Linear(120, 84)
        # 10 fully connected neurons
        self.fc3 = nn.Linear(84, output_dim,)

    def forward(self, x):
        x = self.pool1(F.relu(self.fcconv1(x)).view(-1,6,24,24))
        x = self.flat(x)
        x = self.pool2(F.relu(self.fcconv2(x)).view(-1,16,8,8))
        x = self.flat(x)
        x = F.relu(self.fcconv3(x))
        #x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [22]:
# Class for FCTied-LeNet
class FCTied_LeNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FCTied_LeNet, self).__init__()
        # 6 kernels 5x5
        self.conv1 = nn.Conv2d(input_dim, 6, 55, padding='same',)
        # max-pooling over 2x2
        self.pool1 = nn.MaxPool2d(2, stride=2)
        # 16 kernels 5x5
        self.conv2 = nn.Conv2d(6, 16, 27, padding='same')
        # max-pooling over 2x2
        self.pool2 = nn.MaxPool2d(2, stride=2)
        # 120 kernels 4x4 to match the dimensionality of the fully connected network
        self.conv3 = nn.Conv2d(16, 120, 7,)
        # 120 fully connected neurons, too many parameter in this case w.r.t. the paper
        #self.fc1 = nn.Linear(16 * 5 * 5, 120,)
        self.flat = nn.Flatten(start_dim=1)
        # 84 fully connected neurons
        self.fc2 = nn.Linear(120, 84)
        # 10 fully connected neurons
        self.fc3 = nn.Linear(84, output_dim,)

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.flat(x)
        #x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [11]:
hidden_dim = 400
num_layers = 2
model = FullyConnectedNetwork(input_dim, hidden_dim, output_dim, num_layers)
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: ", num_params)

#save information to file
with open('results.txt', 'a') as f:
    f.write("\n##############################################")
    f.write(f"\nNumber_of_parameters: {num_params}")
    f.write(f"\nhidden_dim: {hidden_dim}")
    f.write(f"\nnum_layers: {num_layers}")
f.close()


Number of parameters:  638810


In [12]:
modules = [module for module in model.modules()]
# Print Model Summary
print(modules[0])

FullyConnectedNetwork(
  (fc_in): Linear(in_features=784, out_features=400, bias=True)
  (fcs): ModuleList(
    (0): Linear(in_features=400, out_features=400, bias=True)
    (1): Linear(in_features=400, out_features=400, bias=True)
  )
  (fc_out): Linear(in_features=400, out_features=10, bias=True)
)


In [6]:
model = LeNet(1, output_dim)
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: %d" % num_params)

Number of parameters: 44426


In [9]:
model = Untied_LeNet(channel_in, output_dim)
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: %d" % num_params)

Number of parameters: 286334


In [31]:
model = FcLeNet(input_dim, output_dim)
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: %d" % num_params)


Number of parameters: 3640574


In [23]:
model = FCTied_LeNet(channel_in, output_dim)
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: %d" % num_params)

Number of parameters: 193370


In [12]:
intrinsic_dim = 750
model_intrinsic = DenseWrap(model, intrinsic_dimension=intrinsic_dim, device=device)
num_params_intrinsic = sum(p.numel() for p in model_intrinsic.parameters() if p.requires_grad)
print("Number of parameters: %d" % num_params_intrinsic)
with open('results.txt', 'a') as f:
    f.write(f"\nintrinsic_dim: {intrinsic_dim}")
f.close()

Number of parameters: 750


In [13]:
torch.autograd.set_detect_anomaly(True)
# train the model


# training step
def train(model, train_loader, optimizer, epoch):
    model.train()
    #train_loss_averager = make_averager()  # mantain a running average of the loss

    # TRAIN
    tqdm_iterator = tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc=f"batch [loss: None]",
        leave=False,)

    len_tr_dl_ds = len(train_loader.dataset)
    len_tr_dl = len(train_loader)

    for batch_idx, (data, target) in tqdm_iterator:
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = F.cross_entropy(output, target)
        model.zero_grad()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        #train_loss_averager(loss.item())
        tqdm_iterator.set_description(
            f"Train Epoch: [{epoch} {batch_idx * len(data)}/{len_tr_dl_ds} ({100.0 * batch_idx / len_tr_dl:.0f}%)\tLoss: {loss.item():.6f}]"
        )
        tqdm_iterator.refresh()  # to show immediately the update
    tqdm_iterator.close()


# testing step
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target).item()  # sum up batch loss
            # get the index of the max probability
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).cpu().sum().item()

    test_loss /= len(test_loader.dataset)
    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )
    
    # show an histogram of the weights of the model
    """start = -1
    stop = 1
    bins = 30
    for param in model.parameters():
        if param.requires_grad:
            
            hist = torch.histc(param.data, bins = bins, min = start, max = stop)
            x = np.arange(start, stop, (stop-start)/bins)
            plt.bar(x, hist.cpu(), align='center')
            plt.ylabel('Frequency')
            plt.show() """

    return correct / len(test_loader.dataset)

if __name__ == "__main__":
    optimizer = optim.SGD(model_intrinsic.parameters(), lr=0.09)
    # download and load MNIST Dataset
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = 0, pin_memory=True)#, persistent_workers=True)#
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = 0, pin_memory=True)#, persistent_workers=True)#
    # train the model
    for epoch in range(1, 101):
        train(model_intrinsic, train_dataloader, optimizer, epoch)
        accuracy = test(model_intrinsic, test_dataloader)
        print("Validation Accuracy: {}".format(accuracy))
        #save information to file
        with open('results.txt', 'a') as f:
            f.write(f"\nEpoch: {epoch}")
            f.write(f"\nValidation Accuracy: {accuracy}")
        f.close()
        
        if accuracy > 0.90:
            torch.save(model_intrinsic.state_dict(), f"model_best_h{hidden_dim}_id{intrinsic_dim}_lay{num_layers}.pt")
            break

    j = None
    with open('results.json', 'r') as f:
        j = json.load(f)
    f.close()
    with open('results.json', 'w') as f:
        j[f"fcmodel_h{hidden_dim}_id{intrinsic_dim}_lay{num_layers}"] = {"number_parameter": num_params, "hidden_dimension": hidden_dim, "number_layers": num_layers, "intrinsic_dimension": intrinsic_dim, "epoch": epoch, "validation_accuracy": accuracy}
        json.dump(j, f, indent=4, separators=(',', ': '))


                                                                                                     


Test set: Average loss: 0.0047, Accuracy: 8097/10000 (81%)

Validation Accuracy: 0.8097


                                                                                                     


Test set: Average loss: 0.0103, Accuracy: 7290/10000 (73%)

Validation Accuracy: 0.729


                                                                                                     


Test set: Average loss: 0.0062, Accuracy: 7597/10000 (76%)

Validation Accuracy: 0.7597


                                                                                                     


Test set: Average loss: 0.0043, Accuracy: 8255/10000 (83%)

Validation Accuracy: 0.8255


                                                                                                     


Test set: Average loss: 0.0035, Accuracy: 8590/10000 (86%)

Validation Accuracy: 0.859


                                                                                                     


Test set: Average loss: 0.0142, Accuracy: 6047/10000 (60%)

Validation Accuracy: 0.6047


                                                                                                     


Test set: Average loss: 0.0050, Accuracy: 8046/10000 (80%)

Validation Accuracy: 0.8046


                                                                                                     


Test set: Average loss: 0.0033, Accuracy: 8712/10000 (87%)

Validation Accuracy: 0.8712


                                                                                                     


Test set: Average loss: 0.0063, Accuracy: 7644/10000 (76%)

Validation Accuracy: 0.7644


                                                                                                      


Test set: Average loss: 0.0051, Accuracy: 7942/10000 (79%)

Validation Accuracy: 0.7942


                                                                                                      


Test set: Average loss: 0.0110, Accuracy: 7040/10000 (70%)

Validation Accuracy: 0.704


                                                                                                      


Test set: Average loss: 0.0044, Accuracy: 8265/10000 (83%)

Validation Accuracy: 0.8265


                                                                                                      


Test set: Average loss: 0.0054, Accuracy: 7999/10000 (80%)

Validation Accuracy: 0.7999


                                                                                                      


Test set: Average loss: 0.0046, Accuracy: 8250/10000 (82%)

Validation Accuracy: 0.825


                                                                                                      


Test set: Average loss: 0.0043, Accuracy: 8215/10000 (82%)

Validation Accuracy: 0.8215


                                                                                                      


Test set: Average loss: 0.0037, Accuracy: 8510/10000 (85%)

Validation Accuracy: 0.851


                                                                                                      


Test set: Average loss: 0.0091, Accuracy: 7190/10000 (72%)

Validation Accuracy: 0.719


                                                                                                      


Test set: Average loss: 0.0044, Accuracy: 8322/10000 (83%)

Validation Accuracy: 0.8322


                                                                                                      


Test set: Average loss: 0.0037, Accuracy: 8550/10000 (86%)

Validation Accuracy: 0.855


                                                                                                      


Test set: Average loss: 0.0036, Accuracy: 8599/10000 (86%)

Validation Accuracy: 0.8599


                                                                                                      


Test set: Average loss: 0.0032, Accuracy: 8765/10000 (88%)

Validation Accuracy: 0.8765


                                                                                                      


Test set: Average loss: 0.0052, Accuracy: 8035/10000 (80%)

Validation Accuracy: 0.8035


                                                                                                      


Test set: Average loss: 0.0041, Accuracy: 8335/10000 (83%)

Validation Accuracy: 0.8335


                                                                                                      


Test set: Average loss: 0.0031, Accuracy: 8741/10000 (87%)

Validation Accuracy: 0.8741


                                                                                                      


Test set: Average loss: 0.0035, Accuracy: 8629/10000 (86%)

Validation Accuracy: 0.8629


                                                                                                      


Test set: Average loss: 0.0058, Accuracy: 7600/10000 (76%)

Validation Accuracy: 0.76


                                                                                                      


Test set: Average loss: 0.0054, Accuracy: 8047/10000 (80%)

Validation Accuracy: 0.8047


                                                                                                      


Test set: Average loss: 0.0038, Accuracy: 8549/10000 (85%)

Validation Accuracy: 0.8549


                                                                                                      


Test set: Average loss: 0.0034, Accuracy: 8650/10000 (86%)

Validation Accuracy: 0.865


                                                                                                      


Test set: Average loss: 0.0039, Accuracy: 8452/10000 (85%)

Validation Accuracy: 0.8452


                                                                                                      


Test set: Average loss: 0.0055, Accuracy: 8045/10000 (80%)

Validation Accuracy: 0.8045


                                                                                                      


Test set: Average loss: 0.0038, Accuracy: 8507/10000 (85%)

Validation Accuracy: 0.8507


                                                                                                      


Test set: Average loss: 0.0037, Accuracy: 8549/10000 (85%)

Validation Accuracy: 0.8549


                                                                                                      


Test set: Average loss: 0.0042, Accuracy: 8341/10000 (83%)

Validation Accuracy: 0.8341


                                                                                                      


Test set: Average loss: 0.0051, Accuracy: 7962/10000 (80%)

Validation Accuracy: 0.7962


                                                                                                      


Test set: Average loss: 0.0031, Accuracy: 8748/10000 (87%)

Validation Accuracy: 0.8748


                                                                                                      


Test set: Average loss: 0.0038, Accuracy: 8484/10000 (85%)

Validation Accuracy: 0.8484


                                                                                                      


Test set: Average loss: 0.0073, Accuracy: 7509/10000 (75%)

Validation Accuracy: 0.7509


                                                                                                      


Test set: Average loss: 0.0041, Accuracy: 8277/10000 (83%)

Validation Accuracy: 0.8277


                                                                                                      


Test set: Average loss: 0.0033, Accuracy: 8689/10000 (87%)

Validation Accuracy: 0.8689


                                                                                                      


Test set: Average loss: 0.0039, Accuracy: 8478/10000 (85%)

Validation Accuracy: 0.8478


                                                                                                      


Test set: Average loss: 0.0037, Accuracy: 8592/10000 (86%)

Validation Accuracy: 0.8592


                                                                                                      


Test set: Average loss: 0.0050, Accuracy: 8072/10000 (81%)

Validation Accuracy: 0.8072


                                                                                                      


Test set: Average loss: 0.0046, Accuracy: 8232/10000 (82%)

Validation Accuracy: 0.8232


                                                                                                      


Test set: Average loss: 0.0051, Accuracy: 8088/10000 (81%)

Validation Accuracy: 0.8088


                                                                                                      


Test set: Average loss: 0.0038, Accuracy: 8416/10000 (84%)

Validation Accuracy: 0.8416


                                                                                                      


Test set: Average loss: 0.0078, Accuracy: 7193/10000 (72%)

Validation Accuracy: 0.7193


                                                                                                      


Test set: Average loss: 0.0050, Accuracy: 8102/10000 (81%)

Validation Accuracy: 0.8102


                                                                                                      


Test set: Average loss: 0.0040, Accuracy: 8489/10000 (85%)

Validation Accuracy: 0.8489


                                                                                                      


Test set: Average loss: 0.0037, Accuracy: 8551/10000 (86%)

Validation Accuracy: 0.8551


                                                                                                      


Test set: Average loss: 0.0042, Accuracy: 8424/10000 (84%)

Validation Accuracy: 0.8424


                                                                                                      


Test set: Average loss: 0.0033, Accuracy: 8723/10000 (87%)

Validation Accuracy: 0.8723


                                                                                                      


Test set: Average loss: 0.0032, Accuracy: 8720/10000 (87%)

Validation Accuracy: 0.872


                                                                                                      


Test set: Average loss: 0.0034, Accuracy: 8699/10000 (87%)

Validation Accuracy: 0.8699


                                                                                                      


Test set: Average loss: 0.0036, Accuracy: 8604/10000 (86%)

Validation Accuracy: 0.8604


                                                                                                      


Test set: Average loss: 0.0035, Accuracy: 8591/10000 (86%)

Validation Accuracy: 0.8591


                                                                                                      


Test set: Average loss: 0.0041, Accuracy: 8379/10000 (84%)

Validation Accuracy: 0.8379


                                                                                                      


Test set: Average loss: 0.0066, Accuracy: 7317/10000 (73%)

Validation Accuracy: 0.7317


                                                                                                      


Test set: Average loss: 0.0038, Accuracy: 8513/10000 (85%)

Validation Accuracy: 0.8513


                                                                                                      


Test set: Average loss: 0.0062, Accuracy: 7743/10000 (77%)

Validation Accuracy: 0.7743


                                                                                                      


Test set: Average loss: 0.0033, Accuracy: 8781/10000 (88%)

Validation Accuracy: 0.8781


                                                                                                      


Test set: Average loss: 0.0042, Accuracy: 8375/10000 (84%)

Validation Accuracy: 0.8375


                                                                                                      


Test set: Average loss: 0.0034, Accuracy: 8659/10000 (87%)

Validation Accuracy: 0.8659


                                                                                                      


Test set: Average loss: 0.0074, Accuracy: 7277/10000 (73%)

Validation Accuracy: 0.7277


                                                                                                      


Test set: Average loss: 0.0035, Accuracy: 8560/10000 (86%)

Validation Accuracy: 0.856


                                                                                                      


Test set: Average loss: 0.0055, Accuracy: 7935/10000 (79%)

Validation Accuracy: 0.7935


                                                                                                      


Test set: Average loss: 0.0044, Accuracy: 8344/10000 (83%)

Validation Accuracy: 0.8344


                                                                                                      


Test set: Average loss: 0.0037, Accuracy: 8577/10000 (86%)

Validation Accuracy: 0.8577


                                                                                                      


Test set: Average loss: 0.0034, Accuracy: 8713/10000 (87%)

Validation Accuracy: 0.8713


                                                                                                      


Test set: Average loss: 0.0032, Accuracy: 8787/10000 (88%)

Validation Accuracy: 0.8787


                                                                                                      


Test set: Average loss: 0.0037, Accuracy: 8631/10000 (86%)

Validation Accuracy: 0.8631


                                                                                                      


Test set: Average loss: 0.0036, Accuracy: 8478/10000 (85%)

Validation Accuracy: 0.8478


                                                                                                      


Test set: Average loss: 0.0033, Accuracy: 8727/10000 (87%)

Validation Accuracy: 0.8727


                                                                                                      


Test set: Average loss: 0.0033, Accuracy: 8769/10000 (88%)

Validation Accuracy: 0.8769


                                                                                                      


Test set: Average loss: 0.0060, Accuracy: 7792/10000 (78%)

Validation Accuracy: 0.7792


                                                                                                      


Test set: Average loss: 0.0039, Accuracy: 8447/10000 (84%)

Validation Accuracy: 0.8447


                                                                                                      


Test set: Average loss: 0.0041, Accuracy: 8288/10000 (83%)

Validation Accuracy: 0.8288


                                                                                                      


Test set: Average loss: 0.0069, Accuracy: 7492/10000 (75%)

Validation Accuracy: 0.7492


                                                                                                      


Test set: Average loss: 0.0034, Accuracy: 8647/10000 (86%)

Validation Accuracy: 0.8647


                                                                                                      


Test set: Average loss: 0.0050, Accuracy: 8050/10000 (80%)

Validation Accuracy: 0.805


                                                                                                      


Test set: Average loss: 0.0042, Accuracy: 8417/10000 (84%)

Validation Accuracy: 0.8417


                                                                                                      


Test set: Average loss: 0.0035, Accuracy: 8614/10000 (86%)

Validation Accuracy: 0.8614


                                                                                                      


Test set: Average loss: 0.0032, Accuracy: 8732/10000 (87%)

Validation Accuracy: 0.8732


                                                                                                      


Test set: Average loss: 0.0033, Accuracy: 8753/10000 (88%)

Validation Accuracy: 0.8753


                                                                                                      


Test set: Average loss: 0.0037, Accuracy: 8545/10000 (85%)

Validation Accuracy: 0.8545


                                                                                                      


Test set: Average loss: 0.0077, Accuracy: 7367/10000 (74%)

Validation Accuracy: 0.7367


                                                                                                      


Test set: Average loss: 0.0043, Accuracy: 8277/10000 (83%)

Validation Accuracy: 0.8277


                                                                                                      


Test set: Average loss: 0.0034, Accuracy: 8751/10000 (88%)

Validation Accuracy: 0.8751


                                                                                                      


Test set: Average loss: 0.0034, Accuracy: 8687/10000 (87%)

Validation Accuracy: 0.8687


                                                                                                      


Test set: Average loss: 0.0063, Accuracy: 7517/10000 (75%)

Validation Accuracy: 0.7517


                                                                                                      


Test set: Average loss: 0.0044, Accuracy: 8226/10000 (82%)

Validation Accuracy: 0.8226


                                                                                                      


Test set: Average loss: 0.0050, Accuracy: 8097/10000 (81%)

Validation Accuracy: 0.8097


                                                                                                      


Test set: Average loss: 0.0042, Accuracy: 8420/10000 (84%)

Validation Accuracy: 0.842


                                                                                                      


Test set: Average loss: 0.0037, Accuracy: 8597/10000 (86%)

Validation Accuracy: 0.8597


                                                                                                      


Test set: Average loss: 0.0044, Accuracy: 8298/10000 (83%)

Validation Accuracy: 0.8298


                                                                                                      


Test set: Average loss: 0.0075, Accuracy: 7479/10000 (75%)

Validation Accuracy: 0.7479


                                                                                                      


Test set: Average loss: 0.0037, Accuracy: 8582/10000 (86%)

Validation Accuracy: 0.8582


                                                                                                      


Test set: Average loss: 0.0032, Accuracy: 8781/10000 (88%)

Validation Accuracy: 0.8781


                                                                                                      


Test set: Average loss: 0.0122, Accuracy: 7144/10000 (71%)

Validation Accuracy: 0.7144


                                                                                                       


Test set: Average loss: 0.0061, Accuracy: 7806/10000 (78%)

Validation Accuracy: 0.7806


In [60]:
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=[1, 3.2, 5.4, 7.6, 9.8],
    y=[1, 3.2, 5.4, 7.6, 9.8],
    mode='markers',
    marker=dict(
        color=[1, 2, 3, 4, 5],
        size=30,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 50'
))

fig.add_trace(go.Scatter(
    x=[1, 3.2, 5.4, 7.6, 9.8],
    y=[1, 3.2, 5.4, 7.6, 9.8],
    mode='markers',
    marker=dict(
        color=[1, 2, 3, 4, 5],
        size=55,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 100'
))

fig.add_trace(go.Scatter(
    x=[1, 3.2, 5.4, 7.6, 9.8],
    y=[1, 3.2, 5.4, 7.6, 9.8],
    mode='markers',
    marker=dict(
        color=[1, 2, 3, 4, 5],
        size=70,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 200'
))

fig.add_trace(go.Scatter(
    x=[1, 3.2, 5.4, 7.6, 9.8],
    y=[1, 3.2, 5.4, 7.6, 9.8],
    mode='markers',
    marker=dict(
        color=[1, 2, 3, 4, 5],
        size=90,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 400'
))

fig.update_layout(legend=dict(
    yanchor="top",
    y=0.99,
    xanchor="left",
    x=0.01
))

fig.show()