In [None]:
import torch
import torch.nn as nn

class Hypernetwork(nn.Module):
    def __init__(self, input_size, num_layers, layer_size, chunk_size, chunk_emb_size, num_chunks):
        super().__init__()
        self.num_chunks = num_chunks
        self.chunk_embeddings = self._generate_chunk_embeddings(chunk_emb_size, num_chunks) 
        
        input_size = input_size + chunk_emb_size
        hypernet_layers = self._prepare_layers(num_layers, layer_size, input_size, chunk_size)
        self.hypernet = nn.Sequential(*hypernet_layers)

    def forward(self, x):
        fast_weights = []
        for chunk_emb in self.chunk_embeddings:
            cat_ = torch.cat((x, chunk_emb), dim=1)
            fast_weight_chunk = self.hypernet(cat_)
            fast_weights.append(fast_weight_chunk)

        fast_weights = self._merge_layers(fast_weights)
        
        return fast_weights

    def _generate_chunk_embeddings(self, chunk_emb_size, num_chunks):
        # chunk_embs = [torch.rand((1, chunk_emb_size)) for _ in range(num_chunks)]
        chunk_embs = []
        for _ in range(num_chunks):
            chunk_emb = torch.rand((1, chunk_emb_size))
            chunk_emb.requires_grad = True
            chunk_embs.append(chunk_emb)
        return chunk_embs
        
    def _prepare_layers(self, num_layers, layer_size, input_size, chunk_size):
        input_layer = nn.Linear(in_features=input_size, out_features=layer_size)
        layers = [input_layer, nn.ReLU()]
        for _ in range(num_layers-1):
            layer = nn.Linear(layer_size, layer_size)
            layers.append(layer)
            layers.append(nn.ReLU())

        layers.append(nn.Linear(in_features=layer_size, out_features=chunk_size))
        layers.append(nn.Sigmoid())
        return layers

    def _merge_layers(self, fast_weights):
        merged_params = []
        for i in range(2):
            weights = torch.cat(fast_weights[128*i:128*(i+1)])
            bias = fast_weights[128*(i+1)]
            merged_params.append(weights)
            merged_params.append(bias)
        return merged_params

In [None]:
import torch.nn.functional as F

class Linear_fw(nn.Linear):
    def __init__(self, in_features, out_features):
        super().__init__(in_features, out_features)
        self.weight.fast = self.weight
        self.bias.fast = self.bias

    def forward(self, x):
        if self.weight.fast is not None and self.bias.fast is not None:
            out = F.linear(x, self.weight.fast, self.bias.fast)
        else:
            out = F.linear(x, self.weight, self.bias)
        return out

In [None]:
class MLP_FW(nn.Module):
    def __init__(self, input_size, num_layers, layer_size, num_classes):
        super().__init__()
        layers = self._generate_layers(input_size, num_layers, layer_size, num_classes)
        
        self.net = nn.Sequential(*layers)

    def _generate_layers(self, input_size, num_layers, layer_size, num_classes):
        layers = [Linear_fw(input_size, layer_size), nn.ReLU()]
        for _ in range(num_layers-2):
            layers.append(Linear_fw(layer_size, layer_size))
            layers.append(nn.ReLU())

        layers.append(Linear_fw(layer_size, num_classes))
        layers.append(nn.ReLU())

        return layers
    
    def _update_weight(self, weight, update_value):
            weight.fast = weight * update_value

In [None]:
mlp = MLP_FW(768, 4, 128, 5)

In [None]:
input_embedding = torch.rand((1,5*768))

In [None]:
hypernet = Hypernetwork(input_size=5*768, num_layers=4, layer_size=500, chunk_size=128, chunk_emb_size=8, num_chunks=129*2)

In [None]:
updates = hypernet(input_embedding)

In [None]:
for k, weight in enumerate(list(mlp.parameters())[2:-2]):
    update_value = updates[k]
    mlp._update_weight(weight, update_value)