# BIP Temporal Invariance Experiment

**Testing the Bond Invariance Principle across 2000+ years of moral reasoning**

This experiment tests whether moral cognition has invariant structure by:
1. Training on ancient Hebrew ethical texts (Sefaria corpus, ~500 BCE - 1800 CE)
2. Testing transfer to modern American advice columns (Dear Abby, 1956-2020)

**Hypothesis**: If BIP holds, bond-level features should transfer across 2000 years with no accuracy drop.

---

## Setup Instructions
1. **Runtime -> Change runtime type -> GPU (T4) or TPU (v5e)**
2. Run cells in order - each shows progress in real-time
3. Expected runtime: ~1-2 hours (TPU) or ~2-4 hours (GPU)

**Supported Accelerators:**
- NVIDIA GPU (T4, V100, A100)
- Google TPU (v2, v3, v4, v5e)
- CPU (slow, not recommended)

---

In [None]:
#@title 1. Setup and Install Dependencies { display-mode: "form" }
#@markdown Installs packages and detects GPU/TPU.

print("=" * 60)
print("BIP TEMPORAL INVARIANCE EXPERIMENT")
print("=" * 60)
print()

# Progress tracker - persists across cells
TASKS = [
    "Install dependencies",
    "Clone Sefaria corpus (~8GB)",
    "Clone sqnd-probe repo (Dear Abby data)",
    "Preprocess corpora",
    "Extract bond structures",
    "Generate train/test splits",
    "Train BIP model",
    "Evaluate results"
]
task_status = {task: "pending" for task in TASKS}

def print_progress():
    """Print current progress checklist."""
    print()
    print("-" * 50)
    print("EXPERIMENT PROGRESS:")
    print("-" * 50)
    for task in TASKS:
        status = task_status[task]
        if status == "done":
            mark = "[X]"
        elif status == "running":
            mark = "[>]"
        else:
            mark = "[ ]"
        print(f"  {mark} {task}")
    print("-" * 50)
    print(flush=True)

def mark_task(task, status):
    """Update task status and print progress."""
    task_status[task] = status
    print_progress()

print_progress()

# Detect accelerator type
mark_task("Install dependencies", "running")

import os
USE_TPU = False
TPU_TYPE = None

# Check for TPU
if 'COLAB_TPU_ADDR' in os.environ or os.path.exists('/dev/accel0'):
    print("TPU detected! Installing torch_xla...")
    USE_TPU = True
    # Install PyTorch XLA for TPU support
    !pip install -q torch~=2.4.0 torch_xla[tpu]~=2.4.0 -f https://storage.googleapis.com/libtpu-releases/index.html
    !pip install -q transformers sentence-transformers scipy scikit-learn pandas numpy tqdm pyyaml
else:
    print("No TPU detected, using GPU/CPU...")
    !pip install -q torch transformers sentence-transformers scipy scikit-learn pandas numpy tqdm pyyaml

import torch
print(f"\nPyTorch version: {torch.__version__}")

# Setup device
if USE_TPU:
    try:
        import torch_xla
        import torch_xla.core.xla_model as xm
        import torch_xla.distributed.parallel_loader as pl
        
        device = xm.xla_device()
        TPU_TYPE = str(device)
        print(f"\nTPU initialized successfully!")
        print(f"Device: {device}")
        print(f"TPU cores available: {xm.xrt_world_size()}")
        ACCELERATOR = "TPU"
    except Exception as e:
        print(f"TPU initialization failed: {e}")
        print("Falling back to GPU/CPU...")
        USE_TPU = False
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        ACCELERATOR = "GPU" if torch.cuda.is_available() else "CPU"
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ACCELERATOR = "GPU" if torch.cuda.is_available() else "CPU"

if ACCELERATOR == "GPU":
    print(f"\nCUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
elif ACCELERATOR == "CPU":
    print("\nWARNING: No accelerator detected! Training will be slow.")
    print("Go to Runtime -> Change runtime type -> GPU or TPU")

print(f"\n>>> Using: {ACCELERATOR} <<<")

# Create directories
os.makedirs('data/raw', exist_ok=True)
os.makedirs('data/processed', exist_ok=True)
os.makedirs('data/splits', exist_ok=True)
os.makedirs('models/checkpoints', exist_ok=True)

mark_task("Install dependencies", "done")

In [None]:
#@title 2. Download Sefaria Corpus (~8GB) { display-mode: "form" }
#@markdown Downloads the complete Sefaria corpus with real-time git progress.

import subprocess
import sys

mark_task("Clone Sefaria corpus (~8GB)", "running")

sefaria_path = 'data/raw/Sefaria-Export'

if not os.path.exists(sefaria_path) or not os.path.exists(f"{sefaria_path}/json"):
    print("="*60)
    print("CLONING SEFARIA CORPUS")
    print("="*60)
    print()
    print("This downloads ~3.5GB and takes 5-15 minutes.")
    print("Git's native progress will display below:")
    print("-"*60)
    print(flush=True)
    
    # Use subprocess.Popen for real-time output streaming
    process = subprocess.Popen(
        ['git', 'clone', '--depth', '1', '--progress',
         'https://github.com/Sefaria/Sefaria-Export.git', sefaria_path],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,  # Git writes progress to stderr
        text=True,
        bufsize=1
    )
    
    # Stream output in real-time
    for line in process.stdout:
        print(line, end='', flush=True)
    
    process.wait()
    
    print("-"*60)
    if process.returncode == 0:
        print("\nSefaria clone COMPLETE!")
    else:
        print(f"\nERROR: Git clone failed with code {process.returncode}")
        print("Try running this cell again, or check your internet connection.")
else:
    print("Sefaria already exists, skipping download.")

# Verify and count files
print()
print("Verifying download...")
!du -sh {sefaria_path} 2>/dev/null || echo "Directory not found"
json_count = !find {sefaria_path}/json -name "*.json" 2>/dev/null | wc -l
print(f"Sefaria JSON files found: {json_count[0]}")

mark_task("Clone Sefaria corpus (~8GB)", "done")


In [None]:
#@title 3. Download Dear Abby Dataset { display-mode: "form" }
#@markdown Downloads the Dear Abby advice column dataset (68,330 entries).

import pandas as pd
from pathlib import Path

mark_task("Clone sqnd-probe repo (Dear Abby data)", "running")

sqnd_path = 'sqnd-probe-data'
if not os.path.exists(sqnd_path):
    print("Cloning sqnd-probe repo...")
    process = subprocess.Popen(
        ['git', 'clone', '--depth', '1', '--progress',
         'https://github.com/ahb-sjsu/sqnd-probe.git', sqnd_path],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )
    for line in process.stdout:
        print(line, end='', flush=True)
    process.wait()
else:
    print("Repo already cloned.")

# Copy Dear Abby data
dear_abby_source = Path('sqnd-probe-data/dear_abby_data/raw_da_qs.csv')
dear_abby_path = Path('data/raw/dear_abby.csv')

if dear_abby_source.exists():
    !cp "{dear_abby_source}" "{dear_abby_path}"
    print(f"\nCopied Dear Abby data")
elif not dear_abby_path.exists():
    raise FileNotFoundError("Dear Abby dataset not found!")

# Verify
df_check = pd.read_csv(dear_abby_path)
print(f"\n" + "=" * 50)
print(f"Dear Abby dataset: {len(df_check):,} entries")
print(f"Columns: {list(df_check.columns)}")
print(f"Year range: {df_check['year'].min():.0f} - {df_check['year'].max():.0f}")
print("=" * 50)

mark_task("Clone sqnd-probe repo (Dear Abby data)", "done")

In [None]:
#@title 4. Define Data Classes and Loaders { display-mode: "form" }
#@markdown Defines enums, dataclasses, and corpus loaders.

import json
import hashlib
import re
from dataclasses import dataclass, field, asdict
from typing import List, Dict
from enum import Enum
from collections import defaultdict
from tqdm.auto import tqdm

print("Defining data structures...")

class TimePeriod(Enum):
    BIBLICAL = 0        # ~1000-500 BCE
    SECOND_TEMPLE = 1   # ~500 BCE - 70 CE
    TANNAITIC = 2       # ~70-200 CE
    AMORAIC = 3         # ~200-500 CE
    GEONIC = 4          # ~600-1000 CE
    RISHONIM = 5        # ~1000-1500 CE
    ACHRONIM = 6        # ~1500-1800 CE
    MODERN_HEBREW = 7   # ~1800-present
    DEAR_ABBY = 8       # 1956-2020

class BondType(Enum):
    HARM_PREVENTION = 0
    RECIPROCITY = 1
    AUTONOMY = 2
    PROPERTY = 3
    FAMILY = 4
    AUTHORITY = 5
    EMERGENCY = 6
    CONTRACT = 7
    CARE = 8
    FAIRNESS = 9

class HohfeldianState(Enum):
    RIGHT = 0
    OBLIGATION = 1
    LIBERTY = 2
    NO_RIGHT = 3

@dataclass
class Passage:
    id: str
    text_original: str
    text_english: str
    time_period: str
    century: int
    source: str
    source_type: str
    category: str
    language: str = "hebrew"
    word_count: int = 0
    has_dispute: bool = False
    consensus_tier: str = "unknown"
    bond_types: List[str] = field(default_factory=list)
    
    def to_dict(self):
        return asdict(self)

CATEGORY_TO_PERIOD = {
    'Tanakh': TimePeriod.BIBLICAL,
    'Torah': TimePeriod.BIBLICAL,
    'Mishnah': TimePeriod.TANNAITIC,
    'Tosefta': TimePeriod.TANNAITIC,
    'Talmud': TimePeriod.AMORAIC,
    'Bavli': TimePeriod.AMORAIC,
    'Midrash': TimePeriod.AMORAIC,
    'Halakhah': TimePeriod.RISHONIM,
    'Chasidut': TimePeriod.ACHRONIM,
}

PERIOD_TO_CENTURY = {
    TimePeriod.BIBLICAL: -6,
    TimePeriod.SECOND_TEMPLE: -2,
    TimePeriod.TANNAITIC: 2,
    TimePeriod.AMORAIC: 4,
    TimePeriod.GEONIC: 8,
    TimePeriod.RISHONIM: 12,
    TimePeriod.ACHRONIM: 17,
    TimePeriod.MODERN_HEBREW: 20,
}

def load_sefaria(base_path: str, max_passages: int = None) -> List[Passage]:
    """Load Sefaria corpus with progress bar."""
    passages = []
    json_path = Path(base_path) / "json"
    
    if not json_path.exists():
        print(f"Warning: {json_path} not found")
        return []
    
    json_files = list(json_path.rglob("*.json"))
    print(f"Processing {len(json_files):,} JSON files...")
    
    for json_file in tqdm(json_files[:max_passages] if max_passages else json_files,
                          desc="Loading Sefaria", unit="file"):
        try:
            with open(json_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except:
            continue
        
        rel_path = json_file.relative_to(json_path)
        category = str(rel_path.parts[0]) if rel_path.parts else "unknown"
        time_period = CATEGORY_TO_PERIOD.get(category, TimePeriod.AMORAIC)
        century = PERIOD_TO_CENTURY.get(time_period, 0)
        
        if isinstance(data, dict):
            hebrew = data.get('he', data.get('text', []))
            english = data.get('text', data.get('en', []))
            
            def flatten(h, e, ref=""):
                if isinstance(h, str) and isinstance(e, str):
                    h_clean = re.sub(r'<[^>]+>', '', h).strip()
                    e_clean = re.sub(r'<[^>]+>', '', e).strip()
                    if 50 <= len(e_clean) <= 2000:
                        pid = hashlib.md5(f"{json_file.stem}:{ref}:{h_clean[:50]}".encode()).hexdigest()[:12]
                        return [Passage(
                            id=f"sefaria_{pid}",
                            text_original=h_clean,
                            text_english=e_clean,
                            time_period=time_period.name,
                            century=century,
                            source=f"{json_file.stem} {ref}".strip(),
                            source_type="sefaria",
                            category=category,
                            language="hebrew",
                            word_count=len(e_clean.split())
                        )]
                    return []
                elif isinstance(h, list) and isinstance(e, list):
                    result = []
                    for i, (hh, ee) in enumerate(zip(h, e)):
                        result.extend(flatten(hh, ee, f"{ref}.{i+1}" if ref else str(i+1)))
                    return result
                return []
            
            passages.extend(flatten(hebrew, english))
    
    return passages

def load_dear_abby(path: str, max_passages: int = None) -> List[Passage]:
    """Load Dear Abby corpus with progress bar."""
    passages = []
    df = pd.read_csv(path)
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading Dear Abby", unit="row"):
        question = str(row.get('question_only', ''))
        if not question or question == 'nan' or len(question) < 50 or len(question) > 2000:
            continue
        
        year = int(row.get('year', 1990))
        pid = hashlib.md5(f"abby:{idx}:{question[:50]}".encode()).hexdigest()[:12]
        
        passages.append(Passage(
            id=f"abby_{pid}",
            text_original=question,
            text_english=question,
            time_period=TimePeriod.DEAR_ABBY.name,
            century=20 if year < 2000 else 21,
            source=f"Dear Abby {year}",
            source_type="dear_abby",
            category="general",
            language="english",
            word_count=len(question.split())
        ))
        
        if max_passages and len(passages) >= max_passages:
            break
    
    return passages

print("Data structures defined!")

In [None]:
#@title 5. Load and Preprocess Corpora { display-mode: "form" }
#@markdown Loads both corpora and shows statistics.

mark_task("Preprocess corpora", "running")

print("=" * 60)
print("LOADING CORPORA")
print("=" * 60)
print()

# Load Sefaria
sefaria_passages = load_sefaria("data/raw/Sefaria-Export")
print(f"\nSefaria passages loaded: {len(sefaria_passages):,}")

# Load Dear Abby
print()
abby_passages = load_dear_abby("data/raw/dear_abby.csv")
print(f"\nDear Abby passages loaded: {len(abby_passages):,}")

# Combine
all_passages = sefaria_passages + abby_passages

print()
print("=" * 60)
print(f"TOTAL PASSAGES: {len(all_passages):,}")
print("=" * 60)

# Statistics
by_period = defaultdict(int)
by_source = defaultdict(int)
for p in all_passages:
    by_period[p.time_period] += 1
    by_source[p.source_type] += 1

print("\nBy source:")
for source, count in sorted(by_source.items()):
    print(f"  {source}: {count:,}")

print("\nBy time period:")
for period, count in sorted(by_period.items()):
    pct = count / len(all_passages) * 100
    bar = '#' * int(pct / 2)
    print(f"  {period:20s}: {count:6,} ({pct:5.1f}%) {bar}")

mark_task("Preprocess corpora", "done")

In [None]:
#@title 6. Extract Bond Structures { display-mode: "form" }
#@markdown Extracts moral bond structures from each passage.

mark_task("Extract bond structures", "running")

RELATION_PATTERNS = {
    BondType.HARM_PREVENTION: [r'\b(kill|murder|harm|hurt|save|rescue|protect|danger)\b'],
    BondType.RECIPROCITY: [r'\b(return|repay|owe|debt|mutual|exchange)\b'],
    BondType.AUTONOMY: [r'\b(choose|decision|consent|agree|force|coerce|right)\b'],
    BondType.PROPERTY: [r'\b(property|own|steal|theft|buy|sell|land)\b'],
    BondType.FAMILY: [r'\b(honor|parent|marry|divorce|inherit|family)\b'],
    BondType.AUTHORITY: [r'\b(obey|command|law|judge|rule|teach)\b'],
    BondType.CARE: [r'\b(care|help|assist|feed|clothe|visit)\b'],
    BondType.FAIRNESS: [r'\b(fair|just|equal|deserve|bias)\b'],
}

HOHFELD_PATTERNS = {
    HohfeldianState.OBLIGATION: [r'\b(must|shall|duty|require|should)\b'],
    HohfeldianState.RIGHT: [r'\b(right to|entitled|deserve)\b'],
    HohfeldianState.LIBERTY: [r'\b(may|can|permitted|allowed)\b'],
}

def extract_bond_structure(passage: Passage) -> Dict:
    """Extract bond structure from passage."""
    text = passage.text_english.lower()
    
    relations = []
    for rel_type, patterns in RELATION_PATTERNS.items():
        for pattern in patterns:
            if re.search(pattern, text, re.IGNORECASE):
                relations.append(rel_type.name)
                break
    
    if not relations:
        relations = ['CARE']
    
    hohfeld = None
    for state, patterns in HOHFELD_PATTERNS.items():
        for pattern in patterns:
            if re.search(pattern, text, re.IGNORECASE):
                hohfeld = state.name
                break
        if hohfeld:
            break
    
    signature = "|".join(sorted(set(relations)))
    
    return {
        'bonds': [{'relation': r} for r in relations],
        'primary_relation': relations[0],
        'hohfeld_state': hohfeld,
        'signature': signature
    }

print("=" * 60)
print("EXTRACTING BOND STRUCTURES")
print("=" * 60)
print()

bond_structures = []
for passage in tqdm(all_passages, desc="Extracting bonds", unit="passage"):
    bond_struct = extract_bond_structure(passage)
    passage.bond_types = [b['relation'] for b in bond_struct['bonds']]
    bond_structures.append({
        'passage_id': passage.id,
        'bond_structure': bond_struct
    })

# Save
print("\nSaving processed data...")
with open("data/processed/passages.jsonl", 'w') as f:
    for p in all_passages:
        f.write(json.dumps(p.to_dict()) + '\n')

with open("data/processed/bond_structures.jsonl", 'w') as f:
    for bs in bond_structures:
        f.write(json.dumps(bs) + '\n')

# Bond type distribution
bond_counts = defaultdict(int)
for bs in bond_structures:
    for bond in bs['bond_structure']['bonds']:
        bond_counts[bond['relation']] += 1

print("\nBond type distribution:")
for bond_type, count in sorted(bond_counts.items(), key=lambda x: -x[1]):
    pct = count / sum(bond_counts.values()) * 100
    bar = '#' * int(pct)
    print(f"  {bond_type:20s}: {count:6,} ({pct:5.1f}%) {bar}")

mark_task("Extract bond structures", "done")

In [None]:
#@title 7. Generate Train/Test Splits { display-mode: "form" }
#@markdown Creates temporal holdout split: train on ancient, test on modern.

import random
random.seed(42)

mark_task("Generate train/test splits", "running")

print("=" * 60)
print("GENERATING SPLITS")
print("=" * 60)
print()

# TEMPORAL HOLDOUT SPLIT (Primary BIP Test)
# Train: Ancient/Medieval Hebrew texts
# Test: Modern (Dear Abby)
train_periods = {'BIBLICAL', 'SECOND_TEMPLE', 'TANNAITIC', 'AMORAIC', 'GEONIC', 'RISHONIM'}
valid_periods = {'ACHRONIM'}
test_periods = {'MODERN_HEBREW', 'DEAR_ABBY'}

train = [p for p in all_passages if p.time_period in train_periods]
valid = [p for p in all_passages if p.time_period in valid_periods]
test = [p for p in all_passages if p.time_period in test_periods]

random.shuffle(train)
random.shuffle(valid)
random.shuffle(test)

temporal_holdout = {
    'name': 'temporal_holdout',
    'train_ids': [p.id for p in train],
    'valid_ids': [p.id for p in valid],
    'test_ids': [p.id for p in test],
    'train_size': len(train),
    'valid_size': len(valid),
    'test_size': len(test)
}

print("TEMPORAL HOLDOUT SPLIT (Primary BIP Test):")
print(f"  Train (ancient/medieval, ~500 BCE - 1500 CE): {len(train):,}")
print(f"  Valid (early modern, ~1500 - 1800 CE):        {len(valid):,}")
print(f"  Test (modern, 1956 - 2020):                   {len(test):,}")
print()
print(f"  Temporal gap: ~500 years between train and test")

# Also create stratified random for comparison
random.shuffle(all_passages)
n = len(all_passages)
n_train = int(0.7 * n)
n_valid = int(0.15 * n)

stratified = {
    'name': 'stratified_random',
    'train_ids': [p.id for p in all_passages[:n_train]],
    'valid_ids': [p.id for p in all_passages[n_train:n_train+n_valid]],
    'test_ids': [p.id for p in all_passages[n_train+n_valid:]],
    'train_size': n_train,
    'valid_size': n_valid,
    'test_size': n - n_train - n_valid
}

print()
print("STRATIFIED RANDOM SPLIT (Control):")
print(f"  Train: {stratified['train_size']:,}")
print(f"  Valid: {stratified['valid_size']:,}")
print(f"  Test:  {stratified['test_size']:,}")

# Save
splits = {
    'temporal_holdout': temporal_holdout,
    'stratified_random': stratified
}

with open("data/splits/all_splits.json", 'w') as f:
    json.dump(splits, f, indent=2)

print("\nSplits saved!")

mark_task("Generate train/test splits", "done")

In [None]:
#@title 8. Define BIP Model Architecture { display-mode: "form" }
#@markdown Defines the adversarial disentanglement model (GPU/TPU compatible).

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader

print("=" * 60)
print("DEFINING MODEL ARCHITECTURE")
print("=" * 60)
print()

class GradientReversal(torch.autograd.Function):
    """Gradient reversal layer for adversarial training."""
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()
    
    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_ * grad_output, None

def gradient_reversal(x, lambda_=1.0):
    return GradientReversal.apply(x, lambda_)

class BIPEncoder(nn.Module):
    """Sentence encoder using pretrained transformer."""
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", d_model=384):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.d_model = d_model
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
        return pooled

class BIPModel(nn.Module):
    """
    Bond Invariance Principle Model.
    
    Disentangles representations into:
    - z_bond: Time-invariant moral structure (should NOT predict time period)
    - z_label: Temporal/cultural context (CAN predict time period)
    """
    def __init__(self, d_model=384, d_bond=64, d_label=32, n_periods=9, n_hohfeld=4):
        super().__init__()
        
        self.encoder = BIPEncoder()
        
        # Disentanglement heads
        self.bond_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, d_bond)
        )
        
        self.label_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, d_label)
        )
        
        # Classifiers
        self.time_from_bond = nn.Linear(d_bond, n_periods)   # Should fail (adversarial)
        self.time_from_label = nn.Linear(d_label, n_periods) # Should succeed
        self.hohfeld_classifier = nn.Linear(d_bond, n_hohfeld)
    
    def forward(self, input_ids, attention_mask, adversarial_lambda=1.0):
        h = self.encoder(input_ids, attention_mask)
        
        z_bond = self.bond_proj(h)
        z_label = self.label_proj(h)
        
        # Time prediction with gradient reversal on z_bond
        if adversarial_lambda > 0:
            z_bond_adv = gradient_reversal(z_bond, adversarial_lambda)
            time_pred_bond = self.time_from_bond(z_bond_adv)
        else:
            time_pred_bond = self.time_from_bond(z_bond)
        
        time_pred_label = self.time_from_label(z_label)
        hohfeld_pred = self.hohfeld_classifier(z_bond)
        
        return {
            'z_bond': z_bond,
            'z_label': z_label,
            'time_pred_bond': time_pred_bond,
            'time_pred_label': time_pred_label,
            'hohfeld_pred': hohfeld_pred
        }

# Dataset class
PERIOD_TO_IDX = {p.name: i for i, p in enumerate(TimePeriod)}
HOHFELD_TO_IDX = {h.name: i for i, h in enumerate(HohfeldianState)}

class MoralDataset(Dataset):
    def __init__(self, passage_ids, passages_file, bonds_file, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        self.passages = {}
        with open(passages_file) as f:
            for line in f:
                p = json.loads(line)
                if p['id'] in passage_ids:
                    self.passages[p['id']] = p
        
        self.bonds = {}
        with open(bonds_file) as f:
            for line in f:
                b = json.loads(line)
                self.bonds[b['passage_id']] = b['bond_structure']
        
        self.ids = [pid for pid in passage_ids if pid in self.passages]
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        pid = self.ids[idx]
        passage = self.passages[pid]
        
        encoded = self.tokenizer(
            passage['text_english'],
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        time_label = PERIOD_TO_IDX.get(passage['time_period'], 0)
        hohfeld = self.bonds.get(pid, {}).get('hohfeld_state')
        hohfeld_label = HOHFELD_TO_IDX.get(hohfeld, 0) if hohfeld else 0
        
        return {
            'input_ids': encoded['input_ids'].squeeze(0),
            'attention_mask': encoded['attention_mask'].squeeze(0),
            'time_label': time_label,
            'hohfeld_label': hohfeld_label
        }

def collate_fn(batch):
    return {
        'input_ids': torch.stack([b['input_ids'] for b in batch]),
        'attention_mask': torch.stack([b['attention_mask'] for b in batch]),
        'time_labels': torch.tensor([b['time_label'] for b in batch]),
        'hohfeld_labels': torch.tensor([b['hohfeld_label'] for b in batch])
    }

print("Model architecture defined!")
print()
print("Key components:")
print("  - BIPEncoder: Sentence transformer (MiniLM-L6)")
print("  - bond_proj: Projects to time-invariant z_bond")
print("  - label_proj: Projects to temporal z_label")
print("  - Gradient reversal: Forces z_bond to be time-agnostic")
print(f"\nTarget device: {ACCELERATOR}")

In [None]:
#@title 9. Train BIP Model (GPU/TPU) { display-mode: "form" }
#@markdown Trains the model with adversarial disentanglement. Auto-detects GPU or TPU.

mark_task("Train BIP model", "running")

print("=" * 60)
print("TRAINING BIP MODEL")
print("=" * 60)
print()
print(f"Accelerator: {ACCELERATOR}")
print(f"Device: {device}")
print()

# Use temporal holdout split
split = temporal_holdout
print(f"Using TEMPORAL HOLDOUT split")
print(f"  Train: {split['train_size']:,} (ancient/medieval)")
print(f"  Valid: {split['valid_size']:,} (early modern)")
print(f"  Test:  {split['test_size']:,} (modern)")
print()

# Initialize
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = BIPModel().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print()

# Create datasets
print("Creating datasets...")
train_dataset = MoralDataset(
    set(split['train_ids']),
    "data/processed/passages.jsonl",
    "data/processed/bond_structures.jsonl",
    tokenizer
)
valid_dataset = MoralDataset(
    set(split['valid_ids']),
    "data/processed/passages.jsonl",
    "data/processed/bond_structures.jsonl",
    tokenizer
)
test_dataset = MoralDataset(
    set(split['test_ids']),
    "data/processed/passages.jsonl",
    "data/processed/bond_structures.jsonl",
    tokenizer
)

print(f"Train samples: {len(train_dataset):,}")
print(f"Valid samples: {len(valid_dataset):,}")
print(f"Test samples:  {len(test_dataset):,}")
print()

if len(train_dataset) == 0:
    print("ERROR: No training data! Check if Sefaria loaded correctly.")
else:
    # Adjust batch size for TPU (larger batches work better)
    batch_size = 64 if USE_TPU else 32
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                              collate_fn=collate_fn, drop_last=True, num_workers=2)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size*2, shuffle=False, 
                              collate_fn=collate_fn, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size*2, shuffle=False, 
                             collate_fn=collate_fn, num_workers=2)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
    
    n_epochs = 10
    best_valid_loss = float('inf')
    
    print(f"Training for {n_epochs} epochs (batch_size={batch_size})...")
    print("=" * 60)
    print()
    
    for epoch in range(1, n_epochs + 1):
        # Training
        model.train()
        total_loss = 0
        n_batches = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{n_epochs}", unit="batch")
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            time_labels = batch['time_labels'].to(device)
            hohfeld_labels = batch['hohfeld_labels'].to(device)
            
            outputs = model(input_ids, attention_mask, adversarial_lambda=1.0)
            
            # Losses
            # 1. Adversarial: maximize entropy of time prediction from z_bond
            time_probs = F.softmax(outputs['time_pred_bond'], dim=-1)
            entropy = -torch.sum(time_probs * torch.log(time_probs + 1e-8), dim=-1)
            loss_adv = -entropy.mean()
            
            # 2. Time prediction from z_label should work
            loss_time = F.cross_entropy(outputs['time_pred_label'], time_labels)
            
            # 3. Hohfeldian classification from z_bond
            loss_hohfeld = F.cross_entropy(outputs['hohfeld_pred'], hohfeld_labels)
            
            loss = loss_adv + loss_time + loss_hohfeld
            
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # TPU vs GPU optimizer step
            if USE_TPU:
                xm.optimizer_step(optimizer)
                xm.mark_step()  # Sync TPU
            else:
                optimizer.step()
            
            total_loss += loss.item()
            n_batches += 1
            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'adv': f"{loss_adv.item():.3f}"})
        
        avg_train_loss = total_loss / n_batches
        
        # Validation
        model.eval()
        valid_loss = 0
        valid_batches = 0
        time_correct = 0
        time_total = 0
        
        with torch.no_grad():
            for batch in valid_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                time_labels = batch['time_labels'].to(device)
                hohfeld_labels = batch['hohfeld_labels'].to(device)
                
                outputs = model(input_ids, attention_mask, adversarial_lambda=0)
                loss = F.cross_entropy(outputs['hohfeld_pred'], hohfeld_labels)
                valid_loss += loss.item()
                valid_batches += 1
                
                # Check if z_bond predicts time (should be near chance = 11%)
                time_preds = outputs['time_pred_bond'].argmax(dim=-1)
                time_correct += (time_preds == time_labels).sum().item()
                time_total += len(time_labels)
                
                if USE_TPU:
                    xm.mark_step()
        
        avg_valid_loss = valid_loss / valid_batches if valid_batches > 0 else 0
        time_acc = time_correct / time_total if time_total > 0 else 0
        
        # Print epoch summary
        print(f"\nEpoch {epoch} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Valid Loss: {avg_valid_loss:.4f}")
        print(f"  Time Acc from z_bond: {time_acc:.1%} (target: ~11% = chance)")
        
        if avg_valid_loss < best_valid_loss:
            best_valid_loss = avg_valid_loss
            # For TPU, need to save from CPU
            if USE_TPU:
                xm.save(model.state_dict(), "models/checkpoints/best_model.pt")
            else:
                torch.save(model.state_dict(), "models/checkpoints/best_model.pt")
            print(f"  -> Saved best model!")
        print()
    
    print("=" * 60)
    print("TRAINING COMPLETE")
    print("=" * 60)

mark_task("Train BIP model", "done")

In [None]:
#@title 10. Evaluate and Interpret Results { display-mode: "form" }
#@markdown Tests the BIP hypothesis: Does z_bond transfer across 2000 years?

mark_task("Evaluate results", "running")

print("=" * 60)
print("BIP TEMPORAL INVARIANCE TEST")
print("=" * 60)
print()

if len(train_dataset) > 0:
    # Load best model
    if USE_TPU:
        model.load_state_dict(torch.load("models/checkpoints/best_model.pt", map_location='cpu'))
        model = model.to(device)
    else:
        model.load_state_dict(torch.load("models/checkpoints/best_model.pt"))
    model.eval()
    
    # Evaluate on test set (modern data)
    print("Evaluating on TEST SET (modern passages)...")
    print()
    
    all_time_preds = []
    all_time_labels = []
    all_hohfeld_preds = []
    all_hohfeld_labels = []
    all_z_bonds = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating", unit="batch"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(input_ids, attention_mask, adversarial_lambda=0)
            
            all_time_preds.extend(outputs['time_pred_bond'].argmax(dim=-1).cpu().tolist())
            all_time_labels.extend(batch['time_labels'].tolist())
            all_hohfeld_preds.extend(outputs['hohfeld_pred'].argmax(dim=-1).cpu().tolist())
            all_hohfeld_labels.extend(batch['hohfeld_labels'].tolist())
            all_z_bonds.append(outputs['z_bond'].cpu())
            
            if USE_TPU:
                xm.mark_step()
    
    # Calculate metrics
    time_acc = sum(p == l for p, l in zip(all_time_preds, all_time_labels)) / len(all_time_preds)
    hohfeld_acc = sum(p == l for p, l in zip(all_hohfeld_preds, all_hohfeld_labels)) / len(all_hohfeld_preds)
    chance_level = 1/9  # 9 time periods
    
    print("=" * 60)
    print("RESULTS")
    print("=" * 60)
    print()
    
    print("TEST 1: Time Prediction from z_bond")
    print("-" * 40)
    print(f"  Accuracy: {time_acc:.1%}")
    print(f"  Chance level: {chance_level:.1%}")
    print(f"  Difference from chance: {abs(time_acc - chance_level):.1%}")
    print()
    
    if abs(time_acc - chance_level) < 0.05:
        print("  RESULT: z_bond IS time-invariant!")
        print("  The bond representation cannot predict temporal origin.")
        time_invariant = True
    else:
        print("  RESULT: z_bond retains some temporal information.")
        time_invariant = False
    print()
    
    print("TEST 2: Hohfeldian Classification from z_bond")
    print("-" * 40)
    print(f"  Accuracy: {hohfeld_acc:.1%}")
    print(f"  Random baseline: 25%")
    print()
    
    if hohfeld_acc > 0.4:
        print("  RESULT: z_bond captures moral structure!")
        moral_structure = True
    else:
        print("  RESULT: Weak moral structure encoding.")
        moral_structure = False
    print()
    
    print("=" * 60)
    print("INTERPRETATION")
    print("=" * 60)
    print()
    
    if time_invariant and moral_structure:
        print("""
    ******************************************************
    *                                                    *
    *     BIP TEMPORAL INVARIANCE: SUPPORTED             *
    *                                                    *
    ******************************************************
    
    The bond embedding (z_bond) successfully captured moral 
    structure while remaining invariant to temporal context.
    
    This suggests that moral cognition has a geometry that 
    is STABLE ACROSS 2000+ YEARS of human ethical reasoning.
    
    The same abstract patterns appear in:
    - Ancient Hebrew texts (~500 BCE - 500 CE)
    - Medieval rabbinical commentary (~500 - 1500 CE)  
    - Modern American advice columns (1956 - 2020)
    
    """)
    elif moral_structure and not time_invariant:
        print("""
    BIP TEST: PARTIAL SUPPORT
    
    The model captures moral structure but also retains
    some temporal information. This could indicate:
    - Need for stronger adversarial training
    - Genuine temporal variation in moral concepts
    - Artifact of linguistic differences
    
    """)
    else:
        print("""
    BIP TEST: INCONCLUSIVE
    
    For definitive results, ensure:
    1. Sefaria corpus loaded successfully  
    2. At least 10,000 passages per time period
    3. Model trained for 10+ epochs with GPU/TPU
    
    """)

mark_task("Evaluate results", "done")

print()
print("=" * 60)
print("EXPERIMENT COMPLETE")
print("=" * 60)
print_progress()

---

## Save Results

Run the cell below to download your trained model and results.

In [None]:
#@title 11. Download Results (Optional) { display-mode: "form" }
#@markdown Creates a zip file with model checkpoint and metrics.

import shutil
from google.colab import files

# Create results directory
!mkdir -p results
!cp models/checkpoints/best_model.pt results/
!cp data/splits/all_splits.json results/

# Save metrics
if len(train_dataset) > 0:
    metrics = {
        'accelerator': ACCELERATOR,
        'time_acc_from_bond': time_acc,
        'hohfeld_acc': hohfeld_acc,
        'chance_level': chance_level,
        'time_invariant': time_invariant,
        'moral_structure': moral_structure,
        'bip_supported': time_invariant and moral_structure
    }
    with open('results/metrics.json', 'w') as f:
        json.dump(metrics, f, indent=2)

# Zip
shutil.make_archive('bip_results', 'zip', 'results')
print("Results saved to bip_results.zip")
print()
print("Contents:")
!ls -la results/

# Download
files.download('bip_results.zip')