In [1]:
import torch
import functorch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from tqdm import tqdm
%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.getcwd()

'/home/bkosa2/RNP/RNP_PyTorch/Pytorch_hypernets/notebooks'

In [3]:
# Important for using GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


cuda:0


In [4]:
# hyperparams
# input_size = 784  # 28x28
# synthnet_hidden_size = 100
# num_classes = 10
num_epochs = 2
batch_size = 100
learning_rate = 0.001

In [5]:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/mnist_data', 
                                                          download=True, 
                                                          train=True,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                          ])), 
                                           batch_size=batch_size, 
                                           shuffle=True)

# download and transform test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/mnist_data', 
                                                          download=True, 
                                                          train=False,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                          ])), 
                                           batch_size=batch_size, 
                                           shuffle=True)

In [64]:
class HyperNetwork(torch.nn.Module):
    def __init__(self, hypnet: torch.nn.Module, synthnet: torch.nn.Module):
        # hypnet is the network that takes an embedding z and produces the parameters of synthnet
        # synthnet is the network that takes input x and produces h
        super().__init__()
        s_func, s_params0 = functorch.make_functional(synthnet)

        # store the information about the parameters
        self._sp_shapes = [sp.shape for sp in s_params0]

        # These are the index offsets for each parameter (e.g. set of weights between
        # each layer and set of biases for each layer)
        self._sp_offsets = np.array([0, *np.cumsum([sp.numel() for sp in s_params0])])

        # make the synthnet_func to accept batched parameters
        synthnet_func = functorch.vmap(s_func)
        # a workaround of functorch's bug #793
        # self._synthnet_batched_func = [synthnet_func]
        self._synthnet_batched_func = synthnet_func
        self._hypnet = hypnet

    def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, nx), z: (batch_size, nz)
        params = self._hypnet(z)  # params: (batch_size, nparams_tot)

        # rearrange params to have the same shape as the synthnet params, except on the batch dimension
        # print(f"self._sp_offsets: {self._sp_offsets}")
        params_lst = []
        for i, shape in enumerate(self._sp_shapes):
            j0, j1 = self._sp_offsets[i], self._sp_offsets[i + 1]
            params_lst.append(params[..., j0:j1].reshape(-1, *shape))

        # print(f"params_lst: {params_lst}")

        # apply the function to the batched parameters and x
        h = self._synthnet_batched_func(params_lst, x)
        return h
        # return params_lst

In [65]:
class Encoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Building an linear encoder with Linear
        # layer followed by Relu activation function
        # 784 ==> 9
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 9)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        return encoded

In [66]:
# np.array([0, *np.cumsum([sp.numel() for sp in a])])

In [67]:
synthnet = torch.nn.Sequential(
    torch.nn.Linear(28*28, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 10),
).to(device)


In [68]:
_, synthnet_params = functorch.make_functional(synthnet)
n_synthnet_params = sum([p.numel() for p in synthnet_params])

In [69]:
# Ben's Code - Just me peeking inside the synthnet_params to understand what PyTorch Module params actually are under the hood. synthnet is just a nn with 1 hidden layer.
print(len(synthnet_params))
print(synthnet_params[0].size())  # The weights for the connections from input layer (784 units) to hidden layer (100 units)
print(synthnet_params[1].size())  # The biases for each of the 100 units in the hidden layer???
print(synthnet_params[2].size())  # The weights for the connections from the hidden layer (100 units) to the output layer (10 units)
print(synthnet_params[3].size())  # The biases for each of the 10 units in the output layer???
print(n_synthnet_params)

4
torch.Size([100, 784])
torch.Size([100])
torch.Size([10, 100])
torch.Size([10])
79510


In [70]:
#! todo 
hypnet = torch.nn.Sequential(
    torch.nn.Linear(9, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, n_synthnet_params),
).to(device)

In [71]:
module = HyperNetwork(hypnet, synthnet).to(device)
encoder = Encoder().to(device)

In [72]:
# Just for Ben to run one iteraiton of hypernet forward func to see what's going on
# under the hood.
inputs, _ = next(iter(train_loader))
inputs = inputs.view(-1, 28*28).to(device)
z = encoder(inputs).to(device)
module(inputs, z)

tensor([[ 2.2939e-01, -1.0539e+00,  1.2904e+00, -8.5967e-01, -9.5300e-01,
          4.1317e-01,  1.9479e-01, -1.6215e+00,  4.4721e-01,  4.0094e-01],
        [-4.1470e-02,  1.2322e+00,  1.2806e+00, -1.2209e-01, -5.6639e-01,
          7.0445e-01, -6.7706e-02, -3.2500e-01, -1.2897e-01, -6.0678e-01],
        [ 1.5933e-01, -4.2136e-01,  3.6571e-01, -6.8924e-01,  3.4137e-02,
         -1.7620e-01,  6.7944e-01, -1.3010e+00,  5.4643e-01,  4.5805e-01],
        [-2.0620e-01,  6.4094e-01,  5.7286e-01, -3.1576e-01,  1.6518e-01,
         -5.5355e-01,  1.1691e-01, -3.1885e-02, -5.1861e-01, -1.0369e+00],
        [ 6.0133e-01, -2.4626e-01, -8.8934e-01, -7.9846e-01, -9.3119e-01,
          6.6587e-01,  3.8584e-01, -1.2794e-01,  2.4409e-01, -6.5646e-01],
        [ 2.8435e-01,  1.9275e-01,  3.9300e-01, -3.0803e-01, -1.3049e-01,
         -7.7536e-02,  4.9321e-01, -7.1867e-01, -3.6915e-01,  1.0619e-01],
        [ 1.3212e-01, -5.8051e-01,  3.0817e-01, -4.7559e-01, -4.5391e-01,
          3.9928e-01,  1.5846e-0

In [73]:
# This is for if you want to test out our HyperNet pipeline with a random input embedding z (that is not generated from an encoder)
# z = torch.randn(32, 3)
# x, y = next(iter(test_loader))

In [74]:
criterion = torch.nn.CrossEntropyLoss()

In [75]:
optimizer = torch.optim.SGD(module.parameters(), lr=learning_rate)

In [76]:
for epoch in range(num_epochs):  # loop over the dataset num_epochs times
    running_loss = 0.0
    i = 0  # Just keeps track of which batch we're on until we get to batch 500, in which case we print
           # out the aggregated loss over the 500 batches and reset (i.e. allows us to just print out the
           # average loss every 500 batches instead of every batch)
    for inputs, labels in tqdm(train_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs = inputs.view(-1, 28*28).to(device)
        labels = labels.to(device)
        

        # zero the parameter gradients
        optimizer.zero_grad()

        z = encoder(inputs).to(device)

        # forward + backward + optimize
        # outputs = module(inputs, weights).reshape(-1, 10)
        outputs = module(inputs, z).reshape(-1, 10)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        i += 1
        if i % 100 == 0:    # print every 500 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

 18%|█▊        | 108/600 [00:02<00:08, 55.07it/s]

[1,   101] loss: 2.300


 35%|███▌      | 210/600 [00:03<00:07, 53.03it/s]

[1,   201] loss: 2.091


 51%|█████     | 306/600 [00:05<00:05, 53.43it/s]

[1,   301] loss: 1.913


 68%|██████▊   | 408/600 [00:07<00:03, 53.57it/s]

[1,   401] loss: 1.739


 84%|████████▍ | 504/600 [00:09<00:01, 52.84it/s]

[1,   501] loss: 1.550


100%|██████████| 600/600 [00:11<00:00, 53.51it/s]


[1,   601] loss: 1.374


 18%|█▊        | 107/600 [00:01<00:08, 58.49it/s]

[2,   101] loss: 1.221


 35%|███▌      | 210/600 [00:03<00:07, 54.84it/s]

[2,   201] loss: 1.080


 51%|█████     | 306/600 [00:05<00:05, 53.87it/s]

[2,   301] loss: 0.965


 68%|██████▊   | 408/600 [00:07<00:03, 54.83it/s]

[2,   401] loss: 0.868


 85%|████████▌ | 510/600 [00:09<00:01, 52.84it/s]

[2,   501] loss: 0.795


100%|██████████| 600/600 [00:10<00:00, 55.32it/s]

[2,   601] loss: 0.754
Finished Training





In [77]:
# Test the accuracy of our trained model
# In test phase, we don't need to compute gradients (for memory efficiency)
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    for images, labels in test_loader:
        images = images.reshape(-1, 28*28).to(device)
        labels = labels.to(device)
        z = encoder(inputs).to(device)
        outputs = module(inputs, z).reshape(-1, 10)
        # max returns (value ,index)
        _, predicted = torch.max(outputs.data, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()

    acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of the network on the 10000 test images: {acc} %')

Accuracy of the network on the 10000 test images: 10.08 %
