# BIP Cross-Cultural Universal Morality Experiment (v6)

**Testing the Bond Invariance Principle across 2000+ years AND across languages (Hebrew + Chinese + Arabic → English)**

This experiment tests whether moral cognition has invariant structure by:
1. Training on ORIGINAL HEBREW texts (Sefaria corpus, ~500 BCE - 1800 CE)
2. Testing transfer to modern ENGLISH advice columns (Dear Abby, 1956-2020)

**Hypothesis**: If BIP holds, bond-level features should transfer across 2000 years with no accuracy drop.

---

## v6 Changes
- **Added bond transfer accuracy test**: Now predicts `primary_relation` from `z_bond` and reports F1 by corpus
- **Fixed bond extractor**: Added patterns for EMERGENCY/CONTRACT, added NONE class instead of CARE default
- **Fixed TPU double-stepping bug**
- **Fixed download cell filename**
- **Renamed MAX_SEFARIA_FILES for clarity**

---

## Setup Instructions
1. **Runtime -> Change runtime type -> GPU (T4) or TPU (v5e)**
2. Run cells in order - each shows progress in real-time
3. Expected runtime: ~1-2 hours (TPU) or ~2-4 hours (GPU)

**Supported Accelerators:**
- NVIDIA GPU (T4, V100, A100)
- Google TPU (v2, v3, v4, v5e)
- CPU (slow, not recommended)

---

In [None]:
#@title 1. Setup and Install Dependencies { display-mode: "form" }
#@markdown Installs packages and detects GPU/TPU. Memory-optimized for Colab.

print("=" * 60)
print("BIP TEMPORAL INVARIANCE EXPERIMENT (v6)")
print("=" * 60)
print()

# Progress tracker
TASKS = [
    "Install dependencies",
    "Clone Sefaria corpus (~8GB)",
    "Clone sqnd-probe repo (Dear Abby data)",
    "Preprocess corpora",
    "Extract bond structures",
    "Generate train/test splits",
    "Train BIP model",
    "Evaluate results"
]
task_status = {task: "pending" for task in TASKS}

def print_progress():
    print()
    print("-" * 50)
    print("EXPERIMENT PROGRESS:")
    print("-" * 50)
    for task in TASKS:
        status = task_status[task]
        if status == "done":
            mark = "[X]"
        elif status == "running":
            mark = "[>]"
        else:
            mark = "[ ]"
        print(f"  {mark} {task}")
    print("-" * 50)
    print(flush=True)

def mark_task(task, status):
    task_status[task] = status
    print_progress()

print_progress()

mark_task("Install dependencies", "running")

import os
import subprocess
import sys

# Install dependencies - MINIMAL set to save memory
print("Installing minimal dependencies...")
deps = [
    "transformers",
    "torch", 
    "sentence-transformers",
    "pandas",
    "tqdm",
    "psutil",
    "scikit-learn"  # For F1 score
]

for dep in deps:
    print(f"  Installing {dep}...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", dep])

print()

# Detect accelerator - WITHOUT importing tensorflow
USE_TPU = False
TPU_TYPE = None

# Check for TPU
if 'COLAB_TPU_ADDR' in os.environ:
    USE_TPU = True
    TPU_TYPE = "TPU (Colab)"
    print("TPU detected!")

# Check for GPU
import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    ACCELERATOR = f"GPU: {gpu_name} ({gpu_mem:.1f}GB)"
    device = torch.device("cuda")
elif USE_TPU:
    ACCELERATOR = TPU_TYPE
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
else:
    ACCELERATOR = "CPU (slow!)"
    device = torch.device("cpu")

print(f"Accelerator: {ACCELERATOR}")
print(f"Device: {device}")

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"System RAM: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

if torch.cuda.is_available():
    print(f"GPU RAM: {torch.cuda.memory_allocated()/1e9:.1f}/{torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")

# Enable mixed precision for 2-3x speedup
if torch.cuda.is_available():
    print()
    print("Enabling mixed precision (FP16) for faster training...")
    from torch.cuda.amp import autocast, GradScaler
    USE_AMP = True
    scaler = GradScaler()
else:
    USE_AMP = False
    scaler = None

# torch.compile for PyTorch 2.0+ (10-30% speedup)
TORCH_COMPILE = False
if hasattr(torch, 'compile') and torch.cuda.is_available():
    print("PyTorch 2.0+ detected - torch.compile available")
    TORCH_COMPILE = False  # Disabled - overhead > benefit for short runs



# ============================================================
# GOOGLE DRIVE - SAVE RESULTS EVEN IF SESSION DIES
# ============================================================
print()
print("=" * 60)
print("MOUNTING GOOGLE DRIVE FOR PERSISTENT STORAGE")
print("=" * 60)
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = '/content/drive/MyDrive/BIP_results'
os.makedirs(SAVE_DIR, exist_ok=True)

# Create local directories
os.makedirs("data/processed", exist_ok=True)
os.makedirs("data/splits", exist_ok=True)
os.makedirs("models/checkpoints", exist_ok=True)
os.makedirs("results", exist_ok=True)
print(f"Results will be saved to: {SAVE_DIR}")
print("If session crashes, your data survives.")
print()

mark_task("Install dependencies", "done")



In [None]:
#@title 2. Download Sefaria Corpus (~8GB) { display-mode: "form" }
#@markdown Downloads the complete Sefaria corpus with real-time git progress.

import subprocess
import sys

mark_task("Clone Sefaria corpus (~8GB)", "running")

sefaria_path = 'data/raw/Sefaria-Export'

if not os.path.exists(sefaria_path) or not os.path.exists(f"{sefaria_path}/json"):
    print("="*60)
    print("CLONING SEFARIA CORPUS")
    print("="*60)
    print()
    print("This downloads ~3.5GB and takes 5-15 minutes.")
    print("Git's native progress will display below:")
    print("-"*60)
    print(flush=True)
    
    # Use subprocess.Popen for real-time output streaming
    process = subprocess.Popen(
        ['git', 'clone', '--depth', '1', '--progress',
         'https://github.com/Sefaria/Sefaria-Export.git', sefaria_path],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,  # Git writes progress to stderr
        text=True,
        bufsize=1
    )
    
    # Stream output in real-time
    for line in process.stdout:
        print(line, end='', flush=True)
    
    process.wait()
    
    print("-"*60)
    if process.returncode == 0:
        print("\nSefaria clone COMPLETE!")
    else:
        print(f"\nERROR: Git clone failed with code {process.returncode}")
        print("Try running this cell again, or check your internet connection.")
else:
    print("Sefaria already exists, skipping download.")

# Verify and count files
print()
print("Verifying download...")
!du -sh {sefaria_path} 2>/dev/null || echo "Directory not found"
json_count = !find {sefaria_path}/json -name "*.json" 2>/dev/null | wc -l
print(f"Sefaria JSON files found: {json_count[0]}")

mark_task("Clone Sefaria corpus (~8GB)", "done")


In [None]:
#@title 3. Download Dear Abby Dataset { display-mode: "form" }
#@markdown Downloads the Dear Abby advice column dataset (68,330 entries).

import pandas as pd
from pathlib import Path

mark_task("Clone sqnd-probe repo (Dear Abby data)", "running")

sqnd_path = 'sqnd-probe-data'
if not os.path.exists(sqnd_path):
    print("Cloning sqnd-probe repo...")
    process = subprocess.Popen(
        ['git', 'clone', '--depth', '1', '--progress',
         'https://github.com/ahb-sjsu/sqnd-probe.git', sqnd_path],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )
    for line in process.stdout:
        print(line, end='', flush=True)
    process.wait()
else:
    print("Repo already cloned.")

# Copy Dear Abby data
dear_abby_source = Path('sqnd-probe-data/dear_abby_data/raw_da_qs.csv')
dear_abby_path = Path('data/raw/dear_abby.csv')

if dear_abby_source.exists():
    !cp "{dear_abby_source}" "{dear_abby_path}"
    print(f"\nCopied Dear Abby data")
elif not dear_abby_path.exists():
    raise FileNotFoundError("Dear Abby dataset not found!")

# Verify
df_check = pd.read_csv(dear_abby_path)
print(f"\n" + "=" * 50)
print(f"Dear Abby dataset: {len(df_check):,} entries")
print(f"Columns: {list(df_check.columns)}")
print(f"Year range: {df_check['year'].min():.0f} - {df_check['year'].max():.0f}")
print("=" * 50)

mark_task("Clone sqnd-probe repo (Dear Abby data)", "done")

In [None]:
#@title 4. Define Data Classes and Loaders { display-mode: "form" }
#@markdown Defines enums, dataclasses, and corpus loaders.
#@markdown **v6 FIX**: Added NONE bond type, patterns for EMERGENCY/CONTRACT.

import json
import hashlib
import re
from dataclasses import dataclass, field, asdict
from typing import List, Dict
from enum import Enum
from collections import defaultdict
from tqdm.auto import tqdm

print("Defining data structures...")

class TimePeriod(Enum):
    BIBLICAL = 0        # ~1000-500 BCE
    SECOND_TEMPLE = 1   # ~500 BCE - 70 CE
    TANNAITIC = 2       # ~70-200 CE
    AMORAIC = 3         # ~200-500 CE
    GEONIC = 4          # ~600-1000 CE
    RISHONIM = 5        # ~1000-1500 CE
    ACHRONIM = 6        # ~1500-1800 CE
    MODERN_HEBREW = 7   # ~1800-present
    DEAR_ABBY = 8       # 1956-2020

# v6 FIX: Added NONE as explicit class (index 10) instead of defaulting to CARE
class BondType(Enum):
    HARM_PREVENTION = 0
    RECIPROCITY = 1
    AUTONOMY = 2
    PROPERTY = 3
    FAMILY = 4
    AUTHORITY = 5
    EMERGENCY = 6
    CONTRACT = 7
    CARE = 8
    FAIRNESS = 9
    NONE = 10  # v6: Explicit NONE class instead of sink to CARE

class HohfeldianState(Enum):
    RIGHT = 0
    OBLIGATION = 1
    LIBERTY = 2
    NO_RIGHT = 3

@dataclass
class Passage:
    id: str
    text_original: str
    text_english: str
    time_period: str
    century: int
    source: str
    source_type: str
    category: str
    language: str = "hebrew"
    word_count: int = 0
    has_dispute: bool = False
    consensus_tier: str = "unknown"
    bond_types: List[str] = field(default_factory=list)
    
    def to_dict(self):
        return asdict(self)

CATEGORY_TO_PERIOD = {
    'Tanakh': TimePeriod.BIBLICAL,
    'Torah': TimePeriod.BIBLICAL,
    'Mishnah': TimePeriod.TANNAITIC,
    'Tosefta': TimePeriod.TANNAITIC,
    'Talmud': TimePeriod.AMORAIC,
    'Bavli': TimePeriod.AMORAIC,
    'Midrash': TimePeriod.AMORAIC,
    'Halakhah': TimePeriod.RISHONIM,
    'Chasidut': TimePeriod.ACHRONIM,
}

PERIOD_TO_CENTURY = {
    TimePeriod.BIBLICAL: -6,
    TimePeriod.SECOND_TEMPLE: -2,
    TimePeriod.TANNAITIC: 2,
    TimePeriod.AMORAIC: 4,
    TimePeriod.GEONIC: 8,
    TimePeriod.RISHONIM: 12,
    TimePeriod.ACHRONIM: 17,
    TimePeriod.MODERN_HEBREW: 20,
}

def load_sefaria(base_path: str, max_files: int = None) -> List[Passage]:
    """Load Sefaria corpus with progress bar.
    
    Args:
        base_path: Path to Sefaria-Export directory
        max_files: Maximum number of JSON files to process (NOT passages)
                   v6 FIX: Renamed from max_passages for clarity
    """
    passages = []
    json_path = Path(base_path) / "json"
    
    if not json_path.exists():
        print(f"Warning: {json_path} not found")
        return []
    
    json_files = list(json_path.rglob("*.json"))
    print(f"Found {len(json_files):,} JSON files...")
    
    # v6 FIX: Clarify that we're limiting FILES, not passages
    files_to_process = json_files[:max_files] if max_files else json_files
    print(f"Processing {len(files_to_process):,} files (max_files={max_files})...")
    
    for json_file in tqdm(files_to_process, desc="Loading Sefaria", unit="file"):
        try:
            with open(json_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except:
            continue
        
        rel_path = json_file.relative_to(json_path)
        category = str(rel_path.parts[0]) if rel_path.parts else "unknown"
        time_period = CATEGORY_TO_PERIOD.get(category, TimePeriod.AMORAIC)
        century = PERIOD_TO_CENTURY.get(time_period, 0)
        
        if isinstance(data, dict):
            hebrew = data.get('he', data.get('text', []))
            english = data.get('text', data.get('en', []))
            
            def flatten(h, e, ref=""):
                if isinstance(h, str) and isinstance(e, str):
                    h_clean = re.sub(r'<[^>]+>', '', h).strip()
                    e_clean = re.sub(r'<[^>]+>', '', e).strip()
                    if 50 <= len(e_clean) <= 2000:
                        pid = hashlib.md5(f"{json_file.stem}:{ref}:{h_clean[:50]}".encode()).hexdigest()[:12]
                        return [Passage(
                            id=f"sefaria_{pid}",
                            text_original=h_clean,
                            text_english=e_clean,
                            time_period=time_period.name,
                            century=century,
                            source=f"{json_file.stem} {ref}".strip(),
                            source_type="sefaria",
                            category=category,
                            language="hebrew",
                            word_count=len(e_clean.split())
                        )]
                    return []
                elif isinstance(h, list) and isinstance(e, list):
                    result = []
                    for i, (hh, ee) in enumerate(zip(h, e)):
                        result.extend(flatten(hh, ee, f"{ref}.{i+1}" if ref else str(i+1)))
                    return result
                return []
            
            passages.extend(flatten(hebrew, english))
    
    return passages

def load_dear_abby(path: str, max_passages: int = None) -> List[Passage]:
    """Load Dear Abby corpus with progress bar."""
    passages = []
    df = pd.read_csv(path)
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading Dear Abby", unit="row"):
        question = str(row.get('question_only', ''))
        if not question or question == 'nan' or len(question) < 50 or len(question) > 2000:
            continue
        
        year = int(row.get('year', 1990))
        pid = hashlib.md5(f"abby:{idx}:{question[:50]}".encode()).hexdigest()[:12]
        
        passages.append(Passage(
            id=f"abby_{pid}",
            text_original=question,
            text_english=question,
            time_period=TimePeriod.DEAR_ABBY.name,
            century=20 if year < 2000 else 21,
            source=f"Dear Abby {year}",
            source_type="dear_abby",
            category="general",
            language="english",
            word_count=len(question.split())
        ))
        
        if max_passages and len(passages) >= max_passages:
            break
    
    return passages

print("Data structures defined!")

In [None]:
#@title 5. Load and Preprocess Corpora { display-mode: "form" }
#@markdown Loads both corpora.
#@markdown **v6 FIX**: Renamed parameter to MAX_SEFARIA_FILES for clarity.

#@markdown **Memory Management:**
MAX_SEFARIA_FILES = 5000  #@param {type:"integer"}
#@markdown **NOTE**: This limits JSON FILES processed, not total passages.
#@markdown Each file may contain multiple passages. Set to 0 for unlimited.

import gc

mark_task("Preprocess corpora", "running")

print("=" * 60)
print("LOADING CORPORA")
print("=" * 60)
print()

if MAX_SEFARIA_FILES > 0:
    print(f"*** MEMORY MODE: Limited to {MAX_SEFARIA_FILES:,} Sefaria JSON FILES ***")
    print("(Each file may yield multiple passages)")
    print()

# Load Sefaria with optional limit
# v6 FIX: Renamed parameter for clarity
limit = MAX_SEFARIA_FILES if MAX_SEFARIA_FILES > 0 else None
sefaria_passages = load_sefaria("data/raw/Sefaria-Export", max_files=limit)
print(f"\nSefaria passages loaded: {len(sefaria_passages):,}")

# Force garbage collection
gc.collect()

# Load Dear Abby
print()
abby_passages = load_dear_abby("data/raw/dear_abby.csv")
print(f"\nDear Abby passages loaded: {len(abby_passages):,}")

# Combine
all_passages = sefaria_passages + abby_passages

# Clear individual lists to save memory
del sefaria_passages
del abby_passages
gc.collect()

print()
print("=" * 60)
print(f"TOTAL PASSAGES: {len(all_passages):,}")
print("=" * 60)

# Statistics
by_period = defaultdict(int)
by_source = defaultdict(int)
for p in all_passages:
    by_period[p.time_period] += 1
    by_source[p.source_type] += 1

print("\nBy source:")
for source, count in sorted(by_source.items()):
    print(f"  {source}: {count:,}")

print("\nBy time period:")
for period, count in sorted(by_period.items()):
    pct = count / len(all_passages) * 100
    bar = '#' * int(pct / 2)
    print(f"  {period:20s}: {count:6,} ({pct:5.1f}%) {bar}")

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"\nMemory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

mark_task("Preprocess corpora", "done")


In [None]:
#@title 6. Extract Bond Structures { display-mode: "form" }
#@markdown Extracts moral bond structures. Streams to disk to save memory.
#@markdown **v6 FIX**: Added EMERGENCY/CONTRACT patterns, NONE instead of CARE default.

import gc

mark_task("Extract bond structures", "running")

# v6 FIX: Added patterns for EMERGENCY and CONTRACT
RELATION_PATTERNS = {
    BondType.HARM_PREVENTION: [r'\b(kill|murder|harm|hurt|save|rescue|protect|danger|attack|injure|wound|destroy)\b'],
    BondType.RECIPROCITY: [r'\b(return|repay|owe|debt|mutual|exchange|give back|pay back|reciprocate)\b'],
    BondType.AUTONOMY: [r'\b(choose|decision|consent|agree|force|coerce|right|freedom|liberty|self-determination)\b'],
    BondType.PROPERTY: [r'\b(property|own|steal|theft|buy|sell|land|possess|belong|asset)\b'],
    BondType.FAMILY: [r'\b(honor|parent|marry|divorce|inherit|family|mother|father|child|son|daughter|spouse|husband|wife)\b'],
    BondType.AUTHORITY: [r'\b(obey|command|law|judge|rule|teach|leader|king|master|servant|subject)\b'],
    BondType.CARE: [r'\b(care|help|assist|feed|clothe|visit|nurture|tend|support|comfort)\b'],
    BondType.FAIRNESS: [r'\b(fair|just|equal|deserve|bias|impartial|equity|discrimination)\b'],
    # v6 FIX: Added EMERGENCY patterns
    BondType.EMERGENCY: [r'\b(emergency|urgent|crisis|danger|life-threatening|immediate|desperate|dire|peril|rescue)\b'],
    # v6 FIX: Added CONTRACT patterns  
    BondType.CONTRACT: [r'\b(contract|agreement|promise|vow|oath|covenant|pledge|commit|bind|treaty|negotiate)\b'],
}

HOHFELD_PATTERNS = {
    HohfeldianState.OBLIGATION: [r'\b(must|shall|duty|require|should|ought|obligated)\b'],
    HohfeldianState.RIGHT: [r'\b(right to|entitled|deserve|claim|due)\b'],
    HohfeldianState.LIBERTY: [r'\b(may|can|permitted|allowed|free to|at liberty)\b'],
}

def extract_bond_structure(passage: Passage) -> Dict:
    """Extract bond structure from passage.
    
    v6 FIX: Now defaults to NONE instead of CARE when no patterns match.
    """
    text = passage.text_english.lower()
    
    relations = []
    for rel_type, patterns in RELATION_PATTERNS.items():
        for pattern in patterns:
            if re.search(pattern, text, re.IGNORECASE):
                relations.append(rel_type.name)
                break
    
    # v6 FIX: Default to NONE instead of CARE to avoid biased sink label
    if not relations:
        relations = ['NONE']
    
    hohfeld = None
    for state, patterns in HOHFELD_PATTERNS.items():
        for pattern in patterns:
            if re.search(pattern, text, re.IGNORECASE):
                hohfeld = state.name
                break
        if hohfeld:
            break
    
    signature = "|".join(sorted(set(relations)))
    
    return {
        'bonds': [{'relation': r} for r in relations],
        'primary_relation': relations[0],
        'hohfeld_state': hohfeld,
        'signature': signature
    }

print("=" * 60)
print("EXTRACTING & SAVING (STREAMING)")
print("=" * 60)
print()
print("Writing directly to disk to conserve memory...")
print()

bond_counts = defaultdict(int)

# Stream directly to files - don't accumulate in memory
with open("data/processed/passages.jsonl", 'w') as f_pass, \
     open("data/processed/bond_structures.jsonl", 'w') as f_bond:
    
    for passage in tqdm(all_passages, desc="Processing", unit="passage"):
        # Extract bonds
        bond_struct = extract_bond_structure(passage)
        passage.bond_types = [b['relation'] for b in bond_struct['bonds']]
        
        # Count for stats
        for bond in bond_struct['bonds']:
            bond_counts[bond['relation']] += 1
        
        # Write immediately (don't accumulate)
        f_pass.write(json.dumps(passage.to_dict()) + '\n')
        f_bond.write(json.dumps({
            'passage_id': passage.id,
            'bond_structure': bond_struct
        }) + '\n')

# Clear passages from memory - we've saved them to disk
n_passages = len(all_passages)
del all_passages
gc.collect()

print()
print(f"Saved {n_passages:,} passages to disk")
print("Cleared passages from memory")

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"Memory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

print()
print("Bond type distribution:")
for bond_type, count in sorted(bond_counts.items(), key=lambda x: -x[1]):
    pct = count / sum(bond_counts.values()) * 100
    bar = '#' * int(pct)
    print(f"  {bond_type:20s}: {count:6,} ({pct:5.1f}%) {bar}")

# v6: Warn if NONE is dominant (may indicate patterns need expansion)
if bond_counts.get('NONE', 0) / sum(bond_counts.values()) > 0.5:
    print()
    print("WARNING: NONE class is >50% - consider expanding patterns.")

mark_task("Extract bond structures", "done")


In [None]:
#@title 7. Generate Train/Test Splits { display-mode: "form" }
#@markdown Creates splits from saved files. Memory efficient - reads only IDs.

import random
import gc
random.seed(42)

mark_task("Generate train/test splits", "running")

print("=" * 60)
print("GENERATING SPLITS (MEMORY EFFICIENT)")
print("=" * 60)
print()

# Read only IDs and time periods from disk - don't load full passages
print("Reading passage metadata from disk...")
passage_info = []  # List of (id, time_period) tuples - minimal memory

with open("data/processed/passages.jsonl", 'r') as f:
    for line in tqdm(f, desc="Reading IDs", unit="line"):
        p = json.loads(line)
        passage_info.append((p['id'], p['time_period']))

print(f"Loaded {len(passage_info):,} passage IDs")

# Define time periods
train_periods = {'BIBLICAL', 'SECOND_TEMPLE', 'TANNAITIC', 'AMORAIC', 'GEONIC', 'RISHONIM'}
valid_periods = {'ACHRONIM'}
test_periods = {'MODERN_HEBREW', 'DEAR_ABBY'}

print()
print("Filtering by time period...")
ancient_ids = [(pid, tp) for pid, tp in passage_info if tp in train_periods]
early_modern_ids = [(pid, tp) for pid, tp in passage_info if tp in valid_periods]
modern_ids = [(pid, tp) for pid, tp in passage_info if tp in test_periods]

print(f"  Ancient/Medieval: {len(ancient_ids):,}")
print(f"  Early Modern:     {len(early_modern_ids):,}")
print(f"  Modern:           {len(modern_ids):,}")

# Shuffle
random.shuffle(ancient_ids)
random.shuffle(early_modern_ids)
random.shuffle(modern_ids)

# ============================================================
# SPLIT A: ANCIENT -> MODERN
# ============================================================
print()
print("-" * 60)
print("SPLIT A: Train ANCIENT, Test MODERN")
print("-" * 60)

temporal_A = {
    'name': 'ancient_to_modern',
    'direction': 'A->M',
    'train_ids': [pid for pid, _ in ancient_ids],
    'valid_ids': [pid for pid, _ in early_modern_ids],
    'test_ids': [pid for pid, _ in modern_ids],
    'train_size': len(ancient_ids),
    'valid_size': len(early_modern_ids),
    'test_size': len(modern_ids)
}
print(f"  Train: {temporal_A['train_size']:,}")
print(f"  Valid: {temporal_A['valid_size']:,}")
print(f"  Test:  {temporal_A['test_size']:,}")

# ============================================================
# SPLIT B: MODERN -> ANCIENT
# ============================================================
print()
print("-" * 60)
print("SPLIT B: Train MODERN, Test ANCIENT")
print("-" * 60)

n_modern = len(modern_ids)
ancient_test = ancient_ids[n_modern:n_modern*2] if len(ancient_ids) >= n_modern*2 else ancient_ids[n_modern:]

temporal_B = {
    'name': 'modern_to_ancient',
    'direction': 'M->A',
    'train_ids': [pid for pid, _ in modern_ids],
    'valid_ids': [pid for pid, _ in early_modern_ids[:len(early_modern_ids)//2]],
    'test_ids': [pid for pid, _ in ancient_test],
    'train_size': len(modern_ids),
    'valid_size': len(early_modern_ids) // 2,
    'test_size': len(ancient_test)
}
print(f"  Train: {temporal_B['train_size']:,}")
print(f"  Valid: {temporal_B['valid_size']:,}")
print(f"  Test:  {temporal_B['test_size']:,}")

# ============================================================
# SPLIT C: MIXED CONTROL
# ============================================================
print()
print("-" * 60)
print("SPLIT C: MIXED (Control)")
print("-" * 60)

all_ids = ancient_ids + modern_ids
random.shuffle(all_ids)
n = len(all_ids)
n_train = int(0.7 * n)
n_valid = int(0.15 * n)

temporal_C = {
    'name': 'mixed_control',
    'direction': 'MIXED',
    'train_ids': [pid for pid, _ in all_ids[:n_train]],
    'valid_ids': [pid for pid, _ in all_ids[n_train:n_train+n_valid]],
    'test_ids': [pid for pid, _ in all_ids[n_train+n_valid:]],
    'train_size': n_train,
    'valid_size': n_valid,
    'test_size': n - n_train - n_valid
}
print(f"  Train: {temporal_C['train_size']:,}")
print(f"  Valid: {temporal_C['valid_size']:,}")
print(f"  Test:  {temporal_C['test_size']:,}")

# Clear temporary data
del passage_info, ancient_ids, early_modern_ids, modern_ids, all_ids
gc.collect()

# Save
print()
print("Saving splits...")
splits = {
    'ancient_to_modern': temporal_A,
    'modern_to_ancient': temporal_B,
    'mixed_control': temporal_C
}

with open("data/splits/all_splits.json", 'w') as f:
    json.dump(splits, f, indent=2)

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"Memory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

print()
print("SPLITS SAVED:")
print("  - ancient_to_modern (A->M)")
print("  - modern_to_ancient (M->A)")  
print("  - mixed_control")





# ============================================================
# DISTRIBUTION CHECK - Catch problems early
# ============================================================
print()
print("=" * 60)
print("LABEL DISTRIBUTION CHECK")
print("=" * 60)

# Count Hohfeld labels AND primary_relation labels
hohfeld_counts = {}
time_counts = {}
bond_type_counts = {}  # v6: Track bond types for transfer test

with open("data/processed/bond_structures.jsonl", 'r') as fb, \
     open("data/processed/passages.jsonl", 'r') as fp:
    for b_line, p_line in zip(fb, fp):
        b = json.loads(b_line)
        p = json.loads(p_line)
        h = b['bond_structure'].get('hohfeld_state', None)
        t = p['time_period']
        bond = b['bond_structure'].get('primary_relation', 'NONE')
        
        hohfeld_counts[h] = hohfeld_counts.get(h, 0) + 1
        time_counts[t] = time_counts.get(t, 0) + 1
        bond_type_counts[bond] = bond_type_counts.get(bond, 0) + 1

print()
print("Hohfeld distribution:")
total_h = sum(hohfeld_counts.values())
for h, c in sorted(hohfeld_counts.items(), key=lambda x: -x[1]):
    pct = 100 * c / total_h
    bar = "#" * int(pct / 2)
    print(f"  {str(h):15s}: {c:>8,} ({pct:5.1f}%) {bar}")

print()
print("Time period distribution:")
total_t = sum(time_counts.values())
for t, c in sorted(time_counts.items(), key=lambda x: -x[1]):
    pct = 100 * c / total_t
    bar = "#" * int(pct / 2)
    print(f"  {t:15s}: {c:>8,} ({pct:5.1f}%) {bar}")

# v6: Show bond type distribution
print()
print("Bond type distribution (for transfer test):")
total_b = sum(bond_type_counts.values())
for b, c in sorted(bond_type_counts.items(), key=lambda x: -x[1]):
    pct = 100 * c / total_b
    bar = "#" * int(pct / 2)
    print(f"  {b:20s}: {c:>8,} ({pct:5.1f}%) {bar}")

# Compute actual chance baselines
N_HOHFELD_CLASSES = len([h for h in hohfeld_counts if h is not None]) + 1  # +1 for None
N_TIME_CLASSES = len(time_counts)
N_BOND_CLASSES = len(bond_type_counts)  # v6: Bond classes for transfer test

CHANCE_HOHFELD = 1.0 / N_HOHFELD_CLASSES
CHANCE_TIME = 1.0 / N_TIME_CLASSES
CHANCE_BOND = 1.0 / N_BOND_CLASSES  # v6

print()
print(f"Chance baseline - Hohfeld: {CHANCE_HOHFELD:.1%} ({N_HOHFELD_CLASSES} classes)")
print(f"Chance baseline - Time:    {CHANCE_TIME:.1%} ({N_TIME_CLASSES} classes)")
print(f"Chance baseline - Bond:    {CHANCE_BOND:.1%} ({N_BOND_CLASSES} classes)")

# Save baselines for later
baselines = {
    'hohfeld_counts': {str(k): v for k, v in hohfeld_counts.items()},
    'time_counts': time_counts,
    'bond_type_counts': bond_type_counts,  # v6
    'chance_hohfeld': CHANCE_HOHFELD,
    'chance_time': CHANCE_TIME,
    'chance_bond': CHANCE_BOND,  # v6
    'n_hohfeld_classes': N_HOHFELD_CLASSES,
    'n_time_classes': N_TIME_CLASSES,
    'n_bond_classes': N_BOND_CLASSES  # v6
}
with open("data/splits/baselines.json", 'w') as f:
    json.dump(baselines, f, indent=2)

# Warn if severe imbalance
most_common_hohfeld = max(hohfeld_counts.values()) / total_h
if most_common_hohfeld > 0.7:
    print()
    print(f"WARNING: Hohfeld labels severely imbalanced! Most common = {most_common_hohfeld:.1%}")
    print("         Model may just predict majority class.")

# ============================================================
# SAVE PREPROCESSING TO DRIVE
# ============================================================
print()
print("=" * 60)
print("SAVING PREPROCESSED DATA TO GOOGLE DRIVE")
print("=" * 60)
import shutil
shutil.copytree("data/processed", f"{SAVE_DIR}/processed", dirs_exist_ok=True)
shutil.copytree("data/splits", f"{SAVE_DIR}/splits", dirs_exist_ok=True)
print(f"Saved to {SAVE_DIR}")
print("If session dies, run: !cp -r {SAVE_DIR}/* data/")
print("Then skip to Cell 8.")
print()

mark_task("Generate train/test splits", "done")



In [None]:
#@title 8. Define BIP Model Architecture { display-mode: "form" }
#@markdown Defines the model with bond prediction head for transfer test.
#@markdown **v6 NEW**: Added bond_classifier and bond prediction output.

import gc
import psutil

# CRITICAL: Clear memory before loading model
print("Clearing memory before model load...")
gc.collect()

if torch.cuda.is_available():
    torch.cuda.empty_cache()

mem = psutil.virtual_memory()
print(f"Memory before model: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")
print()


import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import gc

print("=" * 60)
print("DEFINING MODEL ARCHITECTURE (v6)")
print("=" * 60)
print()
print("*** CROSS-CULTURAL MODE ***")
print("Encoder: paraphrase-multilingual-MiniLM-L12-v2")
print("  - Trained on 50+ languages including Hebrew and English")
print("  - Maps both languages into shared embedding space")
print("  - Sefaria passages: ORIGINAL HEBREW")
print("  - Dear Abby passages: ENGLISH")
print()
print("v6 NEW: Added bond_classifier for explicit bond transfer test")
print()

class GradientReversal(torch.autograd.Function):
    """Gradient reversal layer for adversarial training."""
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()
    
    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_ * grad_output, None

def gradient_reversal(x, lambda_=1.0):
    return GradientReversal.apply(x, lambda_)

class BIPEncoder(nn.Module):
    """Sentence encoder using pretrained transformer."""
    def __init__(self, model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", d_model=384):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.d_model = d_model
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
        return pooled

class BIPModel(nn.Module):
    """Bond Invariance Principle Model with adversarial disentanglement.
    
    v6 NEW: Added bond_classifier for explicit bond prediction (primary_relation)
    to measure bond-level transfer accuracy.
    """
    def __init__(self, d_model=384, d_bond=64, d_label=32, n_periods=14, n_hohfeld=4, n_bonds=11):
        super().__init__()
        
        self.encoder = BIPEncoder()
        
        self.bond_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, d_bond)
        )
        
        self.label_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, d_label)
        )
        
        self.time_classifier_bond = nn.Linear(d_bond, n_periods)
        self.time_classifier_label = nn.Linear(d_label, n_periods)
        self.hohfeld_classifier = nn.Linear(d_bond, n_hohfeld)
        
        # v6 NEW: Bond type classifier for transfer accuracy measurement
        self.bond_classifier = nn.Linear(d_bond, n_bonds)
    
    def forward(self, input_ids, attention_mask, adversarial_lambda=1.0):
        h = self.encoder(input_ids, attention_mask)
        
        z_bond = self.bond_proj(h)
        z_label = self.label_proj(h)
        
        z_bond_adv = gradient_reversal(z_bond, adversarial_lambda)
        time_pred_bond = self.time_classifier_bond(z_bond_adv)
        time_pred_label = self.time_classifier_label(z_label)
        hohfeld_pred = self.hohfeld_classifier(z_bond)
        
        # v6 NEW: Predict primary bond type from z_bond
        bond_pred = self.bond_classifier(z_bond)
        
        return {
            'z_bond': z_bond,
            'z_label': z_label,
            'time_pred_bond': time_pred_bond,
            'time_pred_label': time_pred_label,
            'hohfeld_pred': hohfeld_pred,
            'bond_pred': bond_pred  # v6 NEW
        }

# Time period mapping
TIME_PERIOD_TO_IDX = {
    'BIBLICAL': 0, 'SECOND_TEMPLE': 1, 'TANNAITIC': 2, 'AMORAIC': 3,
    'GEONIC': 4, 'RISHONIM': 5, 'ACHRONIM': 6, 'MODERN_HEBREW': 7,
    # Chinese
    'CONFUCIAN': 8, 'DAOIST': 9, 'MOHIST': 10,
    # Arabic  
    'QURANIC': 11, 'HADITH': 12,
    # Modern
    'DEAR_ABBY': 13
}

HOHFELD_TO_IDX = {
    'OBLIGATION': 0, 'RIGHT': 1, 'LIBERTY': 2, None: 3
}

# v6 NEW: Bond type mapping (matches BondType enum order)
BOND_TYPE_TO_IDX = {
    'HARM_PREVENTION': 0,
    'RECIPROCITY': 1,
    'AUTONOMY': 2,
    'PROPERTY': 3,
    'FAMILY': 4,
    'AUTHORITY': 5,
    'EMERGENCY': 6,
    'CONTRACT': 7,
    'CARE': 8,
    'FAIRNESS': 9,
    'NONE': 10
}
IDX_TO_BOND_TYPE = {v: k for k, v in BOND_TYPE_TO_IDX.items()}

class MoralDataset(Dataset):
    """
    MEMORY-EFFICIENT Dataset that reads from disk on demand.
    v6 NEW: Also returns bond_label for transfer accuracy test.
    """
    def __init__(self, passage_ids: set, passages_file: str, bonds_file: str, tokenizer, max_len=64):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.passage_ids = passage_ids
        
        print(f"  Indexing {len(passage_ids):,} passages...")
        
        self.data = []
        
        # Load only the passages we need
        with open(passages_file, 'r') as f_pass, open(bonds_file, 'r') as f_bond:
            for p_line, b_line in tqdm(zip(f_pass, f_bond), desc="  Loading subset", unit="line", total=None):
                p = json.loads(p_line)
                if p['id'] in passage_ids:
                    b = json.loads(b_line)
                    self.data.append({
                        'text': (p.get('text_original', '') if p.get('language') in ['hebrew', 'chinese', 'arabic'] else p.get('text_english', ''))[:1000],
                        'time_period': p['time_period'],
                        'source_type': p['source_type'],  # v6: Track corpus for per-corpus F1
                        'hohfeld': b['bond_structure']['hohfeld_state'],
                        'primary_relation': b['bond_structure']['primary_relation']  # v6 NEW
                    })
        
        print(f"  Loaded {len(self.data):,} samples")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            max_length=self.max_len,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'time_label': TIME_PERIOD_TO_IDX.get(item['time_period'], 8),
            'hohfeld_label': HOHFELD_TO_IDX.get(item['hohfeld'], 3),
            'bond_label': BOND_TYPE_TO_IDX.get(item['primary_relation'], 10),  # v6 NEW
            'source_type': item['source_type']  # v6: For per-corpus metrics
        }

def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'time_labels': torch.tensor([x['time_label'] for x in batch]),
        'hohfeld_labels': torch.tensor([x['hohfeld_label'] for x in batch]),
        'bond_labels': torch.tensor([x['bond_label'] for x in batch]),  # v6 NEW
        'source_types': [x['source_type'] for x in batch]  # v6: Keep as list for grouping
    }

# Memory cleanup
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("Model architecture defined!")
print(f"  - Bond types: {len(BOND_TYPE_TO_IDX)} classes")
print(f"  - Time periods: {len(TIME_PERIOD_TO_IDX)} classes")
print(f"  - Hohfeld states: {len(HOHFELD_TO_IDX)} classes")
print()

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"Memory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")


In [None]:
#@title 9. Train BIP Model - BIDIRECTIONAL { display-mode: "form" }
#@markdown Trains on BOTH directions with bond transfer accuracy test.
#@markdown **v6 FIX**: Fixed TPU double-stepping bug.
#@markdown **v6 NEW**: Trains bond classifier and reports F1 by corpus.

import gc
import psutil
from sklearn.metrics import f1_score, classification_report

mark_task("Train BIP model", "running")

# Memory cleanup before training
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

mem = psutil.virtual_memory()
print(f"Memory at start: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB")

print("=" * 60)
print("BIDIRECTIONAL BIP TRAINING (v6)")
print("=" * 60)
print()
print(f"Accelerator: {ACCELERATOR}")
print(f"Device: {device}")
print()
print("v6 NEW: Now training bond classifier for transfer accuracy test")
print()

# Load tokenizer once
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

# Store results for both directions
all_results = {}

for split_name in ['ancient_to_modern', 'modern_to_ancient']:
    print()
    print("=" * 60)
    print(f"DIRECTION {split_name}: {'Ancient → Modern' if split_name == 'ancient_to_modern' else 'Modern → Ancient'}")
    print("=" * 60)
    print()
    
    # Load appropriate split
    with open("data/splits/all_splits.json", 'r') as f:
        splits = json.load(f)
    split = splits[split_name]
    
    print(f"Train: {split['train_size']:,}")
    print(f"Valid: {split['valid_size']:,}")
    print(f"Test:  {split['test_size']:,}")
    print()
    
    # Create fresh model for each direction
    print("Creating fresh model...")
    model = BIPModel().to(device)
    
    # Compile model for speed (PyTorch 2.0+)
    if TORCH_COMPILE:
        print("Compiling model with torch.compile...")
        model = torch.compile(model, mode="reduce-overhead")
    
    if split_name == 'ancient_to_modern':
        print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Create datasets
    print("Creating datasets...")
    train_dataset = MoralDataset(
        set(split['train_ids']),
        "data/processed/passages.jsonl",
        "data/processed/bond_structures.jsonl",
        tokenizer
    )
    valid_dataset = MoralDataset(
        set(split['valid_ids']),
        "data/processed/passages.jsonl",
        "data/processed/bond_structures.jsonl",
        tokenizer
    )
    test_dataset = MoralDataset(
        set(split['test_ids']),
        "data/processed/passages.jsonl",
        "data/processed/bond_structures.jsonl",
        tokenizer
    )
    
    print(f"Train samples: {len(train_dataset):,}")
    print(f"Valid samples: {len(valid_dataset):,}")
    print(f"Test samples:  {len(test_dataset):,}")
    print()
    
    if len(train_dataset) == 0:
        print("ERROR: No training data!")
        continue
    
    # Adjust batch size based on dataset size
    batch_size = 256 if split_name == 'ancient_to_modern' else min(32, len(train_dataset) // 10)
    batch_size = max(32, batch_size)  # Minimum batch size
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              collate_fn=collate_fn, drop_last=True, num_workers=4, pin_memory=True, prefetch_factor=4, persistent_workers=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size*2, shuffle=False,
                              collate_fn=collate_fn, num_workers=4, pin_memory=True, prefetch_factor=4, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size*2, shuffle=False,
                             collate_fn=collate_fn, num_workers=4, pin_memory=True, prefetch_factor=4, persistent_workers=True)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
    
    n_epochs = 3
    best_valid_loss = float('inf')
    patience = 3
    patience_counter = 0
    
    print(f"Training for {n_epochs} epochs (batch_size={batch_size})...")
    print()
    
    for epoch in range(1, n_epochs + 1):
        model.train()
        total_loss = 0
        n_batches = 0
        
        pbar = tqdm(train_loader, desc=f"[{split_name}] Epoch {epoch}/{n_epochs}", unit="batch")
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            time_labels = batch['time_labels'].to(device)
            hohfeld_labels = batch['hohfeld_labels'].to(device)
            bond_labels = batch['bond_labels'].to(device)  # v6 NEW
            
            # Mixed precision forward pass
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                outputs = model(input_ids, attention_mask, adversarial_lambda=1.0)
                
                # Losses
                loss_time_bond = F.cross_entropy(outputs['time_pred_bond'], time_labels)
                loss_time_label = F.cross_entropy(outputs['time_pred_label'], time_labels)
                loss_hohfeld = F.cross_entropy(outputs['hohfeld_pred'], hohfeld_labels)
                loss_bond = F.cross_entropy(outputs['bond_pred'], bond_labels)  # v6 NEW
            
            # v6: Include bond loss in total
            loss = loss_hohfeld + loss_time_label + loss_time_bond + loss_bond
            
            optimizer.zero_grad()
            
            # v6 FIX: Choose ONE stepping mechanism, not both
            if USE_TPU:
                # TPU: Use XLA optimizer step ONLY
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                xm.optimizer_step(optimizer)
                xm.mark_step()
            elif USE_AMP and scaler is not None:
                # GPU with AMP
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                # CPU or GPU without AMP
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            total_loss += loss.item()
            n_batches += 1
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})
        
        avg_train_loss = total_loss / n_batches
        
        # Validation
        model.eval()
        valid_loss = 0
        valid_batches = 0
        time_correct = 0
        time_total = 0
        hohfeld_correct = 0
        hohfeld_total = 0
        bond_correct = 0  # v6 NEW
        bond_total = 0
        
        with torch.no_grad():
            for batch in valid_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                time_labels = batch['time_labels'].to(device)
                hohfeld_labels = batch['hohfeld_labels'].to(device)
                bond_labels = batch['bond_labels'].to(device)  # v6 NEW
                
                outputs = model(input_ids, attention_mask, adversarial_lambda=0)
                loss = F.cross_entropy(outputs['hohfeld_pred'], hohfeld_labels)
                valid_loss += loss.item()
                valid_batches += 1
                
                time_preds = outputs['time_pred_bond'].argmax(dim=-1)
                time_correct += (time_preds == time_labels).sum().item()
                time_total += len(time_labels)
                
                hohfeld_preds = outputs['hohfeld_pred'].argmax(dim=-1)
                hohfeld_correct += (hohfeld_preds == hohfeld_labels).sum().item()
                hohfeld_total += len(hohfeld_labels)
                
                # v6 NEW: Bond accuracy
                bond_preds = outputs['bond_pred'].argmax(dim=-1)
                bond_correct += (bond_preds == bond_labels).sum().item()
                bond_total += len(bond_labels)
                
                if USE_TPU:
                    xm.mark_step()
        
        avg_valid_loss = valid_loss / valid_batches if valid_batches > 0 else 0
        time_acc = time_correct / time_total if time_total > 0 else 0
        hohfeld_acc_val = hohfeld_correct / hohfeld_total if hohfeld_total > 0 else 0
        bond_acc_val = bond_correct / bond_total if bond_total > 0 else 0  # v6 NEW
        
        print(f"[{split_name}] Epoch {epoch}: Loss={avg_train_loss:.4f}/{avg_valid_loss:.4f}, Hohfeld={hohfeld_acc_val:.1%}, Bond={bond_acc_val:.1%}, TimeAcc={time_acc:.1%}")
        
        if avg_valid_loss < best_valid_loss:
            best_valid_loss = avg_valid_loss
            model_path = f"models/checkpoints/best_model_{split_name}.pt"
            if USE_TPU:
                xm.save(model.state_dict(), model_path)
            else:
                torch.save(model.state_dict(), model_path)
            print(f"  -> Saved best model for {split_name}!")
            # Backup to Drive
            import shutil
            shutil.copy(model_path, f"{SAVE_DIR}/best_model_{split_name}.pt")
            print(f"  -> Backed up to Google Drive")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"  Early stopping at epoch {epoch} (no improvement for {patience} epochs)")
                break
    
    # ================================================================
    # EVALUATE ON TEST SET - Including bond transfer accuracy (v6 NEW)
    # ================================================================
    print()
    print(f"Evaluating {split_name} on test set...")
    
    model.load_state_dict(torch.load(f"models/checkpoints/best_model_{split_name}.pt", map_location='cpu'))
    model = model.to(device)
    model.eval()
    
    all_time_preds = []
    all_time_labels = []
    all_hohfeld_preds = []
    all_hohfeld_labels = []
    all_bond_preds = []  # v6 NEW
    all_bond_labels = []  # v6 NEW
    all_source_types = []  # v6: For per-corpus metrics
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"[{split_name}] Testing", unit="batch"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(input_ids, attention_mask, adversarial_lambda=0)
            
            all_time_preds.extend(outputs['time_pred_bond'].argmax(dim=-1).cpu().tolist())
            all_time_labels.extend(batch['time_labels'].tolist())
            all_hohfeld_preds.extend(outputs['hohfeld_pred'].argmax(dim=-1).cpu().tolist())
            all_hohfeld_labels.extend(batch['hohfeld_labels'].tolist())
            all_bond_preds.extend(outputs['bond_pred'].argmax(dim=-1).cpu().tolist())  # v6 NEW
            all_bond_labels.extend(batch['bond_labels'].tolist())  # v6 NEW
            all_source_types.extend(batch['source_types'])  # v6
            
            if USE_TPU:
                xm.mark_step()
    
    # Calculate metrics
    time_acc = sum(p == l for p, l in zip(all_time_preds, all_time_labels)) / len(all_time_preds)
    hohfeld_acc = sum(p == l for p, l in zip(all_hohfeld_preds, all_hohfeld_labels)) / len(all_hohfeld_preds)
    bond_acc = sum(p == l for p, l in zip(all_bond_preds, all_bond_labels)) / len(all_bond_preds)  # v6 NEW
    
    # v6 NEW: Calculate F1 scores
    bond_f1_macro = f1_score(all_bond_labels, all_bond_preds, average='macro', zero_division=0)
    bond_f1_weighted = f1_score(all_bond_labels, all_bond_preds, average='weighted', zero_division=0)
    hohfeld_f1_macro = f1_score(all_hohfeld_labels, all_hohfeld_preds, average='macro', zero_division=0)
    
    # v6 NEW: Per-corpus bond F1
    corpus_bond_f1 = {}
    for corpus in set(all_source_types):
        mask = [s == corpus for s in all_source_types]
        corpus_preds = [p for p, m in zip(all_bond_preds, mask) if m]
        corpus_labels = [l for l, m in zip(all_bond_labels, mask) if m]
        if len(corpus_labels) > 0:
            corpus_bond_f1[corpus] = {
                'f1_macro': f1_score(corpus_labels, corpus_preds, average='macro', zero_division=0),
                'f1_weighted': f1_score(corpus_labels, corpus_preds, average='weighted', zero_division=0),
                'accuracy': sum(p == l for p, l in zip(corpus_preds, corpus_labels)) / len(corpus_labels),
                'n_samples': len(corpus_labels)
            }
    
    all_results[split_name] = {
        'time_acc': time_acc,
        'hohfeld_acc': hohfeld_acc,
        'hohfeld_f1_macro': hohfeld_f1_macro,
        'bond_acc': bond_acc,  # v6 NEW
        'bond_f1_macro': bond_f1_macro,  # v6 NEW
        'bond_f1_weighted': bond_f1_weighted,  # v6 NEW
        'corpus_bond_f1': corpus_bond_f1,  # v6 NEW: Per-corpus metrics
        'train_size': split['train_size'],
        'test_size': split['test_size']
    }
    
    print()
    print(f"{split_name.upper()} RESULTS:")
    print(f"  Time prediction from z_bond: {time_acc:.1%} (chance ~11%)")
    print(f"  Hohfeld classification:      {hohfeld_acc:.1%} (F1={hohfeld_f1_macro:.3f})")
    print(f"  Bond classification:         {bond_acc:.1%} (F1={bond_f1_macro:.3f})")
    print()
    print("  Bond F1 by corpus:")
    for corpus, metrics in corpus_bond_f1.items():
        print(f"    {corpus}: F1={metrics['f1_macro']:.3f}, Acc={metrics['accuracy']:.1%} (n={metrics['n_samples']:,})")

print()
print("=" * 60)
print("TRAINING COMPLETE - BOTH DIRECTIONS")
print("=" * 60)

mark_task("Train BIP model", "done")


In [None]:
#@title 10. Evaluate Bidirectional Results { display-mode: "form" }
#@markdown Compares results from BOTH directions to assess true invariance.
#@markdown **v6 NEW**: Reports bond transfer accuracy and F1 by corpus.

import gc
import psutil

mark_task("Evaluate results", "running")

from collections import Counter
try:
    from sklearn.metrics import confusion_matrix, classification_report
    HAS_SKLEARN = True
except ImportError:
    HAS_SKLEARN = False
    print("sklearn not available - skipping confusion matrices")

mem = psutil.virtual_memory()
print(f"Memory at eval start: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB")

print("=" * 60)
print("BIDIRECTIONAL BIP RESULTS (v6)")
print("=" * 60)
print()

# Load baselines
try:
    with open("data/splits/baselines.json", 'r') as f:
        baselines = json.load(f)
    chance_time = baselines['chance_time']
    chance_hohfeld = baselines['chance_hohfeld']
    chance_bond = baselines.get('chance_bond', 1/11)
except:
    chance_time = 1/9
    chance_hohfeld = 1/4
    chance_bond = 1/11

print("DIRECTION A: Ancient → Modern")
print("-" * 40)
res_A = all_results.get('ancient_to_modern', {})
print(f"  Trained on:    {res_A.get('train_size', 0):,} ancient passages")
print(f"  Tested on:     {res_A.get('test_size', 0):,} modern passages")
print(f"  Time acc:      {res_A.get('time_acc', 0):.1%} (chance: {chance_time:.1%})")
print(f"  Hohfeld acc:   {res_A.get('hohfeld_acc', 0):.1%} (F1: {res_A.get('hohfeld_f1_macro', 0):.3f})")
print(f"  Bond acc:      {res_A.get('bond_acc', 0):.1%} (F1: {res_A.get('bond_f1_macro', 0):.3f})")
print()

# v6 NEW: Bond F1 by corpus
if 'corpus_bond_f1' in res_A:
    print("  Bond transfer by corpus:")
    for corpus, metrics in res_A['corpus_bond_f1'].items():
        print(f"    {corpus}: F1={metrics['f1_macro']:.3f}, Acc={metrics['accuracy']:.1%}")
    print()

A_time_near_chance = abs(res_A.get('time_acc', 0) - chance_time) < 0.05
A_hohfeld_good = res_A.get('hohfeld_acc', 0) > 0.35
A_bond_good = res_A.get('bond_f1_macro', 0) > chance_bond * 2  # v6: Bond transfer threshold

print(f"  Time invariant?    {'YES ✓' if A_time_near_chance else 'NO ✗'}")
print(f"  Moral structure?   {'YES ✓' if A_hohfeld_good else 'WEAK'}")
print(f"  Bond transfer?     {'YES ✓' if A_bond_good else 'WEAK'} (F1 > {chance_bond*2:.1%})")
print()

print("DIRECTION B: Modern → Ancient")
print("-" * 40)
res_B = all_results.get('modern_to_ancient', {})
print(f"  Trained on:    {res_B.get('train_size', 0):,} modern passages")
print(f"  Tested on:     {res_B.get('test_size', 0):,} ancient passages")
print(f"  Time acc:      {res_B.get('time_acc', 0):.1%} (chance: {chance_time:.1%})")
print(f"  Hohfeld acc:   {res_B.get('hohfeld_acc', 0):.1%} (F1: {res_B.get('hohfeld_f1_macro', 0):.3f})")
print(f"  Bond acc:      {res_B.get('bond_acc', 0):.1%} (F1: {res_B.get('bond_f1_macro', 0):.3f})")
print()

# v6 NEW: Bond F1 by corpus
if 'corpus_bond_f1' in res_B:
    print("  Bond transfer by corpus:")
    for corpus, metrics in res_B['corpus_bond_f1'].items():
        print(f"    {corpus}: F1={metrics['f1_macro']:.3f}, Acc={metrics['accuracy']:.1%}")
    print()

B_time_near_chance = abs(res_B.get('time_acc', 0) - chance_time) < 0.05
B_hohfeld_good = res_B.get('hohfeld_acc', 0) > 0.35
B_bond_good = res_B.get('bond_f1_macro', 0) > chance_bond * 2

print(f"  Time invariant?    {'YES ✓' if B_time_near_chance else 'NO ✗'}")
print(f"  Moral structure?   {'YES ✓' if B_hohfeld_good else 'WEAK'}")
print(f"  Bond transfer?     {'YES ✓' if B_bond_good else 'WEAK'} (F1 > {chance_bond*2:.1%})")
print()

print("=" * 60)
print("BIDIRECTIONAL INVARIANCE TEST")
print("=" * 60)
print()

# v6: Updated verdict logic to include bond transfer
if A_time_near_chance and B_time_near_chance and A_hohfeld_good and B_hohfeld_good and A_bond_good and B_bond_good:
    print("""
    ╔══════════════════════════════════════════════════════════╗
    ║                                                          ║
    ║     BIDIRECTIONAL BIP: STRONGLY SUPPORTED                ║
    ║                                                          ║
    ╠══════════════════════════════════════════════════════════╣
    ║                                                          ║
    ║  ✓ Ancient → Modern: Bond structure transfers            ║
    ║  ✓ Modern → Ancient: Bond structure transfers            ║
    ║  ✓ BOTH directions show time-invariant moral geometry    ║
    ║  ✓ Bond-level transfer confirmed with F1 metrics         ║
    ║                                                          ║
    ║  This is STRONG evidence for universal moral structure.  ║
    ║                                                          ║
    ╚══════════════════════════════════════════════════════════╝
    """)
    bip_result = "STRONGLY_SUPPORTED"
elif (A_time_near_chance and A_hohfeld_good and A_bond_good) or (B_time_near_chance and B_hohfeld_good and B_bond_good):
    print("""
    ╔══════════════════════════════════════════════════════════╗
    ║                                                          ║
    ║     BIP: SUPPORTED (One direction)                       ║
    ║                                                          ║
    ╠══════════════════════════════════════════════════════════╣
    ║                                                          ║
    ║  At least one direction shows:                           ║
    ║    - Time-invariant representation                       ║
    ║    - Good Hohfeld classification                         ║
    ║    - Bond transfer above chance                          ║
    ║                                                          ║
    ║  Asymmetry may reflect corpus size/diversity differences ║
    ║                                                          ║
    ╚══════════════════════════════════════════════════════════╝
    """)
    bip_result = "SUPPORTED_UNIDIRECTIONAL"
elif A_hohfeld_good or B_hohfeld_good:
    print("""
    ╔══════════════════════════════════════════════════════════╗
    ║                                                          ║
    ║     BIP: PARTIAL SUPPORT                                 ║
    ║                                                          ║
    ╠══════════════════════════════════════════════════════════╣
    ║                                                          ║
    ║  Hohfeld classification works, but:                      ║
    ║    - Bond transfer may be weak                           ║
    ║    - Time may still be decodable                         ║
    ║                                                          ║
    ║  The representation captures moral structure but         ║
    ║  may not be fully time-invariant at bond level.          ║
    ║                                                          ║
    ╚══════════════════════════════════════════════════════════╝
    """)
    bip_result = "PARTIAL_SUPPORT"
else:
    print("""
    ╔══════════════════════════════════════════════════════════╗
    ║                                                          ║
    ║     BIP: INCONCLUSIVE                                    ║
    ║                                                          ║
    ╠══════════════════════════════════════════════════════════╣
    ║                                                          ║
    ║  Neither direction shows clear invariance.               ║
    ║                                                          ║
    ║  Possible issues:                                        ║
    ║  - Need more training epochs                             ║
    ║  - Need better bond extraction patterns                  ║
    ║  - BIP may not hold (null result)                        ║
    ║                                                          ║
    ╚══════════════════════════════════════════════════════════╝
    """)
    bip_result = "INCONCLUSIVE"

# Save detailed results including predictions
detailed_results = {
    'ancient_to_modern': all_results.get('ancient_to_modern', {}),
    'modern_to_ancient': all_results.get('modern_to_ancient', {}),
}

# Save to Drive for post-mortem
with open(f"{SAVE_DIR}/detailed_results.json", 'w') as f:
    json.dump(detailed_results, f, indent=2, default=str)
print(f"Detailed results saved to {SAVE_DIR}/detailed_results.json")

# Save results
results_summary = {
    'ancient_to_modern': all_results.get('ancient_to_modern', {}),
    'modern_to_ancient': all_results.get('modern_to_ancient', {}),
    'A_time_invariant': A_time_near_chance,
    'A_moral_structure': A_hohfeld_good,
    'A_bond_transfer': A_bond_good,
    'B_time_invariant': B_time_near_chance,
    'B_moral_structure': B_hohfeld_good,
    'B_bond_transfer': B_bond_good,
    'bip_result': bip_result,
    'chance_time': chance_time,
    'chance_hohfeld': chance_hohfeld,
    'chance_bond': chance_bond
}

with open('results/bidirectional_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2, default=str)

print()
print("Results saved to results/bidirectional_results.json")

mark_task("Evaluate results", "done")

print()
print("=" * 60)
print("EXPERIMENT COMPLETE")
print("=" * 60)
print_progress()


---

## Save Results

Run the cell below to download your trained model and results.

In [None]:
#@title 11. Download Results (Optional) { display-mode: "form" }
#@markdown Creates a zip file with model checkpoint and metrics.
#@markdown **v6 FIX**: Now copies correct model filenames.

import shutil
from google.colab import files

# Create results directory
!mkdir -p results

# v6 FIX: Copy the actual model checkpoints (with split_name suffix)
for split_name in ['ancient_to_modern', 'modern_to_ancient']:
    model_path = f"models/checkpoints/best_model_{split_name}.pt"
    if os.path.exists(model_path):
        !cp "{model_path}" results/
        print(f"Copied {model_path}")
    else:
        print(f"Warning: {model_path} not found")

!cp data/splits/all_splits.json results/ 2>/dev/null || echo "No splits file"
!cp data/splits/baselines.json results/ 2>/dev/null || echo "No baselines file"

# Save metrics
if 'all_results' in dir() and all_results:
    metrics = {
        'accelerator': ACCELERATOR,
        'results': all_results,
        'bip_result': bip_result if 'bip_result' in dir() else 'unknown'
    }
    with open('results/metrics.json', 'w') as f:
        json.dump(metrics, f, indent=2, default=str)
    print("Saved metrics.json")

# Zip
shutil.make_archive('bip_results_v6', 'zip', 'results')
print()
print("Results saved to bip_results_v6.zip")
print()
print("Contents:")
!ls -la results/

# Download
files.download('bip_results_v6.zip')