In [17]:
import torch 
torch.set_default_device("cuda")
import torch.nn as nn  
import torch.nn.functional as F
import cv2
import numpy as np
import os
import pandas as pd
import pickle

def load_data():
    train_data = []
    train_label = []
    test_data = []
    test_label = []
    with open("./cifar-10-batches-py/data_batch_1", 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
        train_data.extend(dict["data"])
        train_label.extend(dict["labels"])
    with open("./cifar-10-batches-py/data_batch_2", 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
        train_data.extend(dict["data"])
        train_label.extend(dict["labels"])
    with open("./cifar-10-batches-py/data_batch_3", 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
        train_data.extend(dict["data"])
        train_label.extend(dict["labels"])
    with open("./cifar-10-batches-py/data_batch_4", 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
        train_data.extend(dict["data"])
        train_label.extend(dict["labels"])
    with open("./cifar-10-batches-py/data_batch_5", 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
        train_data.extend(dict["data"])
        train_label.extend(dict["labels"])
    with open("./cifar-10-batches-py/test_batch", 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
        test_data.extend(dict["data"])
        test_label.extend(dict["labels"])
    return (torch.tensor(train_data)/255.0, 
    F.one_hot(torch.tensor(train_label), num_classes=10), 
    torch.tensor(test_data)/255.0, 
    F.one_hot(torch.tensor(test_label), num_classes=10))
    return (train_data,train_label,test_data,test_label)

train_data,train_label,test_data,test_label=load_data()
 

  return func(*args, **kwargs)


In [18]:
train_label = train_label *1.0
train_label
test_label = test_label *1.0
test_label

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.]], device='cuda:0')

In [19]:
train_data = train_data.reshape((50000,32,32,3))
test_data = test_data.reshape((10000,32,32,3))


In [20]:
class Expert(nn.Module):  
    def __init__(self, output_dim):  
        super(Expert, self).__init__()  
        self.conv1 = nn.Conv2d(32, 64, (2,1))
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(64, 128, (2,1))
        self.fc1 = nn.Linear(1792, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, output_dim)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
          
    def forward(self, x):  
        x= F.leaky_relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        x = x.view(-1, 1792)
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [147]:
print(test_data[:1].shape)
ex =Expert(11)
ex(test_data[:1])

torch.Size([1, 32, 32, 3])


tensor([[0.5041, 0.5182, 0.5359, 0.5344, 0.4728, 0.5175, 0.4742, 0.5073, 0.4653,
         0.4608, 0.4966]], grad_fn=<SigmoidBackward0>)

In [21]:
class SparseGate(nn.Module):  
    def __init__(self,num_experts, threshold=0.5):  
        super(SparseGate, self).__init__()  
        self.num_experts = num_experts
        self.conv1 = nn.Conv2d(32, 64, (2,1))
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(64, 128, (2,1))
        self.fc1 = nn.Linear(1792, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_experts)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.threshold = threshold  
          
    def forward(self, x):  
        x= F.leaky_relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        x = x.view(-1, 1792)
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        scores = torch.sigmoid(x)   
        # get k of top k
        k = self.num_experts * self.threshold
        topk,_ = torch.topk(scores, int(k))
        mask = []
        for index in range(self.num_experts):
            #if scores[0][index] > topk.min() or scores[0][index] == topk.min():
            if scores[0][index] > self.threshold:
                mask.append(scores[0][index])
            else:
                mask.append(0.0)
        mask =torch.tensor(mask)
        # mask all zero scores
        return mask

In [22]:
class ExpertMask(nn.Module): 
    def __init__(self, output_dim):  
        super(ExpertMask, self).__init__()
        self.output_dim = output_dim
        self.fl = nn.Flatten()
        self.mask = nn.Linear(1024*3, output_dim)
    def forward(self, x):  
        v = self.mask(self.fl(x))*0.0
        return v

In [58]:
class MoE(nn.Module):  
    def __init__(self,  output_dim, num_experts, threshold=0.5):  
        super(MoE, self).__init__()  
        self.experts = nn.ModuleList([Expert( output_dim) for _ in range(num_experts)])  
        self.gate = SparseGate(num_experts, threshold)  
        self.num_experts = num_experts
        self.output_dim =output_dim
          
    def human_setup_compute(self, x, human_set_gate_factors = None):
        if human_set_gate_factors is not None:
            expert_scores = human_set_gate_factors
        activated_experts  = []
        for expert_index in range(self.num_experts):
            if expert_scores[expert_index] > torch.tensor(0.0):
                activated_experts.append(self.experts[expert_index])
            else:
                activated_experts.append(ExpertMask(self.output_dim))
        expert_outputs = torch.stack([expert(x) if isinstance(expert, Expert) else torch.zeros((x.shape[0],self.output_dim)) for expert in activated_experts], dim=1) 
        
        gate_outputs = expert_scores.unsqueeze(-1) 
        
        final_output = torch.sum(gate_outputs * expert_outputs, dim=1) 
        final_output = F.sigmoid(final_output)
        #print(final_output)
        return final_output  

    def forward(self, x):  
        expert_scores = self.gate(x)  
        activated_experts  = []
        for expert_index in range(self.num_experts):
            if expert_scores[expert_index] > torch.tensor(0.0):
                activated_experts.append(self.experts[expert_index])
            else:
                activated_experts.append(ExpertMask(self.output_dim))
        expert_outputs = torch.stack([expert(x) if isinstance(expert, Expert) else torch.zeros((x.shape[0],self.output_dim)) for expert in activated_experts], dim=1)
        gate_outputs = expert_scores.unsqueeze(-1) 
        
        final_output = torch.sum(gate_outputs * expert_outputs, dim=1) 
        final_output = F.sigmoid(final_output)
        #print(final_output)
        return final_output  

In [24]:
class OneHotLoss(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x,y):
        return torch.sum(torch.abs(x-y))

In [81]:
def train(train_data: torch.Tensor, train_label: torch.Tensor, test_data: torch.Tensor, test_label: torch.Tensor):
    output_dim = 10  
    num_experts =20
    threshold = 0.5
    batch_size = 100
    #model = Expert(output_dim)
    model = MoE(output_dim, num_experts, threshold)  
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_function = nn.BCELoss()
    loss_sum = 0
    for epoch in range(10):
        for step in range(0, 50000, batch_size):
            optimizer.zero_grad()
            step_result = model(train_data[step:step+batch_size])
            loss = loss_function(step_result, train_label[step:step+batch_size])
            loss_sum = loss_sum+ loss.item()
            loss.backward() 
            optimizer.step() 
            #print([x.grad for x in optimizer.param_groups[0]["params"]])
            # for params in model.parameters():
            #     print(f"G: {params.grad}")
        print(f'E:{epoch} {loss_sum/500}')
        loss_sum = 0
    return model
                 


In [82]:
trained_model = train(train_data, train_label, test_data, test_label)

E:0 0.2285371019244194
E:1 0.18469052982330322
E:2 0.16142360980808734
E:3 0.1393040187060833
E:4 0.1167462460398674
E:5 0.09433456543833017
E:6 0.07594361340999603
E:7 0.06724514252319932
E:8 0.060273028373718264
E:9 0.05009632034227252


In [83]:
index =78

In [84]:
with torch.no_grad():
    print(trained_model.gate(test_data[index].reshape((1,32,32,3))))
    print(trained_model(test_data[index].reshape((1,32,32,3))))
    print(test_label[index])

tensor([0.0000, 0.0000, 0.5315, 0.0000, 0.5111, 0.0000, 0.0000, 0.0000, 0.5145,
        0.5742, 0.5114, 0.0000, 0.0000, 0.0000, 0.0000, 0.5085, 0.5149, 0.5038,
        0.0000, 0.5587], device='cuda:0')
tensor([[8.6667e-04, 3.6607e-01, 1.9934e-02, 7.4488e-01, 2.2469e-04, 4.7849e-05,
         3.3677e-03, 3.7516e-05, 2.6871e-02, 6.3244e-01]], device='cuda:0')
tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], device='cuda:0')


In [96]:
with torch.no_grad():
    ll = trained_model.human_setup_compute(test_data[index:index+1], 
    torch.tensor([0.0000, 0.0000, 0.5315, 0.0000, 0.5111, 0.0000, 0.0000, 0.0000, 0.5145,
        0.5742, 0.5114, 0.0000, 0.0000, 0.0000, 0.0000, 0.5085, 0.9149, 0.9038,
        0.0000, 0.2587]))
    print(ll)

tensor([[4.5552e-04, 3.7660e-01, 4.0578e-02, 9.6198e-01, 3.8759e-04, 2.2229e-05,
         5.4764e-03, 3.8834e-05, 2.1863e-02, 5.6670e-01]], device='cuda:0')
