In [7]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch
from torchvision import transforms as T
import scipy.io 
from PIL import Image
import numpy as np

pa_100k_group_order = [7,8,13,14,15,16,17,18,19,20,21,22,23,24,25,9,10,11,12,1,2,3,0,4,5,6]
pa_100k_num_in_group = [2, 6, 6, 1, 4, 3, 1, 3]


class PA100KDataset(Dataset):
    def __init__(self, root_dir, transforms, split, use_multitask=False):
        self.annotations = scipy.io.loadmat("./data/PA-100K/annotation.mat")
        self.file_paths = self.annotations[f"{split}_images_name"]
        self.labels = self.annotations[f"{split}_label"]
        self.root_dir = root_dir
        self.transforms = transforms
        self.use_multitask = use_multitask
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        image_path = self.file_paths[index][0][0]
        image = Image.open(os.path.join(self.root_dir, image_path))
        if self.transforms:
            image = self.transforms(image)
        if self.use_multitask:
            group_label = []
            label = self.labels[index]

            for group in range(len(pa_100k_num_in_group)):
                group_num = pa_100k_num_in_group[group]
                start_index = pa_100k_group_order[sum(pa_100k_num_in_group[:group])]
                end_index = pa_100k_group_order[sum(pa_100k_num_in_group[:group]) + group_num + 1]
                print(start_index, end_index)
                group_label.append(np.argmax(self.labels[index][start_index:end_index]))
            return group_label

        label = self.labels[index]
        return image, label

In [9]:
dataset = PA100KDataset(
    root_dir="./data/PA-100K/release_data/release_data", 
    transforms=T.ToTensor(), 
    split="train", 
    use_multitask=True
)

In [10]:
dataset.__getitem__(0)

7 14
13 20
19 9


ValueError: attempt to get argmax of an empty sequence

In [2]:
from st_moe_pytorch import MoE

moe = MoE(
    dim = 512,
    num_experts = 4,               # increase the experts (# parameters) of your model without increasing computation
    gating_top_n = 2,               # default to top 2 gating, but can also be more (3 was tested in the paper with a lower threshold)
)

In [57]:
import torch.nn as nn

class MoETransformer(nn.Module):
    def __init__(self, input_dim, num_heads, num_experts):
        super(MoETransformer, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(input_dim, num_heads, batch_first=True)
        self.moelayer = MoE(dim=input_dim, num_experts=num_experts)
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
    
    def forward(self, q, k, v):
        attn_output, _ = self.multihead_attn(q, k, v)
        
        x = self.norm1(q + attn_output)
        moe_output, total_aux_loss, _, _ = self.moelayer(x)
        x = self.norm2(x + moe_output)
        return x, total_aux_loss
    
class MoEFusionHead(nn.Module):
    def __init__(self, input_dim, num_heads, num_experts):
        super(MoEFusionHead, self).__init__()
        self.transformer = MoETransformer(input_dim, num_heads, num_experts)
        self.pooling = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(input_dim, 1)

    def forward(self, x1, x2):
        outputs, aux_loss = self.transformer(x1, x2, x2)
        outputs = self.pooling(outputs.transpose(1, 2)).squeeze(-1)
        return self.fc(outputs), aux_loss

In [58]:
import torch
features1 = torch.rand(1, 49, 768)
features2 = torch.rand(1, 197, 768)

In [60]:
fusion_layer = MoEFusionHead(768, 2, 4)
outs, loss = fusion_layer(features2, features1)

torch.Size([1, 197, 768])
torch.Size([1, 768])
