In [33]:
import torch
import functorch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from tqdm import tqdm
%matplotlib inline

In [5]:
import os
os.getcwd()

'c:\\Users\\aresf\\Desktop\\Code\\Pytorch_hypernets\\notebooks'

In [6]:
# Important for using GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


cuda:0


In [29]:
# hyperparams
# input_size = 784  # 28x28
# synthnet_hidden_size = 100
# num_classes = 10
num_epochs = 20
batch_size = 100
learning_rate = 0.001

In [8]:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/mnist_data', 
                                                          download=True, 
                                                          train=True,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                          ])), 
                                           batch_size=batch_size, 
                                           shuffle=True)

# download and transform test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/mnist_data', 
                                                          download=True, 
                                                          train=False,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                          ])), 
                                           batch_size=batch_size, 
                                           shuffle=True)

In [47]:
class HyperNetwork(torch.nn.Module):
    def __init__(self, hypnet: torch.nn.Module, synthnet: torch.nn.Module):
        # hypnet is the network that takes an embedding z and produces the parameters of synthnet
        # synthnet is the network that takes input x and produces h
        super().__init__()
        s_func, s_params0 = functorch.make_functional(synthnet)

        # store the information about the parameters
        self._sp_shapes = [sp.shape for sp in s_params0]

        # These are the index offsets for each parameter (e.g. set of weights between
        # each layer and set of biases for each layer)
        self._sp_offsets = np.array([0, *np.cumsum([sp.numel() for sp in s_params0])])

        # make the synthnet_func to accept batched parameters
        synthnet_func = functorch.vmap(s_func)
        # a workaround of functorch's bug #793
        # self._synthnet_batched_func = [synthnet_func]
        self._synthnet_batched_func = synthnet_func
        self._hypnet = hypnet

    def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, nx), z: (batch_size, nz)
        params = self._hypnet(z)  # params: (batch_size, nparams_tot)

        # rearrange params to have the same shape as the synthnet params, except on the batch dimension
        # print(f"self._sp_offsets: {self._sp_offsets}")
        params_lst = []
        for i, shape in enumerate(self._sp_shapes):
            j0, j1 = self._sp_offsets[i], self._sp_offsets[i + 1]
            params_lst.append(params[..., j0:j1].reshape(-1, *shape))

        # print(f"params_lst: {params_lst}")

        # apply the function to the batched parameters and x
        h = self._synthnet_batched_func(params_lst, x)
        h = F.log_softmax(h, dim=1)
        return h
        # return params_lst

In [113]:
class Encoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Building an linear encoder with Linear
        # layer followed by Relu activation function
        # 784 ==> 9
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 64)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        return encoded

In [66]:
# np.array([0, *np.cumsum([sp.numel() for sp in a])])

In [112]:
synthnet = torch.nn.Sequential(
    torch.nn.Linear(28*28, 64),
    torch.nn.ELU(),
    # torch.nn.Linear(64, 64),
    # torch.nn.ELU(),
    torch.nn.Linear(64, 10),
).to(device)


In [102]:
_, synthnet_params = functorch.make_functional(synthnet)
n_synthnet_params = sum([p.numel() for p in synthnet_params])

In [89]:
# Ben's Code - Just me peeking inside the synthnet_params to understand what PyTorch Module params actually are under the hood. synthnet is just a nn with 1 hidden layer.
print(len(synthnet_params))
print(synthnet_params[0].size())  # The weights for the connections from input layer (784 units) to hidden layer (100 units)
print(synthnet_params[1].size())  # The biases for each of the 100 units in the hidden layer???
print(synthnet_params[2].size())  # The weights for the connections from the hidden layer (100 units) to the output layer (10 units)
print(synthnet_params[3].size())  # The biases for each of the 10 units in the output layer???
print(n_synthnet_params)

6
torch.Size([64, 784])
torch.Size([64])
torch.Size([64, 64])
torch.Size([64])
55050


In [114]:
#! todo 
hypnet = torch.nn.Sequential(
    torch.nn.Linear(64, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 100),
    torch.nn.LayerNorm(100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 100),
    torch.nn.LayerNorm(100),
    torch.nn.ReLU(),
    # torch.nn.Linear(100, 100),
    # torch.nn.LayerNorm(100),
    # torch.nn.ReLU(),
    torch.nn.Linear(100, n_synthnet_params),
).to(device)

In [115]:
module = HyperNetwork(hypnet, synthnet).to(device)
encoder = Encoder().to(device)

In [116]:
criterion = torch.nn.CrossEntropyLoss()

In [117]:
optimizer = torch.optim.Adam([*module.parameters(), *encoder.parameters()], lr=0.001)

In [118]:
# Test the accuracy of our trained model
# In test phase, we don't need to compute gradients (for memory efficiency)
def calc_accuracy(test_loader):
    with torch.no_grad():
        n_correct = 0
        n_samples = 0
        for images, labels in test_loader:
            images = images.reshape(-1, 28*28).to(device)
            labels = labels.to(device)
            z = encoder(images).to(device)
            outputs = module(images, z).reshape(-1, 10)
            # max returns (value ,index)
            _, predicted = torch.max(outputs.data, 1)
            n_samples += labels.size(0)
            n_correct += (predicted == labels).sum().item()

        acc = 100.0 * n_correct / n_samples
        return acc
        

In [119]:
calc_accuracy(test_loader)

10.42

In [109]:
for epoch in np.arange(5):  # loop over the dataset num_epochs times
    for inputs, labels in tqdm(train_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs = inputs.view(-1, 28*28).to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        z = encoder(inputs)
        # forward + backward + optimize
        outputs = module(inputs, z).reshape(-1, 10)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
    acc = calc_accuracy(test_loader)
    print(f'Accuracy of the network on the 10000 test images: {acc} %')

print('Finished Training')

100%|██████████| 600/600 [00:09<00:00, 62.61it/s]


Accuracy of the network on the 10000 test images: 10.27 %


100%|██████████| 600/600 [00:09<00:00, 64.40it/s]
100%|██████████| 600/600 [00:09<00:00, 66.19it/s]
100%|██████████| 600/600 [00:09<00:00, 63.17it/s]
100%|██████████| 600/600 [00:09<00:00, 64.98it/s]
100%|██████████| 600/600 [00:08<00:00, 68.39it/s]


Accuracy of the network on the 10000 test images: 9.6 %


100%|██████████| 600/600 [00:08<00:00, 67.02it/s]
 93%|█████████▎| 560/600 [00:09<00:00, 60.89it/s]


KeyboardInterrupt: 