In [1]:
import torch
import functorch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
%matplotlib inline

In [2]:
import os
os.getcwd()

'c:\\Users\\aresf\\Desktop\\Code\\Pytorch_hypernets\\notebooks'

In [4]:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/mnist_data', 
                                                          download=True, 
                                                          train=True,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                          ])), 
                                           batch_size=32, 
                                           shuffle=True)

# download and transform test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/mnist_data', 
                                                          download=True, 
                                                          train=False,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                          ])), 
                                           batch_size=32, 
                                           shuffle=True)

In [64]:
class HyperNetwork(torch.nn.Module):
    def __init__(self, hypnet: torch.nn.Module, synthnet: torch.nn.Module):
        # hypnet is the network that takes x and produces the parameters of synthnet
        # synthnet is the network that takes z and produces h
        super().__init__()
        s_func, s_params0 = functorch.make_functional(synthnet)

        # store the information about the parameters
        self._sp_shapes = [sp.shape for sp in s_params0]

        self._sp_offsets = np.array([0, *np.cumsum([sp.numel() for sp in s_params0])])

        # make the synthnet_func to accept batched parameters
        synthnet_func = functorch.vmap(s_func)
        # a workaround of functorch's bug #793
        # self._synthnet_batched_func = [synthnet_func]
        self._synthnet_batched_func = synthnet_func
        self._hypnet = hypnet

    def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, nx), z: (batch_size, nz)
        params = self._hypnet(x)  # params: (batch_size, nparams_tot)

        # rearrange params to have the same shape as the synthnet params, except on the batch dimension
        params_lst = []
        for i, shape in enumerate(self._sp_shapes):
            j0, j1 = self._sp_offsets[i], self._sp_offsets[i + 1]
            params_lst.append(params[..., j0:j1].reshape(-1, *shape))

        # apply the function to the batched parameters and z
        h = self._synthnet_batched_func(params_lst, z)
        return h
        # return params_lst

In [36]:
np.array([0, *np.cumsum([sp.numel() for sp in a])])

array([    0, 78400, 78500, 79500, 79510])

In [70]:
synthnet = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(784, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 10),
)


In [71]:
_, synthnet_params = functorch.make_functional(synthnet)
n_synthnet_params = sum([p.numel() for p in synthnet_params])

In [72]:
#! todo 
hypnet = torch.nn.Sequential(
    torch.nn.Linear(3, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, n_synthnet_params),
)

In [73]:
module = HyperNetwork(hypnet, synthnet)

In [76]:
z = torch.randn(32, 3)
x, y = next(iter(test_loader))

In [78]:
criterion = torch.nn.CrossEntropyLoss()

In [86]:
optimizer = torch.optim.SGD(module.parameters(), lr=0.001)

In [87]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        #!todo
        # z = Encoder(inputs)

        # forward + backward + optimize
        outputs = module(z, inputs).reshape(-1, 10)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 500 == 0:    # print every 500 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 500:.3f}')
            running_loss = 0.0

print('Finished Training')

[1,     1] loss: 0.013
[1,   501] loss: 2.221
[1,  1001] loss: 1.049
[1,  1501] loss: 0.803
[2,     1] loss: 0.002
[2,   501] loss: 0.647
[2,  1001] loss: 0.583
[2,  1501] loss: 0.554
Finished Training
