In [2]:
from src.resnet import ResNet
import torch 
from torch import nn

In [3]:
tensor_a = torch.tensor([-10., 2., 3.])

torch.nn.functional.softplus(tensor_a)

tensor([4.5399e-05, 2.1269e+00, 3.0486e+00])

In [4]:
top_k = torch.topk(tensor_a, 1)
top_k

torch.return_types.topk(
values=tensor([3.]),
indices=tensor([2]))

In [5]:
tensor_a.masked_fill_(tensor_a != top_k.values, -float('inf'))

tensor([-inf, -inf, 3.])

In [6]:
class MoEGate(nn.Module):
    """
    Mixture of Experts (MoE) Gate module.
    """
    def __init__(self, input_channels: int, gate_dim: int, k: int, bias: bool = False) -> None:
        """
        Args:
            input_channels (int): Number of input channels.
            gate_dim (int): Dimensionality of the gate.
            k (int): Number of top gate values to keep.
            bias (bool, optional): Whether or not to add a bias term in the linear layers.
        """
        super().__init__(input_channels, gate_dim)

        self.w_gate = nn.Linear(input_channels, gate_dim, bias=bias)
        self.w_noise = nn.Linear(input_channels, gate_dim, bias=bias)

        self.k = k
    
    def forward(self, x: torch.tensor) -> torch.tensor:
        """
        Computes the output of the MoE gate by computing the outputs of two linear layers,
        adding noise to one of them, and keeping only the top k largest values.

        Args:
            x (torch.tensor): Input tensor.

        Returns:
            torch.tensor: Output of the MoE gate.
        """
        h_noise = self.w_noise(x)
        noise = torch.randn_like(h_noise).to(x.device)
        h_noise = torch.nn.functional.softplus(h_noise).dot(noise)

        h_gate = self.w_gate(x)
        h = h_gate + h_noise

        top_k = torch.topk(h, k=self.k, dim=-1)
        output = h.masked_fill_(h != top_k.values, -float('inf'))

        return output

        

class MoE(nn.Module):
    def __init__(
        self,
        num_experts: int,
        input_channels: int,
        num_classes: int,
        channel_sizes: list[int],
        gate_dim: int,
        dropout: float
    ) -> None:
        """
        Dummy mixture of experts model.
        """
        super().__init__()
        self.num_experts = num_experts
        experts = []
        for _ in range(num_experts):
            resnet = ResNet(input_channels, num_channels=num_classes, filters=channel_sizes, dropout=dropout)
            
            # delete last few layers to extract feature maps
            del resnet.linear
            del resnet.flatten
            del resnet.avgpool
            del resnet.dropout
            experts.append(resnet)
        
        self.experts = nn.ModuleList(experts)

        self.gate = MoEGate(input_channels=input_channels, gate_dim=gate_dim)

        self.linear = nn.Linear(channel_sizes[-1], num_classes)
        self.dropout = nn.Dropout(p=dropout)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        
        
    
    def forward(self, x: torch.tensor):
        """
        The forward pass of the MoE model.
        """

SyntaxError: invalid syntax (2975407303.py, line 62)

In [9]:
resnet = ResNet(num_channels=3, num_classes=10, filters=[32, 64, 128])

In [19]:
resnet

ResNet(
  (res_layers): Sequential(
    (0): ResBlock(
      (norm1): Sequential(
        (0): GroupNorm(1, 3, eps=1e-05, affine=True)
        (1): ReLU()
      )
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): GroupNorm(1, 32, eps=1e-05, affine=True)
      )
      (activation): ReLU()
      (idconv): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
      (avgpool): Identity()
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (1): ResBlock(
      (norm1): Sequential(
        (0): GroupNorm(1, 32, eps=1e-05, affine=True)
        (1): ReLU()
      )
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (conv2): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): GroupNorm(1, 64, eps=1e-05, affine=True)
      )
      (activation): ReLU()
      (idcon