In [None]:
import os,re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import scipy.sparse as sp
from tools.model_func import get_input
import seaborn as sns
from sklearn.metrics import f1_score,precision_score
import warnings

# metric helpers

In [None]:
def get_sparse_k(y_true,y_pred,k,include_rank=False):
    m,n = y_true.shape
    rows = np.repeat(np.arange(m),k)
    cols = y_pred[:,:k].flatten()
    if include_rank:
        data = np.tile(np.arange(k)+1,m)
    else:
        data = np.ones_like(rows)
    return sp.csr_matrix((data,(rows,cols)),shape=(m,n))
def get_pAtk(y_true,y_pred,k):
    pred = get_sparse_k(y_true,y_pred,k)
    patk = (y_true.multiply(pred).sum(axis=1)).A1
    return patk
def get_nDCGAtk(y_true,y_pred,k):
    pred = get_sparse_k(y_true,y_pred,k,include_rank=True)
    pred.data = 1/np.log(pred.data+1)
    dcg = y_true.multiply(pred).sum(axis=1).A1
    num_labs = y_true[0,:].sum() # small cheat coz we know |y|_0 is constant
    norm_const = (1/np.log(np.arange(min(k,num_labs))+2)).sum()
    ndcg = dcg/norm_const
    return ndcg
def get_multilabel_pAtk(y_true,y_pred,k):
    pred = get_sparse_k(y_true,y_pred,k)
    patk = (y_true.multiply(pred).sum(axis=1)/k).A1
    return patk
def get_macro_precision(y_true,y_pred,k):
    pred = get_sparse_k(y_true,y_pred,k)
    return precision_score(trues,pred,average=None)
def get_macro_F1(y_true,y_pred,k):
    pred = get_sparse_k(y_true,y_pred,k)
    return f1_score(trues,pred,average=None)

# input helper

In [None]:
def get_args(in_dir):
    dirs = sorted([os.path.join(in_dir,d) for d in os.listdir(in_dir)])
    out_d = defaultdict(list)
    for d in dirs:
        log_dir = os.path.join(d,'train.log')
        args_dir = os.path.join(d,'args.csv')
        if not os.path.exists(log_dir) or not os.path.exists(args_dir):
            continue
        df = pd.read_csv(log_dir)
        arg = pd.read_csv(args_dir)
        mode = arg.loc[0,'mode']
        arg['dir'] = d
        df['dir'] = d
        out_d[mode].append(df)   
        out_d['args'].append(arg)
    args = pd.concat(out_d['args'], ignore_index = True, sort = False)
    return args

In [None]:
def get_preds(model_dir,y_tests,per_hierarchy=False):
    cnts = [y_tests[i].shape[1] for i in range(len(y_tests))]
    offsets = [0] + [cnts[i]+sum(cnts[:i]) for i in range(len(cnts))]
    if not per_hierarchy:
        if 'FastText' in model_dir:
            preds = np.loadtxt(os.path.join(model_dir,'pred_outputs.txt'),dtype=int)


        else:
            long_dir = os.path.join(model_dir,'combined_pred_outputs.txt')
            if os.path.exists(long_dir):
                preds = np.loadtxt(long_dir)
            else:
                pred_dirs = sorted([os.path.join(model_dir,d) for d in os.listdir(model_dir) if d.startswith('pred_outputs')])
                preds = [np.loadtxt(pred_dir,dtype=int) for pred_dir in pred_dirs]
                logi_dirs = sorted([os.path.join(model_dir,d) for d in os.listdir(model_dir) if d.startswith('pred_logits')])
                logits = [np.loadtxt(dir) for dir in logi_dirs]

                for i in range(0,len(preds)):
                    preds[i]=preds[i]+offsets[i]
                preds = np.concatenate(preds,axis=1)
                # combined top k prediciton
                ll = np.concatenate(logits,axis=1)
                inds = np.argsort(ll,axis=1)[:,:-11:-1]
                preds = np.take_along_axis(preds, inds, axis=1)
                np.savetxt(long_dir,preds,fmt='%d')
                print('SAVE COMBINED PREDICTIONS:\n{}'.format(long_dir))

    else:
        if 'FastText' in model_dir:
            pred_dirs = sorted([os.path.join(model_dir,d) for d in os.listdir(model_dir)])[1:]
            preds = [np.loadtxt(pred_dir,dtype=int,usecols=0)+offsets[i] for i,pred_dir in enumerate(pred_dirs)]
        else:
            pred_dirs = sorted([os.path.join(model_dir,d) for d in os.listdir(model_dir) if d.startswith('pred_outputs')])
            preds = [np.loadtxt(pred_dir,dtype=int,usecols=0)+offsets[i] for i,pred_dir in enumerate(pred_dirs)]
        preds = np.vstack(preds).T
    return preds
        

# get metrics at different levels

In [None]:
# get multi-label metrics
def get_multi_label_metrics(d,trues,y_tests,metrics,ks=[1,3,5]):
    print(d)
    preds = get_preds(d,y_tests,per_hierarchy=False)
    outs = []
    for key,func in metrics.items():
        scores = []
        for i in ks:
            score = func(trues,preds,i).mean()
            scores.append(score)
            print('{}{}:{:.2f}'.format(key,i,score*100),end=' ')
        print()
        outs.append([scores])
    return outs

In [None]:
# get_per_H_metrics
def get_per_H_metrics(d,y_tests,metrics,ks=[1,3,5]):
    print(d)
    preds = get_preds(d,per_hierarchy=True)
    outs = []
    for key,func in metrics.items():
        scoress = []
        for H in range(len(preds)):
            scores = []
            print('H{}:  '.format(H),end='')
            for i in ks:
                score = func(y_tests[H],preds[H],i).mean()
                scores.append(score)
                print('{}{}:{:.2f}'.format(key,i,score*100),end=' ')
            print()
            scoress.append(scores)
        outs.append(scoress)
    return outs

In [None]:
def get_macro_scores(d,trues,y_tests,metrics,groups,ks,per_hierarchy):
    print(d)
    preds = get_preds(d,y_tests,per_hierarchy=per_hierarchy)
    rows = []
    for key,func in metrics.items():
        for k in ks:
            ss = func(trues,preds,k)
            for g,group in enumerate(groups):
                dic = {}
                dic['dir'] = d
                dic['model'] = d.split('_')[-1]
                dic['group'] = 'G{}'.format(g)
                metric = key.format(k)
                for ind in group:
                    dic[metric] = ss[ind]
                    dic['lab_ind']=ind
                    rows.append({key:val for key,val in dic.items()})
    df = pd.DataFrame.from_dict(rows)
    df = df.groupby(['dir','lab_ind']).max().reset_index()
    return df

In [None]:
# get groups
def get_groups(y_trains,num_groups = 3):
    # get train label frequencey
    train_cnts = np.hstack([y.sum(axis=0).A1 for y in y_trains])
    lab_to_cnts = {i:cnt for i,cnt in enumerate(train_cnts)}
    sorted_labs = sorted(lab_to_cnts.keys(),key=lambda x:lab_to_cnts[x])
    # get groups by count
    group_cnt = sum([cnt for cnt in lab_to_cnts.values()])/num_groups
    groups = []
    accumulated_cnts = 0
    group = []
    for lab in sorted_labs:
        accumulated_cnts+=lab_to_cnts[lab]
        if accumulated_cnts>group_cnt and len(groups)<num_groups:
            groups.append(group)
            group = [lab]
            accumulated_cnts = 0
        else:
            group.append(lab)
    groups.append(group)
    # print things
    cut_off = [0]+[lab_to_cnts[g[-1]] for g in groups]
    t_bound = ['${} < t < {}$'.format(cut_off[i],cut_off[i+1])for i in range(len(groups))]
    lab_per_group = [len(g) for g in groups]
    perc_lab_per_group = [g/sum(lab_per_group)*100 for g in lab_per_group]
    df = pd.DataFrame()
    df['group'] = ['G{}'.format(g) for g in range(len(groups))]
    df['num_train'] = [sum([lab_to_cnts[lab] for lab in group])for group in groups]
    df['perc_train'] = df['num_train']/df['num_train'].sum()*100
    
    df['num_train_cut_off'] = cut_off[1:]
    df['t_bound'] = t_bound
    df['num_labels'] = lab_per_group
    df['perc_labels'] = df['num_labels']/df['num_labels'].sum()*100
    return groups,df

# get group table for latex

In [None]:
dfs = []
names = ['SIC Code','AmazonCat-13k']
for i,DATA in enumerate(['sic_hierarchy','amazon_hierarchy_2']):
    IN_DIR = 'data/{}'.format(DATA)
    _,y_trains,_,_ = get_input(mode='cat', in_dir = IN_DIR, sparse = True, get_output= [0,1,0,0])
    _,df = get_groups(y_trains,num_groups = 3)
    df['data'] = names[i]
    dfs.append(df)
df = pd.concat(dfs)
df = df.set_index(['data','group'])
df.index.names=[None,None]

In [None]:
cols = ['num_train','num_labels','perc_labels','t_bound']
header = ['Train Samples','No. Labels','% Labels','Train Samples per label (t)']
ll = df[cols].to_latex(header = header,index=True,float_format='%.2f',multirow=True)
# ll = ll.replace('  ',' ')
ll = re.sub(' {2,}', '',ll)
ll = ll.replace('\$','$')
# ll = ll.replace('<','$<$')
print(ll)

# per group metrics for all data

In [None]:
# get args
args = get_args('outputs')

In [None]:
warnings.filterwarnings('ignore')

dfs = []
kss=[[4],[3]]
metrics = {
    'precision':get_macro_precision,
    'F1':get_macro_F1
}
for LOSS in ['categorical','binary']:
    for i,DATA in enumerate(['sic_hierarchy','amazon_hierarchy_2']):
        ks = kss[i]
        in_dir = 'data/{}'.format(DATA)
        df = args
        df = df[df['mode']=='cat']
        df = df[df['input']==in_dir]
        df = df[df['loss']==LOSS]
        dirs = sorted(df.dir.to_list())
        dirs = ['outputs/{}_c_FastText'.format(DATA)] + dirs
        _,y_trains,_,y_tests = get_input(mode='cat', in_dir = in_dir, sparse = True, get_output= [0,1,0,1])
        trues = sp.hstack(y_tests).tocsr()
        groups,_ = get_groups(y_trains,num_groups = 3)
        for ph in [True,False]:
            for d in dirs:
                df = get_macro_scores(d,trues,y_tests,metrics,groups,ks=ks,per_hierarchy=ph)
                df['loss']=LOSS
                df['input']=in_dir
                df['per_hierarchy']=ph
                dfs.append(df)

## save

In [None]:
df = pd.concat(dfs,sort=False)
df.head()

In [None]:
df.to_pickle('outputs/dfs/combined0.pkl')

# macro average scores

In [None]:
order = ['FastText','xmlcnn','attentionxml','attention']
df1 = df
df1 = df[df['model'].isin(order)]
df1 = df1.drop(columns=['lab_ind'])
df1 = df1.groupby(['loss','input','per_hierarchy','model']).mean().unstack([-1])
df1

# Get pretty graphs

In [None]:
df = pd.read_pickle('outputs/dfs/combined0.pkl')

In [None]:
metric_dict = {
    'precision':'Precision',
    'F1':'F1 score',
}
data_dict = {
    'data/sic_hierarchy':'SIC Code',
    'data/amazon_hierarchy_2':'AmazonCat-13k',
}

### PLOT : baseline comparison

In [None]:
# PARAMS
datas = ['data/sic_hierarchy','data/amazon_hierarchy_2']
models = ['FastText','xmlcnn','attentionxml','attention']
metrics = ['precision','F1']

# function
df1 = df
df1 = df1[df1.loss=='binary']
df1 = df1[df1.model.isin(models)]
df1 = df1[df1.per_hierarchy == False]
# bar chart of mean
for data in datas:
    df2 = df1[df1.input==data]
    for metric in metrics:
        # y tick counts
        cnts = (df2.group.value_counts()/df2.group.value_counts().sum()*100).to_dict()
        groups = sorted(cnts.keys())
        # plot
        fig,ax = plt.subplots()
        bar = sns.barplot(
            x = 'group',
            y=metric,
            hue='model',
            data = df2,
            ax = ax, 
            palette="Set3",
            order=groups,
            hue_order = models
           )
        ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
               ncol=len(models), mode="expand", borderaxespad=0.)
        ax.set_ylabel(metric_dict[metric])
        ax.set_xlabel('{} label groups (% labels in group)'.format(data_dict[data]))
        ax.set_xticklabels(['{} ({:.2f}%)'.format(key,cnts[key]) for key in groups])
        plt.show()

In [1]:
### TABLE : baseline comparison

In [None]:
# TABLE
datas = ['data/sic_hierarchy','data/amazon_hierarchy_2']
models = ['FastText','xmlcnn','attentionxml','attention']
metrics = ['precision','F1']
for data in datas:
    print(data)
    df2 = df1[df1.input==data]
    srs = (df2.groupby('model')[metrics].mean()*100).to_dict()
    for metric in metrics:
        print('{:10}'.format(metric),end=':')
        print('&'.join(['{:.2f}'.format(srs[metric][key]) for key in models]))

### play

In [None]:
# change in model performance with changing calculation method
df1 = df
df1 = df1[df1.loss=='binary']
# df1 = df1[df1.model!='attention']
# df1 = df1[df1.model!='bert']
# df1 = df1[df1.model=='attentionxml']
df1  = df1.set_index(['loss','input','dir','model','group','lab_ind','per_hierarchy']).stack().unstack([-2,-1])
df1 = (df1.loc[:,(True)]-df1.loc[:,(False)]).reset_index()
df1 = df1[df1.input=='data/amazon_hierarchy_2']
# bar chart of mean
metrics = ['precision','F1']
y_labs = ['binary accuracy','F1 score']
for i in range(len(metrics)):
    metric = metrics[i]
    y_lab = y_labs[i]
    fig,ax = plt.subplots()
    sns.barplot(x = 'group',y=metric,hue='model',data = df1,ax = ax, palette="Set3",
                order=sorted(df1.group.unique()),
                hue_order = order
               )
    ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
           ncol=len(df1.model.unique()), mode="expand", borderaxespad=0.)
    ax.set_ylabel(y_lab)
    ax.set_xlabel('AmazonCat-13k label groups')
    plt.show()

In [None]:
df1 = df
df1 = df1[df1['per_hierarchy']==True].drop(columns=['per_hierarchy'])

# df1 = df1.drop(columns=['dir'])
# df1  = df1.set_index(['input','model','group','lab_ind','loss']).stack().unstack([-2,-1])
# df1


In [None]:
df1 = df1.drop(columns=['dir'])

In [None]:
df1.groupby(columns=['input','model','group','lab_ind','loss'])

In [None]:
# change in model performance with changing calculation method
df1 = df
df1 = df1[df1.loss=='binary']
df1 = df1[df1['per_hierarchy']==True]
# df1 = df1[df1.model!='attention']
# df1 = df1[df1.model!='bert']
# df1 = df1[df1.model=='attentionxml']
df1  = df1.set_index(['input','dir','model','group','lab_ind','loss']).stack().unstack([-2,-1])
df1 = (df1.loc[:,('categorical')]-df1.loc[:,('binary')]).reset_index()
df1 = df1[df1.input=='data/amazon_hierarchy_2']
# bar chart of mean
metrics = ['precision','F1']
y_labs = ['binary accuracy','F1 score']
for i in range(len(metrics)):
    metric = metrics[i]
    y_lab = y_labs[i]
    fig,ax = plt.subplots()
    sns.barplot(x = 'group',y=metric,hue='model',data = df1,ax = ax, palette="Set3",
                order=sorted(df1.group.unique()),
                hue_order = order
               )
    ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
           ncol=len(df1.model.unique()), mode="expand", borderaxespad=0.)
    ax.set_ylabel(y_lab)
    ax.set_xlabel('AmazonCat-13k label groups')
    plt.show()

# multi-lab metrics