In [None]:
import os
import sys

# Ensure no W&B logging will be performed
sys.argv = "main.py -log tb -name tst -reset 1 -lm.eval.enable 0 -log tb -batch_size 20 -restore paper/moe_universal/checkpoints/0gbyzhhc/model.ckpt".split(" ")

# Pretend we are in the main directory
os.chdir("../../")

In [None]:
from main import initialize
import torch
import torch.nn.functional as F
from layers.moe_layer import MoE

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt

plt.rcParams['text.usetex'] = True #Let TeX do the typsetting
plt.rcParams['text.latex.preamble'] = '\\usepackage{sansmath}\n\\sansmath' #Force sans-serif math mode (for axes labels)
plt.rcParams['font.family'] = 'sans-serif' # ... for regular text
plt.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif' # Choose a nice font here

plt.rcParams['figure.dpi'] = 200
plt.rcParams['savefig.dpi'] = 200

In [None]:
helper, task = initialize()
task.create_data_fetcher()

orig_run_model_valid = task.run_model_validation

In [None]:
nexp = task.helper.args.moe.n_experts
ntok = task.helper.args.sentencepiece.n_pieces
ngrp = task.helper.args.transformer.universal.group_size
nlayers = task.helper.args.transformer.encoder_n_layers

token_counts = 0

counts = torch.zeros(ngrp, nlayers // ngrp, nexp, ntok)

In [None]:
global this_data

def run_model_validation(self, data):
    global token_counts
    global this_data

    token_counts = token_counts + F.one_hot(data["data"].flatten().long(), ntok).sum(0)

    this_data = data
    return orig_run_model_valid(data)

task.run_model_validation = run_model_validation.__get__(task)

In [None]:
id_map = {}

def patch_module(module):

    myid = id(module)
    if myid in id_map:
        return

    gid = len(id_map)
    id_map[myid] = gid

    # sel_val, sel_index = self.topk(

    def new_topk(self, *args, **kwargs):
        nonlocal gid
        global this_data
        data = this_data["data"][:-1].T

        sel_val, sel_index = MoE.topk(self, *args, **kwargs)

        assert data.shape == sel_index.shape[:-1]

        data = data.reshape(-1)

        # Shape of counts[gid]: nexp, ntok
        # Linear index: expert * ntok + tok

        seli = sel_index.flatten(end_dim=-2) * ntok
        addi = seli + data[..., None]
        addi = addi.flatten().cpu()

        counts[gid][self.layer // ngrp].flatten().index_add_(0, addi, torch.ones_like(addi, dtype=torch.float32))

        return sel_val, sel_index


    module.topk = new_topk.__get__(module)


In [None]:
for m in task.model.modules():
    if isinstance(m, MoE):
        patch_module(m)

In [None]:
task.validate()

In [None]:
order = torch.argsort(token_counts, descending=True).cpu()
token_counts_o = token_counts.cpu()[order]
counts_o = counts[:, :, :, order]

In [None]:
ostart = 3000
count = 100
gid = 1
layer = 1

labels = task.train_set.vocabulary(order[ostart:ostart+count].tolist())

fig, ax=plt.subplots(figsize=(4, 2))
if layer is None:
    plot_slice = counts_o[gid, :, :, ostart:ostart+count]
    plot_slice = plot_slice.sum(0)
else:
    plot_slice = counts_o[gid, layer, :, ostart:ostart+count]

plot_slice = plot_slice / plot_slice.sum(0, keepdim=True)

plot_slice = plot_slice.T


print("Plot slice shape", plot_slice.shape)

tresh = torch.quantile(plot_slice, 0.95, dim=0, keepdim=True)
# tresh = 0
total_counts = plot_slice * (plot_slice >= tresh)
total_counts = total_counts / total_counts.sum(0, keepdim=True)
# plot_slice = total_counts
total_counts = total_counts * torch.arange(plot_slice.shape[0], dtype=torch.float)[..., None]
total_counts = total_counts.sum(0)
order3 = total_counts.argsort(descending=False)

# print(total_coints[order3])

plot_slice_o = plot_slice[:, order3]
# plot_slice_o = plot_slice_o.T


# plot_slice_o = plot_slice


plt.imshow(plot_slice_o.numpy(), aspect='auto', cmap='viridis', interpolation="none")
plt.colorbar()
# plt.yticks(range(count), labels)

In [None]:
total_counts.shape

In [None]:
counts.shape

In [None]:
order

In [None]:
def plot_group(gid):
    n_experts = count_logs[order[0]][1].shape[0]

    counts = torch.zeros(len(count_logs[order[gid]]), n_experts)
    order2 = list(sorted(count_logs[order[gid]].keys()))
    for j, o in enumerate(order2):
        counts[j] += count_logs[order[gid]][o].cpu()

    total_counts = counts.float() / counts.sum(0, keepdim=True)
    tresh = torch.quantile(total_counts, 0.9, dim=0)
    total_counts = total_counts * (total_counts > tresh)
    total_counts = total_counts * torch.arange(counts.shape[0], dtype=torch.float)[..., None]
    total_counts = total_counts.sum(0)
    order3 = total_counts.argsort(descending=False)

    counts2 = counts[:, order3]

    from matplotlib.colors import LogNorm
    counts2 = counts2.float() / counts2.sum(0, keepdim=True)
    fig, ax=plt.subplots(figsize=(4, 2))
    plt.imshow(counts2.cpu().numpy(), aspect='auto', cmap='viridis', interpolation="none")
    plt.ylabel("Layer")
    plt.xlabel("Expert ID")
    plt.yticks(range(len(counts2)), order2)
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(f"paper/moe_universal/expert_layer_g{gid}.pdf", bbox_inches='tight', dpi=300)
    return fig
    # counts2 = counts

In [None]:
plot_group(1)