In [1]:
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / 'src'))

import torch
import matplotlib.pyplot as plt
from graphssl.utils.plotting_utils import (
    plot_training_curves,
    plot_downstream_results,
    plot_downstream_distribution,
    plot_all_results,
    print_results_summary
)

%matplotlib inline

In [3]:
# Define path to results directory
results_base = Path.cwd().parent / 'results'

# List available result directories
print("Available results directories:")
for i, path in enumerate(sorted(results_base.glob('exp_*')), 1):
    print(f"{i}. {path.name}")

# Select one (modify this to your experiment)
# Example: use the first SSL node experiment
results_path = results_base / 'exp_supervised_node_27217234_20251130_165042'
print(f"\nSelected: {results_path.name}")

Available results directories:
1. exp_ssl_edge_27209540_20251129_085046
2. exp_ssl_edge_27209654_20251129_001448
3. exp_ssl_edge_27217268_20251130_175413
4. exp_ssl_node_sce_27209538_20251129_074740
5. exp_ssl_node_sce_27209655_20251129_010045
6. exp_ssl_node_sce_27209658_20251129_025045
7. exp_ssl_node_sce_27212544_20251129_141147
8. exp_ssl_node_sce_27212545_20251129_162403
9. exp_ssl_node_sce_27217237_20251130_022403
10. exp_ssl_node_sce_27217238_20251130_024722
11. exp_ssl_node_sce_27217289_20251130_030626
12. exp_ssl_tarpfp_27209656_20251129_025015
13. exp_ssl_tarpfp_27217239_20251130_171348
14. exp_ssl_tarpfp_27217244_20251130_175251
15. exp_supervised_link_27212530_20251129_132453
16. exp_supervised_node_27209535_20251129_005635
17. exp_supervised_node_27209637_20251128_234156
18. exp_supervised_node_27212528_20251129_121332
19. exp_supervised_node_27217234_20251130_165042

Selected: exp_supervised_node_27217234_20251130_165042


## 2. Print Results Summary

Get a quick text summary of all results.

In [5]:
print_results_summary(results_path)

UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy._core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([numpy._core.multiarray.scalar])` or the `torch.serialization.safe_globals([numpy._core.multiarray.scalar])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

## 3. Plot Training Curves

Visualize the training and validation metrics over epochs.

In [6]:
# Plot all available training metrics
fig = plot_training_curves(results_path)

UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy._core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([numpy._core.multiarray.scalar])` or the `torch.serialization.safe_globals([numpy._core.multiarray.scalar])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

## 4. Plot Downstream Evaluation Results

Visualize downstream task performance with error bars.

In [None]:
fig = plot_downstream_results(results_path)

## 5. Plot Distribution of Test Results

Show the distribution of test metrics across multiple runs (if available).

In [None]:
# Plot distribution of test accuracies
fig = plot_downstream_distribution(results_path, metric='test_accuracies')

## 6. Generate All Plots at Once

Create all available plots and optionally save them.

In [None]:
# Generate all plots (set show=False to avoid displaying all at once)
# Optionally save to a directory
save_dir = results_path / 'plots'
figures = plot_all_results(results_path, save_dir=save_dir, show=False)

print(f"\nGenerated {len(figures)} figures:")
for name in figures.keys():
    print(f"  - {name}")

## 7. Compare Multiple Experiments

Compare results across different experiment runs.

In [None]:
from graphssl.utils.plotting_utils import load_results

# Compare multiple experiments
experiment_paths = [
    results_base / 'exp_ssl_node_sce_27209655_20251129_010045',
    results_base / 'exp_ssl_node_sce_27209658_20251129_025045',
    # Add more experiment paths here
]

comparison_data = []
for exp_path in experiment_paths:
    if exp_path.exists():
        results = load_results(exp_path)
        exp_name = exp_path.name[:20] + '...'  # Shorten name
        
        # Extract key metrics
        metrics = {'name': exp_name}
        if 'downstream_node' in results:
            metrics['node_acc'] = results['downstream_node'].get('test_acc_mean', 0)
            metrics['node_acc_std'] = results['downstream_node'].get('test_acc_std', 0)
        if 'downstream_link_multiclass' in results:
            metrics['link_f1'] = results['downstream_link_multiclass'].get('test_f1_mean', 0)
            metrics['link_f1_std'] = results['downstream_link_multiclass'].get('test_f1_std', 0)
        
        comparison_data.append(metrics)

# Plot comparison
if comparison_data:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    names = [d['name'] for d in comparison_data]
    x = range(len(names))
    
    if 'node_acc' in comparison_data[0]:
        node_accs = [d.get('node_acc', 0) for d in comparison_data]
        node_stds = [d.get('node_acc_std', 0) for d in comparison_data]
        ax.bar([i - 0.2 for i in x], node_accs, 0.4, yerr=node_stds, 
               label='Node Acc', alpha=0.7, capsize=5)
    
    if 'link_f1' in comparison_data[0]:
        link_f1s = [d.get('link_f1', 0) for d in comparison_data]
        link_stds = [d.get('link_f1_std', 0) for d in comparison_data]
        ax.bar([i + 0.2 for i in x], link_f1s, 0.4, yerr=link_stds,
               label='Link F1', alpha=0.7, capsize=5)
    
    ax.set_xlabel('Experiment')
    ax.set_ylabel('Score')
    ax.set_title('Experiment Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=45, ha='right')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    plt.tight_layout()
    plt.show()
else:
    print("No valid experiments found for comparison")

## Summary

The plotting utilities provide several key functions:

- **`load_results(results_path)`**: Load all available result files from a directory
- **`plot_training_curves(results_path)`**: Visualize training/validation metrics over epochs
- **`plot_downstream_results(results_path)`**: Show downstream task performance with error bars
- **`plot_downstream_distribution(results_path, metric)`**: Display distribution of metrics across runs
- **`plot_all_results(results_path, save_dir)`**: Generate all plots at once
- **`print_results_summary(results_path)`**: Print text summary of results

All functions accept a `results_path` (Path to experiment directory) and return matplotlib figures that can be further customized.