# Try
My personal notes 

In [11]:
import torch
import torch.nn as nn

# Feature Selector
class WSS(nn.Module):
    def __init__(self, in_channel, num_classes, num_selects) -> None:
        super().__init__()
        self.fc = nn.Linear(in_channel, num_classes)
        self.num_selects = num_selects
    
    def forward(self, x):
        # Input: x = [B, HxW, C] 
        # Output: class_prediction = [B, HxW, num_classes], selectd_features = [B, num_selects, num_classes]
        # First project the x by linear, sort it, and use the index to collect the original feature (x).
        h = self.fc(x)
        logits = torch.softmax(h, dim=-1)
        _, ids = torch.sort(logits, dim=-1, descending=True)
        selection = ids[:, :self.num_selects]
        return logits, torch.gather(x, 1, selection)



In [13]:
data = torch.rand((1, 16, 8))
model = WSS(in_channel=8, num_classes=6, num_selects=4)
logits, features = model(data)

print(logits.shape, features.shape)


torch.Size([1, 16, 6]) torch.Size([1, 4, 6])
