# BIP Temporal Invariance Experiment

**Testing the Bond Invariance Principle across 2000+ years of moral reasoning**

This experiment tests whether moral cognition has invariant structure by:
1. Training on ancient Hebrew ethical texts (Sefaria corpus, ~500 BCE - 1800 CE)
2. Testing transfer to modern American advice columns (Dear Abby, 1956-2020)

**Hypothesis**: If BIP holds, bond-level features should transfer across 2000 years with no accuracy drop.

---

## Setup Instructions
1. **Runtime -> Change runtime type -> GPU (T4) or TPU (v5e)**
2. Run cells in order - each shows progress in real-time
3. Expected runtime: ~1-2 hours (TPU) or ~2-4 hours (GPU)

**Supported Accelerators:**
- NVIDIA GPU (T4, V100, A100)
- Google TPU (v2, v3, v4, v5e)
- CPU (slow, not recommended)

---

In [None]:
#@title 1. Setup and Install Dependencies { display-mode: "form" }
#@markdown Installs packages and detects GPU/TPU. Memory-optimized for Colab.

print("=" * 60)
print("BIP TEMPORAL INVARIANCE EXPERIMENT")
print("=" * 60)
print()

# Progress tracker
TASKS = [
    "Install dependencies",
    "Clone Sefaria corpus (~8GB)",
    "Clone sqnd-probe repo (Dear Abby data)",
    "Preprocess corpora",
    "Extract bond structures",
    "Generate train/test splits",
    "Train BIP model",
    "Evaluate results"
]
task_status = {task: "pending" for task in TASKS}

def print_progress():
    print()
    print("-" * 50)
    print("EXPERIMENT PROGRESS:")
    print("-" * 50)
    for task in TASKS:
        status = task_status[task]
        if status == "done":
            mark = "[X]"
        elif status == "running":
            mark = "[>]"
        else:
            mark = "[ ]"
        print(f"  {mark} {task}")
    print("-" * 50)
    print(flush=True)

def mark_task(task, status):
    task_status[task] = status
    print_progress()

print_progress()

mark_task("Install dependencies", "running")

import os
import subprocess
import sys

# Install dependencies - MINIMAL set to save memory
print("Installing minimal dependencies...")
deps = [
    "transformers",
    "torch", 
    "sentence-transformers",
    "pandas",
    "tqdm",
    "psutil"
]

for dep in deps:
    print(f"  Installing {dep}...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", dep])

print()

# Detect accelerator - WITHOUT importing tensorflow
USE_TPU = False
TPU_TYPE = None

# Check for TPU
if 'COLAB_TPU_ADDR' in os.environ:
    USE_TPU = True
    TPU_TYPE = "TPU (Colab)"
    print("TPU detected!")

# Check for GPU
import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    ACCELERATOR = f"GPU: {gpu_name} ({gpu_mem:.1f}GB)"
    device = torch.device("cuda")
elif USE_TPU:
    ACCELERATOR = TPU_TYPE
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
else:
    ACCELERATOR = "CPU (slow!)"
    device = torch.device("cpu")

print(f"Accelerator: {ACCELERATOR}")
print(f"Device: {device}")

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"System RAM: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

if torch.cuda.is_available():
    print(f"GPU RAM: {torch.cuda.memory_allocated()/1e9:.1f}/{torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")

mark_task("Install dependencies", "done")


In [None]:
#@title 2. Download Sefaria Corpus (~8GB) { display-mode: "form" }
#@markdown Downloads the complete Sefaria corpus with real-time git progress.

import subprocess
import sys

mark_task("Clone Sefaria corpus (~8GB)", "running")

sefaria_path = 'data/raw/Sefaria-Export'

if not os.path.exists(sefaria_path) or not os.path.exists(f"{sefaria_path}/json"):
    print("="*60)
    print("CLONING SEFARIA CORPUS")
    print("="*60)
    print()
    print("This downloads ~3.5GB and takes 5-15 minutes.")
    print("Git's native progress will display below:")
    print("-"*60)
    print(flush=True)
    
    # Use subprocess.Popen for real-time output streaming
    process = subprocess.Popen(
        ['git', 'clone', '--depth', '1', '--progress',
         'https://github.com/Sefaria/Sefaria-Export.git', sefaria_path],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,  # Git writes progress to stderr
        text=True,
        bufsize=1
    )
    
    # Stream output in real-time
    for line in process.stdout:
        print(line, end='', flush=True)
    
    process.wait()
    
    print("-"*60)
    if process.returncode == 0:
        print("\nSefaria clone COMPLETE!")
    else:
        print(f"\nERROR: Git clone failed with code {process.returncode}")
        print("Try running this cell again, or check your internet connection.")
else:
    print("Sefaria already exists, skipping download.")

# Verify and count files
print()
print("Verifying download...")
!du -sh {sefaria_path} 2>/dev/null || echo "Directory not found"
json_count = !find {sefaria_path}/json -name "*.json" 2>/dev/null | wc -l
print(f"Sefaria JSON files found: {json_count[0]}")

mark_task("Clone Sefaria corpus (~8GB)", "done")


In [None]:
#@title 3. Download Dear Abby Dataset { display-mode: "form" }
#@markdown Downloads the Dear Abby advice column dataset (68,330 entries).

import pandas as pd
from pathlib import Path

mark_task("Clone sqnd-probe repo (Dear Abby data)", "running")

sqnd_path = 'sqnd-probe-data'
if not os.path.exists(sqnd_path):
    print("Cloning sqnd-probe repo...")
    process = subprocess.Popen(
        ['git', 'clone', '--depth', '1', '--progress',
         'https://github.com/ahb-sjsu/sqnd-probe.git', sqnd_path],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )
    for line in process.stdout:
        print(line, end='', flush=True)
    process.wait()
else:
    print("Repo already cloned.")

# Copy Dear Abby data
dear_abby_source = Path('sqnd-probe-data/dear_abby_data/raw_da_qs.csv')
dear_abby_path = Path('data/raw/dear_abby.csv')

if dear_abby_source.exists():
    !cp "{dear_abby_source}" "{dear_abby_path}"
    print(f"\nCopied Dear Abby data")
elif not dear_abby_path.exists():
    raise FileNotFoundError("Dear Abby dataset not found!")

# Verify
df_check = pd.read_csv(dear_abby_path)
print(f"\n" + "=" * 50)
print(f"Dear Abby dataset: {len(df_check):,} entries")
print(f"Columns: {list(df_check.columns)}")
print(f"Year range: {df_check['year'].min():.0f} - {df_check['year'].max():.0f}")
print("=" * 50)

mark_task("Clone sqnd-probe repo (Dear Abby data)", "done")

In [None]:
#@title 4. Define Data Classes and Loaders { display-mode: "form" }
#@markdown Defines enums, dataclasses, and corpus loaders.

import json
import hashlib
import re
from dataclasses import dataclass, field, asdict
from typing import List, Dict
from enum import Enum
from collections import defaultdict
from tqdm.auto import tqdm

print("Defining data structures...")

class TimePeriod(Enum):
    BIBLICAL = 0        # ~1000-500 BCE
    SECOND_TEMPLE = 1   # ~500 BCE - 70 CE
    TANNAITIC = 2       # ~70-200 CE
    AMORAIC = 3         # ~200-500 CE
    GEONIC = 4          # ~600-1000 CE
    RISHONIM = 5        # ~1000-1500 CE
    ACHRONIM = 6        # ~1500-1800 CE
    MODERN_HEBREW = 7   # ~1800-present
    DEAR_ABBY = 8       # 1956-2020

class BondType(Enum):
    HARM_PREVENTION = 0
    RECIPROCITY = 1
    AUTONOMY = 2
    PROPERTY = 3
    FAMILY = 4
    AUTHORITY = 5
    EMERGENCY = 6
    CONTRACT = 7
    CARE = 8
    FAIRNESS = 9

class HohfeldianState(Enum):
    RIGHT = 0
    OBLIGATION = 1
    LIBERTY = 2
    NO_RIGHT = 3

@dataclass
class Passage:
    id: str
    text_original: str
    text_english: str
    time_period: str
    century: int
    source: str
    source_type: str
    category: str
    language: str = "hebrew"
    word_count: int = 0
    has_dispute: bool = False
    consensus_tier: str = "unknown"
    bond_types: List[str] = field(default_factory=list)
    
    def to_dict(self):
        return asdict(self)

CATEGORY_TO_PERIOD = {
    'Tanakh': TimePeriod.BIBLICAL,
    'Torah': TimePeriod.BIBLICAL,
    'Mishnah': TimePeriod.TANNAITIC,
    'Tosefta': TimePeriod.TANNAITIC,
    'Talmud': TimePeriod.AMORAIC,
    'Bavli': TimePeriod.AMORAIC,
    'Midrash': TimePeriod.AMORAIC,
    'Halakhah': TimePeriod.RISHONIM,
    'Chasidut': TimePeriod.ACHRONIM,
}

PERIOD_TO_CENTURY = {
    TimePeriod.BIBLICAL: -6,
    TimePeriod.SECOND_TEMPLE: -2,
    TimePeriod.TANNAITIC: 2,
    TimePeriod.AMORAIC: 4,
    TimePeriod.GEONIC: 8,
    TimePeriod.RISHONIM: 12,
    TimePeriod.ACHRONIM: 17,
    TimePeriod.MODERN_HEBREW: 20,
}

def load_sefaria(base_path: str, max_passages: int = None) -> List[Passage]:
    """Load Sefaria corpus with progress bar."""
    passages = []
    json_path = Path(base_path) / "json"
    
    if not json_path.exists():
        print(f"Warning: {json_path} not found")
        return []
    
    json_files = list(json_path.rglob("*.json"))
    print(f"Processing {len(json_files):,} JSON files...")
    
    for json_file in tqdm(json_files[:max_passages] if max_passages else json_files,
                          desc="Loading Sefaria", unit="file"):
        try:
            with open(json_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except:
            continue
        
        rel_path = json_file.relative_to(json_path)
        category = str(rel_path.parts[0]) if rel_path.parts else "unknown"
        time_period = CATEGORY_TO_PERIOD.get(category, TimePeriod.AMORAIC)
        century = PERIOD_TO_CENTURY.get(time_period, 0)
        
        if isinstance(data, dict):
            hebrew = data.get('he', data.get('text', []))
            english = data.get('text', data.get('en', []))
            
            def flatten(h, e, ref=""):
                if isinstance(h, str) and isinstance(e, str):
                    h_clean = re.sub(r'<[^>]+>', '', h).strip()
                    e_clean = re.sub(r'<[^>]+>', '', e).strip()
                    if 50 <= len(e_clean) <= 2000:
                        pid = hashlib.md5(f"{json_file.stem}:{ref}:{h_clean[:50]}".encode()).hexdigest()[:12]
                        return [Passage(
                            id=f"sefaria_{pid}",
                            text_original=h_clean,
                            text_english=e_clean,
                            time_period=time_period.name,
                            century=century,
                            source=f"{json_file.stem} {ref}".strip(),
                            source_type="sefaria",
                            category=category,
                            language="hebrew",
                            word_count=len(e_clean.split())
                        )]
                    return []
                elif isinstance(h, list) and isinstance(e, list):
                    result = []
                    for i, (hh, ee) in enumerate(zip(h, e)):
                        result.extend(flatten(hh, ee, f"{ref}.{i+1}" if ref else str(i+1)))
                    return result
                return []
            
            passages.extend(flatten(hebrew, english))
    
    return passages

def load_dear_abby(path: str, max_passages: int = None) -> List[Passage]:
    """Load Dear Abby corpus with progress bar."""
    passages = []
    df = pd.read_csv(path)
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading Dear Abby", unit="row"):
        question = str(row.get('question_only', ''))
        if not question or question == 'nan' or len(question) < 50 or len(question) > 2000:
            continue
        
        year = int(row.get('year', 1990))
        pid = hashlib.md5(f"abby:{idx}:{question[:50]}".encode()).hexdigest()[:12]
        
        passages.append(Passage(
            id=f"abby_{pid}",
            text_original=question,
            text_english=question,
            time_period=TimePeriod.DEAR_ABBY.name,
            century=20 if year < 2000 else 21,
            source=f"Dear Abby {year}",
            source_type="dear_abby",
            category="general",
            language="english",
            word_count=len(question.split())
        ))
        
        if max_passages and len(passages) >= max_passages:
            break
    
    return passages

print("Data structures defined!")

In [None]:
#@title 5. Load and Preprocess Corpora { display-mode: "form" }
#@markdown Loads both corpora. Set MAX_SEFARIA_PASSAGES to limit memory usage.

#@markdown **Memory Management:**
MAX_SEFARIA_PASSAGES = 500000  #@param {type:"integer"}
#@markdown Set to 0 for unlimited. Recommended: 500000 for Colab (12GB RAM)

import gc

mark_task("Preprocess corpora", "running")

print("=" * 60)
print("LOADING CORPORA")
print("=" * 60)
print()

if MAX_SEFARIA_PASSAGES > 0:
    print(f"*** MEMORY MODE: Limited to {MAX_SEFARIA_PASSAGES:,} Sefaria passages ***")
    print()

# Load Sefaria with optional limit
limit = MAX_SEFARIA_PASSAGES if MAX_SEFARIA_PASSAGES > 0 else None
sefaria_passages = load_sefaria("data/raw/Sefaria-Export", max_passages=limit)
print(f"\nSefaria passages loaded: {len(sefaria_passages):,}")

# Force garbage collection
gc.collect()

# Load Dear Abby
print()
abby_passages = load_dear_abby("data/raw/dear_abby.csv")
print(f"\nDear Abby passages loaded: {len(abby_passages):,}")

# Combine
all_passages = sefaria_passages + abby_passages

# Clear individual lists to save memory
del sefaria_passages
del abby_passages
gc.collect()

print()
print("=" * 60)
print(f"TOTAL PASSAGES: {len(all_passages):,}")
print("=" * 60)

# Statistics
by_period = defaultdict(int)
by_source = defaultdict(int)
for p in all_passages:
    by_period[p.time_period] += 1
    by_source[p.source_type] += 1

print("\nBy source:")
for source, count in sorted(by_source.items()):
    print(f"  {source}: {count:,}")

print("\nBy time period:")
for period, count in sorted(by_period.items()):
    pct = count / len(all_passages) * 100
    bar = '#' * int(pct / 2)
    print(f"  {period:20s}: {count:6,} ({pct:5.1f}%) {bar}")

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"\nMemory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

mark_task("Preprocess corpora", "done")


In [None]:
#@title 6. Extract Bond Structures { display-mode: "form" }
#@markdown Extracts moral bond structures. Streams to disk to save memory.

import gc

mark_task("Extract bond structures", "running")

RELATION_PATTERNS = {
    BondType.HARM_PREVENTION: [r'\b(kill|murder|harm|hurt|save|rescue|protect|danger)\b'],
    BondType.RECIPROCITY: [r'\b(return|repay|owe|debt|mutual|exchange)\b'],
    BondType.AUTONOMY: [r'\b(choose|decision|consent|agree|force|coerce|right)\b'],
    BondType.PROPERTY: [r'\b(property|own|steal|theft|buy|sell|land)\b'],
    BondType.FAMILY: [r'\b(honor|parent|marry|divorce|inherit|family)\b'],
    BondType.AUTHORITY: [r'\b(obey|command|law|judge|rule|teach)\b'],
    BondType.CARE: [r'\b(care|help|assist|feed|clothe|visit)\b'],
    BondType.FAIRNESS: [r'\b(fair|just|equal|deserve|bias)\b'],
}

HOHFELD_PATTERNS = {
    HohfeldianState.OBLIGATION: [r'\b(must|shall|duty|require|should)\b'],
    HohfeldianState.RIGHT: [r'\b(right to|entitled|deserve)\b'],
    HohfeldianState.LIBERTY: [r'\b(may|can|permitted|allowed)\b'],
}

def extract_bond_structure(passage: Passage) -> Dict:
    """Extract bond structure from passage."""
    text = passage.text_english.lower()
    
    relations = []
    for rel_type, patterns in RELATION_PATTERNS.items():
        for pattern in patterns:
            if re.search(pattern, text, re.IGNORECASE):
                relations.append(rel_type.name)
                break
    
    if not relations:
        relations = ['CARE']
    
    hohfeld = None
    for state, patterns in HOHFELD_PATTERNS.items():
        for pattern in patterns:
            if re.search(pattern, text, re.IGNORECASE):
                hohfeld = state.name
                break
        if hohfeld:
            break
    
    signature = "|".join(sorted(set(relations)))
    
    return {
        'bonds': [{'relation': r} for r in relations],
        'primary_relation': relations[0],
        'hohfeld_state': hohfeld,
        'signature': signature
    }

print("=" * 60)
print("EXTRACTING & SAVING (STREAMING)")
print("=" * 60)
print()
print("Writing directly to disk to conserve memory...")
print()

bond_counts = defaultdict(int)

# Stream directly to files - don't accumulate in memory
with open("data/processed/passages.jsonl", 'w') as f_pass, \
     open("data/processed/bond_structures.jsonl", 'w') as f_bond:
    
    for passage in tqdm(all_passages, desc="Processing", unit="passage"):
        # Extract bonds
        bond_struct = extract_bond_structure(passage)
        passage.bond_types = [b['relation'] for b in bond_struct['bonds']]
        
        # Count for stats
        for bond in bond_struct['bonds']:
            bond_counts[bond['relation']] += 1
        
        # Write immediately (don't accumulate)
        f_pass.write(json.dumps(passage.to_dict()) + '\n')
        f_bond.write(json.dumps({
            'passage_id': passage.id,
            'bond_structure': bond_struct
        }) + '\n')

# Clear passages from memory - we've saved them to disk
n_passages = len(all_passages)
del all_passages
gc.collect()

print()
print(f"Saved {n_passages:,} passages to disk")
print("Cleared passages from memory")

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"Memory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

print()
print("Bond type distribution:")
for bond_type, count in sorted(bond_counts.items(), key=lambda x: -x[1]):
    pct = count / sum(bond_counts.values()) * 100
    bar = '#' * int(pct)
    print(f"  {bond_type:20s}: {count:6,} ({pct:5.1f}%) {bar}")

mark_task("Extract bond structures", "done")


In [None]:
#@title 7. Generate Train/Test Splits { display-mode: "form" }
#@markdown Creates splits from saved files. Memory efficient - reads only IDs.

import random
import gc
random.seed(42)

mark_task("Generate train/test splits", "running")

print("=" * 60)
print("GENERATING SPLITS (MEMORY EFFICIENT)")
print("=" * 60)
print()

# Read only IDs and time periods from disk - don't load full passages
print("Reading passage metadata from disk...")
passage_info = []  # List of (id, time_period) tuples - minimal memory

with open("data/processed/passages.jsonl", 'r') as f:
    for line in tqdm(f, desc="Reading IDs", unit="line"):
        p = json.loads(line)
        passage_info.append((p['id'], p['time_period']))

print(f"Loaded {len(passage_info):,} passage IDs")

# Define time periods
train_periods = {'BIBLICAL', 'SECOND_TEMPLE', 'TANNAITIC', 'AMORAIC', 'GEONIC', 'RISHONIM'}
valid_periods = {'ACHRONIM'}
test_periods = {'MODERN_HEBREW', 'DEAR_ABBY'}

print()
print("Filtering by time period...")
ancient_ids = [(pid, tp) for pid, tp in passage_info if tp in train_periods]
early_modern_ids = [(pid, tp) for pid, tp in passage_info if tp in valid_periods]
modern_ids = [(pid, tp) for pid, tp in passage_info if tp in test_periods]

print(f"  Ancient/Medieval: {len(ancient_ids):,}")
print(f"  Early Modern:     {len(early_modern_ids):,}")
print(f"  Modern:           {len(modern_ids):,}")

# Shuffle
random.shuffle(ancient_ids)
random.shuffle(early_modern_ids)
random.shuffle(modern_ids)

# ============================================================
# SPLIT A: ANCIENT -> MODERN
# ============================================================
print()
print("-" * 60)
print("SPLIT A: Train ANCIENT, Test MODERN")
print("-" * 60)

temporal_A = {
    'name': 'ancient_to_modern',
    'direction': 'A->M',
    'train_ids': [pid for pid, _ in ancient_ids],
    'valid_ids': [pid for pid, _ in early_modern_ids],
    'test_ids': [pid for pid, _ in modern_ids],
    'train_size': len(ancient_ids),
    'valid_size': len(early_modern_ids),
    'test_size': len(modern_ids)
}
print(f"  Train: {temporal_A['train_size']:,}")
print(f"  Valid: {temporal_A['valid_size']:,}")
print(f"  Test:  {temporal_A['test_size']:,}")

# ============================================================
# SPLIT B: MODERN -> ANCIENT
# ============================================================
print()
print("-" * 60)
print("SPLIT B: Train MODERN, Test ANCIENT")
print("-" * 60)

n_modern = len(modern_ids)
ancient_test = ancient_ids[n_modern:n_modern*2] if len(ancient_ids) >= n_modern*2 else ancient_ids[n_modern:]

temporal_B = {
    'name': 'modern_to_ancient',
    'direction': 'M->A',
    'train_ids': [pid for pid, _ in modern_ids],
    'valid_ids': [pid for pid, _ in early_modern_ids[:len(early_modern_ids)//2]],
    'test_ids': [pid for pid, _ in ancient_test],
    'train_size': len(modern_ids),
    'valid_size': len(early_modern_ids) // 2,
    'test_size': len(ancient_test)
}
print(f"  Train: {temporal_B['train_size']:,}")
print(f"  Valid: {temporal_B['valid_size']:,}")
print(f"  Test:  {temporal_B['test_size']:,}")

# ============================================================
# SPLIT C: MIXED CONTROL
# ============================================================
print()
print("-" * 60)
print("SPLIT C: MIXED (Control)")
print("-" * 60)

all_ids = ancient_ids + modern_ids
random.shuffle(all_ids)
n = len(all_ids)
n_train = int(0.7 * n)
n_valid = int(0.15 * n)

temporal_C = {
    'name': 'mixed_control',
    'direction': 'MIXED',
    'train_ids': [pid for pid, _ in all_ids[:n_train]],
    'valid_ids': [pid for pid, _ in all_ids[n_train:n_train+n_valid]],
    'test_ids': [pid for pid, _ in all_ids[n_train+n_valid:]],
    'train_size': n_train,
    'valid_size': n_valid,
    'test_size': n - n_train - n_valid
}
print(f"  Train: {temporal_C['train_size']:,}")
print(f"  Valid: {temporal_C['valid_size']:,}")
print(f"  Test:  {temporal_C['test_size']:,}")

# Clear temporary data
del passage_info, ancient_ids, early_modern_ids, modern_ids, all_ids
gc.collect()

# Save
print()
print("Saving splits...")
splits = {
    'ancient_to_modern': temporal_A,
    'modern_to_ancient': temporal_B,
    'mixed_control': temporal_C
}

with open("data/splits/all_splits.json", 'w') as f:
    json.dump(splits, f, indent=2)

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"Memory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")

print()
print("SPLITS SAVED:")
print("  - ancient_to_modern (A->M)")
print("  - modern_to_ancient (M->A)")  
print("  - mixed_control")

mark_task("Generate train/test splits", "done")


In [None]:
#@title 8. Define BIP Model Architecture { display-mode: "form" }
#@markdown Defines the model. Clears memory first to avoid OOM.

import gc
import psutil

# CRITICAL: Clear memory before loading model
print("Clearing memory before model load...")
gc.collect()

if torch.cuda.is_available():
    torch.cuda.empty_cache()

mem = psutil.virtual_memory()
print(f"Memory before model: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")
print()


import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import gc

print("=" * 60)
print("DEFINING MODEL ARCHITECTURE")
print("=" * 60)
print()

class GradientReversal(torch.autograd.Function):
    """Gradient reversal layer for adversarial training."""
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()
    
    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_ * grad_output, None

def gradient_reversal(x, lambda_=1.0):
    return GradientReversal.apply(x, lambda_)

class BIPEncoder(nn.Module):
    """Sentence encoder using pretrained transformer."""
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", d_model=384):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.d_model = d_model
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
        return pooled

class BIPModel(nn.Module):
    """Bond Invariance Principle Model with adversarial disentanglement."""
    def __init__(self, d_model=384, d_bond=64, d_label=32, n_periods=9, n_hohfeld=4):
        super().__init__()
        
        self.encoder = BIPEncoder()
        
        self.bond_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, d_bond)
        )
        
        self.label_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, d_label)
        )
        
        self.time_classifier_bond = nn.Linear(d_bond, n_periods)
        self.time_classifier_label = nn.Linear(d_label, n_periods)
        self.hohfeld_classifier = nn.Linear(d_bond, n_hohfeld)
    
    def forward(self, input_ids, attention_mask, adversarial_lambda=1.0):
        h = self.encoder(input_ids, attention_mask)
        
        z_bond = self.bond_proj(h)
        z_label = self.label_proj(h)
        
        z_bond_adv = gradient_reversal(z_bond, adversarial_lambda)
        time_pred_bond = self.time_classifier_bond(z_bond_adv)
        time_pred_label = self.time_classifier_label(z_label)
        hohfeld_pred = self.hohfeld_classifier(z_bond)
        
        return {
            'z_bond': z_bond,
            'z_label': z_label,
            'time_pred_bond': time_pred_bond,
            'time_pred_label': time_pred_label,
            'hohfeld_pred': hohfeld_pred
        }

# Time period mapping
TIME_PERIOD_TO_IDX = {
    'BIBLICAL': 0, 'SECOND_TEMPLE': 1, 'TANNAITIC': 2, 'AMORAIC': 3,
    'GEONIC': 4, 'RISHONIM': 5, 'ACHRONIM': 6, 'MODERN_HEBREW': 7, 'DEAR_ABBY': 8
}

HOHFELD_TO_IDX = {
    'OBLIGATION': 0, 'RIGHT': 1, 'LIBERTY': 2, None: 3
}

class MoralDataset(Dataset):
    """
    MEMORY-EFFICIENT Dataset that reads from disk on demand.
    Does NOT load all data into memory at once.
    """
    def __init__(self, passage_ids: set, passages_file: str, bonds_file: str, tokenizer, max_len=128):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.passage_ids = passage_ids
        
        # Build index: passage_id -> (file_offset, line_length) for passages file
        # This allows us to seek directly to the line we need
        print(f"  Indexing {len(passage_ids):,} passages...")
        
        self.data = []  # Store minimal data: (text, time_period, hohfeld_state)
        
        # Load only the passages we need
        with open(passages_file, 'r') as f_pass, open(bonds_file, 'r') as f_bond:
            for p_line, b_line in tqdm(zip(f_pass, f_bond), desc="  Loading subset", unit="line", total=None):
                p = json.loads(p_line)
                if p['id'] in passage_ids:
                    b = json.loads(b_line)
                    self.data.append({
                        'text': p['text_english'][:1000],  # Truncate long texts
                        'time_period': p['time_period'],
                        'hohfeld': b['bond_structure']['hohfeld_state']
                    })
        
        print(f"  Loaded {len(self.data):,} samples")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            max_length=self.max_len,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'time_label': TIME_PERIOD_TO_IDX.get(item['time_period'], 8),
            'hohfeld_label': HOHFELD_TO_IDX.get(item['hohfeld'], 3)
        }

def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'time_labels': torch.tensor([x['time_label'] for x in batch]),
        'hohfeld_labels': torch.tensor([x['hohfeld_label'] for x in batch])
    }

# Memory cleanup
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("Model architecture defined!")
print()

# Memory status
import psutil
mem = psutil.virtual_memory()
print(f"Memory: {mem.used/1e9:.1f}/{mem.total/1e9:.1f} GB ({mem.percent}%)")


In [None]:
#@title 9. Train BIP Model (GPU/TPU) { display-mode: "form" }
#@markdown Trains the model with adversarial disentanglement. Runs on selected split.

#@markdown **Select which split to train:**
SPLIT_NAME = "ancient_to_modern"  #@param ["ancient_to_modern", "modern_to_ancient", "mixed_control"]

mark_task("Train BIP model", "running")

print("=" * 60)
print("TRAINING BIP MODEL")
print("=" * 60)
print()
print(f"Accelerator: {ACCELERATOR}")
print(f"Device: {device}")
print()

# Load the selected split
with open("data/splits/all_splits.json", 'r') as f:
    all_splits = json.load(f)

split = all_splits[SPLIT_NAME]
print(f"Using split: {SPLIT_NAME}")
print(f"  Direction: {split.get('direction', 'N/A')}")
print(f"  Train: {split['train_size']:,}")
print(f"  Valid: {split['valid_size']:,}")
print(f"  Test:  {split['test_size']:,}")
print()

# For backward compatibility
temporal_holdout = split

# Initialize
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = BIPModel().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print()

# Create datasets
print("Creating datasets...")
train_dataset = MoralDataset(
    set(split['train_ids']),
    "data/processed/passages.jsonl",
    "data/processed/bond_structures.jsonl",
    tokenizer
)
valid_dataset = MoralDataset(
    set(split['valid_ids']),
    "data/processed/passages.jsonl",
    "data/processed/bond_structures.jsonl",
    tokenizer
)
test_dataset = MoralDataset(
    set(split['test_ids']),
    "data/processed/passages.jsonl",
    "data/processed/bond_structures.jsonl",
    tokenizer
)

print(f"Train samples: {len(train_dataset):,}")
print(f"Valid samples: {len(valid_dataset):,}")
print(f"Test samples:  {len(test_dataset):,}")
print()

if len(train_dataset) == 0:
    print("ERROR: No training data! Check if data loaded correctly.")
else:
    batch_size = 64 if USE_TPU else 32
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                              collate_fn=collate_fn, drop_last=True, num_workers=2)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size*2, shuffle=False, 
                              collate_fn=collate_fn, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size*2, shuffle=False, 
                             collate_fn=collate_fn, num_workers=2)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
    
    n_epochs = 10
    best_valid_loss = float('inf')
    
    # Save model with split name
    model_path = f"models/checkpoints/best_model_{SPLIT_NAME}.pt"
    
    print(f"Training for {n_epochs} epochs (batch_size={batch_size})...")
    print(f"Model will be saved to: {model_path}")
    print("=" * 60)
    print()
    
    for epoch in range(1, n_epochs + 1):
        # Training
        model.train()
        total_loss = 0
        n_batches = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{n_epochs}", unit="batch")
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            time_labels = batch['time_labels'].to(device)
            hohfeld_labels = batch['hohfeld_labels'].to(device)
            
            outputs = model(input_ids, attention_mask, adversarial_lambda=1.0)
            
            # Losses
            time_probs = F.softmax(outputs['time_pred_bond'], dim=-1)
            entropy = -torch.sum(time_probs * torch.log(time_probs + 1e-8), dim=-1)
            loss_adv = -entropy.mean()
            
            loss_time = F.cross_entropy(outputs['time_pred_label'], time_labels)
            loss_hohfeld = F.cross_entropy(outputs['hohfeld_pred'], hohfeld_labels)
            
            loss = loss_adv + loss_time + loss_hohfeld
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            if USE_TPU:
                xm.optimizer_step(optimizer)
                xm.mark_step()
            else:
                optimizer.step()
            
            total_loss += loss.item()
            n_batches += 1
            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'adv': f"{loss_adv.item():.3f}"})
        
        avg_train_loss = total_loss / n_batches
        
        # Validation
        model.eval()
        valid_loss = 0
        valid_batches = 0
        time_correct = 0
        time_total = 0
        
        with torch.no_grad():
            for batch in valid_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                time_labels = batch['time_labels'].to(device)
                hohfeld_labels = batch['hohfeld_labels'].to(device)
                
                outputs = model(input_ids, attention_mask, adversarial_lambda=0)
                loss = F.cross_entropy(outputs['hohfeld_pred'], hohfeld_labels)
                valid_loss += loss.item()
                valid_batches += 1
                
                time_preds = outputs['time_pred_bond'].argmax(dim=-1)
                time_correct += (time_preds == time_labels).sum().item()
                time_total += len(time_labels)
                
                if USE_TPU:
                    xm.mark_step()
        
        avg_valid_loss = valid_loss / valid_batches if valid_batches > 0 else 0
        time_acc = time_correct / time_total if time_total > 0 else 0
        
        print(f"\nEpoch {epoch} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Valid Loss: {avg_valid_loss:.4f}")
        print(f"  Time Acc from z_bond: {time_acc:.1%} (target: ~11% = chance)")
        
        if avg_valid_loss < best_valid_loss:
            best_valid_loss = avg_valid_loss
            if USE_TPU:
                xm.save(model.state_dict(), model_path)
            else:
                torch.save(model.state_dict(), model_path)
            print(f"  -> Saved best model!")
        print()
    
    print("=" * 60)
    print(f"TRAINING COMPLETE for {SPLIT_NAME}")
    print("=" * 60)

mark_task("Train BIP model", "done")


In [None]:
#@title 10. Evaluate Results (Bidirectional) { display-mode: "form" }
#@markdown Tests the BIP hypothesis with BIDIRECTIONAL transfer analysis.

import gc

mark_task("Evaluate results", "running")

print("=" * 60)
print("BIP BIDIRECTIONAL EVALUATION")
print("=" * 60)
print()

def evaluate_split(split_name, model_path, test_ids, tokenizer):
    """Evaluate a single split direction."""
    print(f"\nEvaluating: {split_name}")
    print("-" * 40)
    
    # Create test dataset
    test_ds = MoralDataset(
        set(test_ids),
        "data/processed/passages.jsonl",
        "data/processed/bond_structures.jsonl",
        tokenizer
    )
    test_ld = DataLoader(test_ds, batch_size=64, shuffle=False, collate_fn=collate_fn, num_workers=0)
    
    # Load model
    model_eval = BIPModel().to(device)
    try:
        if USE_TPU:
            model_eval.load_state_dict(torch.load(model_path, map_location='cpu'))
            model_eval = model_eval.to(device)
        else:
            model_eval.load_state_dict(torch.load(model_path, map_location=device))
        model_eval.eval()
    except FileNotFoundError:
        print(f"  Model not found: {model_path}")
        return None
    
    all_time_preds = []
    all_time_labels = []
    all_hohfeld_preds = []
    all_hohfeld_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_ld, desc=f"Testing", unit="batch"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model_eval(input_ids, attention_mask, adversarial_lambda=0)
            
            all_time_preds.extend(outputs['time_pred_bond'].argmax(dim=-1).cpu().tolist())
            all_time_labels.extend(batch['time_labels'].tolist())
            all_hohfeld_preds.extend(outputs['hohfeld_pred'].argmax(dim=-1).cpu().tolist())
            all_hohfeld_labels.extend(batch['hohfeld_labels'].tolist())
            
            if USE_TPU:
                xm.mark_step()
    
    # Cleanup
    del model_eval, test_ds, test_ld
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    time_acc = sum(p == l for p, l in zip(all_time_preds, all_time_labels)) / len(all_time_preds) if all_time_preds else 0
    hohfeld_acc = sum(p == l for p, l in zip(all_hohfeld_preds, all_hohfeld_labels)) / len(all_hohfeld_preds) if all_hohfeld_preds else 0
    
    return {'time_acc': time_acc, 'hohfeld_acc': hohfeld_acc}

# Load splits
with open("data/splits/all_splits.json", 'r') as f:
    all_splits = json.load(f)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

results = {}
chance_level = 1/9

# Evaluate each split that has a trained model
for split_name in ['ancient_to_modern', 'modern_to_ancient', 'mixed_control']:
    model_path = f"models/checkpoints/best_model_{split_name}.pt"
    
    if not os.path.exists(model_path):
        if split_name == 'ancient_to_modern' and os.path.exists("models/checkpoints/best_model.pt"):
            model_path = "models/checkpoints/best_model.pt"
        else:
            print(f"\nSkipping {split_name}: No model found")
            continue
    
    split = all_splits.get(split_name)
    if not split:
        continue
    
    result = evaluate_split(split_name, model_path, split['test_ids'], tokenizer)
    if result:
        results[split_name] = result

# Print results table
print()
print("=" * 60)
print("RESULTS SUMMARY")
print("=" * 60)
print()
print(f"{'Split':<25} {'Direction':<10} {'Time Acc':<12} {'Hohfeld Acc':<12}")
print("-" * 60)

for split_name, res in results.items():
    split = all_splits.get(split_name, {})
    direction = split.get('direction', '?')
    print(f"{split_name:<25} {direction:<10} {res['time_acc']:>10.1%} {res['hohfeld_acc']:>10.1%}")

print("-" * 60)
print(f"Chance level: {chance_level:.1%}")

# BIP Verdict
print()
print("=" * 60)
print("BIP VERDICT")
print("=" * 60)

a2m = results.get('ancient_to_modern', {})
m2a = results.get('modern_to_ancient', {})

if a2m:
    a2m_inv = abs(a2m.get('time_acc', 1) - chance_level) < 0.05
    a2m_struct = a2m.get('hohfeld_acc', 0) > 0.35
    
    print()
    if a2m_inv and a2m_struct:
        print("*** BIP SUPPORTED ***")
        print()
        print("Time invariance: YES (cannot predict era from z_bond)")
        print("Moral structure: YES (Hohfeld classification works)")
    elif a2m_struct:
        print("BIP: PARTIAL SUPPORT")
        print()
        print("Moral structure captured, but some temporal leakage.")
    else:
        print("BIP: INCONCLUSIVE")
        print()
        print("Weak moral structure encoding.")

    if m2a:
        m2a_inv = abs(m2a.get('time_acc', 1) - chance_level) < 0.05
        m2a_struct = m2a.get('hohfeld_acc', 0) > 0.35
        
        if a2m_inv and m2a_inv and a2m_struct and m2a_struct:
            print()
            print("*** BIDIRECTIONAL INVARIANCE CONFIRMED ***")
            print("Both A->M and M->A show temporal invariance!")
else:
    print("No results to evaluate. Train at least one split first.")

mark_task("Evaluate results", "done")

print()
print("=" * 60)
print("EXPERIMENT COMPLETE")
print("=" * 60)
print_progress()


---

## Save Results

Run the cell below to download your trained model and results.

In [None]:
#@title 11. Download Results (Optional) { display-mode: "form" }
#@markdown Creates a zip file with model checkpoint and metrics.

import shutil
from google.colab import files

# Create results directory
!mkdir -p results
!cp models/checkpoints/best_model.pt results/
!cp data/splits/all_splits.json results/

# Save metrics
if len(train_dataset) > 0:
    metrics = {
        'accelerator': ACCELERATOR,
        'time_acc_from_bond': time_acc,
        'hohfeld_acc': hohfeld_acc,
        'chance_level': chance_level,
        'time_invariant': time_invariant,
        'moral_structure': moral_structure,
        'bip_supported': time_invariant and moral_structure
    }
    with open('results/metrics.json', 'w') as f:
        json.dump(metrics, f, indent=2)

# Zip
shutil.make_archive('bip_results', 'zip', 'results')
print("Results saved to bip_results.zip")
print()
print("Contents:")
!ls -la results/

# Download
files.download('bip_results.zip')