In [7]:
import wandb
import torch
import numpy as np
import math
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns
import torch.nn as nn
import pandas as pd

from models.resnet.resnet18 import ResNet18
from datasets.cifar100_dataset import CIFAR100Dataset
from utils.dataset_utils import train_test_split, get_transformation
from utils.cifar100_utils import CIFAR100_LABELS, get_superclass, CIFAR100_DECODING

from models.moe_layer.resnet18.resnet18_moe import ResNet18MoE
from models.moe_layer.soft_gating_networks import SimpleGate
from models.moe_layer.resnet.moe_block_layer import ResidualMoeBlockLayer, MoeBlockLayer
from models.moe_layer.resnet18.resnet18_experts import NarrowResNet18Expert
from torchvision import transforms


In [2]:
num_experts = 4
position = 4
loss = 'mean'

gate = SimpleGate(
    in_channels=256, 
    num_experts=num_experts,
    top_k=2,
    use_noise=True,
    name='SimpleGate',
    loss_fkt='importance',
    w_aux_loss=0.5
    )
        
moe_layer = ResidualMoeBlockLayer(
    num_experts=num_experts, 
    layer_position=position, 
    top_k=2,
    gating_network=gate,
    resnet_expert=NarrowResNet18Expert)

model = ResNet18MoE(
    moe_layers=[moe_layer],
)

file_model = wandb.restore('Residual_4_topK=2_loss=kl_divergence_w_aux=0.5_moePosition=4_0_final.tar', run_path='lukas-struppek/final_resnet_18/2wb6k7ji')
model.load_state_dict(torch.load(file_model.name)['model_state_dict'])
model = model.to(model.device)

In [3]:
transformations_test = get_transformation('cifar100', phase='test')
cifar_test = CIFAR100Dataset(root_dir='/home/lb4653/mixture-of-experts-thesis/data/cifar100/testing', transform=transformations_test)
dataloader = torch.utils.data.DataLoader(cifar_test, batch_size=256)

In [4]:
def varying_k_accuracy(model, moe_block):
    transformations_test = get_transformation('cifar100', phase='test')
    total_results = dict()
    for k in range(1, 5):
        moe_block.gate.top_k = k
        eval_results = dict()
        for label in CIFAR100_LABELS:
            test_data = CIFAR100Dataset(root_dir='/home/lb4653/mixture-of-experts-thesis/data/cifar100/testing', transform=transformations_test, labels=[label])
            eval_results[label] = model.evaluate(test_data)['acc']
        test_data = CIFAR100Dataset(root_dir='/home/lb4653/mixture-of-experts-thesis/data/cifar100/testing', transform=transformations_test)
        eval_results['total'] = model.evaluate(test_data)['acc']
        total_results[k] = eval_results
    return total_results

In [5]:
total_results = varying_k_accuracy(model, model.layers[-1])

---------------------------
Evaluation of  ResNet18MoE
Evaluation on 100 samples
Evaluation complete in  00:00:00
Evaluation Accuracy: 0.5500
------------------------------------ Finished Evaluation ------------------------------------

------------------------------------ Beginning Evaluation ------------------------------------
Evaluation of  ResNet18MoE
Evaluation on 100 samples
Evaluation complete in  00:00:00
Evaluation Accuracy: 0.8200
------------------------------------ Finished Evaluation ------------------------------------

------------------------------------ Beginning Evaluation ------------------------------------
Evaluation of  ResNet18MoE
Evaluation on 100 samples
Evaluation complete in  00:00:00
Evaluation Accuracy: 0.8200
------------------------------------ Finished Evaluation ------------------------------------

------------------------------------ Beginning Evaluation ------------------------------------
Evaluation of  ResNet18MoE
Evaluation on 100 samples
Evaluat

In [18]:
df = pd.DataFrame(total_results)

In [20]:
df

Unnamed: 0,1,2,3,4
apple,0.8500,0.8800,0.8800,0.8800
aquarium_fish,0.8200,0.8700,0.8800,0.8800
baby,0.5500,0.5800,0.5800,0.5800
bear,0.4900,0.5200,0.5300,0.5400
beaver,0.5600,0.6200,0.6300,0.6300
...,...,...,...,...
willow_tree,0.6800,0.7000,0.7100,0.7100
wolf,0.7000,0.7100,0.7200,0.7300
woman,0.4200,0.4500,0.4600,0.4600
worm,0.7100,0.7200,0.7200,0.7200


In [23]:
df + pd.DataFrame(total_results)

Unnamed: 0,1,2,3,4
apple,1.7000,1.760,1.7600,1.7600
aquarium_fish,1.6400,1.740,1.7600,1.7600
baby,1.1000,1.160,1.1600,1.1600
bear,0.9800,1.040,1.0600,1.0800
beaver,1.1200,1.240,1.2600,1.2600
...,...,...,...,...
willow_tree,1.3600,1.400,1.4200,1.4200
wolf,1.4000,1.420,1.4400,1.4600
woman,0.8400,0.900,0.9200,0.9200
worm,1.4200,1.440,1.4400,1.4400
