# CNN tests

In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
import random

from classiferNN import CNN,evaluate_model,train_model
from MOE import MOE_CNN

%load_ext autoreload
%autoreload 2

* Set Up

In [2]:
device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_workers=0
pin_memory=False
print(f"Device used: {device}")

Device used: cpu


In [3]:
mean = 0.3240
std = 0.3612
batch_size = 128
num_experts = 4
sample_fraction = 0.60
seed = 42

# Transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(26),
    transforms.Normalize(mean=(mean,), std=(std,))
])

# Load datasets
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# Split train/val (80/20)
g = torch.Generator().manual_seed(seed)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = random_split(train_dataset, [train_size, val_size], generator=g)

# Expert loaders
train_indices = list(range(train_size))
num_samples = int(sample_fraction * train_size)

expert_train_loaders = []
for i in range(num_experts):
    sampled_indices = random.sample(train_indices, num_samples)  # Without replacement
    sampled_subset = Subset(train_subset, sampled_indices)
    loader = DataLoader(sampled_subset, batch_size=batch_size, shuffle=True,num_workers=num_workers, pin_memory=pin_memory)
    expert_train_loaders.append(loader)

# Validation and test loaders
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False,num_workers=num_workers, pin_memory=pin_memory)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,num_workers=num_workers, pin_memory=pin_memory)


In [4]:
(len(expert_train_loaders[0])*batch_size)/train_size

0.6

In [5]:
model = CNN(dim = 26,conv_channels= [4,4],fc_dims=[16,16],kernel_size =5,stride=3,padding=1,max_pooling=False)
model

CNN(
  (conv): Sequential(
    (0): Conv2d(1, 4, kernel_size=(5, 5), stride=(3, 3), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(4, 4, kernel_size=(5, 5), stride=(3, 3), padding=(1, 1))
    (3): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=16, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=16, bias=True)
    (3): ReLU()
    (4): Linear(in_features=16, out_features=10, bias=True)
  )
)

In [6]:
model.get_num_params()

Total trainable parameters: 1222


In [7]:
for i in range(num_experts):
    print(f"------------------- Training Expert{i} -------------------")
    model = CNN(dim = 26,conv_channels= [4,4],fc_dims=[16,16,16],kernel_size =4,stride=3,padding=2,max_pooling=False)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    epochs=20
    
    train_model(model=model,
                train_loader=expert_train_loaders[i],
                val_loader=val_loader,
                optimizer=optimizer,
                criterion=criterion,
                num_epochs=epochs,
                device=device)
    torch.save(model,f'./pretrained/expert{i}.pth')
    print(f"------------------- Expert{i} finished -------------------")

------------------- Training Expert0 -------------------
Epoch 1/20 completed in 7.21s.
Epoch 2/20 completed in 7.03s.
Epoch 3/20 completed in 6.98s.
Epoch 4/20 completed in 6.98s.
Epoch 5/20 completed in 6.96s.
Validation Accuracy after epoch 5: 79.97%
Epoch 6/20 completed in 6.99s.
Epoch 7/20 completed in 7.10s.
Epoch 8/20 completed in 6.88s.
Epoch 9/20 completed in 6.90s.
Epoch 10/20 completed in 6.93s.
Validation Accuracy after epoch 10: 81.93%
Epoch 11/20 completed in 6.92s.
Epoch 12/20 completed in 7.00s.
Epoch 13/20 completed in 6.92s.
Epoch 14/20 completed in 6.92s.
Epoch 15/20 completed in 6.98s.
Validation Accuracy after epoch 15: 82.82%
Epoch 16/20 completed in 6.95s.
Epoch 17/20 completed in 7.02s.
Epoch 18/20 completed in 6.98s.
Epoch 19/20 completed in 7.10s.
Epoch 20/20 completed in 6.94s.
Validation Accuracy after epoch 20: 83.68%
------------------- Expert0 finished -------------------
------------------- Training Expert1 -------------------
Epoch 1/20 completed in 7.0

## Train MOE 

In [15]:
experts = []
for i in range(num_experts): 
    expert = CNN(dim = 26,conv_channels= [4,4],fc_dims=[16,16],kernel_size =4,stride=3,padding=2,max_pooling=True)
    path = f'./pretrained/expert{i}.pth'
    expert = torch.load(path)  # map_location ensures it loads on correct device
    expert.eval()  # Set to eval mode if using for inference
    experts.append(expert)


for expert in experts:
    for param in expert.parameters():
        param.requires_grad = False

# Create Mixture of Experts
gating_net = CNN(dim = 26,conv_channels= [8,8],fc_dims=[16,16],kernel_size =5,stride=3,padding=2,max_pooling=True)
moe_model = MOE_CNN(experts=experts, input_dim=26,gating_net = gating_net)

In [16]:
moe_model

MOE_CNN(
  (experts): ModuleList(
    (0-3): 4 x CNN(
      (conv): Sequential(
        (0): Conv2d(1, 4, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2))
        (1): ReLU()
        (2): Conv2d(4, 4, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2))
        (3): ReLU()
      )
      (fc): Sequential(
        (0): Linear(in_features=64, out_features=16, bias=True)
        (1): ReLU()
        (2): Linear(in_features=16, out_features=16, bias=True)
        (3): ReLU()
        (4): Linear(in_features=16, out_features=16, bias=True)
        (5): ReLU()
        (6): Linear(in_features=16, out_features=10, bias=True)
      )
    )
  )
  (gate): GatingNetwork(
    (net): CNN(
      (conv): Sequential(
        (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(3, 3), padding=(2, 2))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): Conv2d(8, 8, kernel_size=(5, 5), stride=(3, 3), padding=(2, 2))
        (4): ReLU()
        (5)

In [17]:
moe_model.get_num_params() #without experts parameters 

Total trainable parameters: 2446


* load data

In [18]:
mean= 0.3240 
std= 0.3612
batch_size = 128
transform = transforms.Compose([
    transforms.ToTensor(),            # Convert to tensor
    transforms.CenterCrop(26),
    transforms.Normalize(mean=(mean,),std=(std,)),
])

train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

g = torch.Generator()
g.manual_seed(seed)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,num_workers=num_workers, pin_memory=pin_memory)

# Split into train and validation (80/20)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = random_split(train_dataset, [train_size, val_size], generator=g)

# DataLoaders
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True,num_workers=num_workers, pin_memory=pin_memory)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False,num_workers=num_workers, pin_memory=pin_memory)

In [19]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(moe_model.parameters(), lr=1e-4)
epochs=15
    
train_model(model=moe_model,
            train_loader=train_loader,
            val_loader=val_loader,
            optimizer=optimizer,
            criterion=criterion,
            num_epochs=epochs,
            device=device,
            val_every_k = 2)

Epoch 1/15 completed in 14.03s.
Epoch 2/15 completed in 13.18s.
Validation Accuracy after epoch 2: 85.98%
Epoch 3/15 completed in 13.12s.
Epoch 4/15 completed in 13.11s.
Validation Accuracy after epoch 4: 85.93%
Epoch 5/15 completed in 13.14s.
Epoch 6/15 completed in 13.59s.
Validation Accuracy after epoch 6: 85.83%
Epoch 7/15 completed in 13.34s.
Epoch 8/15 completed in 13.73s.
Validation Accuracy after epoch 8: 85.78%
Epoch 9/15 completed in 13.71s.
Epoch 10/15 completed in 13.31s.
Validation Accuracy after epoch 10: 85.82%
Epoch 11/15 completed in 13.44s.
Epoch 12/15 completed in 13.18s.
Validation Accuracy after epoch 12: 85.76%
Epoch 13/15 completed in 13.16s.
Epoch 14/15 completed in 13.25s.
Validation Accuracy after epoch 14: 85.83%
Epoch 15/15 completed in 13.09s.
Validation Accuracy after epoch 15: 85.86%


In [20]:
evaluate_model(model=moe_model,data_loader=test_loader,device='cpu')

85.14