In [1]:
import wandb
import torch
import numpy as np
import math
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns
import torch.nn as nn

from models.resnet.resnet18 import ResNet18
from datasets.cifar100_dataset import CIFAR100Dataset
from utils.dataset_utils import train_test_split, get_transformation
from utils.cifar100_utils import CIFAR100_LABELS, get_superclass, CIFAR100_DECODING

from models.moe_layer.resnet18.resnet18_moe import ResNet18MoE
from models.moe_layer.soft_gating_networks import SimpleGate
from models.moe_layer.resnet.moe_block_layer import ResidualMoeBlockLayer, MoeBlockLayer
from models.moe_layer.resnet18.resnet18_experts import NarrowResNet18Expert
from torchvision import transforms


In [2]:
num_experts = 4
position = 4
loss = 'importance'

gate = SimpleGate(
    in_channels=256, 
    num_experts=num_experts,
    top_k=2,
    use_noise=True,
    name='SimpleGate',
    loss_fkt='importance',
    w_aux_loss=0.5
    )
        
moe_layer = ResidualMoeBlockLayer(
    num_experts=num_experts, 
    layer_position=position, 
    top_k=2,
    gating_network=gate,
    resnet_expert=NarrowResNet18Expert)

model = ResNet18MoE(
    moe_layers=[moe_layer],
)

# file_model = wandb.restore('Residual_4_topK=2_loss=importance_w_aux=0.5_moePosition=1_1_final.tar', run_path='lukas-struppek/final_resnet_18/16aj0dnf')
# model.load_state_dict(torch.load(file_model.name)['model_state_dict'])
file_model = wandb.restore('Residual_4_topK=2_loss=importance_w_aux=0.5_moePosition=4_2_final.tar', run_path='lukas-struppek/final_resnet_18/1ezu238w')
model.load_state_dict(torch.load(file_model.name)['model_state_dict'])
model = model.to(model.device)

In [3]:
transformations_test = get_transformation('cifar100', phase='test')
no_transform = get_transformation('no_transform')
cifar_test = CIFAR100Dataset(root_dir='/home/lb4653/mixture-of-experts-thesis/data/cifar100/testing', transform=transformations_test)
cifar_test_no_transform = CIFAR100Dataset(root_dir='/home/lb4653/mixture-of-experts-thesis/data/cifar100/testing', transform=no_transform)
dataloader = torch.utils.data.DataLoader(cifar_test, batch_size=256)

In [4]:
class Hook:
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.inputs = []
        self.outputs = []
        self.weights = []
        self.selected_experts = []
        
    def hook_fn(self, module, input, output):
        for pred in output:
            self.outputs.append(pred.detach().cpu().numpy())
            top_k_logits, top_k_indices = pred.topk(2)
            top_k_weights = nn.functional.softmax(top_k_logits)
            weights = torch.zeros(num_experts, requires_grad=False)
            weights[top_k_indices[0]] = top_k_weights[0]
            weights[top_k_indices[1]] = top_k_weights[1]
            self.weights.append(weights.detach().cpu().numpy())
            experts = [top_k_indices[0].cpu().item(), top_k_indices[1].cpu().item()]
            self.selected_experts.append(experts)
        
    def close(self):
        self.hook.remove()


In [5]:
hook = Hook(model.layers[position-1].gate.fc)

In [6]:
labels = []
superclass_labels = []
for image, label in dataloader:
    for l in label:
        labels.append(CIFAR100_DECODING[l.cpu().item()])
        superclass_labels.append(get_superclass(CIFAR100_DECODING[l.cpu().item()]))
    model(image.to('cuda:0'))

In [7]:
tsne = TSNE(n_components=2, perplexity=15, learning_rate=5, verbose=1, n_iter=1500).fit_transform(hook.outputs)

[t-SNE] Computing 46 nearest neighbors...
[t-SNE] Indexed 10000 samples in 0.004s...
[t-SNE] Computed neighbors for 10000 samples in 0.117s...
[t-SNE] Computed conditional probabilities for sample 1000 / 10000
[t-SNE] Computed conditional probabilities for sample 2000 / 10000
[t-SNE] Computed conditional probabilities for sample 3000 / 10000
[t-SNE] Computed conditional probabilities for sample 4000 / 10000
[t-SNE] Computed conditional probabilities for sample 5000 / 10000
[t-SNE] Computed conditional probabilities for sample 6000 / 10000
[t-SNE] Computed conditional probabilities for sample 7000 / 10000
[t-SNE] Computed conditional probabilities for sample 8000 / 10000
[t-SNE] Computed conditional probabilities for sample 9000 / 10000
[t-SNE] Computed conditional probabilities for sample 10000 / 10000
[t-SNE] Mean sigma: 0.200277
[t-SNE] KL divergence after 250 iterations with early exaggeration: 95.310150
[t-SNE] KL divergence after 1500 iterations: 2.135578


In [8]:
tx, ty = tsne[:,0], tsne[:,1]
tx = (tx-np.min(tx)) / (np.max(tx) - np.min(tx))
ty = (ty-np.min(ty)) / (np.max(ty) - np.min(ty))

df_subset = dict()
df_subset['tsne-2d-one'] = tx
df_subset['tsne-2d-two'] = ty

In [9]:
width = 3500
height = 3500
max_dim = 100
full_image = Image.new('RGB', (width, height))
for idx, x in enumerate(cifar_test_no_transform):
    tile = transforms.ToPILImage()(x[0].squeeze_(0))

    rs = max(1, tile.width / max_dim, tile.height / max_dim)
    tile = tile.resize((int(tile.width / rs),
                        int(tile.height / rs)),
                    Image.ANTIALIAS)
    full_image.paste(tile, (int((width-max_dim) * tx[idx]),
                            int((height-max_dim) * ty[idx])))
full_image.save('./images/{}_pos{}_exp{}_logits.png'.format(loss, position, num_experts))

In [10]:
for exp in range(4):
    width = 3500
    height = 3500
    max_dim = 100
    full_image = Image.new('RGB', (width, height))
    for idx, x in enumerate(cifar_test_no_transform):
        if exp in hook.selected_experts[idx]:
            tile = transforms.ToPILImage()(x[0].squeeze_(0))

            rs = max(1, tile.width / max_dim, tile.height / max_dim)
            tile = tile.resize((int(tile.width / rs),
                                int(tile.height / rs)),
                            Image.ANTIALIAS)
            full_image.paste(tile, (int((width-max_dim) * tx[idx]),
                                    int((height-max_dim) * ty[idx])))
    full_image.save('./images/{}_pos{}_exp{}_logits_expert{}.png'.format(loss, position, num_experts, exp))

In [11]:
tsne = TSNE(n_components=2, perplexity=30, learning_rate=5, verbose=1, n_iter=1500).fit_transform(hook.weights)

tx, ty = tsne[:,0], tsne[:,1]
tx = (tx-np.min(tx)) / (np.max(tx) - np.min(tx))
ty = (ty-np.min(ty)) / (np.max(ty) - np.min(ty))

df_subset = dict()
df_subset['tsne-2d-one'] = tx
df_subset['tsne-2d-two'] = ty

[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 10000 samples in 0.020s...
[t-SNE] Computed neighbors for 10000 samples in 0.098s...
[t-SNE] Computed conditional probabilities for sample 1000 / 10000
[t-SNE] Computed conditional probabilities for sample 2000 / 10000
[t-SNE] Computed conditional probabilities for sample 3000 / 10000
[t-SNE] Computed conditional probabilities for sample 4000 / 10000
[t-SNE] Computed conditional probabilities for sample 5000 / 10000
[t-SNE] Computed conditional probabilities for sample 6000 / 10000
[t-SNE] Computed conditional probabilities for sample 7000 / 10000
[t-SNE] Computed conditional probabilities for sample 8000 / 10000
[t-SNE] Computed conditional probabilities for sample 9000 / 10000
[t-SNE] Computed conditional probabilities for sample 10000 / 10000
[t-SNE] Mean sigma: 0.006425
[t-SNE] KL divergence after 250 iterations with early exaggeration: 80.477547
[t-SNE] KL divergence after 1500 iterations: 0.910992


In [12]:
width = 3500
height = 3500
max_dim = 100
full_image = Image.new('RGB', (width, height))
for idx, x in enumerate(cifar_test_no_transform):
    tile = transforms.ToPILImage()(x[0].squeeze_(0))

    rs = max(1, tile.width / max_dim, tile.height / max_dim)
    tile = tile.resize((int(tile.width / rs),
                        int(tile.height / rs)),
                       Image.ANTIALIAS)
    full_image.paste(tile, (int((width-max_dim) * tx[idx]),
                            int((height-max_dim) * ty[idx])))
full_image.save('./images/{}_pos{}_exp{}_weights.png'.format(loss, position, num_experts))

In [13]:
for exp in range(4):
    width = 3500
    height = 3500
    max_dim = 100
    full_image = Image.new('RGB', (width, height))
    for idx, x in enumerate(cifar_test_no_transform):
        if exp in hook.selected_experts[idx]:
            tile = transforms.ToPILImage()(x[0].squeeze_(0))

            rs = max(1, tile.width / max_dim, tile.height / max_dim)
            tile = tile.resize((int(tile.width / rs),
                                int(tile.height / rs)),
                            Image.ANTIALIAS)
            full_image.paste(tile, (int((width-max_dim) * tx[idx]),
                                    int((height-max_dim) * ty[idx])))
    full_image.save('./images/{}_pos{}_exp{}_weights_expert{}.png'.format(loss, position, num_experts, exp))

In [None]:
for label in ['skyscraper', 'bus', 'lobster', 'man', 'leopard']:
    width = 3500
    height = 3500
    max_dim = 100
    full_image = Image.new('RGB', (width, height))
    for idx, x in enumerate(cifar_test_no_transform):
        if CIFAR100_DECODING[x[1]] == label:
            tile = transforms.ToPILImage()(x[0].squeeze_(0))

            rs = max(1, tile.width / max_dim, tile.height / max_dim)
            tile = tile.resize((int(tile.width / rs),
                                int(tile.height / rs)),
                            Image.ANTIALIAS)
            full_image.paste(tile, (int((width-max_dim) * tx[idx]),
                                    int((height-max_dim) * ty[idx])))

    full_image.save('./images/{}_pos{}_exp{}_weights_label_{}.png'.format(loss, position, num_experts, label))

In [14]:
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
results = pca.fit_transform(hook.outputs)

In [15]:
tx, ty = results[:,0], results[:,1]
tx = (tx-np.min(tx)) / (np.max(tx) - np.min(tx))
ty = (ty-np.min(ty)) / (np.max(ty) - np.min(ty))

df_subset = dict()
df_subset['tsne-2d-one'] = tx
df_subset['tsne-2d-two'] = ty

In [16]:
width = 3500
height = 3500
max_dim = 100
full_image = Image.new('RGB', (width, height))
for idx, x in enumerate(cifar_test_no_transform):
    tile = transforms.ToPILImage()(x[0].squeeze_(0))

    rs = max(1, tile.width / max_dim, tile.height / max_dim)
    tile = tile.resize((int(tile.width / rs),
                        int(tile.height / rs)),
                       Image.ANTIALIAS)
    full_image.paste(tile, (int((width-max_dim) * tx[idx]),
                            int((height-max_dim) * ty[idx])))

full_image.save('./images/{}_pos{}_exp{}_pca.png'.format(loss, position, num_experts))

In [18]:
for exp in range(4):
    width = 3500
    height = 3500
    max_dim = 100
    full_image = Image.new('RGB', (width, height))
    for idx, x in enumerate(cifar_test_no_transform):
        if exp in hook.selected_experts[idx]:
            tile = transforms.ToPILImage()(x[0].squeeze_(0))

            rs = max(1, tile.width / max_dim, tile.height / max_dim)
            tile = tile.resize((int(tile.width / rs),
                                int(tile.height / rs)),
                            Image.ANTIALIAS)
            full_image.paste(tile, (int((width-max_dim) * tx[idx]),
                                    int((height-max_dim) * ty[idx])))

    full_image.save('./images/{}_pos{}_exp{}_pca_expert{}.png'.format(loss, position, num_experts, exp))