# Quantifying Relations in Lower Dimensional Subspace
  
  - Finding a lower dimensional subspace wherein 95% of variance is captured for a certain relation.
  - Checking if upon projecting the relations on that subspace and measuring the orthogonal leakage, is the correct option satisfying the pair closest.
  - Finding nearby terms in the embedded space for a certain word.  


## Setup
 - Importing the libraries, global parameters etc
 - Different Types of Embeddings for the pairs:
    - Word2Vec
    - Glove
    - BERT
    - Roberta
    - SBERT
    - LABSE

### File Imports

In [1]:
# --- Core Libraries for Data Handling and Numerics ---
import io
import re
import os
import json
import zipfile
import requests
import shutil
import subprocess
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --- TensorFlow and TF-Hub for Word2Vec ---
import tensorflow as tf
import tensorflow_hub as hub

# --- PyTorch and Hugging Face Transformers for BERT ---
import torch
from transformers import AutoTokenizer, AutoModel

print("All libraries imported successfully.")

All libraries imported successfully.


### Global Parameters and Word Pair Definitions

In [2]:
# --- Analysis Parameters ---
K_MAX = 64                      # The maximum subspace dimension to evaluate
VAR_TARGET = 0.95               # The target explained variance for choosing our optimal dimension, k*
NORMALIZE_ROWS = True           # Whether to L2-normalize embeddings before analysis
NUM_SUBSPACES = 3               # The number of subspaces to extract using the deflation method
K_BLOCK = 5                     # The dimension of each block for the deflation method

In [3]:
from typing import List, Tuple

# =============================================================
# Each relation below contains EXACTLY 60 unique (head, tail) pairs,
# crafted so that, within a relation, all examples come from a single
# coherent domain to help MCQ scoring.
# Token style: snake_case, ASCII only.
# =============================================================

# 1) KARTA — Agent → Verb  (domain: healthcare & life‑sciences roles)
KARTA_60: List[Tuple[str, str]] = [
    ("surgeon", "operate"), ("physician", "diagnose"), ("nurse", "care"), ("anesthetist", "sedate"),
    ("radiologist", "scan"), ("cardiologist", "treat"), ("neurologist", "assess"), ("oncologist", "chemotherapy"),
    ("dermatologist", "examine"), ("pediatrician", "vaccinate"), ("gynecologist", "deliver"), ("urologist", "treat"),
    ("orthopedist", "set_bone"), ("pathologist", "analyze"), ("microbiologist", "culture"), ("virologist", "sequence"),
    ("immunologist", "test"), ("hematologist", "monitor"), ("endocrinologist", "regulate"), ("nephrologist", "dialyze"),
    ("gastroenterologist", "scope"), ("pulmonologist", "ventilate"), ("psychiatrist", "counsel"), ("psychologist", "evaluate"),
    ("physiotherapist", "rehabilitate"), ("occupational_therapist", "train"), ("speech_therapist", "coach"), ("pharmacist", "dispense"),
    ("dietitian", "plan"), ("paramedic", "stabilize"), ("emt", "transport"), ("dentist", "extract"),
    ("dental_hygienist", "clean"), ("optometrist", "prescribe"), ("ophthalmologist", "operate"), ("audiologist", "assess"),
    ("lab_technologist", "process"), ("biostatistician", "analyze"), ("epidemiologist", "model"), ("genetic_counselor", "advise"),
    ("bioinformatician", "compute"), ("research_scientist", "experiment"), ("clinical_researcher", "trial"), ("toxicologist", "test"),
    ("public_health_officer", "monitor"), ("health_inspector", "inspect"), ("medical_coder", "code"), ("medical_scribe", "document"),
    ("hospitalist", "coordinate"), ("triage_nurse", "prioritize"), ("scrub_nurse", "assist"), ("circulating_nurse", "manage"),
    ("icu_physician", "intubate"), ("neonatologist", "resuscitate"), ("palliative_care_physician", "comfort"), ("geriatrician", "support"),
    ("infection_control_specialist", "contain"), ("wound_care_nurse", "dress"), ("midwife", "deliver"), ("respiratory_therapist", "nebulize")
]

# 2) KARMA — Patient/Object → Verb  (domain: cooking & kitchen tasks)
KARMA_60: List[Tuple[str, str]] = [
    ("onion", "chop"), ("garlic", "mince"), ("tomato", "dice"), ("potato", "peel"),
    ("carrot", "julienne"), ("spinach", "saute"), ("mushroom", "saute"), ("egg", "whisk"),
    ("batter", "mix"), ("dough", "knead"), ("bread", "slice"), ("steak", "grill"),
    ("fish", "pan_sear"), ("chicken", "roast"), ("turkey", "roast"), ("pork", "braise"),
    ("tofu", "press"), ("paneer", "cube"), ("noodles", "boil"), ("pasta", "boil"),
    ("rice", "steam"), ("quinoa", "rinse"), ("lentils", "simmer"), ("beans", "soak"),
    ("soup", "simmer"), ("sauce", "reduce"), ("stock", "strain"), ("gravy", "thicken"),
    ("butter", "clarify"), ("cream", "whip"), ("custard", "set"), ("pudding", "chill"),
    ("cake", "bake"), ("cookies", "bake"), ("pie", "bake"), ("pastry", "laminate"),
    ("vegetables", "blanch"), ("peppers", "roast"), ("eggplant", "char"), ("corn", "shuck"),
    ("peas", "shell"), ("herbs", "infuse"), ("spices", "toast"), ("nuts", "toast"),
    ("seeds", "grind"), ("coffee_beans", "grind"), ("tea_leaves", "steep"), ("garam_masala", "blend"),
    ("salad", "toss"), ("fries", "deep_fry"), ("fritters", "deep_fry"), ("pancakes", "flip"),
    ("waffles", "press"), ("sandwich", "assemble"), ("burger", "assemble"), ("tortilla", "warm"),
    ("salsa", "pulse"), ("guacamole", "mash"), ("sushi_rice", "season"), ("dumplings", "steam")
]

# 3) KARANA — Instrument/Means → Verb  (domain: kitchen tools)
KARANA_60: List[Tuple[str, str]] = [
    ("chef_knife", "slice"), ("paring_knife", "peel"), ("bread_knife", "saw"), ("cleaver", "chop"),
    ("mandoline", "shave"), ("peeler", "peel"), ("grater", "grate"), ("zester", "zest"),
    ("mortar", "pound"), ("pestle", "crush"), ("spatula", "flip"), ("tongs", "grip"),
    ("whisk", "whisk"), ("ladle", "serve"), ("skimmer", "skim"), ("slotted_spoon", "drain"),
    ("wooden_spoon", "stir"), ("silicone_spatula", "scrape"), ("colander", "drain"), ("sieve", "sift"),
    ("rolling_pin", "roll"), ("bench_scraper", "portion"), ("pastry_brush", "glaze"), ("measuring_cup", "measure"),
    ("measuring_spoon", "measure"), ("scale", "weigh"), ("thermometer", "probe"), ("timer", "time"),
    ("cast_iron_skillet", "sear"), ("nonstick_pan", "fry"), ("saucepan", "simmer"), ("stockpot", "boil"),
    ("wok", "stir_fry"), ("pressure_cooker", "pressure_cook"), ("slow_cooker", "braise"), ("rice_cooker", "steam"),
    ("oven", "bake"), ("toaster", "toast"), ("microwave", "reheat"), ("air_fryer", "crisp"),
    ("immersion_blender", "puree"), ("stand_mixer", "mix"), ("food_processor", "pulse"), ("blender", "blend"),
    ("espresso_machine", "brew"), ("coffee_grinder", "grind"), ("kettle", "boil"), ("sous_vide_circulator", "circulate"),
    ("sheet_pan", "roast"), ("baking_stone", "crisp"), ("dutch_oven", "braise"), ("roasting_rack", "elevate"),
    ("thermo_probe", "monitor"), ("butcher_block", "butcher"), ("fish_spatula", "turn"), ("pasta_maker", "extrude"),
    ("pizza_peel", "launch"), ("cookie_cutter", "cut"), ("piping_bag", "pipe"), ("ice_cream_scoop", "scoop")
]

# 4) SAMPRADANA — Recipient → Benefaction verb  (domain: education & grants; no "receive")
SAMPRADANA_60: List[Tuple[str, str]] = [
    ("student", "earn"), ("scholar", "win"), ("applicant", "secure"), ("candidate", "obtain"),
    ("intern", "gain"), ("trainee", "acquire"), ("fellow", "attain"), ("researcher", "garner"),
    ("postdoc", "obtain"), ("professor", "collect"), ("lab", "acquire"), ("department", "award"),
    ("university", "attract"), ("school", "attract"), ("library", "accept"), ("museum", "accept"),
    ("nonprofit", "receive_funding"), ("ngo", "secure_funding"), ("community_center", "collect"), ("orphanage", "accept"),
    ("team", "claim"), ("club", "claim"), ("startup", "raise"), ("incubator", "allocate"),
    ("accelerator", "allocate"), ("foundation", "grant"), ("charity", "collect"), ("trust", "endow"),
    ("athlete", "earn"), ("coach", "obtain"), ("teacher", "obtain"), ("schoolchild", "win"),
    ("valedictorian", "earn"), ("salutatorian", "earn"), ("graduate", "secure"), ("undergraduate", "secure"),
    ("exchange_student", "obtain"), ("visiting_scholar", "obtain"), ("principal_investigator", "win"), ("co_investigator", "win"),
    ("open_source_project", "receive_sponsorship"), ("lab_consortium", "pool"), ("student_club", "fundraise"), ("robotics_team", "fundraise"),
    ("debate_team", "win"), ("math_circle", "win"), ("hackathon_team", "claim"), ("art_collective", "obtain"),
    ("youth_program", "secure"), ("music_school", "accept"), ("dance_troupe", "accept"), ("theater_group", "accept"),
    ("science_fair_winner", "claim"), ("olympiad_winner", "claim"), ("grant_holder", "renew"), ("stipend_holder", "renew"),
    ("faculty", "receive_award"), ("alumni_association", "collect"), ("parent_teacher_association", "collect"), ("education_board", "allocate")
]

# 5) APADANA — Source → Separation/Origin verb  (domain: infrastructure & fluids)
APADANA_60: List[Tuple[str, str]] = [
    ("reservoir", "release"), ("dam", "spill"), ("spillway", "discharge"), ("canal", "drain"),
    ("sluice", "vent"), ("culvert", "drain"), ("storm_drain", "outflow"), ("gutter", "drip"),
    ("downspout", "pour"), ("roof", "runoff"), ("pipe", "leak"), ("nozzle", "spray"),
    ("hydrant", "discharge"), ("valve", "bleed"), ("radiator", "vent"), ("chimney", "emit"),
    ("smokestack", "emit"), ("exhaust_duct", "exhaust"), ("vent", "exhale"), ("air_outlet", "blow"),
    ("compressor", "purge"), ("boiler", "blowdown"), ("cooling_tower", "evaporate"), ("condenser", "drain"),
    ("evaporator", "drip"), ("sprinkler", "spray"), ("fountain", "jet"), ("spring", "bubble"),
    ("well", "draw"), ("aquifer", "seep"), ("sewer", "discharge"), ("treatment_plant", "effuse"),
    ("filter", "backwash"), ("settling_tank", "decant"), ("separator", "vent"), ("reactor", "vent"),
    ("smelter", "emit"), ("kiln", "vent"), ("flare_stack", "burnoff"), ("gasometer", "release"),
    ("oil_rig", "vent"), ("pipeline", "bleed"), ("manifold", "divert"), ("header", "distribute"),
    ("fuel_tank", "drain"), ("radiator_core", "weap"), ("heat_exchanger", "leak"), ("conduit", "drain"),
    ("carburetor", "drip"), ("injector", "atomize"), ("atomizer", "mist"), ("aerosol_can", "spray"),
    ("paint_gun", "spray"), ("hose", "spurt"), ("nozzle_tip", "jet"), ("spray_bar", "fan"),
    ("cloud", "rain"), ("glacier", "calve"), ("volcano", "erupt"), ("delta", "shed")
]

# 6) ADHIKARANA — Locus/Location → Occurrence verb  (domain: medical & lab settings)
ADHIKARANA_60: List[Tuple[str, str]] = [
    ("operating_room", "operate"), ("icu", "stabilize"), ("emergency_room", "triage"), ("ward", "recover"),
    ("clinic", "treat"), ("pharmacy", "dispense"), ("radiology_suite", "scan"), ("cat_lab", "image"),
    ("mri_room", "image"), ("ct_room", "image"), ("ultrasound_suite", "scan"), ("lab_bench", "analyze"),
    ("biosafety_cabinet", "culture"), ("incubator", "grow"), ("centrifuge_room", "separate"), ("freezer_room", "store"),
    ("blood_bank", "store"), ("vaccination_booth", "immunize"), ("rehab_gym", "rehabilitate"), ("physiotherapy_room", "train"),
    ("consultation_room", "assess"), ("telemedicine_booth", "consult"), ("sleep_lab", "monitor"), ("endoscopy_suite", "scope"),
    ("dialysis_unit", "dialyze"), ("chemotherapy_bay", "infuse"), ("isolation_room", "contain"), ("burn_unit", "treat"),
    ("neonatal_icu", "resuscitate"), ("post_op_bay", "observe"), ("triage_desk", "prioritize"), ("nurse_station", "coordinate"),
    ("cafeteria", "dine"), ("waiting_area", "queue"), ("parking_garage", "park"), ("ambulance_bay", "arrive"),
    ("helipad", "land"), ("conference_room", "meet"), ("boardroom", "decide"), ("auditorium", "present"),
    ("classroom", "learn"), ("simulation_lab", "train"), ("animal_facility", "house"), ("vivarium", "maintain"),
    ("cold_room", "store"), ("clean_room", "assemble"), ("tissue_culture_room", "culture"), ("autoclave_room", "sterilize"),
    ("waste_room", "segregate"), ("linen_room", "launder"), ("supply_room", "stock"), ("server_room", "host"),
    ("data_center", "compute"), ("it_helpdesk", "support"), ("security_office", "monitor"), ("records_room", "archive"),
    ("chapel", "pray"), ("gift_shop", "sell"), ("kitchen", "cook"), ("rooftop_garden", "grow")
]

# 7) HETU — Cause → Effect  (domain: health & physiology)
HETU_60: List[Tuple[str, str]] = [
    ("infection", "fever"), ("allergy", "sneeze"), ("dehydration", "headache"), ("malnutrition", "fatigue"),
    ("overexertion", "cramp"), ("sedentary_lifestyle", "obesity"), ("sleep_deprivation", "drowsiness"), ("stress", "anxiety"),
    ("sun_exposure", "sunburn"), ("cold_exposure", "frostbite"), ("smoking", "cough"), ("alcohol", "hangover"),
    ("food_poisoning", "vomiting"), ("motion_sickness", "nausea"), ("migraine_trigger", "aura"), ("pollen", "itching"),
    ("dust", "sneezing"), ("mold", "wheezing"), ("noise", "tinnitus"), ("bright_light", "photophobia"),
    ("screen_time", "eyestrain"), ("caffeine", "jitters"), ("nicotine", "tachycardia"), ("hypertension", "stroke_risk"),
    ("hyperglycemia", "thirst"), ("hypoglycemia", "dizziness"), ("obesity", "insulin_resistance"), ("anemia", "weakness"),
    ("vitamin_d_deficiency", "bone_pain"), ("iodine_deficiency", "goiter"), ("iron_deficiency", "pallor"), ("b12_deficiency", "neuropathy"),
    ("viral_infection", "myalgia"), ("bacterial_infection", "pus"), ("fungal_infection", "rash"), ("parasite", "diarrhea"),
    ("trauma", "swelling"), ("burn", "blister"), ("frostbite", "necrosis"), ("radiation", "erythema"),
    ("chemotherapy", "hair_loss"), ("antibiotic_overuse", "resistance"), ("steroid_use", "immunosuppression"), ("dehydration_heat", "syncope"),
    ("overhydration", "hyponatremia"), ("overtraining", "fatigue"), ("poor_posture", "back_pain"), ("repetitive_strain", "tendinitis"),
    ("contagion", "outbreak"), ("unprotected_sun", "melanoma_risk"), ("loud_music", "hearing_loss"), ("sedatives", "drowsiness"),
    ("depressants", "slowed_breathing"), ("stimulants", "insomnia"), ("poor_hygiene", "infection"), ("contaminated_water", "diarrhea"),
    ("high_altitude", "hypoxia"), ("low_altitude", "edema_relief"), ("anaphylaxis_trigger", "shock"), ("asthma_trigger", "bronchospasm")
]

# 8) SAMBANDHA — Country → Capital  (domain: political geography)
SAMBANDHA_60: List[Tuple[str, str]] = [
    ("algeria", "algiers"), ("angola", "luanda"), ("benin", "porto_novo"), ("botswana", "gaborone"),
    ("egypt", "cairo"), ("ethiopia", "addis_ababa"), ("ghana", "accra"), ("kenya", "nairobi"),
    ("libya", "tripoli"), ("morocco", "rabat"), ("nigeria", "abuja"), ("rwanda", "kigali"),
    ("senegal", "dakar"), ("somalia", "mogadishu"), ("sudan", "khartoum"), ("uganda", "kampala"),
    ("zambia", "lusaka"), ("zimbabwe", "harare"), ("argentina", "buenos_aires"), ("brazil", "brasilia"),
    ("canada", "ottawa"), ("chile", "santiago"), ("colombia", "bogota"), ("cuba", "havana"),
    ("jamaica", "kingston"), ("mexico", "mexico_city"), ("peru", "lima"), ("venezuela", "caracas"),
    ("afghanistan", "kabul"), ("australia", "canberra"), ("bangladesh", "dhaka"), ("china", "beijing"),
    ("india", "new_delhi"), ("indonesia", "jakarta"), ("iran", "tehran"), ("iraq", "baghdad"),
    ("israel", "jerusalem"), ("japan", "tokyo"), ("jordan", "amman"), ("malaysia", "kuala_lumpur"),
    ("nepal", "kathmandu"), ("pakistan", "islamabad"), ("philippines", "manila"), ("qatar", "doha"),
    ("russia", "moscow"), ("syria", "damascus"), ("thailand", "bangkok"), ("turkey", "ankara"),
    ("vietnam", "hanoi"), ("austria", "vienna"), ("belgium", "brussels"), ("bulgaria", "sofia"),
    ("denmark", "copenhagen"), ("finland", "helsinki"), ("france", "paris"), ("germany", "berlin"),
    ("greece", "athens"), ("hungary", "budapest"), ("ireland", "dublin"), ("italy", "rome")
]

# 9) IS_A — Hyponym → Hypernym  (domain: animalia)
IS_A: List[Tuple[str, str]] = [
    ("lion", "animal"), ("tiger", "animal"), ("leopard", "animal"), ("cheetah", "animal"),
    ("jaguar", "animal"), ("panther", "animal"), ("cougar", "animal"), ("lynx", "animal"),
    ("bobcat", "animal"), ("ocelot", "animal"), ("wolf", "animal"), ("coyote", "animal"),
    ("fox", "animal"), ("jackal", "animal"), ("hyena", "animal"), ("bear", "animal"),
    ("panda", "animal"), ("polar_bear", "animal"), ("sloth", "animal"), ("otter", "animal"),
    ("weasel", "animal"), ("wolverine", "animal"), ("badger", "animal"), ("raccoon", "animal"),
    ("elephant", "animal"), ("rhino", "animal"), ("hippo", "animal"), ("giraffe", "animal"),
    ("zebra", "animal"), ("buffalo", "animal"), ("bison", "animal"), ("antelope", "animal"),
    ("gazelle", "animal"), ("camel", "animal"), ("llama", "animal"), ("alpaca", "animal"),
    ("goat", "animal"), ("sheep", "animal"), ("cow", "animal"), ("yak", "animal"),
    ("horse", "animal"), ("donkey", "animal"), ("pig", "animal"), ("boar", "animal"),
    ("kangaroo", "animal"), ("koala", "animal"), ("platypus", "animal"), ("echidna", "animal"),
    ("orangutan", "animal"), ("gorilla", "animal"), ("chimpanzee", "animal"), ("baboon", "animal"),
    ("lemur", "animal"), ("mongoose", "animal"), ("meerkat", "animal"), ("porcupine", "animal"),
    ("beaver", "animal"), ("squirrel", "animal"), ("hare", "animal"), ("rabbit", "animal")
]

# 10) PART_OF — Component → Whole  (domain: computer hardware)
PART_OF: List[Tuple[str, str]] = [
    ("cpu", "computer"), ("gpu", "computer"), ("motherboard", "computer"), ("ram_module", "computer"),
    ("power_supply", "computer"), ("case", "computer"), ("cooler", "computer"), ("heat_sink", "computer"),
    ("fan", "computer"), ("nvme_drive", "computer"), ("sata_drive", "computer"), ("optical_drive", "computer"),
    ("ethernet_card", "computer"), ("wifi_card", "computer"), ("sound_card", "computer"), ("capture_card", "computer"),
    ("usb_controller", "computer"), ("chipset", "computer"), ("bios_chip", "computer"), ("cmos_battery", "computer"),
    ("vrm", "motherboard"), ("pcie_slot", "motherboard"), ("m2_slot", "motherboard"), ("sata_port", "motherboard"),
    ("cpu_socket", "motherboard"), ("ram_slot", "motherboard"), ("backplate", "motherboard"), ("io_shield", "case"),
    ("front_panel_header", "motherboard"), ("heatsink_fins", "cooler"), ("heatpipe", "cooler"), ("thermal_paste", "cooler"),
    ("fan_blade", "fan"), ("fan_hub", "fan"), ("radiator", "liquid_cooler"), ("pump", "liquid_cooler"),
    ("reservoir", "liquid_cooler"), ("tubing", "liquid_cooler"), ("fuse", "power_supply"), ("transformer", "power_supply"),
    ("inductor", "power_supply"), ("capacitor", "power_supply"), ("bridge_rectifier", "power_supply"), ("pwm_controller", "fan"),
    ("keyboard", "desktop_setup"), ("mouse", "desktop_setup"), ("monitor", "desktop_setup"), ("speaker", "desktop_setup"),
    ("microphone", "desktop_setup"), ("webcam", "desktop_setup"), ("usb_cable", "desktop_setup"), ("hdmi_cable", "desktop_setup"),
    ("display_port_cable", "desktop_setup"), ("surge_protector", "desktop_setup"), ("kvm_switch", "desktop_setup"), ("desk_mat", "desktop_setup"),
    ("wrist_rest", "desktop_setup"), ("mouse_pad", "desktop_setup"), ("vesa_mount", "monitor"), ("backlight", "monitor")
]

# 11) MEMBER_OF — Member → Collection  (domain: team sports)
MEMBER_OF: List[Tuple[str, str]] = [
    ("goalkeeper", "football_team"), ("left_back", "football_team"), ("right_back", "football_team"), ("center_back", "football_team"),
    ("defensive_midfielder", "football_team"), ("central_midfielder", "football_team"), ("attacking_midfielder", "football_team"), ("left_winger", "football_team"),
    ("right_winger", "football_team"), ("striker", "football_team"), ("captain", "football_team"), ("coach", "football_team"),
    ("referee", "match_officials"), ("linesman", "match_officials"), ("fourth_official", "match_officials"), ("var_referee", "match_officials"),
    ("point_guard", "basketball_team"), ("shooting_guard", "basketball_team"), ("small_forward", "basketball_team"), ("power_forward", "basketball_team"),
    ("center", "basketball_team"), ("head_coach", "basketball_team"), ("assistant_coach", "basketball_team"), ("trainer", "basketball_team"),
    ("pitcher", "baseball_team"), ("catcher", "baseball_team"), ("first_baseman", "baseball_team"), ("second_baseman", "baseball_team"),
    ("shortstop", "baseball_team"), ("third_baseman", "baseball_team"), ("left_fielder", "baseball_team"), ("center_fielder", "baseball_team"),
    ("right_fielder", "baseball_team"), ("designated_hitter", "baseball_team"), ("manager", "baseball_team"), ("bullpen_coach", "baseball_team"),
    ("setter", "volleyball_team"), ("libero", "volleyball_team"), ("middle_blocker", "volleyball_team"), ("opposite_hitter", "volleyball_team"),
    ("outside_hitter", "volleyball_team"), ("server", "volleyball_team"), ("wing_defense", "netball_team"), ("goal_defense", "netball_team"),
    ("goal_attack", "netball_team"), ("goal_shooter", "netball_team"), ("wing_attack", "netball_team"), ("center_position", "netball_team"),
    ("scrum_half", "rugby_team"), ("fly_half", "rugby_team"), ("hooker", "rugby_team"), ("prop", "rugby_team"),
    ("lock", "rugby_team"), ("flanker", "rugby_team"), ("number_eight", "rugby_team"), ("fullback", "rugby_team")
]

# 12) LOCATED_IN — Entity → Place  (domain: US cities → state)
LOCATED_IN: List[Tuple[str, str]] = [
    ("new_york_city", "new_york"), ("los_angeles", "california"), ("chicago", "illinois"), ("houston", "texas"),
    ("phoenix", "arizona"), ("philadelphia", "pennsylvania"), ("san_antonio", "texas"), ("san_diego", "california"),
    ("dallas", "texas"), ("san_jose", "california"), ("austin", "texas"), ("jacksonville", "florida"),
    ("fort_worth", "texas"), ("columbus", "ohio"), ("charlotte", "north_carolina"), ("san_francisco", "california"),
    ("indianapolis", "indiana"), ("seattle", "washington"), ("denver", "colorado"), ("washington_dc", "district_of_columbia"),
    ("boston", "massachusetts"), ("el_paso", "texas"), ("nashville", "tennessee"), ("detroit", "michigan"),
    ("oklahoma_city", "oklahoma"), ("portland", "oregon"), ("las_vegas", "nevada"), ("memphis", "tennessee"),
    ("louisville", "kentucky"), ("baltimore", "maryland"), ("milwaukee", "wisconsin"), ("albuquerque", "new_mexico"),
    ("tucson", "arizona"), ("fresno", "california"), ("sacramento", "california"), ("kansas_city", "missouri"),
    ("mesa", "arizona"), ("atlanta", "georgia"), ("omaha", "nebraska"), ("colorado_springs", "colorado"),
    ("raleigh", "north_carolina"), ("miami", "florida"), ("long_beach", "california"), ("virginia_beach", "virginia"),
    ("oakland", "california"), ("minneapolis", "minnesota"), ("tulsa", "oklahoma"), ("wichita", "kansas"),
    ("new_orleans", "louisiana"), ("arlington", "texas"), ("tampa", "florida"), ("aurora", "colorado"),
    ("honolulu", "hawaii"), ("anaheim", "california"), ("santa_ana", "california"), ("st_louis", "missouri"),
    ("pittsburgh", "pennsylvania"), ("orlando", "florida"), ("cincinnati", "ohio"), ("cleveland", "ohio")
]

# 13) MADE_OF — Object → Material  (domain: consumer & household items)
MADE_OF: List[Tuple[str, str]] = [
    ("wine_bottle", "glass"), ("window_pane", "glass"), ("mirror_panel", "glass"), ("laboratory_beaker", "glass"),
    ("statue_bust", "marble"), ("kitchen_counter", "marble"), ("floor_tile", "marble"), ("sculpture_block", "marble"),
    ("bridge_cable", "steel"), ("cutlery", "steel"), ("spring_coil", "steel"), ("knife_blade", "steel"),
    ("hemp_rope", "hemp"), ("hemp_sack", "hemp"), ("cotton_canvas", "cotton"), ("cotton_shirt", "cotton"),
    ("wool_sweater", "wool"), ("wool_blanket", "wool"), ("copper_wire", "copper"), ("copper_pipe", "copper"),
    ("fiberglass_board", "fiberglass"), ("polycarbonate_helmet", "polycarbonate"), ("kevlar_vest", "kevlar"), ("pvc_pipe", "pvc"),
    ("clay_brick", "clay"), ("clay_pot", "clay"), ("ceramic_tile", "ceramic"), ("ceramic_mug", "ceramic"),
    ("cast_iron_pan", "iron"), ("cast_iron_skillet", "iron"), ("aluminum_foil", "aluminum"), ("aluminum_can", "aluminum"),
    ("nickel_coin", "nickel"), ("lithium_battery", "lithium"), ("silica_lens", "silica"), ("silicon_chip", "silicon"),
    ("wooden_table", "wood"), ("wooden_chair", "wood"), ("wooden_door", "wood"), ("wooden_floor", "wood"),
    ("plastic_bottle", "plastic"), ("plastic_bag", "plastic"), ("plastic_bucket", "plastic"), ("plastic_toy", "plastic"),
    ("gold_ring", "gold"), ("gold_necklace", "gold"), ("silver_bracelet", "silver"), ("steel_spoon", "steel"),
    ("bronze_statue", "bronze"), ("bronze_bell", "bronze"), ("carbon_fiber_frame", "carbon_fiber"), ("bike_frame", "carbon_fiber"),
    ("concrete_slab", "concrete"), ("asphalt_road", "asphalt"), ("graphite_pencil_core", "graphite"), ("glass_bulb_envelope", "glass"),
    ("nylon_guitar_string", "nylon"), ("kevlar_helmet", "kevlar"), ("copper_heat_pipe", "copper"), ("silicon_solar_cell", "silicon")
]

# 14) CREATED_BY — Artifact/Work → Creator  (domain: literature — novels → authors)
CREATED_BY: List[Tuple[str, str]] = [
    ("pride_and_prejudice", "jane_austen"), ("sense_and_sensibility", "jane_austen"), ("emma", "jane_austen"), ("moby_dick", "herman_melville"),
    ("the_scarlet_letter", "nathaniel_hawthorne"), ("great_expectations", "charles_dickens"), ("oliver_twist", "charles_dickens"), ("david_copperfield", "charles_dickens"),
    ("jane_eyre", "charlotte_bronte"), ("wuthering_heights", "emily_bronte"), ("middlemarch", "george_eliot"), ("the_picture_of_dorian_gray", "oscar_wilde"),
    ("dracula", "bram_stoker"), ("frankenstein", "mary_shelley"), ("crime_and_punishment", "fyodor_dostoevsky"), ("war_and_peace", "leo_tolstoy"),
    ("anna_karenina", "leo_tolstoy"), ("the_brothers_karamazov", "fyodor_dostoevsky"), ("ulysses", "james_joyce"), ("the_great_gatsby", "f_scott_fitzgerald"),
    ("to_kill_a_mockingbird", "harper_lee"), ("catch_22", "joseph_heller"), ("1984", "george_orwell"), ("animal_farm", "george_orwell"),
    ("brave_new_world", "aldous_huxley"), ("lord_of_the_flies", "william_goldings"), ("the_lord_of_the_rings", "j_r_r_tolkien"), ("the_hobbit", "j_r_r_tolkien"),
    ("one_hundred_years_of_solitude", "gabriel_garcia_marquez"), ("love_in_the_time_of_cholera", "gabriel_garcia_marquez"), ("the_stranger", "albert_camus"), ("the_plague", "albert_camus"),
    ("the_trial", "franz_kafka"), ("don_quixote", "miguel_de_cervantes"), ("faust", "johann_wolfgang_von_goethe"), ("madame_bovary", "gustave_flaubert"),
    ("les_miserables", "victor_hugo"), ("the_count_of_monte_cristo", "alexandre_dumas"), ("the_three_musketeers", "alexandre_dumas"), ("the_idiot", "fyodor_dostoevsky"),
    ("white_whale", "herman_melville"), ("the_sun_also_rises", "ernest_hemingway"), ("for_whom_the_bell_tolls", "ernest_hemingway"), ("old_man_and_the_sea", "ernest_hemingway"),
    ("beloved", "toni_morrison"), ("song_of_solomon", "toni_morrison"), ("a_farewell_to_arms", "ernest_hemingway"), ("slaughterhouse_five", "kurt_vonnegut"),
    ("the_sound_and_the_fury", "william_faulkner"), ("as_i_lay_dying", "william_faulkner"), ("the_handmaids_tale", "margaret_atwood"), ("the_blind_assassin", "margaret_atwood"),
    ("the_name_of_the_rose", "umberto_eco"), ("if_on_a_winters_night_a_traveler", "italo_calvino"), ("norwegian_wood", "haruki_murakami"), ("kafka_on_the_shore", "haruki_murakami")
]

# 15) WORKS_FOR — Person/role → Organization  (domain: tech & research orgs)
WORKS_FOR: List[Tuple[str, str]] = [
    ("software_engineer", "google"), ("research_scientist", "google"), ("product_manager", "google"), ("site_reliability_engineer", "google"),
    ("data_scientist", "microsoft"), ("ml_engineer", "microsoft"), ("program_manager", "microsoft"), ("ux_designer", "microsoft"),
    ("hardware_engineer", "apple"), ("ios_developer", "apple"), ("silicon_architect", "apple"), ("industrial_designer", "apple"),
    ("android_developer", "meta"), ("research_engineer", "meta"), ("data_engineer", "meta"), ("security_engineer", "meta"),
    ("systems_engineer", "amazon"), ("solutions_architect", "amazon"), ("applied_scientist", "amazon"), ("devops_engineer", "amazon"),
    ("graphics_engineer", "nvidia"), ("systems_researcher", "nvidia"), ("compiler_engineer", "nvidia"), ("asic_engineer", "nvidia"),
    ("cpu_architect", "amd"), ("firmware_engineer", "amd"), ("driver_engineer", "amd"), ("design_engineer", "amd"),
    ("quant_researcher", "two_sigma"), ("quant_trader", "two_sigma"), ("risk_analyst", "two_sigma"), ("software_engineer", "two_sigma"),
    ("research_scientist", "openai"), ("policy_researcher", "openai"), ("alignment_researcher", "openai"), ("applied_engineer", "openai"),
    ("robotics_engineer", "boston_dynamics"), ("controls_engineer", "boston_dynamics"), ("field_engineer", "boston_dynamics"), ("test_engineer", "boston_dynamics"),
    ("autonomy_engineer", "tesla"), ("battery_engineer", "tesla"), ("manufacturing_engineer", "tesla"), ("vehicle_designer", "tesla"),
    ("bioinformatician", "broad_institute"), ("genomics_scientist", "broad_institute"), ("computational_biologist", "broad_institute"), ("research_assistant", "broad_institute"),
    ("assistant_professor", "mit"), ("associate_professor", "mit"), ("full_professor", "mit"), ("postdoc", "mit"),
    ("phd_student", "stanford"), ("ms_student", "stanford"), ("research_engineer", "stanford"), ("lab_manager", "stanford")
]

# 16) AGENT_USES — Agent/profession → Tool  (domain: photography & filmmaking)
AGENT_USES: List[Tuple[str, str]] = [
    ("photographer", "dslr_camera"), ("photographer", "mirrorless_camera"), ("photographer", "tripod"), ("photographer", "light_meter"),
    ("videographer", "cinema_camera"), ("videographer", "gimbal"), ("videographer", "external_recorder"), ("videographer", "follow_focus"),
    ("gaffer", "light_panel"), ("gaffer", "softbox"), ("gaffer", "c_stand"), ("gaffer", "flag"),
    ("grip", "dolly"), ("grip", "slider"), ("grip", "jib"), ("grip", "clamp"),
    ("sound_recordist", "boom_mic"), ("sound_recordist", "shotgun_mic"), ("sound_recordist", "field_mixer"), ("sound_recordist", "lav_mic"),
    ("colorist", "grading_panel"), ("colorist", "reference_monitor"), ("editor", "nle_workstation"), ("editor", "raid_array"),
    ("director", "viewfinder"), ("director", "monitor"), ("script_supervisor", "slate"), ("script_supervisor", "continuity_log"),
    ("focus_puller", "wireless_follow_focus"), ("camera_op", "fluid_head"), ("camera_assistant", "matte_box"), ("camera_assistant", "nd_filter"),
    ("steadicam_operator", "steadicam_rig"), ("drone_pilot", "quadcopter"), ("drone_pilot", "remote_controller"), ("dit", "ingest_station"),
    ("lighting_tech", "barn_doors"), ("lighting_tech", "gel"), ("lighting_tech", "scrim"), ("lighting_tech", "diffuser"),
    ("production_sound_mixer", "timecode_slate"), ("production_sound_mixer", "headphones"), ("boom_operator", "boom_pole"), ("boom_operator", "shock_mount"),
    ("stills_photographer", "prime_lens"), ("stills_photographer", "zoom_lens"), ("stills_photographer", "flash"), ("stills_photographer", "reflector"),
    ("gaffer", "light_dimmer"), ("grip", "apple_box"), ("grip", "sandbag"), ("grip", "spud"),
    ("script_supervisor", "shot_list"), ("director", "storyboard"), ("editor", "color_checker"), ("colorist", "lut_box")
]

# 17) CAUSES — Cause → Effect  (domain: weather & environment)
CAUSES: List[Tuple[str, str]] = [
    ("rain", "flooding"), ("drought", "crop_failure"), ("heatwave", "dehydration"), ("cold_snap", "frost_damage"),
    ("hailstorm", "dented_roofs"), ("blizzard", "whiteout"), ("ice_storm", "power_outage"), ("windstorm", "treefall"),
    ("thunderstorm", "lightning"), ("lightning", "wildfire"), ("wildfire", "smoke"), ("smoke", "haze"),
    ("volcanic_eruption", "ashfall"), ("earthquake", "building_collapse"), ("tsunami", "inundation"), ("landslide", "road_blockage"),
    ("coastal_erosion", "beach_loss"), ("river_erosion", "bank_collapse"), ("sedimentation", "river_shallowing"), ("desertification", "soil_infertility"),
    ("urban_sprawl", "long_commute"), ("traffic", "congestion"), ("congestion", "delay"), ("air_pollution", "smog"),
    ("water_pollution", "fish_kill"), ("noise_pollution", "annoyance"), ("light_pollution", "skyglow"), ("plastic_dumping", "ocean_gyres"),
    ("overfishing", "stock_depletion"), ("deforestation", "habitat_loss"), ("afforestation", "carbon_sink"), ("reforestation", "soil_stability"),
    ("heat_island", "higher_temperatures"), ("cold_front", "freeze"), ("warm_front", "rain"), ("pressure_drop", "storminess"),
    ("cyclone", "storm_surge"), ("hurricane", "flooding"), ("typhoon", "wind_damage"), ("tornado", "structural_damage"),
    ("snowmelt", "spring_flood"), ("glacier_retreat", "sea_level_rise"), ("permafrost_thaw", "methane_release"), ("ocean_warming", "coral_bleaching"),
    ("el_nino", "drought"), ("la_nina", "flooding"), ("acid_rain", "lake_acidification"), ("ozone_depletion", "uv_exposure"),
    ("oil_spill", "coastline_damage"), ("chemical_leak", "toxicity"), ("nuclear_accident", "radiation"), ("mine_tailings", "river_contamination"),
    ("dam_failure", "downstream_flood"), ("levee_breach", "inundation"), ("storm_surge", "coastal_flood"), ("king_tide", "tidal_flood"),
    ("siltation", "navigation_hazard"), ("over_irrigation", "salinization"), ("overgrazing", "land_degradation"), ("peat_fire", "haze")
]


### Generating Embeddings

#### **Word2Vec (via Tensorflow Hub)**

In [4]:
### 1) TF-Hub Wiki-words (word2vec-style) wrapper
W2V_HANDLE = "https://tfhub.dev/google/Wiki-words-500-with-normalization/2"

def init_w2v_embedder(handle: str = W2V_HANDLE):
    """
    Returns a simple callable embed(words: List[str]) -> np.ndarray (float32).
    Uses hub.load() which returns a callable SavedModel for this TF2 handle.
    """
    try:
        module = hub.load(handle)   # TF2 saved model; returns callable in many TF-Hub word models
    except Exception as e:
        # fallback to KerasLayer (works for many TF2 hub models)
        module = hub.KerasLayer(handle)

    def embed(words: List[str]) -> np.ndarray:
        # Accepts list of python strings and returns (N, D) float32 numpy array
        try:
            out = module(words)           # often works: module(["hello","world"])
        except Exception:
            # try calling via tf.constant
            out = module(tf.constant(words))
        # ensure numpy float32
        arr = np.asarray(out)
        if arr.dtype != np.float32:
            arr = arr.astype(np.float32)
        return arr

    return embed

#### **GlOVE (via Stanford NLP)**

In [5]:
def load_glove_from_zip(zip_path: str, dim: int = 100, filename: Optional[str] = None,
                        progress: bool = False) -> Dict[str, np.ndarray]:
    """
    Load glove.6B.{dim}d.txt from a zip (e.g., glove.6B.zip).
    Returns dict: word -> np.float32 vector.
    This avoids gensim and reads directly from the .txt inside the zip.
    """
    if filename is None:
        filename = f"glove.6B.{dim}d.txt"

    embeddings = {}
    with zipfile.ZipFile(zip_path, "r") as zf:
        members = zf.namelist()
        if filename not in members:
            raise FileNotFoundError(f"{filename} not in {zip_path}; available: {members[:10]}")

        with zf.open(filename, "r") as f:
            for i, raw_line in enumerate(f):
                line = raw_line.decode("utf-8", errors="replace").strip()
                if not line:
                    continue
                parts = line.split()
                word = parts[0]
                vals = np.asarray(parts[1:], dtype=np.float32)
                if vals.shape[0] != dim:
                    # skip / warn
                    continue
                embeddings[word] = vals
                if progress and (i % 200000 == 0):
                    print(f"read {i} lines...")

    return embeddings

def build_and_save_glove_matrix(vocab: List[str], glove_dict: Dict[str, np.ndarray],
                                dim: int, out_path: str):
    """
    Build embedding matrix aligned to vocab (list of tokens).
    Unknown tokens get zero vectors. Saves matrix and vocab->idx mapping.
    """
    vocab_lower = [w.lower() for w in vocab]  # match typical glove tokens
    V = len(vocab_lower)
    mat = np.zeros((V, dim), dtype=np.float32)
    missing = []
    for i, w in enumerate(vocab_lower):
        if w in glove_dict:
            mat[i] = glove_dict[w]
        else:
            # try some normalization fallbacks
            w2 = w.replace('-', ' ')
            if w2 in glove_dict:
                mat[i] = glove_dict[w2]
            else:
                missing.append(w)
                # leave zero vector or random init if you prefer:
                # mat[i] = np.random.normal(scale=0.01, size=(dim,)).astype(np.float32)
    # save
    np.save(out_path + ".npy", mat)
    with open(out_path + ".vocab.json", "w", encoding="utf-8") as f:
        json.dump(vocab_lower, f, ensure_ascii=False, indent=2)
    return mat, missing

#### **BERT (via Hugging Face Transformers)**

In [6]:
class BertEmbedder:
    """A wrapper for Hugging Face's BERT model to generate embeddings."""
    def __init__(self, model_id: str, device: Optional[str] = None):
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModel.from_pretrained(model_id)
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
        self.model.to(self.device).eval()
        self.dim = self.model.config.hidden_size
        print(f"BERT model loaded on device: {self.device}")

    @torch.no_grad()
    def encode(self, texts: List[str], batch_size=32, max_length=32) -> np.ndarray:
        """Encodes a list of texts into embeddings using mean pooling."""
        all_embeds = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            inputs = self.tokenizer(batch, padding=True, truncation=True,
                                    max_length=max_length, return_tensors="pt").to(self.device)

            outputs = self.model(**inputs)

            # Perform mean pooling
            mask = inputs["attention_mask"].unsqueeze(-1).expand(outputs.last_hidden_state.size())
            masked_embeddings = outputs.last_hidden_state * mask
            summed = torch.sum(masked_embeddings, 1)
            counts = torch.clamp(mask.sum(1), min=1e-9)
            mean_pooled = summed / counts

            all_embeds.append(mean_pooled.cpu().numpy())

        return np.vstack(all_embeds).astype(np.float32)

# --- Initialize BERT ---
print("\nSetting up BERT...")
BERT_MODEL_ID = "bert-base-uncased"
bert_embedder = BertEmbedder(BERT_MODEL_ID)
print(f"✅ BERT is ready! Model: {BERT_MODEL_ID}, Embedding dimension: {bert_embedder.dim}")


Setting up BERT...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BERT model loaded on device: cpu
✅ BERT is ready! Model: bert-base-uncased, Embedding dimension: 768


#### **Roberta**

In [7]:
# This is the same class you provided, renamed for generality
class HuggingFaceEmbedder:
    """A generalized wrapper for Hugging Face models to generate phrase embeddings."""
    def __init__(self, model_id: str, device: Optional[str] = None):
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModel.from_pretrained(model_id)
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
        self.model.to(self.device).eval()
        self.dim = self.model.config.hidden_size
        print(f"Model '{model_id}' loaded on device: {self.device}")

    @torch.no_grad()
    def encode(self, texts: List[str], batch_size=32, max_length=32) -> np.ndarray:
        """Encodes a list of texts into embeddings using mean pooling."""
        all_embeds = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            inputs = self.tokenizer(batch, padding=True, truncation=True,
                                    max_length=max_length, return_tensors="pt").to(self.device)

            outputs = self.model(**inputs)

            # Perform mean pooling on the last hidden state
            mask = inputs["attention_mask"].unsqueeze(-1)
            summed = (outputs.last_hidden_state * mask).sum(dim=1)
            counts = torch.clamp(mask.sum(dim=1), min=1e-9)
            mean_pooled = summed / counts

            all_embeds.append(mean_pooled.cpu().numpy())

        return np.vstack(all_embeds).astype(np.float32)

# --- Initialize RoBERTa ---
print("Setting up RoBERTa...")
ROBERTA_MODEL_ID = "roberta-base"
roberta_embedder = HuggingFaceEmbedder(ROBERTA_MODEL_ID)
print(f"✅ RoBERTa is ready! Model: {ROBERTA_MODEL_ID}, Embedding dimension: {roberta_embedder.dim}")

Setting up RoBERTa...


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model 'roberta-base' loaded on device: cpu
✅ RoBERTa is ready! Model: roberta-base, Embedding dimension: 768


#### **LaBSE**

In [8]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text  # <-- CRITICAL IMPORT: This registers the custom ops
import numpy as np
from typing import List

class LaBSEmbedder:
    """
    A corrected wrapper for the LaBSE model from TensorFlow Hub that ensures
    the necessary pre-processor and its dependencies are loaded correctly.
    """
    def __init__(self, handle: str):
        # Define the handles for the pre-processor and the encoder
        preprocessor_handle = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"

        # Load the pre-processor and the main LaBSE encoder model
        self.preprocessor = hub.KerasLayer(preprocessor_handle)
        self.encoder = hub.KerasLayer(handle, trainable=False)

        # Probe the model to get its output dimension
        test_text = tf.constant(["probing the model"])
        processed_test_text = self.preprocessor(test_text)
        embedding_result = self.encoder(processed_test_text)
        self.dim = int(embedding_result["pooled_output"].shape[-1])

    def encode(self, texts: List[str]) -> np.ndarray:
        """
        Encodes a list of texts by first running them through the pre-processor
        and then the LaBSE encoder.
        """
        text_tensor = tf.constant(texts)
        processed_text = self.preprocessor(text_tensor)
        embedding_result = self.encoder(processed_text)
        return embedding_result["pooled_output"].numpy().astype(np.float32)

# --- Initialize LaBSE with the corrected class ---
print("Setting up LaBSE...")
LABSE_HANDLE = "https://tfhub.dev/google/LaBSE/2"
labse_embedder = LaBSEmbedder(LABSE_HANDLE)
print(f"✅ LaBSE is ready! Embedding dimension: {labse_embedder.dim}")

# --- Example Usage ---
sample_embedding = labse_embedder.encode(["hello world"])
print(f"Successfully created a sample embedding with shape: {sample_embedding.shape}")

Setting up LaBSE...
✅ LaBSE is ready! Embedding dimension: 768
Successfully created a sample embedding with shape: (1, 768)


## Relation PCA

In [9]:
import os
import math
from collections import defaultdict

REPORT_DIR = "./relation_pca_reports"
os.makedirs(REPORT_DIR, exist_ok=True)

In [10]:
# -------------------------
# Helpers: normalization
# -------------------------
def l2_normalize_rows(X: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    return X / np.clip(norms, eps, None)

def safe_center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(axis=0, keepdims=True)

def explained_variance_from_singular_values(s: np.ndarray, n_samples: int) -> np.ndarray:
    # For centered data X (n x d), with SVD X = U S V^T, eigenvalues of covariance (1/(n-1)) X^T X are (S^2)/(n-1).
    # Explained variance ratios are proportional to s^2.
    lam = (s ** 2) / max(n_samples - 1, 1)
    total = lam.sum() if lam.sum() > 0 else 1.0
    evr = lam / total
    return lam, evr

In [11]:
# -------------------------
# Embedding registry
# -------------------------

# 1) Word2Vec (TF-Hub)
try:
    w2v_embed = init_w2v_embedder(W2V_HANDLE)
    W2V_READY = True
except Exception as e:
    print(f"[WARN] Could not init TF-Hub Word2Vec: {e}")
    w2v_embed, W2V_READY = None, False

# 2) GloVe (zip reading; optional)
GLOVE_READY = False
GLOVE_DIM = 100
GLOVE_DICT = {}
_GLOVE_CANDIDATES = ["/content/glove.6B.zip", "./glove.6B.zip", "/mnt/data/glove.6B.zip"]
for _cand in _GLOVE_CANDIDATES:
    if os.path.exists(_cand):
        try:
            GLOVE_DICT = load_glove_from_zip(_cand, dim=GLOVE_DIM, progress=False)
            GLOVE_READY = True
            print(f"[INFO] Loaded GloVe {GLOVE_DIM}d from: {_cand} (|V|={len(GLOVE_DICT)})")
            break
        except Exception as e:
            print(f"[WARN] Failed loading GloVe from {_cand}: {e}")

# 3) BERT / RoBERTa / LaBSE already instantiated above:
BERT_READY = True if 'bert_embedder' in globals() else False
ROBERTA_READY = True if 'roberta_embedder' in globals() else False
LABSE_READY = True if 'labse_embedder' in globals() else False

def embed_words(words: List[str], embedding_kind: str) -> np.ndarray:
    """
    Returns an array of shape (N, D) for a list of tokens/strings using the selected embedding.
    Supported: 'w2v', 'glove', 'bert', 'roberta', 'labse'
    """
    if embedding_kind == "w2v":
        if not W2V_READY:
            raise RuntimeError("Word2Vec embedder not initialized.")
        arr = w2v_embed(words)  # TF-Hub callable
        return np.asarray(arr, dtype=np.float32)

    if embedding_kind == "glove":
        if not GLOVE_READY:
            raise RuntimeError("GloVe not loaded. Put glove.6B.zip in one of the known paths.")
        # Look up each word (lowercased); OOV -> zeros
        out = np.zeros((len(words), GLOVE_DIM), dtype=np.float32)
        for i, w in enumerate(words):
            key = w.lower()
            vec = GLOVE_DICT.get(key)
            if vec is None:
                # try a few mild normalizations
                key2 = key.replace('-', ' ')
                vec = GLOVE_DICT.get(key2)
            if vec is not None:
                out[i] = vec
        return out

    if embedding_kind == "bert":
        if not BERT_READY:
            raise RuntimeError("BERT embedder not initialized.")
        return bert_embedder.encode(words)

    if embedding_kind == "roberta":
        if not ROBERTA_READY:
            raise RuntimeError("RoBERTa embedder not initialized.")
        return roberta_embedder.encode(words)

    if embedding_kind == "labse":
        if not LABSE_READY:
            raise RuntimeError("LaBSE embedder not initialized.")
        return labse_embedder.encode(words)

    raise ValueError(f"Unknown embedding_kind: {embedding_kind}")

In [12]:
def relation_pca_for_pairs(pairs: List[Tuple[str, str]],
                           embedding_kind: str,
                           normalize_rows: bool = NORMALIZE_ROWS,
                           k_max: int = K_MAX,
                           var_target: float = VAR_TARGET,
                           drop_zero_rows: bool = True) -> Dict:
    """
    Compute PCA on relation difference vectors r_i = e(v_i) - e(u_i).
    Returns dict of metrics, spectrum, basis (top-k*) and bookkeeping.
    """
    if len(pairs) == 0:
        raise ValueError("Empty pair list")

    U = [u for (u, v) in pairs]
    V = [v for (u, v) in pairs]

    # Embed
    EU = embed_words(U, embedding_kind)
    EV = embed_words(V, embedding_kind)

    # Optional row-wise L2 normalize BEFORE differencing (common for relational subspaces)
    if normalize_rows:
        EU = l2_normalize_rows(EU)
        EV = l2_normalize_rows(EV)

    # Differences
    R = EV - EU  # shape (n, d)

    if drop_zero_rows:
        # Filter any rows where either EU or EV is (near) zero norm (can happen with OOV for GloVe/W2V)
        mask_u = np.linalg.norm(EU, axis=1) > 1e-9
        mask_v = np.linalg.norm(EV, axis=1) > 1e-9
        mask = mask_u & mask_v
        if mask.sum() < len(pairs):
            print(f"[{embedding_kind}] Dropping {(~mask).sum()} pairs due to OOV/zero vectors.")
        R = R[mask]
        kept_idx = np.where(mask)[0].tolist()
    else:
        kept_idx = list(range(len(pairs)))

    n, d = R.shape
    if n < 2:
        raise ValueError(f"Not enough valid pairs after filtering for {embedding_kind}: n={n}")

    # Center
    Rc = safe_center(R)

    # SVD
    U_svd, s, Vt = np.linalg.svd(Rc, full_matrices=False)

    # Spectrum
    lam, evr = explained_variance_from_singular_values(s, n)
    cum_evr = np.cumsum(evr)

    # Select k* (smallest k achieving target)
    k_eff = min(k_max, len(s))
    k_star = int(np.searchsorted(cum_evr[:k_eff], var_target) + 1)
    if k_star > k_eff:
        k_star = k_eff

    evr_at_k = float(cum_evr[k_star - 1])
    leakage_at_k = float(1.0 - evr_at_k)

    # Basis vectors for top-k*
    basis_k = Vt[:k_star].copy()  # shape (k*, d)

    # Some descriptive stats on relation vectors
    norms = np.linalg.norm(Rc, axis=1)
    med_norm = float(np.median(norms))
    mean_norm = float(np.mean(norms))
    std_norm = float(np.std(norms))

    # Alignment with first PC (signed projection)
    pc1 = Vt[0]
    proj_pc1 = Rc @ pc1
    med_abs_proj_pc1 = float(np.median(np.abs(proj_pc1)))
    mean_abs_proj_pc1 = float(np.mean(np.abs(proj_pc1)))

    # Per-dimension table
    spectrum = []
    for i in range(len(s)):
        spectrum.append({
            "component": i + 1,
            "singular_value": float(s[i]),
            "eigenvalue": float(lam[i]),
            "explained_variance_ratio": float(evr[i]),
            "cumulative_evr": float(cum_evr[i]),
        })

    # Pack results
    result = {
        "n_pairs_used": int(n),
        "embedding_kind": embedding_kind,
        "embedding_dim": int(d),
        "k_max": int(k_max),
        "var_target": float(var_target),
        "k_star": int(k_star),
        "evr_at_k": evr_at_k,
        "orthogonal_leakage_at_k": leakage_at_k,
        "median_centered_norm": med_norm,
        "mean_centered_norm": mean_norm,
        "std_centered_norm": std_norm,
        "median_abs_proj_pc1": med_abs_proj_pc1,
        "mean_abs_proj_pc1": mean_abs_proj_pc1,
        "spectrum": spectrum,
        "basis_top_k": basis_k,      # np.ndarray (k*, d)
        "kept_indices": kept_idx,    # which pairs were kept
    }
    return result

In [13]:
RELATION_SETS: Dict[str, List[Tuple[str, str]]] = {
    # Kāraka-style
    "KARTA": KARTA_60,
    "KARMA": KARMA_60,
    "KARANA": KARANA_60,
    "SAMPRADANA": SAMPRADANA_60,
    "APADANA": APADANA_60,
    "ADHIKARANA": ADHIKARANA_60,
    "HETU": HETU_60,
    "SAMBANDHA_country_capital": SAMBANDHA_60,
    # KG-style
    "IS_A": IS_A,
    "PART_OF": PART_OF,
    "MEMBER_OF": MEMBER_OF,
    "LOCATED_IN": LOCATED_IN,
    "MADE_OF": MADE_OF,
    "CREATED_BY": CREATED_BY,
    "WORKS_FOR": WORKS_FOR,
    "AGENT_USES": AGENT_USES,
    "CAUSES": CAUSES,
}

EMBEDDING_KINDS = []
if W2V_READY: EMBEDDING_KINDS.append("w2v")
if GLOVE_READY: EMBEDDING_KINDS.append("glove")
if BERT_READY: EMBEDDING_KINDS.append("bert")
if ROBERTA_READY: EMBEDDING_KINDS.append("roberta")
if LABSE_READY: EMBEDDING_KINDS.append("labse")

print(f"[INFO] Running Relation-PCA for embeddings: {EMBEDDING_KINDS}")

[INFO] Running Relation-PCA for embeddings: ['w2v', 'bert', 'roberta', 'labse']


In [14]:
# -------------------------
# Run analyses & write reports
# -------------------------
OVERVIEW_ROWS = []
KSTAR = {}
EVR_AT_K = {}
LEAK_AT_K = {}

for rel_name, pairs in RELATION_SETS.items():
    for emb_kind in EMBEDDING_KINDS:
        try:
            res = relation_pca_for_pairs(
                pairs=pairs,
                embedding_kind=emb_kind,
                normalize_rows=NORMALIZE_ROWS,
                k_max=K_MAX,
                var_target=VAR_TARGET,
                drop_zero_rows=True
            )

            # Store key metrics in memory
            KSTAR[(emb_kind, rel_name)] = res["k_star"]
            EVR_AT_K[(emb_kind, rel_name)] = res["evr_at_k"]
            LEAK_AT_K[(emb_kind, rel_name)] = res["orthogonal_leakage_at_k"]

            # Write detailed spectrum CSV
            spec_csv = os.path.join(REPORT_DIR, f"spectrum_{emb_kind}_{rel_name}.csv")
            with open(spec_csv, "w", encoding="utf-8") as f:
                f.write("component,singular_value,eigenvalue,explained_variance_ratio,cumulative_evr\n")
                for row in res["spectrum"]:
                    f.write(f"{row['component']},{row['singular_value']},{row['eigenvalue']},{row['explained_variance_ratio']},{row['cumulative_evr']}\n")

            # Write compact JSON (without basis to keep file small)
            compact = {k: v for k, v in res.items() if k not in ("basis_top_k",)}
            json_path = os.path.join(REPORT_DIR, f"summary_{emb_kind}_{rel_name}.json")
            with open(json_path, "w", encoding="utf-8") as jf:
                json.dump(compact, jf, ensure_ascii=False, indent=2)

            # Optionally save basis as .npy for reproducibility
            basis_path = os.path.join(REPORT_DIR, f"basis_{emb_kind}_{rel_name}_topk.npy")
            np.save(basis_path, res["basis_top_k"])

            OVERVIEW_ROWS.append({
                "relation": rel_name,
                "embedding": emb_kind,
                "n_pairs_used": res["n_pairs_used"],
                "embedding_dim": res["embedding_dim"],
                "k_star": res["k_star"],
                "evr_at_k": res["evr_at_k"],
                "orthogonal_leakage_at_k": res["orthogonal_leakage_at_k"],
                "median_centered_norm": res["median_centered_norm"],
            })

            print(f"[OK] {emb_kind:8s} | {rel_name:28s} | k*={res['k_star']:2d} | EVR={res['evr_at_k']:.3f} | leak={res['orthogonal_leakage_at_k']:.3f} | n={res['n_pairs_used']}")

        except Exception as e:
            print(f"[FAIL] {emb_kind} | {rel_name}: {e}")

# Save overview
if OVERVIEW_ROWS:
    overview_df = pd.DataFrame(OVERVIEW_ROWS).sort_values(["embedding", "relation"]).reset_index(drop=True)
    overview_csv = os.path.join(REPORT_DIR, "summary_overview.csv")
    overview_json = os.path.join(REPORT_DIR, "summary_overview.json")
    overview_df.to_csv(overview_csv, index=False)
    with open(overview_json, "w", encoding="utf-8") as jf:
        json.dump(OVERVIEW_ROWS, jf, ensure_ascii=False, indent=2)
    print(f"[INFO] Wrote overview to:\n  {overview_csv}\n  {overview_json}")
else:
    print("[WARN] No overview rows written (all runs failed?).")

# -------------------------
# Tabular presentation instead of chart
# -------------------------
if OVERVIEW_ROWS:
    rels = list(RELATION_SETS.keys())
    embs = EMBEDDING_KINDS

    # 1) Pretty, human-readable table with rows = embeddings, cols = relations
    #    and each cell showing: k*, EVR, leakage
    pretty_cells = {emb: {} for emb in embs}
    for emb in embs:
        for rel in rels:
            key = (emb, rel)
            if key in EVR_AT_K:
                pretty_cells[emb][rel] = f"k={KSTAR[key]}, evr={EVR_AT_K[key]:.2f}, leak={LEAK_AT_K[key]:.2f}"
            else:
                pretty_cells[emb][rel] = "—"

    pretty_df = pd.DataFrame.from_dict(pretty_cells, orient="index", columns=rels)
    pretty_df.index.name = "embedding"

    # Save the pretty table (CSV + Markdown)
    pretty_csv = os.path.join(REPORT_DIR, "table_pretty_metrics.csv")
    pretty_md  = os.path.join(REPORT_DIR, "table_pretty_metrics.md")
    pretty_df.to_csv(pretty_csv)
    with open(pretty_md, "w", encoding="utf-8") as f:
        f.write(pretty_df.to_markdown())

    # 2) Separate numeric matrices (k*, EVR, leakage) with the same layout
    def _matrix_from(metric_dict, fill_value=np.nan):
        mat = {emb: {} for emb in embs}
        for emb in embs:
            for rel in rels:
                mat[emb][rel] = metric_dict.get((emb, rel), fill_value)
        return pd.DataFrame.from_dict(mat, orient="index", columns=rels)

    kstar_df = _matrix_from(KSTAR, fill_value=np.nan)
    evr_df   = _matrix_from(EVR_AT_K, fill_value=np.nan)
    leak_df  = _matrix_from(LEAK_AT_K, fill_value=np.nan)

    kstar_df.index.name = evr_df.index.name = leak_df.index.name = "embedding"

    # Save numeric matrices
    kstar_csv = os.path.join(REPORT_DIR, "matrix_kstar.csv")
    evr_csv   = os.path.join(REPORT_DIR, "matrix_evr_at_k.csv")
    leak_csv  = os.path.join(REPORT_DIR, "matrix_leakage_at_k.csv")
    kstar_df.to_csv(kstar_csv)
    evr_df.to_csv(evr_csv)
    leak_df.to_csv(leak_csv)

    # Console printout (compact)
    print("\n=== Relation-PCA Key Metrics (rows = embedding, cols = relation) ===")
    print(pretty_df.to_string(max_cols=120, max_rows=200, justify="center"))
    print(f"\n[INFO] Saved tables to:\n- {pretty_csv}\n- {pretty_md}\n- {kstar_csv}\n- {evr_csv}\n- {leak_csv}")
else:
    print("[WARN] No results available to tabulate.")


[w2v] Dropping 3 pairs due to OOV/zero vectors.
[OK] w2v      | KARTA                        | k*=43 | EVR=0.951 | leak=0.049 | n=57
[OK] bert     | KARTA                        | k*=44 | EVR=0.953 | leak=0.047 | n=60
[OK] roberta  | KARTA                        | k*=43 | EVR=0.952 | leak=0.048 | n=60
[OK] labse    | KARTA                        | k*=45 | EVR=0.953 | leak=0.047 | n=60
[OK] w2v      | KARMA                        | k*=44 | EVR=0.952 | leak=0.048 | n=60
[OK] bert     | KARMA                        | k*=42 | EVR=0.955 | leak=0.045 | n=60
[OK] roberta  | KARMA                        | k*=42 | EVR=0.955 | leak=0.045 | n=60
[OK] labse    | KARMA                        | k*=45 | EVR=0.954 | leak=0.046 | n=60
[w2v] Dropping 2 pairs due to OOV/zero vectors.
[OK] w2v      | KARANA                       | k*=44 | EVR=0.953 | leak=0.047 | n=58
[OK] bert     | KARANA                       | k*=43 | EVR=0.951 | leak=0.049 | n=60
[OK] roberta  | KARANA                       | k*=41 |

In [16]:
# =========================
# EVR Convergence Plot (Fig. 1)
# =========================
# Produces a single figure with one curve per relation showing cumulative EVR vs k,
# plus an optional mean curve with IQR shading across relations.

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Choose which embedding’s spectra to plot (must match files written above)
SELECT_EMB = "bert"   # options: "bert", "roberta", "labse", "w2v", "glove" (if available)
K_PLOT = min(64, K_MAX)   # plot up to K_MAX components
OUT_BASENAME = f"evr_convergence_{SELECT_EMB}"

# Shorten labels a bit for the legend
def _short_rel(name: str) -> str:
    return (name
            .replace("_country_capital", "")
            .replace("_", " ")
            .title())

curves = []
labels = []
x_grid = None

for rel in RELATION_SETS.keys():
    spec_csv = os.path.join(REPORT_DIR, f"spectrum_{SELECT_EMB}_{rel}.csv")
    if not os.path.exists(spec_csv):
        # If spectrum file is missing (e.g. run failed for some emb-rel), skip gracefully
        print(f"[SKIP] Missing spectrum for {SELECT_EMB} | {rel}: {spec_csv}")
        continue

    df = pd.read_csv(spec_csv)
    # Ensure we have the expected columns
    if "component" not in df.columns or "cumulative_evr" not in df.columns:
        print(f"[SKIP] Malformed spectrum file: {spec_csv}")
        continue

    # Truncate to K_PLOT points
    comp = df["component"].values[:K_PLOT]
    cev  = df["cumulative_evr"].values[:K_PLOT]

    # Align x_grid across relations
    if x_grid is None:
        x_grid = comp
    else:
        # If component counts differ slightly, align by min length
        m = min(len(x_grid), len(comp))
        x_grid = x_grid[:m]
        cev = cev[:m]

    curves.append(cev)
    labels.append(_short_rel(rel))

if not curves:
    print("[WARN] No curves available to plot. Make sure spectra were generated earlier.")
else:
    # Stack for summary stats across relations
    C = np.vstack(curves)  # shape: (n_relations, k_points)
    mean_curve = C.mean(axis=0)
    q25 = np.percentile(C, 25, axis=0)
    q75 = np.percentile(C, 75, axis=0)

    # --- Plot ---
    plt.figure(figsize=(8.0, 4.5))  # single-column friendly; tweak if you’ll use figure*
    # individual relation curves
    for cev, lab in zip(curves, labels):
        plt.plot(x_grid, cev, linewidth=1.0, alpha=0.8, label=lab)

    # mean + IQR band (optional but helpful)
    plt.fill_between(x_grid, q25, q75, alpha=0.15, linewidth=0)
    plt.plot(x_grid, mean_curve, linewidth=2.0, alpha=1.0)

    # target threshold
    plt.axhline(VAR_TARGET, linestyle="--", linewidth=1.0)

    plt.xlabel("k (components)")
    plt.ylabel("Cumulative EVR")
    plt.ylim(0.0, 1.0)
    plt.xlim(x_grid[0], x_grid[-1])
    plt.grid(True, alpha=0.3)

    # Legend outside to avoid clutter
    plt.legend(ncol=3, fontsize=7, frameon=False, bbox_to_anchor=(1.02, 1.0), loc="upper left")
    plt.tight_layout()

    out_png = os.path.join(REPORT_DIR, OUT_BASENAME + ".png")
    out_pdf = os.path.join(REPORT_DIR, OUT_BASENAME + ".pdf")
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.savefig(out_pdf, bbox_inches="tight")
    plt.close()

    print(f"[INFO] Saved EVR plot to:\n  {out_png}\n  {out_pdf}")


ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 60 and the array at index 10 has size 56

In [None]:
# ============================================
# MCQ-style evaluation via Relation Subspaces
# ============================================
# Goal:
#   For each (relation set, embedding), use the learned PCA subspace (basis_{emb}_{rel}_topk.npy)
#   to answer MCQ of the form: given u, pick v that minimizes orthogonal residual
#       r = e(v) - e(u)
#       residual_k(r) = || r - P_k r ||_2,  where P_k = V_k^T V_k (rows of V_k are orthonormal PCs)
#   We compute accuracy for k=1..k*, report the smallest k after which accuracy is always 100%,
#   and write per-dimension results to disk. Finally, aggregate key metrics in tables.

import os, json
from typing import Dict, List, Tuple

EVAL_DIR = os.path.join(REPORT_DIR, "mcq_eval")
os.makedirs(EVAL_DIR, exist_ok=True)

def _load_basis_and_meta(emb_kind: str, rel_name: str):
    """Load top-k basis (k* x d) and summary meta (including kept_indices)."""
    basis_path = os.path.join(REPORT_DIR, f"basis_{emb_kind}_{rel_name}_topk.npy")
    sum_path   = os.path.join(REPORT_DIR, f"summary_{emb_kind}_{rel_name}.json")
    if not os.path.exists(basis_path):
        raise FileNotFoundError(f"Missing basis file: {basis_path}")
    if not os.path.exists(sum_path):
        raise FileNotFoundError(f"Missing summary JSON: {sum_path}")
    basis = np.load(basis_path)
    with open(sum_path, "r", encoding="utf-8") as f:
        meta = json.load(f)
    kept_idx = meta.get("kept_indices", list(range(len(RELATION_SETS[rel_name]))))
    return basis, meta, kept_idx

def _projection_residual_norms(R: np.ndarray, V_k: np.ndarray) -> np.ndarray:
    """
    R: (..., d)
    V_k: (k, d) with orthonormal rows (from SVD V^T)
    Returns: residual L2 norm for each vector in R, shape R.shape[:-1]
    """
    # Project r onto span(V_k): P r = V_k^T (V_k r)
    # Do batched operations: proj_coeffs = R @ V_k.T    [shape (..., k)]
    proj_coeffs = R @ V_k.T
    recon = proj_coeffs @ V_k     # (..., d)
    resid = R - recon
    return np.linalg.norm(resid, axis=-1)

def _accuracy_vs_k(pairs: List[Tuple[str, str]],
                   kept_idx: List[int],
                   emb_kind: str,
                   basis_topk: np.ndarray,
                   normalize_rows: bool = NORMALIZE_ROWS) -> Dict[str, object]:
    """
    Compute accuracy for k=1..k* using full candidate set from the SAME relation list.
    Only evaluate over kept_idx (consistent with PCA training).
    """
    # Slice pairs to kept ones
    pairs_kept = [pairs[i] for i in kept_idx]
    if len(pairs_kept) < 2:
        raise ValueError("Not enough valid pairs for evaluation.")

    # Build candidate lists
    U = [u for (u, v) in pairs_kept]
    V = [v for (u, v) in pairs_kept]  # candidate pool

    # Embed once
    EU = embed_words(U, emb_kind)
    EV_all = embed_words(V, emb_kind)

    # Optional normalize BEFORE differencing
    if normalize_rows:
        EU = l2_normalize_rows(EU)
        EV_all = l2_normalize_rows(EV_all)

    n = len(pairs_kept)
    d = EU.shape[1]
    k_star = basis_topk.shape[0]

    # Precompute for each i the difference r(i,j) for all candidate j
    # R_ij shape: (n, n, d) where R_ij[i, j, :] = e(v_j) - e(u_i)
    # This is small (<= 60x60xd), fine.
    # Expand dims for broadcasting
    R_ij = EV_all[None, :, :] - EU[:, None, :]

    # ground-truth index mapping: for pair i, the correct v is V[i]
    # Accuracy per k (1..k*)
    acc_per_k = []
    per_item_records = []  # detailed rows to write out

    for k in range(1, k_star + 1):
        V_k = basis_topk[:k, :]  # (k, d)

        # Residual norms for all options
        # shape (n, n)
        resid = _projection_residual_norms(R_ij.reshape(-1, d), V_k).reshape(n, n)

        # For each i, choose argmin_j of residual[i, j]
        pred_idx = np.argmin(resid, axis=1)
        correct_idx = np.arange(n)
        correct_flags = (pred_idx == correct_idx)
        acc = float(np.mean(correct_flags))

        acc_per_k.append({"k": k, "accuracy": acc})

        # Save detailed per-item rows for this k (OPTIONAL: compact recording)
        for i in range(n):
            per_item_records.append({
                "k": k,
                "i": i,
                "u": U[i],
                "true_v": V[i],
                "pred_v": V[pred_idx[i]],
                "correct": bool(correct_flags[i]),
                "resid_true": float(resid[i, correct_idx[i]]),
                "resid_pred": float(resid[i, pred_idx[i]]),
                # A small rank statistic: where is the true_v among all candidates by residual?
                "true_rank": int(np.argsort(resid[i, :]).tolist().index(correct_idx[i]) + 1),
            })

    # Find smallest k after which accuracy is ALWAYS 1.0 up to k*
    k_always_100 = None
    for idx, row in enumerate(acc_per_k):
        if row["accuracy"] >= 0.999999:  # robust to float
            # check tail
            tail_ok = all(r["accuracy"] >= 0.999999 for r in acc_per_k[idx:])
            if tail_ok:
                k_always_100 = row["k"]
                break

    return {
        "k_star": k_star,
        "acc_per_k": acc_per_k,
        "k_always_100": k_always_100,
        "per_item_records": per_item_records,
        "n_eval": n,
    }

In [None]:
# ---------------------------------------------------
# Run MCQ evaluation across all embeddings & relations
# ---------------------------------------------------
EVAL_OVERVIEW = []  # one row per (emb, rel)
ACC_CURVES: Dict[Tuple[str, str], List[Dict[str, float]]] = {}

for rel_name, pairs in RELATION_SETS.items():
    for emb_kind in EMBEDDING_KINDS:
        try:
            basis_topk, meta, kept_idx = _load_basis_and_meta(emb_kind, rel_name)
            eval_res = _accuracy_vs_k(
                pairs=pairs,
                kept_idx=kept_idx,
                emb_kind=emb_kind,
                basis_topk=basis_topk,
                normalize_rows=NORMALIZE_ROWS
            )
            ACC_CURVES[(emb_kind, rel_name)] = eval_res["acc_per_k"]

            # Persist detailed iterations
            curve_path = os.path.join(EVAL_DIR, f"acc_curve_{emb_kind}_{rel_name}.csv")
            with open(curve_path, "w", encoding="utf-8") as f:
                f.write("k,accuracy\n")
                for row in eval_res["acc_per_k"]:
                    f.write(f"{row['k']},{row['accuracy']:.6f}\n")

            detail_path = os.path.join(EVAL_DIR, f"detailed_{emb_kind}_{rel_name}.csv")
            with open(detail_path, "w", encoding="utf-8") as f:
                f.write("k,i,u,true_v,pred_v,correct,resid_true,resid_pred,true_rank\n")
                for r in eval_res["per_item_records"]:
                    f.write(f"{r['k']},{r['i']},{r['u']},{r['true_v']},{r['pred_v']},{int(r['correct'])},{r['resid_true']:.6f},{r['resid_pred']:.6f},{r['true_rank']}\n")

            # Aggregate row
            acc_at_kstar = next((row["accuracy"] for row in eval_res["acc_per_k"] if row["k"] == eval_res["k_star"]), None)
            EVAL_OVERVIEW.append({
                "embedding": emb_kind,
                "relation": rel_name,
                "n_eval": eval_res["n_eval"],
                "k_star": eval_res["k_star"],
                "k_always_100": eval_res["k_always_100"],
                "acc_at_k_star": acc_at_kstar,
            })

            print(f"[QAEVAL OK] {emb_kind:8s} | {rel_name:28s} | n={eval_res['n_eval']:2d} | k*={eval_res['k_star']:2d} | k_100={eval_res['k_always_100']} | acc@k*={acc_at_kstar:.3f}")

        except Exception as e:
            print(f"[QAEVAL FAIL] {emb_kind} | {rel_name}: {e}")

[QAEVAL OK] w2v      | KARTA                        | n=57 | k*=43 | k_100=None | acc@k*=0.807
[QAEVAL OK] bert     | KARTA                        | n=60 | k*=44 | k_100=None | acc@k*=0.900
[QAEVAL OK] roberta  | KARTA                        | n=60 | k*=43 | k_100=None | acc@k*=0.817
[QAEVAL OK] labse    | KARTA                        | n=60 | k*=45 | k_100=None | acc@k*=0.867
[QAEVAL OK] w2v      | KARMA                        | n=60 | k*=44 | k_100=None | acc@k*=0.767
[QAEVAL OK] bert     | KARMA                        | n=60 | k*=42 | k_100=None | acc@k*=0.750
[QAEVAL OK] roberta  | KARMA                        | n=60 | k*=42 | k_100=None | acc@k*=0.650
[QAEVAL OK] labse    | KARMA                        | n=60 | k*=45 | k_100=None | acc@k*=0.783
[QAEVAL OK] w2v      | KARANA                       | n=58 | k*=44 | k_100=None | acc@k*=0.897
[QAEVAL OK] bert     | KARANA                       | n=60 | k*=43 | k_100=None | acc@k*=0.917
[QAEVAL OK] roberta  | KARANA                     

In [None]:
# ============================================
# Why do some relations hit 100% (and earlier than k*)?
# Quantitative diagnostics & correlations
# ============================================

from itertools import combinations

DIAG_DIR = os.path.join(REPORT_DIR, "diagnostics")
os.makedirs(DIAG_DIR, exist_ok=True)

def _load_summary(emb_kind: str, rel_name: str):
    sum_path = os.path.join(REPORT_DIR, f"summary_{emb_kind}_{rel_name}.json")
    with open(sum_path, "r", encoding="utf-8") as f:
        meta = json.load(f)
    return meta

def _relation_deltas_and_entities(pairs: List[Tuple[str, str]],
                                  kept_idx: List[int],
                                  emb_kind: str,
                                  normalize_rows: bool = NORMALIZE_ROWS):
    """Returns (R, Rc, EU, EV) with R = EV - EU (optionally row-normalized beforehand)."""
    pairs_kept = [pairs[i] for i in kept_idx]
    U = [u for (u, v) in pairs_kept]
    V = [v for (u, v) in pairs_kept]
    EU = embed_words(U, emb_kind)
    EV = embed_words(V, emb_kind)
    if normalize_rows:
        EU = l2_normalize_rows(EU)
        EV = l2_normalize_rows(EV)
    R = EV - EU
    Rc = safe_center(R)
    return R, Rc, EU, EV, U, V

def _pairwise_cosine_stats(X: np.ndarray):
    """Median & IQR of pairwise cosine similarity over rows of X (slow O(n^2) but n<=60)."""
    Xn = l2_normalize_rows(X)
    n = Xn.shape[0]
    if n < 2:
        return np.nan, np.nan
    cosines = []
    for i in range(n):
        for j in range(i+1, n):
            cosines.append(float(np.dot(Xn[i], Xn[j])))
    if not cosines:
        return np.nan, np.nan
    arr = np.asarray(cosines)
    return float(np.median(arr)), float(np.percentile(arr, 75) - np.percentile(arr, 25))

def _centroid_alignment_stats(R: np.ndarray):
    """Cosine alignment to uncentered delta centroid; median & IQR."""
    if R.shape[0] < 2:
        return np.nan, np.nan
    c = R.mean(axis=0)
    c_norm = np.linalg.norm(c)
    if c_norm < 1e-12:
        return np.nan, np.nan
    c_unit = c / c_norm
    proj = (l2_normalize_rows(R) @ c_unit)
    return float(np.median(proj)), float(np.percentile(proj, 75) - np.percentile(proj, 25))

def _spectral_metrics(meta_summary: Dict):
    """Steepness/aniso of EVR: top-1 EVR, EVR@3, spectral entropy, condition @ k*."""
    spectrum = meta_summary["spectrum"]
    evr = np.array([row["explained_variance_ratio"] for row in spectrum], dtype=np.float32)
    lam = np.array([row["eigenvalue"] for row in spectrum], dtype=np.float32)
    evr_top1 = float(evr[0]) if evr.size else np.nan
    evr_top3 = float(np.sum(evr[:min(3, len(evr))])) if evr.size else np.nan
    # spectral entropy (lower = spikier)
    p = evr + 1e-12
    p /= p.sum()
    spec_entropy = float(-np.sum(p * np.log(p)))
    # condition number (λ1 / λ_k*) within the learned subspace
    k_star = meta_summary["k_star"]
    if k_star >= 1 and lam.size >= k_star:
        cond_k = float((lam[0] + 1e-12) / (lam[k_star-1] + 1e-12))
    else:
        cond_k = np.nan
    return {
        "evr_top1": evr_top1,
        "evr_top3": evr_top3,
        "spectral_entropy": spec_entropy,
        "condnum_within_kstar": cond_k,
    }

def _type_heterogeneity(EU: np.ndarray, EV: np.ndarray):
    """
    Simple proxy: spread of heads vs tails separately (median pairwise cosine).
    Lower median cosine (or higher IQR) => more heterogeneity.
    """
    med_u, iqr_u = _pairwise_cosine_stats(EU)
    med_v, iqr_v = _pairwise_cosine_stats(EV)
    return {
        "head_pairwise_cos_median": med_u,
        "head_pairwise_cos_iqr": iqr_u,
        "tail_pairwise_cos_median": med_v,
        "tail_pairwise_cos_iqr": iqr_v,
    }

def _ambiguity_and_margins(pairs: List[Tuple[str, str]],
                           kept_idx: List[int],
                           emb_kind: str,
                           basis_topk: np.ndarray,
                           k_probe_list: List[int]):
    """
    For each k in k_probe_list, compute per-question margin:
        margin = (second_best_residual - best_residual) at that k
    Also compute accuracy, median margin, and fraction of 'easy' questions (large margin).
    """
    pairs_kept = [pairs[i] for i in kept_idx]
    U = [u for (u, v) in pairs_kept]
    V = [v for (u, v) in pairs_kept]
    EU = embed_words(U, emb_kind)
    EV_all = embed_words(V, emb_kind)
    if NORMALIZE_ROWS:
        EU = l2_normalize_rows(EU)
        EV_all = l2_normalize_rows(EV_all)
    n, d = EU.shape
    R_ij = EV_all[None, :, :] - EU[:, None, :]  # (n, n, d)

    out = []
    for k in k_probe_list:
        V_k = basis_topk[:k, :]
        resid = _projection_residual_norms(R_ij.reshape(-1, d), V_k).reshape(n, n)
        order = np.argsort(resid, axis=1)
        best = resid[np.arange(n), order[:, 0]]
        second = resid[np.arange(n), order[:, 1]]
        margin = second - best
        pred = order[:, 0]
        acc = float(np.mean(pred == np.arange(n)))
        out.append({
            "k": int(k),
            "acc": acc,
            "margin_median": float(np.median(margin)),
            "margin_iqr": float(np.percentile(margin, 75) - np.percentile(margin, 25)),
            "frac_easy_margin>0.1": float(np.mean(margin > 0.1)),
            "frac_easy_margin>0.05": float(np.mean(margin > 0.05)),
        })
    return out

def _gini_from_evr(meta_summary: Dict):
    """Gini coefficient over EVR as another 'spikiness' indicator."""
    evr = np.array([row["explained_variance_ratio"] for row in meta_summary["spectrum"]], dtype=np.float64)
    if evr.size == 0:
        return np.nan
    # Gini over a probability vector p:
    p = evr / evr.sum()
    # sort
    ps = np.sort(p)
    n = len(ps)
    cum = np.cumsum(ps)
    # Gini for distributions: 1 - 2 * sum_i (cum_i) / n
    gini = 1.0 - 2.0 * np.sum(cum) / n + 1.0 / n
    return float(max(0.0, gini))

# -------------------------
# Compute diagnostics
# -------------------------
DIAG_ROWS = []
MARGIN_ROWS = []  # per (emb, rel, k) margins/acc summary

for rel_name, pairs in RELATION_SETS.items():
    for emb_kind in EMBEDDING_KINDS:
        try:
            basis_topk, meta, kept_idx = _load_basis_and_meta(emb_kind, rel_name)
            R, Rc, EU, EV, U, V = _relation_deltas_and_entities(pairs, kept_idx, emb_kind, normalize_rows=NORMALIZE_ROWS)

            # Tightness & concentration
            pair_med_cos, pair_iqr_cos = _pairwise_cosine_stats(R)
            cent_med, cent_iqr = _centroid_alignment_stats(R)

            # Spectral properties from earlier PCA summary
            spec = _spectral_metrics(meta)
            gini = _gini_from_evr(meta)

            # Subspace energy at small k
            lam = np.array([row["eigenvalue"] for row in meta["spectrum"]], dtype=np.float64)
            lam_sum = lam.sum() + 1e-12
            evr_k1 = float((lam[0] if lam.size else 0.0) / lam_sum)
            evr_k3 = float(np.sum(lam[:min(3, len(lam))]) / lam_sum)
            evr_k5 = float(np.sum(lam[:min(5, len(lam))]) / lam_sum)

            # Type heterogeneity (heads/tails)
            type_het = _type_heterogeneity(EU, EV)

            # Ambiguity margins at small k and at k*
            k_star = int(meta["k_star"])
            k_probe = list({1, 2, 3, min(5, k_star), k_star})
            k_probe = sorted(set([k for k in k_probe if 1 <= k <= basis_topk.shape[0]]))
            margin_summary = _ambiguity_and_margins(pairs, kept_idx, emb_kind, basis_topk, k_probe)
            for ms in margin_summary:
                row_ms = {
                    "embedding": emb_kind,
                    "relation": rel_name,
                    **ms
                }
                MARGIN_ROWS.append(row_ms)

            # From earlier MCQ eval
            acc_curve = ACC_CURVES.get((emb_kind, rel_name), [])
            acc_k1 = next((r["accuracy"] for r in acc_curve if r["k"] == 1), np.nan)
            acc_k3 = next((r["accuracy"] for r in acc_curve if r["k"] == 3), np.nan)
            acc_k5 = next((r["accuracy"] for r in acc_curve if r["k"] == 5), np.nan)
            acc_kstar = next((r["accuracy"] for r in acc_curve if r["k"] == k_star), np.nan)
            k_always_100 = next((row["k_always_100"] for row in EVAL_OVERVIEW
                                 if row["embedding"] == emb_kind and row["relation"] == rel_name), None)

            # OOV rate proxy: fraction dropped vs original n
            n_orig = len(RELATION_SETS[rel_name])
            n_kept = len(kept_idx)
            oov_frac = float(1.0 - n_kept / max(1, n_orig))

            DIAG_ROWS.append({
                "embedding": emb_kind,
                "relation": rel_name,
                "n_used": n_kept,
                "k_star": k_star,
                "k_always_100": k_always_100,
                "acc_k1": acc_k1,
                "acc_k3": acc_k3,
                "acc_k5": acc_k5,
                "acc_kstar": acc_kstar,

                # tightness & centroid alignment
                "pairwise_cos_median": pair_med_cos,
                "pairwise_cos_iqr": pair_iqr_cos,
                "centroid_cos_median": cent_med,
                "centroid_cos_iqr": cent_iqr,

                # spectral / spikiness / anisotropy
                "evr_top1": spec["evr_top1"],
                "evr_top3": spec["evr_top3"],
                "spectral_entropy": spec["spectral_entropy"],
                "gini_evr": gini,
                "condnum_within_kstar": spec["condnum_within_kstar"],
                "evr_k1": evr_k1,
                "evr_k3": evr_k3,
                "evr_k5": evr_k5,

                # type heterogeneity
                **type_het,

                # OOV proxy
                "oov_frac": oov_frac,
            })
        except Exception as e:
            print(f"[DIAG FAIL] {emb_kind} | {rel_name}: {e}")

# Save diagnostics
if DIAG_ROWS:
    diag_df = pd.DataFrame(DIAG_ROWS).sort_values(["embedding", "relation"]).reset_index(drop=True)
    diag_csv = os.path.join(DIAG_DIR, "relation_diagnostics.csv")
    diag_df.to_csv(diag_csv, index=False)

    marg_df = pd.DataFrame(MARGIN_ROWS).sort_values(["embedding", "relation", "k"]).reset_index(drop=True)
    marg_csv = os.path.join(DIAG_DIR, "margin_summaries.csv")
    marg_df.to_csv(marg_csv, index=False)

    print(f"[INFO] Wrote diagnostics:\n  {diag_csv}\n  {marg_csv}")

# ---------------------------------------------
# Aggregate table: what predicts early perfection?
# ---------------------------------------------
# Hypothesis: higher tightness (centroid_cos_median, pairwise_cos_median), spikier spectra
# (evr_top1, evr_k3/gini), lower head/tail heterogeneity -> earlier k_always_100 and higher acc@small k.

def _pearson(x, y):
    x = np.asarray(x, dtype=np.float64)
    y = np.asarray(y, dtype=np.float64)
    mask = np.isfinite(x) & np.isfinite(y)
    if mask.sum() < 3:
        return np.nan
    x = x[mask]; y = y[mask]
    x = (x - x.mean()) / (x.std() + 1e-12)
    y = (y - y.mean()) / (y.std() + 1e-12)
    return float(np.mean(x * y))

if DIAG_ROWS:
    df = pd.DataFrame(DIAG_ROWS)

    # Target outcomes:
    # - smaller k_always_100 (fill None as NaN)
    # - acc_k1 and acc_k3
    df["k_always_100_num"] = pd.to_numeric(df["k_always_100"], errors="coerce")

    predictors = [
        "pairwise_cos_median",
        "centroid_cos_median",
        "pairwise_cos_iqr",
        "centroid_cos_iqr",
        "evr_top1",
        "evr_top3",
        "evr_k1",
        "evr_k3",
        "gini_evr",
        "spectral_entropy",
        "condnum_within_kstar",
        "head_pairwise_cos_median",
        "head_pairwise_cos_iqr",
        "tail_pairwise_cos_median",
        "tail_pairwise_cos_iqr",
        "oov_frac",
    ]
    outcomes = ["k_always_100_num", "acc_k1", "acc_k3"]

    corr_rows = []
    for pred in predictors:
        for out in outcomes:
            r = _pearson(df[pred].values, df[out].values)
            corr_rows.append({"predictor": pred, "outcome": out, "pearson_r": r})

    corr_df = pd.DataFrame(corr_rows).pivot(index="predictor", columns="outcome", values="pearson_r").sort_values(by="k_always_100_num", ascending=True)
    corr_csv = os.path.join(DIAG_DIR, "predictor_outcome_correlations.csv")
    corr_df.to_csv(corr_csv)

    # Human-readable summary table (by embedding rows)
    show_cols = [
        "relation", "k_star", "k_always_100", "acc_k1", "acc_k3", "acc_kstar",
        "pairwise_cos_median", "centroid_cos_median",
        "evr_top1", "evr_top3", "gini_evr", "spectral_entropy",
        "head_pairwise_cos_median", "tail_pairwise_cos_median",
        "oov_frac",
    ]
    summaries = []
    for emb in EMBEDDING_KINDS:
        sub = df[df["embedding"] == emb][["embedding"] + show_cols].copy()
        summaries.append(sub)
        # Save per-embedding table
        sub_csv = os.path.join(DIAG_DIR, f"summary_why_100_{emb}.csv")
        sub.to_csv(sub_csv, index=False)

    big_summary = pd.concat(summaries, axis=0).reset_index(drop=True)
    big_summary_csv = os.path.join(DIAG_DIR, "summary_why_100_all.csv")
    big_summary.to_csv(big_summary_csv, index=False)

    # Console prints
    print("\n=== Aggregate diagnostics suggesting why some relations perfect out early ===")
    # Display one compact table per embedding
    for emb in EMBEDDING_KINDS:
        sub = big_summary[big_summary["embedding"] == emb].drop(columns=["embedding"]).set_index("relation")
        # Keep only a subset for console brevity
        print(f"\n[Embedding: {emb}]")
        print(sub[["k_star","k_always_100","acc_k1","acc_k3","acc_kstar",
                  "pairwise_cos_median","centroid_cos_median","evr_top1","evr_top3","gini_evr"]]
              .round(3)
              .to_string(max_cols=120, max_rows=200, justify="center"))

    print(f"\n[INFO] Correlations written to: {corr_csv}")

# -------------------------------------------------------
# (Optional) Alternative scorer sanity check (whitened)
# -------------------------------------------------------
# In anisotropic subspaces, whitening can help earlier k. We compute accuracy at k in {1,3,5,k*}
# using Mahalanobis-like residuals inside the subspace (Λ^{-1/2} scaling).

def _whitened_residuals(R: np.ndarray, V_k: np.ndarray, lam: np.ndarray, k: int):
    """
    R: (..., d); V_k: (k, d) row-orthonormal; lam: eigenvalues of covariance (descending).
    Return residual computed as sqrt( || r_orth ||^2 + || Λ^{-1/2} (V_k r) ||^2 )  [so small values prefer directions with high variance]
    """
    # components in subspace:
    coeffs = R @ V_k.T                     # (..., k)
    lamk = lam[:k].copy()
    lamk = np.maximum(lamk, 1e-9)
    coeffs_white = coeffs / np.sqrt(lamk)  # whitening
    recon_white_norm2 = np.sum(coeffs_white**2, axis=-1)  # || Λ^{-1/2} V^T r ||^2

    # orthogonal component:
    recon = coeffs @ V_k                   # (..., d)
    orth = R - recon
    orth_norm2 = np.sum(orth**2, axis=-1)

    # combined residual (Mahalanobis-like)
    return np.sqrt(orth_norm2 + recon_white_norm2)

ALT_EVAL_DIR = os.path.join(EVAL_DIR, "alt_scorer")
os.makedirs(ALT_EVAL_DIR, exist_ok=True)

ALT_OVERVIEW = []

for rel_name, pairs in RELATION_SETS.items():
    for emb_kind in EMBEDDING_KINDS:
        try:
            basis_topk, meta, kept_idx = _load_basis_and_meta(emb_kind, rel_name)
            # eigenvalues from summary (descending)
            lam = np.array([row["eigenvalue"] for row in meta["spectrum"]], dtype=np.float64)
            k_star = int(meta["k_star"])

            # Prepare candidates
            pairs_kept = [pairs[i] for i in kept_idx]
            U = [u for (u, v) in pairs_kept]
            V = [v for (u, v) in pairs_kept]
            EU = embed_words(U, emb_kind)
            EV_all = embed_words(V, emb_kind)
            if NORMALIZE_ROWS:
                EU = l2_normalize_rows(EU)
                EV_all = l2_normalize_rows(EV_all)
            n, d = EU.shape
            R_ij = EV_all[None, :, :] - EU[:, None, :]
            k_probe = [1, 3, min(5, k_star), k_star]
            k_probe = sorted(set([k for k in k_probe if 1 <= k <= basis_topk.shape[0]]))

            # Evaluate
            rows = []
            for k in k_probe:
                V_k = basis_topk[:k, :]
                resid = _whitened_residuals(R_ij.reshape(-1, d), V_k, lam, k).reshape(n, n)
                pred = np.argmin(resid, axis=1)
                acc = float(np.mean(pred == np.arange(n)))
                rows.append({"k": k, "acc_whitened": acc})

            # Persist and aggregate
            alt_csv = os.path.join(ALT_EVAL_DIR, f"whitened_{emb_kind}_{rel_name}.csv")
            with open(alt_csv, "w", encoding="utf-8") as f:
                f.write("k,acc_whitened\n")
                for r in rows:
                    f.write(f"{r['k']},{r['acc_whitened']:.6f}\n")

            # Compare to original acc at same ks
            base_curve = {r["k"]: r["accuracy"] for r in ACC_CURVES.get((emb_kind, rel_name), [])}
            ALT_OVERVIEW.append({
                "embedding": emb_kind,
                "relation": rel_name,
                **{f"acc_base_k{r['k']}": float(base_curve.get(r["k"], np.nan)) for r in rows},
                **{f"acc_white_k{r['k']}": r["acc_whitened"] for r in rows},
            })
        except Exception as e:
            print(f"[ALT FAIL] {emb_kind} | {rel_name}: {e}")

if ALT_OVERVIEW:
    alt_df = pd.DataFrame(ALT_OVERVIEW).sort_values(["embedding", "relation"]).reset_index(drop=True)
    alt_csv = os.path.join(ALT_EVAL_DIR, "whitened_vs_base_comparison.csv")
    alt_df.to_csv(alt_csv, index=False)
    print(f"[INFO] Whitened scorer comparison written to: {alt_csv}")

# -------------------------------------------------------
# Final compact table: key predictors vs outcomes
# -------------------------------------------------------
if DIAG_ROWS:
    df = pd.DataFrame(DIAG_ROWS)
    display_cols = [
        "embedding","relation","n_used","k_star","k_always_100","acc_k1","acc_k3","acc_kstar",
        "centroid_cos_median","pairwise_cos_median","evr_top1","evr_top3","gini_evr","spectral_entropy",
        "head_pairwise_cos_median","tail_pairwise_cos_median","oov_frac"
    ]
    final_table = df[display_cols].copy().sort_values(["embedding","relation"]).reset_index(drop=True)
    final_csv = os.path.join(DIAG_DIR, "final_predictors_outcomes_table.csv")
    final_table.to_csv(final_csv, index=False)

    print("\n=== Final predictors vs outcomes (rows = embedding, per relation) ===")
    for emb in EMBEDDING_KINDS:
        sub = final_table[final_table["embedding"] == emb].drop(columns=["embedding"]).set_index("relation")
        print(f"\n[Embedding: {emb}]")
        print(sub.round(3).to_string(max_cols=120, max_rows=200, justify="center"))
    print(f"\n[INFO] Saved final table to: {final_csv}")

# ----------------------------------------------------------------------------------------
# Notes:
# - Metrics like high 'centroid_cos_median', high 'pairwise_cos_median', high 'evr_top1/3',
#   high 'gini_evr' (spiky spectrum) typically correlate with smaller k_always_100
#   and higher accuracy at low k (acc_k1, acc_k3).
# - Large head/tail heterogeneity (low medians / high IQRs) and high spectral_entropy
#   tend to push k_always_100 upward and depress early-k accuracy.
# - The whitened scorer provides a sanity check for anisotropy; improvements at small k
#   suggest within-subspace variance is highly anisotropic.
# ----------------------------------------------------------------------------------------


[INFO] Wrote diagnostics:
  ./relation_pca_reports/diagnostics/relation_diagnostics.csv
  ./relation_pca_reports/diagnostics/margin_summaries.csv

=== Aggregate diagnostics suggesting why some relations perfect out early ===

[Embedding: w2v]
                           k_star  k_always_100  acc_k1  acc_k3  acc_kstar  pairwise_cos_median  centroid_cos_median  evr_top1  evr_top3  gini_evr
relation                                                                                                                                          
KARTA                        43         NaN       0.035   0.070    0.807           0.270                0.530           0.097     0.240     0.509 
KARMA                        44         NaN       0.050   0.033    0.767           0.178                0.482           0.106     0.235     0.513 
KARANA                       44         NaN       0.224   0.259    0.897           0.045                0.286           0.106     0.224     0.494 
SAMPRADANA            

In [None]:
# =============================
# Country → Currency (60 pairs)
# =============================
COUNTRY_CURRENCY_60 = [
    ("united_states", "us_dollar"),
    ("canada", "canadian_dollar"),
    ("australia", "australian_dollar"),
    ("new_zealand", "new_zealand_dollar"),
    ("singapore", "singapore_dollar"),
    ("hong_kong", "hong_kong_dollar"),
    ("taiwan", "new_taiwan_dollar"),
    ("jamaica", "jamaican_dollar"),
    ("guyana", "guyanese_dollar"),
    ("namibia", "namibian_dollar"),

    ("mexico", "mexican_peso"),
    ("argentina", "argentine_peso"),
    ("chile", "chilean_peso"),
    ("colombia", "colombian_peso"),
    ("philippines", "philippine_peso"),

    ("india", "indian_rupee"),
    ("pakistan", "pakistani_rupee"),
    ("sri_lanka", "sri_lankan_rupee"),
    ("nepal", "nepalese_rupee"),
    ("mauritius", "mauritian_rupee"),

    ("indonesia", "indonesian_rupiah"),
    ("japan", "japanese_yen"),
    ("china", "chinese_yuan"),
    ("south_korea", "south_korean_won"),
    ("vietnam", "vietnamese_dong"),
    ("thailand", "thai_baht"),
    ("malaysia", "malaysian_ringgit"),

    ("saudi_arabia", "saudi_riyal"),
    ("qatar", "qatari_riyal"),
    ("united_arab_emirates", "uae_dirham"),
    ("morocco", "moroccan_dirham"),

    ("iraq", "iraqi_dinar"),
    ("jordan", "jordanian_dinar"),
    ("algeria", "algerian_dinar"),
    ("tunisia", "tunisian_dinar"),
    ("bahrain", "bahraini_dinar"),
    ("kuwait", "kuwaiti_dinar"),

    ("turkey", "turkish_lira"),
    ("israel", "israeli_shekel"),
    ("united_kingdom", "pound_sterling"),
    ("egypt", "egyptian_pound"),
    ("sudan", "sudanese_pound"),

    ("south_africa", "south_african_rand"),
    ("ethiopia", "ethiopian_birr"),
    ("nigeria", "nigerian_naira"),
    ("ghana", "ghanaian_cedi"),
    ("kenya", "kenyan_shilling"),
    ("uganda", "ugandan_shilling"),
    ("tanzania", "tanzanian_shilling"),

    ("switzerland", "swiss_franc"),
    ("senegal", "west_african_cfa_franc"),
    ("cameroon", "central_african_cfa_franc"),

    ("bangladesh", "bangladeshi_taka"),
    ("maldives", "maldivian_rufiyaa"),
    ("afghanistan", "afghan_afghani"),
    ("iran", "iranian_rial"),
    ("oman", "omani_rial"),
    ("armenia", "armenian_dram"),
    ("georgia", "georgian_lari"),
    ("azerbaijan", "azerbaijani_manat"),
]
assert len(COUNTRY_CURRENCY_60) == 60, "Need exactly 60 pairs."

# =============================
# Build matrices + save to CSV
# =============================
ambient_ccy60, proj_ccy60, sum_ccy60 = mcq_matrices(
    COUNTRY_CURRENCY_60, embedding_kind="bert", k=None, normalize_rows=True
)

ccy_dir = os.path.join(REPORT_DIR, "country_currency_60")
os.makedirs(ccy_dir, exist_ok=True)
ambient_ccy60.to_csv(os.path.join(ccy_dir, "ambient_residuals.csv"))
proj_ccy60.to_csv(os.path.join(ccy_dir, "orthogonal_leakage.csv"))
sum_ccy60.to_csv(os.path.join(ccy_dir, "mcq_summary.csv"))

print(f"[WRITE] Ambient residuals  -> {os.path.join(ccy_dir, 'ambient_residuals.csv')}")
print(f"[WRITE] Orthogonal leakage -> {os.path.join(ccy_dir, 'orthogonal_leakage.csv')}")
print(f"[WRITE] MCQ summary        -> {os.path.join(ccy_dir, 'mcq_summary.csv')}")


[MCQ] embedding=bert | k=43 | accuracy=1.000
[WRITE] Ambient residuals  -> ./relation_pca_reports/country_currency_60/ambient_residuals.csv
[WRITE] Orthogonal leakage -> ./relation_pca_reports/country_currency_60/orthogonal_leakage.csv
[WRITE] MCQ summary        -> ./relation_pca_reports/country_currency_60/mcq_summary.csv


In [None]:
# ==========================================
# Learn PCA + build residual/leakage matrices
# ==========================================
import os, json
import numpy as np
import pandas as pd

def learn_relation_space_and_matrices(
    pairs,
    embedding_kind="bert",
    var_target=0.95,
    normalize_rows=True,
    outdir=os.path.join(REPORT_DIR, "country_currency_60_pca")
):
    os.makedirs(outdir, exist_ok=True)

    U = [u for (u, v) in pairs]  # rows / queries (countries)
    V = [v for (u, v) in pairs]  # cols / candidates (currencies)

    # ---- Embed + (optional) row-normalize BEFORE differencing ----
    EU = embed_words(U, embedding_kind)
    EV = embed_words(V, embedding_kind)
    if normalize_rows:
        EU = l2_normalize_rows(EU)
        EV = l2_normalize_rows(EV)

    # ---- Relation deltas and PCA ----
    R  = EV - EU                # (n, d)
    Rc = safe_center(R)
    n, d = Rc.shape

    U_svd, s, Vt = np.linalg.svd(Rc, full_matrices=False)         # Rc = U S V^T
    lam = (s**2) / max(n - 1, 1)                                  # covariance eigenvalues
    evr = lam / (lam.sum() if lam.sum() > 0 else 1.0)             # explained variance ratio
    cum = np.cumsum(evr)

    k_star = int(np.searchsorted(cum, var_target) + 1)
    k_star = max(1, min(k_star, len(s)))                          # clip

    V_k = Vt[:k_star, :]                                          # (k*, d)
    Pk  = V_k.T @ V_k                                             # projection matrix onto span(V_k)

    # ---- Build matrices over all candidates ----
    # R_ij[i, j, :] = e(v_j) - e(u_i)
    R_ij = EV[None, :, :] - EU[:, None, :]

    ambient = np.linalg.norm(R_ij, axis=2)                        # ||r||
    proj    = np.linalg.norm(R_ij - (R_ij @ Pk), axis=2)          # ||(I - Pk) r||

    # ---- MCQ accuracy under projected residual ----
    pred_idx = np.argmin(proj, axis=1)
    acc = float(np.mean(pred_idx == np.arange(n)))

    # ---- Save spectrum + matrices + summary ----
    spectrum_df = pd.DataFrame({
        "component": np.arange(1, len(s) + 1),
        "singular_value": s,
        "eigenvalue": lam,
        "explained_variance_ratio": evr,
        "cumulative_evr": np.cumsum(evr),
    })
    ambient_df = pd.DataFrame(ambient, index=U, columns=V)
    leak_df    = pd.DataFrame(proj,    index=U, columns=V)
    summary_df = pd.DataFrame({
        "true_tail": V,
        "pred_tail": [V[j] for j in pred_idx],
        "correct": (pred_idx == np.arange(n)),
    }, index=U)

    spectrum_df.to_csv(os.path.join(outdir, "pca_spectrum.csv"), index=False)
    ambient_df.to_csv(os.path.join(outdir, "ambient_residuals.csv"))
    leak_df.to_csv(os.path.join(outdir, "orthogonal_leakage.csv"))
    summary_df.to_csv(os.path.join(outdir, "mcq_summary.csv"))

    meta = {
        "embedding_kind": embedding_kind,
        "n_pairs": n,
        "embedding_dim": int(d),
        "k_star": int(k_star),
        "evr_at_k": float(cum[k_star - 1]),
        "orthogonal_leakage_at_k": float(1.0 - cum[k_star - 1]),
        "mcq_accuracy_projected": acc,
        "normalize_rows": bool(normalize_rows),
        "var_target": float(var_target),
    }
    with open(os.path.join(outdir, "summary.json"), "w") as f:
        json.dump(meta, f, indent=2)

    print(f"[REL-PCA] emb={embedding_kind} | n={n} | d={d} | k*={k_star} | EVR@k={meta['evr_at_k']:.3f} | acc={acc:.3f}")
    print(f"[WRITE] {outdir}")
    return ambient_df, leak_df, spectrum_df, summary_df, meta

# ------------------------------------------
# Use your 60 country→currency pairs & run
# ------------------------------------------
# Reuse the 60 you defined earlier; assert for safety:
assert len(COUNTRY_CURRENCY_60) == 60

ambient_ccy60, leak_ccy60, spec_ccy60, sum_ccy60, meta_ccy60 = learn_relation_space_and_matrices(
    COUNTRY_CURRENCY_60,
    embedding_kind="bert",    # or "roberta" / "labse" / "w2v" (if ready)
    var_target=0.95,
    normalize_rows=True,
    outdir=os.path.join(REPORT_DIR, "country_currency_60_pca")
)

# (Optional, quick peek in notebooks)
display(spec_ccy60.head()); display(ambient_ccy60.round(3)); display(leak_ccy60.round(3)); display(sum_ccy60)
print(meta_ccy60)


[REL-PCA] emb=bert | n=60 | d=768 | k*=43 | EVR@k=0.952 | acc=1.000
[WRITE] ./relation_pca_reports/country_currency_60_pca


Unnamed: 0,component,singular_value,eigenvalue,explained_variance_ratio,cumulative_evr
0,1,1.690237,0.048422,0.136243,0.136243
1,2,1.668711,0.047197,0.132795,0.269037
2,3,1.223347,0.025366,0.07137,0.340408
3,4,1.098023,0.020435,0.057497,0.397904
4,5,1.006111,0.017157,0.048274,0.446178


Unnamed: 0,us_dollar,canadian_dollar,australian_dollar,new_zealand_dollar,singapore_dollar,hong_kong_dollar,new_taiwan_dollar,jamaican_dollar,guyanese_dollar,namibian_dollar,...,west_african_cfa_franc,central_african_cfa_franc,bangladeshi_taka,maldivian_rufiyaa,afghan_afghani,iranian_rial,omani_rial,armenian_dram,georgian_lari,azerbaijani_manat
united_states,0.626,0.673,0.684,0.702,0.739,0.832,0.698,0.696,0.892,0.718,...,0.855,0.877,0.74,0.862,0.785,0.736,0.775,0.73,0.758,0.77
canada,0.996,0.938,0.996,0.971,0.91,1.093,1.032,0.963,1.069,0.954,...,1.157,1.155,0.973,1.135,1.007,0.969,1.043,0.991,1.018,0.955
australia,0.964,0.938,0.897,0.909,0.884,1.02,0.973,0.929,1.023,0.917,...,1.104,1.093,0.965,1.09,0.987,0.946,0.991,0.941,0.992,0.963
new_zealand,0.765,0.748,0.746,0.534,0.74,0.846,0.738,0.76,0.864,0.751,...,0.969,0.974,0.804,0.911,0.822,0.802,0.807,0.831,0.842,0.844
singapore,1.0,0.983,0.982,0.962,0.778,1.018,0.961,0.957,1.02,0.934,...,1.16,1.143,0.963,1.134,1.015,0.976,1.018,0.972,1.006,0.948
hong_kong,0.827,0.782,0.798,0.701,0.703,0.547,0.724,0.761,0.94,0.774,...,0.879,0.874,0.796,0.942,0.878,0.838,0.839,0.811,0.852,0.837
taiwan,0.954,0.948,0.944,0.947,0.808,0.971,0.895,0.917,0.998,0.921,...,1.118,1.107,0.953,1.112,0.972,0.938,0.977,0.95,0.973,0.928
jamaica,0.959,0.937,0.941,0.932,0.887,1.025,0.953,0.851,0.998,0.899,...,1.096,1.092,0.96,1.101,0.981,0.973,0.995,0.954,0.973,0.935
guyana,0.981,0.953,0.966,0.957,0.892,1.065,0.98,0.92,0.998,0.922,...,1.126,1.11,0.986,1.11,0.951,0.969,0.985,0.967,0.983,0.949
namibia,1.03,0.996,1.0,0.97,0.916,1.069,1.005,0.976,1.079,0.87,...,1.13,1.118,1.041,1.162,1.007,1.029,1.041,1.022,1.038,1.013


Unnamed: 0,us_dollar,canadian_dollar,australian_dollar,new_zealand_dollar,singapore_dollar,hong_kong_dollar,new_taiwan_dollar,jamaican_dollar,guyanese_dollar,namibian_dollar,...,west_african_cfa_franc,central_african_cfa_franc,bangladeshi_taka,maldivian_rufiyaa,afghan_afghani,iranian_rial,omani_rial,armenian_dram,georgian_lari,azerbaijani_manat
united_states,0.271,0.449,0.447,0.542,0.552,0.638,0.535,0.496,0.519,0.516,...,0.526,0.529,0.544,0.543,0.551,0.539,0.536,0.505,0.509,0.54
canada,0.457,0.326,0.413,0.526,0.544,0.626,0.538,0.462,0.503,0.505,...,0.513,0.52,0.545,0.549,0.541,0.536,0.55,0.517,0.522,0.536
australia,0.474,0.411,0.308,0.517,0.515,0.617,0.528,0.482,0.521,0.496,...,0.52,0.521,0.527,0.522,0.535,0.532,0.545,0.517,0.506,0.535
new_zealand,0.563,0.542,0.53,0.313,0.589,0.624,0.56,0.589,0.577,0.579,...,0.627,0.624,0.642,0.607,0.635,0.632,0.586,0.643,0.613,0.637
singapore,0.528,0.504,0.489,0.545,0.343,0.551,0.443,0.528,0.533,0.501,...,0.542,0.541,0.542,0.563,0.552,0.575,0.563,0.556,0.557,0.558
hong_kong,0.617,0.588,0.59,0.538,0.533,0.316,0.509,0.582,0.619,0.601,...,0.611,0.618,0.605,0.63,0.629,0.658,0.641,0.654,0.625,0.631
taiwan,0.501,0.478,0.47,0.511,0.443,0.497,0.292,0.479,0.501,0.491,...,0.509,0.525,0.535,0.527,0.546,0.526,0.558,0.518,0.51,0.533
jamaica,0.474,0.433,0.434,0.516,0.51,0.591,0.504,0.282,0.445,0.473,...,0.468,0.49,0.492,0.491,0.532,0.531,0.531,0.514,0.485,0.517
guyana,0.475,0.435,0.445,0.505,0.495,0.611,0.491,0.428,0.267,0.439,...,0.47,0.473,0.535,0.505,0.529,0.535,0.527,0.516,0.503,0.534
namibia,0.507,0.481,0.464,0.53,0.513,0.631,0.525,0.498,0.474,0.31,...,0.481,0.499,0.556,0.546,0.532,0.579,0.565,0.543,0.524,0.537


Unnamed: 0,true_tail,pred_tail,correct
united_states,us_dollar,us_dollar,True
canada,canadian_dollar,canadian_dollar,True
australia,australian_dollar,australian_dollar,True
new_zealand,new_zealand_dollar,new_zealand_dollar,True
singapore,singapore_dollar,singapore_dollar,True
hong_kong,hong_kong_dollar,hong_kong_dollar,True
taiwan,new_taiwan_dollar,new_taiwan_dollar,True
jamaica,jamaican_dollar,jamaican_dollar,True
guyana,guyanese_dollar,guyanese_dollar,True
namibia,namibian_dollar,namibian_dollar,True


{'embedding_kind': 'bert', 'n_pairs': 60, 'embedding_dim': 768, 'k_star': 43, 'evr_at_k': 0.9518915414810181, 'orthogonal_leakage_at_k': 0.048108458518981934, 'mcq_accuracy_projected': 1.0, 'normalize_rows': True, 'var_target': 0.95}


In [None]:
# ===========================================================
# Learn PCA on the entire space; evaluate MCQ on 6 countries
# ===========================================================
import os, json
import numpy as np
import pandas as pd

VAR_TARGET_FULL = 0.95    # target EVR for choosing k* on the full space
NORMALIZE_ROWS  = True    # match your pipeline

def learn_relation_pca_from_fullspace(pairs, embedding_kind="bert",
                                      var_target=VAR_TARGET_FULL,
                                      normalize_rows=NORMALIZE_ROWS):
    """
    Learn relation PCA on the full country->currency list (pairs).
    Returns dict with projector P_k*, basis V_k, eigenvalues, and meta.
    """
    U = [u for (u, v) in pairs]
    V = [v for (u, v) in pairs]
    EU = embed_words(U, embedding_kind)
    EV = embed_words(V, embedding_kind)
    if normalize_rows:
        EU = l2_normalize_rows(EU)
        EV = l2_normalize_rows(EV)

    R  = EV - EU
    Rc = safe_center(R)
    n, d = Rc.shape

    U_svd, s, Vt = np.linalg.svd(Rc, full_matrices=False)
    lam = (s**2) / max(n - 1, 1)
    evr = lam / (lam.sum() if lam.sum() > 0 else 1.0)
    cum = np.cumsum(evr)

    k_star = int(np.searchsorted(cum, var_target) + 1)
    k_star = max(1, min(k_star, len(s)))
    V_k = Vt[:k_star, :]             # (k*, d)
    Pk  = V_k.T @ V_k                 # (d, d)

    return {
        "Pk": Pk,
        "V_k": V_k,
        "k_star": k_star,
        "lam": lam,
        "evr": evr,
        "evr_at_k": float(cum[k_star - 1]),
        "leakage_at_k": float(1.0 - cum[k_star - 1]),
        "n_full": int(n),
        "d": int(d),
    }

def mcq_matrices_with_fixed_projector(test_pairs, projector, embedding_kind="bert",
                                      normalize_rows=NORMALIZE_ROWS):
    """
    Using a fixed projector (learned from the full space), compute:
      - ambient residuals ||e(v_j)-e(u_i)||
      - orthogonal leakage ||(I-P) (e(v_j)-e(u_i))||
    Returns ambient_df, leakage_df, summary_df (MCQ using leakage argmin).
    """
    U = [u for (u, v) in test_pairs]
    V = [v for (u, v) in test_pairs]

    EU = embed_words(U, embedding_kind)
    EV = embed_words(V, embedding_kind)
    if normalize_rows:
        EU = l2_normalize_rows(EU)
        EV = l2_normalize_rows(EV)

    # Build all r_ij = e(v_j) - e(u_i)
    R_ij = EV[None, :, :] - EU[:, None, :]           # (n, n, d)
    I_minus_P = np.eye(projector.shape[0], dtype=np.float32) - projector

    ambient = np.linalg.norm(R_ij, axis=2)
    leakage = np.linalg.norm(R_ij @ I_minus_P, axis=2)

    # MCQ prediction per row = argmin leakage
    pred_idx = np.argmin(leakage, axis=1)
    acc = float(np.mean(pred_idx == np.arange(len(U))))

    ambient_df = pd.DataFrame(ambient, index=U, columns=V)
    leakage_df = pd.DataFrame(leakage, index=U, columns=V)
    summary_df = pd.DataFrame({
        "true_tail": V,
        "pred_tail": [V[j] for j in pred_idx],
        "correct": (pred_idx == np.arange(len(U))),
    }, index=U)

    return ambient_df, leakage_df, summary_df, acc

# -------------------------
# 1) Learn on the ENTIRE set
# -------------------------
# Assume you already defined COUNTRY_CURRENCY_60 (full training space, ~60 pairs)
# e.g., COUNTRY_CURRENCY_60 = [("india","indian_rupee"), ("japan","japanese_yen"), ...]
full_meta = learn_relation_pca_from_fullspace(
    pairs=COUNTRY_CURRENCY_60,
    embedding_kind="bert",       # or "roberta"/"labse"/"w2v"/"glove" if wired up
    var_target=0.95,
    normalize_rows=True
)
print(f"[TRAIN] k*={full_meta['k_star']} | EVR@k={full_meta['evr_at_k']:.3f} | leak={full_meta['leakage_at_k']:.3f} | n={full_meta['n_full']} | d={full_meta['d']}")

# -----------------------------------------------
# 2) Evaluate MCQ/leakage ONLY on these 6 pairs
#    (India/Japan/China/US plus 2 more)
# -----------------------------------------------
TEST_COUNTRY_CURRENCY_6 = [
    ("india",          "indian_rupee"),
    ("japan",          "japanese_yen"),
    ("china",          "chinese_yuan"),     # or "renminbi" if that’s your token
    ("united_states",  "us_dollar"),
    ("canada",         "canadian_dollar"),
    ("australia",      "australian_dollar"),
]

ambient6, leak6, sum6, acc6 = mcq_matrices_with_fixed_projector(
    test_pairs=TEST_COUNTRY_CURRENCY_6,
    projector=full_meta["Pk"],
    embedding_kind="bert",
    normalize_rows=True
)

print(f"[TEST-6] MCQ accuracy (projected residual): {acc6:.3f}")
# Optional: peek
display(ambient6.round(3)); display(leak6.round(3)); display(sum6)


[TRAIN] k*=43 | EVR@k=0.952 | leak=0.048 | n=60 | d=768
[TEST-6] MCQ accuracy (projected residual): 1.000


Unnamed: 0,indian_rupee,japanese_yen,chinese_yuan,us_dollar,canadian_dollar,australian_dollar
india,0.936,0.877,0.938,0.967,0.952,0.954
japan,1.03,0.797,0.907,0.979,0.965,0.959
china,0.982,0.842,0.82,0.927,0.91,0.902
united_states,0.828,0.69,0.726,0.626,0.673,0.684
canada,1.1,0.923,0.959,0.996,0.938,0.996
australia,1.02,0.884,0.923,0.964,0.938,0.897


Unnamed: 0,indian_rupee,japanese_yen,chinese_yuan,us_dollar,canadian_dollar,australian_dollar
india,0.312,0.468,0.493,0.501,0.478,0.473
japan,0.514,0.277,0.484,0.531,0.519,0.504
china,0.495,0.449,0.277,0.496,0.48,0.472
united_states,0.492,0.501,0.506,0.271,0.449,0.447
canada,0.495,0.51,0.51,0.457,0.326,0.413
australia,0.487,0.489,0.488,0.474,0.411,0.308


Unnamed: 0,true_tail,pred_tail,correct
india,indian_rupee,indian_rupee,True
japan,japanese_yen,japanese_yen,True
china,chinese_yuan,chinese_yuan,True
united_states,us_dollar,us_dollar,True
canada,canadian_dollar,canadian_dollar,True
australia,australian_dollar,australian_dollar,True
