In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [5]:
class Controller(nn.Module):
    def __init__(self, dim_in, dim_lowrank, dim_hidden, num_blocks):
        super(Controller, self).__init__()
        self.dim_in = dim_in
        self.dim_lowrank = dim_lowrank
        self.dim_hidden = dim_hidden
        self.num_blocks = num_blocks
        assert self.dim_hidden % self.num_blocks == 0, "hidden vector must be divisible into N blocks"
        self.U = nn.Linear(dim_in, dim_lowrank, bias = False)
        self.V = nn.Linear(dim_lowrank, dim_hidden, bias = False)
    def forward(self, x):
        logits = self.V(self.U(x))
        original_shape = logits.shape
        logits = logits.reshape(*logits.shape[:-1], self.num_blocks, self.dim_hidden // self.num_blocks)
        if self.training:
            mask = F.gumbel_softmax(logits, tau=0.1, hard=True)
            return mask.reshape(original_shape)
        else:
            selected = torch.argmax(logits, dim=-1)
            mask = F.one_hot(selected, num_classes = self.dim_hidden // self.num_blocks)
            return mask.reshape(original_shape)
            
            

In [6]:
A = torch.rand(2,5,3,4)
am = torch.argmax(A,dim=-1)
onehot = F.one_hot(am, num_classes=4)
print(A)
print(onehot)
print(onehot.reshape(2,5,12))

tensor([[[[0.8200, 0.7036, 0.2224, 0.4481],
          [0.6884, 0.9172, 0.6434, 0.4819],
          [0.2516, 0.6864, 0.6995, 0.4641]],

         [[0.7833, 0.0350, 0.8097, 0.2948],
          [0.4111, 0.9093, 0.7296, 0.9314],
          [0.8385, 0.9954, 0.9574, 0.2396]],

         [[0.5223, 0.8909, 0.0233, 0.5707],
          [0.0803, 0.1416, 0.1115, 0.6671],
          [0.7048, 0.2429, 0.8210, 0.7969]],

         [[0.2291, 0.8161, 0.5855, 0.9163],
          [0.9725, 0.7026, 0.8404, 0.8103],
          [0.2154, 0.5107, 0.9583, 0.8376]],

         [[0.4464, 0.6062, 0.3336, 0.8799],
          [0.1320, 0.7307, 0.0491, 0.3558],
          [0.3795, 0.6653, 0.8301, 0.7284]]],


        [[[0.2127, 0.1470, 0.8098, 0.9140],
          [0.9478, 0.1481, 0.3205, 0.0146],
          [0.0550, 0.4292, 0.6248, 0.4054]],

         [[0.4921, 0.7932, 0.3811, 0.2717],
          [0.6979, 0.5866, 0.2212, 0.5477],
          [0.6557, 0.9372, 0.2888, 0.7619]],

         [[0.0366, 0.5994, 0.2568, 0.7806],
          [0.284

In [7]:
A = torch.rand(2,5,3,4)
gs = F.gumbel_softmax(A, tau=0.1, hard=True)
print(gs)
print(gs.reshape(2,5,12))

tensor([[[[1., 0., 0., 0.],
          [0., 0., 0., 1.],
          [1., 0., 0., 0.]],

         [[0., 0., 0., 1.],
          [0., 0., 0., 1.],
          [0., 0., 1., 0.]],

         [[0., 0., 0., 1.],
          [0., 0., 0., 1.],
          [0., 0., 0., 1.]],

         [[1., 0., 0., 0.],
          [0., 1., 0., 0.],
          [0., 1., 0., 0.]],

         [[1., 0., 0., 0.],
          [0., 0., 1., 0.],
          [0., 0., 0., 1.]]],


        [[[0., 0., 0., 1.],
          [0., 0., 1., 0.],
          [0., 0., 0., 1.]],

         [[0., 1., 0., 0.],
          [1., 0., 0., 0.],
          [0., 0., 1., 0.]],

         [[0., 0., 0., 1.],
          [0., 0., 1., 0.],
          [1., 0., 0., 0.]],

         [[0., 1., 0., 0.],
          [0., 1., 0., 0.],
          [0., 0., 0., 1.]],

         [[1., 0., 0., 0.],
          [1., 0., 0., 0.],
          [0., 1., 0., 0.]]]])
tensor([[[1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0.],
         [0., 0., 0

In [8]:
cnt = Controller(6,3,8,2)
print(cnt(torch.rand(100,50,6)).shape)

torch.Size([100, 50, 8])


In [9]:
class ControllerFFN(nn.Module):
    def __init__(self, dim_in, dim_lowrank, dim_hidden, num_blocks):
        super(ControllerFFN, self).__init__()
        self.dim_in = dim_in
        self.dim_lowrank = dim_lowrank
        self.dim_hidden = dim_hidden
        self.num_blocks = num_blocks
        assert self.dim_hidden % self.num_blocks == 0, "hidden vector must be divisible into N blocks"
        self.controller = Controller(dim_in, dim_lowrank, dim_hidden, num_blocks)
        self.layer1 = nn.Linear(dim_in, dim_hidden)
        self.layer2 = nn.Linear(dim_hidden, dim_in)
    def forward(self, x):
        return self.layer2(self.controller(x)* F.relu(self.layer1(x)))

In [10]:
cntffn = ControllerFFN(6,3,8,2)
print(cntffn(torch.rand(100,50,6)).shape)

torch.Size([100, 50, 6])
