In [93]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from loader import PathDataModule
from tqdm import tqdm

# --- 1. Configuration and Data Loading ---
config_path = 'full.json'
config_data = json.load(open(config_path, 'r'))
print(config_data)

{'dataset': 'ogbl-collab', 'storage_dir': '../data/', 'embedding_config': './full_embedding.json', 'train_ratio': 0.3, 'hidden_dim': 128, 'max_hops': 4, 'num_neg': 50, 'num_threads': 'vast', 'max_epochs': 10, 'batch_size': 'vast', 'dim_feedforward': 128, 'nhead': 4, 'num_layers': 4, 'dropout': 0.1, 'store': 'model', 'save_text_embeddings': True, 'shallow': True, 'pre_scan': ['train'], 'adjust_no_neg_paths_samples': True, 'max_adjust': 5.0, 'positive_deviation': True, 'embedding': 'all', 'test_time': 4, 'num_ckpt': 2, 'scale_loss': True, 'chi2': False, 'lr': 0.001, 'wandb_project': 'thesis-graph'}


In [94]:
import os, requests, json, pprint

try:
    cid  = os.getenv("CONTAINER_ID")          # ← set by Vast.ai inside every container
    key  = os.getenv("CONTAINER_API_KEY")     # ← scoped token for this one instance
    assert cid and key, "Not running on a Vast.ai container!"

    resp = requests.get(
        f"https://console.vast.ai/api/v0/instances/{cid}/",
        headers={"Authorization": f"Bearer {key}",
                "accept": "application/json"},
        timeout=10,
    )

    info = resp.json()
    print("Effective vCPUs:", info['instances']["cpu_cores_effective"])
    # print(info['instances'].keys())
except Exception as e:
    print("Not running on Vast.ai container, or failed to fetch instance info.")
    print("Error:", e)

Not running on Vast.ai container, or failed to fetch instance info.
Error: Not running on a Vast.ai container!


In [95]:
import json
import os

embedding_cfg_path = config_data['embedding_config'] # path to config file
print("Embedding config path:", embedding_cfg_path)
# Read config from json file at embedding cfg path
with open(embedding_cfg_path, 'r') as f:
    embedding_cfg = json.load(f)
    
embedding_cfg

Embedding config path: ./full_embedding.json


{'batch_size': 4096,
 'lr': 0.001,
 'epochs': 10,
 'model_name': 'transe',
 'hidden_channels': 128,
 'p_norm': 2}

In [96]:


# --- 2. Load exported JSON results from model.py ---

# Construct the path to your export (update these variables as needed)
save_dir = config_data['storage_dir'] + embedding_cfg['model_name'] + \
    "/" + config_data['dataset'] + "/" + config_data['wandb_project']
    
print(f"Save directory: {save_dir}")


Save directory: ../data/transe/ogbl-collab/thesis-graph


In [97]:
wandb_name = "1"      
stage = "val"  # or "test"
epoch = "0"  # e.g. "9"
test_time = False  # set True if test_time

prefix = "test" if test_time else "train"
export_path = os.path.join(save_dir, wandb_name, f"{prefix}_{stage}_{epoch}_raw.json")
print(export_path)

../data/transe/ogbl-collab/thesis-graph/1/train_val_0_raw.json


In [98]:
# Load the JSON file
with open(export_path, "r") as f:
    export_items = json.load(f)

print(f"Loaded {len(export_items)}")

Loaded 160084


In [99]:
export_items[0]

{'score': 0.10174286365509033,
 'length': 1,
 'label': 1,
 'has_neg': True,
 'pos_dist': [22.353134155273438],
 'neg_dist': [[21.08070182800293],
  [21.08070182800293],
  [21.08070182800293],
  [21.08070182800293],
  [20.94063949584961],
  [22.353134155273438],
  [22.353134155273438],
  [22.353134155273438],
  [21.128007888793945],
  [21.128007888793945],
  [21.128007888793945],
  [21.313785552978516],
  [21.313785552978516],
  [21.754823684692383],
  [21.798219680786133],
  [22.58941650390625],
  [21.19548225402832],
  [22.160743713378906],
  [21.9285831451416],
  [22.265535354614258]],
 'adjusted_score': 4.101742744445801}

In [100]:
# recompute adjusted score
max_hops = config_data['max_hops']
max_adjust = config_data['max_adjust']
print(f"Max hops: {max_hops}, Max adjust: {max_adjust}")

Max hops: 4, Max adjust: 5.0


In [101]:
# length zero adjusted to max_hops + 2
for item in tqdm(export_items):
    if item['length'] == 0:
        item['length'] = max_hops + 2

min_length = min(item['length'] for item in export_items)
print(f"Minimum length: {min_length}")

100%|██████████| 160084/160084 [00:00<00:00, 2898652.48it/s]

Minimum length: 1





In [102]:
# adjusted score
for item in tqdm(export_items):
    # ratio = 1 - (item['length'] - min_length) / (max_hops + 1 - min_length)
    # item['adjusted_score'] = item['score'] + (ratio * max_adjust)
    item['adjusted_score'] = item['score']

100%|██████████| 160084/160084 [00:00<00:00, 2575224.89it/s]


In [103]:
export_items[3]

{'score': 0.0,
 'length': 6,
 'label': 1,
 'has_neg': False,
 'pos_dist': None,
 'neg_dist': None,
 'adjusted_score': 0.0}

In [104]:
# Get list of label 0 but non-zero length
neg_items = [item for item in export_items if item['label'] == 0 and item['has_neg']]
print(f"Filtered non-zero items: {len(neg_items)}")

Filtered non-zero items: 7083


In [105]:
neg_items[200]

{'score': 0.07763171195983887,
 'length': 4,
 'label': 0,
 'has_neg': True,
 'pos_dist': [21.756620407104492,
  20.32785987854004,
  22.24242401123047,
  21.86017417907715],
 'neg_dist': [[21.756620407104492,
   20.529212951660156,
   20.60442352294922,
   20.41798973083496],
  [21.756620407104492,
   20.529212951660156,
   20.60442352294922,
   20.450239181518555],
  [22.24396324157715, 20.29498863220215, 20.3772029876709, 20.47519302368164],
  [22.24396324157715, 20.29498863220215, 20.3772029876709, 20.47519302368164],
  [22.24396324157715, 20.29498863220215, 20.3772029876709, 20.47519302368164],
  [22.24396324157715, 20.29498863220215, 20.3772029876709, 20.376419067382812],
  [22.24396324157715, 20.29498863220215, 20.3772029876709, 20.376419067382812],
  [22.24396324157715, 20.29498863220215, 20.3772029876709, 20.376419067382812],
  [21.756620407104492,
   20.32785987854004,
   20.29673194885254,
   20.334518432617188],
  [22.20726203918457,
   20.6400089263916,
   20.39439392089843

In [106]:
pos_items = [item for item in export_items if item['label'] == 1 and item['has_neg']]
print(f"Filtered positive items: {len(pos_items)}")

Filtered positive items: 47553


In [107]:
pos_items[0]

{'score': 0.10174286365509033,
 'length': 1,
 'label': 1,
 'has_neg': True,
 'pos_dist': [22.353134155273438],
 'neg_dist': [[21.08070182800293],
  [21.08070182800293],
  [21.08070182800293],
  [21.08070182800293],
  [20.94063949584961],
  [22.353134155273438],
  [22.353134155273438],
  [22.353134155273438],
  [21.128007888793945],
  [21.128007888793945],
  [21.128007888793945],
  [21.313785552978516],
  [21.313785552978516],
  [21.754823684692383],
  [21.798219680786133],
  [22.58941650390625],
  [21.19548225402832],
  [22.160743713378906],
  [21.9285831451416],
  [22.265535354614258]],
 'adjusted_score': 0.10174286365509033}

In [108]:
K = 50
# Select top K items in negative highest adjsuted scores
top_neg = sorted(neg_items, key=lambda x: x['adjusted_score'], reverse=True)[:K]
print(f"Top {K} negative items: {len(top_neg)}")

Top 50 negative items: 50


In [109]:
top_neg[-1]

{'score': 1.0,
 'length': 4,
 'label': 0,
 'has_neg': True,
 'pos_dist': [20.422826766967773,
  22.248979568481445,
  21.150657653808594,
  20.23993682861328],
 'neg_dist': [[20.422826766967773,
   21.050504684448242,
   20.53862762451172,
   20.229887008666992],
  [20.422826766967773,
   21.050504684448242,
   20.53862762451172,
   20.229887008666992],
  [20.422826766967773,
   20.60368537902832,
   20.368629455566406,
   20.921024322509766],
  [20.422826766967773,
   20.60368537902832,
   20.368629455566406,
   21.689128875732422],
  [20.422826766967773,
   20.60368537902832,
   20.368629455566406,
   21.322595596313477],
  [20.422826766967773,
   20.60368537902832,
   20.368629455566406,
   20.888479232788086],
  [20.422826766967773,
   21.050504684448242,
   20.498796463012695,
   22.062734603881836],
  [20.422826766967773,
   21.050504684448242,
   20.498796463012695,
   22.062734603881836],
  [20.422826766967773,
   20.76075553894043,
   20.410066604614258,
   21.19136619567871],

In [110]:
top_neg_score = top_neg[-1]['adjusted_score']
print(f"Top negative score: {top_neg_score}")

Top negative score: 1.0


In [111]:
top_neg[-1]['neg_dist']

[[20.422826766967773,
  21.050504684448242,
  20.53862762451172,
  20.229887008666992],
 [20.422826766967773,
  21.050504684448242,
  20.53862762451172,
  20.229887008666992],
 [20.422826766967773,
  20.60368537902832,
  20.368629455566406,
  20.921024322509766],
 [20.422826766967773,
  20.60368537902832,
  20.368629455566406,
  21.689128875732422],
 [20.422826766967773,
  20.60368537902832,
  20.368629455566406,
  21.322595596313477],
 [20.422826766967773,
  20.60368537902832,
  20.368629455566406,
  20.888479232788086],
 [20.422826766967773,
  21.050504684448242,
  20.498796463012695,
  22.062734603881836],
 [20.422826766967773,
  21.050504684448242,
  20.498796463012695,
  22.062734603881836],
 [20.422826766967773,
  20.76075553894043,
  20.410066604614258,
  21.19136619567871],
 [20.422826766967773,
  20.76075553894043,
  20.410066604614258,
  21.19136619567871],
 [20.422826766967773,
  21.050504684448242,
  20.573421478271484,
  21.43611717224121],
 [20.422826766967773,
  21.05050

In [112]:
# Sort the positive items by adjusted score
top_pos = sorted(pos_items, key=lambda x: x['adjusted_score'], reverse=True)
print(f"Top positive items: {len(top_pos)}")

Top positive items: 47553


In [113]:
top_pos[-1]

{'score': 0.0,
 'length': 3,
 'label': 1,
 'has_neg': True,
 'pos_dist': [20.20102882385254, 22.482969284057617, 20.511489868164062],
 'neg_dist': [[20.20102882385254, 21.71308708190918, 20.47885513305664],
  [20.20102882385254, 21.71308708190918, 20.47885513305664],
  [20.20102882385254, 21.71308708190918, 20.47885513305664],
  [20.20102882385254, 21.71308708190918, 21.766225814819336],
  [20.20102882385254, 21.71308708190918, 21.766225814819336],
  [20.20102882385254, 21.71308708190918, 21.766225814819336],
  [20.20102882385254, 21.71308708190918, 20.76085090637207],
  [20.20102882385254, 21.71308708190918, 20.76085090637207],
  [20.20102882385254, 21.71308708190918, 20.76085090637207],
  [20.20102882385254, 21.71308708190918, 21.326335906982422],
  [20.20102882385254, 21.71308708190918, 21.326335906982422],
  [20.20102882385254, 21.71308708190918, 21.326335906982422],
  [20.20102882385254, 21.71308708190918, 21.37709617614746],
  [20.20102882385254, 21.71308708190918, 21.37709617614

In [114]:
# How many pos items higher adjusted scores than last item in top_neg
higher_pos = [
    item for item in top_pos if item['adjusted_score'] > top_neg_score]
len(higher_pos)

0

In [115]:
len(higher_pos) / len(pos_items)

0.0

In [116]:
top_neg[2]

{'score': 1.0,
 'length': 4,
 'label': 0,
 'has_neg': True,
 'pos_dist': [22.197872161865234,
  20.35811996459961,
  22.22220230102539,
  20.43638801574707],
 'neg_dist': [[22.197872161865234,
   20.62787437438965,
   20.551225662231445,
   20.351850509643555],
  [22.197872161865234,
   20.62787437438965,
   20.551225662231445,
   20.351850509643555],
  [22.197872161865234,
   20.62787437438965,
   20.551225662231445,
   20.351850509643555],
  [22.197872161865234,
   20.62787437438965,
   20.551225662231445,
   20.351850509643555],
  [22.197872161865234,
   20.62787437438965,
   20.551225662231445,
   20.351850509643555],
  [22.197872161865234,
   20.35811996459961,
   20.552207946777344,
   20.958778381347656],
  [22.197872161865234,
   20.408058166503906,
   20.686731338500977,
   21.594724655151367],
  [22.197872161865234,
   20.35811996459961,
   20.420076370239258,
   21.008882522583008],
  [22.197872161865234,
   20.601146697998047,
   20.30323028564453,
   21.058637619018555],
 

In [117]:
higher_pos[-1]

IndexError: list index out of range

In [118]:
# example = top_pos[40000]
example = top_neg[-2]
print(f"Example item: {example}")

Example item: {'score': 1.0, 'length': 4, 'label': 0, 'has_neg': True, 'pos_dist': [22.09030532836914, 20.504453659057617, 22.31141471862793, 21.428071975708008], 'neg_dist': [[22.09030532836914, 21.486425399780273, 20.59282875061035, 20.512922286987305], [22.09030532836914, 21.486425399780273, 20.59282875061035, 20.512922286987305], [22.09030532836914, 21.486425399780273, 20.59282875061035, 20.357187271118164], [22.09030532836914, 21.486425399780273, 20.59282875061035, 20.42481803894043], [22.09030532836914, 21.486425399780273, 20.22919273376465, 20.083999633789062], [22.09030532836914, 22.296140670776367, 20.83771514892578, 20.6279354095459], [22.09030532836914, 21.486425399780273, 20.59282875061035, 20.515310287475586], [22.09030532836914, 21.486425399780273, 20.59282875061035, 20.515310287475586], [22.09030532836914, 22.296140670776367, 20.83771514892578, 21.211776733398438], [22.09030532836914, 22.296140670776367, 20.496936798095703, 21.0692138671875], [22.09030532836914, 21.48642

In [119]:
import torch

In [120]:
pos = torch.tensor(example['pos_dist']) # (length, )
neg = torch.tensor(example['neg_dist'])  # (num_neg, length)

In [121]:
if config_data["positive_deviation"]:
    # Add pos (length, ) to neg pool (num_neg, length)
    neg = torch.cat([neg, pos.unsqueeze(0)], dim=0)

In [122]:
pos.shape, neg.shape

(torch.Size([4]), torch.Size([21, 4]))

In [123]:
mean, std = neg.mean(0), neg.std(0, correction=0)
print(mean, std)

tensor([22.0903, 21.7096, 20.6324, 20.8592]) tensor([0.0000, 0.4633, 0.4196, 0.4275])


In [124]:
z = (pos - mean) / (std + 1e-8)
mean_z = z.mean()
print(z, mean_z)

tensor([190.7349,  -2.6010,   4.0015,   1.3306]) tensor(48.3665)


In [125]:
# arcsinh(mean z)
arcsinh_mean_z = torch.asinh(mean_z)
print(f"Arcsinh mean z: {arcsinh_mean_z}")

Arcsinh mean z: 4.572061538696289


In [126]:
percentile_pos = 1.0 - torch.special.ndtr(mean_z).item()

In [127]:
print(f"Percentile of positive item: {percentile_pos:.4f}")

Percentile of positive item: 0.0000
