# BIP Cross-Cultural Universal Morality Experiment

**Testing the Bond Invariance Principle across 2000+ years AND across languages (Hebrew + Chinese + Arabic → English)**

This experiment tests whether moral cognition has invariant structure by:
1. Training on ORIGINAL HEBREW texts (Sefaria corpus, ~500 BCE - 1800 CE)
2. Testing transfer to modern ENGLISH advice columns (Dear Abby, 1956-2020)

**Hypothesis**: If BIP holds, bond-level features should transfer across 2000 years with no accuracy drop.

---

## Setup Instructions
1. **Runtime -> Change runtime type -> GPU (T4) or TPU (v5e)**
2. Run cells in order - each shows progress in real-time
3. Expected runtime: ~1-2 hours (TPU) or ~2-4 hours (GPU)

**Supported Accelerators:**
- NVIDIA GPU (T4, V100, A100)
- Google TPU (v2, v3, v4, v5e)
- CPU (slow, not recommended)

---

In [None]:
#@title 1. Setup and Install Dependencies { display-mode: "form" }
#@markdown Installs packages and detects GPU/TPU. Memory-optimized for Colab.

print("=" * 60)
print("BIP TEMPORAL INVARIANCE EXPERIMENT")
print("=" * 60)
print()

# Progress tracker
TASKS = [
    "Install dependencies",
    "Clone Sefaria corpus (~8GB)",
    "Clone sqnd-probe repo (Dear Abby data)",
    "Preprocess corpora",
    "Extract bond structures",
    "Generate train/test splits",
    "Train BIP model",
    "Evaluate results"
]
task_status = {task: "pending" for task in TASKS}

def print_progress():
    print()
    print("-" * 50)
    print("EXPERIMENT PROGRESS:")
    print("-" * 50)
    for task in TASKS:
        status = task_status[task]
        if status == "done":
            mark = "[X]"
        elif status == "running":
            mark = "[>]"
        else:
            mark = "[ ]"
        print(f"  {mark} {task}")
    print("-" * 50)
    print(flush=True)

def mark_task(task, status):
    task_status[task] = status
    print_progress()

print_progress()

mark_task("Install dependencies", "running")

import os
import subprocess
import sys

# Install dependencies - MINIMAL set to save memory
print("Installing minimal dependencies...")
deps = [
    "transformers",
    "torch", 
    "sentence-transformers",
    "pandas",
    "tqdm",
    "psutil"
]

for dep in deps:
    print(f"  Installing {dep}...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", dep])

print()

# Detect accelerator - WITHOUT importing tensorflow
USE_TPU = False
TPU_TYPE = None

# Check for TPU
if 'COLAB_TPU_ADDR' in os.environ:
    USE_TPU = True
    TPU_TYPE = "TPU (Colab)"
    print("TPU detected!")

# Check for GPU
import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    ACCELERATOR = f"GPU: {gpu_name} ({gpu_mem:.1f}GB)"
    device = torch.device("cuda")
elif USE_TPU:
    ACCELERATOR = TPU_TYPE
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
else:
    ACCELERATOR = "CPU (slow!)"
    device = torch.device("cpu")

print(f"Accelerator: {ACCELERATOR}")
print(f"Device: {device}")

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"System RAM: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

if torch.cuda.is_available():
    print(f"GPU RAM: {torch.cuda.memory_allocated()/1e9:.1f}/{torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")

# Enable mixed precision for 2-3x speedup
if torch.cuda.is_available():
    print()
    print("Enabling mixed precision (FP16) for faster training...")
    from torch.cuda.amp import autocast, GradScaler
    USE_AMP = True
    scaler = GradScaler()
else:
    USE_AMP = False
    scaler = None

# torch.compile for PyTorch 2.0+ (10-30% speedup)
TORCH_COMPILE = False
if hasattr(torch, 'compile') and torch.cuda.is_available():
    print("PyTorch 2.0+ detected - torch.compile available")
    TORCH_COMPILE = False  # Disabled - overhead > benefit for short runs



# ============================================================
# GOOGLE DRIVE - SAVE RESULTS EVEN IF SESSION DIES
# ============================================================
print()
print("=" * 60)
print("MOUNTING GOOGLE DRIVE FOR PERSISTENT STORAGE")
print("=" * 60)
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = '/content/drive/MyDrive/BIP_results'
os.makedirs(SAVE_DIR, exist_ok=True)

# Create local directories
os.makedirs("data/processed", exist_ok=True)
os.makedirs("data/splits", exist_ok=True)
os.makedirs("models/checkpoints", exist_ok=True)
os.makedirs("results", exist_ok=True)
print(f"Results will be saved to: {SAVE_DIR}")
print("If session crashes, your data survives.")
print()

mark_task("Install dependencies", "done")



In [None]:
#@title 2. Download Sefaria Corpus (~8GB) { display-mode: "form" }
#@markdown Downloads the complete Sefaria corpus with real-time git progress.

import subprocess
import sys

mark_task("Clone Sefaria corpus (~8GB)", "running")

sefaria_path = 'data/raw/Sefaria-Export'

if not os.path.exists(sefaria_path) or not os.path.exists(f"{sefaria_path}/json"):
    print("="*60)
    print("CLONING SEFARIA CORPUS")
    print("="*60)
    print()
    print("This downloads ~3.5GB and takes 5-15 minutes.")
    print("Git's native progress will display below:")
    print("-"*60)
    print(flush=True)
    
    # Use subprocess.Popen for real-time output streaming
    process = subprocess.Popen(
        ['git', 'clone', '--depth', '1', '--progress',
         'https://github.com/Sefaria/Sefaria-Export.git', sefaria_path],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,  # Git writes progress to stderr
        text=True,
        bufsize=1
    )
    
    # Stream output in real-time
    for line in process.stdout:
        print(line, end='', flush=True)
    
    process.wait()
    
    print("-"*60)
    if process.returncode == 0:
        print("\nSefaria clone COMPLETE!")
    else:
        print(f"\nERROR: Git clone failed with code {process.returncode}")
        print("Try running this cell again, or check your internet connection.")
else:
    print("Sefaria already exists, skipping download.")

# Verify and count files
print()
print("Verifying download...")
!du -sh {sefaria_path} 2>/dev/null || echo "Directory not found"
json_count = !find {sefaria_path}/json -name "*.json" 2>/dev/null | wc -l
print(f"Sefaria JSON files found: {json_count[0]}")

mark_task("Clone Sefaria corpus (~8GB)", "done")


In [None]:
#@title 3. Download Dear Abby Dataset { display-mode: "form" }
#@markdown Downloads the Dear Abby advice column dataset (68,330 entries).

import pandas as pd
from pathlib import Path

mark_task("Clone sqnd-probe repo (Dear Abby data)", "running")

sqnd_path = 'sqnd-probe-data'
if not os.path.exists(sqnd_path):
    print("Cloning sqnd-probe repo...")
    process = subprocess.Popen(
        ['git', 'clone', '--depth', '1', '--progress',
         'https://github.com/ahb-sjsu/sqnd-probe.git', sqnd_path],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )
    for line in process.stdout:
        print(line, end='', flush=True)
    process.wait()
else:
    print("Repo already cloned.")

# Copy Dear Abby data
dear_abby_source = Path('sqnd-probe-data/dear_abby_data/raw_da_qs.csv')
dear_abby_path = Path('data/raw/dear_abby.csv')

if dear_abby_source.exists():
    !cp "{dear_abby_source}" "{dear_abby_path}"
    print(f"\nCopied Dear Abby data")
elif not dear_abby_path.exists():
    raise FileNotFoundError("Dear Abby dataset not found!")

# Verify
df_check = pd.read_csv(dear_abby_path)
print(f"\n" + "=" * 50)
print(f"Dear Abby dataset: {len(df_check):,} entries")
print(f"Columns: {list(df_check.columns)}")
print(f"Year range: {df_check['year'].min():.0f} - {df_check['year'].max():.0f}")
print("=" * 50)

mark_task("Clone sqnd-probe repo (Dear Abby data)", "done")

In [None]:
#@title 4. Define Data Classes and Loaders { display-mode: "form" }
#@markdown Defines enums, dataclasses, and corpus loaders.

import json
import hashlib
import re
from dataclasses import dataclass, field, asdict
from typing import List, Dict
from enum import Enum
from collections import defaultdict
from tqdm.auto import tqdm

print("Defining data structures...")

class TimePeriod(Enum):
    BIBLICAL = 0        # ~1000-500 BCE
    SECOND_TEMPLE = 1   # ~500 BCE - 70 CE
    TANNAITIC = 2       # ~70-200 CE
    AMORAIC = 3         # ~200-500 CE
    GEONIC = 4          # ~600-1000 CE
    RISHONIM = 5        # ~1000-1500 CE
    ACHRONIM = 6        # ~1500-1800 CE
    MODERN_HEBREW = 7   # ~1800-present
    DEAR_ABBY = 8       # 1956-2020

class BondType(Enum):
    HARM_PREVENTION = 0
    RECIPROCITY = 1
    AUTONOMY = 2
    PROPERTY = 3
    FAMILY = 4
    AUTHORITY = 5
    EMERGENCY = 6
    CONTRACT = 7
    CARE = 8
    FAIRNESS = 9

class HohfeldianState(Enum):
    RIGHT = 0
    OBLIGATION = 1
    LIBERTY = 2
    NO_RIGHT = 3

@dataclass
class Passage:
    id: str
    text_original: str
    text_english: str
    time_period: str
    century: int
    source: str
    source_type: str
    category: str
    language: str = "hebrew"
    word_count: int = 0
    has_dispute: bool = False
    consensus_tier: str = "unknown"
    bond_types: List[str] = field(default_factory=list)
    
    def to_dict(self):
        return asdict(self)

CATEGORY_TO_PERIOD = {
    'Tanakh': TimePeriod.BIBLICAL,
    'Torah': TimePeriod.BIBLICAL,
    'Mishnah': TimePeriod.TANNAITIC,
    'Tosefta': TimePeriod.TANNAITIC,
    'Talmud': TimePeriod.AMORAIC,
    'Bavli': TimePeriod.AMORAIC,
    'Midrash': TimePeriod.AMORAIC,
    'Halakhah': TimePeriod.RISHONIM,
    'Chasidut': TimePeriod.ACHRONIM,
}

PERIOD_TO_CENTURY = {
    TimePeriod.BIBLICAL: -6,
    TimePeriod.SECOND_TEMPLE: -2,
    TimePeriod.TANNAITIC: 2,
    TimePeriod.AMORAIC: 4,
    TimePeriod.GEONIC: 8,
    TimePeriod.RISHONIM: 12,
    TimePeriod.ACHRONIM: 17,
    TimePeriod.MODERN_HEBREW: 20,
}

def load_sefaria(base_path: str, max_passages: int = None) -> List[Passage]:
    """Load Sefaria corpus with progress bar."""
    passages = []
    json_path = Path(base_path) / "json"
    
    if not json_path.exists():
        print(f"Warning: {json_path} not found")
        return []
    
    json_files = list(json_path.rglob("*.json"))
    print(f"Processing {len(json_files):,} JSON files...")
    
    for json_file in tqdm(json_files[:max_passages] if max_passages else json_files,
                          desc="Loading Sefaria", unit="file"):
        try:
            with open(json_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except:
            continue
        
        rel_path = json_file.relative_to(json_path)
        category = str(rel_path.parts[0]) if rel_path.parts else "unknown"
        time_period = CATEGORY_TO_PERIOD.get(category, TimePeriod.AMORAIC)
        century = PERIOD_TO_CENTURY.get(time_period, 0)
        
        if isinstance(data, dict):
            hebrew = data.get('he', data.get('text', []))
            english = data.get('text', data.get('en', []))
            
            def flatten(h, e, ref=""):
                if isinstance(h, str) and isinstance(e, str):
                    h_clean = re.sub(r'<[^>]+>', '', h).strip()
                    e_clean = re.sub(r'<[^>]+>', '', e).strip()
                    if 50 <= len(e_clean) <= 2000:
                        pid = hashlib.md5(f"{json_file.stem}:{ref}:{h_clean[:50]}".encode()).hexdigest()[:12]
                        return [Passage(
                            id=f"sefaria_{pid}",
                            text_original=h_clean,
                            text_english=e_clean,
                            time_period=time_period.name,
                            century=century,
                            source=f"{json_file.stem} {ref}".strip(),
                            source_type="sefaria",
                            category=category,
                            language="hebrew",
                            word_count=len(e_clean.split())
                        )]
                    return []
                elif isinstance(h, list) and isinstance(e, list):
                    result = []
                    for i, (hh, ee) in enumerate(zip(h, e)):
                        result.extend(flatten(hh, ee, f"{ref}.{i+1}" if ref else str(i+1)))
                    return result
                return []
            
            passages.extend(flatten(hebrew, english))
    
    return passages

def load_dear_abby(path: str, max_passages: int = None) -> List[Passage]:
    """Load Dear Abby corpus with progress bar."""
    passages = []
    df = pd.read_csv(path)
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading Dear Abby", unit="row"):
        question = str(row.get('question_only', ''))
        if not question or question == 'nan' or len(question) < 50 or len(question) > 2000:
            continue
        
        year = int(row.get('year', 1990))
        pid = hashlib.md5(f"abby:{idx}:{question[:50]}".encode()).hexdigest()[:12]
        
        passages.append(Passage(
            id=f"abby_{pid}",
            text_original=question,
            text_english=question,
            time_period=TimePeriod.DEAR_ABBY.name,
            century=20 if year < 2000 else 21,
            source=f"Dear Abby {year}",
            source_type="dear_abby",
            category="general",
            language="english",
            word_count=len(question.split())
        ))
        
        if max_passages and len(passages) >= max_passages:
            break
    
    return passages

print("Data structures defined!")

In [None]:
#@title 5. Load and Preprocess Corpora { display-mode: "form" }
#@markdown Loads both corpora. Set MAX_SEFARIA_PASSAGES to limit memory usage.

#@markdown **Memory Management:**
MAX_SEFARIA_PASSAGES = 200000  #@param {type:"integer"}
#@markdown **FAST MODE:** 200K passages. Set to 500K+ for full run.  #@param {type:"integer"}
#@markdown Set to 0 for unlimited. Recommended: 500000 for Colab (12GB RAM)

import gc

mark_task("Preprocess corpora", "running")

print("=" * 60)
print("LOADING CORPORA")
print("=" * 60)
print()

if MAX_SEFARIA_PASSAGES > 0:
    print(f"*** MEMORY MODE: Limited to {MAX_SEFARIA_PASSAGES:,} Sefaria passages ***")
    print()

# Load Sefaria with optional limit
limit = MAX_SEFARIA_PASSAGES if MAX_SEFARIA_PASSAGES > 0 else None
sefaria_passages = load_sefaria("data/raw/Sefaria-Export", max_passages=limit)
print(f"\nSefaria passages loaded: {len(sefaria_passages):,}")

# Force garbage collection
gc.collect()

# Load Dear Abby
print()
abby_passages = load_dear_abby("data/raw/dear_abby.csv")
print(f"\nDear Abby passages loaded: {len(abby_passages):,}")

# Combine
all_passages = sefaria_passages + abby_passages

# Clear individual lists to save memory
del sefaria_passages
del abby_passages
gc.collect()

print()
print("=" * 60)
print(f"TOTAL PASSAGES: {len(all_passages):,}")
print("=" * 60)

# Statistics
by_period = defaultdict(int)
by_source = defaultdict(int)
for p in all_passages:
    by_period[p.time_period] += 1
    by_source[p.source_type] += 1

print("\nBy source:")
for source, count in sorted(by_source.items()):
    print(f"  {source}: {count:,}")

print("\nBy time period:")
for period, count in sorted(by_period.items()):
    pct = count / len(all_passages) * 100
    bar = '#' * int(pct / 2)
    print(f"  {period:20s}: {count:6,} ({pct:5.1f}%) {bar}")

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"\nMemory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

mark_task("Preprocess corpora", "done")


In [None]:
#@title 5a. Load Chinese Classics (Chinese Text Project) { display-mode: "form" }
#@markdown Downloads Confucian and Daoist classics in original Classical Chinese.
#@markdown Source: ctext.org (free API, academic use)

import requests
import time
import json
import os
from tqdm import tqdm

print("=" * 60)
print("LOADING CHINESE CLASSICS")
print("=" * 60)
print()
print("Source: Chinese Text Project (ctext.org)")
print("Period: ~500 BCE - 200 CE")
print("Language: Classical Chinese (文言文)")
print()

# CText API base
CTEXT_API = "https://api.ctext.org"

# Key Confucian/Daoist/Legalist texts with moral content
CHINESE_TEXTS = {
    # Confucian Four Books
    "analects": {"title": "論語 Analects", "period": "CONFUCIAN", "century": -5},
    "mengzi": {"title": "孟子 Mencius", "period": "CONFUCIAN", "century": -4},
    "daxue": {"title": "大學 Great Learning", "period": "CONFUCIAN", "century": -5},
    "zhongyong": {"title": "中庸 Doctrine of the Mean", "period": "CONFUCIAN", "century": -5},
    
    # Daoist
    "dao-de-jing": {"title": "道德經 Dao De Jing", "period": "DAOIST", "century": -6},
    "zhuangzi": {"title": "莊子 Zhuangzi", "period": "DAOIST", "century": -4},
    
    # Legalist/Other
    "xunzi": {"title": "荀子 Xunzi", "period": "CONFUCIAN", "century": -3},
    "mozi": {"title": "墨子 Mozi", "period": "MOHIST", "century": -5},
}

def fetch_ctext_book(textid, max_chapters=50):
    """Fetch a book from Chinese Text Project API."""
    passages = []
    
    try:
        # Get table of contents
        toc_url = f"{CTEXT_API}/gettoc?urn=ctp:{textid}"
        resp = requests.get(toc_url, timeout=30)
        
        if resp.status_code != 200:
            print(f"  Warning: Could not fetch TOC for {textid}")
            return passages
        
        toc = resp.json()
        
        if "children" not in toc:
            print(f"  Warning: No chapters in {textid}")
            return passages
        
        chapters = toc["children"][:max_chapters]
        
        for chapter in chapters:
            chapter_urn = chapter.get("urn", "")
            if not chapter_urn:
                continue
            
            # Get chapter text
            text_url = f"{CTEXT_API}/gettext?urn={chapter_urn}"
            
            try:
                resp = requests.get(text_url, timeout=30)
                if resp.status_code != 200:
                    continue
                
                data = resp.json()
                
                # Extract passages (each paragraph)
                if "text" in data:
                    for i, para in enumerate(data["text"]):
                        zh_text = para.get("text", "")
                        en_text = para.get("translation", "")  # May not always have translation
                        
                        if zh_text and len(zh_text) > 10:
                            passages.append({
                                "text_original": zh_text,
                                "text_english": en_text if en_text else "",
                                "source_ref": f"{chapter_urn}:{i}",
                            })
                
                time.sleep(0.2)  # Rate limiting
                
            except Exception as e:
                continue
        
    except Exception as e:
        print(f"  Error fetching {textid}: {e}")
    
    return passages

# Alternative: Download pre-packaged texts if API fails
def download_chinese_fallback():
    """Download Chinese texts from backup source."""
    passages = []
    
    # Analects (public domain, widely available)
    analects_url = "https://raw.githubusercontent.com/cjdd3b/chinese-texts/master/analects.json"
    
    try:
        resp = requests.get(analects_url, timeout=30)
        if resp.status_code == 200:
            data = resp.json()
            for item in data:
                if "chinese" in item and len(item["chinese"]) > 10:
                    passages.append({
                        "text_original": item["chinese"],
                        "text_english": item.get("english", ""),
                        "source_ref": item.get("ref", "analects"),
                    })
    except:
        pass
    
    return passages

# Load Chinese texts
chinese_passages = []
os.makedirs("data/raw/chinese", exist_ok=True)

# Check for cached data first
cache_file = "data/raw/chinese/all_passages.jsonl"
if os.path.exists(cache_file):
    print("Loading from cache...")
    with open(cache_file, 'r') as f:
        for line in f:
            chinese_passages.append(json.loads(line))
    print(f"Loaded {len(chinese_passages):,} cached passages")
else:
    print("Fetching from Chinese Text Project API...")
    print("(This may take a few minutes)")
    print()
    
    for textid, info in tqdm(CHINESE_TEXTS.items(), desc="Books"):
        print(f"\n  {info['title']}...")
        book_passages = fetch_ctext_book(textid)
        
        for p in book_passages:
            p["source"] = "ctext"
            p["source_type"] = "chinese_classic"
            p["category"] = info["title"]
            p["time_period"] = info["period"]
            p["century"] = info["century"]
            p["language"] = "chinese"
        
        chinese_passages.extend(book_passages)
        print(f"    Got {len(book_passages)} passages")
    
    # Fallback if API didn't work well
    if len(chinese_passages) < 1000:
        print("\nAPI limited, trying fallback sources...")
        fallback = download_chinese_fallback()
        for p in fallback:
            p["source"] = "ctext_fallback"
            p["source_type"] = "chinese_classic"
            p["category"] = "Analects"
            p["time_period"] = "CONFUCIAN"
            p["century"] = -5
            p["language"] = "chinese"
        chinese_passages.extend(fallback)
    
    # Cache results
    if chinese_passages:
        with open(cache_file, 'w') as f:
            for p in chinese_passages:
                f.write(json.dumps(p, ensure_ascii=False) + '\n')
        print(f"\nCached {len(chinese_passages):,} passages")

print()
print("=" * 60)
print(f"CHINESE CLASSICS LOADED: {len(chinese_passages):,} passages")
print("=" * 60)

# Show sample
if chinese_passages:
    sample = chinese_passages[0]
    print(f"\nSample passage:")
    print(f"  Chinese: {sample['text_original'][:100]}...")
    if sample.get('text_english'):
        print(f"  English: {sample['text_english'][:100]}...")
    print(f"  Period: {sample['time_period']}")
    print(f"  Century: {sample['century']} CE")


In [None]:
#@title 5b. Load Arabic Classics (Quran + Hadith) { display-mode: "form" }
#@markdown Downloads Quran and Hadith collections in original Arabic.
#@markdown Sources: quran.com API, sunnah.com

import requests
import time
import json
import os
from tqdm import tqdm

print("=" * 60)
print("LOADING ARABIC CLASSICS")
print("=" * 60)
print()
print("Sources: Quran API, Sunnah.com (Hadith)")
print("Period: 600 CE - 900 CE")
print("Language: Classical Arabic (العربية الفصحى)")
print()

# Quran API (free, no auth needed)
QURAN_API = "https://api.quran.com/api/v4"

# Sunnah.com API for Hadith
SUNNAH_API = "https://api.sunnah.com/v1"

def fetch_quran():
    """Fetch all Quran verses with translations."""
    passages = []
    
    print("Fetching Quran (114 surahs)...")
    
    for surah in tqdm(range(1, 115), desc="Surahs"):
        try:
            # Get Arabic text
            url = f"{QURAN_API}/verses/by_chapter/{surah}?language=en&words=false&translations=131&fields=text_uthmani"
            resp = requests.get(url, timeout=30)
            
            if resp.status_code != 200:
                continue
            
            data = resp.json()
            verses = data.get("verses", [])
            
            for verse in verses:
                arabic = verse.get("text_uthmani", "")
                
                # Get translation if available
                translations = verse.get("translations", [])
                english = translations[0]["text"] if translations else ""
                
                # Clean HTML tags from translation
                import re
                english = re.sub(r'<[^>]+>', '', english)
                
                if arabic and len(arabic) > 10:
                    passages.append({
                        "text_original": arabic,
                        "text_english": english,
                        "source_ref": f"quran:{surah}:{verse.get('verse_number', 0)}",
                        "source": "quran",
                        "source_type": "quran",
                        "category": f"Surah {surah}",
                        "time_period": "QURANIC",
                        "century": 7,
                        "language": "arabic",
                    })
            
            time.sleep(0.1)  # Rate limiting
            
        except Exception as e:
            continue
    
    return passages

def fetch_hadith_collection(collection, max_hadiths=2000):
    """Fetch hadiths from a collection."""
    passages = []
    
    # Sunnah.com requires API key, try alternative source
    # Use hadithapi.com (free tier)
    HADITH_API = "https://hadithapi.com/api"
    
    collections_map = {
        "bukhari": "bukhari",
        "muslim": "muslim", 
        "abudawud": "abudawud",
        "tirmidhi": "tirmidhi",
    }
    
    if collection not in collections_map:
        return passages
    
    try:
        # This API may have rate limits
        url = f"{HADITH_API}/{collections_map[collection]}?apiKey=$2y$10$HwOv6dXfZMRHxxxxxxxxxxxx"
        
        # Alternative: use pre-compiled hadith datasets
        # Many are available on GitHub/Kaggle
        pass
        
    except:
        pass
    
    return passages

def download_arabic_fallback():
    """Download Arabic texts from backup sources."""
    passages = []
    
    # Try to get Quran from alternative source
    quran_url = "https://raw.githubusercontent.com/risan/quran-json/main/quran.json"
    
    try:
        print("  Trying fallback Quran source...")
        resp = requests.get(quran_url, timeout=60)
        if resp.status_code == 200:
            data = resp.json()
            for surah in data:
                surah_num = surah.get("id", 0)
                for verse in surah.get("verses", []):
                    arabic = verse.get("text", "")
                    if arabic and len(arabic) > 10:
                        passages.append({
                            "text_original": arabic,
                            "text_english": "",  # This source may not have translations
                            "source_ref": f"quran:{surah_num}:{verse.get('id', 0)}",
                            "source": "quran",
                            "source_type": "quran",
                            "category": f"Surah {surah_num}",
                            "time_period": "QURANIC",
                            "century": 7,
                            "language": "arabic",
                        })
            print(f"    Got {len(passages)} verses")
    except Exception as e:
        print(f"    Fallback failed: {e}")
    
    # Try Hadith from tanzil.net or similar
    # These are often pre-compiled
    
    return passages

def create_arabic_from_embedded():
    """Create Arabic dataset from embedded essential texts."""
    passages = []
    
    # Core Quranic verses with strong moral content (embedded for reliability)
    # These are the most frequently cited verses on ethics/morality
    CORE_QURAN = [
        # Justice
        ("يَا أَيُّهَا الَّذِينَ آمَنُوا كُونُوا قَوَّامِينَ بِالْقِسْطِ شُهَدَاءَ لِلَّهِ", 
         "O you who believe, be persistently standing firm in justice, witnesses for Allah", "4:135"),
        ("إِنَّ اللَّهَ يَأْمُرُ بِالْعَدْلِ وَالْإِحْسَانِ", 
         "Indeed, Allah orders justice and good conduct", "16:90"),
        
        # Kindness to parents
        ("وَقَضَىٰ رَبُّكَ أَلَّا تَعْبُدُوا إِلَّا إِيَّاهُ وَبِالْوَالِدَيْنِ إِحْسَانًا",
         "Your Lord has decreed that you worship none but Him, and be good to your parents", "17:23"),
        
        # No compulsion
        ("لَا إِكْرَاهَ فِي الدِّينِ",
         "There is no compulsion in religion", "2:256"),
        
        # Honoring contracts
        ("يَا أَيُّهَا الَّذِينَ آمَنُوا أَوْفُوا بِالْعُقُودِ",
         "O you who believe, fulfill your contracts", "5:1"),
        
        # Speak truth
        ("يَا أَيُّهَا الَّذِينَ آمَنُوا اتَّقُوا اللَّهَ وَقُولُوا قَوْلًا سَدِيدًا",
         "O you who believe, fear Allah and speak words of truth", "33:70"),
        
        # Charity
        ("وَآتُوا حَقَّهُ يَوْمَ حَصَادِهِ",
         "And give its due on the day of harvest", "6:141"),
        
        # Patience
        ("وَاصْبِرْ وَمَا صَبْرُكَ إِلَّا بِاللَّهِ",
         "Be patient, for your patience is only through Allah", "16:127"),
        
        # Forgiveness
        ("وَلْيَعْفُوا وَلْيَصْفَحُوا أَلَا تُحِبُّونَ أَنْ يَغْفِرَ اللَّهُ لَكُمْ",
         "Let them pardon and overlook. Would you not like Allah to forgive you?", "24:22"),
        
        # Prohibition of murder
        ("وَلَا تَقْتُلُوا النَّفْسَ الَّتِي حَرَّمَ اللَّهُ إِلَّا بِالْحَقِّ",
         "Do not kill the soul which Allah has forbidden except by right", "6:151"),
        
        # Care for orphans
        ("وَيَسْأَلُونَكَ عَنِ الْيَتَامَىٰ قُلْ إِصْلَاحٌ لَهُمْ خَيْرٌ",
         "They ask you about orphans. Say: Improvement for them is best", "2:220"),
        
        # Honesty in trade
        ("وَأَوْفُوا الْكَيْلَ وَالْمِيزَانَ بِالْقِسْطِ",
         "Give full measure and weight in justice", "6:152"),
        
        # Mutual consultation
        ("وَأَمْرُهُمْ شُورَىٰ بَيْنَهُمْ",
         "Their affair is consultation among themselves", "42:38"),
        
        # Avoid suspicion
        ("يَا أَيُّهَا الَّذِينَ آمَنُوا اجْتَنِبُوا كَثِيرًا مِنَ الظَّنِّ",
         "O you who believe, avoid much suspicion", "49:12"),
        
        # Equality
        ("يَا أَيُّهَا النَّاسُ إِنَّا خَلَقْنَاكُمْ مِنْ ذَكَرٍ وَأُنْثَىٰ وَجَعَلْنَاكُمْ شُعُوبًا وَقَبَائِلَ لِتَعَارَفُوا",
         "O mankind, We created you from male and female and made you peoples and tribes that you may know one another", "49:13"),
    ]
    
    for arabic, english, ref in CORE_QURAN:
        passages.append({
            "text_original": arabic,
            "text_english": english,
            "source_ref": f"quran:{ref}",
            "source": "quran",
            "source_type": "quran",
            "category": "Core Ethics",
            "time_period": "QURANIC",
            "century": 7,
            "language": "arabic",
        })
    
    # Core Hadith on ethics
    CORE_HADITH = [
        # Golden Rule
        ("لا يُؤْمِنُ أَحَدُكُمْ حَتَّى يُحِبَّ لِأَخِيهِ مَا يُحِبُّ لِنَفْسِهِ",
         "None of you truly believes until he loves for his brother what he loves for himself", "bukhari:13"),
        
        # Kindness
        ("الرَّاحِمُونَ يَرْحَمُهُمُ الرَّحْمَنُ ارْحَمُوا مَنْ فِي الْأَرْضِ يَرْحَمْكُمْ مَنْ فِي السَّمَاءِ",
         "The merciful are shown mercy by the Most Merciful. Show mercy to those on earth, and the One above will show mercy to you", "tirmidhi:1924"),
        
        # Truth
        ("عَلَيْكُمْ بِالصِّدْقِ فَإِنَّ الصِّدْقَ يَهْدِي إِلَى الْبِرِّ",
         "You must be truthful, for truthfulness leads to righteousness", "bukhari:6094"),
        
        # Remove harm
        ("لَا ضَرَرَ وَلَا ضِرَارَ",
         "There should be no harm and no reciprocal harm", "ibn_majah:2341"),
        
        # Good character
        ("إِنَّمَا بُعِثْتُ لِأُتَمِّمَ مَكَارِمَ الْأَخْلَاقِ",
         "I was sent to perfect good character", "malik:muwatta"),
        
        # Helping others
        ("اللَّهُ فِي عَوْنِ الْعَبْدِ مَا كَانَ الْعَبْدُ فِي عَوْنِ أَخِيهِ",
         "Allah helps the servant as long as the servant helps his brother", "muslim:2699"),
        
        # Trust
        ("أَدِّ الْأَمَانَةَ إِلَى مَنِ ائْتَمَنَكَ",
         "Render trusts to those who entrusted you", "tirmidhi:1264"),
        
        # Avoiding anger
        ("لَيْسَ الشَّدِيدُ بِالصُّرَعَةِ إِنَّمَا الشَّدِيدُ الَّذِي يَمْلِكُ نَفْسَهُ عِنْدَ الْغَضَبِ",
         "The strong man is not the one who can wrestle, but the one who controls himself when angry", "bukhari:6114"),
        
        # Feeding the hungry
        ("أَطْعِمُوا الْجَائِعَ وَعُودُوا الْمَرِيضَ وَفُكُّوا الْعَانِيَ",
         "Feed the hungry, visit the sick, and free the captive", "bukhari:5649"),
        
        # Neighbor rights
        ("مَا زَالَ جِبْرِيلُ يُوصِينِي بِالْجَارِ حَتَّى ظَنَنْتُ أَنَّهُ سَيُوَرِّثُهُ",
         "Gabriel kept recommending the neighbor to me until I thought he would make him an heir", "bukhari:6015"),
    ]
    
    for arabic, english, ref in CORE_HADITH:
        passages.append({
            "text_original": arabic,
            "text_english": english,
            "source_ref": f"hadith:{ref}",
            "source": "hadith",
            "source_type": "hadith",
            "category": "Core Ethics",
            "time_period": "HADITH",
            "century": 8,
            "language": "arabic",
        })
    
    return passages

# Load Arabic texts
arabic_passages = []
os.makedirs("data/raw/arabic", exist_ok=True)

cache_file = "data/raw/arabic/all_passages.jsonl"
if os.path.exists(cache_file):
    print("Loading from cache...")
    with open(cache_file, 'r') as f:
        for line in f:
            arabic_passages.append(json.loads(line))
    print(f"Loaded {len(arabic_passages):,} cached passages")
else:
    # Try API first
    print("Fetching from Quran API...")
    arabic_passages = fetch_quran()
    
    # If API limited, use fallback
    if len(arabic_passages) < 1000:
        print("\nAPI limited, trying fallback...")
        arabic_passages.extend(download_arabic_fallback())
    
    # Always add embedded core texts (guaranteed to work)
    print("\nAdding embedded core ethical texts...")
    arabic_passages.extend(create_arabic_from_embedded())
    
    # Cache
    if arabic_passages:
        with open(cache_file, 'w') as f:
            for p in arabic_passages:
                f.write(json.dumps(p, ensure_ascii=False) + '\n')
        print(f"Cached {len(arabic_passages):,} passages")

print()
print("=" * 60)
print(f"ARABIC CLASSICS LOADED: {len(arabic_passages):,} passages")
print("=" * 60)

# Show sample
if arabic_passages:
    sample = arabic_passages[0]
    print(f"\nSample passage:")
    print(f"  Arabic: {sample['text_original'][:100]}...")
    if sample.get('text_english'):
        print(f"  English: {sample['text_english'][:100]}...")
    print(f"  Period: {sample['time_period']}")
    print(f"  Century: {sample['century']} CE")


In [None]:
#@title 5c. Merge All Corpora (Hebrew + Chinese + Arabic) { display-mode: "form" }
#@markdown Combines all ancient moral texts into unified dataset.
#@markdown Each passage retains its original language.

import json
import os
from collections import Counter
from tqdm import tqdm

print("=" * 60)
print("MERGING CROSS-CULTURAL CORPUS")
print("=" * 60)
print()

# Aggregate all passages
all_ancient_passages = []

# 1. Hebrew (Sefaria) - load from disk if not in memory
if 'passages' not in dir() or not passages:
    print("Loading Hebrew passages from disk...")
    passages = []
    with open("data/processed/passages.jsonl", 'r') as f:
        for line in f:
            passages.append(json.loads(line))

print(f"Hebrew (Sefaria):  {len(passages):,} passages")
for p in passages:
    if hasattr(p, 'to_dict'):
        p_dict = p.to_dict()
    elif isinstance(p, dict):
        p_dict = p
    else:
        continue
    p_dict['language'] = 'hebrew'
    p_dict['corpus'] = 'sefaria'
    all_ancient_passages.append(p_dict)

# 2. Chinese classics
if 'chinese_passages' in dir() and chinese_passages:
    print(f"Chinese (CText):   {len(chinese_passages):,} passages")
    for p in chinese_passages:
        p['corpus'] = 'ctext'
        all_ancient_passages.append(p)
else:
    print("Chinese: Not loaded (run cell 5a)")

# 3. Arabic (Quran + Hadith)
if 'arabic_passages' in dir() and arabic_passages:
    print(f"Arabic (Quran):    {len(arabic_passages):,} passages")
    for p in arabic_passages:
        p['corpus'] = 'quran_hadith'
        all_ancient_passages.append(p)
else:
    print("Arabic: Not loaded (run cell 5b)")

print()
print(f"TOTAL ANCIENT: {len(all_ancient_passages):,} passages")
print()

# Language distribution
lang_counts = Counter(p.get('language', 'unknown') for p in all_ancient_passages)
print("Language Distribution:")
for lang, count in sorted(lang_counts.items(), key=lambda x: -x[1]):
    pct = count / len(all_ancient_passages) * 100
    print(f"  {lang:10s}: {count:>10,} ({pct:5.1f}%)")

print()

# Time period distribution
period_counts = Counter(p.get('time_period', 'unknown') for p in all_ancient_passages)
print("Time Period Distribution:")
for period, count in sorted(period_counts.items(), key=lambda x: -x[1])[:15]:
    pct = count / len(all_ancient_passages) * 100
    print(f"  {period:15s}: {count:>10,} ({pct:5.1f}%)")

print()

# Update time period mapping to include all cultures
UNIFIED_TIME_PERIODS = {
    # Hebrew
    'BIBLICAL': 0,
    'SECOND_TEMPLE': 1,
    'TANNAITIC': 2,
    'AMORAIC': 3,
    'GEONIC': 4,
    'RISHONIM': 5,
    'ACHRONIM': 6,
    'MODERN_HEBREW': 7,
    
    # Chinese
    'CONFUCIAN': 8,
    'DAOIST': 9,
    'MOHIST': 10,
    
    # Arabic
    'QURANIC': 11,
    'HADITH': 12,
    
    # Modern English
    'DEAR_ABBY': 13,
}

print(f"Unified time periods: {len(UNIFIED_TIME_PERIODS)}")

# Save merged corpus
print()
print("Saving merged corpus...")

os.makedirs("data/processed", exist_ok=True)

with open("data/processed/all_passages_multilingual.jsonl", 'w') as f:
    for i, p in enumerate(all_ancient_passages):
        p['unified_id'] = f"ancient_{i:07d}"
        f.write(json.dumps(p, ensure_ascii=False) + '\n')

print(f"Saved to data/processed/all_passages_multilingual.jsonl")

# Summary statistics
print()
print("=" * 60)
print("CROSS-CULTURAL CORPUS READY")
print("=" * 60)
print()
print(f"  Languages: {len(lang_counts)}")
print(f"  Hebrew passages:  {lang_counts.get('hebrew', 0):,}")
print(f"  Chinese passages: {lang_counts.get('chinese', 0):,}")
print(f"  Arabic passages:  {lang_counts.get('arabic', 0):,}")
print(f"  Time periods: {len(period_counts)}")
print(f"  Total: {len(all_ancient_passages):,}")
print()
print("The multilingual encoder will map all languages")
print("into a shared semantic space for training.")


In [None]:
#@title 6. Extract Bond Structures { display-mode: "form" }
#@markdown Extracts moral bond structures. Streams to disk to save memory.

import gc

mark_task("Extract bond structures", "running")

RELATION_PATTERNS = {
    BondType.HARM_PREVENTION: [r'\b(kill|murder|harm|hurt|save|rescue|protect|danger)\b'],
    BondType.RECIPROCITY: [r'\b(return|repay|owe|debt|mutual|exchange)\b'],
    BondType.AUTONOMY: [r'\b(choose|decision|consent|agree|force|coerce|right)\b'],
    BondType.PROPERTY: [r'\b(property|own|steal|theft|buy|sell|land)\b'],
    BondType.FAMILY: [r'\b(honor|parent|marry|divorce|inherit|family)\b'],
    BondType.AUTHORITY: [r'\b(obey|command|law|judge|rule|teach)\b'],
    BondType.CARE: [r'\b(care|help|assist|feed|clothe|visit)\b'],
    BondType.FAIRNESS: [r'\b(fair|just|equal|deserve|bias)\b'],
}

HOHFELD_PATTERNS = {
    HohfeldianState.OBLIGATION: [r'\b(must|shall|duty|require|should)\b'],
    HohfeldianState.RIGHT: [r'\b(right to|entitled|deserve)\b'],
    HohfeldianState.LIBERTY: [r'\b(may|can|permitted|allowed)\b'],
}

def extract_bond_structure(passage: Passage) -> Dict:
    """Extract bond structure from passage."""
    text = passage.text_english.lower()
    
    relations = []
    for rel_type, patterns in RELATION_PATTERNS.items():
        for pattern in patterns:
            if re.search(pattern, text, re.IGNORECASE):
                relations.append(rel_type.name)
                break
    
    if not relations:
        relations = ['CARE']
    
    hohfeld = None
    for state, patterns in HOHFELD_PATTERNS.items():
        for pattern in patterns:
            if re.search(pattern, text, re.IGNORECASE):
                hohfeld = state.name
                break
        if hohfeld:
            break
    
    signature = "|".join(sorted(set(relations)))
    
    return {
        'bonds': [{'relation': r} for r in relations],
        'primary_relation': relations[0],
        'hohfeld_state': hohfeld,
        'signature': signature
    }

print("=" * 60)
print("EXTRACTING & SAVING (STREAMING)")
print("=" * 60)
print()
print("Writing directly to disk to conserve memory...")
print()

bond_counts = defaultdict(int)

# Stream directly to files - don't accumulate in memory
with open("data/processed/passages.jsonl", 'w') as f_pass, \
     open("data/processed/bond_structures.jsonl", 'w') as f_bond:
    
    for passage in tqdm(all_passages, desc="Processing", unit="passage"):
        # Extract bonds
        bond_struct = extract_bond_structure(passage)
        passage.bond_types = [b['relation'] for b in bond_struct['bonds']]
        
        # Count for stats
        for bond in bond_struct['bonds']:
            bond_counts[bond['relation']] += 1
        
        # Write immediately (don't accumulate)
        f_pass.write(json.dumps(passage.to_dict()) + '\n')
        f_bond.write(json.dumps({
            'passage_id': passage.id,
            'bond_structure': bond_struct
        }) + '\n')

# Clear passages from memory - we've saved them to disk
n_passages = len(all_passages)
del all_passages
gc.collect()

print()
print(f"Saved {n_passages:,} passages to disk")
print("Cleared passages from memory")

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"Memory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

print()
print("Bond type distribution:")
for bond_type, count in sorted(bond_counts.items(), key=lambda x: -x[1]):
    pct = count / sum(bond_counts.values()) * 100
    bar = '#' * int(pct)
    print(f"  {bond_type:20s}: {count:6,} ({pct:5.1f}%) {bar}")

mark_task("Extract bond structures", "done")


In [None]:
#@title 7. Generate Train/Test Splits { display-mode: "form" }
#@markdown Creates splits from saved files. Memory efficient - reads only IDs.

import random
import gc
random.seed(42)

mark_task("Generate train/test splits", "running")

print("=" * 60)
print("GENERATING SPLITS (MEMORY EFFICIENT)")
print("=" * 60)
print()

# Read only IDs and time periods from disk - don't load full passages
print("Reading passage metadata from disk...")
passage_info = []  # List of (id, time_period) tuples - minimal memory

with open("data/processed/passages.jsonl", 'r') as f:
    for line in tqdm(f, desc="Reading IDs", unit="line"):
        p = json.loads(line)
        passage_info.append((p['id'], p['time_period']))

print(f"Loaded {len(passage_info):,} passage IDs")

# Define time periods
train_periods = {'BIBLICAL', 'SECOND_TEMPLE', 'TANNAITIC', 'AMORAIC', 'GEONIC', 'RISHONIM'}
valid_periods = {'ACHRONIM'}
test_periods = {'MODERN_HEBREW', 'DEAR_ABBY'}

print()
print("Filtering by time period...")
ancient_ids = [(pid, tp) for pid, tp in passage_info if tp in train_periods]
early_modern_ids = [(pid, tp) for pid, tp in passage_info if tp in valid_periods]
modern_ids = [(pid, tp) for pid, tp in passage_info if tp in test_periods]

print(f"  Ancient/Medieval: {len(ancient_ids):,}")
print(f"  Early Modern:     {len(early_modern_ids):,}")
print(f"  Modern:           {len(modern_ids):,}")

# Shuffle
random.shuffle(ancient_ids)
random.shuffle(early_modern_ids)
random.shuffle(modern_ids)

# ============================================================
# SPLIT A: ANCIENT -> MODERN
# ============================================================
print()
print("-" * 60)
print("SPLIT A: Train ANCIENT, Test MODERN")
print("-" * 60)

temporal_A = {
    'name': 'ancient_to_modern',
    'direction': 'A->M',
    'train_ids': [pid for pid, _ in ancient_ids],
    'valid_ids': [pid for pid, _ in early_modern_ids],
    'test_ids': [pid for pid, _ in modern_ids],
    'train_size': len(ancient_ids),
    'valid_size': len(early_modern_ids),
    'test_size': len(modern_ids)
}
print(f"  Train: {temporal_A['train_size']:,}")
print(f"  Valid: {temporal_A['valid_size']:,}")
print(f"  Test:  {temporal_A['test_size']:,}")

# ============================================================
# SPLIT B: MODERN -> ANCIENT
# ============================================================
print()
print("-" * 60)
print("SPLIT B: Train MODERN, Test ANCIENT")
print("-" * 60)

n_modern = len(modern_ids)
ancient_test = ancient_ids[n_modern:n_modern*2] if len(ancient_ids) >= n_modern*2 else ancient_ids[n_modern:]

temporal_B = {
    'name': 'modern_to_ancient',
    'direction': 'M->A',
    'train_ids': [pid for pid, _ in modern_ids],
    'valid_ids': [pid for pid, _ in early_modern_ids[:len(early_modern_ids)//2]],
    'test_ids': [pid for pid, _ in ancient_test],
    'train_size': len(modern_ids),
    'valid_size': len(early_modern_ids) // 2,
    'test_size': len(ancient_test)
}
print(f"  Train: {temporal_B['train_size']:,}")
print(f"  Valid: {temporal_B['valid_size']:,}")
print(f"  Test:  {temporal_B['test_size']:,}")

# ============================================================
# SPLIT C: MIXED CONTROL
# ============================================================
print()
print("-" * 60)
print("SPLIT C: MIXED (Control)")
print("-" * 60)

all_ids = ancient_ids + modern_ids
random.shuffle(all_ids)
n = len(all_ids)
n_train = int(0.7 * n)
n_valid = int(0.15 * n)

temporal_C = {
    'name': 'mixed_control',
    'direction': 'MIXED',
    'train_ids': [pid for pid, _ in all_ids[:n_train]],
    'valid_ids': [pid for pid, _ in all_ids[n_train:n_train+n_valid]],
    'test_ids': [pid for pid, _ in all_ids[n_train+n_valid:]],
    'train_size': n_train,
    'valid_size': n_valid,
    'test_size': n - n_train - n_valid
}
print(f"  Train: {temporal_C['train_size']:,}")
print(f"  Valid: {temporal_C['valid_size']:,}")
print(f"  Test:  {temporal_C['test_size']:,}")

# Clear temporary data
del passage_info, ancient_ids, early_modern_ids, modern_ids, all_ids
gc.collect()

# Save
print()
print("Saving splits...")
splits = {
    'ancient_to_modern': temporal_A,
    'modern_to_ancient': temporal_B,
    'mixed_control': temporal_C
}

with open("data/splits/all_splits.json", 'w') as f:
    json.dump(splits, f, indent=2)

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"Memory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

print()
print("SPLITS SAVED:")
print("  - ancient_to_modern (A->M)")
print("  - modern_to_ancient (M->A)")  
print("  - mixed_control")





# ============================================================
# DISTRIBUTION CHECK - Catch problems early
# ============================================================
print()
print("=" * 60)
print("LABEL DISTRIBUTION CHECK")
print("=" * 60)

# Count Hohfeld labels
hohfeld_counts = {}
time_counts = {}
with open("data/processed/bond_structures.jsonl", 'r') as fb, \
     open("data/processed/passages.jsonl", 'r') as fp:
    for b_line, p_line in zip(fb, fp):
        b = json.loads(b_line)
        p = json.loads(p_line)
        h = b['bond_structure'].get('hohfeld_state', None)
        t = p['time_period']
        hohfeld_counts[h] = hohfeld_counts.get(h, 0) + 1
        time_counts[t] = time_counts.get(t, 0) + 1

print()
print("Hohfeld distribution:")
total_h = sum(hohfeld_counts.values())
for h, c in sorted(hohfeld_counts.items(), key=lambda x: -x[1]):
    pct = 100 * c / total_h
    bar = "#" * int(pct / 2)
    print(f"  {str(h):15s}: {c:>8,} ({pct:5.1f}%) {bar}")

print()
print("Time period distribution:")
total_t = sum(time_counts.values())
for t, c in sorted(time_counts.items(), key=lambda x: -x[1]):
    pct = 100 * c / total_t
    bar = "#" * int(pct / 2)
    print(f"  {t:15s}: {c:>8,} ({pct:5.1f}%) {bar}")

# Compute actual chance baselines
N_HOHFELD_CLASSES = len([h for h in hohfeld_counts if h is not None]) + 1  # +1 for None
N_TIME_CLASSES = len(time_counts)
CHANCE_HOHFELD = 1.0 / N_HOHFELD_CLASSES
CHANCE_TIME = 1.0 / N_TIME_CLASSES
print()
print(f"Chance baseline - Hohfeld: {CHANCE_HOHFELD:.1%} ({N_HOHFELD_CLASSES} classes)")
print(f"Chance baseline - Time:    {CHANCE_TIME:.1%} ({N_TIME_CLASSES} classes)")

# Save baselines for later
baselines = {
    'hohfeld_counts': {str(k): v for k, v in hohfeld_counts.items()},
    'time_counts': time_counts,
    'chance_hohfeld': CHANCE_HOHFELD,
    'chance_time': CHANCE_TIME,
    'n_hohfeld_classes': N_HOHFELD_CLASSES,
    'n_time_classes': N_TIME_CLASSES
}
with open("data/splits/baselines.json", 'w') as f:
    json.dump(baselines, f, indent=2)

# Warn if severe imbalance
most_common_hohfeld = max(hohfeld_counts.values()) / total_h
if most_common_hohfeld > 0.7:
    print()
    print(f"WARNING: Hohfeld labels severely imbalanced! Most common = {most_common_hohfeld:.1%}")
    print("         Model may just predict majority class.")

# ============================================================
# SAVE PREPROCESSING TO DRIVE
# ============================================================
print()
print("=" * 60)
print("SAVING PREPROCESSED DATA TO GOOGLE DRIVE")
print("=" * 60)
import shutil
shutil.copytree("data/processed", f"{SAVE_DIR}/processed", dirs_exist_ok=True)
shutil.copytree("data/splits", f"{SAVE_DIR}/splits", dirs_exist_ok=True)
print(f"Saved to {SAVE_DIR}")
print("If session dies, run: !cp -r {SAVE_DIR}/* data/")
print("Then skip to Cell 8.")
print()

mark_task("Generate train/test splits", "done")



In [None]:
#@title 8. Define BIP Model Architecture { display-mode: "form" }
#@markdown Defines the model. Clears memory first to avoid OOM.

import gc
import psutil

# CRITICAL: Clear memory before loading model
print("Clearing memory before model load...")
gc.collect()

if torch.cuda.is_available():
    torch.cuda.empty_cache()

mem = psutil.virtual_memory()
print(f"Memory before model: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")
print()


import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import gc

print("=" * 60)
print("DEFINING MODEL ARCHITECTURE")
print("=" * 60)
print()
print("*** CROSS-CULTURAL MODE ***")
print("Encoder: paraphrase-multilingual-MiniLM-L12-v2")
print("  - Trained on 50+ languages including Hebrew and English")
print("  - Maps both languages into shared embedding space")
print("  - Sefaria passages: ORIGINAL HEBREW")
print("  - Dear Abby passages: ENGLISH")
print()
print("This is the STRONG test: Does Hebrew moral structure")
print("transfer to English with no translation intermediary?")
print()
print("=" * 60)
print()
print("=" * 60)
print()

class GradientReversal(torch.autograd.Function):
    """Gradient reversal layer for adversarial training."""
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()
    
    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_ * grad_output, None

def gradient_reversal(x, lambda_=1.0):
    return GradientReversal.apply(x, lambda_)

class BIPEncoder(nn.Module):
    """Sentence encoder using pretrained transformer."""
    def __init__(self, model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", d_model=384):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.d_model = d_model
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
        return pooled

class BIPModel(nn.Module):
    """Bond Invariance Principle Model with adversarial disentanglement."""
    def __init__(self, d_model=384, d_bond=64, d_label=32, n_periods=14, n_hohfeld=4):
        super().__init__()
        
        self.encoder = BIPEncoder()
        
        self.bond_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, d_bond)
        )
        
        self.label_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, d_label)
        )
        
        self.time_classifier_bond = nn.Linear(d_bond, n_periods)
        self.time_classifier_label = nn.Linear(d_label, n_periods)
        self.hohfeld_classifier = nn.Linear(d_bond, n_hohfeld)
    
    def forward(self, input_ids, attention_mask, adversarial_lambda=1.0):
        h = self.encoder(input_ids, attention_mask)
        
        z_bond = self.bond_proj(h)
        z_label = self.label_proj(h)
        
        z_bond_adv = gradient_reversal(z_bond, adversarial_lambda)
        time_pred_bond = self.time_classifier_bond(z_bond_adv)
        time_pred_label = self.time_classifier_label(z_label)
        hohfeld_pred = self.hohfeld_classifier(z_bond)
        
        return {
            'z_bond': z_bond,
            'z_label': z_label,
            'time_pred_bond': time_pred_bond,
            'time_pred_label': time_pred_label,
            'hohfeld_pred': hohfeld_pred
        }

# Time period mapping
TIME_PERIOD_TO_IDX = {
    'BIBLICAL': 0, 'SECOND_TEMPLE': 1, 'TANNAITIC': 2, 'AMORAIC': 3,
    'GEONIC': 4, 'RISHONIM': 5, 'ACHRONIM': 6, 'MODERN_HEBREW': 7,
    # Chinese
    'CONFUCIAN': 8, 'DAOIST': 9, 'MOHIST': 10,
    # Arabic  
    'QURANIC': 11, 'HADITH': 12,
    # Modern
    'DEAR_ABBY': 13
}

HOHFELD_TO_IDX = {
    'OBLIGATION': 0, 'RIGHT': 1, 'LIBERTY': 2, None: 3
}

class MoralDataset(Dataset):
    """
    MEMORY-EFFICIENT Dataset that reads from disk on demand.
    Does NOT load all data into memory at once.
    """
    def __init__(self, passage_ids: set, passages_file: str, bonds_file: str, tokenizer, max_len=64):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.passage_ids = passage_ids
        
        # Build index: passage_id -> (file_offset, line_length) for passages file
        # This allows us to seek directly to the line we need
        print(f"  Indexing {len(passage_ids):,} passages...")
        
        self.data = []  # Store minimal data: (text, time_period, hohfeld_state)
        
        # Load only the passages we need
        with open(passages_file, 'r') as f_pass, open(bonds_file, 'r') as f_bond:
            for p_line, b_line in tqdm(zip(f_pass, f_bond), desc="  Loading subset", unit="line", total=None):
                p = json.loads(p_line)
                if p['id'] in passage_ids:
                    b = json.loads(b_line)
                    self.data.append({
                        'text': (p.get('text_original', '') if p.get('language') in ['hebrew', 'chinese', 'arabic'] else p.get('text_english', ''))[:1000],  # Use native script for ancient (Hebrew/Chinese/Arabic), English for modern
                        'time_period': p['time_period'],
                        'hohfeld': b['bond_structure']['hohfeld_state']
                    })
        
        print(f"  Loaded {len(self.data):,} samples")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            max_length=self.max_len,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'time_label': TIME_PERIOD_TO_IDX.get(item['time_period'], 8),
            'hohfeld_label': HOHFELD_TO_IDX.get(item['hohfeld'], 3)
        }

def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'time_labels': torch.tensor([x['time_label'] for x in batch]),
        'hohfeld_labels': torch.tensor([x['hohfeld_label'] for x in batch])
    }

# Memory cleanup
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("Model architecture defined!")
print()

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"Memory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")


In [None]:
#@title 9. Train BIP Model - BIDIRECTIONAL { display-mode: "form" }
#@markdown Trains on BOTH directions for stronger invariance testing.

import gc
import psutil

mark_task("Train BIP model", "running")

# Memory cleanup before training
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

mem = psutil.virtual_memory()
print(f"Memory at start: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB")

print("=" * 60)
print("BIDIRECTIONAL BIP TRAINING")
print("=" * 60)
print()
print(f"Accelerator: {ACCELERATOR}")
print(f"Device: {device}")
print()

# Load tokenizer once
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

# Store results for both directions
# Memory cleanup between directions built into the loop
all_results = {}

for split_name in ['ancient_to_modern', 'modern_to_ancient']:
    print()
    print("=" * 60)
    print(f"DIRECTION {split_name}: {'Ancient → Modern' if split_name == 'ancient_to_modern' else 'Modern → Ancient'}")
    print("=" * 60)
    print()
    
    # Load appropriate split
    # Keys match what Cell 7 saved
    with open("data/splits/all_splits.json", 'r') as f:
        splits = json.load(f)
    split = splits[split_name]
    
    print(f"Train: {split['train_size']:,}")
    print(f"Valid: {split['valid_size']:,}")
    print(f"Test:  {split['test_size']:,}")
    print()
    
    # Create fresh model for each direction
    print("Creating fresh model...")
    model = BIPModel().to(device)
    
    # Compile model for speed (PyTorch 2.0+)
    if TORCH_COMPILE:
        print("Compiling model with torch.compile...")
        model = torch.compile(model, mode="reduce-overhead")
    
    if split_name == 'ancient_to_modern':
        print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Create datasets
    print("Creating datasets...")
    train_dataset = MoralDataset(
        set(split['train_ids']),
        "data/processed/passages.jsonl",
        "data/processed/bond_structures.jsonl",
        tokenizer
    )
    valid_dataset = MoralDataset(
        set(split['valid_ids']),
        "data/processed/passages.jsonl",
        "data/processed/bond_structures.jsonl",
        tokenizer
    )
    test_dataset = MoralDataset(
        set(split['test_ids']),
        "data/processed/passages.jsonl",
        "data/processed/bond_structures.jsonl",
        tokenizer
    )
    
    print(f"Train samples: {len(train_dataset):,}")
    print(f"Valid samples: {len(valid_dataset):,}")
    print(f"Test samples:  {len(test_dataset):,}")
    print()
    
    if len(train_dataset) == 0:
        print("ERROR: No training data!")
        continue
    
    # Adjust batch size based on dataset size
    batch_size = 256 if split_name == 'ancient_to_modern' else min(32, len(train_dataset) // 10)
    batch_size = max(32, batch_size)  # Minimum batch size
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              collate_fn=collate_fn, drop_last=True, num_workers=4, pin_memory=True, prefetch_factor=4, persistent_workers=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size*2, shuffle=False,
                              collate_fn=collate_fn, num_workers=4, pin_memory=True, prefetch_factor=4, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size*2, shuffle=False,
                             collate_fn=collate_fn, num_workers=4, pin_memory=True, prefetch_factor=4, persistent_workers=True)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
    
    # Fewer epochs for direction B (smaller dataset)
    n_epochs = 3  # Fast mode if direction == 'A' else 15
    best_valid_loss = float('inf')
    patience = 3  # Early stopping patience
    patience_counter = 0
    
    print(f"Training for {n_epochs} epochs (batch_size={batch_size})...")
    print()
    
    for epoch in range(1, n_epochs + 1):
        model.train()
        total_loss = 0
        n_batches = 0
        
        pbar = tqdm(train_loader, desc=f"[{split_name}] Epoch {epoch}/{n_epochs}", unit="batch")
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            time_labels = batch['time_labels'].to(device)
            hohfeld_labels = batch['hohfeld_labels'].to(device)
            
            # Mixed precision forward pass
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                outputs = model(input_ids, attention_mask, adversarial_lambda=1.0)
                
                # Losses
                # STANDARD DANN: Cross-entropy on time classifier + GRL handles adversarial
                # DO NOT use entropy maximization - it double-reverses!
                loss_time_bond = F.cross_entropy(outputs['time_pred_bond'], time_labels)
                loss_time_label = F.cross_entropy(outputs['time_pred_label'], time_labels)
                loss_hohfeld = F.cross_entropy(outputs['hohfeld_pred'], hohfeld_labels)
            
            loss = loss_hohfeld + loss_time_label + loss_time_bond
            
            optimizer.zero_grad()
            if USE_AMP and scaler is not None:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            if USE_TPU:
                xm.optimizer_step(optimizer)
                xm.mark_step()
            # GPU step handled in mixed precision block above
            
            total_loss += loss.item()
            n_batches += 1
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})
        
        avg_train_loss = total_loss / n_batches
        
        # Validation
        model.eval()
        valid_loss = 0
        valid_batches = 0
        time_correct = 0
        time_total = 0
        hohfeld_correct = 0
        hohfeld_total = 0
        
        with torch.no_grad():
            for batch in valid_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                time_labels = batch['time_labels'].to(device)
                hohfeld_labels = batch['hohfeld_labels'].to(device)
                
                outputs = model(input_ids, attention_mask, adversarial_lambda=0)
                loss = F.cross_entropy(outputs['hohfeld_pred'], hohfeld_labels)
                valid_loss += loss.item()
                valid_batches += 1
                
                time_preds = outputs['time_pred_bond'].argmax(dim=-1)
                time_correct += (time_preds == time_labels).sum().item()
                time_total += len(time_labels)
                
                hohfeld_preds = outputs['hohfeld_pred'].argmax(dim=-1)
                hohfeld_correct += (hohfeld_preds == hohfeld_labels).sum().item()
                hohfeld_total += len(hohfeld_labels)
                
                if USE_TPU:
                    xm.mark_step()
        
        avg_valid_loss = valid_loss / valid_batches if valid_batches > 0 else 0
        time_acc = time_correct / time_total if time_total > 0 else 0
        hohfeld_acc_val = hohfeld_correct / hohfeld_total if hohfeld_total > 0 else 0
        
        print(f"[{split_name}] Epoch {epoch}: Loss={avg_train_loss:.4f}/{avg_valid_loss:.4f}, Hohfeld={hohfeld_acc_val:.1%}, TimeAcc={time_acc:.1%}")
        
        if avg_valid_loss < best_valid_loss:
            best_valid_loss = avg_valid_loss
            model_path = f"models/checkpoints/best_model_{split_name}.pt"
            if USE_TPU:
                xm.save(model.state_dict(), model_path)
            else:
                torch.save(model.state_dict(), model_path)
            print(f"  -> Saved best model for {split_name}!")
            # Backup to Drive
            import shutil
            shutil.copy(model_path, f"{SAVE_DIR}/best_model_{split_name}.pt")
            print(f"  -> Backed up to Google Drive")
            patience_counter = 0  # Reset patience
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"  Early stopping at epoch {epoch} (no improvement for {patience} epochs)")
                break
    
    # Evaluate on test set
    print()
    print(f"Evaluating {split_name} on test set...")
    
    model.load_state_dict(torch.load(f"models/checkpoints/best_model_{split_name}.pt", map_location='cpu'))
    model = model.to(device)
    model.eval()
    
    all_time_preds = []
    all_time_labels = []
    all_hohfeld_preds = []
    all_hohfeld_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"[{split_name}] Testing", unit="batch"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(input_ids, attention_mask, adversarial_lambda=0)
            
            all_time_preds.extend(outputs['time_pred_bond'].argmax(dim=-1).cpu().tolist())
            all_time_labels.extend(batch['time_labels'].tolist())
            all_hohfeld_preds.extend(outputs['hohfeld_pred'].argmax(dim=-1).cpu().tolist())
            all_hohfeld_labels.extend(batch['hohfeld_labels'].tolist())
            
            if USE_TPU:
                xm.mark_step()
    
    # Calculate metrics
    time_acc = sum(p == l for p, l in zip(all_time_preds, all_time_labels)) / len(all_time_preds)
    hohfeld_acc = sum(p == l for p, l in zip(all_hohfeld_preds, all_hohfeld_labels)) / len(all_hohfeld_preds)
    
    all_results[split_name] = {
        'time_acc': time_acc,
        'hohfeld_acc': hohfeld_acc,
        'train_size': split['train_size'],
        'test_size': split['test_size']
    }
    
    print()
    print(f"{split_name.upper()} RESULTS:")
    print(f"  Time prediction from z_bond: {time_acc:.1%} (chance ~11%)")
    print(f"  Hohfeld classification:      {hohfeld_acc:.1%} (chance 25%)")

print()
print("=" * 60)
print("TRAINING COMPLETE - BOTH DIRECTIONS")
print("=" * 60)

mark_task("Train BIP model", "done")


In [None]:
#@title 10. Evaluate Bidirectional Results { display-mode: "form" }
#@markdown Compares results from BOTH directions to assess true invariance.

import gc
import psutil

mark_task("Evaluate results", "running")

from collections import Counter
try:
    from sklearn.metrics import confusion_matrix, classification_report
    HAS_SKLEARN = True
except ImportError:
    HAS_SKLEARN = False
    print("sklearn not available - skipping confusion matrices")

mem = psutil.virtual_memory()
print(f"Memory at eval start: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB")

print("=" * 60)
print("BIDIRECTIONAL BIP RESULTS")
print("=" * 60)
print()

chance_time = 1/9  # 9 time periods
chance_hohfeld = 1/4  # 4 Hohfeld states

print("DIRECTION A: Ancient → Modern")
print("-" * 40)
print(f"  Trained on:    {all_results.get('ancient_to_modern', {})['train_size']:,} ancient passages")
print(f"  Tested on:     {all_results.get('ancient_to_modern', {})['test_size']:,} modern passages")
print(f"  Time acc:      {all_results.get('ancient_to_modern', {})['time_acc']:.1%} (chance: {chance_time:.1%})")
print(f"  Hohfeld acc:   {all_results.get('ancient_to_modern', {})['hohfeld_acc']:.1%} (chance: {chance_hohfeld:.1%})")
print()

A_time_near_chance = abs(all_results.get('ancient_to_modern', {})['time_acc'] - chance_time) < 0.05
A_hohfeld_good = all_results.get('ancient_to_modern', {})['hohfeld_acc'] > 0.35

print(f"  Time invariant?    {'YES ✓' if A_time_near_chance else 'NO ✗'}")
print(f"  Moral structure?   {'YES ✓' if A_hohfeld_good else 'WEAK'}")
print()

print("DIRECTION B: Modern → Ancient")
print("-" * 40)
print(f"  Trained on:    {all_results.get('modern_to_ancient', {})['train_size']:,} modern passages")
print(f"  Tested on:     {all_results.get('modern_to_ancient', {})['test_size']:,} ancient passages")
print(f"  Time acc:      {all_results.get('modern_to_ancient', {})['time_acc']:.1%} (chance: {chance_time:.1%})")
print(f"  Hohfeld acc:   {all_results.get('modern_to_ancient', {})['hohfeld_acc']:.1%} (chance: {chance_hohfeld:.1%})")
print()

B_time_near_chance = abs(all_results.get('modern_to_ancient', {})['time_acc'] - chance_time) < 0.05
B_hohfeld_good = all_results.get('modern_to_ancient', {})['hohfeld_acc'] > 0.35

print(f"  Time invariant?    {'YES ✓' if B_time_near_chance else 'NO ✗'}")
print(f"  Moral structure?   {'YES ✓' if B_hohfeld_good else 'WEAK'}")
print()

print("=" * 60)
print("BIDIRECTIONAL INVARIANCE TEST")
print("=" * 60)
print()

if A_time_near_chance and B_time_near_chance and A_hohfeld_good and B_hohfeld_good:
    print("""
    ╔══════════════════════════════════════════════════════════╗
    ║                                                          ║
    ║     BIDIRECTIONAL BIP: STRONGLY SUPPORTED                ║
    ║                                                          ║
    ╠══════════════════════════════════════════════════════════╣
    ║                                                          ║
    ║  ✓ Ancient → Modern: Structure transfers                 ║
    ║  ✓ Modern → Ancient: Structure transfers                 ║
    ║  ✓ BOTH directions show time-invariant moral geometry    ║
    ║                                                          ║
    ║  This is STRONG evidence for universal moral structure.  ║
    ║                                                          ║
    ╚══════════════════════════════════════════════════════════╝
    """)
    bip_result = "STRONGLY_SUPPORTED"
elif A_time_near_chance and A_hohfeld_good:
    print("""
    ╔══════════════════════════════════════════════════════════╗
    ║                                                          ║
    ║     BIP: SUPPORTED (Direction A only)                    ║
    ║                                                          ║
    ╠══════════════════════════════════════════════════════════╣
    ║                                                          ║
    ║  ✓ Ancient → Modern: Structure transfers                 ║
    ║  ? Modern → Ancient: Weaker or inconclusive              ║
    ║                                                          ║
    ║  Possible explanations:                                  ║
    ║  - Ancient corpus richer/more diverse                    ║
    ║  - Sample size imbalance                                 ║
    ║  - Asymmetric structure (still interesting)              ║
    ║                                                          ║
    ╚══════════════════════════════════════════════════════════╝
    """)
    bip_result = "SUPPORTED_UNIDIRECTIONAL"
elif B_time_near_chance and B_hohfeld_good:
    print("""
    ╔══════════════════════════════════════════════════════════╗
    ║                                                          ║
    ║     BIP: SUPPORTED (Direction B only)                    ║
    ║                                                          ║
    ╠══════════════════════════════════════════════════════════╣
    ║                                                          ║
    ║  ? Ancient → Modern: Weaker or inconclusive              ║
    ║  ✓ Modern → Ancient: Structure transfers                 ║
    ║                                                          ║
    ║  Unexpected result - needs investigation.                ║
    ║                                                          ║
    ╚══════════════════════════════════════════════════════════╝
    """)
    bip_result = "SUPPORTED_REVERSE_ONLY"
else:
    print("""
    ╔══════════════════════════════════════════════════════════╗
    ║                                                          ║
    ║     BIP: INCONCLUSIVE                                    ║
    ║                                                          ║
    ╠══════════════════════════════════════════════════════════╣
    ║                                                          ║
    ║  Neither direction shows clear invariance.               ║
    ║                                                          ║
    ║  Possible issues:                                        ║
    ║  - Need more training epochs                             ║
    ║  - Need better bond extraction                           ║
    ║  - BIP may not hold (null result)                        ║
    ║                                                          ║
    ╚══════════════════════════════════════════════════════════╝
    """)
    bip_result = "INCONCLUSIVE"

# Load baselines
try:
    with open("data/splits/baselines.json", 'r') as f:
        baselines = json.load(f)
    chance_time = baselines['chance_time']
    chance_hohfeld = baselines['chance_hohfeld']
except:
    pass  # Use defaults from above

# Save detailed results including predictions
detailed_results = {
    'ancient_to_modern': all_results.get('ancient_to_modern', {}),
    'modern_to_ancient': all_results.get('modern_to_ancient', {}),
}

# Save to Drive for post-mortem
with open(f"{SAVE_DIR}/detailed_results.json", 'w') as f:
    json.dump(detailed_results, f, indent=2)
print(f"Detailed results saved to {SAVE_DIR}/detailed_results.json")

# Save results
results_summary = {
    'ancient_to_modern': all_results.get('ancient_to_modern', {}),
    'modern_to_ancient': all_results.get('modern_to_ancient', {}),
    'A_time_invariant': A_time_near_chance,
    'A_moral_structure': A_hohfeld_good,
    'B_time_invariant': B_time_near_chance,
    'B_moral_structure': B_hohfeld_good,
    'bip_result': bip_result,
    'chance_time': chance_time,
    'chance_hohfeld': chance_hohfeld
}

with open('results/bidirectional_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print()
print("Results saved to results/bidirectional_results.json")

mark_task("Evaluate results", "done")

print()
print("=" * 60)
print("EXPERIMENT COMPLETE")
print("=" * 60)
print_progress()


---

## Save Results

Run the cell below to download your trained model and results.

In [None]:
#@title 11. Download Results (Optional) { display-mode: "form" }
#@markdown Creates a zip file with model checkpoint and metrics.

import shutil
from google.colab import files

# Create results directory
!mkdir -p results
!cp models/checkpoints/best_model.pt results/
!cp data/splits/all_splits.json results/

# Save metrics
if len(train_dataset) > 0:
    metrics = {
        'accelerator': ACCELERATOR,
        'time_acc_from_bond': time_acc,
        'hohfeld_acc': hohfeld_acc,
        'chance_level': chance_level,
        'time_invariant': time_invariant,
        'moral_structure': moral_structure,
        'bip_supported': time_invariant and moral_structure
    }
    with open('results/metrics.json', 'w') as f:
        json.dump(metrics, f, indent=2)

# Zip
shutil.make_archive('bip_results', 'zip', 'results')
print("Results saved to bip_results.zip")
print()
print("Contents:")
!ls -la results/

# Download
files.download('bip_results.zip')