In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from Experts import BasicExpert
from Routers import DeepRouter
import numpy as np

In [16]:
'''
k: top k experts
m: number of shared experts
If certain experts are always picked during training, they are selected as shared experts
'''
class DeepMOE(nn.Module):
    def __init__(self, in_features=128, out_features=32, num_experts=8, k=2, balance_weight=0.01, shared_experts_indices=None, inference_mode=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_experts = num_experts
        self.k = k
        self.balance_weight = balance_weight
        self.shared_experts_indices = shared_experts_indices
        self.experts = nn.ModuleList([BasicExpert(in_features, out_features) for _ in range(num_experts)])
        self.expert_indices = np.arange(num_experts)
        self.router = DeepRouter(in_features, k, num_experts, 2, inference_mode)
        self.inference_mode = inference_mode

    def forward(self, x):
        device = x.device
        batch_size = x.size(0)
        output = torch.zeros(batch_size, self.out_features, device=device)
        expert_usage = torch.zeros(self.num_experts, device=device)
        balance_loss = 0

        if self.inference_mode:
            # outputs of shared experts
            for i in self.shared_experts_indices:
                output += self.experts[i](x)
            # exempt shared experts from candidates
            new_experts = [self.experts[i] for i in self.expert_indices if i not in self.shared_experts_indices]

            for i in range(batch_size):
                expert_indices, expert_weights = self.router(x[i].to(device))
                for j in range(len(expert_indices)):
                    expert_idx = expert_indices[j]
                    output = output + expert_weights[j] * new_experts[expert_indices[j]](x[i].to(device))

            return output
            
        else:
            # batch-wise calculation
            for i in range(batch_size):
                expert_indices, expert_weights = self.router(x[i].to(device))
                for j in range(len(expert_indices)):
                    expert_idx = expert_indices[j]
                    output = output + expert_weights[j] * self.experts[expert_indices[j]](x[i].to(device))
                    expert_usage[expert_idx] += 1 # calculate number of usage times

            balance_loss = self.compute_balance_loss(expert_usage, batch_size)
            return output, balance_loss, expert_usage
    
    def compute_balance_loss(self, expert_usage, batch_size):
        expert_freq = expert_usage / batch_size
        # load balancing loss using variance
        balance_loss = torch.var(expert_freq) * self.balance_weight
        return balance_loss

In [17]:
class testModel(nn.Module):
    def __init__(self, inference_mode=False):
        super(testModel, self).__init__()
        self.inference_mode = inference_mode
        self.fc1 = nn.Linear(32*32*3, 128)
        self.moe = DeepMOE()
        self.fc2 = nn.Linear(32, 10)
    
    def forward(self, x):
        device = x.device
        x = F.relu(self.fc1(x.to(device)))
        if self.inference_mode:
            self.moe.inference_mode = self.inference_mode
            x = self.moe(x.to(device))
            x = self.fc2(x.to(device))
            return x
        
        x, balance_loss = self.moe(x.to(device))
        x = self.fc2(x).to(device)

        return x, balance_loss

In [18]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])

train_dataset = datasets.CIFAR10(root='/work/datasets/CIFAR10', train=True, transform=transform)
test_dataset = datasets.CIFAR10(root='/work/datasets/CIFAR10', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=3)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=3)

In [19]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = testModel()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [20]:
num_epochs = 50
total_expert_usage_sum = np.zeros(8)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_balance_loss = 0.0
    expert_usage_sum = np.zeros(8)
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    
    for inputs, labels in train_loader_tqdm:
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = inputs.view(inputs.size(0), -1)
        
        
        optimizer.zero_grad()
        outputs, balance_loss, expert_usage = model(inputs)
        main_loss = criterion(outputs, labels)
        total_loss = main_loss + balance_loss

        total_loss.backward()
        optimizer.step()
        
        running_loss += main_loss.item()
        running_balance_loss += balance_loss.item()
        
        train_loader_tqdm.set_postfix({
            "Loss": running_loss / (train_loader_tqdm.n + 1),
            "Balance Loss": running_balance_loss / (train_loader_tqdm.n + 1)
        })

        expert_usage = expert_usage.cpu().detach().numpy()
        expert_usage_sum += expert_usage

    total_expert_usage_sum += expert_usage_sum

    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Balance Loss: {running_balance_loss/len(train_loader):.4f}")

shared_experts_indices = np.argsort(total_expert_usage_sum)[-2:]

                                                   

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [None]:
model.inference_mode = True
model.moe.shared_experts_indices = shared_experts_indices
model.eval()
correct = 0
total = 0
test_loss = 0
test_balance_loss = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = inputs.view(inputs.size(0), -1)
        
        outputs = model(inputs)
        main_loss = criterion(outputs, labels)
        
        test_loss += main_loss.item()
        
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

test_loss = test_loss / len(test_loader)
test_balance_loss = test_balance_loss / len(test_loader)
accuracy = 100. * correct / total

print(f'Test Loss: {test_loss:.4f} | Test Balance Loss: {test_balance_loss:.4f} | Accuracy: {accuracy:.2f}%')

In [8]:
sum = np.zeros(2)
for i in range(2):
    for j in range(2):
        sum += np.array([j, j])
    total += sum

print(total)

[7. 7.]
