In [108]:
from src.resnet import ResNet
import torch 
from torch import nn

In [109]:
tensor_a = torch.tensor([-10., 2., 3.])

torch.nn.functional.softplus(tensor_a)

tensor([4.5399e-05, 2.1269e+00, 3.0486e+00])

In [110]:
top_k = torch.topk(tensor_a, 1)
top_k

torch.return_types.topk(
values=tensor([3.]),
indices=tensor([2]))

In [111]:
tensor_a.masked_fill_(tensor_a != top_k.values, -float('inf'))

tensor([-inf, -inf, 3.])

In [112]:
class MoEGate(nn.Module):
    """
    Mixture of Experts (MoE) Gate module.
    """
    def __init__(self, input_size: int, gate_dim: int, k: int, bias: bool = False) -> None:
        """
        Args:
            input_size (int): Number of input channels.
            gate_dim (int): Dimensionality of the gate.
            k (int): Number of top gate values to keep.
            bias (bool, optional): Whether or not to add a bias term in the linear layers.
        """
        super().__init__()

        self.w_gate = nn.Linear(input_size, gate_dim, bias=bias)
        self.w_noise = nn.Linear(input_size, gate_dim, bias=bias)

        self.k = k
    
    def forward(self, x: torch.tensor) -> torch.tensor:
        """
        Computes the output of the MoE gate by computing the outputs of two linear layers,
        adding noise to one of them, and keeping only the top k largest values.

        Args:
            x (torch.tensor): Input tensor.

        Returns:
            torch.tensor: Output of the MoE gate.
        """
        h_noise = self.w_noise(x)
        noise = torch.randn_like(h_noise).to(x.device)
        h_noise = torch.nn.functional.softplus(h_noise)
        h_noise = h_noise * noise

        h_gate = self.w_gate(x)
        h = h_gate + h_noise

        top_k = torch.topk(h, k=self.k, dim=-1)
        
        output = h.masked_fill_(~torch.isin(h, top_k.values), -float('inf'))

        return output

        

class MoE(nn.Module):
    def __init__(
        self,
        num_experts: int,
        input_size: int,
        input_channels: int,
        num_classes: int,
        channel_sizes: list[int],
        gate_dim: int,
        dropout: float,
        k: int
    ) -> None:
        """
        Dummy mixture of experts model.
        """
        super().__init__()
        self.num_experts = num_experts
        experts = []

        self.gate = MoEGate(input_size=input_size, gate_dim=gate_dim, k=k)

        for _ in range(num_experts):
            resnet = ResNet(num_channels=input_channels, num_classes=num_classes, filters=channel_sizes, dropout=dropout)
            experts.append(resnet)
            del resnet.linear
            del resnet.flatten
            del resnet.avgpool
            del resnet.dropout
        
        self.experts = nn.ModuleList(experts)
        self.linear = nn.Linear(channel_sizes[-1], num_classes)
        self.dropout = nn.Dropout(p=dropout)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        
        
    
    def forward(self, x: torch.tensor):
        """
        The forward pass of the MoE model.
        """
        gate_out = self.gate(x)
        gate_weights = torch.nn.functional.softmax(gate_out, dim=-1).squeeze()
        # I want to do a weighted sum of the experts, where the weigts are the result of the softmax
        # we also don't want to compute the output of each expert since N-k of them are masked out
        return gate_weights

        


In [113]:
moe = MoE(2, 3072, 3, 10, [8,16,32], 8, 0.1, 2)


In [114]:
moe

MoE(
  (gate): MoEGate(
    (w_gate): Linear(in_features=3072, out_features=8, bias=False)
    (w_noise): Linear(in_features=3072, out_features=8, bias=False)
  )
  (experts): ModuleList(
    (0-1): 2 x ResNet(
      (res_layers): Sequential(
        (0): ResBlock(
          (norm1): Sequential(
            (0): GroupNorm(1, 3, eps=1e-05, affine=True)
            (1): ReLU()
          )
          (conv1): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv2): Sequential(
            (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): GroupNorm(1, 8, eps=1e-05, affine=True)
          )
          (activation): ReLU()
          (idconv): Conv2d(3, 8, kernel_size=(1, 1), stride=(1, 1))
          (avgpool): Identity()
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (1): ResBlock(
          (norm1): Sequential(
            (0): GroupNorm(1, 8, eps=1e-05, affine=True)
            (1): ReLU()
          )
       

In [115]:
x = torch.randn(1, 3, 32, 32)
x = x.view(x.shape[0], -1)
moe(x)


tensor([0.0000, 0.0000, 0.8182, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
       grad_fn=<SqueezeBackward0>)