# BIP Native-Language Morality Experiment (v9)

**Testing Universal Moral Structure with NO Translation Bridge**

This experiment tests whether moral cognition has invariant structure using a rigorous methodology:

1. **Native-language pattern matching**: Each language uses its OWN patterns to extract moral labels
2. **Native text input**: Model sees original Hebrew/Chinese/Arabic text
3. **Mathematical alignment only**: The ONLY connection between languages is the learned latent space

---

## Why This Matters

**v8 approach** (weaker): Extract labels from English translations → cheats via semantic bridge

**v9 approach** (stronger): 
- Hebrew text → Hebrew patterns → Hebrew moral labels
- Chinese text → Chinese patterns → Chinese moral labels  
- Arabic text → Arabic patterns → Arabic moral labels
- English text → English patterns → English moral labels

If the model can still transfer bond representations across languages, it means the **mathematical structure of morality is truly universal** - not just an artifact of translation.

---

## Native Pattern Examples

| Concept | Hebrew | Chinese | Arabic | English |
|---------|--------|---------|--------|--------|
| Obligation | חייב, צריך, אסור | 必須, 應該, 當 | يجب، واجب، فرض | must, shall, ought |
| Harm | הרג, נזק, פגע | 殺, 害, 傷 | قتل، ضرر، أذى | kill, harm, hurt |
| Family | אב, אם, כבד | 父, 母, 孝 | والدين، أب، أم | father, mother, honor |
| Authority | מלך, שופט, צווה | 君, 臣, 命 | طاعة، حكم، أمر | king, judge, command |

---

In [None]:
#@title 1. Setup { display-mode: "form" }

import time
EXPERIMENT_START = time.time()

print("="*60)
print("BIP NATIVE-LANGUAGE EXPERIMENT (v9)")
print("No Translation Bridge - Pure Mathematical Alignment")
print("="*60)
print()

TASKS = [
    "Install dependencies",
    "Download corpora",
    "Define native patterns",
    "Extract bonds (native)",
    "Generate splits",
    "Train BIP model",
    "Linear probe test",
    "Evaluate results"
]
task_status = {task: "pending" for task in TASKS}
task_times = {}
task_start_time = None

def print_progress():
    print()
    print("-"*50)
    for task in TASKS:
        status = task_status[task]
        mark = "[X]" if status == "done" else "[>]" if status == "running" else "[ ]"
        time_str = f" ({task_times.get(task, 0):.1f}s)" if task in task_times else ""
        print(f"  {mark} {task}{time_str}")
    print(f"  Elapsed: {(time.time() - EXPERIMENT_START)/60:.1f} min")
    print("-"*50, flush=True)

def mark_task(task, status):
    global task_start_time
    if status == "running":
        task_start_time = time.time()
    elif status == "done" and task_start_time:
        task_times[task] = time.time() - task_start_time
    task_status[task] = status
    print_progress()

print_progress()
mark_task("Install dependencies", "running")

import os, subprocess, sys

for dep in ["transformers", "torch", "sentence-transformers", "pandas", "tqdm", "psutil", "scikit-learn", "requests"]:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", dep])

import torch, json, psutil, shutil, gc
import torch.nn as nn
import torch.nn.functional as F

USE_TPU = 'COLAB_TPU_ADDR' in os.environ

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    device = torch.device("cuda")
    vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    # Adjust batch size based on GPU
    if 'A100' in gpu_name:
        BASE_BATCH_SIZE = 512
        GPU_TIER = 'A100'
    elif 'L4' in gpu_name:
        BASE_BATCH_SIZE = 512
        GPU_TIER = 'L4'
    elif 'T4' in gpu_name:
        BASE_BATCH_SIZE = 256  # T4 has 16GB, be conservative
        GPU_TIER = 'T4'
    elif vram_gb >= 20:
        BASE_BATCH_SIZE = 512
        GPU_TIER = f'HIGH ({vram_gb:.0f}GB)'
    elif vram_gb >= 12:
        BASE_BATCH_SIZE = 256
        GPU_TIER = f'MED ({vram_gb:.0f}GB)'
    else:
        BASE_BATCH_SIZE = 128
        GPU_TIER = f'LOW ({vram_gb:.0f}GB)'
    
    print(f"GPU: {gpu_name}")
    print(f"VRAM: {vram_gb:.1f} GB")
    print(f"Tier: {GPU_TIER}")
elif USE_TPU:
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    IS_L4 = False
    BASE_BATCH_SIZE = 256
else:
    device = torch.device("cpu")
    IS_L4 = False
    BASE_BATCH_SIZE = 64

print(f"Device: {device}")
print(f"Batch size: {BASE_BATCH_SIZE} (optimized for {GPU_TIER})")

if torch.cuda.is_available():
    from torch.cuda.amp import GradScaler
    USE_AMP = True
    scaler = GradScaler()
else:
    USE_AMP = False
    scaler = None

from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = '/content/drive/MyDrive/BIP_native_v9'
os.makedirs(SAVE_DIR, exist_ok=True)

for d in ["data/processed", "data/splits", "data/raw", "models/checkpoints", "results"]:
    os.makedirs(d, exist_ok=True)

def print_resources(label=""):
    mem = psutil.virtual_memory()
    msg = f"[{label}] RAM: {mem.used/1e9:.1f}/{mem.total/1e9:.1f}GB"
    if torch.cuda.is_available():
        msg += f" | GPU: {torch.cuda.memory_allocated()/1e9:.1f}GB"
    print(msg)

mark_task("Install dependencies", "done")


In [None]:
#@title 2. Download All Corpora { display-mode: "form" }

import subprocess
import os
import json
import requests
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm

mark_task("Download corpora", "running")

# ========== SEFARIA ==========
print("="*60)
print("1. SEFARIA (Hebrew/Aramaic)")
print("="*60)

sefaria_path = 'data/raw/Sefaria-Export'
if not os.path.exists(f"{sefaria_path}/json"):
    process = subprocess.Popen(
        ['git', 'clone', '--depth', '1', '--progress',
         'https://github.com/Sefaria/Sefaria-Export.git', sefaria_path],
        stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    for line in process.stdout:
        print(line, end='', flush=True)
    process.wait()
else:
    print("Already downloaded.")

# ========== CHINESE ==========
print("\n" + "="*60)
print("2. CHINESE CLASSICS")
print("="*60)

chinese_dir = Path('data/raw/chinese')
chinese_dir.mkdir(parents=True, exist_ok=True)

# Embedded bilingual Chinese classics (original + translation for reference only)
CHINESE_TEXTS = [
    # Analects - Confucian ethics
    ("學而時習之，不亦說乎？有朋自遠方來，不亦樂乎？人不知而不慍，不亦君子乎？", "analects", "CONFUCIAN", -5),
    ("其為人也孝弟，而好犯上者，鮮矣；不好犯上，而好作亂者，未之有也。君子務本，本立而道生。孝弟也者，其為仁之本與！", "analects", "CONFUCIAN", -5),
    ("巧言令色，鮮矣仁！", "analects", "CONFUCIAN", -5),
    ("吾日三省吾身：為人謀而不忠乎？與朋友交而不信乎？傳不習乎？", "analects", "CONFUCIAN", -5),
    ("弟子入則孝，出則弟，謹而信，汎愛眾，而親仁。", "analects", "CONFUCIAN", -5),
    ("君子不重則不威，學則不固。主忠信，無友不如己者，過則勿憚改。", "analects", "CONFUCIAN", -5),
    ("父在，觀其志；父沒，觀其行；三年無改於父之道，可謂孝矣。", "analects", "CONFUCIAN", -5),
    ("禮之用，和為貴。", "analects", "CONFUCIAN", -5),
    ("信近於義，言可復也。恭近於禮，遠恥辱也。", "analects", "CONFUCIAN", -5),
    ("君子食無求飽，居無求安，敏於事而慎於言，就有道而正焉。", "analects", "CONFUCIAN", -5),
    ("不患人之不己知，患不知人也。", "analects", "CONFUCIAN", -5),
    ("為政以德，譬如北辰，居其所而眾星共之。", "analects", "CONFUCIAN", -5),
    ("道之以政，齊之以刑，民免而無恥；道之以德，齊之以禮，有恥且格。", "analects", "CONFUCIAN", -5),
    ("吾十有五而志于學，三十而立，四十而不惑，五十而知天命，六十而耳順，七十而從心所欲，不踰矩。", "analects", "CONFUCIAN", -5),
    ("溫故而知新，可以為師矣。", "analects", "CONFUCIAN", -5),
    ("君子不器。", "analects", "CONFUCIAN", -5),
    ("先行其言而後從之。", "analects", "CONFUCIAN", -5),
    ("君子周而不比，小人比而不周。", "analects", "CONFUCIAN", -5),
    ("學而不思則罔，思而不學則殆。", "analects", "CONFUCIAN", -5),
    ("知之為知之，不知為不知，是知也。", "analects", "CONFUCIAN", -5),
    # Dao De Jing - Daoist philosophy
    ("道可道，非常道。名可名，非常名。無名天地之始；有名萬物之母。", "daodejing", "DAOIST", -6),
    ("天下皆知美之為美，斯惡已。皆知善之為善，斯不善已。", "daodejing", "DAOIST", -6),
    ("不尚賢，使民不爭；不貴難得之貨，使民不為盜。", "daodejing", "DAOIST", -6),
    ("道沖而用之或不盈。淵兮似萬物之宗。", "daodejing", "DAOIST", -6),
    ("天地不仁，以萬物為芻狗；聖人不仁，以百姓為芻狗。", "daodejing", "DAOIST", -6),
    ("上善若水。水善利萬物而不爭，處眾人之所惡，故幾於道。", "daodejing", "DAOIST", -6),
    ("持而盈之，不如其已；揣而銳之，不可長保。", "daodejing", "DAOIST", -6),
    ("三十輻共一轂，當其無，有車之用。", "daodejing", "DAOIST", -6),
    ("五色令人目盲；五音令人耳聾；五味令人口爽。", "daodejing", "DAOIST", -6),
    ("大道廢，有仁義；智慧出，有大偽。", "daodejing", "DAOIST", -6),
    ("絕聖棄智，民利百倍；絕仁棄義，民復孝慈；絕巧棄利，盜賊無有。", "daodejing", "DAOIST", -6),
    ("曲則全，枉則直，窪則盈，敝則新，少則得，多則惑。", "daodejing", "DAOIST", -6),
    ("希言自然。故飄風不終朝，驟雨不終日。", "daodejing", "DAOIST", -6),
    ("人法地，地法天，天法道，道法自然。", "daodejing", "DAOIST", -6),
    ("知人者智，自知者明。勝人者有力，自勝者強。", "daodejing", "DAOIST", -6),
    # Mencius - Confucian ethics on human nature
    ("人皆有不忍人之心。先王有不忍人之心，斯有不忍人之政矣。", "mencius", "CONFUCIAN", -4),
    ("惻隱之心，仁之端也；羞惡之心，義之端也；辭讓之心，禮之端也；是非之心，智之端也。", "mencius", "CONFUCIAN", -4),
    ("人之所以異於禽獸者幾希，庶民去之，君子存之。", "mencius", "CONFUCIAN", -4),
    ("得道者多助，失道者寡助。", "mencius", "CONFUCIAN", -4),
    ("天時不如地利，地利不如人和。", "mencius", "CONFUCIAN", -4),
    ("老吾老，以及人之老；幼吾幼，以及人之幼。", "mencius", "CONFUCIAN", -4),
    ("民為貴，社稷次之，君為輕。", "mencius", "CONFUCIAN", -4),
    ("生於憂患，死於安樂。", "mencius", "CONFUCIAN", -4),
    ("富貴不能淫，貧賤不能移，威武不能屈，此之謂大丈夫。", "mencius", "CONFUCIAN", -4),
    ("窮則獨善其身，達則兼善天下。", "mencius", "CONFUCIAN", -4),
    # Zhuangzi - Daoist philosophy
    ("北冥有魚，其名為鯤。鯤之大，不知其幾千里也。", "zhuangzi", "DAOIST", -4),
    ("昔者莊周夢為胡蝶，栩栩然胡蝶也。不知周之夢為胡蝶與，胡蝶之夢為周與？", "zhuangzi", "DAOIST", -4),
    ("吾生也有涯，而知也無涯。以有涯隨無涯，殆已！", "zhuangzi", "DAOIST", -4),
    ("泉涸，魚相與處於陸，相呴以濕，相濡以沫，不如相忘於江湖。", "zhuangzi", "DAOIST", -4),
    ("人皆知有用之用，而莫知無用之用也。", "zhuangzi", "DAOIST", -4),
    # Xunzi - Confucian philosophy (human nature is bad, needs cultivation)
    ("人之性惡，其善者偽也。", "xunzi", "CONFUCIAN", -3),
    ("故木受繩則直，金就礪則利，君子博學而日參省乎己，則知明而行無過矣。", "xunzi", "CONFUCIAN", -3),
    ("青，取之於藍，而青於藍；冰，水為之，而寒於水。", "xunzi", "CONFUCIAN", -3),
    ("不積跬步，無以至千里；不積小流，無以成江海。", "xunzi", "CONFUCIAN", -3),
    ("鍥而舍之，朽木不折；鍥而不舍，金石可鏤。", "xunzi", "CONFUCIAN", -3),
]

with open(chinese_dir / 'chinese_native.json', 'w', encoding='utf-8') as f:
    data = [{'id': f'chinese_{i}', 'text': t, 'source': s, 'period': p, 'century': c} 
            for i, (t, s, p, c) in enumerate(CHINESE_TEXTS)]
    json.dump(data, f, ensure_ascii=False, indent=2)
print(f"Saved {len(CHINESE_TEXTS)} Chinese passages")

# ========== ISLAMIC ==========
print("\n" + "="*60)
print("3. ISLAMIC TEXTS (Arabic)")
print("="*60)

islamic_dir = Path('data/raw/islamic')
islamic_dir.mkdir(parents=True, exist_ok=True)

# Key Quranic and Hadith texts in Arabic
ISLAMIC_TEXTS = [
    # Quran - core moral teachings
    ("بِسْمِ اللَّهِ الرَّحْمَٰنِ الرَّحِيمِ", "quran_1_1", "QURANIC", 7),
    ("وَقَضَىٰ رَبُّكَ أَلَّا تَعْبُدُوا إِلَّا إِيَّاهُ وَبِالْوَالِدَيْنِ إِحْسَانًا", "quran_17_23", "QURANIC", 7),
    ("وَلَا تَقْتُلُوا النَّفْسَ الَّتِي حَرَّمَ اللَّهُ إِلَّا بِالْحَقِّ", "quran_17_33", "QURANIC", 7),
    ("وَأَوْفُوا بِالْعَهْدِ إِنَّ الْعَهْدَ كَانَ مَسْئُولًا", "quran_17_34", "QURANIC", 7),
    ("وَلَا تَقْرَبُوا مَالَ الْيَتِيمِ إِلَّا بِالَّتِي هِيَ أَحْسَنُ", "quran_17_34b", "QURANIC", 7),
    ("إِنَّ اللَّهَ يَأْمُرُ بِالْعَدْلِ وَالْإِحْسَانِ وَإِيتَاءِ ذِي الْقُرْبَىٰ", "quran_16_90", "QURANIC", 7),
    ("يَا أَيُّهَا الَّذِينَ آمَنُوا كُونُوا قَوَّامِينَ بِالْقِسْطِ شُهَدَاءَ لِلَّهِ", "quran_4_135", "QURANIC", 7),
    ("وَتَعَاوَنُوا عَلَى الْبِرِّ وَالتَّقْوَىٰ وَلَا تَعَاوَنُوا عَلَى الْإِثْمِ وَالْعُدْوَانِ", "quran_5_2", "QURANIC", 7),
    ("مَنْ قَتَلَ نَفْسًا بِغَيْرِ نَفْسٍ أَوْ فَسَادٍ فِي الْأَرْضِ فَكَأَنَّمَا قَتَلَ النَّاسَ جَمِيعًا", "quran_5_32a", "QURANIC", 7),
    ("وَمَنْ أَحْيَاهَا فَكَأَنَّمَا أَحْيَا النَّاسَ جَمِيعًا", "quran_5_32b", "QURANIC", 7),
    ("وَلَا تَأْكُلُوا أَمْوَالَكُم بَيْنَكُم بِالْبَاطِلِ", "quran_2_188", "QURANIC", 7),
    ("يَا أَيُّهَا الَّذِينَ آمَنُوا أَوْفُوا بِالْعُقُودِ", "quran_5_1", "QURANIC", 7),
    ("وَالَّذِينَ هُمْ لِأَمَانَاتِهِمْ وَعَهْدِهِمْ رَاعُونَ", "quran_23_8", "QURANIC", 7),
    ("إِنَّمَا الْمُؤْمِنُونَ إِخْوَةٌ فَأَصْلِحُوا بَيْنَ أَخَوَيْكُمْ", "quran_49_10", "QURANIC", 7),
    ("وَلَا يَغْتَب بَّعْضُكُم بَعْضًا", "quran_49_12", "QURANIC", 7),
    ("يَا أَيُّهَا النَّاسُ إِنَّا خَلَقْنَاكُم مِّن ذَكَرٍ وَأُنثَىٰ وَجَعَلْنَاكُمْ شُعُوبًا وَقَبَائِلَ لِتَعَارَفُوا", "quran_49_13", "QURANIC", 7),
    ("لَّا يَنْهَاكُمُ اللَّهُ عَنِ الَّذِينَ لَمْ يُقَاتِلُوكُمْ فِي الدِّينِ وَلَمْ يُخْرِجُوكُم مِّن دِيَارِكُمْ أَن تَبَرُّوهُمْ وَتُقْسِطُوا إِلَيْهِمْ", "quran_60_8", "QURANIC", 7),
    ("وَإِذَا حَكَمْتُم بَيْنَ النَّاسِ أَن تَحْكُمُوا بِالْعَدْلِ", "quran_4_58", "QURANIC", 7),
    ("خُذِ الْعَفْوَ وَأْمُرْ بِالْعُرْفِ وَأَعْرِضْ عَنِ الْجَاهِلِينَ", "quran_7_199", "QURANIC", 7),
    ("ادْفَعْ بِالَّتِي هِيَ أَحْسَنُ فَإِذَا الَّذِي بَيْنَكَ وَبَيْنَهُ عَدَاوَةٌ كَأَنَّهُ وَلِيٌّ حَمِيمٌ", "quran_41_34", "QURANIC", 7),
    # Hadith - Prophetic traditions on ethics
    ("إنما الأعمال بالنيات وإنما لكل امرئ ما نوى", "bukhari_1", "HADITH", 9),
    ("لا يؤمن أحدكم حتى يحب لأخيه ما يحب لنفسه", "bukhari_13", "HADITH", 9),
    ("من كان يؤمن بالله واليوم الآخر فليقل خيرا أو ليصمت", "bukhari_6018", "HADITH", 9),
    ("المسلم من سلم المسلمون من لسانه ويده", "bukhari_10", "HADITH", 9),
    ("لا ضرر ولا ضرار", "ibn_majah_2341", "HADITH", 9),
    ("ارحموا من في الأرض يرحمكم من في السماء", "tirmidhi_1924", "HADITH", 9),
    ("الدين النصيحة", "muslim_55", "HADITH", 9),
    ("من رأى منكم منكرا فليغيره بيده فإن لم يستطع فبلسانه فإن لم يستطع فبقلبه", "muslim_49", "HADITH", 9),
    ("لا يحل مال امرئ مسلم إلا بطيب نفس منه", "ahmad_20172", "HADITH", 9),
    ("كلكم راع وكلكم مسؤول عن رعيته", "bukhari_7138", "HADITH", 9),
    ("إن الله كتب الإحسان على كل شيء", "muslim_1955", "HADITH", 9),
    ("ليس منا من لم يرحم صغيرنا ويوقر كبيرنا", "tirmidhi_1919", "HADITH", 9),
    ("المؤمن للمؤمن كالبنيان يشد بعضه بعضا", "bukhari_481", "HADITH", 9),
    ("من غشنا فليس منا", "muslim_101", "HADITH", 9),
    ("اتق الله حيثما كنت وأتبع السيئة الحسنة تمحها وخالق الناس بخلق حسن", "tirmidhi_1987", "HADITH", 9),
    ("إن الله لا ينظر إلى صوركم وأموالكم ولكن ينظر إلى قلوبكم وأعمالكم", "muslim_2564", "HADITH", 9),
    ("أحب الناس إلى الله أنفعهم للناس", "tabarani", "HADITH", 9),
    ("خيركم خيركم لأهله وأنا خيركم لأهلي", "tirmidhi_3895", "HADITH", 9),
    ("ما نقصت صدقة من مال", "muslim_2588", "HADITH", 9),
    ("الظلم ظلمات يوم القيامة", "bukhari_2447", "HADITH", 9),
]

with open(islamic_dir / 'islamic_native.json', 'w', encoding='utf-8') as f:
    data = [{'id': f'islamic_{i}', 'text': t, 'source': s, 'period': p, 'century': c}
            for i, (t, s, p, c) in enumerate(ISLAMIC_TEXTS)]
    json.dump(data, f, ensure_ascii=False, indent=2)
print(f"Saved {len(ISLAMIC_TEXTS)} Arabic passages")

# ========== DEAR ABBY ==========
print("\n" + "="*60)
print("4. DEAR ABBY (English)")
print("="*60)

if not os.path.exists('sqnd-probe-data'):
    subprocess.run(['git', 'clone', '--depth', '1', 'https://github.com/ahb-sjsu/sqnd-probe.git', 'sqnd-probe-data'])

!cp sqnd-probe-data/dear_abby_data/raw_da_qs.csv data/raw/dear_abby.csv 2>/dev/null || echo "Using existing"
df = pd.read_csv('data/raw/dear_abby.csv')
print(f"Loaded {len(df):,} Dear Abby entries")

print_resources("After downloads")
mark_task("Download corpora", "done")


In [None]:
#@title 3. Define Native-Language Patterns + Text Normalization { display-mode: "form" }
#@markdown Native patterns for each language + proper Unicode normalization.

import re
import unicodedata
from enum import Enum, auto
from dataclasses import dataclass, asdict
from typing import Dict, List, Set, Tuple

print("="*60)
print("TEXT NORMALIZATION & NATIVE PATTERNS")
print("="*60)
print()

# ============================================================
# TEXT NORMALIZATION (Critical for Hebrew/Arabic pattern matching)
# ============================================================

def normalize_hebrew(text: str) -> str:
    """Normalize Hebrew text for pattern matching."""
    # Unicode NFKC normalization
    text = unicodedata.normalize('NFKC', text)
    
    # Remove Hebrew diacritics (nikud)
    # Range: U+0591 to U+05C7 (cantillation marks and points)
    text = re.sub(r'[\u0591-\u05C7]', '', text)
    
    # Normalize final letters to regular forms
    finals_map = {
        'ך': 'כ',  # final kaf -> kaf
        'ם': 'מ',  # final mem -> mem
        'ן': 'נ',  # final nun -> nun
        'ף': 'פ',  # final pe -> pe
        'ץ': 'צ',  # final tsadi -> tsadi
    }
    for final, regular in finals_map.items():
        text = text.replace(final, regular)
    
    return text

def normalize_arabic(text: str) -> str:
    """Normalize Arabic text for pattern matching."""
    # Unicode NFKC normalization
    text = unicodedata.normalize('NFKC', text)
    
    # Remove Arabic diacritics (tashkeel)
    # Range: U+064B to U+065F (Arabic marks)
    text = re.sub(r'[\u064B-\u065F]', '', text)
    
    # Remove tatweel (kashida)
    text = text.replace('\u0640', '')
    
    # Normalize alef variants
    alef_variants = ['أ', 'إ', 'آ', 'ٱ']  # alef with hamza above/below, madda, wasla
    for variant in alef_variants:
        text = text.replace(variant, 'ا')  # normalize to plain alef
    
    # Normalize teh marbuta to heh
    text = text.replace('ة', 'ه')
    
    # Normalize alef maksura to yeh
    text = text.replace('ى', 'ي')
    
    return text

def normalize_text(text: str, language: str) -> str:
    """Normalize text based on language."""
    if language == 'hebrew':
        return normalize_hebrew(text)
    elif language == 'aramaic':
        # Aramaic uses Hebrew script, same normalization
        return normalize_hebrew(text)
    elif language == 'arabic':
        return normalize_arabic(text)
    elif language == 'classical_chinese':
        # Chinese: just NFKC normalization
        return unicodedata.normalize('NFKC', text)
    else:
        # English and others: NFKC + lowercase for matching
        return unicodedata.normalize('NFKC', text.lower())

print("Normalization functions defined:")
print("  - normalize_hebrew(): removes nikud, normalizes finals")
print("  - normalize_arabic(): removes tashkeel, normalizes alef/teh")
print("  - normalize_text(): dispatches by language")
print()

# Test normalization
test_hebrew = "הָאָדָם"  # with nikud
test_arabic = "الْإِنْسَانُ"  # with tashkeel
print(f"Hebrew normalization: '{test_hebrew}' -> '{normalize_hebrew(test_hebrew)}'")
print(f"Arabic normalization: '{test_arabic}' -> '{normalize_arabic(test_arabic)}'")
print()

# ============================================================
# BOND AND HOHFELD DEFINITIONS
# ============================================================

class BondType(Enum):
    HARM_PREVENTION = auto()
    RECIPROCITY = auto()
    AUTONOMY = auto()
    PROPERTY = auto()
    FAMILY = auto()
    AUTHORITY = auto()
    CARE = auto()
    FAIRNESS = auto()
    CONTRACT = auto()
    NONE = auto()

class HohfeldState(Enum):
    OBLIGATION = auto()
    RIGHT = auto()
    LIBERTY = auto()
    NO_RIGHT = auto()

# ============================================================
# BOND PATTERNS BY LANGUAGE (using normalized forms)
# ============================================================
# Patterns should match NORMALIZED text (no diacritics, normalized letters)

ALL_BOND_PATTERNS = {
    'hebrew': {
        BondType.HARM_PREVENTION: [r'הרג', r'רצח', r'נזק', r'הכה', r'הציל', r'שמר', r'פקוח.נפש'],
        BondType.RECIPROCITY: [r'גמול', r'השיב', r'פרע', r'נתן.*קבל', r'מדה.כנגד'],
        BondType.AUTONOMY: [r'בחר', r'רצון', r'חפש', r'עצמ'],
        BondType.PROPERTY: [r'קנה', r'מכר', r'גזל', r'גנב', r'ממון', r'נכס', r'ירש'],
        BondType.FAMILY: [r'אב', r'אמ', r'בנ', r'כבד.*אב', r'כבד.*אמ', r'משפחה', r'אח', r'אחות'],
        BondType.AUTHORITY: [r'מלכ', r'שופט', r'צוה', r'תורה', r'מצוה', r'דין', r'חק'],
        BondType.CARE: [r'חסד', r'רחמ', r'עזר', r'תמכ', r'צדקה'],
        BondType.FAIRNESS: [r'צדק', r'משפט', r'ישר', r'שוה'],
        BondType.CONTRACT: [r'ברית', r'נדר', r'שבוע', r'התחיב', r'ערב'],
    },
    'aramaic': {
        BondType.HARM_PREVENTION: [r'קטל', r'נזק', r'חבל', r'שזיב', r'פצי'],
        BondType.RECIPROCITY: [r'פרע', r'שלמ', r'אגר'],
        BondType.AUTONOMY: [r'צבי', r'רעו'],
        BondType.PROPERTY: [r'זבנ', r'קנה', r'גזל', r'ממונא', r'נכסי'],
        BondType.FAMILY: [r'אבא', r'אמא', r'ברא', r'ברתא', r'יקר', r'אחא'],
        BondType.AUTHORITY: [r'מלכא', r'דינא', r'דיינא', r'פקודא', r'אורית'],
        BondType.CARE: [r'חסד', r'רחמ', r'סעד'],
        BondType.FAIRNESS: [r'דינא', r'קשוט', r'תריצ'],
        BondType.CONTRACT: [r'קימא', r'שבועה', r'נדרא', r'ערבא'],
    },
    'classical_chinese': {
        BondType.HARM_PREVENTION: [r'殺', r'害', r'傷', r'救', r'護', r'衛', r'暴'],
        BondType.RECIPROCITY: [r'報', r'還', r'償', r'酬', r'答'],
        BondType.AUTONOMY: [r'自', r'由', r'任', r'意', r'志'],
        BondType.PROPERTY: [r'財', r'物', r'產', r'盜', r'竊', r'賣', r'買'],
        BondType.FAMILY: [r'孝', r'父', r'母', r'親', r'子', r'弟', r'兄', r'家'],
        BondType.AUTHORITY: [r'君', r'臣', r'王', r'命', r'令', r'法', r'治'],
        BondType.CARE: [r'仁', r'愛', r'慈', r'惠', r'恩', r'憐'],
        BondType.FAIRNESS: [r'義', r'正', r'公', r'平', r'均'],
        BondType.CONTRACT: [r'約', r'盟', r'誓', r'諾', r'信'],
    },
    'arabic': {
        # Patterns for NORMALIZED Arabic (no diacritics)
        BondType.HARM_PREVENTION: [r'قتل', r'ضرر', r'اذ[يى]', r'ظلم', r'انقذ', r'حفظ', r'امان'],
        BondType.RECIPROCITY: [r'جزا', r'رد', r'قصاص', r'مثل', r'عوض'],
        BondType.AUTONOMY: [r'حر', r'ارادة', r'اختيار', r'مشيئ'],
        BondType.PROPERTY: [r'مال', r'ملك', r'سرق', r'بيع', r'شرا', r'ميراث', r'غصب'],
        BondType.FAMILY: [r'والد', r'ابو', r'ام', r'ابن', r'بنت', r'اهل', r'قرب[يى]', r'رحم'],
        BondType.AUTHORITY: [r'طاع', r'امر', r'حكم', r'سلطان', r'خليف', r'امام', r'شريع'],
        BondType.CARE: [r'رحم', r'احسان', r'عطف', r'صدق', r'زكا'],
        BondType.FAIRNESS: [r'عدل', r'قسط', r'حق', r'انصاف', r'سو[يى]'],
        BondType.CONTRACT: [r'عهد', r'عقد', r'نذر', r'يمين', r'وفا', r'امان'],
    },
    'english': {
        BondType.HARM_PREVENTION: [r'\bkill', r'\bmurder', r'\bharm', r'\bhurt', r'\bsave', r'\bprotect', r'\bviolence'],
        BondType.RECIPROCITY: [r'\breturn', r'\brepay', r'\bexchange', r'\bgive.*back', r'\breciproc'],
        BondType.AUTONOMY: [r'\bfree', r'\bchoice', r'\bchoose', r'\bconsent', r'\bautonomy', r'\bright to'],
        BondType.PROPERTY: [r'\bsteal', r'\btheft', r'\bown', r'\bproperty', r'\bbelong', r'\binherit'],
        BondType.FAMILY: [r'\bfather', r'\bmother', r'\bparent', r'\bchild', r'\bfamily', r'\bhonor.*parent'],
        BondType.AUTHORITY: [r'\bobey', r'\bcommand', r'\bauthority', r'\blaw', r'\brule', r'\bgovern'],
        BondType.CARE: [r'\bcare', r'\bhelp', r'\bkind', r'\bcompassion', r'\bcharity', r'\bmercy'],
        BondType.FAIRNESS: [r'\bfair', r'\bjust', r'\bequal', r'\bequity', r'\bright\b'],
        BondType.CONTRACT: [r'\bpromise', r'\bcontract', r'\bagreem', r'\bvow', r'\boath', r'\bcommit'],
    },
}

ALL_HOHFELD_PATTERNS = {
    'hebrew': {
        HohfeldState.OBLIGATION: [r'חייב', r'צריכ', r'מוכרח', r'מצווה'],
        HohfeldState.RIGHT: [r'זכות', r'רשאי', r'זכאי', r'מגיע'],
        HohfeldState.LIBERTY: [r'מותר', r'רשות', r'פטור', r'יכול'],
        HohfeldState.NO_RIGHT: [r'אסור', r'אינו רשאי', r'אין.*זכות'],
    },
    'aramaic': {
        HohfeldState.OBLIGATION: [r'חייב', r'מחויב', r'בעי'],
        HohfeldState.RIGHT: [r'זכות', r'רשאי', r'זכי'],
        HohfeldState.LIBERTY: [r'שרי', r'מותר', r'פטור'],
        HohfeldState.NO_RIGHT: [r'אסור', r'לא.*רשאי'],
    },
    'classical_chinese': {
        HohfeldState.OBLIGATION: [r'必', r'須', r'當', r'應', r'宜'],
        HohfeldState.RIGHT: [r'可', r'得', r'權', r'宜'],
        HohfeldState.LIBERTY: [r'許', r'任', r'聽', r'免'],
        HohfeldState.NO_RIGHT: [r'不可', r'勿', r'禁', r'莫', r'非'],
    },
    'arabic': {
        HohfeldState.OBLIGATION: [r'يجب', r'واجب', r'فرض', r'لازم', r'وجوب'],
        HohfeldState.RIGHT: [r'حق', r'يحق', r'جائز', r'يجوز'],
        HohfeldState.LIBERTY: [r'مباح', r'حلال', r'جائز', r'اباح'],
        HohfeldState.NO_RIGHT: [r'حرام', r'محرم', r'ممنوع', r'لا يجوز', r'نه[يى]'],
    },
    'english': {
        HohfeldState.OBLIGATION: [r'\bmust\b', r'\bshall\b', r'\bobligat', r'\bduty', r'\brequir'],
        HohfeldState.RIGHT: [r'\bright\b', r'\bentitle', r'\bdeserve', r'\bclaim'],
        HohfeldState.LIBERTY: [r'\bmay\b', r'\bpermit', r'\ballow', r'\bfree to'],
        HohfeldState.NO_RIGHT: [r'\bforbid', r'\bprohibit', r'\bmust not', r'\bshall not'],
    },
}

print("Native patterns defined for 5 languages:")
for lang in ALL_BOND_PATTERNS:
    n_bond = sum(len(p) for p in ALL_BOND_PATTERNS[lang].values())
    n_hohfeld = sum(len(p) for p in ALL_HOHFELD_PATTERNS.get(lang, {}).values())
    print(f"  {lang:20s}: {n_bond:3d} bond patterns, {n_hohfeld:2d} Hohfeld patterns")

print()
print("PATTERNS USE NORMALIZED TEXT (no diacritics)!")
print()
print("✓ Cell 3 complete")


In [None]:
#@title 4. Load Corpora and Extract Bonds (Native Patterns) { display-mode: "form" }
#@markdown Labels extracted using NATIVE patterns with proper normalization.
#@markdown **Includes label quality metrics for reviewer concerns.**

import json
import hashlib
import re
import gc
import random
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from collections import defaultdict

mark_task("Extract bonds (native)", "running")

print("="*60)
print("LOADING CORPORA & EXTRACTING BONDS (NATIVE)")
print("="*60)
print()
print("CRITICAL: Labels extracted from NATIVE text using NATIVE patterns!")
print("Text is NORMALIZED before pattern matching (no diacritics).")
print()

@dataclass
class Passage:
    id: str
    text: str  # Original text
    text_normalized: str  # Normalized for pattern matching
    language: str
    time_period: str
    century: int
    source: str
    source_type: str
    
    def to_dict(self):
        d = asdict(self)
        del d['text_normalized']  # Don't save normalized text, can regenerate
        return d

def detect_sefaria_language(text: str, category: str) -> str:
    """Detect language of Sefaria text."""
    aramaic_cats = {'Talmud', 'Bavli', 'Yerushalmi', 'Zohar'}
    if category in aramaic_cats:
        return 'aramaic'
    arabic_chars = sum(1 for c in text if '\u0600' <= c <= '\u06FF')
    if arabic_chars > len(text) * 0.3:
        return 'arabic'
    return 'hebrew'

# ============================================================
# BOND EXTRACTION WITH PATTERN TRACKING
# ============================================================

# Track which patterns fire (for quality analysis)
pattern_hits = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))  # lang -> bond -> pattern -> count
audit_samples = defaultdict(list)  # lang -> list of (text, bond, pattern) for manual audit

def extract_bonds_native(text: str, text_normalized: str, language: str) -> Dict:
    """Extract bonds using NATIVE language patterns on NORMALIZED text."""
    bond_patterns = ALL_BOND_PATTERNS.get(language, {})
    hohfeld_patterns = ALL_HOHFELD_PATTERNS.get(language, {})
    
    # Find matching bonds (on normalized text)
    found_bonds = []
    matched_patterns = []
    
    for bond_type, patterns in bond_patterns.items():
        for pattern in patterns:
            if re.search(pattern, text_normalized):
                found_bonds.append(bond_type.name)
                matched_patterns.append((bond_type.name, pattern))
                # Track pattern hits
                pattern_hits[language][bond_type.name][pattern] += 1
                break  # One pattern per bond type is enough
    
    if not found_bonds:
        found_bonds = ['NONE']
        matched_patterns = [('NONE', None)]
    
    # Find Hohfeld state
    hohfeld = None
    hohfeld_pattern = None
    for state, patterns in hohfeld_patterns.items():
        for pattern in patterns:
            if re.search(pattern, text_normalized):
                hohfeld = state.name
                hohfeld_pattern = pattern
                break
        if hohfeld:
            break
    
    return {
        'primary_bond': found_bonds[0],
        'all_bonds': found_bonds,
        'hohfeld': hohfeld,
        'language': language,
        'matched_patterns': matched_patterns,  # For audit
        'hohfeld_pattern': hohfeld_pattern,
    }

# Time period mappings
CATEGORY_TO_PERIOD = {
    'Tanakh': 'BIBLICAL', 'Torah': 'BIBLICAL', 'Prophets': 'BIBLICAL', 'Writings': 'BIBLICAL',
    'Mishnah': 'TANNAITIC', 'Tosefta': 'TANNAITIC',
    'Talmud': 'AMORAIC', 'Bavli': 'AMORAIC', 'Yerushalmi': 'AMORAIC', 'Midrash': 'AMORAIC',
    'Halakhah': 'RISHONIM', 'Kabbalah': 'RISHONIM', 'Philosophy': 'RISHONIM',
    'Chasidut': 'ACHRONIM', 'Musar': 'ACHRONIM', 'Responsa': 'ACHRONIM',
}

PERIOD_TO_CENTURY = {
    'BIBLICAL': -6, 'TANNAITIC': 2, 'AMORAIC': 4, 'RISHONIM': 12, 'ACHRONIM': 17,
    'CONFUCIAN': -4, 'DAOIST': -5, 'QURANIC': 7, 'HADITH': 9, 'DEAR_ABBY': 20,
}

all_passages = []

# ========== LOAD SEFARIA ==========
print("Loading Sefaria (Hebrew/Aramaic)...")
sefaria_path = Path('data/raw/Sefaria-Export/json')
if sefaria_path.exists():
    json_files = list(sefaria_path.rglob('*.json'))
    print(f"  Found {len(json_files):,} files")
    
    for json_file in tqdm(json_files, desc="Sefaria", unit="file"):
        try:
            with open(json_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except:
            continue
        
        category = str(json_file.relative_to(sefaria_path).parts[0]) if json_file.relative_to(sefaria_path).parts else "unknown"
        period = CATEGORY_TO_PERIOD.get(category, 'AMORAIC')
        century = PERIOD_TO_CENTURY.get(period, 4)
        
        if isinstance(data, dict):
            hebrew = data.get('he', [])
            
            def flatten_hebrew(h, ref=""):
                if isinstance(h, str):
                    h_clean = re.sub(r'<[^>]+>', '', h).strip()
                    if 20 <= len(h_clean) <= 2000:
                        lang = detect_sefaria_language(h_clean, category)
                        pid = hashlib.md5(f"{json_file.stem}:{ref}:{h_clean[:30]}".encode()).hexdigest()[:12]
                        # NORMALIZE TEXT
                        h_normalized = normalize_text(h_clean, lang)
                        return [Passage(
                            id=f"sefaria_{pid}",
                            text=h_clean,
                            text_normalized=h_normalized,
                            language=lang,
                            time_period=period,
                            century=century,
                            source=f"{json_file.stem} {ref}".strip(),
                            source_type='sefaria'
                        )]
                    return []
                elif isinstance(h, list):
                    result = []
                    for i, hh in enumerate(h):
                        result.extend(flatten_hebrew(hh, f"{ref}.{i+1}" if ref else str(i+1)))
                    return result
                return []
            
            all_passages.extend(flatten_hebrew(hebrew))

print(f"  Loaded {len(all_passages):,} Sefaria passages")

# ========== LOAD CHINESE ==========
print("\nLoading Chinese classics...")
with open('data/raw/chinese/chinese_native.json', 'r', encoding='utf-8') as f:
    chinese_data = json.load(f)

for item in chinese_data:
    text_norm = normalize_text(item['text'], 'classical_chinese')
    all_passages.append(Passage(
        id=item['id'],
        text=item['text'],
        text_normalized=text_norm,
        language='classical_chinese',
        time_period=item['period'],
        century=item['century'],
        source=item['source'],
        source_type='chinese'
    ))
print(f"  Loaded {len(chinese_data)} Chinese passages")

# ========== LOAD ISLAMIC ==========
print("\nLoading Islamic texts...")
with open('data/raw/islamic/islamic_native.json', 'r', encoding='utf-8') as f:
    islamic_data = json.load(f)

for item in islamic_data:
    text_norm = normalize_text(item['text'], 'arabic')
    all_passages.append(Passage(
        id=item['id'],
        text=item['text'],
        text_normalized=text_norm,
        language='arabic',
        time_period=item['period'],
        century=item['century'],
        source=item['source'],
        source_type='islamic'
    ))
print(f"  Loaded {len(islamic_data)} Arabic passages")

# ========== LOAD DEAR ABBY ==========
print("\nLoading Dear Abby (English)...")
df = pd.read_csv('data/raw/dear_abby.csv')
abby_count = 0
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Dear Abby", unit="row"):
    question = str(row.get('question_only', ''))
    if not question or question == 'nan' or len(question) < 50 or len(question) > 2000:
        continue
    
    year = int(row.get('year', 1990))
    pid = hashlib.md5(f"abby:{idx}:{question[:50]}".encode()).hexdigest()[:12]
    text_norm = normalize_text(question, 'english')
    
    all_passages.append(Passage(
        id=f"abby_{pid}",
        text=question,
        text_normalized=text_norm,
        language='english',
        time_period='DEAR_ABBY',
        century=20 if year < 2000 else 21,
        source=f"Dear Abby {year}",
        source_type='dear_abby'
    ))
    abby_count += 1
print(f"  Loaded {abby_count:,} English passages")

print()
print("="*60)
print(f"TOTAL PASSAGES: {len(all_passages):,}")
print("="*60)

# ========== EXTRACT BONDS WITH QUALITY TRACKING ==========
print()
print("="*60)
print("EXTRACTING BONDS (NATIVE PATTERNS ON NORMALIZED TEXT)")
print("="*60)

bond_counts = defaultdict(lambda: defaultdict(int))
n_mismatches = 0

# Collect audit samples (200 per language)
AUDIT_SAMPLE_SIZE = 200

with open('data/processed/passages.jsonl', 'w') as f_pass, \
     open('data/processed/bonds.jsonl', 'w') as f_bond:
    
    for p in tqdm(all_passages, desc="Extracting", unit="passage"):
        # NATIVE extraction on NORMALIZED text
        bonds = extract_bonds_native(p.text, p.text_normalized, p.language)
        
        bond_counts[p.language][bonds['primary_bond']] += 1
        
        # Collect audit samples
        if len(audit_samples[p.language]) < AUDIT_SAMPLE_SIZE:
            audit_samples[p.language].append({
                'text': p.text[:200],
                'bond': bonds['primary_bond'],
                'pattern': bonds['matched_patterns'][0][1] if bonds['matched_patterns'] else None,
            })
        
        f_pass.write(json.dumps(p.to_dict()) + '\n')
        f_bond.write(json.dumps({
            'passage_id': p.id,
            'bonds': {
                'primary_bond': bonds['primary_bond'],
                'all_bonds': bonds['all_bonds'],
                'hohfeld': bonds['hohfeld'],
                'language': bonds['language'],
            }
        }) + '\n')

# ============================================================
# LABEL QUALITY METRICS (Reviewer Concern #1)
# ============================================================
print()
print("="*60)
print("LABEL QUALITY METRICS")
print("="*60)

# 1. Coverage (% that are NOT 'NONE')
print("\n1. COVERAGE (% with labels, not NONE):")
for lang in sorted(bond_counts.keys()):
    total = sum(bond_counts[lang].values())
    none_count = bond_counts[lang].get('NONE', 0)
    coverage = (total - none_count) / total * 100 if total > 0 else 0
    print(f"   {lang:20s}: {coverage:5.1f}% labeled ({total - none_count:,}/{total:,})")

# 2. Top triggering patterns per class per language
print("\n2. TOP TRIGGERING PATTERNS:")
for lang in sorted(pattern_hits.keys()):
    print(f"\n   {lang.upper()}:")
    for bond in sorted(pattern_hits[lang].keys()):
        if bond == 'NONE':
            continue
        patterns = pattern_hits[lang][bond]
        top = sorted(patterns.items(), key=lambda x: -x[1])[:3]
        if top:
            top_str = ", ".join([f"'{p}' ({c})" for p, c in top])
            print(f"     {bond:20s}: {top_str}")

# 3. Save audit samples for manual review
print("\n3. AUDIT SAMPLES (saved for manual review):")
audit_file = 'data/processed/audit_samples.json'
with open(audit_file, 'w', encoding='utf-8') as f:
    json.dump(dict(audit_samples), f, indent=2, ensure_ascii=False)
print(f"   Saved {sum(len(v) for v in audit_samples.values())} samples to {audit_file}")
for lang, samples in audit_samples.items():
    print(f"   {lang:20s}: {len(samples)} samples")

# 4. Bond distribution summary
print("\n4. BOND DISTRIBUTION BY LANGUAGE:")
for lang in sorted(bond_counts.keys()):
    print(f"\n  {lang.upper()}:")
    total = sum(bond_counts[lang].values())
    for bond, cnt in sorted(bond_counts[lang].items(), key=lambda x: -x[1])[:5]:
        pct = 100 * cnt / total
        print(f"    {bond:20s}: {cnt:>6,} ({pct:5.1f}%)")

# Save pattern statistics
pattern_stats = {
    lang: {bond: dict(patterns) for bond, patterns in bonds.items()}
    for lang, bonds in pattern_hits.items()
}
with open('data/processed/pattern_stats.json', 'w') as f:
    json.dump(pattern_stats, f, indent=2)
print(f"\nPattern statistics saved to data/processed/pattern_stats.json")

n_passages = len(all_passages)
del all_passages
gc.collect()

print_resources("After extraction")
mark_task("Extract bonds (native)", "done")


In [None]:
#@title 5. Generate Splits { display-mode: "form" }
#@markdown Language-family and temporal splits.

import random
import json
import shutil
from collections import defaultdict
from tqdm.auto import tqdm

random.seed(42)

mark_task("Generate splits", "running")

print("="*60)
print("GENERATING SPLITS")
print("="*60)

# Read metadata
passage_meta = []
with open('data/processed/passages.jsonl', 'r') as f:
    for line in tqdm(f, desc="Reading", unit="line"):
        p = json.loads(line)
        passage_meta.append(p)

print(f"Total: {len(passage_meta):,}")

by_lang = defaultdict(list)
for p in passage_meta:
    by_lang[p['language']].append(p['id'])

print("\nBy language:")
for lang, ids in sorted(by_lang.items(), key=lambda x: -len(x[1])):
    print(f"  {lang:20s}: {len(ids):>8,}")

# ========== SPLIT 1: Hebrew -> All Others ==========
print("\n" + "-"*60)
print("SPLIT 1: HEBREW → ALL OTHERS")
print("-"*60)

hebrew_ids = by_lang['hebrew']
other_ids = [p['id'] for p in passage_meta if p['language'] != 'hebrew']
random.shuffle(hebrew_ids)
random.shuffle(other_ids)

split_hebrew = {
    'name': 'hebrew_to_others',
    'train_ids': hebrew_ids,
    'valid_ids': other_ids[:min(5000, len(other_ids)//10)],
    'test_ids': other_ids[min(5000, len(other_ids)//10):],
    'train_size': len(hebrew_ids),
    'valid_size': min(5000, len(other_ids)//10),
    'test_size': len(other_ids) - min(5000, len(other_ids)//10),
}
print(f"  Train (Hebrew): {split_hebrew['train_size']:,}")
print(f"  Test (Others):  {split_hebrew['test_size']:,}")

# ========== SPLIT 2: Semitic -> East Asian + English ==========
print("\n" + "-"*60)
print("SPLIT 2: SEMITIC → CHINESE + ENGLISH")
print("-"*60)

semitic_ids = by_lang['hebrew'] + by_lang.get('aramaic', []) + by_lang.get('arabic', [])
non_semitic_ids = by_lang.get('classical_chinese', []) + by_lang.get('english', [])
random.shuffle(semitic_ids)
random.shuffle(non_semitic_ids)

split_semitic = {
    'name': 'semitic_to_non_semitic',
    'train_ids': semitic_ids,
    'valid_ids': non_semitic_ids[:len(non_semitic_ids)//10],
    'test_ids': non_semitic_ids[len(non_semitic_ids)//10:],
    'train_size': len(semitic_ids),
    'valid_size': len(non_semitic_ids)//10,
    'test_size': len(non_semitic_ids) - len(non_semitic_ids)//10,
}
print(f"  Train (Semitic): {split_semitic['train_size']:,}")
print(f"  Test (Non-Semitic): {split_semitic['test_size']:,}")

# ========== SPLIT 3: Ancient -> Modern ==========
print("\n" + "-"*60)
print("SPLIT 3: ANCIENT → MODERN")
print("-"*60)

ancient_periods = {'BIBLICAL', 'TANNAITIC', 'AMORAIC', 'CONFUCIAN', 'DAOIST', 'QURANIC', 'HADITH'}
modern_periods = {'RISHONIM', 'ACHRONIM', 'DEAR_ABBY'}

ancient_ids = [p['id'] for p in passage_meta if p['time_period'] in ancient_periods]
modern_ids = [p['id'] for p in passage_meta if p['time_period'] in modern_periods]
random.shuffle(ancient_ids)
random.shuffle(modern_ids)

split_temporal = {
    'name': 'ancient_to_modern',
    'train_ids': ancient_ids,
    'valid_ids': modern_ids[:len(modern_ids)//10],
    'test_ids': modern_ids[len(modern_ids)//10:],
    'train_size': len(ancient_ids),
    'valid_size': len(modern_ids)//10,
    'test_size': len(modern_ids) - len(modern_ids)//10,
}
print(f"  Train (Ancient): {split_temporal['train_size']:,}")
print(f"  Test (Modern): {split_temporal['test_size']:,}")

# ========== SPLIT 4: Mixed Baseline ==========
print("\n" + "-"*60)
print("SPLIT 4: MIXED (In-Domain Baseline)")
print("-"*60)

all_ids = [p['id'] for p in passage_meta]
random.shuffle(all_ids)
n = len(all_ids)

split_mixed = {
    'name': 'mixed_baseline',
    'train_ids': all_ids[:int(0.7*n)],
    'valid_ids': all_ids[int(0.7*n):int(0.85*n)],
    'test_ids': all_ids[int(0.85*n):],
    'train_size': int(0.7*n),
    'valid_size': int(0.15*n),
    'test_size': n - int(0.85*n),
}
print(f"  Train: {split_mixed['train_size']:,}")
print(f"  Test: {split_mixed['test_size']:,}")

# ========== SPLIT 5: Chinese -> All Others ==========
print("\n" + "-"*60)
print("SPLIT 5: CHINESE → ALL OTHERS")
print("-"*60)

chinese_ids = by_lang.get('classical_chinese', [])
non_chinese_ids = [p['id'] for p in passage_meta if p['language'] != 'classical_chinese']
random.shuffle(non_chinese_ids)

split_chinese = {
    'name': 'chinese_to_others',
    'train_ids': chinese_ids,
    'valid_ids': non_chinese_ids[:min(len(chinese_ids), 500)],
    'test_ids': non_chinese_ids[min(len(chinese_ids), 500):min(len(chinese_ids), 500)+10000],
    'train_size': len(chinese_ids),
    'valid_size': min(len(chinese_ids), 500),
    'test_size': min(10000, len(non_chinese_ids) - min(len(chinese_ids), 500)),
}
print(f"  Train (Chinese): {split_chinese['train_size']:,}")
print(f"  Test (Others): {split_chinese['test_size']:,}")

# Save
all_splits = {
    'hebrew_to_others': split_hebrew,
    'semitic_to_non_semitic': split_semitic,
    'ancient_to_modern': split_temporal,
    'mixed_baseline': split_mixed,
    'chinese_to_others': split_chinese,
}

with open('data/splits/all_splits.json', 'w') as f:
    json.dump(all_splits, f, indent=2)

# Compute baselines
bond_counts = defaultdict(int)
lang_counts = defaultdict(int)
period_counts = defaultdict(int)

with open('data/processed/bonds.jsonl', 'r') as fb, \
     open('data/processed/passages.jsonl', 'r') as fp:
    for b_line, p_line in zip(fb, fp):
        b = json.loads(b_line)
        p = json.loads(p_line)
        assert b['passage_id'] == p['id']
        bond_counts[b['bonds']['primary_bond']] += 1
        lang_counts[p['language']] += 1
        period_counts[p['time_period']] += 1

print("\nID integrity: PASSED ✓")

baselines = {
    'bond_counts': dict(bond_counts),
    'language_counts': dict(lang_counts),
    'period_counts': dict(period_counts),
    'chance_bond': 1.0 / len(bond_counts),
    'chance_language': 1.0 / len(lang_counts),
    'chance_period': 1.0 / len(period_counts),
}

with open('data/splits/baselines.json', 'w') as f:
    json.dump(baselines, f, indent=2)

print(f"\nChance bond:     {baselines['chance_bond']:.1%}")
print(f"Chance language: {baselines['chance_language']:.1%}")
print(f"Chance period:   {baselines['chance_period']:.1%}")

# Save to Drive
shutil.copytree('data/processed', f'{SAVE_DIR}/processed', dirs_exist_ok=True)
shutil.copytree('data/splits', f'{SAVE_DIR}/splits', dirs_exist_ok=True)

mark_task("Generate splits", "done")


In [None]:
#@title 6. Define Model Architecture { display-mode: "form" }
#@markdown Multilingual encoder with adversarial language/time heads.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from tqdm.auto import tqdm
import json
import gc

print("="*60)
print("MODEL ARCHITECTURE")
print("="*60)

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

class GradientReversal(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()
    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_ * grad_output, None

def gradient_reversal(x, lambda_=1.0):
    return GradientReversal.apply(x, lambda_)

class BIPModel(nn.Module):
    def __init__(self, d_model=384, d_bond=64, n_bonds=10, n_langs=5, n_periods=10, n_hohfeld=4):
        super().__init__()
        self.encoder = AutoModel.from_pretrained("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
        
        self.bond_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, d_bond)
        )
        
        # Main task heads
        self.bond_classifier = nn.Linear(d_bond, n_bonds)
        self.hohfeld_classifier = nn.Linear(d_bond, n_hohfeld)
        
        # Adversarial heads (try to predict confounds from z_bond)
        self.language_classifier = nn.Linear(d_bond, n_langs)
        self.period_classifier = nn.Linear(d_bond, n_periods)
    
    def forward(self, input_ids, attention_mask, adv_lambda=1.0):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        h = (out.last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1, keepdim=True).clamp(min=1e-9)
        
        z_bond = self.bond_proj(h)
        z_bond_adv = gradient_reversal(z_bond, adv_lambda)
        
        return {
            'z_bond': z_bond,
            'bond_pred': self.bond_classifier(z_bond),
            'hohfeld_pred': self.hohfeld_classifier(z_bond),
            'language_pred': self.language_classifier(z_bond_adv),
            'period_pred': self.period_classifier(z_bond_adv),
        }
    
    def extract_z_bond(self, input_ids, attention_mask):
        with torch.no_grad():
            out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
            h = (out.last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1, keepdim=True).clamp(min=1e-9)
            return self.bond_proj(h)

# Index mappings
BOND_TO_IDX = {b.name: i for i, b in enumerate(BondType)}
LANG_TO_IDX = {'hebrew': 0, 'aramaic': 1, 'classical_chinese': 2, 'arabic': 3, 'english': 4}
PERIOD_TO_IDX = {'BIBLICAL': 0, 'TANNAITIC': 1, 'AMORAIC': 2, 'RISHONIM': 3, 'ACHRONIM': 4,
                 'CONFUCIAN': 5, 'DAOIST': 6, 'QURANIC': 7, 'HADITH': 8, 'DEAR_ABBY': 9}
HOHFELD_TO_IDX = {'OBLIGATION': 0, 'RIGHT': 1, 'LIBERTY': 2, None: 3}

class NativeDataset(Dataset):
    """Dataset using NATIVE text only."""
    def __init__(self, passage_ids: set, passages_file: str, bonds_file: str, tokenizer, max_len=128):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.data = []
        self.n_mismatches = 0  # Track ID mismatches
        
        with open(passages_file, 'r') as fp, open(bonds_file, 'r') as fb:
            for p_line, b_line in tqdm(zip(fp, fb), desc="Loading", unit="line"):
                p = json.loads(p_line)
                b = json.loads(b_line)
                # DATA INTEGRITY CHECK (Reviewer concern: silent data drops)
                if b['passage_id'] != p['id']:
                    self.n_mismatches += 1
                    continue
                    # Note: We track mismatches instead of assert to handle edge cases
                    # Final count is logged below
                if p['id'] in passage_ids:
                    self.data.append({
                        'text': p['text'][:1000],  # NATIVE text!
                        'language': p['language'],
                        'period': p['time_period'],
                        'bond': b['bonds']['primary_bond'],
                        'hohfeld': b['bonds']['hohfeld'],
                    })
        print(f"  Loaded {len(self.data):,} samples")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        enc = self.tokenizer(item['text'], truncation=True, max_length=self.max_len,
                            padding='max_length', return_tensors='pt')
        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'bond_label': BOND_TO_IDX.get(item['bond'], 9),
            'language_label': LANG_TO_IDX.get(item['language'], 4),
            'period_label': PERIOD_TO_IDX.get(item['period'], 9),
            'hohfeld_label': HOHFELD_TO_IDX.get(item['hohfeld'], 3),
            'language': item['language'],
        }

def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'bond_labels': torch.tensor([x['bond_label'] for x in batch]),
        'language_labels': torch.tensor([x['language_label'] for x in batch]),
        'period_labels': torch.tensor([x['period_label'] for x in batch]),
        'hohfeld_labels': torch.tensor([x['hohfeld_label'] for x in batch]),
        'languages': [x['language'] for x in batch],
    }

print(f"Bond types: {len(BOND_TO_IDX)}")
print(f"Languages: {len(LANG_TO_IDX)}")
print(f"Periods: {len(PERIOD_TO_IDX)}")
print(f"Batch size: {BASE_BATCH_SIZE}")


In [None]:
#@title 7. Train BIP Model { display-mode: "form" }
#@markdown Training with adversarial language/period invariance.

#@markdown **Select splits:**
TRAIN_HEBREW_TO_OTHERS = True  #@param {type:"boolean"}
TRAIN_SEMITIC_TO_NON_SEMITIC = True  #@param {type:"boolean"}
TRAIN_ANCIENT_TO_MODERN = True  #@param {type:"boolean"}
TRAIN_MIXED_BASELINE = True  #@param {type:"boolean"}

import time
import gc
from sklearn.metrics import f1_score
from transformers import AutoTokenizer

mark_task("Train BIP model", "running")

print("="*60)
print("TRAINING BIP MODEL (Native Patterns)")
print("="*60)
print()
print("CRITICAL: Model sees NATIVE text, labels from NATIVE patterns.")
print("NO English translation used anywhere!")
print()

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

splits_to_train = []
if TRAIN_HEBREW_TO_OTHERS:
    splits_to_train.append('hebrew_to_others')
if TRAIN_SEMITIC_TO_NON_SEMITIC:
    splits_to_train.append('semitic_to_non_semitic')
if TRAIN_ANCIENT_TO_MODERN:
    splits_to_train.append('ancient_to_modern')
if TRAIN_MIXED_BASELINE:
    splits_to_train.append('mixed_baseline')

print(f"Training {len(splits_to_train)} splits")

all_results = {}

for split_idx, split_name in enumerate(splits_to_train):
    split_start = time.time()
    print()
    print("="*60)
    print(f"[{split_idx+1}/{len(splits_to_train)}] {split_name}")
    print("="*60)
    
    with open('data/splits/all_splits.json', 'r') as f:
        split = json.load(f)[split_name]
    
    print(f"Train: {split['train_size']:,} | Test: {split['test_size']:,}")
    
    model = BIPModel().to(device)
    
    train_dataset = NativeDataset(set(split['train_ids']), 'data/processed/passages.jsonl',
                                   'data/processed/bonds.jsonl', tokenizer)
    test_dataset = NativeDataset(set(split['test_ids']), 'data/processed/passages.jsonl',
                                  'data/processed/bonds.jsonl', tokenizer)
    
    if len(train_dataset) == 0:
        print("ERROR: No data!")
        continue
    
    # Adjust for GPU memory
    if 'T4' in gpu_name if torch.cuda.is_available() else False:
        batch_size = min(192, max(32, len(train_dataset) // 20))  # Conservative for T4
        GRAD_ACCUM_STEPS = 2  # Simulate larger batch with accumulation
    else:
        batch_size = min(BASE_BATCH_SIZE, max(32, len(train_dataset) // 20))
        GRAD_ACCUM_STEPS = 1
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              collate_fn=collate_fn, drop_last=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size*2, shuffle=False,
                             collate_fn=collate_fn, num_workers=4, pin_memory=True)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
    
    n_epochs = 5
    
    # ADVERSARIAL WARMUP SCHEDULE (Reviewer concern #2)
    # Start with low adversarial weight, ramp up
    def get_adv_lambda(epoch, n_epochs, warmup_epochs=2):
        """Warmup adversarial loss to avoid early instability."""
        if epoch <= warmup_epochs:
            return 0.1 + 0.9 * (epoch / warmup_epochs)  # 0.1 -> 1.0
        return 1.0
    
    # Loss weights (tuned to avoid over-penalizing representation)
    BOND_WEIGHT = 1.0
    LANG_WEIGHT = 0.5  # Adversarial
    PERIOD_WEIGHT = 0.5  # Adversarial
    best_loss = float('inf')
    
    for epoch in range(1, n_epochs + 1):
        model.train()
        total_loss = 0
        n_batches = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch}", unit="batch", leave=False):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            bond_labels = batch['bond_labels'].to(device)
            language_labels = batch['language_labels'].to(device)
            period_labels = batch['period_labels'].to(device)
            
            # Get current adversarial lambda (warmup schedule)
            current_adv_lambda = get_adv_lambda(epoch, n_epochs)
            
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                out = model(input_ids, attention_mask, adv_lambda=current_adv_lambda)
                
                loss_bond = F.cross_entropy(out['bond_pred'], bond_labels)
                loss_lang = F.cross_entropy(out['language_pred'], language_labels)
                loss_period = F.cross_entropy(out['period_pred'], period_labels)
            
            # Weighted loss (Reviewer concern #2: unweighted adversarial)
            loss = BOND_WEIGHT * loss_bond + LANG_WEIGHT * loss_lang + PERIOD_WEIGHT * loss_period
            
            # Gradient accumulation for smaller GPUs
            loss = loss / GRAD_ACCUM_STEPS
            
            if USE_AMP and scaler:
                scaler.scale(loss).backward()
                if (n_batches + 1) % GRAD_ACCUM_STEPS == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
            else:
                loss.backward()
                if (n_batches + 1) % GRAD_ACCUM_STEPS == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    optimizer.zero_grad()
            
            total_loss += loss.item()
            n_batches += 1
        
        avg_loss = total_loss / n_batches
        print(f"Epoch {epoch}: Loss={avg_loss:.4f} (adv_λ={current_adv_lambda:.2f})")
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), f'models/checkpoints/best_{split_name}.pt')
            # Also save to Drive for v10 analysis
            torch.save(model.state_dict(), f'{SAVE_DIR}/best_{split_name}.pt')
    
    # Evaluate
    print("\nEvaluating...")
    model.load_state_dict(torch.load(f'models/checkpoints/best_{split_name}.pt'))
    model.eval()
    
    all_preds = {'bond': [], 'lang': []}
    all_labels = {'bond': [], 'lang': []}
    all_languages = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing", unit="batch"):
            out = model(batch['input_ids'].to(device), batch['attention_mask'].to(device), 0)
            all_preds['bond'].extend(out['bond_pred'].argmax(-1).cpu().tolist())
            all_preds['lang'].extend(out['language_pred'].argmax(-1).cpu().tolist())
            all_labels['bond'].extend(batch['bond_labels'].tolist())
            all_labels['lang'].extend(batch['language_labels'].tolist())
            all_languages.extend(batch['languages'])
    
    bond_f1 = f1_score(all_labels['bond'], all_preds['bond'], average='macro', zero_division=0)
    bond_acc = sum(p == l for p, l in zip(all_preds['bond'], all_labels['bond'])) / len(all_preds['bond'])
    lang_acc = sum(p == l for p, l in zip(all_preds['lang'], all_labels['lang'])) / len(all_preds['lang'])
    
    # Per-language F1
    lang_f1 = {}
    for lang in set(all_languages):
        mask = [l == lang for l in all_languages]
        if sum(mask) > 0:
            preds = [p for p, m in zip(all_preds['bond'], mask) if m]
            labels = [l for l, m in zip(all_labels['bond'], mask) if m]
            lang_f1[lang] = {'f1': f1_score(labels, preds, average='macro', zero_division=0), 'n': sum(mask)}
    
    all_results[split_name] = {
        'bond_f1_macro': bond_f1,
        'bond_acc': bond_acc,
        'language_acc_adversary': lang_acc,
        'per_language_f1': lang_f1,
        'training_time': time.time() - split_start
    }
    
    print(f"\n{split_name} RESULTS:")
    print(f"  Bond F1 (macro): {bond_f1:.3f}")
    print(f"  Bond accuracy:   {bond_acc:.1%}")
    print(f"  Language acc (adversary): {lang_acc:.1%}")
    print("  Per-language Bond F1:")
    for lang, m in sorted(lang_f1.items(), key=lambda x: -x[1]['n']):
        print(f"    {lang:20s}: F1={m['f1']:.3f} (n={m['n']:,})")
    
    del model, train_dataset, test_dataset
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

print()
print("="*60)
print("TRAINING COMPLETE")
print("="*60)

mark_task("Train BIP model", "done")


In [None]:
#@title 8. Critical Baselines & Ablations { display-mode: "form" }
#@markdown **Reviewer-requested controls to validate universality claims.**
#@markdown 
#@markdown 1. **Rule Baseline**: Pattern matcher only (no neural network)
#@markdown 2. **Shuffle Control**: Shuffle labels within each language
#@markdown 3. **Keyword Masking**: Remove pattern-matching tokens from input

import random
import re
import copy
from sklearn.metrics import f1_score
from collections import defaultdict

print("="*60)
print("CRITICAL BASELINES & ABLATIONS")
print("="*60)
print()
print("These controls address reviewer concern #4:")
print("'You need 2-3 killer baselines to make universality harder to dismiss'")
print()

# Load splits and data
with open('data/splits/all_splits.json', 'r') as f:
    all_splits = json.load(f)

with open('data/splits/baselines.json', 'r') as f:
    baselines = json.load(f)

baseline_results = {}

# ============================================================
# BASELINE 1: Rule-based prediction (pattern matcher only)
# ============================================================
print("="*60)
print("BASELINE 1: RULE-BASED (Pattern Matcher Only)")
print("="*60)
print()
print("If the model beats this cross-lingually, neural learning adds value.")
print()

def rule_based_predict(text: str, language: str) -> str:
    """Predict bond using only pattern matching (no neural network)."""
    text_norm = normalize_text(text, language)
    bond_patterns = ALL_BOND_PATTERNS.get(language, {})
    
    for bond_type, patterns in bond_patterns.items():
        for pattern in patterns:
            if re.search(pattern, text_norm):
                return bond_type.name
    return 'NONE'

# Evaluate rule baseline on each split
print("Evaluating rule baseline on test sets...")

# Load passage data
passages_by_id = {}
with open('data/processed/passages.jsonl', 'r') as f:
    for line in f:
        p = json.loads(line)
        passages_by_id[p['id']] = p

bonds_by_id = {}
with open('data/processed/bonds.jsonl', 'r') as f:
    for line in f:
        b = json.loads(line)
        bonds_by_id[b['passage_id']] = b['bonds']

for split_name in ['hebrew_to_others', 'semitic_to_non_semitic', 'ancient_to_modern']:
    if split_name not in all_splits:
        continue
    
    split = all_splits[split_name]
    test_ids = split['test_ids'][:10000]  # Limit for speed
    
    y_true = []
    y_pred_rule = []
    
    for pid in test_ids:
        if pid not in passages_by_id or pid not in bonds_by_id:
            continue
        
        p = passages_by_id[pid]
        true_bond = bonds_by_id[pid]['primary_bond']
        pred_bond = rule_based_predict(p['text'], p['language'])
        
        y_true.append(true_bond)
        y_pred_rule.append(pred_bond)
    
    if y_true:
        rule_f1 = f1_score(y_true, y_pred_rule, average='macro', zero_division=0)
        baseline_results[f'{split_name}_rule_baseline'] = {
            'f1_macro': rule_f1,
            'n_samples': len(y_true),
            'description': 'Pattern matcher only, no neural network'
        }
        print(f"  {split_name}: Rule F1 = {rule_f1:.3f}")

print()

# ============================================================
# BASELINE 2: Shuffle Control (scrambled labels)
# ============================================================
print("="*60)
print("BASELINE 2: SHUFFLE CONTROL")
print("="*60)
print()
print("Shuffle bond labels WITHIN each language (preserve class balance).")
print("If transfer collapses, the model learned real cross-lingual structure.")
print()

#@markdown **Run shuffle control?** (adds ~15 min)
RUN_SHUFFLE_CONTROL = True  #@param {type:"boolean"}

if RUN_SHUFFLE_CONTROL:
    # Create shuffled bond labels
    print("Creating shuffled labels...")
    
    shuffled_bonds = {}
    by_language = defaultdict(list)
    
    for pid, bonds in bonds_by_id.items():
        if pid in passages_by_id:
            lang = passages_by_id[pid]['language']
            by_language[lang].append((pid, bonds['primary_bond']))
    
    # Shuffle within each language
    random.seed(42)
    for lang, items in by_language.items():
        pids = [x[0] for x in items]
        labels = [x[1] for x in items]
        random.shuffle(labels)  # Shuffle labels only
        for pid, label in zip(pids, labels):
            shuffled_bonds[pid] = label
    
    print(f"Shuffled {len(shuffled_bonds):,} labels across {len(by_language)} languages")
    
    # Save shuffled bonds for training
    with open('data/processed/bonds_shuffled.jsonl', 'w') as f:
        for pid, bond in shuffled_bonds.items():
            f.write(json.dumps({
                'passage_id': pid,
                'bonds': {
                    'primary_bond': bond,
                    'all_bonds': [bond],
                    'hohfeld': None,
                    'language': passages_by_id[pid]['language']
                }
            }) + '\n')
    
    # Train a model on shuffled labels (abbreviated training)
    print("\nTraining model on SHUFFLED labels (hebrew_to_others split)...")
    
    # Use existing NativeDataset but with shuffled bonds file
    class ShuffledDataset(torch.utils.data.Dataset):
        def __init__(self, ids, passages_file, bonds_file, tokenizer, max_len=128):
            self.tokenizer = tokenizer
            self.max_len = max_len
            self.data = []
            
            ids_set = set(ids)
            
            with open(passages_file, 'r') as fp, open(bonds_file, 'r') as fb:
                for p_line, b_line in zip(fp, fb):
                    p = json.loads(p_line)
                    b = json.loads(b_line)
                    if p['id'] in ids_set and b['passage_id'] == p['id']:
                        self.data.append({
                            'id': p['id'],
                            'text': p['text'][:500],
                            'language': p['language'],
                            'period': p['time_period'],
                            'bond': b['bonds']['primary_bond'],
                        })
        
        def __len__(self):
            return len(self.data)
        
        def __getitem__(self, idx):
            item = self.data[idx]
            enc = self.tokenizer(item['text'], truncation=True, max_length=self.max_len,
                                padding='max_length', return_tensors='pt')
            return {
                'input_ids': enc['input_ids'].squeeze(0),
                'attention_mask': enc['attention_mask'].squeeze(0),
                'bond': item['bond'],
                'language': item['language'],
            }
    
    # Quick training (2 epochs)
    shuffle_model = BIPModel().to(device)
    split = all_splits['hebrew_to_others']
    
    shuffle_train = ShuffledDataset(
        split['train_ids'], 
        'data/processed/passages.jsonl',
        'data/processed/bonds_shuffled.jsonl',
        tokenizer
    )
    shuffle_test = ShuffledDataset(
        split['test_ids'][:5000],
        'data/processed/passages.jsonl', 
        'data/processed/bonds.jsonl',  # Test on REAL labels
        tokenizer
    )
    
    if len(shuffle_train) > 0:
        def simple_collate(batch):
            bond_to_idx = {b.name: i for i, b in enumerate(BondType)}
            return {
                'input_ids': torch.stack([x['input_ids'] for x in batch]),
                'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
                'bond_labels': torch.tensor([bond_to_idx.get(x['bond'], 0) for x in batch]),
            }
        
        train_loader = DataLoader(shuffle_train, batch_size=256, shuffle=True, collate_fn=simple_collate)
        test_loader = DataLoader(shuffle_test, batch_size=256, collate_fn=simple_collate)
        
        optimizer = torch.optim.AdamW(shuffle_model.parameters(), lr=2e-5)
        
        # Quick training
        shuffle_model.train()
        for epoch in range(2):
            for batch in tqdm(train_loader, desc=f"Shuffle Epoch {epoch+1}", leave=False):
                out = shuffle_model(batch['input_ids'].to(device), batch['attention_mask'].to(device), 0)
                loss = F.cross_entropy(out['bond_pred'], batch['bond_labels'].to(device))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        # Evaluate
        shuffle_model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in test_loader:
                out = shuffle_model(batch['input_ids'].to(device), batch['attention_mask'].to(device), 0)
                all_preds.extend(out['bond_pred'].argmax(-1).cpu().tolist())
                all_labels.extend(batch['bond_labels'].tolist())
        
        shuffle_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
        baseline_results['shuffle_control'] = {
            'f1_macro': shuffle_f1,
            'description': 'Trained on shuffled labels, tested on real labels',
            'expected': 'Should be near chance if model learns real structure'
        }
        print(f"\n  Shuffle control F1: {shuffle_f1:.3f}")
        print(f"  (Chance: {baselines['chance_bond']:.3f})")
        
        if shuffle_f1 < baselines['chance_bond'] * 1.5:
            print("  ✓ Transfer collapsed as expected!")
        else:
            print("  ✗ WARNING: Model may be learning spurious patterns")
        
        del shuffle_model
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
else:
    print("Shuffle control skipped (enable RUN_SHUFFLE_CONTROL)")

print()

# ============================================================
# BASELINE 3: Keyword Masking Ablation
# ============================================================
print("="*60)
print("BASELINE 3: KEYWORD MASKING ABLATION")
print("="*60)
print()
print("Mask/remove tokens that match bond patterns from input text.")
print("If transfer holds, model isn't just doing keyword spotting.")
print()

#@markdown **Run keyword masking ablation?** (adds ~20 min)
RUN_KEYWORD_MASKING = True  #@param {type:"boolean"}

if RUN_KEYWORD_MASKING:
    def mask_keywords(text: str, language: str, mask_token: str = "[MASK]") -> str:
        """Replace pattern-matching keywords with mask token."""
        text_masked = text
        bond_patterns = ALL_BOND_PATTERNS.get(language, {})
        hohfeld_patterns = ALL_HOHFELD_PATTERNS.get(language, {})
        
        all_patterns = []
        for patterns in bond_patterns.values():
            all_patterns.extend(patterns)
        for patterns in hohfeld_patterns.values():
            all_patterns.extend(patterns)
        
        for pattern in all_patterns:
            text_masked = re.sub(pattern, mask_token, text_masked, flags=re.IGNORECASE)
        
        return text_masked
    
    # Test masking
    test_text = "You must honor your father and mother"
    print(f"Example masking:")
    print(f"  Original: '{test_text}'")
    print(f"  Masked:   '{mask_keywords(test_text, 'english')}'")
    print()
    
    # Create masked dataset
    class MaskedDataset(torch.utils.data.Dataset):
        def __init__(self, ids, passages_by_id, bonds_by_id, tokenizer, max_len=128):
            self.tokenizer = tokenizer
            self.max_len = max_len
            self.data = []
            
            for pid in ids:
                if pid in passages_by_id and pid in bonds_by_id:
                    p = passages_by_id[pid]
                    masked_text = mask_keywords(p['text'][:500], p['language'])
                    self.data.append({
                        'text': masked_text,
                        'bond': bonds_by_id[pid]['primary_bond'],
                        'language': p['language'],
                    })
        
        def __len__(self):
            return len(self.data)
        
        def __getitem__(self, idx):
            item = self.data[idx]
            enc = self.tokenizer(item['text'], truncation=True, max_length=self.max_len,
                                padding='max_length', return_tensors='pt')
            return {
                'input_ids': enc['input_ids'].squeeze(0),
                'attention_mask': enc['attention_mask'].squeeze(0),
                'bond': item['bond'],
            }
    
    print("Training model on MASKED text (hebrew_to_others split)...")
    
    masked_model = BIPModel().to(device)
    split = all_splits['hebrew_to_others']
    
    masked_train = MaskedDataset(split['train_ids'], passages_by_id, bonds_by_id, tokenizer)
    masked_test = MaskedDataset(split['test_ids'][:5000], passages_by_id, bonds_by_id, tokenizer)
    
    if len(masked_train) > 0:
        train_loader = DataLoader(masked_train, batch_size=256, shuffle=True, collate_fn=simple_collate)
        test_loader = DataLoader(masked_test, batch_size=256, collate_fn=simple_collate)
        
        optimizer = torch.optim.AdamW(masked_model.parameters(), lr=2e-5)
        
        masked_model.train()
        for epoch in range(3):
            for batch in tqdm(train_loader, desc=f"Masked Epoch {epoch+1}", leave=False):
                out = masked_model(batch['input_ids'].to(device), batch['attention_mask'].to(device), 0)
                loss = F.cross_entropy(out['bond_pred'], batch['bond_labels'].to(device))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        # Evaluate on masked test set
        masked_model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in test_loader:
                out = masked_model(batch['input_ids'].to(device), batch['attention_mask'].to(device), 0)
                all_preds.extend(out['bond_pred'].argmax(-1).cpu().tolist())
                all_labels.extend(batch['bond_labels'].tolist())
        
        masked_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
        
        # Compare to main model
        main_model_f1 = all_results.get('hebrew_to_others', {}).get('bond_f1_macro', 0)
        
        baseline_results['keyword_masking'] = {
            'f1_macro': masked_f1,
            'main_model_f1': main_model_f1,
            'retention': masked_f1 / main_model_f1 if main_model_f1 > 0 else 0,
            'description': 'Keywords masked from input, tests if model learns beyond patterns'
        }
        
        print(f"\n  Keyword-masked F1: {masked_f1:.3f}")
        print(f"  Main model F1:     {main_model_f1:.3f}")
        print(f"  Retention:         {masked_f1/main_model_f1*100:.1f}%" if main_model_f1 > 0 else "N/A")
        
        if masked_f1 > baselines['chance_bond'] * 1.3:
            print("  ✓ Transfer holds even without keywords!")
            print("  → Model learns semantic structure, not just keyword spotting")
        else:
            print("  ✗ Performance dropped significantly")
            print("  → Model may rely heavily on keyword patterns")
        
        del masked_model
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
else:
    print("Keyword masking ablation skipped (enable RUN_KEYWORD_MASKING)")

# Save baseline results
with open('results/baseline_ablation_results.json', 'w') as f:
    json.dump(baseline_results, f, indent=2)

print()
print("="*60)
print("BASELINE SUMMARY")
print("="*60)
for name, res in baseline_results.items():
    print(f"  {name}: F1={res['f1_macro']:.3f}")

print()
print("Results saved to results/baseline_ablation_results.json")


In [None]:
#@title 9. Linear Probe Test { display-mode: "form" }
#@markdown Can language/period be decoded from frozen z_bond?

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

mark_task("Linear probe test", "running")

print("="*60)
print("LINEAR PROBE TEST")
print("="*60)
print()
print("Testing if language/period can be decoded from z_bond.")
print("If probe accuracy ≈ chance, confounds are removed!")
print()

linear_probe_results = {}

for split_name in [s for s in all_results.keys() if s != 'mixed_baseline']:
    print(f"\n{'='*50}")
    print(f"PROBE: {split_name}")
    print(f"{'='*50}")
    
    model = BIPModel().to(device)
    model.load_state_dict(torch.load(f'models/checkpoints/best_{split_name}.pt'))
    model.eval()
    
    with open('data/splits/all_splits.json', 'r') as f:
        split = json.load(f)[split_name]
    
    test_dataset = NativeDataset(set(split['test_ids']), 'data/processed/passages.jsonl',
                                  'data/processed/bonds.jsonl', tokenizer)
    test_loader = DataLoader(test_dataset, batch_size=BASE_BATCH_SIZE, collate_fn=collate_fn, num_workers=4)
    
    all_z = []
    all_lang = []
    all_period = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Extract", unit="batch"):
            z = model.extract_z_bond(batch['input_ids'].to(device), batch['attention_mask'].to(device))
            all_z.append(z.cpu().numpy())
            all_lang.extend(batch['language_labels'].tolist())
            all_period.extend(batch['period_labels'].tolist())
    
    X = np.vstack(all_z)
    y_lang = np.array(all_lang)
    y_period = np.array(all_period)
    
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    np.random.seed(42)
    idx = np.random.permutation(len(X_scaled))
    train_idx, test_idx = idx[:len(idx)//2], idx[len(idx)//2:]
    
    # Language probe
    lang_probe = LogisticRegression(max_iter=1000, multi_class='multinomial', n_jobs=-1)
    lang_probe.fit(X_scaled[train_idx], y_lang[train_idx])
    lang_acc = (lang_probe.predict(X_scaled[test_idx]) == y_lang[test_idx]).mean()
    lang_chance = 1.0 / len(np.unique(y_lang[test_idx]))
    # CORRECT: Low accuracy = CAN'T decode = INVARIANT (good!)
    # v4 bug: 0% accuracy was marked as "not invariant" - that's backwards!
    lang_invariant = lang_acc < (lang_chance + 0.15)  # Within 15% of chance = invariant
    
    # Period probe
    period_probe = LogisticRegression(max_iter=1000, multi_class='multinomial', n_jobs=-1)
    period_probe.fit(X_scaled[train_idx], y_period[train_idx])
    period_acc = (period_probe.predict(X_scaled[test_idx]) == y_period[test_idx]).mean()
    period_chance = 1.0 / len(np.unique(y_period[test_idx]))
    period_invariant = period_acc < (period_chance + 0.15)  # Within 15% of chance = invariant
    
    print(f"\nRESULTS:")
    print(f"  Language: {lang_acc:.1%} (chance: {lang_chance:.1%}) {'✓' if lang_invariant else '✗'}")
    print(f"  Period:   {period_acc:.1%} (chance: {period_chance:.1%}) {'✓' if period_invariant else '✗'}")
    
    linear_probe_results[split_name] = {
        'lang_probe_acc': float(lang_acc),
        'lang_chance': float(lang_chance),
        'lang_invariant': bool(lang_invariant),
        'period_probe_acc': float(period_acc),
        'period_chance': float(period_chance),
        'period_invariant': bool(period_invariant),
    }
    
    del model
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

with open('results/linear_probe_results.json', 'w') as f:
    json.dump(linear_probe_results, f, indent=2)

mark_task("Linear probe test", "done")


In [None]:
#@title 10. Final Evaluation { display-mode: "form" }
#@markdown Comprehensive results with proper interpretation of asymmetric transfers.
#@markdown 
#@markdown **Key insight from v4**: Ancient→Modern showed TIME-INVARIANT transfer
#@markdown (0% time accuracy = can't decode time = GOOD for invariance!)

import time
import json

mark_task("Evaluate results", "running")

print("="*60)
print("FINAL BIP EVALUATION (v9 - Native Patterns)")
print("="*60)
print()
print("METHODOLOGY: Labels extracted from NATIVE text using NATIVE patterns.")
print("NO English translation used. Only mathematical alignment connects languages.")
print()

# Reference v4 findings
print("-"*60)
print("CONTEXT: v4 Findings (Ancient→Modern)")
print("-"*60)
print("""
  v4 showed ASYMMETRIC results:
  
  Ancient→Modern (3.9M → 19K):
    - Time accuracy: 0.0%  ← CAN'T decode time = INVARIANT ✓
    - Hohfeld accuracy: 48.6% (1.94x chance) = TRANSFER ✓
    
  Modern→Ancient (19K → 3.9M):  
    - Failed (likely data quantity issue)
    
  Key insight: Low probe accuracy is GOOD - it means the
  representation doesn't encode the confound!
""")
print("-"*60)
print()

with open('data/splits/baselines.json', 'r') as f:
    baselines = json.load(f)

chance_bond = baselines['chance_bond']
chance_language = baselines['chance_language']
chance_period = baselines['chance_period']

# Load baseline ablation results
try:
    with open('results/baseline_ablation_results.json', 'r') as f:
        baseline_ablation = json.load(f)
except:
    baseline_ablation = {}

print("="*60)
print("CROSS-DOMAIN TRANSFER RESULTS")
print("="*60)

# Track which splits show the pattern we want
successful_transfers = []
asymmetric_notes = []

for split_name, res in all_results.items():
    if split_name == 'mixed_baseline':
        continue
    
    print(f"\n{'='*60}")
    print(f"{split_name.upper()}")
    print("="*60)
    
    probe = linear_probe_results.get(split_name, {})
    
    # Get metrics
    bond_f1 = res.get('bond_f1_macro', 0)
    lang_probe_acc = probe.get('lang_probe_acc', 0)
    lang_chance = probe.get('lang_chance', 0.2)
    period_probe_acc = probe.get('period_probe_acc', 0)
    period_chance = probe.get('period_chance', 0.1)
    
    # CORRECT INTERPRETATION:
    # - Low probe accuracy (near or below chance) = GOOD (invariant)
    # - High bond F1 (above chance) = GOOD (transfer works)
    
    # Check invariance (probe should be NEAR chance, not above)
    lang_invariant = lang_probe_acc < (lang_chance + 0.15)  # Within 15% of chance
    period_invariant = period_probe_acc < (period_chance + 0.15)
    
    # Check transfer (bond prediction should beat chance)
    transfer_ratio = bond_f1 / chance_bond if chance_bond > 0 else 0
    transfer_works = transfer_ratio > 1.5
    
    print(f"\n  INVARIANCE (low probe acc = GOOD):")
    print(f"    Language probe: {lang_probe_acc:.1%} (chance: {lang_chance:.1%})")
    if lang_probe_acc < lang_chance:
        print(f"      → BELOW chance = STRONGLY INVARIANT ✓✓")
    elif lang_invariant:
        print(f"      → Near chance = INVARIANT ✓")
    else:
        print(f"      → Above chance = NOT invariant ✗")
    
    print(f"    Period probe:   {period_probe_acc:.1%} (chance: {period_chance:.1%})")
    if period_probe_acc < period_chance:
        print(f"      → BELOW chance = STRONGLY INVARIANT ✓✓")
    elif period_invariant:
        print(f"      → Near chance = INVARIANT ✓")
    else:
        print(f"      → Above chance = NOT invariant ✗")
    
    print(f"\n  TRANSFER (high bond F1 = GOOD):")
    print(f"    Bond F1: {bond_f1:.3f} (chance: {chance_bond:.3f})")
    print(f"    Transfer ratio: {transfer_ratio:.2f}x chance")
    if transfer_ratio > 2.0:
        print(f"      → STRONG transfer ✓✓")
    elif transfer_works:
        print(f"      → Transfer works ✓")
    else:
        print(f"      → Weak/no transfer ✗")
    
    # Per-language breakdown (important for asymmetric analysis)
    if 'per_language_f1' in res:
        print(f"\n  PER-LANGUAGE BREAKDOWN:")
        lang_f1s = res['per_language_f1']
        for lang, m in sorted(lang_f1s.items(), key=lambda x: -x[1].get('n', 0)):
            f1 = m.get('f1', 0)
            n = m.get('n', 0)
            ratio = f1 / chance_bond if chance_bond > 0 else 0
            status = "✓" if ratio > 1.5 else "~" if ratio > 1.0 else "✗"
            print(f"    {lang:20s}: F1={f1:.3f} ({ratio:.1f}x) {status}  (n={n:,})")
    
    # Data size analysis (asymmetry check)
    train_size = res.get('train_size', 0) or all_splits.get(split_name, {}).get('train_size', 0)
    test_size = res.get('test_size', 0) or all_splits.get(split_name, {}).get('test_size', 0)
    
    if train_size and test_size:
        size_ratio = train_size / test_size if test_size > 0 else 0
        print(f"\n  DATA SIZES:")
        print(f"    Train: {train_size:,} | Test: {test_size:,} | Ratio: {size_ratio:.1f}x")
        
        if size_ratio > 10:
            asymmetric_notes.append(f"{split_name}: Large→small ({size_ratio:.0f}x)")
        elif size_ratio < 0.1:
            asymmetric_notes.append(f"{split_name}: Small→large ({1/size_ratio:.0f}x)")
    
    # Overall assessment for this split
    print(f"\n  ASSESSMENT:")
    if (lang_invariant or period_invariant) and transfer_works:
        print(f"    ✓ SUCCESS: Invariant representation + working transfer")
        successful_transfers.append(split_name)
    elif transfer_works and not (lang_invariant or period_invariant):
        print(f"    ~ PARTIAL: Transfer works but representation encodes confounds")
        print(f"      (Could be data asymmetry - check sizes above)")
    elif (lang_invariant or period_invariant) and not transfer_works:
        print(f"    ~ PARTIAL: Invariant but weak transfer")
        print(f"      (May need more training data or epochs)")
    else:
        print(f"    ✗ FAILED: Neither invariance nor transfer")

# ============================================================
# BASELINE COMPARISON
# ============================================================
print()
print("="*60)
print("BASELINE COMPARISONS")
print("="*60)

beats_rule_baseline = False
shuffle_collapsed = False  
keyword_robust = False

if baseline_ablation:
    # Rule baseline comparison
    print("\n1. NEURAL vs RULE BASELINE:")
    for key in baseline_ablation:
        if 'rule_baseline' in key:
            rule_f1 = baseline_ablation[key].get('f1_macro', 0)
            split = key.replace('_rule_baseline', '')
            neural_f1 = all_results.get(split, {}).get('bond_f1_macro', 0)
            improvement = (neural_f1 - rule_f1) / rule_f1 * 100 if rule_f1 > 0 else 0
            print(f"    {split}:")
            print(f"      Rule:   {rule_f1:.3f}")
            print(f"      Neural: {neural_f1:.3f}")
            print(f"      Improvement: {improvement:+.1f}%")
            if neural_f1 > rule_f1 * 1.1:
                beats_rule_baseline = True
                print(f"      → Neural adds value ✓")
    
    # Shuffle control
    if 'shuffle_control' in baseline_ablation:
        print("\n2. SHUFFLE CONTROL:")
        shuffle_f1 = baseline_ablation['shuffle_control'].get('f1_macro', 0)
        print(f"    Shuffled labels F1: {shuffle_f1:.3f}")
        print(f"    Chance: {chance_bond:.3f}")
        if shuffle_f1 < chance_bond * 1.5:
            shuffle_collapsed = True
            print(f"    → Transfer collapsed as expected ✓")
        else:
            print(f"    → WARNING: May learn spurious patterns ✗")
    
    # Keyword masking
    if 'keyword_masking' in baseline_ablation:
        print("\n3. KEYWORD MASKING:")
        masked_f1 = baseline_ablation['keyword_masking'].get('f1_macro', 0)
        main_f1 = baseline_ablation['keyword_masking'].get('main_model_f1', 0)
        retention = masked_f1 / main_f1 * 100 if main_f1 > 0 else 0
        print(f"    With keywords:    {main_f1:.3f}")
        print(f"    Without keywords: {masked_f1:.3f}")
        print(f"    Retention: {retention:.1f}%")
        if masked_f1 > chance_bond * 1.3:
            keyword_robust = True
            print(f"    → Model learns beyond keywords ✓")
        else:
            print(f"    → Relies heavily on keywords ✗")
else:
    print("\n  (Baseline ablations not run)")

# ============================================================
# ASYMMETRIC ANALYSIS
# ============================================================
if asymmetric_notes:
    print()
    print("="*60)
    print("ASYMMETRIC TRANSFER ANALYSIS")
    print("="*60)
    print()
    print("Like v4, some splits have large data asymmetry:")
    for note in asymmetric_notes:
        print(f"  • {note}")
    print()
    print("Large→small transfers typically work better (more training data).")
    print("This is expected and doesn't invalidate the universality claim.")

# ============================================================
# FINAL VERDICT
# ============================================================
print()
print("="*60)
print("FINAL VERDICT")
print("="*60)

# Count successes
n_successful = len(successful_transfers)
n_total_splits = len([s for s in all_results if s != 'mixed_baseline'])
baseline_checks = sum([beats_rule_baseline, shuffle_collapsed, keyword_robust])

print(f"""
EVIDENCE SUMMARY:
─────────────────
  Successful transfers (invariant + transfer): {n_successful}/{n_total_splits}
  Splits: {', '.join(successful_transfers) if successful_transfers else 'None'}
  
  Baseline checks passed: {baseline_checks}/3
    • Neural > Rule:        {'✓' if beats_rule_baseline else '✗'}
    • Shuffle collapsed:    {'✓' if shuffle_collapsed else '✗'}
    • Keyword-robust:       {'✓' if keyword_robust else '✗'}
""")

# Determine verdict
if n_successful >= 2 and baseline_checks >= 2:
    verdict = "STRONGLY_SUPPORTED"
    box = """
    ╔══════════════════════════════════════════════════════════╗
    ║     BIP: STRONGLY SUPPORTED                              ║
    ╠══════════════════════════════════════════════════════════╣
    ║  • Multiple splits show invariant cross-lingual transfer ║
    ║  • Native patterns (no English bridge)                   ║
    ║  • Baselines confirm neural model learns real structure  ║
    ║                                                          ║
    ║  Evidence supports universal mathematical structure      ║
    ║  in moral cognition across languages and time.           ║
    ╚══════════════════════════════════════════════════════════╝
    """
elif n_successful >= 1 or (n_total_splits > 0 and baseline_checks >= 2):
    verdict = "SUPPORTED"
    box = """
    ╔══════════════════════════════════════════════════════════╗
    ║     BIP: SUPPORTED                                       ║
    ╠══════════════════════════════════════════════════════════╣
    ║  • Some cross-lingual transfer demonstrated              ║
    ║  • May have asymmetric effects (data size dependent)     ║
    ║  • Similar to v4: large→small direction works better     ║
    ╚══════════════════════════════════════════════════════════╝
    """
elif any(r.get('bond_f1_macro', 0) > chance_bond * 1.3 for r in all_results.values()):
    verdict = "PARTIAL"
    box = """
    ╔══════════════════════════════════════════════════════════╗
    ║     BIP: PARTIAL SUPPORT                                 ║
    ╠══════════════════════════════════════════════════════════╣
    ║  • Some transfer above chance, but:                      ║
    ║  • Representation may encode confounds                   ║
    ║  • Or baseline checks failed                             ║
    ║                                                          ║
    ║  Needs investigation of specific failure modes.          ║
    ╚══════════════════════════════════════════════════════════╝
    """
else:
    verdict = "INCONCLUSIVE"
    box = """
    ╔══════════════════════════════════════════════════════════╗
    ║     BIP: INCONCLUSIVE                                    ║
    ╠══════════════════════════════════════════════════════════╣
    ║  • Transfer not demonstrated                             ║
    ║                                                          ║
    ║  Possible issues:                                        ║
    ║  • Label quality (check coverage metrics)                ║
    ║  • Data quantity (check asymmetry)                       ║
    ║  • Training (try more epochs)                            ║
    ║  • Or: BIP may not hold (null result)                    ║
    ╚══════════════════════════════════════════════════════════╝
    """

print(box)

# v4 comparison note
print()
print("─"*60)
print("COMPARISON TO v4:")
print("─"*60)
print("""
v4 (Ancient→Modern) showed:
  • 0% time accuracy (GOOD - time invariant)
  • 48.6% Hohfeld accuracy (1.94x chance - transfer works)

This was actually a SUCCESS that was mislabeled as "inconclusive"
because the evaluation logic was inverted.

v9 improvements over v4:
  • Native patterns (no translation bridge)
  • Text normalization (Hebrew/Arabic diacritics)
  • Multiple splits (not just temporal)
  • Baseline ablations (rule, shuffle, masking)
  • Correct interpretation of probe accuracy
""")

# ============================================================
# SAVE RESULTS
# ============================================================
total_time = time.time() - EXPERIMENT_START

# Load split info
try:
    with open('data/splits/all_splits.json', 'r') as f:
        all_splits = json.load(f)
except:
    all_splits = {}

final_results = {
    'model_results': all_results,
    'linear_probe_results': linear_probe_results,
    'baseline_ablation_results': baseline_ablation,
    'successful_transfers': successful_transfers,
    'baseline_checks': {
        'beats_rule_baseline': beats_rule_baseline,
        'shuffle_collapsed': shuffle_collapsed,
        'keyword_robust': keyword_robust,
        'total_passed': baseline_checks,
    },
    'verdict': verdict,
    'total_time_minutes': total_time / 60,
    'methodology': {
        'approach': 'Native-language pattern matching',
        'translation_used': False,
        'text_normalization': True,
        'languages': ['hebrew', 'aramaic', 'classical_chinese', 'arabic', 'english'],
        'label_source': 'Native text with language-specific patterns',
        'connection': 'Mathematical only (shared latent space)',
        'adversarial_warmup': True,
    },
    'baselines': baselines,
    'v4_comparison': {
        'ancient_to_modern_time_acc': 0.0,
        'ancient_to_modern_hohfeld_acc': 0.486,
        'note': 'v4 was actually a success - 0% time acc means invariant'
    }
}

with open('results/final_results.json', 'w') as f:
    json.dump(final_results, f, indent=2, default=str)

# Also save to Drive
try:
    with open(f'{SAVE_DIR}/final_results.json', 'w') as f:
        json.dump(final_results, f, indent=2, default=str)
except:
    pass

print(f"\nTotal time: {total_time/60:.1f} minutes")
print(f"Results saved to results/final_results.json")

mark_task("Evaluate results", "done")
print_progress()


In [None]:
#@title 11. Download Results & Models { display-mode: "form" }
#@markdown Downloads trained models and results as a zip file.
#@markdown **Models can be re-uploaded to v10 for analysis without Google Drive.**

import shutil
import os
from google.colab import files

print("="*60)
print("PACKAGING RESULTS FOR DOWNLOAD")
print("="*60)

!mkdir -p download_package/models
!mkdir -p download_package/data

# Copy trained models
print("\nCopying trained models...")
for split_name in all_results.keys():
    src = f'models/checkpoints/best_{split_name}.pt'
    if os.path.exists(src):
        !cp "{src}" download_package/models/
        size_mb = os.path.getsize(src) / 1e6
        print(f"  {split_name}: {size_mb:.1f} MB")

# Copy data files
print("\nCopying data files...")
!cp data/processed/passages.jsonl download_package/data/ 2>/dev/null && echo "  passages.jsonl"
!cp data/processed/bonds.jsonl download_package/data/ 2>/dev/null && echo "  bonds.jsonl"
!cp data/splits/*.json download_package/data/ 2>/dev/null && echo "  splits/*.json"

# Copy results
print("\nCopying results...")
!cp results/*.json download_package/ 2>/dev/null || true

# Create zip
print("\nCreating zip archive...")
shutil.make_archive('bip_v9_complete', 'zip', 'download_package')

# Show contents
print("\n" + "="*60)
print("PACKAGE CONTENTS:")
print("="*60)
!find download_package -type f -exec ls -lh {} \;

total_size = sum(os.path.getsize(os.path.join(dp, f)) for dp, dn, fn in os.walk('download_package') for f in fn)
print(f"\nTotal package size: {total_size/1e6:.1f} MB")

print("\n" + "="*60)
print("DOWNLOADING...")
print("="*60)
print("Save this zip file - it contains everything needed for v10 analysis!")

files.download('bip_v9_complete.zip')
