In [2]:
import json
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcl
import math
import numpy as np
import pandas as pd
import re

from ast import literal_eval
from itertools import product
from matplotlib.transforms import ScaledTranslation
from matplotlib.ticker import NullFormatter
from os import makedirs
from os.path import isdir, isfile
from pathlib import Path
from string import ascii_lowercase
from time import time
from tqdm import tqdm
from constants import *
from UTILS.mutils import njoin, str2bool, str2ls, create_model_dir, convert_train_history
from UTILS.mutils import collect_model_dirs, find_subdirs, load_model_files
from UTILS.figure_utils import matrixify_axs, label_axs
from plot_results import get_metric_curves, load_seed_runs, final_epoch_stats

# Phase ensembles

In [1]:
models_root = njoin(DROOT, '4L-ps=2')

# collect subdirs containing the model directories
model_root_dirs = models_roots = find_subdirs(models_root, MODEL_SUFFIX)
print(model_root_dirs)                  

# all trained model types
model_types = []   
DCT_ALL = {} 
for model_root_dir in model_root_dirs:
    DCT_cur = collect_model_dirs(model_root_dir, suffix=MODEL_SUFFIX)
    for model_type, df_model_cur in DCT_cur.items():
        df_clean = df_model_cur.dropna(subset='alpha') if 'alpha' in df_model_cur.columns else df_model_cur
        if model_type not in DCT_ALL:
            model_types.append(model_type)
            DCT_ALL[model_type] = df_clean
        else:
            DCT_ALL[model_type] = pd.concat([DCT_ALL[model_type], df_clean], ignore_index=True)   

# isolate partiulcar setting for qk_share
df_model = DCT_ALL[[model_type for model_type in list(DCT_ALL.keys()) if fns_manifold in model_type][0]]
df_model.reset_index(drop=True, inplace=True)
qk_shares = list(df_model.loc[:,'qk_share'].unique())            

NameError: name 'njoin' is not defined

In [4]:
DCT_ALL.keys()

dict_keys(['rdfnsvit', 'opdpvit', 'oprdfnsvit', 'dpvit'])

In [5]:
DCT_ALL[list(DCT_ALL.keys())[0]].loc[:,'dataset_name'].unique()

array(['cifar10'], dtype=object)

In [6]:
df_test = pd.read_csv('Q://scratch//uu69//cq5024//projects//fractional-attn//vit-pytorch//.droot//4L-ps=2//config_qkv//cifar10//layers=4-heads=6-hidden=48-qkv//dpvit-cifar10-qkv//model=0//run_performance.csv')

In [7]:
df_test

Unnamed: 0.1,Unnamed: 0,iter,lr,train_loss,val_loss,train_acc,val_acc,secs_per_eval
0,0,782,0.00010,2.064233,1.933128,0.212935,0.258161,
1,1,1564,0.00010,1.887160,1.817175,0.280291,0.317576,25.863362
2,2,2346,0.00010,1.765761,1.675470,0.318994,0.359773,23.614602
3,3,3128,0.00010,1.695333,1.664246,0.345408,0.366441,23.750036
4,4,3910,0.00010,1.638615,1.583613,0.364850,0.398388,23.821391
...,...,...,...,...,...,...,...,...
245,245,192372,0.00001,0.544838,0.774390,0.806025,0.733977,24.138856
246,246,193154,0.00001,0.544255,0.777900,0.807505,0.732982,24.300847
247,247,193936,0.00001,0.547899,0.779497,0.804807,0.733380,24.111060
248,248,194718,0.00001,0.541615,0.774529,0.806965,0.736167,24.493165


In [8]:
def phase_ensembles(models_root, selected_dataset='cifar10',
                    fns_manifold='rd', qk_share=True, selected_alphas='1,2',
                    metrics='val_acc,val_loss',
                    is_ops = [True],  # [False,True]
                    cbar_separate=False, display=False):

    global qk_shares

    assert fns_manifold in ['sp', 'rd', 'v2_rd'], f'{fns_manifold} does not exist!'
    qk_share, cbar_separate, display = map(str2bool, (qk_share, cbar_separate, display))
    metrics, is_ops = str2ls(metrics), str2ls(is_ops)

    # collect subdirs containing the model directories
    model_root_dirs = models_roots = find_subdirs(njoin(models_root), MODEL_SUFFIX)
    print(model_root_dirs)                  

    # all trained model types
    model_types = []   
    DCT_ALL = {} 
    for model_root_dir in model_root_dirs:
        DCT_cur = collect_model_dirs(model_root_dir, suffix=MODEL_SUFFIX)
        for model_type, df_model_cur in DCT_cur.items():
            df_clean = df_model_cur.dropna(subset='alpha') if 'alpha' in df_model_cur.columns else df_model_cur
            if model_type not in DCT_ALL:
                model_types.append(model_type)
                DCT_ALL[model_type] = df_clean
            else:
                DCT_ALL[model_type] = pd.concat([DCT_ALL[model_type], df_clean], ignore_index=True)                    

    # isolate partiulcar setting for qk_share
    df_model = DCT_ALL[[model_type for model_type in list(DCT_ALL.keys()) if fns_manifold in model_type][0]]
    df_model.reset_index(drop=True, inplace=True)
    qk_shares = list(df_model.loc[:,'qk_share'].unique())
    print(qk_shares)
    assert qk_share in qk_shares, f'qk_share = {qk_share} setting does not exist!'
    
    # print('df_model')
    # print(df_model)

    # ---- col names ----
    stats_colnames = ['min', 'max', 'mid', 'median', 'mean', 'std', 'counter']   

    # ----- general settings -----
    num_attention_heads, num_hidden_layers, hidden_size =\
         DCT_ALL[list(DCT_ALL.keys())[0]].loc[0,['num_attention_heads', 'num_hidden_layers', 'hidden_size']]
    #dataset = DCT_ALL[list(DCT_ALL.keys())[0]].loc[0,'dataset_name']
    assert selected_dataset in DCT_ALL[list(DCT_ALL.keys())[0]].loc[:,'dataset_name'].unique(), 'selected_dataset does not exist'

    # ----- fns setting -----
    alphas = sorted(df_model.loc[:,'alpha'].unique())[::-1]  # large to small
    epss = sorted(df_model.loc[:,'bandwidth'].unique())    
    if selected_alphas.lower() == 'none':
        selected_alphas = alphas
    else:
        selected_alphas = [float(selected_alpha) for selected_alpha in str2ls(selected_alphas)]
    #eps = epss[0]
    eps = 1  # hard coded

    # ----- models to plot -----
    fns_model_type = fns_manifold + 'fns' + MODEL_SUFFIX    
    other_model_types = ['dp' + MODEL_SUFFIX]  # 'sink' + MODEL_SUFFIX
    model_types_to_plot = [fns_model_type] + other_model_types
            
    print(f'model_types_to_plot: {model_types_to_plot}')

    nrows, ncols = len(metrics), len(is_ops)     
    # figsize = (3*ncols,3.5*nrows)
    fig, axs = plt.subplots(nrows,ncols,figsize=(5,4))
    # axs = matrixify_axs(axs, nrows, ncols)  # convert axs to 2D array
    # label_axs(fig, axs)  # alphabetically label subfigures             

    model_types_plotted = []
    model_types_seeds = {}     
    for (row_idx, metric), (col_idx, is_op) in product(enumerate(metrics), enumerate(is_ops)):
        ax = axs[row_idx, col_idx] 
        # summary statistics
        row_stats = []

        print(f'model_type = {model_type}')        
        for model_type in model_types_to_plot:
            if is_op:
                model_type = 'op' + model_type
            if model_type in DCT_ALL.keys():
                df_model = DCT_ALL[model_type]
            else:
                continue
            # matching conditions for model setup
            condition0 = (df_model['ensembles']>0)&(df_model['qk_share']==qk_share)&(df_model['is_op']==is_op)&\
                         (df_model['model_dir'].str.contains(selected_dataset))&\
                         (df_model['model_dir'].str.contains(f'/{model_type}-'))
            matching_df = df_model[condition0]

            if model_type not in model_types_plotted:
                model_types_plotted.append(model_type)

            lstyle_model = LINESTYLE_DICT[model_type]
            for alpha in selected_alphas:
                is_fns = 'fns' in model_type
                alpha = alpha if is_fns else None
                matching_df.reset_index(drop=True, inplace=True)                

                print('matching_df')
                print(matching_df)                      

                # color
                if is_fns:
                    color = '#2E63A6' if alpha == 1.2 else '#A4292F'
                else:
                    # color = 'k'
                    color = '#636363'
                # color = HYP_CMAP(HYP_CNORM(alpha)) if is_fns else OTHER_COLORS_DICT[model_type]  
                # -------------------- SINK, DP -------------------- 
                model_info = matching_df 
                # -------------------- FNS --------------------
                if is_fns:
                    # matching conditions for FNS setup
                    condition = (matching_df['alpha']==alpha) & (matching_df['bandwidth']==eps)
                    model_info = model_info[condition]
                # get aggregated training curves
                if model_info.shape[0] > 0:
                    seeds, qk_share = (model_info[k].item() for k in ('seeds', 'qk_share'))                
                    epochs, run_perf_all = load_seed_runs(model_info['model_dir'].item(), seeds, metric)   
                else:
                    continue

                if run_perf_all is not None:
                    counter = run_perf_all.shape[1]
                    metric_curves = get_metric_curves(run_perf_all)      
                    exe_plot = ax.plot(epochs, metric_curves[1], linestyle='-', c=color, alpha=1, clip_on=False, label='DP' if not is_fns else rf'$\alpha = {alpha}$')
                    if (row_idx,col_idx) == (0,0):
                        im = exe_plot                      
                    # Calculate std                       
                    metric_std = np.nanstd(run_perf_all.to_numpy(), axis=1)
                    ax.fill_between(epochs, metric_curves[1]-metric_std, metric_curves[1]+metric_std, color=color, alpha=0.3, clip_on=False, edgecolor='none') 

                    # results of the final epoch
                    row_stats.append([model_type, alpha] +\
                                     final_epoch_stats(run_perf_all,metric) + [counter])    
                    ax.spines['top'].set_visible(False)
                    ax.spines['right'].set_visible(False)
                    ax.set_xlim([0,20])
                    ax.set_xticks([0, 5, 10, 15, 20])
                    if row_idx == 0:
                        ax.set_ylim(bottom=72,top=85)
                        ax.set_yticks([75,80,85])
                    elif row_idx == 1:
                        # ax.set_ylim([0.45, 0.6])
                        ax.set_yticks([0.45, 0.5, 0.55])
                if not is_fns:
                    break  # only do once if model is not FNS type

        summary_stats = pd.DataFrame(data=row_stats, columns=['model_type','alpha']+stats_colnames)

        # print message
        # print(metric)
        # print(f'is_op = {is_op}, qk_share = {qk_share}')
        # print(summary_stats)
        # print('\n')                    

    # # labels
    # model_labels = []
    # for model_type in model_types_plotted:  
    #     if model_type[:2] != 'op': 
    #         color = 'k' if 'fns' in model_type else OTHER_COLORS_DICT[model_type]            
    #         model_label = NAMES_DICT[model_type]
    #         if model_label not in model_labels:            
    #             axs[0,0].plot([], [], c=color, linestyle=LINESTYLE_DICT[model_type], label=model_label)
    #             model_labels.append(model_label)

    # # legend
    axs[0,0].legend(loc='best', frameon=False)                     
    # for alpha in selected_alphas[::-1]:
    #     axs[0,0].plot([], [], c=HYP_CMAP(HYP_CNORM(alpha)), linestyle='solid', 
    #                   label=rf'$\alpha$ = {alpha}')         
    # ncol_legend = 2  #if len(model_types_plotted) == 3 else 1
    # if len(model_types_plotted) >= 2:
    #     #axs[0,0].legend(loc='best', ncol=ncol_legend, frameon=False)           
    #     axs[0,0].legend(loc='best', ncol=ncol_legend, frameon=False)                     

    # Add shared x and y labels     
    #fig.supxlabel('Epochs', fontsize='medium'); fig.supylabel(NAMES_DICT[metrics[0]], fontsize='medium')

    for row_idx in range(len(qk_shares)):        
        for col_idx, is_op in enumerate(is_ops):  
            ax = axs[row_idx, col_idx]
            #ax.set_ylabel(NAMES_DICT[metric])
            if row_idx == 0:
                #ax.set_title(NAMES_DICT[metric])
                ax_title = r'$W \in O(d)$' if is_ops[col_idx] else r'$W \notin O(d)$'
                ax.set_title(ax_title)
            
            axs[row_idx,col_idx].sharey(axs[row_idx, 0])
            axs[-1,col_idx].set_xlabel('Epochs')
        # axs[row_idx,0].set_ylabel(NAMES_DICT[metrics[row_idx]])
    axs[0,0].set_ylabel('Testing accuracy (%)')
    axs[1,0].set_ylabel('Testing loss')

    # Adjust layout
    plt.subplots_adjust(wspace=0.4, hspace=0.3)
    # plt.tight_layout()  # Leave space for the right label                 

    dataset_name_short = ''
    if isinstance(selected_dataset,str):
        if '_' in selected_dataset:
            for s in selected_dataset.split('_'):
                dataset_name_short += s[0]
        else:
            dataset_name_short += selected_dataset

    model_types_short = [model_type.replace(MODEL_SUFFIX,'') for model_type in model_types_plotted]
    
    return fig, axs

    # from constants import FIGS_DIR
    # SAVE_DIR = njoin(FIGS_DIR, 'nlp-task')
    # if display:
    #     plt.show()
    # else:
    #     if not isdir(SAVE_DIR): makedirs(SAVE_DIR)
    #     fig_file = models_root.split('/')[1] + '-'
    #     #fig_file += f'layers={num_hidden_layers}-heads={num_attention_heads}-hidden={hidden_size}-'            
    #     fig_file += f'l={num_hidden_layers}-d={hidden_size}-'
    #     fig_file += 'qqv-' if qk_share else 'qkv-'
    #     fig_file += '_'.join(model_types_short)+ '-' + metrics[0] + '-' + f'ds={dataset_name_short}'
    #     fig_file += '.pdf'
    #     plt.savefig(njoin(SAVE_DIR, fig_file))            
    #     print(f'Figure saved in {njoin(SAVE_DIR, fig_file)}')

    # # separate colorbar
    # if cbar_separate:    
    #     """
    #     #fig.subplots_adjust(right=0.8)
    #     fig = plt.figure()
    #     cbar_ax = fig.add_axes([0.85, 0.20, 0.03, 0.75])
    #     cbar_ticks = list(np.arange(1,2.01,0.2))
    #     cbar = fig.colorbar(im, cax=cbar_ax, ticks=cbar_ticks)
    #     cbar.ax.set_yticklabels(cbar_ticks)
    #     cbar.ax.tick_params(axis='y', labelsize=tick_size)
    #     """
        
    #     fig = plt.figure()
    #     cbar_ax = fig.add_axes([0.85, 0.20, 0.03, 0.75])
    #     cbar_ticks = list(np.linspace(1,2,6))
        
    #     cbar = mpl.colorbar.ColorbarBase(cbar_ax, norm=HYP_CNORM, cmap=HYP_CM)
    #     cbar.ax.set_yticklabels(cbar_ticks)
    #     cbar.ax.tick_params(axis='y', labelsize=16.5)

    #     plt.savefig(njoin(SAVE_DIR,"alpha_colorbar.pdf"), bbox_inches='tight')  

In [9]:
from constants import FIGS_DIR
fig, axs = phase_ensembles(njoin(DROOT, '4L-ps=2'), is_ops=[False,True], selected_alphas='1.2,2', qk_share=False)
# plt.tight_layout()
# SAVE_DIR = njoin(FIGS_DIR, 'vit-task')    
# fig_file = 'phase_ensembles'
# fig_file += '.pdf'
# #plt.savefig(njoin(SAVE_DIR, fig_file), bbox_inches='tight')
# plt.show()

['q:\\scratch\\uu69\\cq5024\\projects\\fractional-attn\\vit-pytorch\\.droot\\4L-ps=2\\config_qkv\\cifar10\\layers=4-heads=6-hidden=48-qkv']


  df = df._append(model_dir_dct, ignore_index=True)
  df = df._append(model_dir_dct, ignore_index=True)


[False]
df_model
   alpha  bandwidth    a qk_share qkv_bias dataset_name  \
0    1.2        1.0  0.0    False    False      cifar10   
1    2.0        1.0  0.0    False    False      cifar10   
2    1.2        1.0  0.0    False    False      cifar10   
3    2.0        1.0  0.0    False    False      cifar10   

                                          train_loss  \
0  [0.6191950440406799, 0.5639753937721252, 0.646...   
1  [0.6191950440406799, 0.5639753937721252, 0.646...   
2  [0.6191950440406799, 0.5639753937721252, 0.646...   
3  [0.6191950440406799, 0.5639753937721252, 0.646...   

                                            val_loss  \
0  [0.8076232075691223, 0.7732414603233337, 0.788...   
1  [0.8076232075691223, 0.7732414603233337, 0.788...   
2  [0.8076232075691223, 0.7732414603233337, 0.788...   
3  [0.8076232075691223, 0.7732414603233337, 0.788...   

                                           train_acc  \
0  [0.780610203742981, 0.8005113005638123, 0.7708...   
1  [0.7806102

  axs[0,0].legend(loc='best', frameon=False)


: 