In [4]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from loader import PathDataModule
from tqdm import tqdm

# --- 1. Configuration and Data Loading ---
config_path = 'full.json'
config_data = json.load(open(config_path, 'r'))
print(config_data)

Matplotlib is building the font cache; this may take a moment.


{'dataset': 'ogbl-collab', 'storage_dir': '../data/', 'embedding_config': './full_embedding.json', 'train_ratio': 0.3, 'hidden_dim': 128, 'max_hops': 4, 'num_threads': 'vast', 'max_epochs': 10, 'batch_size': 'vast', 'dim_feedforward': 128, 'nhead': 4, 'num_layers': 4, 'dropout': 0.1, 'store': 'model', 'save_text_embeddings': True, 'shallow': True, 'pre_scan': ['train'], 'adjust_no_neg_paths_samples': True, 'max_adjust': 5.0, 'positive_deviation': True, 'embedding': 'all', 'test_time': 4, 'num_ckpt': 2, 'scale_loss': True, 'chi2': False, 'lr': 0.001, 'wandb_project': 'thesis-graph'}


In [2]:
import os, requests, json, pprint

try:
    cid  = os.getenv("CONTAINER_ID")          # ← set by Vast.ai inside every container
    key  = os.getenv("CONTAINER_API_KEY")     # ← scoped token for this one instance
    assert cid and key, "Not running on a Vast.ai container!"

    resp = requests.get(
        f"https://console.vast.ai/api/v0/instances/{cid}/",
        headers={"Authorization": f"Bearer {key}",
                "accept": "application/json"},
        timeout=10,
    )

    info = resp.json()
    print("Effective vCPUs:", info['instances']["cpu_cores_effective"])
    # print(info['instances'].keys())
except Exception as e:
    print("Not running on Vast.ai container, or failed to fetch instance info.")
    print("Error:", e)

Not running on Vast.ai container, or failed to fetch instance info.
Error: Not running on a Vast.ai container!


In [13]:
import json
import os

embedding_cfg_path = config_data['embedding_config'] # path to config file
print("Embedding config path:", embedding_cfg_path)
# Read config from json file at embedding cfg path
with open(embedding_cfg_path, 'r') as f:
    embedding_cfg = json.load(f)
    
embedding_cfg

Embedding config path: ./full_embedding.json


{'batch_size': 4096,
 'lr': 0.001,
 'epochs': 10,
 'model_name': 'transe',
 'hidden_channels': 128,
 'p_norm': 2}

In [16]:


# --- 2. Load exported JSON results from model.py ---

# Construct the path to your export (update these variables as needed)
save_dir = config_data['storage_dir'] + embedding_cfg['model_name'] + \
    "/" + config_data['dataset'] + "/" + config_data['wandb_project']
    
print(f"Save directory: {save_dir}")


Save directory: ../data/transe/ogbl-collab/thesis-graph


In [18]:
wandb_name = "0"      
stage = "val"  # or "test"
epoch = "0"  # e.g. "9"
test_time = False  # set True if test_time

prefix = "test" if test_time else "train"
export_path = os.path.join(save_dir, wandb_name, f"{prefix}_{stage}_{epoch}_raw.json")
print(export_path)

../data/transe/ogbl-collab/thesis-graph/0/train_val_0_raw.json


In [19]:
# Load the JSON file
with open(export_path, "r") as f:
    export_items = json.load(f)

print(f"Loaded {len(export_items)}")

Loaded 160084


In [52]:
export_items[0]

{'score': 0.09943151473999023,
 'length': 1,
 'label': 1,
 'has_neg': True,
 'pos_dist': [11.837754249572754],
 'neg_dist': [[10.281356811523438],
  [10.281356811523438],
  [10.281356811523438],
  [10.281356811523438],
  [10.34947681427002],
  [11.837754249572754],
  [11.837754249572754],
  [11.837754249572754],
  [10.65380859375],
  [10.65380859375],
  [10.65380859375],
  [9.960492134094238],
  [9.960492134094238],
  [11.205546379089355],
  [11.113570213317871],
  [11.833039283752441],
  [10.587995529174805],
  [11.491482734680176],
  [11.343835830688477],
  [11.866433143615723]],
 'adjusted_score': 4.09943151473999}

In [67]:
# recompute adjusted score
max_hops = config_data['max_hops']
max_adjust = config_data['max_adjust']
print(f"Max hops: {max_hops}, Max adjust: {max_adjust}")

Max hops: 4, Max adjust: 5.0


In [68]:
# length zero adjusted to max_hops + 2
for item in tqdm(export_items):
    if item['length'] == 0:
        item['length'] = max_hops + 2

min_length = min(item['length'] for item in export_items)
print(f"Minimum length: {min_length}")

100%|██████████| 160084/160084 [00:00<00:00, 813324.63it/s]


Minimum length: 1


In [69]:
# adjusted score
for item in tqdm(export_items):
    ratio = 1 - (item['length'] - min_length) / (max_hops + 1 - min_length)
    item['adjusted_score'] = item['score'] + (ratio * max_adjust)

  0%|          | 0/160084 [00:00<?, ?it/s]

100%|██████████| 160084/160084 [00:00<00:00, 273830.31it/s]


In [72]:
export_items[3]

{'score': 0.0,
 'length': 6,
 'label': 1,
 'has_neg': False,
 'pos_dist': None,
 'neg_dist': None,
 'adjusted_score': -1.25}

In [79]:
# Get list of label 0 but non-zero length
neg_items = [item for item in export_items if item['label'] == 0 and item['has_neg']]
print(f"Filtered non-zero items: {len(neg_items)}")

Filtered non-zero items: 7083


In [80]:
neg_items[200]

{'score': 0.1156584620475769,
 'length': 4,
 'label': 0,
 'has_neg': True,
 'pos_dist': [10.880391120910645,
  9.0462646484375,
  11.833160400390625,
  11.009916305541992],
 'neg_dist': [[10.880391120910645,
   9.187994956970215,
   9.84019947052002,
   9.160685539245605],
  [10.880391120910645, 9.187994956970215, 9.84019947052002, 9.070316314697266],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.232804298400879],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.232804298400879],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.232804298400879],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.158455848693848],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.158455848693848],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.158455848693848],
  [10.880391120910645, 9.0462646484375, 8.9961576461792, 9.072173118591309],
  [11.69300651550293, 9.3045654296875, 9.246265411376953, 9.40832233428955],
  [11.834572

In [81]:
pos_items = [item for item in export_items if item['label'] == 1 and item['has_neg']]
print(f"Filtered positive items: {len(pos_items)}")

Filtered positive items: 47553


In [82]:
pos_items[100]

{'score': 0.10578328371047974,
 'length': 4,
 'label': 1,
 'has_neg': True,
 'pos_dist': [10.618992805480957,
  11.806119918823242,
  10.071521759033203,
  11.036768913269043],
 'neg_dist': [[10.258853912353516,
   9.220107078552246,
   11.754127502441406,
   9.055917739868164],
  [10.258853912353516,
   9.220107078552246,
   11.754127502441406,
   9.055917739868164],
  [10.258853912353516,
   9.220107078552246,
   11.754127502441406,
   9.055917739868164],
  [10.258853912353516,
   9.220107078552246,
   11.754127502441406,
   9.055917739868164],
  [10.258853912353516,
   9.220107078552246,
   11.754127502441406,
   9.055917739868164],
  [10.258853912353516,
   9.220107078552246,
   11.754127502441406,
   9.055917739868164],
  [10.258853912353516,
   9.220107078552246,
   11.754127502441406,
   9.055917739868164],
  [10.258853912353516,
   9.220107078552246,
   11.754127502441406,
   10.539167404174805],
  [9.720569610595703,
   11.568655014038086,
   9.324789047241211,
   9.1733274459

In [83]:
K = 50
# Select top K items in negative highest adjsuted scores
top_neg = sorted(neg_items, key=lambda x: x['adjusted_score'], reverse=True)[:K]
print(f"Top {K} negative items: {len(top_neg)}")

Top 50 negative items: 50


In [84]:
top_neg[-1]

{'score': 0.04810476303100586,
 'length': 2,
 'label': 0,
 'has_neg': True,
 'pos_dist': [11.907910346984863, 11.800949096679688],
 'neg_dist': [[11.65769100189209, 9.036953926086426],
  [11.65769100189209, 9.158842086791992],
  [11.65769100189209, 9.09864616394043],
  [11.65769100189209, 9.19129467010498],
  [11.65769100189209, 9.616718292236328],
  [11.65769100189209, 9.363011360168457],
  [11.65769100189209, 9.611128807067871],
  [11.65769100189209, 9.94509506225586],
  [11.65769100189209, 9.872747421264648],
  [11.65769100189209, 10.084528923034668],
  [11.65769100189209, 10.084528923034668],
  [11.65769100189209, 10.397904396057129],
  [11.65769100189209, 10.541487693786621],
  [11.65769100189209, 10.541487693786621],
  [11.65769100189209, 10.26445484161377],
  [11.65769100189209, 10.238680839538574],
  [11.65769100189209, 10.544231414794922],
  [10.154446601867676, 10.298988342285156],
  [11.65769100189209, 11.149565696716309],
  [11.65769100189209, 11.149565696716309]],
 'adjust

In [None]:
top_neg_score = top_neg[-1]['adjusted_score']
print(f"Top negative score: {top_neg_score}")

Top negative score: 3.798104763031006


In [86]:
# Sort the positive items by adjusted score
top_pos = sorted(pos_items, key=lambda x: x['adjusted_score'], reverse=True)
print(f"Top positive items: {len(top_pos)}")

Top positive items: 47553


In [88]:
top_pos[-1]

{'score': 0.0,
 'length': 4,
 'label': 1,
 'has_neg': True,
 'pos_dist': [8.87608814239502,
  9.663317680358887,
  11.870468139648438,
  9.194988250732422],
 'neg_dist': [[8.87608814239502,
   9.00528335571289,
   9.917184829711914,
   9.127494812011719],
  [8.87608814239502, 9.00528335571289, 9.917184829711914, 9.127494812011719],
  [8.87608814239502, 9.00528335571289, 9.917184829711914, 9.127494812011719],
  [8.87608814239502, 9.00528335571289, 10.477001190185547, 9.342772483825684],
  [8.87608814239502, 9.00528335571289, 10.477001190185547, 9.342772483825684],
  [8.87608814239502, 9.00528335571289, 10.477001190185547, 9.342772483825684],
  [8.87608814239502, 9.00528335571289, 10.477001190185547, 9.904752731323242],
  [8.87608814239502, 9.00528335571289, 10.477001190185547, 9.904752731323242],
  [8.87608814239502, 9.00528335571289, 10.477001190185547, 9.904752731323242],
  [8.87608814239502, 9.00528335571289, 10.477001190185547, 9.904752731323242],
  [8.87608814239502, 9.005283355712

In [100]:
# How many pos items higher adjusted scores than last item in top_neg
higher_pos = [
    item for item in top_pos if item['adjusted_score'] > top_neg_score]
len(higher_pos)

28221

In [101]:
len(higher_pos) / len(pos_items)

0.593464134754905

In [212]:
top_neg[2]

{'score': 0.5525299906730652,
 'length': 1,
 'label': 0,
 'has_neg': True,
 'pos_dist': [10.340187072753906],
 'neg_dist': [[10.9002046585083],
  [10.036107063293457],
  [10.908438682556152],
  [11.678206443786621],
  [11.931611061096191],
  [9.94173812866211],
  [10.060835838317871],
  [9.959172248840332],
  [10.111923217773438],
  [10.485343933105469],
  [11.073307037353516],
  [10.134439468383789],
  [10.20603084564209],
  [10.360928535461426],
  [10.773707389831543],
  [10.060410499572754],
  [11.012125968933105],
  [9.097715377807617],
  [10.538289070129395],
  [9.367842674255371]],
 'adjusted_score': 5.552529990673065}

In [210]:
higher_pos[-1]

{'score': 0.05698955059051514,
 'length': 2,
 'label': 1,
 'has_neg': True,
 'pos_dist': [11.744134902954102, 9.308808326721191],
 'neg_dist': [[10.79800033569336, 9.019627571105957],
  [10.79800033569336, 9.019627571105957],
  [10.79800033569336, 9.019627571105957],
  [10.79800033569336, 9.019627571105957],
  [10.79800033569336, 9.019627571105957],
  [10.79800033569336, 9.019627571105957],
  [10.20862865447998, 9.302032470703125],
  [10.20862865447998, 9.302032470703125],
  [10.20862865447998, 9.254949569702148],
  [10.20862865447998, 9.254949569702148],
  [10.79800033569336, 8.999908447265625],
  [10.057108879089355, 9.060296058654785],
  [10.057108879089355, 9.060296058654785],
  [10.79800033569336, 9.187488555908203],
  [10.057108879089355, 9.345187187194824],
  [10.057108879089355, 9.345187187194824],
  [10.057108879089355, 9.345187187194824],
  [10.057108879089355, 9.345187187194824],
  [11.744134902954102, 9.238731384277344],
  [10.20862865447998, 9.516009330749512]],
 'adjusted

In [202]:
# example = top_pos[30000]
example = top_neg[-30]
print(f"Example item: {example}")

Example item: {'score': 0.8055481314659119, 'length': 2, 'label': 0, 'has_neg': True, 'pos_dist': [9.02879810333252, 11.814641952514648], 'neg_dist': [[10.189638137817383, 11.470064163208008], [10.908744812011719, 11.718202590942383], [11.769558906555176, 11.77409553527832], [11.769558906555176, 11.77409553527832], [11.769558906555176, 12.102071762084961], [11.769558906555176, 12.102071762084961], [10.075501441955566, 11.720455169677734], [10.47824478149414, 10.513073921203613], [10.485727310180664, 10.92050552368164], [11.769558906555176, 11.358835220336914], [11.769558906555176, 11.358835220336914], [11.769558906555176, 11.358835220336914], [11.769558906555176, 11.358835220336914], [10.54875373840332, 10.767017364501953], [10.389164924621582, 11.487698554992676], [10.485727310180664, 10.481659889221191], [11.002202987670898, 10.045196533203125], [11.002202987670898, 10.045196533203125], [11.002202987670898, 10.045196533203125], [11.769558906555176, 11.552481651306152]], 'adjusted_sco

In [203]:
pos = np.array(example['pos_dist'])
neg = np.array(example['neg_dist'])

In [204]:
pos.shape, neg.shape

((2,), (20, 2))

In [205]:
mean, std = np.mean(neg, axis=0), np.std(neg, axis=0, ddof=0)
print(mean, std)

[11.12470708 11.1977212 ] [0.62915026 0.65042252]


In [206]:
z = (pos - mean) / (std + 1e-8)
mean_z = np.mean(z)
print(z, mean_z)

[-3.33133286  0.94849229] -1.1914202810395529


In [207]:
# arcsinh(mean z)
arcsinh_mean_z = np.arcsinh(mean_z)
print(f"Arcsinh mean z: {arcsinh_mean_z}")

Arcsinh mean z: -1.0104689225626229
