In [None]:
import os
import sys

# Ensure no W&B logging will be performed
sys.argv = "main.py -log tb -name tst -reset 1 -lm.eval.enable 0 -log tb -batch_size 20 -restore paper/moe_universal/checkpoints/0gbyzhhc/model.ckpt".split(" ")
# sys.argv = "main.py -log tb -name tst -reset 1 -lm.eval.enable 0 -log tb -batch_size 20 -restore paper/moe_universal/checkpoints/plvywltl/model.ckpt".split(" ")

# Pretend we are in the main directory
os.chdir("../../")

In [None]:
from main import initialize
import torch
import torch.nn.functional as F
from layers.moe_layer import MoE

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
from matplotlib.ticker import ScalarFormatter

plt.rcParams['text.usetex'] = True #Let TeX do the typsetting
plt.rcParams['text.latex.preamble'] = '\\usepackage{sansmath}\n\\sansmath' #Force sans-serif math mode (for axes labels)
plt.rcParams['font.family'] = 'sans-serif' # ... for regular text
plt.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif' # Choose a nice font here

plt.rcParams['figure.dpi'] = 200
plt.rcParams['savefig.dpi'] = 200

In [None]:
helper, task = initialize()
task.create_data_fetcher()

orig_run_model_valid = task.run_model_validation

In [None]:
nexp = task.helper.args.moe.n_experts
ntok = task.helper.args.sentencepiece.n_pieces
ngrp = task.helper.args.transformer.universal.group_size
nlayers = task.helper.args.transformer.encoder_n_layers
context = task.helper.args.lm.unroll
k = task.helper.args.pkm.n_heads
bsz = task.helper.args.batch_size

token_counts = 0

cnt = 0
simmap = torch.zeros(ngrp, nlayers // ngrp, nlayers // ngrp)

thissel = torch.zeros(ngrp, nlayers // ngrp, bsz, context, k, dtype=torch.long)



In [None]:
global this_data

def run_model_validation(self, data):
    global token_counts
    global this_data
    global cnt
    global simmap

    token_counts = token_counts + F.one_hot(data["data"].flatten().long(), ntok).sum(0)

    this_data = data

    thissel.zero_()

    res = orig_run_model_valid(data)

    ohsel = F.one_hot(thissel, nexp).sum(-2)
    ohsel = (ohsel.flatten(2,3).permute(0, 2, 1, 3) > 0).float()

    #shape: ngrp, bsz*context, nlayer, nexp
    overlap = torch.einsum("nglk,ngok->nglo", ohsel, ohsel)
    norm = torch.maximum(ohsel.unsqueeze(-3), ohsel.unsqueeze(-2)).sum(-1)
    simcnt = (overlap / norm).sum(1)
    cnt += ohsel.shape[1]
    simmap += simcnt

    return res

task.run_model_validation = run_model_validation.__get__(task)

In [None]:
id_map = {}

def patch_module(module):

    myid = id(module)
    if myid in id_map:
        return

    gid = len(id_map)
    id_map[myid] = gid

    # sel_val, sel_index = self.topk(

    def new_topk(self, *args, **kwargs):
        gid = id_map[id(self)]

        sel_val, sel_index = MoE.topk(self, *args, **kwargs)

        thissel[gid, self.layer//ngrp] = sel_index

        return sel_val, sel_index


    module.topk = new_topk.__get__(module)


In [None]:
for m in task.model.modules():
    if isinstance(m, MoE):
        patch_module(m)

In [None]:
task.validate()

In [None]:
avg = simmap/cnt

In [None]:
gid = 0
fig, ax=plt.subplots(figsize=(2.4,2))
plt.imshow(avg[0], cmap='viridis', aspect='auto')
plt.colorbar()
plt.xticks([a*2 + gid for a in range(nlayers // ngrp // 2 + 1)], [str(a*4 + 1 + gid) for a in range(nlayers // ngrp // 2 + 1)])
plt.yticks([a*2 + gid for a in range(nlayers // ngrp // 2 + 1)], [str(a*4 + 1 + gid) for a in range(nlayers // ngrp // 2+1)])
plt.xlabel("Layer")
plt.ylabel("Layer")
plt.tight_layout()
plt.savefig("paper/moe_universal/layer_similarity.pdf", bbox_inches='tight')