In [16]:
from src.resnet import ResNet
import torch 
from torch import nn

In [17]:
tensor_a = torch.tensor([-10., 2., 3.])

torch.nn.functional.softplus(tensor_a)

tensor([4.5399e-05, 2.1269e+00, 3.0486e+00])

In [18]:
top_k = torch.topk(tensor_a, 1)
top_k

torch.return_types.topk(
values=tensor([3.]),
indices=tensor([2]))

In [19]:
tensor_a.masked_fill_(tensor_a != top_k.values, -float('inf'))

tensor([-inf, -inf, 3.])

In [20]:
class MoEGate(nn.Module):
    """
    Mixture of Experts (MoE) Gate module.
    """
    def __init__(self, input_channels: int, gate_dim: int, k: int, bias: bool = False) -> None:
        """
        Args:
            input_channels (int): Number of input channels.
            gate_dim (int): Dimensionality of the gate.
            k (int): Number of top gate values to keep.
            bias (bool, optional): Whether or not to add a bias term in the linear layers.
        """
        super().__init__(input_channels, gate_dim)

        self.w_gate = nn.Linear(input_channels, gate_dim, bias=bias)
        self.w_noise = nn.Linear(input_channels, gate_dim, bias=bias)

        self.k = k
    
    def forward(self, x: torch.tensor) -> torch.tensor:
        """
        Computes the output of the MoE gate by computing the outputs of two linear layers,
        adding noise to one of them, and keeping only the top k largest values.

        Args:
            x (torch.tensor): Input tensor.

        Returns:
            torch.tensor: Output of the MoE gate.
        """
        h_noise = self.w_noise(x)
        noise = torch.randn_like(h_noise).to(x.device)
        h_noise = torch.nn.functional.softplus(h_noise).dot(noise)

        h_gate = self.w_gate(x)
        h = h_gate + h_noise

        top_k = torch.topk(h, k=self.k, dim=-1)
        output = h.masked_fill_(h != top_k.values, -float('inf'))

        return output

        

class MoE(nn.Module):
    def __init__(
        self,
        num_experts: int,
        input_channels: int,
        num_classes: int,
        channel_sizes: list[int],
        gate_dim: int
    ) -> None:
        """
        Dummy mixture of experts model.
        """
        super().__init__()
        self.num_experts = num_experts
        experts = []
        for _ in range(num_experts):
            experts.append(
                ResNet(input_channels, num_channels=num_classes, filters=channel_sizes)
            )
        
        self.experts = nn.ModuleList(experts)

        self.gate = MoEGate(input_channels=input_channels, gate_dim=gate_dim)
        
    
    def forward(self, x: torch.tensor):
        pass

Problem Statement:

Given an array of integers nums and an integer target, find all unique pairs in nums that sum up to target. Return these pairs as a list of lists, where each inner list contains two elements representing the pair.

Example 1:
Input: nums = [2,7,11,15], target = 9
Output: [[2,7]]

Example 2:
Input: nums = [2,5,8,-4,-6], target = 10
Output: [[-6,4]]

Example 3:
Input: nums = [1,2,3,4,5], target = 10
Output: []

Class Definition:

In [None]:
class PairSum:
    def two_sum(self, nums: list[int], target: int) -> list[list[int]]:
        """
        Find all unique pairs in the input array that sum up to the target.

        Args:
            nums (list[int]): The input array of integers.
            target (int): The target sum.

        Returns:
            list[list[int]]: A list of lists, where each inner list contains two elements representing a pair of numbers from `nums` that sum up to `target`.
        """

        """
        Brute force: 

        O(n^2) iteration where we go through every single element in the list and every other element in the list.

        Optimization opportunities: 

        - slight optimization is that when we advance to the next index, we only need to check idx+1...n
        """
        pairs = []
        for i in range(len(nums) - 1):
            for j in range(i, len(nums)):
                if nums[i]+nums[j] == target:
                    pairs.append([nums[i], nums[j]])
        return pairs



In [None]:
import unittest

class TestPairSum(unittest.TestCase):
    def setUp(self):
        self.pair_sum = PairSum()

    def test_example_1(self):
        nums = [2, 7, 11, 15]
        target = 9
        expected_output = [[2, 7]]
        self.assertEqual(self.pair_sum.two_sum(nums, target), expected_output)

    def test_example_2(self):
        nums = [2, 5, 8, -4, -6]
        target = 10
        expected_output = [[-6, 4]]
        self.assertEqual(self.pair_sum.two_sum(nums, target), expected_output)

    def test_example_3(self):
        nums = [1, 2, 3, 4, 5]
        target = 10
        expected_output = []
        self.assertEqual(self.pair_sum.two_sum(nums, target), expected_output)

    def test_empty_input_array(self):
        nums = []
        target = 0
        expected_output = []
        self.assertEqual(self.pair_sum.two_sum(nums, target), expected_output)

    def test_single_element_input_array(self):
        nums = [5]
        target = 10
        expected_output = []
        self.assertEqual(self.pair_sum.two_sum(nums, target), expected_output)

if __name__ == "__main__":
    unittest.main()
