In [5]:
from src.resnet import ResNet
import torch 
from torch import nn

In [9]:
tensor_a = torch.tensor([-10., 2., 3.])

torch.nn.functional.softplus(tensor_a)

tensor([4.5399e-05, 2.1269e+00, 3.0486e+00])

In [11]:
top_k = torch.topk(tensor_a, 1)
top_k

torch.return_types.topk(
values=tensor([3.]),
indices=tensor([2]))

In [13]:
tensor_a.masked_fill_(tensor_a != top_k.values, -float('inf'))

tensor([-inf, -inf, 3.])

In [15]:

class MoEGate(nn.Module):
    """
    Mixture of Experts (MoE) Gate module.
    """
    def __init__(self, input_channels: int, gate_dim: int, k: int, bias: bool = False) -> None:
        """
        Args:
            input_channels (int): Number of input channels.
            gate_dim (int): Dimensionality of the gate.
            k (int): Number of top gate values to keep.
            bias (bool, optional): Whether or not to add a bias term in the linear layers.
        """
        super().__init__(input_channels, gate_dim)

        self.w_gate = nn.Linear(input_channels, gate_dim, bias=bias)
        self.w_noise = nn.Linear(input_channels, gate_dim, bias=bias)

        self.k = k
    
    def forward(self, x: torch.tensor) -> torch.tensor:
        """
        Computes the output of the MoE gate by computing the outputs of two linear layers,
        adding noise to one of them, and keeping only the top k largest values.

        Args:
            x (torch.tensor): Input tensor.

        Returns:
            torch.tensor: Output of the MoE gate.
        """
        h_noise = self.w_noise(x)
        noise = torch.randn_like(h_noise).to(x.device)
        h_noise = torch.nn.functional.softplus(h_noise).dot(noise)

        h_gate = self.w_gate(x)
        h = h_gate + h_noise

        top_k = torch.topk(h, k=self.k, dim=-1)
        output = h.masked_fill_(h != top_k.values, -float('inf'))

        return output

        

class MoE(nn.Module):
    def __init__(
        self,
        num_experts: int,
        input_channels: int,
        num_classes: int,
        channel_sizes: list[int],
        gate_dim: int
    ) -> None:
        """
        Dummy mixture of experts model.
        """
        super().__init__()
        self.num_experts = num_experts
        experts = []
        for _ in range(num_experts):
            experts.append(
                ResNet(input_channels, num_channels=num_classes, filters=channel_sizes)
            )
        
        self.experts = nn.ModuleList(experts)

        self.gate = MoEGate(input_channels=input_channels, gate_dim=gate_dim)
        
    
    def forward(self, x: torch.tensor):
        pass