In [7]:
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score

In [8]:
class Net(nn.Module):
    def __init__(self,in_ch,out_ch1,kernel_size,act):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_ch,out_ch1 , kernel_size)
        #self.conv2 = nn.Conv2d(out_ch1, out_ch2, kernel_size)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(15*15*10, 50)
        self.fc2 = nn.Linear(50, 10)
        self.act = act #0 for linear and 1 for non-linear
    def activation(self,x):
        return x**3
    def forward(self, x):
      #print(x.shape)
      if(self.act==1):
          x = self.activation(self.conv1(x))
      else:
          x = self.conv1(x)
      x = self.pool(x)
      #print(x.shape)
      x = x.view(-1, 15*15*10)
      x = self.fc1(x)
      x = self.fc2(x)
      return x;

In [9]:
class Router(nn.Module):
    def __init__(self,n_experts,in_ch,out_ch1):
      super(Router, self).__init__()
      self.n_experts = n_experts
      self.conv1 = nn.Conv2d(in_ch, 6, 3)
      self.pool = nn.MaxPool2d(2, 2)
      self.fc1 = nn.Linear( 6*3 * 3, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, n_experts)
    def forward(self, x):
      x = self.pool(F.relu(self.conv1(x)))
      x = x.view(-1, 6 * 3 * 3)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x

In [10]:
class MoE:
  def __init__(self,n_experts,in_ch,out_ch1,patches,device):
    self.n_experts = n_experts
    self.router = Router(n_experts,in_ch,out_ch1).to(device)
    self.experts = nn.ModuleList([Net(in_ch,out_ch1,3,1) for i in range(n_experts)]).to(device)
    self.num_patches = patches
    self.loss = nn.CrossEntropyLoss()
    self.optim_router = torch.optim.SGD(self.router.parameters(), lr=0.001, momentum=0.9)
    self.expert_optimizers = [torch.optim.Adam(self.experts[i].parameters(),lr = 0.001) for i in range(self.n_experts)]
    self.device = device
  def fit(self,trainloader,device):
    patch_size = 32//self.num_patches
    tot_loss = 0
    for batch in trainloader:
      X,y = batch
      y = torch.nn.functional.one_hot(y, num_classes=10).float().to(device)
      for i in range(len(X)):
        x = X[i].to(device)
        #plt.imshow(np.reshape(x,(32,32,3)))
        y_i = y[i].to(device)
        #print(y_i)
        tot_val = torch.zeros(self.n_experts).to(device)
        for j in range(self.num_patches):
          for k in range(self.num_patches):
            x_patch = x[:,j*patch_size:(j+1)*patch_size,k*patch_size:(k+1)*patch_size].to(device)
            #print(x_patch.shape)
            val = self.router(x_patch)
            tot_val = tot_val + val
        expert_chosen = torch.argmax(tot_val)
        y_pred = self.experts[expert_chosen](x).view(-1).to(device)
        #print(y_pred)
        #print(y_i)
        #input()
        loss = self.loss(y_pred,y_i)
        tot_loss = tot_loss + loss
        self.optim_router.zero_grad()
        self.expert_optimizers[expert_chosen].zero_grad()
        loss.backward()
        self.optim_router.step()
        self.expert_optimizers[expert_chosen].step()
    return tot_loss/len(trainloader)
  def predict(self,X):
    patch_size = 32//self.num_patches
    pred_values = np.zeros(len(X))
    for i in range(len(X)):
      x = X[i]
      tot_val = torch.zeros(self.n_experts)
      for j in range(self.num_patches):
        for k in range(self.num_patches):
          x_patch = x[:,j*patch_size:(j+1)*patch_size,k*patch_size:(k+1)*patch_size]
          val = self.router(x_patch)
          tot_val = tot_val + val
      expert_chosen = torch.argmax(tot_val)
      y_pred = self.experts[expert_chosen](x)
      pred_values[i] = torch.argmax(y_pred).cpu().item()
    return pred_values

In [12]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                          shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model =MoE(4,3,10,4,device)#n_experts,in_ch,out_ch1,patches
n_epochs = 10
loss_list = []
for i in range(n_epochs):
  present_loss = model.fit(trainloader,device)
  loss_list.append(present_loss)
  print(f"Epoch {i+1}/{n_epochs}, Loss: {present_loss}")

Files already downloaded and verified
Files already downloaded and verified
Epoch 1/10, Loss: 6.235317230224609
Epoch 2/10, Loss: 5.600617408752441
Epoch 3/10, Loss: 5.425222396850586
Epoch 4/10, Loss: 5.310474872589111
Epoch 5/10, Loss: 5.220433712005615
Epoch 6/10, Loss: 5.14109468460083
Epoch 7/10, Loss: 5.08414363861084
Epoch 8/10, Loss: 5.053863525390625
Epoch 9/10, Loss: 5.027271270751953
Epoch 10/10, Loss: 4.99997091293335


In [23]:
def predict(model,X,device):
  patch_size = 32//4
  pred_values = np.zeros(len(X))
  num_patches = 4
  for i in range(len(X)):
    x = X[i].to(device)
    tot_val = torch.zeros(model.n_experts)
    for j in range(num_patches):
      for k in range(num_patches):
        x_patch = x[:,j*patch_size:(j+1)*patch_size,k*patch_size:(k+1)*patch_size].to(device)
        val = model.router(x_patch).cpu()
        tot_val = tot_val + val
    expert_chosen = torch.argmax(tot_val)
    y_pred = model.experts[expert_chosen](x)
    pred_values[i] = torch.argmax(y_pred).cpu().item()
  return pred_values

In [24]:
all_predictions = []
all_true_labels = []

# Predict and calculate accuracy
with torch.no_grad():  # Disable gradient tracking for inference
    for batch in testloader:
        X_test, y_test = batch  # Extract test inputs and true labels
        preds = predict(model,X_test,device)  # Get predictions
        all_predictions.extend(preds.tolist())  # Collect predictions
        all_true_labels.extend(y_test.tolist())  # Collect true labels

# Compute accuracy score
accuracy = accuracy_score(all_true_labels, all_predictions)
print(f"Accuracy: {accuracy * 100:.2f}%")

Accuracy: 52.68%


tensor(4)
