# BIP Cross-Temporal Morality Experiment (v7.1)

**Testing the Bond Invariance Principle across 2000+ years: Hebrew → English**

This experiment tests whether moral cognition has invariant structure by:
1. Training on ORIGINAL HEBREW texts (Sefaria corpus, ~500 BCE - 1800 CE)
2. Testing transfer to modern ENGLISH advice columns (Dear Abby, 1956-2020)

**Hypothesis**: If BIP holds, bond-level features should transfer across 2000 years with minimal degradation compared to in-domain baselines.

---

## v7.1 Changes (L4 Optimized)
- **Tuned for L4 GPU** (22.5GB VRAM, 53GB RAM, 236GB disk)
- **Larger batch sizes**: 512 for training (vs 256 on T4)
- **Full corpus by default**: MAX_SEFARIA_FILES = 0 (unlimited)
- **Better instrumentation**: Disk space, GPU memory throughout
- **Fixed missing imports**: All cells are self-contained
- **Runtime estimate**: ~45-90 minutes on L4

---

## Important Methodological Notes

**Label Source**: Bond types and Hohfeld states are extracted from **English translations** (`text_english`), not the original Hebrew. This means we're testing whether Hebrew text encodes moral structures that align with labels derived from English translations.

**Current Scope**: Hebrew (Sefaria) ↔ English (Dear Abby) transfer only.

---

## Setup Instructions
1. **Runtime -> Change runtime type -> L4 GPU**
2. Run cells in order
3. Expected runtime: ~45-90 minutes on L4

---

In [None]:
#@title 1. Setup and Install Dependencies { display-mode: "form" }
#@markdown Installs packages and detects GPU. Tuned for L4 runtime.

import time
EXPERIMENT_START = time.time()

print("=" * 60)
print("BIP TEMPORAL INVARIANCE EXPERIMENT (v7.1 - L4 Optimized)")
print("=" * 60)
print()

# Progress tracker
TASKS = [
    "Install dependencies",
    "Clone Sefaria corpus (~8GB)",
    "Clone sqnd-probe repo (Dear Abby data)",
    "Preprocess corpora",
    "Extract bond structures",
    "Generate train/test splits",
    "Train BIP model",
    "Linear probe test",
    "Evaluate results"
]
task_status = {task: "pending" for task in TASKS}
task_times = {}

def print_progress():
    print()
    print("-" * 50)
    print("EXPERIMENT PROGRESS:")
    print("-" * 50)
    for task in TASKS:
        status = task_status[task]
        if status == "done":
            mark = "[X]"
            time_str = f" ({task_times.get(task, 0):.1f}s)" if task in task_times else ""
        elif status == "running":
            mark = "[>]"
            time_str = ""
        else:
            mark = "[ ]"
            time_str = ""
        print(f"  {mark} {task}{time_str}")
    elapsed = time.time() - EXPERIMENT_START
    print("-" * 50)
    print(f"  Total elapsed: {elapsed/60:.1f} minutes")
    print(flush=True)

def mark_task(task, status):
    global task_start_time
    if status == "running":
        task_start_time = time.time()
    elif status == "done" and 'task_start_time' in dir():
        task_times[task] = time.time() - task_start_time
    task_status[task] = status
    print_progress()

print_progress()

mark_task("Install dependencies", "running")

import os
import subprocess
import sys

# Install dependencies
print("Installing dependencies...")
deps = [
    "transformers",
    "torch", 
    "sentence-transformers",
    "pandas",
    "tqdm",
    "psutil",
    "scikit-learn"
]

for dep in deps:
    print(f"  Installing {dep}...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", dep])

print()

# Detect accelerator
USE_TPU = False
TPU_TYPE = None

if 'COLAB_TPU_ADDR' in os.environ:
    USE_TPU = True
    TPU_TYPE = "TPU (Colab)"
    print("TPU detected!")

import torch
import json
import psutil
import shutil

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    ACCELERATOR = f"GPU: {gpu_name} ({gpu_mem:.1f}GB)"
    device = torch.device("cuda")
    
    # L4 detection for batch size tuning
    IS_L4 = 'L4' in gpu_name
    IS_A100 = 'A100' in gpu_name
    IS_V100 = 'V100' in gpu_name
elif USE_TPU:
    ACCELERATOR = TPU_TYPE
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    IS_L4 = False
    IS_A100 = False
    IS_V100 = False
else:
    ACCELERATOR = "CPU (slow!)"
    device = torch.device("cpu")
    IS_L4 = False
    IS_A100 = False
    IS_V100 = False

print(f"Accelerator: {ACCELERATOR}")
print(f"Device: {device}")

# System resources
mem = psutil.virtual_memory()
disk = shutil.disk_usage('/')
print(f"System RAM: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")
print(f"Disk: {disk.used/1e9:.1f}/{disk.total/1e9:.1f} GB ({100*disk.used/disk.total:.1f}%)")

if torch.cuda.is_available():
    gpu_used = torch.cuda.memory_allocated()/1e9
    gpu_total = torch.cuda.get_device_properties(0).total_memory/1e9
    print(f"GPU RAM: {gpu_used:.1f}/{gpu_total:.1f} GB ({100*gpu_used/gpu_total:.1f}%)")

# Set batch sizes based on GPU
if IS_L4 or IS_A100:
    BASE_BATCH_SIZE = 512  # L4 has 22.5GB, can handle larger batches
    print(f"\n*** L4/A100 detected: Using large batch size ({BASE_BATCH_SIZE}) ***")
elif IS_V100:
    BASE_BATCH_SIZE = 384
    print(f"\n*** V100 detected: Using batch size {BASE_BATCH_SIZE} ***")
else:
    BASE_BATCH_SIZE = 256  # T4 or smaller
    print(f"\n*** Using standard batch size ({BASE_BATCH_SIZE}) ***")

# Enable mixed precision
if torch.cuda.is_available():
    print("\nEnabling mixed precision (FP16) for faster training...")
    from torch.cuda.amp import autocast, GradScaler
    USE_AMP = True
    scaler = GradScaler()
else:
    USE_AMP = False
    scaler = None

TORCH_COMPILE = False

# Google Drive mount
print()
print("=" * 60)
print("MOUNTING GOOGLE DRIVE FOR PERSISTENT STORAGE")
print("=" * 60)
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = '/content/drive/MyDrive/BIP_results'
os.makedirs(SAVE_DIR, exist_ok=True)

os.makedirs("data/processed", exist_ok=True)
os.makedirs("data/splits", exist_ok=True)
os.makedirs("data/raw", exist_ok=True)
os.makedirs("models/checkpoints", exist_ok=True)
os.makedirs("results", exist_ok=True)
print(f"Results will be saved to: {SAVE_DIR}")
print()

def print_resources(label=""):
    """Print current resource usage."""
    mem = psutil.virtual_memory()
    disk = shutil.disk_usage('/')
    msg = f"[{label}] " if label else ""
    msg += f"RAM: {mem.used/1e9:.1f}/{mem.total/1e9:.1f}GB"
    if torch.cuda.is_available():
        gpu_used = torch.cuda.memory_allocated()/1e9
        gpu_total = torch.cuda.get_device_properties(0).total_memory/1e9
        msg += f" | GPU: {gpu_used:.1f}/{gpu_total:.1f}GB"
    msg += f" | Disk: {disk.used/1e9:.0f}/{disk.total/1e9:.0f}GB"
    print(msg)

print_resources("After setup")

mark_task("Install dependencies", "done")


In [None]:
#@title 2. Download Sefaria Corpus (~8GB) { display-mode: "form" }
#@markdown Downloads the complete Sefaria corpus with real-time git progress.

import subprocess
import sys
import os

mark_task("Clone Sefaria corpus (~8GB)", "running")

sefaria_path = 'data/raw/Sefaria-Export'

if not os.path.exists(sefaria_path) or not os.path.exists(f"{sefaria_path}/json"):
    print("="*60)
    print("CLONING SEFARIA CORPUS")
    print("="*60)
    print()
    print("This downloads ~3.5GB and takes 5-15 minutes.")
    print("Git's native progress will display below:")
    print("-"*60)
    print(flush=True)
    
    process = subprocess.Popen(
        ['git', 'clone', '--depth', '1', '--progress',
         'https://github.com/Sefaria/Sefaria-Export.git', sefaria_path],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )
    
    for line in process.stdout:
        print(line, end='', flush=True)
    
    process.wait()
    
    print("-"*60)
    if process.returncode == 0:
        print("\nSefaria clone COMPLETE!")
    else:
        print(f"\nERROR: Git clone failed with code {process.returncode}")
else:
    print("Sefaria already exists, skipping download.")

print()
print("Verifying download...")
!du -sh {sefaria_path} 2>/dev/null || echo "Directory not found"
json_count = !find {sefaria_path}/json -name "*.json" 2>/dev/null | wc -l
print(f"Sefaria JSON files found: {json_count[0]}")

print_resources("After Sefaria download")

mark_task("Clone Sefaria corpus (~8GB)", "done")


In [None]:
#@title 3. Download Dear Abby Dataset { display-mode: "form" }
#@markdown Downloads the Dear Abby advice column dataset (68,330 entries).

import subprocess
import os
import pandas as pd
from pathlib import Path

mark_task("Clone sqnd-probe repo (Dear Abby data)", "running")

sqnd_path = 'sqnd-probe-data'
if not os.path.exists(sqnd_path):
    print("Cloning sqnd-probe repo...")
    process = subprocess.Popen(
        ['git', 'clone', '--depth', '1', '--progress',
         'https://github.com/ahb-sjsu/sqnd-probe.git', sqnd_path],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )
    for line in process.stdout:
        print(line, end='', flush=True)
    process.wait()
else:
    print("Repo already cloned.")

dear_abby_source = Path('sqnd-probe-data/dear_abby_data/raw_da_qs.csv')
dear_abby_path = Path('data/raw/dear_abby.csv')

if dear_abby_source.exists():
    !cp "{dear_abby_source}" "{dear_abby_path}"
    print(f"\nCopied Dear Abby data")
elif not dear_abby_path.exists():
    raise FileNotFoundError("Dear Abby dataset not found!")

df_check = pd.read_csv(dear_abby_path)
print(f"\n" + "=" * 50)
print(f"Dear Abby dataset: {len(df_check):,} entries")
print(f"Columns: {list(df_check.columns)}")
print(f"Year range: {df_check['year'].min():.0f} - {df_check['year'].max():.0f}")
print("=" * 50)

print_resources("After Dear Abby download")

mark_task("Clone sqnd-probe repo (Dear Abby data)", "done")

In [None]:
#@title 4. Define Data Classes and Loaders { display-mode: "form" }
#@markdown Defines enums, dataclasses, and corpus loaders.
#@markdown **Note**: Labels extracted from English translations (text_english).

import json
import hashlib
import re
import os
import pandas as pd
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Set
from enum import Enum
from collections import defaultdict
from tqdm.auto import tqdm

print("Defining data structures...")

class TimePeriod(Enum):
    BIBLICAL = 0        # ~1000-500 BCE
    SECOND_TEMPLE = 1   # ~500 BCE - 70 CE
    TANNAITIC = 2       # ~70-200 CE
    AMORAIC = 3         # ~200-500 CE
    GEONIC = 4          # ~600-1000 CE
    RISHONIM = 5        # ~1000-1500 CE
    ACHRONIM = 6        # ~1500-1800 CE
    MODERN_HEBREW = 7   # ~1800-present
    DEAR_ABBY = 8       # 1956-2020

class BondType(Enum):
    HARM_PREVENTION = 0
    RECIPROCITY = 1
    AUTONOMY = 2
    PROPERTY = 3
    FAMILY = 4
    AUTHORITY = 5
    EMERGENCY = 6
    CONTRACT = 7
    CARE = 8
    FAIRNESS = 9
    NONE = 10  # Explicit NONE class

class HohfeldianState(Enum):
    RIGHT = 0
    OBLIGATION = 1
    LIBERTY = 2
    NO_RIGHT = 3

@dataclass
class Passage:
    id: str
    text_original: str
    text_english: str
    time_period: str
    century: int
    source: str
    source_type: str
    category: str
    language: str = "hebrew"
    word_count: int = 0
    has_dispute: bool = False
    consensus_tier: str = "unknown"
    bond_types: List[str] = field(default_factory=list)
    
    def to_dict(self):
        return asdict(self)

CATEGORY_TO_PERIOD = {
    'Tanakh': TimePeriod.BIBLICAL,
    'Torah': TimePeriod.BIBLICAL,
    'Mishnah': TimePeriod.TANNAITIC,
    'Tosefta': TimePeriod.TANNAITIC,
    'Talmud': TimePeriod.AMORAIC,
    'Bavli': TimePeriod.AMORAIC,
    'Midrash': TimePeriod.AMORAIC,
    'Halakhah': TimePeriod.RISHONIM,
    'Chasidut': TimePeriod.ACHRONIM,
}

PERIOD_TO_CENTURY = {
    TimePeriod.BIBLICAL: -6,
    TimePeriod.SECOND_TEMPLE: -2,
    TimePeriod.TANNAITIC: 2,
    TimePeriod.AMORAIC: 4,
    TimePeriod.GEONIC: 8,
    TimePeriod.RISHONIM: 12,
    TimePeriod.ACHRONIM: 17,
    TimePeriod.MODERN_HEBREW: 20,
}

def load_sefaria(base_path: str, max_files: int = None) -> List[Passage]:
    """Load Sefaria corpus.
    
    Args:
        base_path: Path to Sefaria-Export directory
        max_files: Maximum number of JSON files to process (NOT passages).
                   Set to None or 0 for unlimited.
    """
    passages = []
    json_path = Path(base_path) / "json"
    
    if not json_path.exists():
        print(f"Warning: {json_path} not found")
        return []
    
    json_files = list(json_path.rglob("*.json"))
    print(f"Found {len(json_files):,} JSON files...")
    
    if max_files and max_files > 0:
        files_to_process = json_files[:max_files]
        print(f"Processing {len(files_to_process):,} files (max_files={max_files})...")
    else:
        files_to_process = json_files
        print(f"Processing ALL {len(files_to_process):,} files...")
    
    for json_file in tqdm(files_to_process, desc="Loading Sefaria", unit="file"):
        try:
            with open(json_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except:
            continue
        
        rel_path = json_file.relative_to(json_path)
        category = str(rel_path.parts[0]) if rel_path.parts else "unknown"
        time_period = CATEGORY_TO_PERIOD.get(category, TimePeriod.AMORAIC)
        century = PERIOD_TO_CENTURY.get(time_period, 0)
        
        if isinstance(data, dict):
            hebrew = data.get('he', data.get('text', []))
            english = data.get('text', data.get('en', []))
            
            def flatten(h, e, ref=""):
                if isinstance(h, str) and isinstance(e, str):
                    h_clean = re.sub(r'<[^>]+>', '', h).strip()
                    e_clean = re.sub(r'<[^>]+>', '', e).strip()
                    if 50 <= len(e_clean) <= 2000:
                        pid = hashlib.md5(f"{json_file.stem}:{ref}:{h_clean[:50]}".encode()).hexdigest()[:12]
                        return [Passage(
                            id=f"sefaria_{pid}",
                            text_original=h_clean,
                            text_english=e_clean,
                            time_period=time_period.name,
                            century=century,
                            source=f"{json_file.stem} {ref}".strip(),
                            source_type="sefaria",
                            category=category,
                            language="hebrew",
                            word_count=len(e_clean.split())
                        )]
                    return []
                elif isinstance(h, list) and isinstance(e, list):
                    result = []
                    for i, (hh, ee) in enumerate(zip(h, e)):
                        result.extend(flatten(hh, ee, f"{ref}.{i+1}" if ref else str(i+1)))
                    return result
                return []
            
            passages.extend(flatten(hebrew, english))
    
    return passages

def load_dear_abby(path: str, max_passages: int = None) -> List[Passage]:
    """Load Dear Abby corpus."""
    passages = []
    df = pd.read_csv(path)
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading Dear Abby", unit="row"):
        question = str(row.get('question_only', ''))
        if not question or question == 'nan' or len(question) < 50 or len(question) > 2000:
            continue
        
        year = int(row.get('year', 1990))
        pid = hashlib.md5(f"abby:{idx}:{question[:50]}".encode()).hexdigest()[:12]
        
        passages.append(Passage(
            id=f"abby_{pid}",
            text_original=question,
            text_english=question,
            time_period=TimePeriod.DEAR_ABBY.name,
            century=20 if year < 2000 else 21,
            source=f"Dear Abby {year}",
            source_type="dear_abby",
            category="general",
            language="english",
            word_count=len(question.split())
        ))
        
        if max_passages and len(passages) >= max_passages:
            break
    
    return passages

print("Data structures defined!")
print()
print("NOTE: Bond/Hohfeld labels will be extracted from text_english.")
print("For Hebrew texts, this means labels come from English translations.")

In [None]:
#@title 5. Load and Preprocess Corpora { display-mode: "form" }
#@markdown Loads Hebrew (Sefaria) and English (Dear Abby) corpora.
#@markdown **L4 Optimized**: Set to 0 for full corpus (recommended for L4).

#@markdown **Memory Management:**
MAX_SEFARIA_FILES = 0  #@param {type:"integer"}
#@markdown Set to 0 for FULL corpus (recommended for L4 with 53GB RAM).
#@markdown Set to 5000 for faster testing.

import gc
import json
from collections import defaultdict

mark_task("Preprocess corpora", "running")

print("=" * 60)
print("LOADING CORPORA")
print("=" * 60)
print()
print("Current scope: Hebrew (Sefaria) ↔ English (Dear Abby)")
print()

if MAX_SEFARIA_FILES > 0:
    print(f"*** LIMITED MODE: Processing {MAX_SEFARIA_FILES:,} Sefaria JSON FILES ***")
else:
    print("*** FULL CORPUS MODE: Processing ALL Sefaria files ***")
print()

limit = MAX_SEFARIA_FILES if MAX_SEFARIA_FILES > 0 else None
sefaria_passages = load_sefaria("data/raw/Sefaria-Export", max_files=limit)
print(f"\nSefaria passages loaded: {len(sefaria_passages):,}")

print_resources("After Sefaria load")
gc.collect()

print()
abby_passages = load_dear_abby("data/raw/dear_abby.csv")
print(f"\nDear Abby passages loaded: {len(abby_passages):,}")

all_passages = sefaria_passages + abby_passages

del sefaria_passages
del abby_passages
gc.collect()

print()
print("=" * 60)
print(f"TOTAL PASSAGES: {len(all_passages):,}")
print("=" * 60)

by_period = defaultdict(int)
by_source = defaultdict(int)
for p in all_passages:
    by_period[p.time_period] += 1
    by_source[p.source_type] += 1

print("\nBy source:")
for source, count in sorted(by_source.items()):
    print(f"  {source}: {count:,}")

print("\nBy time period:")
for period, count in sorted(by_period.items()):
    pct = count / len(all_passages) * 100
    bar = '#' * int(pct / 2)
    print(f"  {period:20s}: {count:6,} ({pct:5.1f}%) {bar}")

print_resources("After loading")

mark_task("Preprocess corpora", "done")


In [None]:
#@title 6. Extract Bond Structures { display-mode: "form" }
#@markdown Extracts moral bond structures from **English text** (translations for Hebrew).

import gc
import json
import re
from collections import defaultdict
from tqdm.auto import tqdm

mark_task("Extract bond structures", "running")

print("=" * 60)
print("EXTRACTING BOND STRUCTURES")
print("=" * 60)
print()
print("NOTE: Patterns applied to text_english (translations for Hebrew).")
print("This means Hebrew text labels are derived from English translations.")
print()

RELATION_PATTERNS = {
    BondType.HARM_PREVENTION: [r'\b(kill|murder|harm|hurt|save|rescue|protect|danger|attack|injure|wound|destroy)\b'],
    BondType.RECIPROCITY: [r'\b(return|repay|owe|debt|mutual|exchange|give back|pay back|reciprocate)\b'],
    BondType.AUTONOMY: [r'\b(choose|decision|consent|agree|force|coerce|right|freedom|liberty|self-determination)\b'],
    BondType.PROPERTY: [r'\b(property|own|steal|theft|buy|sell|land|possess|belong|asset)\b'],
    BondType.FAMILY: [r'\b(honor|parent|marry|divorce|inherit|family|mother|father|child|son|daughter|spouse|husband|wife)\b'],
    BondType.AUTHORITY: [r'\b(obey|command|law|judge|rule|teach|leader|king|master|servant|subject)\b'],
    BondType.CARE: [r'\b(care|help|assist|feed|clothe|visit|nurture|tend|support|comfort)\b'],
    BondType.FAIRNESS: [r'\b(fair|just|equal|deserve|bias|impartial|equity|discrimination)\b'],
    BondType.EMERGENCY: [r'\b(emergency|urgent|crisis|danger|life-threatening|immediate|desperate|dire|peril|rescue)\b'],
    BondType.CONTRACT: [r'\b(contract|agreement|promise|vow|oath|covenant|pledge|commit|bind|treaty|negotiate)\b'],
}

HOHFELD_PATTERNS = {
    HohfeldianState.OBLIGATION: [r'\b(must|shall|duty|require|should|ought|obligated)\b'],
    HohfeldianState.RIGHT: [r'\b(right to|entitled|deserve|claim|due)\b'],
    HohfeldianState.LIBERTY: [r'\b(may|can|permitted|allowed|free to|at liberty)\b'],
}

def extract_bond_structure(passage: Passage) -> Dict:
    """Extract bond structure from passage's English text."""
    text = passage.text_english.lower()
    
    relations = []
    for rel_type, patterns in RELATION_PATTERNS.items():
        for pattern in patterns:
            if re.search(pattern, text, re.IGNORECASE):
                relations.append(rel_type.name)
                break
    
    if not relations:
        relations = ['NONE']
    
    hohfeld = None
    for state, patterns in HOHFELD_PATTERNS.items():
        for pattern in patterns:
            if re.search(pattern, text, re.IGNORECASE):
                hohfeld = state.name
                break
        if hohfeld:
            break
    
    signature = "|".join(sorted(set(relations)))
    
    return {
        'bonds': [{'relation': r} for r in relations],
        'primary_relation': relations[0],
        'hohfeld_state': hohfeld,
        'signature': signature
    }

print("Writing to disk...")
print()

bond_counts = defaultdict(int)

with open("data/processed/passages.jsonl", 'w') as f_pass, \
     open("data/processed/bond_structures.jsonl", 'w') as f_bond:
    
    for passage in tqdm(all_passages, desc="Processing", unit="passage"):
        bond_struct = extract_bond_structure(passage)
        passage.bond_types = [b['relation'] for b in bond_struct['bonds']]
        
        for bond in bond_struct['bonds']:
            bond_counts[bond['relation']] += 1
        
        f_pass.write(json.dumps(passage.to_dict()) + '\n')
        f_bond.write(json.dumps({
            'passage_id': passage.id,
            'bond_structure': bond_struct
        }) + '\n')

n_passages = len(all_passages)
del all_passages
gc.collect()

print()
print(f"Saved {n_passages:,} passages to disk")

print_resources("After extraction")

print()
print("Bond type distribution:")
for bond_type, count in sorted(bond_counts.items(), key=lambda x: -x[1]):
    pct = count / sum(bond_counts.values()) * 100
    bar = '#' * int(pct)
    print(f"  {bond_type:20s}: {count:6,} ({pct:5.1f}%) {bar}")

if bond_counts.get('NONE', 0) / sum(bond_counts.values()) > 0.5:
    print()
    print("WARNING: NONE class is >50% - consider expanding patterns.")

mark_task("Extract bond structures", "done")


In [None]:
#@title 7. Generate Train/Test Splits { display-mode: "form" }
#@markdown Creates temporal splits for cross-era evaluation.

import random
import gc
import json
from tqdm.auto import tqdm
from collections import defaultdict

random.seed(42)

mark_task("Generate train/test splits", "running")

print("=" * 60)
print("GENERATING SPLITS")
print("=" * 60)
print()

print("Reading passage metadata from disk...")
passage_info = []

with open("data/processed/passages.jsonl", 'r') as f:
    for line in tqdm(f, desc="Reading IDs", unit="line"):
        p = json.loads(line)
        passage_info.append((p['id'], p['time_period']))

print(f"Loaded {len(passage_info):,} passage IDs")

train_periods = {'BIBLICAL', 'SECOND_TEMPLE', 'TANNAITIC', 'AMORAIC', 'GEONIC', 'RISHONIM'}
valid_periods = {'ACHRONIM'}
test_periods = {'MODERN_HEBREW', 'DEAR_ABBY'}

print()
print("Filtering by time period...")
ancient_ids = [(pid, tp) for pid, tp in passage_info if tp in train_periods]
early_modern_ids = [(pid, tp) for pid, tp in passage_info if tp in valid_periods]
modern_ids = [(pid, tp) for pid, tp in passage_info if tp in test_periods]

print(f"  Ancient/Medieval: {len(ancient_ids):,}")
print(f"  Early Modern:     {len(early_modern_ids):,}")
print(f"  Modern:           {len(modern_ids):,}")

random.shuffle(ancient_ids)
random.shuffle(early_modern_ids)
random.shuffle(modern_ids)

# SPLIT A: ANCIENT -> MODERN
print()
print("-" * 60)
print("SPLIT A: Train ANCIENT, Test MODERN")
print("-" * 60)

temporal_A = {
    'name': 'ancient_to_modern',
    'direction': 'A->M',
    'train_ids': [pid for pid, _ in ancient_ids],
    'valid_ids': [pid for pid, _ in early_modern_ids],
    'test_ids': [pid for pid, _ in modern_ids],
    'train_size': len(ancient_ids),
    'valid_size': len(early_modern_ids),
    'test_size': len(modern_ids)
}
print(f"  Train: {temporal_A['train_size']:,}")
print(f"  Valid: {temporal_A['valid_size']:,}")
print(f"  Test:  {temporal_A['test_size']:,}")

# SPLIT B: MODERN -> ANCIENT
print()
print("-" * 60)
print("SPLIT B: Train MODERN, Test ANCIENT")
print("-" * 60)

n_modern = len(modern_ids)
ancient_test = ancient_ids[n_modern:n_modern*2] if len(ancient_ids) >= n_modern*2 else ancient_ids[n_modern:]

temporal_B = {
    'name': 'modern_to_ancient',
    'direction': 'M->A',
    'train_ids': [pid for pid, _ in modern_ids],
    'valid_ids': [pid for pid, _ in early_modern_ids[:len(early_modern_ids)//2]],
    'test_ids': [pid for pid, _ in ancient_test],
    'train_size': len(modern_ids),
    'valid_size': len(early_modern_ids) // 2,
    'test_size': len(ancient_test)
}
print(f"  Train: {temporal_B['train_size']:,}")
print(f"  Valid: {temporal_B['valid_size']:,}")
print(f"  Test:  {temporal_B['test_size']:,}")

# SPLIT C: MIXED (in-domain baseline)
print()
print("-" * 60)
print("SPLIT C: MIXED (In-Domain Baseline)")
print("-" * 60)

all_ids = ancient_ids + modern_ids
random.shuffle(all_ids)
n = len(all_ids)
n_train = int(0.7 * n)
n_valid = int(0.15 * n)

temporal_C = {
    'name': 'mixed_control',
    'direction': 'MIXED',
    'train_ids': [pid for pid, _ in all_ids[:n_train]],
    'valid_ids': [pid for pid, _ in all_ids[n_train:n_train+n_valid]],
    'test_ids': [pid for pid, _ in all_ids[n_train+n_valid:]],
    'train_size': n_train,
    'valid_size': n_valid,
    'test_size': n - n_train - n_valid
}
print(f"  Train: {temporal_C['train_size']:,}")
print(f"  Valid: {temporal_C['valid_size']:,}")
print(f"  Test:  {temporal_C['test_size']:,}")

del passage_info, ancient_ids, early_modern_ids, modern_ids, all_ids
gc.collect()

print()
print("Saving splits...")
splits = {
    'ancient_to_modern': temporal_A,
    'modern_to_ancient': temporal_B,
    'mixed_control': temporal_C
}

with open("data/splits/all_splits.json", 'w') as f:
    json.dump(splits, f, indent=2)

# DISTRIBUTION CHECK with ID integrity
print()
print("=" * 60)
print("LABEL DISTRIBUTION CHECK (with ID integrity verification)")
print("=" * 60)

hohfeld_counts = {}
time_counts = {}
bond_type_counts = {}
id_mismatches = 0

with open("data/processed/bond_structures.jsonl", 'r') as fb, \
     open("data/processed/passages.jsonl", 'r') as fp:
    for b_line, p_line in zip(fb, fp):
        b = json.loads(b_line)
        p = json.loads(p_line)
        
        # ID integrity check
        if b['passage_id'] != p['id']:
            id_mismatches += 1
            continue
        
        h = b['bond_structure'].get('hohfeld_state', None)
        t = p['time_period']
        bond = b['bond_structure'].get('primary_relation', 'NONE')
        
        hohfeld_counts[h] = hohfeld_counts.get(h, 0) + 1
        time_counts[t] = time_counts.get(t, 0) + 1
        bond_type_counts[bond] = bond_type_counts.get(bond, 0) + 1

if id_mismatches > 0:
    print(f"WARNING: {id_mismatches} ID mismatches found!")
else:
    print("ID integrity check: PASSED ✓")
print()

print("Hohfeld distribution:")
total_h = sum(hohfeld_counts.values())
for h, c in sorted(hohfeld_counts.items(), key=lambda x: -x[1]):
    pct = 100 * c / total_h
    bar = "#" * int(pct / 2)
    print(f"  {str(h):15s}: {c:>8,} ({pct:5.1f}%) {bar}")

print()
print("Time period distribution:")
total_t = sum(time_counts.values())
for t, c in sorted(time_counts.items(), key=lambda x: -x[1]):
    pct = 100 * c / total_t
    bar = "#" * int(pct / 2)
    print(f"  {t:15s}: {c:>8,} ({pct:5.1f}%) {bar}")

print()
print("Bond type distribution:")
total_b = sum(bond_type_counts.values())
for b, c in sorted(bond_type_counts.items(), key=lambda x: -x[1]):
    pct = 100 * c / total_b
    bar = "#" * int(pct / 2)
    print(f"  {b:20s}: {c:>8,} ({pct:5.1f}%) {bar}")

N_HOHFELD_CLASSES = len([h for h in hohfeld_counts if h is not None]) + 1
N_TIME_CLASSES = len(time_counts)
N_BOND_CLASSES = len(bond_type_counts)

CHANCE_HOHFELD = 1.0 / N_HOHFELD_CLASSES
CHANCE_TIME = 1.0 / N_TIME_CLASSES
CHANCE_BOND = 1.0 / N_BOND_CLASSES

print()
print(f"Chance baseline - Hohfeld: {CHANCE_HOHFELD:.1%} ({N_HOHFELD_CLASSES} classes)")
print(f"Chance baseline - Time:    {CHANCE_TIME:.1%} ({N_TIME_CLASSES} classes)")
print(f"Chance baseline - Bond:    {CHANCE_BOND:.1%} ({N_BOND_CLASSES} classes)")

baselines = {
    'hohfeld_counts': {str(k): v for k, v in hohfeld_counts.items()},
    'time_counts': time_counts,
    'bond_type_counts': bond_type_counts,
    'chance_hohfeld': CHANCE_HOHFELD,
    'chance_time': CHANCE_TIME,
    'chance_bond': CHANCE_BOND,
    'n_hohfeld_classes': N_HOHFELD_CLASSES,
    'n_time_classes': N_TIME_CLASSES,
    'n_bond_classes': N_BOND_CLASSES
}
with open("data/splits/baselines.json", 'w') as f:
    json.dump(baselines, f, indent=2)

most_common_hohfeld = max(hohfeld_counts.values()) / total_h
if most_common_hohfeld > 0.7:
    print()
    print(f"WARNING: Hohfeld labels severely imbalanced! Most common = {most_common_hohfeld:.1%}")

# Save to Drive
print()
print("Saving preprocessed data to Google Drive...")
import shutil
shutil.copytree("data/processed", f"{SAVE_DIR}/processed", dirs_exist_ok=True)
shutil.copytree("data/splits", f"{SAVE_DIR}/splits", dirs_exist_ok=True)
print(f"Saved to {SAVE_DIR}")

print_resources("After splits")

mark_task("Generate train/test splits", "done")


In [None]:
#@title 8. Define BIP Model Architecture { display-mode: "form" }
#@markdown Model with bond prediction head and support for linear probe extraction.

import gc
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from tqdm.auto import tqdm

print("Clearing memory before model load...")
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print_resources("Before model definition")
print()

print("=" * 60)
print("DEFINING MODEL ARCHITECTURE (v7.1)")
print("=" * 60)
print()
print("Encoder: paraphrase-multilingual-MiniLM-L12-v2")
print("  - Maps Hebrew and English into shared embedding space")
print("  - Input: Hebrew (Sefaria) or English (Dear Abby)")
print("  - Labels: Derived from English translations")
print()

class GradientReversal(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()
    
    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_ * grad_output, None

def gradient_reversal(x, lambda_=1.0):
    return GradientReversal.apply(x, lambda_)

class BIPEncoder(nn.Module):
    def __init__(self, model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", d_model=384):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.d_model = d_model
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
        return pooled

class BIPModel(nn.Module):
    """BIP Model with z_bond extraction for linear probe."""
    def __init__(self, d_model=384, d_bond=64, d_label=32, n_periods=9, n_hohfeld=4, n_bonds=11):
        super().__init__()
        
        self.encoder = BIPEncoder()
        
        self.bond_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, d_bond)
        )
        
        self.label_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, d_label)
        )
        
        self.time_classifier_bond = nn.Linear(d_bond, n_periods)
        self.time_classifier_label = nn.Linear(d_label, n_periods)
        self.hohfeld_classifier = nn.Linear(d_bond, n_hohfeld)
        self.bond_classifier = nn.Linear(d_bond, n_bonds)
    
    def forward(self, input_ids, attention_mask, adversarial_lambda=1.0):
        h = self.encoder(input_ids, attention_mask)
        
        z_bond = self.bond_proj(h)
        z_label = self.label_proj(h)
        
        z_bond_adv = gradient_reversal(z_bond, adversarial_lambda)
        time_pred_bond = self.time_classifier_bond(z_bond_adv)
        time_pred_label = self.time_classifier_label(z_label)
        hohfeld_pred = self.hohfeld_classifier(z_bond)
        bond_pred = self.bond_classifier(z_bond)
        
        return {
            'z_bond': z_bond,
            'z_label': z_label,
            'time_pred_bond': time_pred_bond,
            'time_pred_label': time_pred_label,
            'hohfeld_pred': hohfeld_pred,
            'bond_pred': bond_pred
        }
    
    def extract_z_bond(self, input_ids, attention_mask):
        """Extract z_bond embeddings only (for linear probe)."""
        with torch.no_grad():
            h = self.encoder(input_ids, attention_mask)
            z_bond = self.bond_proj(h)
        return z_bond

TIME_PERIOD_TO_IDX = {
    'BIBLICAL': 0, 'SECOND_TEMPLE': 1, 'TANNAITIC': 2, 'AMORAIC': 3,
    'GEONIC': 4, 'RISHONIM': 5, 'ACHRONIM': 6, 'MODERN_HEBREW': 7,
    'DEAR_ABBY': 8
}
IDX_TO_TIME_PERIOD = {v: k for k, v in TIME_PERIOD_TO_IDX.items()}

HOHFELD_TO_IDX = {
    'OBLIGATION': 0, 'RIGHT': 1, 'LIBERTY': 2, None: 3
}

BOND_TYPE_TO_IDX = {
    'HARM_PREVENTION': 0, 'RECIPROCITY': 1, 'AUTONOMY': 2, 'PROPERTY': 3,
    'FAMILY': 4, 'AUTHORITY': 5, 'EMERGENCY': 6, 'CONTRACT': 7,
    'CARE': 8, 'FAIRNESS': 9, 'NONE': 10
}
IDX_TO_BOND_TYPE = {v: k for k, v in BOND_TYPE_TO_IDX.items()}

class MoralDataset(Dataset):
    """Dataset with ID integrity checking."""
    def __init__(self, passage_ids: set, passages_file: str, bonds_file: str, tokenizer, max_len=64):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.passage_ids = passage_ids
        
        print(f"  Indexing {len(passage_ids):,} passages...")
        
        self.data = []
        id_mismatches = 0
        
        with open(passages_file, 'r') as f_pass, open(bonds_file, 'r') as f_bond:
            for p_line, b_line in tqdm(zip(f_pass, f_bond), desc="  Loading subset", unit="line", total=None):
                p = json.loads(p_line)
                b = json.loads(b_line)
                
                # ID integrity check
                if b['passage_id'] != p['id']:
                    id_mismatches += 1
                    continue
                
                if p['id'] in passage_ids:
                    self.data.append({
                        'text': (p.get('text_original', '') if p.get('language') in ['hebrew', 'chinese', 'arabic'] else p.get('text_english', ''))[:1000],
                        'time_period': p['time_period'],
                        'source_type': p['source_type'],
                        'hohfeld': b['bond_structure']['hohfeld_state'],
                        'primary_relation': b['bond_structure']['primary_relation']
                    })
        
        if id_mismatches > 0:
            print(f"  WARNING: {id_mismatches} ID mismatches found and skipped!")
        print(f"  Loaded {len(self.data):,} samples")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            max_length=self.max_len,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'time_label': TIME_PERIOD_TO_IDX.get(item['time_period'], 8),
            'hohfeld_label': HOHFELD_TO_IDX.get(item['hohfeld'], 3),
            'bond_label': BOND_TYPE_TO_IDX.get(item['primary_relation'], 10),
            'source_type': item['source_type']
        }

def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'time_labels': torch.tensor([x['time_label'] for x in batch]),
        'hohfeld_labels': torch.tensor([x['hohfeld_label'] for x in batch]),
        'bond_labels': torch.tensor([x['bond_label'] for x in batch]),
        'source_types': [x['source_type'] for x in batch]
    }

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("Model architecture defined!")
print(f"  - Bond types: {len(BOND_TYPE_TO_IDX)} classes")
print(f"  - Time periods: {len(TIME_PERIOD_TO_IDX)} classes")
print(f"  - Hohfeld states: {len(HOHFELD_TO_IDX)} classes")
print(f"  - Base batch size: {BASE_BATCH_SIZE}")

print_resources("After model definition")


In [None]:
#@title 9. Train BIP Model { display-mode: "form" }
#@markdown Trains bidirectionally with bond classification.
#@markdown **L4 Optimized**: Larger batch sizes, more workers.

import gc
import json
import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from sklearn.metrics import f1_score
from tqdm.auto import tqdm

mark_task("Train BIP model", "running")

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print_resources("Training start")

print("=" * 60)
print("BIDIRECTIONAL BIP TRAINING (v7.1 - L4 Optimized)")
print("=" * 60)
print()
print(f"Accelerator: {ACCELERATOR}")
print(f"Device: {device}")
print(f"Base batch size: {BASE_BATCH_SIZE}")
print()

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

all_results = {}

for split_idx, split_name in enumerate(['ancient_to_modern', 'modern_to_ancient', 'mixed_control']):
    split_start = time.time()
    print()
    print("=" * 60)
    direction_label = {
        'ancient_to_modern': 'Ancient → Modern',
        'modern_to_ancient': 'Modern → Ancient',
        'mixed_control': 'Mixed (In-Domain Baseline)'
    }[split_name]
    print(f"TRAINING [{split_idx+1}/3]: {direction_label}")
    print("=" * 60)
    print()
    
    with open("data/splits/all_splits.json", 'r') as f:
        splits = json.load(f)
    split = splits[split_name]
    
    print(f"Train: {split['train_size']:,}")
    print(f"Valid: {split['valid_size']:,}")
    print(f"Test:  {split['test_size']:,}")
    print()
    
    print("Creating fresh model...")
    model = BIPModel().to(device)
    
    if split_name == 'ancient_to_modern':
        n_params = sum(p.numel() for p in model.parameters())
        print(f"Model parameters: {n_params:,}")
    
    print("Creating datasets...")
    train_dataset = MoralDataset(
        set(split['train_ids']),
        "data/processed/passages.jsonl",
        "data/processed/bond_structures.jsonl",
        tokenizer
    )
    valid_dataset = MoralDataset(
        set(split['valid_ids']),
        "data/processed/passages.jsonl",
        "data/processed/bond_structures.jsonl",
        tokenizer
    )
    test_dataset = MoralDataset(
        set(split['test_ids']),
        "data/processed/passages.jsonl",
        "data/processed/bond_structures.jsonl",
        tokenizer
    )
    
    print(f"Train samples: {len(train_dataset):,}")
    print(f"Valid samples: {len(valid_dataset):,}")
    print(f"Test samples:  {len(test_dataset):,}")
    print()
    
    if len(train_dataset) == 0:
        print("ERROR: No training data!")
        continue
    
    # L4-optimized batch sizes
    if split_name == 'ancient_to_modern':
        batch_size = BASE_BATCH_SIZE
    else:
        batch_size = min(BASE_BATCH_SIZE, max(32, len(train_dataset) // 10))
    
    # More workers for L4's better CPU
    num_workers = 4 if IS_L4 or IS_A100 else 2
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        collate_fn=collate_fn, drop_last=True, 
        num_workers=num_workers, pin_memory=True, 
        prefetch_factor=4, persistent_workers=True
    )
    valid_loader = DataLoader(
        valid_dataset, batch_size=batch_size*2, shuffle=False,
        collate_fn=collate_fn, 
        num_workers=num_workers, pin_memory=True,
        prefetch_factor=4, persistent_workers=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size*2, shuffle=False,
        collate_fn=collate_fn,
        num_workers=num_workers, pin_memory=True,
        prefetch_factor=4, persistent_workers=True
    )
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
    
    n_epochs = 5  # More epochs for L4
    best_valid_loss = float('inf')
    patience = 3
    patience_counter = 0
    
    print(f"Training for {n_epochs} epochs (batch_size={batch_size}, workers={num_workers})...")
    print()
    
    for epoch in range(1, n_epochs + 1):
        epoch_start = time.time()
        model.train()
        total_loss = 0
        n_batches = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{n_epochs}", unit="batch")
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            time_labels = batch['time_labels'].to(device)
            hohfeld_labels = batch['hohfeld_labels'].to(device)
            bond_labels = batch['bond_labels'].to(device)
            
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                outputs = model(input_ids, attention_mask, adversarial_lambda=1.0)
                
                loss_time_bond = F.cross_entropy(outputs['time_pred_bond'], time_labels)
                loss_time_label = F.cross_entropy(outputs['time_pred_label'], time_labels)
                loss_hohfeld = F.cross_entropy(outputs['hohfeld_pred'], hohfeld_labels)
                loss_bond = F.cross_entropy(outputs['bond_pred'], bond_labels)
            
            loss = loss_hohfeld + loss_time_label + loss_time_bond + loss_bond
            
            optimizer.zero_grad()
            
            if USE_TPU:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                xm.optimizer_step(optimizer)
                xm.mark_step()
            elif USE_AMP and scaler is not None:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            total_loss += loss.item()
            n_batches += 1
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})
        
        avg_train_loss = total_loss / n_batches
        
        # Validation
        model.eval()
        valid_loss = 0
        valid_batches = 0
        
        with torch.no_grad():
            for batch in valid_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                hohfeld_labels = batch['hohfeld_labels'].to(device)
                
                outputs = model(input_ids, attention_mask, adversarial_lambda=0)
                loss = F.cross_entropy(outputs['hohfeld_pred'], hohfeld_labels)
                valid_loss += loss.item()
                valid_batches += 1
                
                if USE_TPU:
                    xm.mark_step()
        
        avg_valid_loss = valid_loss / valid_batches if valid_batches > 0 else 0
        epoch_time = time.time() - epoch_start
        
        # GPU memory check
        if torch.cuda.is_available():
            gpu_used = torch.cuda.memory_allocated()/1e9
            gpu_total = torch.cuda.get_device_properties(0).total_memory/1e9
            gpu_str = f" | GPU: {gpu_used:.1f}/{gpu_total:.1f}GB"
        else:
            gpu_str = ""
        
        print(f"Epoch {epoch}: Loss={avg_train_loss:.4f}/{avg_valid_loss:.4f} | {epoch_time:.1f}s{gpu_str}")
        
        if avg_valid_loss < best_valid_loss:
            best_valid_loss = avg_valid_loss
            model_path = f"models/checkpoints/best_model_{split_name}.pt"
            if USE_TPU:
                xm.save(model.state_dict(), model_path)
            else:
                torch.save(model.state_dict(), model_path)
            print(f"  -> Saved best model!")
            import shutil
            shutil.copy(model_path, f"{SAVE_DIR}/best_model_{split_name}.pt")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"  Early stopping at epoch {epoch}")
                break
    
    # EVALUATE
    print()
    print(f"Evaluating {split_name}...")
    
    model.load_state_dict(torch.load(f"models/checkpoints/best_model_{split_name}.pt", map_location='cpu'))
    model = model.to(device)
    model.eval()
    
    all_time_preds = []
    all_time_labels = []
    all_hohfeld_preds = []
    all_hohfeld_labels = []
    all_bond_preds = []
    all_bond_labels = []
    all_source_types = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing", unit="batch"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(input_ids, attention_mask, adversarial_lambda=0)
            
            all_time_preds.extend(outputs['time_pred_bond'].argmax(dim=-1).cpu().tolist())
            all_time_labels.extend(batch['time_labels'].tolist())
            all_hohfeld_preds.extend(outputs['hohfeld_pred'].argmax(dim=-1).cpu().tolist())
            all_hohfeld_labels.extend(batch['hohfeld_labels'].tolist())
            all_bond_preds.extend(outputs['bond_pred'].argmax(dim=-1).cpu().tolist())
            all_bond_labels.extend(batch['bond_labels'].tolist())
            all_source_types.extend(batch['source_types'])
            
            if USE_TPU:
                xm.mark_step()
    
    time_acc = sum(p == l for p, l in zip(all_time_preds, all_time_labels)) / len(all_time_preds)
    hohfeld_acc = sum(p == l for p, l in zip(all_hohfeld_preds, all_hohfeld_labels)) / len(all_hohfeld_preds)
    bond_acc = sum(p == l for p, l in zip(all_bond_preds, all_bond_labels)) / len(all_bond_preds)
    
    bond_f1_macro = f1_score(all_bond_labels, all_bond_preds, average='macro', zero_division=0)
    bond_f1_weighted = f1_score(all_bond_labels, all_bond_preds, average='weighted', zero_division=0)
    hohfeld_f1_macro = f1_score(all_hohfeld_labels, all_hohfeld_preds, average='macro', zero_division=0)
    
    corpus_bond_f1 = {}
    for corpus in set(all_source_types):
        mask = [s == corpus for s in all_source_types]
        corpus_preds = [p for p, m in zip(all_bond_preds, mask) if m]
        corpus_labels = [l for l, m in zip(all_bond_labels, mask) if m]
        if len(corpus_labels) > 0:
            corpus_bond_f1[corpus] = {
                'f1_macro': f1_score(corpus_labels, corpus_preds, average='macro', zero_division=0),
                'accuracy': sum(p == l for p, l in zip(corpus_preds, corpus_labels)) / len(corpus_labels),
                'n_samples': len(corpus_labels)
            }
    
    split_time = time.time() - split_start
    
    all_results[split_name] = {
        'time_acc': time_acc,
        'hohfeld_acc': hohfeld_acc,
        'hohfeld_f1_macro': hohfeld_f1_macro,
        'bond_acc': bond_acc,
        'bond_f1_macro': bond_f1_macro,
        'bond_f1_weighted': bond_f1_weighted,
        'corpus_bond_f1': corpus_bond_f1,
        'train_size': split['train_size'],
        'test_size': split['test_size'],
        'training_time_seconds': split_time
    }
    
    print()
    print(f"{split_name.upper()} RESULTS ({split_time/60:.1f} min):")
    print(f"  Time from z_bond (adversary): {time_acc:.1%}")
    print(f"  Hohfeld classification:       {hohfeld_acc:.1%} (F1={hohfeld_f1_macro:.3f})")
    print(f"  Bond classification:          {bond_acc:.1%} (F1={bond_f1_macro:.3f})")
    if corpus_bond_f1:
        print("  Bond F1 by corpus:")
        for corpus, metrics in corpus_bond_f1.items():
            print(f"    {corpus}: F1={metrics['f1_macro']:.3f}, Acc={metrics['accuracy']:.1%}")
    
    # Cleanup between splits
    del model, train_dataset, valid_dataset, test_dataset
    del train_loader, valid_loader, test_loader
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print()
print("=" * 60)
print("TRAINING COMPLETE")
print("=" * 60)

print_resources("After training")

mark_task("Train BIP model", "done")


In [None]:
#@title 10. Linear Probe Test for Time Invariance { display-mode: "form" }
#@markdown **Primary invariance test**: Can a fresh probe decode time from frozen z_bond?

import json
import gc
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from tqdm.auto import tqdm

mark_task("Linear probe test", "running")

print("=" * 60)
print("LINEAR PROBE TEST FOR TIME INVARIANCE")
print("=" * 60)
print()
print("This is the strongest test of time invariance:")
print("  1. Freeze encoder + bond_proj")
print("  2. Extract z_bond on test set")
print("  3. Fit fresh logistic regression to predict time")
print("  4. If probe accuracy ≈ chance, time is truly removed")
print()

linear_probe_results = {}

for split_name in ['ancient_to_modern', 'modern_to_ancient']:
    print(f"\n{'='*50}")
    print(f"LINEAR PROBE: {split_name}")
    print(f"{'='*50}")
    
    # Load best model
    model_path = f"models/checkpoints/best_model_{split_name}.pt"
    model = BIPModel().to(device)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model = model.to(device)
    model.eval()
    
    # Freeze everything
    for param in model.parameters():
        param.requires_grad = False
    
    # Load test data
    with open("data/splits/all_splits.json", 'r') as f:
        splits = json.load(f)
    split = splits[split_name]
    
    test_dataset = MoralDataset(
        set(split['test_ids']),
        "data/processed/passages.jsonl",
        "data/processed/bond_structures.jsonl",
        tokenizer
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=BASE_BATCH_SIZE, shuffle=False,
        collate_fn=collate_fn, num_workers=4, pin_memory=True
    )
    
    # Extract z_bond embeddings
    print("Extracting z_bond embeddings...")
    all_z_bond = []
    all_time_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Extracting", unit="batch"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            z_bond = model.extract_z_bond(input_ids, attention_mask)
            
            all_z_bond.append(z_bond.cpu().numpy())
            all_time_labels.extend(batch['time_labels'].tolist())
            
            if USE_TPU:
                xm.mark_step()
    
    X = np.vstack(all_z_bond)
    y = np.array(all_time_labels)
    
    print(f"Extracted {X.shape[0]} embeddings of dimension {X.shape[1]}")
    print(f"Time classes in test set: {len(np.unique(y))}")
    
    # Standardize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Split for probe training (50/50)
    n = len(X_scaled)
    np.random.seed(42)
    indices = np.random.permutation(n)
    train_idx = indices[:n//2]
    test_idx = indices[n//2:]
    
    X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]
    
    print(f"\nProbe training set: {len(X_train)}")
    print(f"Probe test set: {len(X_test)}")
    
    # Fit logistic regression
    print("\nFitting logistic regression probe...")
    probe = LogisticRegression(
        max_iter=1000,
        multi_class='multinomial',
        solver='lbfgs',
        random_state=42,
        n_jobs=-1
    )
    probe.fit(X_train, y_train)
    
    # Evaluate probe
    probe_preds = probe.predict(X_test)
    probe_acc = (probe_preds == y_test).mean()
    
    # Calculate chance level
    unique_classes = np.unique(y_test)
    chance_level = 1.0 / len(unique_classes)
    
    # Check if probe accuracy is near chance
    is_time_invariant = probe_acc < (chance_level + 0.10)
    
    print()
    print(f"PROBE RESULTS:")
    print(f"  Probe accuracy: {probe_acc:.1%}")
    print(f"  Chance level:   {chance_level:.1%} ({len(unique_classes)} classes)")
    print(f"  Above chance:   {probe_acc - chance_level:+.1%}")
    print()
    
    if is_time_invariant:
        print(f"  ✓ TIME INVARIANT: Probe cannot decode time from z_bond")
    else:
        print(f"  ✗ TIME LEAKAGE: Probe can still decode time from z_bond")
    
    linear_probe_results[split_name] = {
        'probe_acc': float(probe_acc),
        'chance_level': float(chance_level),
        'above_chance': float(probe_acc - chance_level),
        'is_time_invariant': bool(is_time_invariant),
        'n_classes': int(len(unique_classes)),
        'n_samples': int(len(X_test))
    }
    
    # Cleanup
    del model, test_dataset, test_loader, X, y, X_scaled
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Save probe results
with open('results/linear_probe_results.json', 'w') as f:
    json.dump(linear_probe_results, f, indent=2)

print()
print("=" * 60)
print("LINEAR PROBE SUMMARY")
print("=" * 60)
print()
for split_name, res in linear_probe_results.items():
    status = "✓ INVARIANT" if res['is_time_invariant'] else "✗ LEAKAGE"
    print(f"{split_name}:")
    print(f"  Probe: {res['probe_acc']:.1%} vs Chance: {res['chance_level']:.1%} → {status}")

print_resources("After linear probe")

mark_task("Linear probe test", "done")


In [None]:
#@title 11. Evaluate Final Results { display-mode: "form" }
#@markdown Comprehensive evaluation with linear probe as primary invariance test.

import gc
import json
import time

mark_task("Evaluate results", "running")

print_resources("Evaluation start")

print("=" * 60)
print("FINAL BIP EVALUATION (v7.1)")
print("=" * 60)
print()

# Load baselines
try:
    with open("data/splits/baselines.json", 'r') as f:
        baselines = json.load(f)
    chance_time = baselines['chance_time']
    chance_hohfeld = baselines['chance_hohfeld']
    chance_bond = baselines.get('chance_bond', 1/11)
except:
    chance_time = 1/9
    chance_hohfeld = 1/4
    chance_bond = 1/11

# Get in-domain baseline if available
in_domain_baseline = all_results.get('mixed_control', {})

print("=" * 60)
print("CROSS-TEMPORAL TRANSFER RESULTS")
print("=" * 60)
print()

for split_name in ['ancient_to_modern', 'modern_to_ancient']:
    res = all_results.get(split_name, {})
    probe_res = linear_probe_results.get(split_name, {})
    
    direction = 'Ancient → Modern' if split_name == 'ancient_to_modern' else 'Modern → Ancient'
    print(f"DIRECTION: {direction}")
    print("-" * 40)
    train_time = res.get('training_time_seconds', 0)
    print(f"  Train: {res.get('train_size', 0):,} | Test: {res.get('test_size', 0):,} | Time: {train_time/60:.1f}min")
    print()
    
    # Time invariance (LINEAR PROBE - primary test)
    probe_acc = probe_res.get('probe_acc', 1.0)
    probe_chance = probe_res.get('chance_level', chance_time)
    is_invariant = probe_res.get('is_time_invariant', False)
    
    print(f"  TIME INVARIANCE (Linear Probe - Primary Test):")
    print(f"    Probe accuracy: {probe_acc:.1%} (chance: {probe_chance:.1%})")
    print(f"    Status: {'✓ INVARIANT' if is_invariant else '✗ LEAKAGE'}")
    print()
    
    # Adversary head (secondary)
    adv_acc = res.get('time_acc', 0)
    print(f"  Time (Adversary Head - Secondary):")
    print(f"    Accuracy: {adv_acc:.1%} (chance: {chance_time:.1%})")
    print()
    
    # Bond transfer
    bond_f1 = res.get('bond_f1_macro', 0)
    bond_acc = res.get('bond_acc', 0)
    print(f"  BOND TRANSFER:")
    print(f"    Accuracy: {bond_acc:.1%} | F1 (macro): {bond_f1:.3f}")
    print(f"    Chance: {chance_bond:.1%}")
    
    # Compare to in-domain baseline
    if in_domain_baseline:
        baseline_bond_f1 = in_domain_baseline.get('bond_f1_macro', 0)
        if baseline_bond_f1 > 0:
            degradation = baseline_bond_f1 - bond_f1
            print(f"    In-domain baseline F1: {baseline_bond_f1:.3f}")
            print(f"    Degradation: {degradation:+.3f} ({degradation/baseline_bond_f1*100:+.1f}%)")
    print()
    
    # Per-corpus breakdown
    if 'corpus_bond_f1' in res and res['corpus_bond_f1']:
        print(f"  BY CORPUS:")
        for corpus, metrics in res['corpus_bond_f1'].items():
            print(f"    {corpus}: F1={metrics['f1_macro']:.3f}, Acc={metrics['accuracy']:.1%} (n={metrics['n_samples']:,})")
    print()
    
    # Hohfeld
    hohfeld_f1 = res.get('hohfeld_f1_macro', 0)
    hohfeld_acc = res.get('hohfeld_acc', 0)
    print(f"  HOHFELD CLASSIFICATION:")
    print(f"    Accuracy: {hohfeld_acc:.1%} | F1 (macro): {hohfeld_f1:.3f}")
    print(f"    Chance: {chance_hohfeld:.1%}")
    print()

# Summary verdict
print("=" * 60)
print("SUMMARY VERDICT")
print("=" * 60)
print()

A_probe = linear_probe_results.get('ancient_to_modern', {})
B_probe = linear_probe_results.get('modern_to_ancient', {})
A_res = all_results.get('ancient_to_modern', {})
B_res = all_results.get('modern_to_ancient', {})

A_invariant = A_probe.get('is_time_invariant', False)
B_invariant = B_probe.get('is_time_invariant', False)
A_bond_good = A_res.get('bond_f1_macro', 0) > chance_bond * 1.5
B_bond_good = B_res.get('bond_f1_macro', 0) > chance_bond * 1.5

print(f"Ancient → Modern:")
print(f"  Time invariant (probe): {'✓' if A_invariant else '✗'}")
print(f"  Bond transfer (F1 > {chance_bond*1.5:.1%}): {'✓' if A_bond_good else '✗'}")
print()
print(f"Modern → Ancient:")
print(f"  Time invariant (probe): {'✓' if B_invariant else '✗'}")
print(f"  Bond transfer (F1 > {chance_bond*1.5:.1%}): {'✓' if B_bond_good else '✗'}")
print()

if A_invariant and B_invariant and A_bond_good and B_bond_good:
    verdict = "STRONGLY_SUPPORTED"
    verdict_box = """
    ╔══════════════════════════════════════════════════════════╗
    ║     BIP: STRONGLY SUPPORTED                              ║
    ╠══════════════════════════════════════════════════════════╣
    ║  Both directions show:                                   ║
    ║    ✓ Time-invariant bond representation (probe test)     ║
    ║    ✓ Bond transfer well above chance                     ║
    ║  Cross-domain performance shows minimal degradation.     ║
    ╚══════════════════════════════════════════════════════════╝
    """
elif (A_invariant and A_bond_good) or (B_invariant and B_bond_good):
    verdict = "SUPPORTED_UNIDIRECTIONAL"
    verdict_box = """
    ╔══════════════════════════════════════════════════════════╗
    ║     BIP: SUPPORTED (One Direction)                       ║
    ╠══════════════════════════════════════════════════════════╣
    ║  At least one direction shows time-invariant bond        ║
    ║  representation with transfer above chance.              ║
    ╚══════════════════════════════════════════════════════════╝
    """
elif A_bond_good or B_bond_good:
    verdict = "PARTIAL_SUPPORT"
    verdict_box = """
    ╔══════════════════════════════════════════════════════════╗
    ║     BIP: PARTIAL SUPPORT                                 ║
    ╠══════════════════════════════════════════════════════════╣
    ║  Bond transfer works, but time information may still     ║
    ║  leak through z_bond (probe can decode it).              ║
    ╚══════════════════════════════════════════════════════════╝
    """
else:
    verdict = "INCONCLUSIVE"
    verdict_box = """
    ╔══════════════════════════════════════════════════════════╗
    ║     BIP: INCONCLUSIVE                                    ║
    ╠══════════════════════════════════════════════════════════╣
    ║  Neither direction shows clear invariance with transfer. ║
    ╚══════════════════════════════════════════════════════════╝
    """

print(verdict_box)

# Total runtime
total_time = time.time() - EXPERIMENT_START
print(f"\nTotal experiment time: {total_time/60:.1f} minutes")

# Save all results
final_results = {
    'model_results': {k: {kk: vv for kk, vv in v.items() if kk != 'test_preds'} 
                      for k, v in all_results.items()},
    'linear_probe_results': linear_probe_results,
    'verdict': verdict,
    'total_time_seconds': total_time,
    'accelerator': ACCELERATOR,
    'baselines': {
        'chance_time': chance_time,
        'chance_hohfeld': chance_hohfeld,
        'chance_bond': chance_bond
    },
    'methodology_notes': {
        'label_source': 'English translations (text_english)',
        'languages': 'Hebrew (Sefaria) ↔ English (Dear Abby)',
        'primary_invariance_test': 'Linear probe on frozen z_bond'
    }
}

with open('results/final_results.json', 'w') as f:
    json.dump(final_results, f, indent=2, default=str)

with open(f"{SAVE_DIR}/final_results.json", 'w') as f:
    json.dump(final_results, f, indent=2, default=str)

print(f"\nResults saved to results/final_results.json")
print(f"Backed up to {SAVE_DIR}/final_results.json")

print_resources("Final")

mark_task("Evaluate results", "done")

print()
print("=" * 60)
print("EXPERIMENT COMPLETE")
print("=" * 60)
print_progress()


---

## Save Results

Run the cell below to download your trained model and results.

In [None]:
#@title 12. Download Results (Optional) { display-mode: "form" }
#@markdown Creates a zip file with model checkpoints, metrics, and probe results.

import shutil
import os
from google.colab import files

!mkdir -p results

# Copy model checkpoints
for split_name in ['ancient_to_modern', 'modern_to_ancient', 'mixed_control']:
    model_path = f"models/checkpoints/best_model_{split_name}.pt"
    if os.path.exists(model_path):
        !cp "{model_path}" results/
        print(f"Copied {model_path}")

!cp data/splits/all_splits.json results/ 2>/dev/null || true
!cp data/splits/baselines.json results/ 2>/dev/null || true
!cp results/final_results.json results/ 2>/dev/null || true
!cp results/linear_probe_results.json results/ 2>/dev/null || true

# Zip
shutil.make_archive('bip_results_v7', 'zip', 'results')
print()
print("Results saved to bip_results_v7.zip")
print()
print("Contents:")
!ls -la results/

files.download('bip_results_v7.zip')