In [None]:
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from loader import PathDataModule
from tqdm import tqdm

# --- 1. Configuration and Data Loading ---
config_path = 'full.json'
config_data = json.load(open(config_path, 'r'))
print(config_data)

In [None]:

# Initialize the DataModule and load the data
print("Setting up DataModule...")
dm = PathDataModule(config_path=config_path, batch_size=32)
dm.setup('fit')
print("Data loaded.")


In [None]:

# --- 2. Data Extraction and Preparation ---
plot_data = []
split = 'valid'  # You can change this to 'valid' or 'test'

print(f"Processing data for '{split}' split...")
edges_df = dm.data[split]
pos_paths = dm.pos_paths[split]
neg_paths = dm.neg_paths[split]


In [None]:

for eid, row in tqdm(edges_df.iterrows(), total=len(edges_df), desc="Extracting paths"):
    eid_str = str(eid)
    label = row['label']

    # Process positive path
    if eid_str in pos_paths and pos_paths[eid_str].get('nodes'):
        pos_path_len = len(pos_paths[eid_str]['nodes'])
        plot_data.append({
            'path_length': pos_path_len,
            'path_type': 'positive',
            'label': 'true_link' if label == 1 else 'false_link'
        })

    # Process negative paths
    if eid_str in neg_paths:
        for neg_path_interleaved in neg_paths[eid_str]:
            # As per loader.py, nodes are at even indices
            neg_path_len = len(neg_path_interleaved[::2])
            plot_data.append({
                'path_length': neg_path_len,
                'path_type': 'negative',
                'label': 'true_link' if label == 1 else 'false_link'
            })


In [None]:

plot_df = pd.DataFrame(plot_data)
print("Data prepared for plotting.")


In [None]:

# --- 3. Visualization ---
if not plot_df.empty:
    print("Generating plots...")
    # Set plot style
    sns.set_theme(style="whitegrid")

    # Create a scatter plot (strip plot) to show the distribution
    plt.figure(figsize=(16, 9))
    sns.stripplot(data=plot_df, x='path_length', y='path_type', hue='label',
                  jitter=0.35, alpha=0.6, dodge=True, palette={'true_link': 'blue', 'false_link': 'red'})
    plt.title(
        f'Path Length Scatter Plot for {split.capitalize()} Set', fontsize=18)
    plt.xlabel('Path Length (Number of Nodes)', fontsize=12)
    plt.ylabel('Path Type', fontsize=12)
    plt.xticks(range(plot_df['path_length'].min(),
               plot_df['path_length'].max() + 1))
    plt.legend(title='Edge Label')
    plt.show()

    # Create a box plot for a clearer summary of the distributions
    plt.figure(figsize=(16, 9))
    sns.boxplot(data=plot_df, x='path_length', y='path_type', hue='label',
                palette={'true_link': 'blue', 'false_link': 'red'})
    plt.title(
        f'Path Length Distribution for {split.capitalize()} Set', fontsize=18)
    plt.xlabel('Path Length (Number of Nodes)', fontsize=12)
    plt.ylabel('Path Type', fontsize=12)
    plt.xticks(range(plot_df['path_length'].min(),
               plot_df['path_length'].max() + 1))
    plt.legend(title='Edge Label')
    plt.show()
else:
    print("No data available to plot.")