In [21]:
import json
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcl
import math
import numpy as np
import pandas as pd
import re

from ast import literal_eval
from itertools import product
from matplotlib.transforms import ScaledTranslation
from matplotlib.ticker import NullFormatter
from matplotlib.colors import to_rgb
from os import makedirs
from os.path import isdir, isfile
from pathlib import Path
from string import ascii_lowercase
from time import time
from tqdm import tqdm
from constants import *
from UTILS.mutils import njoin, str2bool, str2ls, create_model_dir, convert_train_history
from UTILS.mutils import collect_model_dirs, find_subdirs, load_model_files
from UTILS.figure_utils import matrixify_axs, label_axs
from plot_results import *

COLORS = ["#636363", "#469C76", "#2E63A6", "#C17DA5", "#C66526", "#EEE461", "#A4292F"]
alphas_all = [1, 1.2, 1.4, 1.6, 1.8, 2, None]
DICT_COLORS = {}
for idx, color in enumerate(COLORS):
    DICT_COLORS[alphas_all[idx]] = color

MARKER_ALPHAS = [1] * 4

In [24]:
def len_inference(models_root, n_layer=1,
                  fns_type='fns', manifold='rd', is_rescale_dist=True, selected_alphas=[1.2, 2.0],
                  is_op=True, qk_shares=[False,True], metric='test_acc'):

    # general setting
    if metric == 'test_acc':
        fname = 'bs=1-test_inference.csv'
    elif metric == 'train_acc':
        fname = 'bs=1-train_inference.csv'

    # get layers, emb_ds from regular expression
    pattern = r"\d+L-hidden=\d+-max_len=512"
    if is_rescale_dist:            
        pattern += "-rescaled"

    # Extract matching subfolders
    layer_dirs_dict = {}
    layers, emb_ds = [], []
    for layer_dir in os.listdir(models_root):
        is_match = re.fullmatch(pattern, layer_dir)
        if is_match:
            #layer, emb_d = int(is_match.group(1)), int(is_match.group(2))
            layer = int(layer_dir.split('L')[0])          
            #emb_d = int(layer_dir.split('-')[1].split('=')[1])
            emb_d = int(layer_dir.split('-')[1].split('=')[1])  
            if isdir(njoin(models_root, layer_dir)):
                layer_dirs_dict[f'{n_layer}-{emb_d}'] = njoin(models_root, layer_dir)
            layers.append(layer)
            emb_ds.append(emb_d)
    layers = np.array(sorted(list(set(layers)))); layers = layers[layers < 4]
    emb_ds = np.array(sorted(list(set(emb_ds)))); emb_ds = emb_ds[emb_ds < 65]    
    assert n_layer in layers, f'{n_layer} does not exist!'

    # get all model dirs
    pattern = re.compile(r"model=\d+$")  # seed paths
    all_model_dirs = [str(p) for p in Path(models_root).rglob("*") if p.is_dir() and pattern.search(str(p))]    
    model_dirs = []
    fns_type = manifold + 'fns' + MODEL_SUFFIX
    other_type = 'dp'+MODEL_SUFFIX
    if is_op:
        fns_type = 'op' + fns_type
        other_type = 'op' + other_type
    model_types_to_plot = [fns_type, other_type]
    for model_dir in all_model_dirs:
        # is_fns = f'/{fns_type}' in model_dir
        # is_dp = f'/{other_type}' in model_dir
        is_fns = f'{fns_type}' in model_dir
        is_dp = f'{other_type}' in model_dir        
        if is_fns:
            # isolate alphas from SELECTED_ALPHAS
            if not any(f'alpha={float(alpha)}' in model_dir for alpha in selected_alphas):
                continue               
        # elif is_dp:
        if model_dir is not None and isfile(njoin(model_dir, fname)):
            model_dirs.append(model_dir)

    # number of controlled variables
    inference = pd.read_csv(njoin(model_dirs[0], fname))
    seq_lens = inference.loc[:,'seq_len']
    _, config, _, _ = load_model_files(model_dir)    
    thresholds = []
    ii = 6
    while 2**ii <= config['seq_len']:
        thresholds.append(2**ii)
        ii += 1    

    ensembles = 5  # figure out how to extract this

    nrows, ncols = len(qk_shares), len(emb_ds)
    #figsize = (3*ncols,3*nrows)
    figsize = (6, 3.4)
    fig, axs = plt.subplots(nrows,ncols,figsize=figsize,sharex=True,sharey=True)  # layout='constrained'
    axs = matrixify_axs(axs, nrows, ncols)
    #label_axs(fig, axs)

    metrics_all = np.zeros([2, len(selected_alphas)+1, len(qk_shares), 
                                len(emb_ds), len(thresholds), ensembles])
    metrics_all[:] = np.nan
    for model_dir in model_dirs:
        # load config
        attn_setup, config, run_performance, train_setting = load_model_files(model_dir)
        seed, model_name, qk_share = attn_setup['seed'], attn_setup['model_name'],\
              attn_setup['qk_share']
        hidden = config['hidden']
        is_fns = model_name[-9:] == 'fns' + MODEL_SUFFIX
        if is_fns:
            alpha = attn_setup['alpha']
            alpha_idx = selected_alphas.index(alpha)
        else:
            alpha_idx = len(selected_alphas)
        #if isfile(njoin(model_dir, fname)):
        inference = pd.read_csv(njoin(model_dir, fname))
        for tidx, threshold in enumerate(thresholds):
            if tidx == 0:
                mask = inference["seq_len"] <= threshold
            else:
                mask = (thresholds[tidx-1] < inference["seq_len"]) & (inference["seq_len"] <= threshold)
            metrics_all[:, alpha_idx, qk_shares.index(qk_share), list(emb_ds).index(hidden), tidx, seed] =\
                [inference.loc[mask, "is_correct"].sum(), mask.sum()]                

            # print(f'threshold = {threshold}')
            # print([inference.loc[mask, "is_correct"].sum(), mask.sum()])

    # accuracy is count / total
    metric_plot = metrics_all[0,:] / metrics_all[1,:]
    for sidx, didx, alpha_idx in\
          product(range(len(qk_shares)), range(len(emb_ds)), range(len(selected_alphas)+1)):
        is_fns = alpha_idx < len(selected_alphas)
        if is_fns:
            alpha = selected_alphas[alpha_idx]
            #color = HYP_CMAP(HYP_CNORM(alpha))
            color = DICT_COLORS[alpha]
            legend_label = rf'$\alpha$ = {alpha}'
        else:
            #color = OTHER_COLORS_DICT[other_type]
            color = DICT_COLORS[None]
            legend_label = 'DP'

        r, g, b = to_rgb(color)
        color = np.array([(r, g, b, alpha) for alpha in MARKER_ALPHAS])        

        if 'acc' in metric:
            metric_mean = np.nanmean(metric_plot[alpha_idx,sidx,didx,:,:] * 100,-1)
            metric_std = np.nanstd(metric_plot[alpha_idx,sidx,didx,:,:] * 100,-1)
        else:
            metric_mean = np.nanmean(metric_plot[alpha_idx,sidx,didx,:,:],-1)
            metric_std = np.nanstd(metric_plot[alpha_idx,sidx,didx,:,:],-1)

        # axs[sidx,didx].plot(thresholds, metric_mean,
        #                     markersize=MARKERSIZE,
        #                     c=color, linestyle='-')  

        # error bars
        # axs[sidx,didx].fill_between(thresholds,  metric_mean - metric_std, metric_mean + metric_std,
        #                             color=color, alpha=0.2, edgecolor='none')             
        axs[sidx,didx].errorbar(thresholds, metric_mean, yerr=metric_std, 
                                fmt='.', linestyle='-', label=legend_label, c=color[0], clip_on=False)                                                 

    # legend                
    axs[0,0].legend(frameon=False,loc='best')
    # log x-axis
    axs[0,0].set_xscale('log')                            

    for ncol in range(ncols):
        axs[0,ncol].set_title(rf'$d = {emb_ds[ncol]}$')
        axs[-1,ncol].set_xlabel(r'Seq. length $n$')        
        # tick labels
        axs[-1,ncol].set_xticks(thresholds)
        axs[-1,ncol].set_xticklabels(thresholds)
        # remove minor ticks
        axs[-1,ncol].xaxis.set_minor_formatter(NullFormatter()) 
        axs[-1,ncol].xaxis.minorticks_off() 
    for nrow in range(nrows):
        axs[nrow,0].set_ylabel(r'$Q = K$' if qk_shares[nrow] else r'$Q \neq K$')    
    
    # remove top and right spines
    for nrow, ncol in product(range(nrows), range(ncols)):
        ax = axs[nrow, ncol]
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)    

    return fig, axs

In [25]:
%matplotlib inline
n_layer = 1
metric = 'test_acc'
fig, axs = len_inference(njoin(DROOT, 'L-d-grid-v2'), n_layer=n_layer, metric=metric)
plt.tight_layout()   
SAVE_DIR = njoin(FIGS_DIR, 'nlp-task')
if not isdir(SAVE_DIR): makedirs(SAVE_DIR)    
fig_file = f'{n_layer}L-{metric}-len_inference.pdf'
plt.savefig(njoin(SAVE_DIR, fig_file))   
plt.show()