In [43]:
import torch
import functorch
import torchvision
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

In [51]:
batch_size = 64

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        "../data",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    ),
    batch_size=batch_size,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        "../data",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    ),
    batch_size=batch_size,
    shuffle=True,
)

In [52]:
x, y = next(iter(train_loader))
x.shape, y.shape

(torch.Size([64, 1, 28, 28]), torch.Size([64]))

In [39]:
class PrimaryNet(torch.nn.Module):
    def __init__(self, hypnet: torch.nn.Module, primary: torch.nn.Module):
        super(PrimaryNet, self).__init__()

        p_func, p_params = functorch.make_functional(primary)

        self.p_shapes = [p.shape for p in p_params]
        self.p_offsets = np.array(
            [0, *np.cumsum([p.numel() for p in p_params])])

        self.primary_func = functorch.vmap(p_func)
        self.hypnet = hypnet

    def forward(self, z, x):
        # z is embedding, x is primary input
        params = self.hypnet(z)
        params_lst = []
        for i, shape in enumerate(self.p_shapes):
            j0, j1 = self.p_offsets[i], self.p_offsets[i + 1]
            params_lst.append(params[..., j0:j1].reshape(-1, *shape))

        h = self.primary_func(params_lst, x)
        return h, self.primary_func

In [40]:
primary = torch.nn.Sequential(
    torch.nn.Linear(2, 100), torch.nn.ReLU(), torch.nn.Linear(100, 5)
)

_, primary_params = functorch.make_functional(primary)
n_primary_params = sum([p.numel() for p in primary_params])

hypnet = torch.nn.Sequential(
    torch.nn.Linear(3, 100), torch.nn.ReLU(), torch.nn.Linear(100, n_primary_params)
)

module = PrimaryNet(hypnet, primary)

  warn_deprecated('make_functional', 'torch.func.functional_call')
  warn_deprecated('vmap', 'torch.vmap')


In [36]:
inp = torch.randn(64, 3)
x = torch.randn(64, 2)

In [41]:
out, pn = module(inp, x)
out.shape

torch.Size([64, 5])

<function torch._functorch.apis.vmap.<locals>.wrapped(*args, **kwargs)>