In [1]:
import torch
import torchvision

import numpy as np

from models.cifar100.cifar100resnet import CifarResNet
from models.cifar100.cifar100expert import CifarExpert
from models.cifar100.gating_network import GatingNetwork
from models.cifar100.moe import MoE
from utils.cifar100_dataset import CIFAR100Dataset

from torchsummary import summary

In [2]:
transformations_training = torchvision.transforms.Compose([
                torchvision.transforms.RandomHorizontalFlip(p=0.5),
                torchvision.transforms.RandomCrop(size=32, padding=4),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
])

In [3]:
feature_extractor = CifarResNet(name='CIFAR100 Feature Extractor')
feature_extractor.load_state_dict(torch.load('./trained_models/baseline_model.pth'))

<All keys matched successfully>

In [4]:
experts = []
for superclass in CIFAR100Dataset.CIFAR100_SUPERCLASSES:
    expert = CifarExpert(classes=CIFAR100Dataset.CIFAR100_SUPERCLASSES[superclass], name='expert_' + superclass, feature_extractor=feature_extractor)
    expert.load_state_dict(torch.load('./trained_models/superclass_experts/' + expert.name + '.pth'))
    experts.append(expert)

In [5]:
moe = MoE(
    classes=CIFAR100Dataset.CIFAR100_LABELS,
    experts=experts,
    Gate=GatingNetwork,
    name='MoE_superclasses',
    feature_extractor=feature_extractor,
    data_folder='/home/lb4653/thesis/mixture-of-experts-thesis/data/cifar100/training',
    transform=transformations_training
)

In [6]:
moe.train_gate(transform=transformations_training, num_epochs=10, learning_rate=0.01)

Training of MoE_superclasses_gate
Training on device: cuda:0
Training on 40,000 samples
Validation on 10,000 samples
Trainable parameters: 1,552,920

Epoch 1/10
----------
training Loss: 3.0372  Top1 Accuracy: 0.4361  Top5 Accuracy: 0.7044
validation Loss: 2.8239  Top1 Accuracy: 0.4818  Top5 Accuracy: 0.7509

Epoch 2/10
----------
training Loss: 2.7395  Top1 Accuracy: 0.5041  Top5 Accuracy: 0.7737
validation Loss: 2.6763  Top1 Accuracy: 0.5152  Top5 Accuracy: 0.7881

Epoch 3/10
----------
training Loss: 2.6151  Top1 Accuracy: 0.5318  Top5 Accuracy: 0.7964
validation Loss: 2.6194  Top1 Accuracy: 0.5297  Top5 Accuracy: 0.8004

Epoch 4/10
----------
training Loss: 2.5311  Top1 Accuracy: 0.5501  Top5 Accuracy: 0.8073
validation Loss: 2.5398  Top1 Accuracy: 0.5488  Top5 Accuracy: 0.8067

Epoch 5/10
----------
training Loss: 2.4818  Top1 Accuracy: 0.5615  Top5 Accuracy: 0.8133
validation Loss: 2.5732  Top1 Accuracy: 0.5475  Top5 Accuracy: 0.8089

Epoch 6/10
----------
training Loss: 2.4264  

In [7]:
moe.train_gate(transform=transformations_training, num_epochs=10, learning_rate=0.001)

Training of MoE_superclasses_gate
Training on device: cuda:0
Training on 40,000 samples
Validation on 10,000 samples
Trainable parameters: 1,552,920

Epoch 1/10
----------
training Loss: 2.1816  Top1 Accuracy: 0.6232  Top5 Accuracy: 0.8368
validation Loss: 2.3970  Top1 Accuracy: 0.5883  Top5 Accuracy: 0.8335

Epoch 2/10
----------
training Loss: 2.1369  Top1 Accuracy: 0.6326  Top5 Accuracy: 0.8390
validation Loss: 2.3668  Top1 Accuracy: 0.5904  Top5 Accuracy: 0.8361

Epoch 3/10
----------
training Loss: 2.1269  Top1 Accuracy: 0.6356  Top5 Accuracy: 0.8431
validation Loss: 2.3809  Top1 Accuracy: 0.5882  Top5 Accuracy: 0.8338

Epoch 4/10
----------
training Loss: 2.1010  Top1 Accuracy: 0.6395  Top5 Accuracy: 0.8429
validation Loss: 2.3809  Top1 Accuracy: 0.5831  Top5 Accuracy: 0.8338

Epoch 5/10
----------
training Loss: 2.1025  Top1 Accuracy: 0.6413  Top5 Accuracy: 0.8421
validation Loss: 2.3505  Top1 Accuracy: 0.5930  Top5 Accuracy: 0.8347

Epoch 6/10
----------
training Loss: 2.0789  