In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import seaborn as sns
from matplotlib import axes
import os 
import pandas as pd
import json

import sys, os
sys.path.append('/cluster/home/kamara/Explain')
from clutils.nbutils import *
from clutils.nbutils.params import get_param_ranges
os.getcwd()
sns.set()

In [None]:
def get_info(logdir, dataset):
    logs = {}
    train_scores, test_scores = {}, {}
    for filename in os.listdir(logdir):
        if filename.endswith(".stdout") and filename.startswith("_"):
            if dataset in filename:
                _, logs[filename.replace(".stdout", "")] = parseLogs(join(logdir, filename), kw='__logs')
                _, train_scores[filename.replace(".stdout", "")] = parseLogs(join(logdir, filename), kw='__gnn_train_scores')
                _, test_scores[filename.replace(".stdout", "")] = parseLogs(join(logdir, filename), kw='__gnn_test_scores')
    return(logs, train_scores, test_scores)

def get_df_results(logs, ranges, name, metrics, selection = 'last'):
    dicts = []
    for params_set in enumerateParams(ranges):
        key = name.format(**params_set)
        # if key in logs and len(logs[key]) >= 1 and metric in logs[key]:
        any_metric = (key in logs) and (len(logs[key]) >= 1) and any([metric in logs[key] for metric in metrics])
        if any_metric:
            if selection == 'all':
                for index, row in logs[key].iterrows():
                    metrics_dict = {
                        metric: select_value(row.to_frame().T, metric, 'last') if key in logs and len(logs[key]) >= 1 and metric in logs[key] else -1
                        for metric in metrics
                    }
                    dicts.append(dictmerge(params_set, metrics_dict))
                    
            else:
                metrics_dict = {
                    metric: select_value(logs[key], metric, selection) if key in logs and len(logs[key]) >= 1 and metric in logs[key] else -1
                    for metric in metrics
                }
                dicts.append(dictmerge(params_set, metrics_dict))

    df = pd.DataFrame(dicts)
    return df

# Gridsearch GNN for FacebookPagePage

In [None]:
logdir = f'/cluster/home/kamara/Explain/checkpoints/node_classification/facebook/gridsearch_facebook_lr/logs'
jsonpath = f'/cluster/home/kamara/Explain/checkpoints/node_classification/facebook/gridsearch_facebook_lr/sweep.json'
logs, train_scores, test_scores = get_info(logdir)

print(list(logs.keys())[1])
name = '{none}_explainer_name={explainer_name}_sparsity={sparsity}_dataset={dataset}_hard_mask={hard_mask}'

ranges = get_param_ranges(jsonpath)

In [None]:
fig, ax = plt.subplots(1, figsize = (12,6), sharex=True, sharey=True)
plt.rcParams['axes.grid'] = True
plot_lines = []
for params_set in enumerateParams(ranges):
    key = name.format(**params_set)
    v = logs[key]
    if v.empty == False:
        val_losses = v["val_err"]
        train_losses = v["train_err"]
        nepochs = len(val_losses)*10
        lr = params_set['lr']
        i = ranges['lr'].index(lr)
        
        l1, = ax.plot(train_losses, label=lr, c = palette[i])
        l2, = ax.plot(val_losses, c = palette[i], ls ='dashdot')
        plot_lines += [l1]
        
# set labels
plt.setp(ax, xlabel="epoch")
plt.setp(ax, ylabel="avg squared error")

ax.set_title("Gridsearch on learning rate for FacebookPagePage")

legend1 = plt.legend(plot_lines, ranges['lr'], bbox_to_anchor=(1.15, 0.7), loc = 'center', title='learning rates')
plt.gca().add_artist(legend1)

legend2 = plt.legend([l1, l2], ['train', 'val'], title='error type', bbox_to_anchor=(1.15, 0.2), loc = 'center')
plt.gca().add_artist(legend2)

plt.grid(True)
#plt.tight_layout()
plt.show()

# Gridsearch GNN for Cora, CiteSeer and PubMed

In [None]:
logdir = f'/cluster/home/kamara/Explain/checkpoints/node_classification/planetoid/test/logs'
jsonpath = f'/cluster/home/kamara/Explain/checkpoints/node_classification/planetoid/test/sweep.json'
logs, train_scores, test_scores = get_info(logdir)

print(list(logs.keys())[1])
name = '{none}_explainer_name={explainer_name}_sparsity={sparsity}_dataset={dataset}_hard_mask={hard_mask}'

ranges_all = get_param_ranges(jsonpath)

## Learning rate

In [None]:
# Fix weight decay, dropout
df = logs[(logs.weight_decay==0.0)&(logs.dropout==0.0)]
ranges = ranges_all.copy()
ranges['weight_decay'] == 0
ranges['dropout'] = 0

In [None]:
palette = sns.color_palette("Paired", 25)

fig, axs = plt.subplots(2, 2, figsize = (12,6), sharex=True, sharey=True)
plt.rcParams['axes.grid'] = True
plot_lines = []
for params_set in enumerateParams(ranges):
    key = name.format(**params_set)
    v = logs[key]
    if v.empty == False:
        val_losses = v["val_err"]
        train_losses = v["train_err"]
        nepochs = len(val_losses)*10
        dataset = params_set['dataset']
        lr = params_set['lr']
        i = ranges['lr'].index(lr)

        if dataset=='Cora':
            l1, = axs[0,0].plot(train_losses, label=lr, c = palette[i])
            l2, = axs[0,0].plot(val_losses, c = palette[i], ls ='dashdot')

        elif optimizer=='CiteSeer':
            l1, = axs[0,1].plot(train_losses, label=lr, c = palette[i])
            l2, = axs[0,1].plot(val_losses, c = palette[i], ls ='dashdot')

        elif optimizer=='PubMed':
            l1, = axs[1,0].plot(train_losses, label=lr, c = palette[i])
            l2, = axs[1,0].plot(val_losses, c = palette[i], ls ='dashdot')

        else:
            print('Not an neural net')


axs[0,0].set_title(f"Cora")
axs[0,1].set_title(f"CiteSeer")
axs[1,0].set_title(f"PubMed")

# set labels
plt.setp(axs, xlabel="epoch")
plt.setp(axs, ylabel="avg squared error")

legend1 = plt.legend(plot_lines, ranges['lr'], bbox_to_anchor=(1.15, 1.45), loc = 'center', title='learning rate')
plt.gca().add_artist(legend1)

legend2 = plt.legend([l1, l2], ['train', 'val'], title='error type', bbox_to_anchor=(1.15, 0.2), loc = 'center')
plt.gca().add_artist(legend2)

fig.suptitle('Gridsearch on learning rate', fontsize=16)
plt.grid(True)
#plt.tight_layout()
plt.show()