In [1]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from loader import PathDataModule
from tqdm import tqdm

# --- 1. Configuration and Data Loading ---
config_path = 'full.json'
config_data = json.load(open(config_path, 'r'))
print(config_data)

{'dataset': 'ogbl-collab', 'storage_dir': '../data/', 'embedding_config': './full_embedding.json', 'train_ratio': 0.3, 'hidden_dim': 128, 'max_hops': 4, 'num_neg': 50, 'num_threads': 'vast', 'max_epochs': 10, 'batch_size': 'vast', 'dim_feedforward': 128, 'nhead': 4, 'num_layers': 4, 'dropout': 0.1, 'store': 'model', 'save_text_embeddings': True, 'shallow': True, 'pre_scan': ['train'], 'adjust_no_neg_paths_samples': True, 'max_adjust': 5.0, 'positive_deviation': True, 'embedding': 'all', 'test_time': 4, 'num_ckpt': 2, 'scale_loss': True, 'chi2': False, 'lr': 0.001, 'wandb_project': 'thesis-graph'}


In [2]:
import os, requests, json, pprint

try:
    cid  = os.getenv("CONTAINER_ID")          # ← set by Vast.ai inside every container
    key  = os.getenv("CONTAINER_API_KEY")     # ← scoped token for this one instance
    assert cid and key, "Not running on a Vast.ai container!"

    resp = requests.get(
        f"https://console.vast.ai/api/v0/instances/{cid}/",
        headers={"Authorization": f"Bearer {key}",
                "accept": "application/json"},
        timeout=10,
    )

    info = resp.json()
    print("Effective vCPUs:", info['instances']["cpu_cores_effective"])
    # print(info['instances'].keys())
except Exception as e:
    print("Not running on Vast.ai container, or failed to fetch instance info.")
    print("Error:", e)

Not running on Vast.ai container, or failed to fetch instance info.
Error: Not running on a Vast.ai container!


In [3]:
import json
import os

embedding_cfg_path = config_data['embedding_config'] # path to config file
print("Embedding config path:", embedding_cfg_path)
# Read config from json file at embedding cfg path
with open(embedding_cfg_path, 'r') as f:
    embedding_cfg = json.load(f)
    
embedding_cfg

Embedding config path: ./full_embedding.json


{'batch_size': 4096,
 'lr': 0.001,
 'epochs': 10,
 'model_name': 'transe',
 'hidden_channels': 128,
 'p_norm': 2}

In [4]:


# --- 2. Load exported JSON results from model.py ---

# Construct the path to your export (update these variables as needed)
save_dir = config_data['storage_dir'] + embedding_cfg['model_name'] + \
    "/" + config_data['dataset'] + "/" + config_data['wandb_project']
    
print(f"Save directory: {save_dir}")


Save directory: ../data/transe/ogbl-collab/thesis-graph


In [5]:
wandb_name = "0"      
stage = "val"  # or "test"
epoch = "0"  # e.g. "9"
test_time = False  # set True if test_time

prefix = "test" if test_time else "train"
export_path = os.path.join(save_dir, wandb_name, f"{prefix}_{stage}_{epoch}_raw.json")
print(export_path)

../data/transe/ogbl-collab/thesis-graph/0/train_val_0_raw.json


In [6]:
# Load the JSON file
with open(export_path, "r") as f:
    export_items = json.load(f)

print(f"Loaded {len(export_items)}")

Loaded 160084


In [7]:
export_items[0]

{'score': 0.09943151473999023,
 'length': 1,
 'label': 1,
 'has_neg': True,
 'pos_dist': [11.837754249572754],
 'neg_dist': [[10.281356811523438],
  [10.281356811523438],
  [10.281356811523438],
  [10.281356811523438],
  [10.34947681427002],
  [11.837754249572754],
  [11.837754249572754],
  [11.837754249572754],
  [10.65380859375],
  [10.65380859375],
  [10.65380859375],
  [9.960492134094238],
  [9.960492134094238],
  [11.205546379089355],
  [11.113570213317871],
  [11.833039283752441],
  [10.587995529174805],
  [11.491482734680176],
  [11.343835830688477],
  [11.866433143615723]],
 'adjusted_score': 4.09943151473999}

In [8]:
# recompute adjusted score
max_hops = config_data['max_hops']
max_adjust = config_data['max_adjust']
print(f"Max hops: {max_hops}, Max adjust: {max_adjust}")

Max hops: 4, Max adjust: 5.0


In [9]:
# length zero adjusted to max_hops + 2
for item in tqdm(export_items):
    if item['length'] == 0:
        item['length'] = max_hops + 2

min_length = min(item['length'] for item in export_items)
print(f"Minimum length: {min_length}")

100%|██████████| 160084/160084 [00:03<00:00, 47345.29it/s]


Minimum length: 1


In [10]:
# adjusted score
for item in tqdm(export_items):
    # ratio = 1 - (item['length'] - min_length) / (max_hops + 1 - min_length)
    # item['adjusted_score'] = item['score'] + (ratio * max_adjust)
    item['adjusted_score'] = item['score']

100%|██████████| 160084/160084 [00:01<00:00, 102707.57it/s]


In [11]:
export_items[3]

{'score': 0.0,
 'length': 6,
 'label': 1,
 'has_neg': False,
 'pos_dist': None,
 'neg_dist': None,
 'adjusted_score': 0.0}

In [12]:
# Get list of label 0 but non-zero length
neg_items = [item for item in export_items if item['label'] == 0 and item['has_neg']]
print(f"Filtered non-zero items: {len(neg_items)}")

Filtered non-zero items: 7083


In [13]:
neg_items[200]

{'score': 0.1156584620475769,
 'length': 4,
 'label': 0,
 'has_neg': True,
 'pos_dist': [10.880391120910645,
  9.0462646484375,
  11.833160400390625,
  11.009916305541992],
 'neg_dist': [[10.880391120910645,
   9.187994956970215,
   9.84019947052002,
   9.160685539245605],
  [10.880391120910645, 9.187994956970215, 9.84019947052002, 9.070316314697266],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.232804298400879],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.232804298400879],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.232804298400879],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.158455848693848],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.158455848693848],
  [11.834572792053223, 9.037266731262207, 9.06799030303955, 9.158455848693848],
  [10.880391120910645, 9.0462646484375, 8.9961576461792, 9.072173118591309],
  [11.69300651550293, 9.3045654296875, 9.246265411376953, 9.40832233428955],
  [11.834572

In [14]:
pos_items = [item for item in export_items if item['label'] == 1 and item['has_neg']]
print(f"Filtered positive items: {len(pos_items)}")

Filtered positive items: 47553


In [25]:
pos_items[0]

{'score': 0.09943151473999023,
 'length': 1,
 'label': 1,
 'has_neg': True,
 'pos_dist': [11.837754249572754],
 'neg_dist': [[10.281356811523438],
  [10.281356811523438],
  [10.281356811523438],
  [10.281356811523438],
  [10.34947681427002],
  [11.837754249572754],
  [11.837754249572754],
  [11.837754249572754],
  [10.65380859375],
  [10.65380859375],
  [10.65380859375],
  [9.960492134094238],
  [9.960492134094238],
  [11.205546379089355],
  [11.113570213317871],
  [11.833039283752441],
  [10.587995529174805],
  [11.491482734680176],
  [11.343835830688477],
  [11.866433143615723]],
 'adjusted_score': 0.09943151473999023}

In [16]:
K = 50
# Select top K items in negative highest adjsuted scores
top_neg = sorted(neg_items, key=lambda x: x['adjusted_score'], reverse=True)[:K]
print(f"Top {K} negative items: {len(top_neg)}")

Top 50 negative items: 50


In [27]:
top_neg[-1]

{'score': 1.0,
 'length': 4,
 'label': 0,
 'has_neg': True,
 'pos_dist': [9.08548641204834,
  11.88125228881836,
  11.885870933532715,
  9.091623306274414],
 'neg_dist': [[9.08548641204834,
   9.949609756469727,
   9.588653564453125,
   8.944148063659668],
  [9.08548641204834, 9.949609756469727, 10.210198402404785, 9.128499031066895],
  [9.08548641204834, 10.216167449951172, 10.950872421264648, 9.19459056854248],
  [9.08548641204834,
   10.216167449951172,
   10.367166519165039,
   10.207352638244629],
  [9.08548641204834,
   10.216167449951172,
   10.367166519165039,
   10.207352638244629],
  [9.08548641204834,
   10.216167449951172,
   10.367166519165039,
   9.508516311645508],
  [9.08548641204834,
   10.216167449951172,
   10.367166519165039,
   9.508516311645508],
  [9.08548641204834,
   10.216167449951172,
   10.950872421264648,
   10.500466346740723],
  [9.08548641204834, 9.949609756469727, 9.588653564453125, 9.90206241607666],
  [9.08548641204834,
   10.216167449951172,
   10.75

In [18]:
top_neg_score = top_neg[-1]['adjusted_score']
print(f"Top negative score: {top_neg_score}")

Top negative score: 1.0


In [None]:
top_neg[-1]['neg_dist']

[[9.08548641204834, 9.949609756469727, 9.588653564453125, 8.944148063659668],
 [9.08548641204834, 9.949609756469727, 10.210198402404785, 9.128499031066895],
 [9.08548641204834, 10.216167449951172, 10.950872421264648, 9.19459056854248],
 [9.08548641204834,
  10.216167449951172,
  10.367166519165039,
  10.207352638244629],
 [9.08548641204834,
  10.216167449951172,
  10.367166519165039,
  10.207352638244629],
 [9.08548641204834, 10.216167449951172, 10.367166519165039, 9.508516311645508],
 [9.08548641204834, 10.216167449951172, 10.367166519165039, 9.508516311645508],
 [9.08548641204834,
  10.216167449951172,
  10.950872421264648,
  10.500466346740723],
 [9.08548641204834, 9.949609756469727, 9.588653564453125, 9.90206241607666],
 [9.08548641204834,
  10.216167449951172,
  10.758651733398438,
  10.691717147827148],
 [9.08548641204834, 10.216167449951172, 10.367166519165039, 9.925450325012207],
 [9.08548641204834, 10.216167449951172, 10.367166519165039, 9.925450325012207],
 [9.08548641204834,

In [19]:
# Sort the positive items by adjusted score
top_pos = sorted(pos_items, key=lambda x: x['adjusted_score'], reverse=True)
print(f"Top positive items: {len(top_pos)}")

Top positive items: 47553


In [20]:
top_pos[-1]

{'score': 0.0,
 'length': 3,
 'label': 1,
 'has_neg': True,
 'pos_dist': [9.50780200958252, 11.829113006591797, 9.670458793640137],
 'neg_dist': [[9.50780200958252, 11.829113006591797, 9.108136177062988],
  [9.50780200958252, 11.829113006591797, 9.115540504455566],
  [9.50780200958252, 11.829113006591797, 9.06275749206543],
  [9.50780200958252, 11.829113006591797, 9.06275749206543],
  [9.50780200958252, 11.829113006591797, 9.06275749206543],
  [9.50780200958252, 11.829113006591797, 9.031770706176758],
  [9.50780200958252, 11.829113006591797, 9.006580352783203],
  [9.50780200958252, 11.829113006591797, 9.006580352783203],
  [9.50780200958252, 11.829113006591797, 9.09538745880127],
  [9.50780200958252, 11.829113006591797, 9.09538745880127],
  [9.50780200958252, 11.829113006591797, 8.990823745727539],
  [9.50780200958252, 11.829113006591797, 9.201242446899414],
  [9.50780200958252, 11.829113006591797, 9.042387008666992],
  [9.50780200958252, 11.829113006591797, 9.157793045043945],
  [9.50

In [21]:
# How many pos items higher adjusted scores than last item in top_neg
higher_pos = [
    item for item in top_pos if item['adjusted_score'] > top_neg_score]
len(higher_pos)

0

In [22]:
len(higher_pos) / len(pos_items)

0.0

In [23]:
top_neg[2]

{'score': 1.0,
 'length': 4,
 'label': 0,
 'has_neg': True,
 'pos_dist': [10.178739547729492,
  9.065033912658691,
  11.839302062988281,
  11.906525611877441],
 'neg_dist': [[10.178739547729492,
   9.065033912658691,
   10.536744117736816,
   8.995314598083496],
  [10.178739547729492,
   9.065033912658691,
   10.960749626159668,
   9.010370254516602],
  [10.178739547729492,
   9.065033912658691,
   10.566941261291504,
   9.425470352172852],
  [10.178739547729492,
   9.065033912658691,
   10.566941261291504,
   9.425470352172852],
  [10.178739547729492,
   9.065033912658691,
   10.960749626159668,
   9.716288566589355],
  [10.178739547729492,
   9.065033912658691,
   9.791705131530762,
   10.57628059387207],
  [10.178739547729492,
   9.065033912658691,
   9.791705131530762,
   10.57628059387207],
  [10.178739547729492,
   9.065033912658691,
   10.536744117736816,
   9.57210636138916],
  [10.178739547729492,
   9.065033912658691,
   10.906560897827148,
   10.175162315368652],
  [10.17873

In [24]:
higher_pos[-1]

IndexError: list index out of range

In [83]:
# example = top_pos[40000]
example = top_neg[-2]
print(f"Example item: {example}")

Example item: {'score': 1.0, 'length': 4, 'label': 0, 'has_neg': True, 'pos_dist': [9.06659984588623, 11.812776565551758, 9.020383834838867, 10.778009414672852], 'neg_dist': [[9.06659984588623, 9.067873001098633, 9.627302169799805, 9.019967079162598], [9.06659984588623, 9.067873001098633, 9.627302169799805, 9.019967079162598], [9.06659984588623, 9.146127700805664, 9.094917297363281, 8.964913368225098], [9.06659984588623, 9.146127700805664, 9.094917297363281, 8.964913368225098], [9.06659984588623, 9.146127700805664, 9.094917297363281, 8.964913368225098], [9.06659984588623, 9.146127700805664, 9.094917297363281, 8.964913368225098], [9.06659984588623, 9.146127700805664, 9.094917297363281, 8.964913368225098], [9.06659984588623, 9.146127700805664, 9.094917297363281, 8.964913368225098], [9.06659984588623, 9.5545072555542, 9.07029914855957, 9.351490020751953], [9.06659984588623, 9.5545072555542, 9.07029914855957, 9.351490020751953], [9.06659984588623, 9.5545072555542, 9.07029914855957, 9.35149

In [84]:
import torch

In [85]:
pos = torch.tensor(example['pos_dist']) # (length, )
neg = torch.tensor(example['neg_dist'])  # (num_neg, length)

In [86]:
if config_data["positive_deviation"]:
    # Add pos (length, ) to neg pool (num_neg, length)
    neg = torch.cat([neg, pos.unsqueeze(0)], dim=0)

In [87]:
pos.shape, neg.shape

(torch.Size([4]), torch.Size([21, 4]))

In [88]:
mean, std = neg.mean(0), neg.std(0, correction=0)
print(mean, std)

tensor([9.0666, 9.4990, 9.1280, 9.1827]) tensor([0.0000, 0.5568, 0.1628, 0.3776])


In [89]:
z = (pos - mean) / (std + 1e-8)
mean_z = z.mean()
print(z, mean_z)

tensor([ 0.0000,  4.1558, -0.6610,  4.2252]) tensor(1.9300)


In [90]:
# arcsinh(mean z)
arcsinh_mean_z = torch.asinh(mean_z)
print(f"Arcsinh mean z: {arcsinh_mean_z}")

Arcsinh mean z: 1.4118947982788086
