In [2]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import argparse
import pickle

plt.rcParams['font.family'] = 'Times New Roman'

def get_test_mrr(file_path):
    try:
        with open(file_path, 'r') as f:
            mrr = None
            for lines in f:
                if lines.startswith('\ttest AP'):
                    mrr = float(lines.strip('\n').split(':')[-1])
                    return mrr
    except FileNotFoundError:
        import pdb; pdb.set_trace()
        pass

def get_best_epoch(file_path):
    try:
        with open(file_path, 'r') as f:
            epoch = None
            for lines in f:
                if lines.startswith('Loading'):
                    epoch = int(lines.split(' ')[4])
                    return epoch
    except FileNotFoundError:
        import pdb; pdb.set_trace()
        pass

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--trial', type=str, help='trial name')
parser.add_argument('--log_dir', type=str, default='log', help='log file directory')
parser.add_argument('--pkl_path', type=str, default='')
parser.add_argument('--target', type=str, default='mrr', choices=['mrr', 'epoch'])
parser.add_argument('--num_scope', type=int, default=25, help='trial name')
parser.add_argument('--num_neighbor', type=int, default=10, help='trial name')
parser.add_argument('--runs', type=int, default=5, help='trial name')
parser.add_argument('--layers', type=int, default=1, help='layer number')
parser.add_argument('--fontsize', type=int, default=32, help='font size')
parser.add_argument('--no_title', action='store_true')
parser.add_argument('--save_legends', action='store_true')

args = parser.parse_args(['--pkl_path', '../all_mrrs_0318.pkl', '--no_title'])
log_dir = args.log_dir
config_dir = 'config' + '/{}'.format(args.trial)
# Optionally, you can set the font size as well
plt.rcParams['font.size'] = args.fontsize
if args.layers == 1:
    scans = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '20', '50', '100']
    intscans = [int(k) for k in scans]
else:
    scans = ['5x5', '5x10', '10x5', '10x10']
datasets = ['WIKI', 'uci', 'Flights', 'LASTFM', 'mooc']
show_datasets = {
    'WIKI': 'Wikipedia',
    'REDDIT': 'REDDIT',
    'Flights': 'Flights',
    'LASTFM': 'LASTFM',
    'mooc': 'MOOC',
    'uci': 'UCI',
    'CollegeMsg': 'CollegeMsg'
}
aggrs = ['TGAT', 'GraphMixer']
show_aggrs = {
    'TGAT': 'Attention',
    'GraphMixer': 'MLP-Mixer'
}
samplings = ['re', 'uni',]
memorys = ['gru', 'embed', '']
show_memorys = {
    'gru': 'RNN',
    'embed': 'Embedding',
    '': 'None'
}

colors = [
'#ff8a65',
'#ffd54f',
'#aed581',
'#4db6ac',
'#4fc3f7',
'#7986cb'
]

In [4]:
all_data = {}

# load data
if os.path.exists(args.pkl_path):
    with open(args.pkl_path, 'rb') as f:
        all_data = pickle.load(f)
else:
    raise NotImplementedError


In [9]:
settings = {
    'WIKI': ['TGAT', 'gru', 're'],
    'REDDIT' : ['TGAT', 'embed', 're'],
    'uci':['TGAT', 'gru', 're'],
    'LASTFM': ['TGAT', 'embed', 're'],
}
plt.figure(figsize=(10, 8))
fmts = ['-o', '-x', '-s']
for i, dataset in enumerate(settings.keys()):
    df_mean = pd.DataFrame()
    df_std = pd.DataFrame()
    df_all = pd.DataFrame()
    aggr, memory, spl = settings[dataset]
    means = []
    stds = []
    for scan in scans:
        results = np.array(all_data[dataset][scan][aggr][spl][memory])
        if len(results) == args.runs:
            means.append(np.mean(results))
            stds.append(np.std(results))
        else:
            print(dataset, scan, aggr, memory)
    if len(means) != len(scans):
        import pdb; pdb.set_trace()
    plt.errorbar(x=intscans, y=means, yerr=stds, fmt=fmts[i], capsize=5, label=show_datasets[dataset], color=colors[i])

title_str = f""
if not args.no_title:
    plt.title(title_str, x=0.5, y=1.05)
x_labels = scans
plt.xticks(ticks=intscans, labels=x_labels)
plt.xlabel(f'# of Neighbors')
plt.ylabel(f'Mean Reciprocal Rank (MRR)')
plt.tight_layout()

plt.savefig(f'../figures/mem_saturate.pdf')

# if args.save_legends:
#     fig, ax = plt.subplots()

#     legend = ax.legend(handles=handles, loc='center', ncol=len(handles)/3)

#     ax.axis('off')

#     fig.canvas.draw()
#     bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
#     fig.savefig('figures/legend.pdf', bbox_inches=bbox)

REDDIT 1 TGAT embed
REDDIT 6 TGAT embed
REDDIT 7 TGAT embed
REDDIT 8 TGAT embed
REDDIT 9 TGAT embed
> [0;32m/var/folders/g0/tq1p162j3s179yw_6h8mlbqr0000gn/T/ipykernel_55522/4190354618.py[0m(25)[0;36m<module>[0;34m()[0m
[0;32m     23 [0;31m    [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeans[0m[0;34m)[0m [0;34m!=[0m [0mlen[0m[0;34m([0m[0mscans[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 25 [0;31m    [0mplt[0m[0;34m.[0m[0merrorbar[0m[0;34m([0m[0mx[0m[0;34m=[0m[0mintscans[0m[0;34m,[0m [0my[0m[0;34m=[0m[0mmeans[0m[0;34m,[0m [0myerr[0m[0;34m=[0m[0mstds[0m[0;34m,[0m [0mfmt[0m[0;34m=[0m[0mfmts[0m[0;34m[[0m[0mi[0m[0;34m][0m[0;34m,[0m [0mcapsize[0m[0;34m=[0m[0;36m5[0m[0;34m,[0m [0mlabel[0m[0;34m=[0m[0mshow_datasets[0m[0;34m[[0m