In [None]:
%load_ext autoreload
%autoreload 2

# Importing Libraries

In [None]:
import pandas as pd
from data_store import nc_datasets, gc_datasets, lp_datasets
from tabulate import tabulate
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

In [None]:
plt.rcdefaults()

In [None]:
mpl.rcParams['figure.dpi'] = 300
# mpl.rcParams['text.usetex'] = True
# mpl.rcParams['font.family'] = 'serif'
# mpl.rcParams['font.serif'] = ['Computer Modern']
# mpl.rcParams['text.latex.preamble'] = '\\usepackage{amsmath}\\usepackage{amssymb}'
# mpl.rcParams['font.size'] = 11
# mpl.rcParams['lines.linewidth'] = 2
# mpl.rcParams['lines.markersize'] = 6
# mpl.rcParams['grid.linestyle'] = '--'
# mpl.rcParams['grid.linewidth'] = 0.5
# mpl.rcParams['axes.grid'] = True

In [None]:
def get_figsize(columnwidth, wf=0.5, hf=(5.**0.5-1.0)/2.0, ):
    """Parameters:
      - wf [float]:  width fraction in columnwidth units
      - hf [float]:  height fraction in columnwidth units.
                     Set by default to golden ratio.
      - columnwidth [float]: width of the column in latex. Get this from LaTeX 
                             using \showthe\columnwidth
    Returns:  [fig_width,fig_height]: that should be given to matplotlib
    """
    fig_width_pt = columnwidth*wf 
    inches_per_pt = 1.0/72.27               # Convert pt to inch
    fig_width = fig_width_pt*inches_per_pt  # width in inches
    
    fig_height = fig_width*hf      # height in inches
    return [fig_width, fig_height]

column_width = 426.0  # Column width (pt) in LaTeX
fig_size = get_figsize(column_width, wf=0.6, hf=1.0)
fig_size

In [None]:
models = ['gcn', 'graphsage', 'gat', 'gin']
metrics = ['loss', 'accuracy', 'f1', 'precision', 'recall']
tasks = ['nc', 'gc', 'lp']
log_dir = 'logs'
stds = ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10']
std= stds[1]
sns.set_style("whitegrid")
paired_colors = sns.color_palette("Paired", n_colors=2 * len(models))
model_colors = {model: (paired_colors[i * 2], paired_colors[i * 2 + 1]) for i, model in enumerate(models)}

In [None]:
def plot_metrics(metrics, model_metrics_df, dataset, save=False, max_epoch=None):
    
    for metric in metrics:
        plt.figure(figsize=fig_size)  # Create a new figure for each metric
        
        for model_name, df in model_metrics_df.items():
            train_metric = df[f"train_{metric}"].dropna().reset_index(drop=True)
            val_metric = df[f"val_{metric}"].dropna().reset_index(drop=True)
            epochs = df["epoch"].drop_duplicates().reset_index(drop=True)
            max_epoch = len(epochs) if max_epoch is None else max_epoch
            sns.lineplot(x=epochs[:max_epoch], y=train_metric[:max_epoch], label=f"{model_name} train", linestyle="solid", color=model_colors[model_name][0])
            # Plot validation metric
            sns.lineplot(x=epochs[:max_epoch], y=val_metric[:max_epoch], label=f"{model_name} val", linestyle="dashed", color=model_colors[model_name][1])
    
        #plt.title(f"{metric.capitalize()}", fontsize=14)
        plt.xlabel("Epochs")
        plt.ylabel(metric.capitalize())
        #plt.ylabel('')
        plt.tight_layout()
        if save:
            plt.savefig(f'trainfig/{dataset}_{metric}.pdf', format='pdf')
        plt.show()  # Show each figure separately
        
def print_test_metrics_table(model_metrics):
    """
    Prints a table of test metrics for each model.
    
    Parameters:
        model_metrics (dict): Dictionary where keys are model names and 
                              values are dataframes with test metrics.
    """
    # Extract the last row (final test metrics) for each model
    table_data = []
    for model_name, df in model_metrics.items():
        final_metrics = df.iloc[-1][["test_loss", "test_accuracy", "test_f1", "test_precision", "test_recall"]].values
        table_data.append([model_name] + list(final_metrics))

    # Define table headers
    headers = ["Model", "Test Loss", "Test Accuracy", "Test F1", "Test Precision", "Test Recall"]

    # Print the table
    print(tabulate(table_data, headers=headers, tablefmt="fancy_grid", floatfmt=".4f"))

# Link Prediction

In [None]:
task = tasks[2]

## Cora

In [None]:
dataset = lp_datasets[0]
paths = {model:f'{log_dir}/{task}_{model}_{dataset}/version_0/metrics.csv' for model in models}
model_metrics= {model: pd.read_csv(path) for model, path in paths.items()}
plot_metrics(metrics, model_metrics, dataset, save=True)

In [None]:
print_test_metrics_table(model_metrics)

# Graph Classification

In [None]:
task = tasks[1]

## BA-2motif

In [None]:
dataset = gc_datasets[0]
paths = {model:f'{log_dir}/{task}_{model}_{dataset}{std}/version_0/metrics.csv' for model in models}
model_metrics= {model: pd.read_csv(path) for model, path in paths.items()}
plot_metrics(metrics, model_metrics, dataset, save=True)

In [None]:
print_test_metrics_table(model_metrics)

## MUTAG

In [None]:
dataset = gc_datasets[1]
paths = {model:f'{log_dir}/{task}_{model}_{dataset}/version_0/metrics.csv' for model in models}
model_metrics= {model: pd.read_csv(path) for model, path in paths.items()}
plot_metrics(metrics, model_metrics, dataset, save=True)

In [None]:
print_test_metrics_table(model_metrics)

# Node Classification

In [None]:
task = tasks[0]

## BA-Shapes

In [None]:
dataset = nc_datasets[0]
paths = {model:f'{log_dir}/{task}_{model}_{dataset}{std}/version_0/metrics.csv' for model in models}
model_metrics= {model: pd.read_csv(path) for model, path in paths.items()}
plot_metrics(metrics, model_metrics, dataset, save=True)

In [None]:
print_test_metrics_table(model_metrics)

## Tree-Grid

In [None]:
dataset = nc_datasets[1]
paths = {model:f'{log_dir}/{task}_{model}_{dataset}{std}/version_0/metrics.csv' for model in models}
model_metrics= {model: pd.read_csv(path) for model, path in paths.items()}
plot_metrics(metrics, model_metrics, dataset, save=True)

In [None]:
print_test_metrics_table(model_metrics)