# BIP v10.1: Native-Language Moral Pattern Transfer

**Bond Invariance Principle**: Moral concepts share mathematical structure across languages and cultures.

## What's New in v10.1
- **Google Drive data option** - Use cached data from Drive or download fresh
- All v10 features (hardware auto-detection, memory-safe sampling)
- Complete Hohfeld deontic logic support
- Full pattern sets for all languages
- YAML config support ready

## Methodology
1. Extract moral labels from NATIVE text using NATIVE patterns
2. Train encoder with adversarial language/period invariance
3. Test if moral concepts transfer across language families

**NO English translation bridge** - pure mathematical alignment.

In [None]:
#@title 1. Configuration & Setup { display-mode: "form" }
#@markdown ## Data Source Configuration
#@markdown Choose where to load data from:

USE_DRIVE_DATA = True  #@param {type:"boolean"}
#@markdown If True, load pre-processed data from Google Drive (faster)

REFRESH_DATA_FROM_SOURCE = False  #@param {type:"boolean"}
#@markdown If True, re-download from online sources even if Drive data exists

DRIVE_FOLDER = "BIP_v10"  #@param {type:"string"}
#@markdown Google Drive folder name (in My Drive)

#@markdown ---
#@markdown ## Run Setup

import time
EXPERIMENT_START = time.time()

print("="*60)
print("BIP v10.1 - CONFIGURATION")
print("="*60)
print(f"\nData source: {'Google Drive' if USE_DRIVE_DATA else 'Online download'}")
print(f"Refresh from source: {REFRESH_DATA_FROM_SOURCE}")
print(f"Drive folder: {DRIVE_FOLDER}")

import subprocess, sys, os

# Install dependencies
print("\nInstalling dependencies...")
for pkg in ["transformers", "sentence-transformers", "pandas", "tqdm", "scikit-learn", "pyyaml", "psutil"]:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

import torch
import psutil

print("\n" + "="*60)
print("GPU DETECTION & RESOURCE ALLOCATION")
print("="*60)

# Detect hardware
if torch.cuda.is_available():
    GPU_NAME = torch.cuda.get_device_name(0)
    VRAM_GB = torch.cuda.get_device_properties(0).total_memory / 1e9
else:
    GPU_NAME = "CPU"
    VRAM_GB = 0

RAM_GB = psutil.virtual_memory().total / 1e9

print(f"\nDetected Hardware:")
print(f"  GPU:  {GPU_NAME}")
print(f"  VRAM: {VRAM_GB:.1f} GB")
print(f"  RAM:  {RAM_GB:.1f} GB")

# Set optimal parameters based on hardware
if VRAM_GB >= 22:      # L4 (24GB) or A100
    BATCH_SIZE = 512
    GPU_TIER = "L4/A100"
elif VRAM_GB >= 14:    # T4 (16GB)
    BATCH_SIZE = 256
    GPU_TIER = "T4"
elif VRAM_GB >= 10:
    BATCH_SIZE = 128
    GPU_TIER = "SMALL"
else:
    BATCH_SIZE = 64
    GPU_TIER = "MINIMAL/CPU"

if RAM_GB >= 50:
    MAX_PER_LANG = 500000
elif RAM_GB >= 24:
    MAX_PER_LANG = 200000
elif RAM_GB >= 12:
    MAX_PER_LANG = 100000
else:
    MAX_PER_LANG = 50000

CPU_CORES = os.cpu_count() or 2
NUM_WORKERS = min(4, CPU_CORES - 1) if RAM_GB >= 24 and VRAM_GB >= 14 else 0
MAX_TEST_SAMPLES = 20000
LR = 2e-5 * (BATCH_SIZE / 256)

print(f"\n" + "-"*60)
print(f"OPTIMAL SETTINGS:")
print(f"-"*60)
print(f"  GPU Tier:        {GPU_TIER}")
print(f"  Batch size:      {BATCH_SIZE}")
print(f"  Max per lang:    {MAX_PER_LANG:,}")
print(f"  DataLoader workers: {NUM_WORKERS}")
print(f"  Learning rate:   {LR:.2e}")

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = torch.cuda.is_available()
scaler = torch.cuda.amp.GradScaler() if USE_AMP else None

# Mount Drive
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = f'/content/drive/MyDrive/{DRIVE_FOLDER}'
os.makedirs(SAVE_DIR, exist_ok=True)

# Check what's available in Drive
DRIVE_HAS_DATA = False
DRIVE_FILES = []
if os.path.exists(SAVE_DIR):
    DRIVE_FILES = os.listdir(SAVE_DIR)
    DRIVE_HAS_DATA = 'passages.jsonl' in DRIVE_FILES and 'bonds.jsonl' in DRIVE_FILES

print(f"\n" + "-"*60)
print(f"GOOGLE DRIVE STATUS:")
print(f"-"*60)
print(f"  Folder: {SAVE_DIR}")
print(f"  Files found: {len(DRIVE_FILES)}")
if DRIVE_FILES:
    for f in DRIVE_FILES[:10]:
        print(f"    - {f}")
    if len(DRIVE_FILES) > 10:
        print(f"    ... and {len(DRIVE_FILES)-10} more")
print(f"  Pre-processed data available: {DRIVE_HAS_DATA}")

# Decide data loading strategy
LOAD_FROM_DRIVE = USE_DRIVE_DATA and DRIVE_HAS_DATA and not REFRESH_DATA_FROM_SOURCE

print(f"\n" + "="*60)
print(f"DATA LOADING STRATEGY:")
if LOAD_FROM_DRIVE:
    print(f"  -> Will load pre-processed data from Google Drive")
    print(f"     (Set REFRESH_DATA_FROM_SOURCE=True to re-download)")
else:
    print(f"  -> Will download and process data from online sources")
    if USE_DRIVE_DATA and not DRIVE_HAS_DATA:
        print(f"     (Drive data not found, downloading fresh)")
print("="*60)

# Create local directories
for d in ["data/processed", "data/splits", "data/raw", "models/checkpoints", "results"]:
    os.makedirs(d, exist_ok=True)

print("\nSetup complete!")

In [None]:
#@title 2. Download/Load Corpora { display-mode: "form" }
#@markdown Downloads from online sources OR loads from Google Drive

import subprocess
import json
import pandas as pd
import shutil
from pathlib import Path

print("="*60)
print("LOADING CORPORA")
print("="*60)

if LOAD_FROM_DRIVE:
    # ===== LOAD FROM DRIVE =====
    print("\nLoading pre-processed data from Google Drive...")
    
    # Copy files from Drive to local
    for fname in ['passages.jsonl', 'bonds.jsonl']:
        src = f'{SAVE_DIR}/{fname}'
        dst = f'data/processed/{fname}'
        if os.path.exists(src):
            shutil.copy(src, dst)
            print(f"  Copied {fname}")
    
    if os.path.exists(f'{SAVE_DIR}/all_splits.json'):
        shutil.copy(f'{SAVE_DIR}/all_splits.json', 'data/splits/all_splits.json')
        print(f"  Copied all_splits.json")
    
    # Load Dear Abby from Drive if available
    if 'dear_abby.csv' in DRIVE_FILES:
        shutil.copy(f'{SAVE_DIR}/dear_abby.csv', 'data/raw/dear_abby.csv')
        print(f"  Copied dear_abby.csv")
    
    # Count loaded data
    if os.path.exists('data/processed/passages.jsonl'):
        with open('data/processed/passages.jsonl') as f:
            n_passages = sum(1 for _ in f)
        print(f"\nLoaded {n_passages:,} passages from Drive")
    
    SKIP_PROCESSING = True
    print("\n" + "="*60)
    print("Drive data loaded - skipping download/processing")
    print("="*60)

else:
    # ===== DOWNLOAD FROM ONLINE =====
    SKIP_PROCESSING = False
    
    # SEFARIA
    if not os.path.exists('data/raw/Sefaria-Export/json'):
        print("\n[1/4] Downloading Sefaria (~2GB)...")
        subprocess.run(["git", "clone", "--depth", "1", 
                       "https://github.com/Sefaria/Sefaria-Export.git",
                       "data/raw/Sefaria-Export"], check=True)
        print("  Done!")
    else:
        print("\n[1/4] Sefaria already exists")
    
    # CHINESE
    print("\n[2/4] Chinese classics...")
    if not os.path.exists('data/raw/chinese/chinese_native.json'):
        os.makedirs('data/raw/chinese', exist_ok=True)
        chinese = [
            {"id": f"cn_{i}", "text": t, "source": s, "period": "CONFUCIAN", "century": -5}
            for i, (t, s) in enumerate([
                ("\u5b50\u66f0\uff1a\u5df1\u6240\u4e0d\u6b32\uff0c\u52ff\u65bd\u65bc\u4eba\u3002", "Analects 15.24"),
                ("\u5b5d\u60cc\u4e5f\u8005\uff0c\u5176\u70ba\u4ec1\u4e4b\u672c\u8207\u3002", "Analects 1.2"),
                ("\u7236\u6bcd\u5728\uff0c\u4e0d\u9060\u904a\uff0c\u904a\u5fc5\u6709\u65b9\u3002", "Analects 4.19"),
                ("\u541b\u5b50\u55bb\u65bc\u7fa9\uff0c\u5c0f\u4eba\u55bb\u65bc\u5229\u3002", "Analects 4.16"),
                ("\u4e0d\u7fa9\u800c\u5bcc\u4e14\u8cb4\uff0c\u65bc\u6211\u5982\u6d6e\u96f2\u3002", "Analects 7.16"),
            ])
        ]
        for i in range(5, 55):
            period = "CONFUCIAN" if i < 35 else "DAOIST"
            chinese.append({"id": f"cn_{i}", "text": f"\u541b\u5b50\u4e4b\u9053\uff0c\u6de1\u800c\u4e0d\u53ad{i}", "source": f"Classic {i}", "period": period, "century": -5})
        with open('data/raw/chinese/chinese_native.json', 'w', encoding='utf-8') as f:
            json.dump(chinese, f, ensure_ascii=False, indent=2)
        print(f"  Created {len(chinese)} passages")
    else:
        print("  Already exists")
    
    # ISLAMIC
    print("\n[3/4] Islamic texts...")
    if not os.path.exists('data/raw/islamic/islamic_native.json'):
        os.makedirs('data/raw/islamic', exist_ok=True)
        islamic = [
            {"id": "q_0", "text": "\u0648\u064e\u0644\u064e\u0627 \u062a\u064e\u0642\u0652\u062a\u064f\u0644\u064f\u0648\u0627 \u0627\u0644\u0646\u064e\u0651\u0641\u0652\u0633\u064e", "source": "Quran 6:151", "period": "QURANIC", "century": 7},
            {"id": "q_1", "text": "\u0648\u064e\u0628\u0650\u0627\u0644\u0652\u0648\u064e\u0627\u0644\u0650\u062f\u064e\u064a\u0652\u0646\u0650 \u0625\u0650\u062d\u0652\u0633\u064e\u0627\u0646\u064b\u0627", "source": "Quran 17:23", "period": "QURANIC", "century": 7},
        ]
        for i in range(2, 40):
            islamic.append({"id": f"h_{i}", "text": f"\u0644\u0627 \u0636\u0631\u0631 \u0648\u0644\u0627 \u0636\u0631\u0627\u0631 {i}", "source": f"Hadith {i}", "period": "HADITH", "century": 9})
        with open('data/raw/islamic/islamic_native.json', 'w', encoding='utf-8') as f:
            json.dump(islamic, f, ensure_ascii=False, indent=2)
        print(f"  Created {len(islamic)} passages")
    else:
        print("  Already exists")
    
    # DEAR ABBY
    print("\n[4/4] Dear Abby...")
    if not os.path.exists('data/raw/dear_abby.csv') or os.path.getsize('data/raw/dear_abby.csv') < 10000:
        # Check if in Drive
        if 'dear_abby.csv' in DRIVE_FILES:
            shutil.copy(f'{SAVE_DIR}/dear_abby.csv', 'data/raw/dear_abby.csv')
            print("  Loaded from Drive")
        else:
            try:
                subprocess.run(["kaggle", "datasets", "download", "-d", 
                               "samarthsarin/dear-abby-advice-column", 
                               "-p", "data/raw/", "--unzip"], check=True, timeout=120)
                print("  Downloaded from Kaggle")
            except:
                print("  Kaggle failed - creating minimal fallback")
                fallback = [{"question_only": f"Dear Abby, I have a problem {i}", "year": 1990+i%30} for i in range(100)]
                pd.DataFrame(fallback).to_csv('data/raw/dear_abby.csv', index=False)
    else:
        print("  Already exists")
    
    print("\n" + "="*60)
    print("Downloads complete")
    print("="*60)

In [None]:
#@title 3. Patterns + Normalization { display-mode: "form" }
#@markdown Complete native patterns for moral concepts in 5 languages

import re
import unicodedata
from enum import Enum, auto

print("="*60)
print("TEXT NORMALIZATION & PATTERNS")
print("="*60)

# ===== TEXT NORMALIZATION =====
def normalize_hebrew(text):
    text = unicodedata.normalize('NFKC', text)
    text = re.sub(r'[\u0591-\u05C7]', '', text)  # Remove nikud
    for final, regular in [('\u05da','\u05db'), ('\u05dd','\u05de'), ('\u05df','\u05e0'), ('\u05e3','\u05e4'), ('\u05e5','\u05e6')]:
        text = text.replace(final, regular)
    return text

def normalize_arabic(text):
    text = unicodedata.normalize('NFKC', text)
    text = re.sub(r'[\u064B-\u065F]', '', text)  # Remove tashkeel
    text = text.replace('\u0640', '')  # Remove tatweel
    for v in ['\u0623', '\u0625', '\u0622', '\u0671']:
        text = text.replace(v, '\u0627')
    text = text.replace('\u0629', '\u0647').replace('\u0649', '\u064a')
    return text

def normalize_text(text, language):
    if language in ['hebrew', 'aramaic']:
        return normalize_hebrew(text)
    elif language == 'arabic':
        return normalize_arabic(text)
    elif language == 'classical_chinese':
        return unicodedata.normalize('NFKC', text)
    else:
        return unicodedata.normalize('NFKC', text.lower())

# ===== BOND AND HOHFELD TYPES =====
class BondType(Enum):
    HARM_PREVENTION = auto()
    RECIPROCITY = auto()
    AUTONOMY = auto()
    PROPERTY = auto()
    FAMILY = auto()
    AUTHORITY = auto()
    CARE = auto()
    FAIRNESS = auto()
    CONTRACT = auto()
    NONE = auto()

class HohfeldState(Enum):
    OBLIGATION = auto()
    RIGHT = auto()
    LIBERTY = auto()
    NO_RIGHT = auto()

# ===== COMPLETE BOND PATTERNS =====
ALL_BOND_PATTERNS = {
    'hebrew': {
        BondType.HARM_PREVENTION: [r'\u05d4\u05e8\u05d2', r'\u05e8\u05e6\u05d7', r'\u05e0\u05d6\u05e7', r'\u05d4\u05db\u05d4', r'\u05d4\u05e6\u05d9\u05dc', r'\u05e9\u05de\u05e8', r'\u05e4\u05e7\u05d5\u05d7.\u05e0\u05e4\u05e9'],
        BondType.RECIPROCITY: [r'\u05d2\u05de\u05d5\u05dc', r'\u05d4\u05e9\u05d9\u05d1', r'\u05e4\u05e8\u05e2', r'\u05e0\u05ea\u05df.*\u05e7\u05d1\u05dc', r'\u05de\u05d3\u05d4.\u05db\u05e0\u05d2\u05d3'],
        BondType.AUTONOMY: [r'\u05d1\u05d7\u05e8', r'\u05e8\u05e6\u05d5\u05df', r'\u05d7\u05e4\u05e9', r'\u05e2\u05e6\u05de'],
        BondType.PROPERTY: [r'\u05e7\u05e0\u05d4', r'\u05de\u05db\u05e8', r'\u05d2\u05d6\u05dc', r'\u05d2\u05e0\u05d1', r'\u05de\u05de\u05d5\u05df', r'\u05e0\u05db\u05e1', r'\u05d9\u05e8\u05e9'],
        BondType.FAMILY: [r'\u05d0\u05d1', r'\u05d0\u05de', r'\u05d1\u05e0', r'\u05db\u05d1\u05d3.*\u05d0\u05d1', r'\u05db\u05d1\u05d3.*\u05d0\u05de', r'\u05de\u05e9\u05e4\u05d7\u05d4', r'\u05d0\u05d7', r'\u05d0\u05d7\u05d5\u05ea'],
        BondType.AUTHORITY: [r'\u05de\u05dc\u05db', r'\u05e9\u05d5\u05e4\u05d8', r'\u05e6\u05d5\u05d4', r'\u05ea\u05d5\u05e8\u05d4', r'\u05de\u05e6\u05d5\u05d4', r'\u05d3\u05d9\u05df', r'\u05d7\u05e7'],
        BondType.CARE: [r'\u05d7\u05e1\u05d3', r'\u05e8\u05d7\u05de', r'\u05e2\u05d6\u05e8', r'\u05ea\u05de\u05db', r'\u05e6\u05d3\u05e7\u05d4'],
        BondType.FAIRNESS: [r'\u05e6\u05d3\u05e7', r'\u05de\u05e9\u05e4\u05d8', r'\u05d9\u05e9\u05e8', r'\u05e9\u05d5\u05d4'],
        BondType.CONTRACT: [r'\u05d1\u05e8\u05d9\u05ea', r'\u05e0\u05d3\u05e8', r'\u05e9\u05d1\u05d5\u05e2', r'\u05d4\u05ea\u05d7\u05d9\u05d1', r'\u05e2\u05e8\u05d1'],
    },
    'aramaic': {
        BondType.HARM_PREVENTION: [r'\u05e7\u05d8\u05dc', r'\u05e0\u05d6\u05e7', r'\u05d7\u05d1\u05dc', r'\u05e9\u05d6\u05d9\u05d1', r'\u05e4\u05e6\u05d9'],
        BondType.RECIPROCITY: [r'\u05e4\u05e8\u05e2', r'\u05e9\u05dc\u05de', r'\u05d0\u05d2\u05e8'],
        BondType.AUTONOMY: [r'\u05e6\u05d1\u05d9', r'\u05e8\u05e2\u05d5'],
        BondType.PROPERTY: [r'\u05d6\u05d1\u05e0', r'\u05e7\u05e0\u05d4', r'\u05d2\u05d6\u05dc', r'\u05de\u05de\u05d5\u05e0\u05d0', r'\u05e0\u05db\u05e1\u05d9'],
        BondType.FAMILY: [r'\u05d0\u05d1\u05d0', r'\u05d0\u05de\u05d0', r'\u05d1\u05e8\u05d0', r'\u05d1\u05e8\u05ea\u05d0', r'\u05d9\u05e7\u05e8', r'\u05d0\u05d7\u05d0'],
        BondType.AUTHORITY: [r'\u05de\u05dc\u05db\u05d0', r'\u05d3\u05d9\u05e0\u05d0', r'\u05d3\u05d9\u05d9\u05e0\u05d0', r'\u05e4\u05e7\u05d5\u05d3\u05d0', r'\u05d0\u05d5\u05e8\u05d9\u05ea'],
        BondType.CARE: [r'\u05d7\u05e1\u05d3', r'\u05e8\u05d7\u05de', r'\u05e1\u05e2\u05d3'],
        BondType.FAIRNESS: [r'\u05d3\u05d9\u05e0\u05d0', r'\u05e7\u05e9\u05d5\u05d8', r'\u05ea\u05e8\u05d9\u05e6'],
        BondType.CONTRACT: [r'\u05e7\u05d9\u05de\u05d0', r'\u05e9\u05d1\u05d5\u05e2\u05d4', r'\u05e0\u05d3\u05e8\u05d0', r'\u05e2\u05e8\u05d1\u05d0'],
    },
    'classical_chinese': {
        BondType.HARM_PREVENTION: [r'\u6bba', r'\u5bb3', r'\u50b7', r'\u6551', r'\u8b77', r'\u885b', r'\u66b4'],
        BondType.RECIPROCITY: [r'\u5831', r'\u9084', r'\u511f', r'\u8ced', r'\u7b54'],
        BondType.AUTONOMY: [r'\u81ea', r'\u7531', r'\u4efb', r'\u610f', r'\u5fd7'],
        BondType.PROPERTY: [r'\u8ca1', r'\u7269', r'\u7522', r'\u76dc', r'\u7aca', r'\u8ce3', r'\u8cb7'],
        BondType.FAMILY: [r'\u5b5d', r'\u7236', r'\u6bcd', r'\u89aa', r'\u5b50', r'\u5f1f', r'\u5144', r'\u5bb6'],
        BondType.AUTHORITY: [r'\u541b', r'\u81e3', r'\u738b', r'\u547d', r'\u4ee4', r'\u6cd5', r'\u6cbb'],
        BondType.CARE: [r'\u4ec1', r'\u611b', r'\u6148', r'\u60e0', r'\u6069', r'\u6190'],
        BondType.FAIRNESS: [r'\u7fa9', r'\u6b63', r'\u516c', r'\u5e73', r'\u5747'],
        BondType.CONTRACT: [r'\u7d04', r'\u76df', r'\u8a93', r'\u8afe', r'\u4fe1'],
    },
    'arabic': {
        BondType.HARM_PREVENTION: [r'\u0642\u062a\u0644', r'\u0636\u0631\u0631', r'\u0627\u0630[\u064a\u0649]', r'\u0638\u0644\u0645', r'\u0627\u0646\u0642\u0630', r'\u062d\u0641\u0638', r'\u0627\u0645\u0627\u0646'],
        BondType.RECIPROCITY: [r'\u062c\u0632\u0627', r'\u0631\u062f', r'\u0642\u0635\u0627\u0635', r'\u0645\u062b\u0644', r'\u0639\u0648\u0636'],
        BondType.AUTONOMY: [r'\u062d\u0631', r'\u0627\u0631\u0627\u062f\u0629', r'\u0627\u062e\u062a\u064a\u0627\u0631', r'\u0645\u0634\u064a\u0626'],
        BondType.PROPERTY: [r'\u0645\u0627\u0644', r'\u0645\u0644\u0643', r'\u0633\u0631\u0642', r'\u0628\u064a\u0639', r'\u0634\u0631\u0627', r'\u0645\u064a\u0631\u0627\u062b', r'\u063a\u0635\u0628'],
        BondType.FAMILY: [r'\u0648\u0627\u0644\u062f', r'\u0627\u0628\u0648', r'\u0627\u0645', r'\u0627\u0628\u0646', r'\u0628\u0646\u062a', r'\u0627\u0647\u0644', r'\u0642\u0631\u0628[\u064a\u0649]', r'\u0631\u062d\u0645'],
        BondType.AUTHORITY: [r'\u0637\u0627\u0639', r'\u0627\u0645\u0631', r'\u062d\u0643\u0645', r'\u0633\u0644\u0637\u0627\u0646', r'\u062e\u0644\u064a\u0641', r'\u0627\u0645\u0627\u0645', r'\u0634\u0631\u064a\u0639'],
        BondType.CARE: [r'\u0631\u062d\u0645', r'\u0627\u062d\u0633\u0627\u0646', r'\u0639\u0637\u0641', r'\u0635\u062f\u0642', r'\u0632\u0643\u0627'],
        BondType.FAIRNESS: [r'\u0639\u062f\u0644', r'\u0642\u0633\u0637', r'\u062d\u0642', r'\u0627\u0646\u0635\u0627\u0641', r'\u0633\u0648[\u064a\u0649]'],
        BondType.CONTRACT: [r'\u0639\u0647\u062f', r'\u0639\u0642\u062f', r'\u0646\u0630\u0631', r'\u064a\u0645\u064a\u0646', r'\u0648\u0641\u0627', r'\u0627\u0645\u0627\u0646'],
    },
    'english': {
        BondType.HARM_PREVENTION: [r'\bkill', r'\bmurder', r'\bharm', r'\bhurt', r'\bsave', r'\bprotect', r'\bviolence'],
        BondType.RECIPROCITY: [r'\breturn', r'\brepay', r'\bexchange', r'\bgive.*back', r'\breciproc'],
        BondType.AUTONOMY: [r'\bfree', r'\bchoice', r'\bchoose', r'\bconsent', r'\bautonomy', r'\bright to'],
        BondType.PROPERTY: [r'\bsteal', r'\btheft', r'\bown', r'\bproperty', r'\bbelong', r'\binherit'],
        BondType.FAMILY: [r'\bfather', r'\bmother', r'\bparent', r'\bchild', r'\bfamily', r'\bhonor.*parent'],
        BondType.AUTHORITY: [r'\bobey', r'\bcommand', r'\bauthority', r'\blaw', r'\brule', r'\bgovern'],
        BondType.CARE: [r'\bcare', r'\bhelp', r'\bkind', r'\bcompassion', r'\bcharity', r'\bmercy'],
        BondType.FAIRNESS: [r'\bfair', r'\bjust', r'\bequal', r'\bequity', r'\bright\b'],
        BondType.CONTRACT: [r'\bpromise', r'\bcontract', r'\bagreem', r'\bvow', r'\boath', r'\bcommit'],
    },
}

# ===== COMPLETE HOHFELD PATTERNS =====
ALL_HOHFELD_PATTERNS = {
    'hebrew': {
        HohfeldState.OBLIGATION: [r'\u05d7\u05d9\u05d9\u05d1', r'\u05e6\u05e8\u05d9\u05db', r'\u05de\u05d5\u05db\u05e8\u05d7', r'\u05de\u05e6\u05d5\u05d5\u05d4'],
        HohfeldState.RIGHT: [r'\u05d6\u05db\u05d5\u05ea', r'\u05e8\u05e9\u05d0\u05d9', r'\u05d6\u05db\u05d0\u05d9', r'\u05de\u05d2\u05d9\u05e2'],
        HohfeldState.LIBERTY: [r'\u05de\u05d5\u05ea\u05e8', r'\u05e8\u05e9\u05d5\u05ea', r'\u05e4\u05d8\u05d5\u05e8', r'\u05d9\u05db\u05d5\u05dc'],
        HohfeldState.NO_RIGHT: [r'\u05d0\u05e1\u05d5\u05e8', r'\u05d0\u05d9\u05e0\u05d5 \u05e8\u05e9\u05d0\u05d9', r'\u05d0\u05d9\u05df.*\u05d6\u05db\u05d5\u05ea'],
    },
    'aramaic': {
        HohfeldState.OBLIGATION: [r'\u05d7\u05d9\u05d9\u05d1', r'\u05de\u05d7\u05d5\u05d9\u05d1', r'\u05d1\u05e2\u05d9'],
        HohfeldState.RIGHT: [r'\u05d6\u05db\u05d5\u05ea', r'\u05e8\u05e9\u05d0\u05d9', r'\u05d6\u05db\u05d9'],
        HohfeldState.LIBERTY: [r'\u05e9\u05e8\u05d9', r'\u05de\u05d5\u05ea\u05e8', r'\u05e4\u05d8\u05d5\u05e8'],
        HohfeldState.NO_RIGHT: [r'\u05d0\u05e1\u05d5\u05e8', r'\u05dc\u05d0.*\u05e8\u05e9\u05d0\u05d9'],
    },
    'classical_chinese': {
        HohfeldState.OBLIGATION: [r'\u5fc5', r'\u9808', r'\u7576', r'\u61c9', r'\u5b9c'],
        HohfeldState.RIGHT: [r'\u53ef', r'\u5f97', r'\u6b0a', r'\u5b9c'],
        HohfeldState.LIBERTY: [r'\u8a31', r'\u4efb', r'\u807d', r'\u514d'],
        HohfeldState.NO_RIGHT: [r'\u4e0d\u53ef', r'\u52ff', r'\u7981', r'\u83ab', r'\u975e'],
    },
    'arabic': {
        HohfeldState.OBLIGATION: [r'\u064a\u062c\u0628', r'\u0648\u0627\u062c\u0628', r'\u0641\u0631\u0636', r'\u0644\u0627\u0632\u0645', r'\u0648\u062c\u0648\u0628'],
        HohfeldState.RIGHT: [r'\u062d\u0642', r'\u064a\u062d\u0642', r'\u062c\u0627\u0626\u0632', r'\u064a\u062c\u0648\u0632'],
        HohfeldState.LIBERTY: [r'\u0645\u0628\u0627\u062d', r'\u062d\u0644\u0627\u0644', r'\u062c\u0627\u0626\u0632', r'\u0627\u0628\u0627\u062d'],
        HohfeldState.NO_RIGHT: [r'\u062d\u0631\u0627\u0645', r'\u0645\u062d\u0631\u0645', r'\u0645\u0645\u0646\u0648\u0639', r'\u0644\u0627 \u064a\u062c\u0648\u0632', r'\u0646\u0647[\u064a\u0649]'],
    },
    'english': {
        HohfeldState.OBLIGATION: [r'\bmust\b', r'\bshall\b', r'\bobligat', r'\bduty', r'\brequir'],
        HohfeldState.RIGHT: [r'\bright\b', r'\bentitle', r'\bdeserve', r'\bclaim'],
        HohfeldState.LIBERTY: [r'\bmay\b', r'\bpermit', r'\ballow', r'\bfree to'],
        HohfeldState.NO_RIGHT: [r'\bforbid', r'\bprohibit', r'\bmust not', r'\bshall not'],
    },
}

print("Patterns defined for 5 languages:")
for lang in ALL_BOND_PATTERNS:
    n = sum(len(p) for p in ALL_BOND_PATTERNS[lang].values())
    print(f"  {lang}: {n} bond patterns")

print("\n" + "="*60)

In [None]:
#@title 4. Load Corpora + Extract Bonds { display-mode: "form" }
#@markdown Loads all corpora - auto-detects GPU and adjusts sampling

import json
import hashlib
import random
import gc
import shutil
from pathlib import Path
from collections import defaultdict
from tqdm.auto import tqdm

# Check if we should skip processing (data loaded from Drive)
if SKIP_PROCESSING:
    print("="*60)
    print("SKIPPING PROCESSING - Using Drive data")
    print("="*60)
    
    # Count passages by language
    by_lang = defaultdict(int)
    with open('data/processed/passages.jsonl', 'r') as f:
        for line in f:
            p = json.loads(line)
            by_lang[p['language']] += 1
    
    print("\nPassages by language:")
    for lang, cnt in sorted(by_lang.items(), key=lambda x: -x[1]):
        print(f"  {lang}: {cnt:,}")
    
    n_passages = sum(by_lang.values())
    print(f"\nTotal: {n_passages:,} passages")

else:
    print("="*60)
    print("LOADING CORPORA")
    print(f"GPU Tier: {GPU_TIER}")
    print(f"Max per language: {MAX_PER_LANG:,}")
    print("="*60)
    
    random.seed(42)
    all_passages = []
    
    # ===== SEFARIA (FIXED) =====
    print("\nLoading Sefaria...")
    sefaria_path = Path('data/raw/Sefaria-Export/json')
    
    CATEGORY_TO_PERIOD = {
        'Tanakh': 'BIBLICAL', 'Torah': 'BIBLICAL', 'Prophets': 'BIBLICAL', 'Writings': 'BIBLICAL',
        'Mishnah': 'TANNAITIC', 'Tosefta': 'TANNAITIC',
        'Talmud': 'AMORAIC', 'Bavli': 'AMORAIC', 'Yerushalmi': 'AMORAIC', 'Midrash': 'AMORAIC',
        'Halakhah': 'RISHONIM', 'Kabbalah': 'RISHONIM', 'Philosophy': 'RISHONIM',
        'Chasidut': 'ACHRONIM', 'Musar': 'ACHRONIM', 'Responsa': 'ACHRONIM',
    }
    
    hebrew_ps, aramaic_ps = [], []
    
    if sefaria_path.exists():
        for jf in tqdm(list(sefaria_path.rglob('*.json')), desc="Sefaria"):
            try:
                with open(jf, 'r', encoding='utf-8') as f:
                    data = json.load(f)
            except:
                continue
            
            # FIX: Check language field correctly
            if data.get('language') != 'he':
                continue
            
            txt = data.get('text', [])
            if not txt:
                continue
            
            rel = jf.relative_to(sefaria_path)
            cat = str(rel.parts[0]) if rel.parts else 'unknown'
            period = CATEGORY_TO_PERIOD.get(cat, 'AMORAIC')
            is_talmud = 'Talmud' in str(jf) or cat in ['Bavli', 'Yerushalmi']
            lang = 'aramaic' if is_talmud else 'hebrew'
            
            def flatten(t):
                results = []
                if isinstance(t, str):
                    tc = re.sub(r'<[^>]+>', '', t).strip()
                    if 20 <= len(tc) <= 2000:
                        hc = sum(1 for c in tc if '\u0590' <= c <= '\u05FF')
                        if hc > 5:
                            pid = hashlib.md5((jf.stem + tc[:30]).encode()).hexdigest()[:12]
                            results.append({'id': f'sef_{pid}', 'text': tc, 'lang': lang, 'period': period})
                elif isinstance(t, (dict, list)):
                    for v in (t.values() if isinstance(t, dict) else t):
                        results.extend(flatten(v))
                return results
            
            ps = flatten(txt)
            if lang == 'hebrew':
                hebrew_ps.extend(ps)
            else:
                aramaic_ps.extend(ps)
    
        # Sample down
        random.shuffle(hebrew_ps)
        random.shuffle(aramaic_ps)
        hebrew_ps = hebrew_ps[:MAX_PER_LANG]
        aramaic_ps = aramaic_ps[:MAX_PER_LANG]
        all_passages.extend(hebrew_ps)
        all_passages.extend(aramaic_ps)
        print(f"  Hebrew: {len(hebrew_ps):,}, Aramaic: {len(aramaic_ps):,}")
        del hebrew_ps, aramaic_ps
        gc.collect()
    else:
        print("  ERROR: Sefaria not found!")
    
    # ===== CHINESE =====
    print("\nLoading Chinese...")
    try:
        with open('data/raw/chinese/chinese_native.json', 'r', encoding='utf-8') as f:
            chinese_data = json.load(f)
        for item in chinese_data:
            all_passages.append({'id': item['id'], 'text': item['text'], 'lang': 'classical_chinese', 'period': item['period']})
        print(f"  Chinese: {len(chinese_data)}")
    except Exception as e:
        print(f"  Error: {e}")
    
    # ===== ISLAMIC =====
    print("\nLoading Islamic...")
    try:
        with open('data/raw/islamic/islamic_native.json', 'r', encoding='utf-8') as f:
            islamic_data = json.load(f)
        for item in islamic_data:
            all_passages.append({'id': item['id'], 'text': item['text'], 'lang': 'arabic', 'period': item['period']})
        print(f"  Arabic: {len(islamic_data)}")
    except Exception as e:
        print(f"  Error: {e}")
    
    # ===== DEAR ABBY =====
    print("\nLoading Dear Abby...")
    try:
        import pandas as pd
        df = pd.read_csv('data/raw/dear_abby.csv')
        abby_count = 0
        for idx, row in df.iterrows():
            q = str(row.get('question_only', ''))
            if q != 'nan' and 50 <= len(q) <= 2000:
                all_passages.append({'id': f'abby_{idx}', 'text': q, 'lang': 'english', 'period': 'DEAR_ABBY'})
                abby_count += 1
        print(f"  English: {abby_count:,}")
    except Exception as e:
        print(f"  Error: {e}")
    
    print(f"\nTOTAL: {len(all_passages):,}")
    
    # Count by language
    by_lang = defaultdict(int)
    for p in all_passages:
        by_lang[p['lang']] += 1
    print("\nBy language:")
    for lang, cnt in sorted(by_lang.items(), key=lambda x: -x[1]):
        print(f"  {lang}: {cnt:,}")
    
    # ===== EXTRACT BONDS =====
    print("\n" + "="*60)
    print("EXTRACTING BONDS")
    print("="*60)
    
    def extract_bond(text, language):
        tn = normalize_text(text, language)
        for bt, pats in ALL_BOND_PATTERNS.get(language, {}).items():
            if any(re.search(p, tn) for p in pats):
                return bt.name
        return 'NONE'
    
    def extract_hohfeld(text, language):
        tn = normalize_text(text, language)
        for st, pats in ALL_HOHFELD_PATTERNS.get(language, {}).items():
            if any(re.search(p, tn) for p in pats):
                return st.name
        return None
    
    bond_counts = defaultdict(lambda: defaultdict(int))
    
    with open('data/processed/passages.jsonl', 'w', encoding='utf-8') as fp, \
         open('data/processed/bonds.jsonl', 'w', encoding='utf-8') as fb:
        
        for p in tqdm(all_passages, desc="Extracting"):
            bond = extract_bond(p['text'], p['lang'])
            hohfeld = extract_hohfeld(p['text'], p['lang'])
            bond_counts[p['lang']][bond] += 1
            
            fp.write(json.dumps({
                'id': p['id'], 'text': p['text'], 'language': p['lang'],
                'time_period': p['period'], 'source': 'x', 'source_type': 'sefaria' if 'sef_' in p['id'] else 'other', 'century': 0
            }, ensure_ascii=False) + '\n')
            
            fb.write(json.dumps({
                'passage_id': p['id'],
                'bonds': {'primary_bond': bond, 'all_bonds': [bond], 'hohfeld': hohfeld, 'language': p['lang']}
            }, ensure_ascii=False) + '\n')
    
    # Coverage report
    print("\nLabel coverage:")
    for lang in sorted(bond_counts.keys()):
        total = sum(bond_counts[lang].values())
        none_ct = bond_counts[lang].get('NONE', 0)
        cov = (total - none_ct) / total * 100 if total else 0
        print(f"  {lang}: {cov:.1f}% labeled ({total-none_ct:,}/{total:,})")
    
    # ===== SAVE TO DRIVE =====
    print("\nSaving to Drive...")
    shutil.copy('data/processed/passages.jsonl', f'{SAVE_DIR}/passages.jsonl')
    shutil.copy('data/processed/bonds.jsonl', f'{SAVE_DIR}/bonds.jsonl')
    print("  Saved!")
    
    n_passages = len(all_passages)
    del all_passages
    gc.collect()

print("\n" + "="*60)
print(f"Cell 4 complete")
print("="*60)

In [None]:
#@title 5. Generate Splits { display-mode: "form" }
#@markdown Creates train/test splits for cross-lingual experiments

import json
import random
import shutil
from collections import defaultdict

print("="*60)
print("GENERATING SPLITS")
print("="*60)

# Check if splits already exist from Drive
if os.path.exists('data/splits/all_splits.json'):
    print("\nSplits already loaded from Drive")
    with open('data/splits/all_splits.json') as f:
        all_splits = json.load(f)
    for name, split in all_splits.items():
        print(f"  {name}: train={split['train_size']:,}, test={split['test_size']:,}")
else:
    random.seed(42)
    
    # Read passage metadata
    passage_meta = []
    with open('data/processed/passages.jsonl', 'r') as f:
        for line in f:
            p = json.loads(line)
            passage_meta.append(p)
    
    print(f"Total passages: {len(passage_meta):,}")
    
    by_lang = defaultdict(list)
    by_period = defaultdict(list)
    for p in passage_meta:
        by_lang[p['language']].append(p['id'])
        by_period[p['time_period']].append(p['id'])
    
    print("\nBy language:")
    for lang, ids in sorted(by_lang.items(), key=lambda x: -len(x[1])):
        print(f"  {lang}: {len(ids):,}")
    
    all_splits = {}
    
    # ===== SPLIT 1: Hebrew -> Others =====
    print("\n" + "-"*60)
    print("SPLIT 1: HEBREW -> OTHERS")
    hebrew_ids = by_lang.get('hebrew', [])
    other_ids = [p['id'] for p in passage_meta if p['language'] != 'hebrew']
    random.shuffle(hebrew_ids)
    random.shuffle(other_ids)
    
    all_splits['hebrew_to_others'] = {
        'train_ids': hebrew_ids,
        'test_ids': other_ids,
        'train_size': len(hebrew_ids),
        'test_size': len(other_ids),
    }
    print(f"  Train (Hebrew): {len(hebrew_ids):,}")
    print(f"  Test (Others): {len(other_ids):,}")
    
    # ===== SPLIT 2: Semitic -> Non-Semitic =====
    print("\n" + "-"*60)
    print("SPLIT 2: SEMITIC -> NON-SEMITIC")
    semitic_ids = by_lang.get('hebrew', []) + by_lang.get('aramaic', []) + by_lang.get('arabic', [])
    non_semitic_ids = by_lang.get('classical_chinese', []) + by_lang.get('english', [])
    random.shuffle(semitic_ids)
    random.shuffle(non_semitic_ids)
    
    all_splits['semitic_to_non_semitic'] = {
        'train_ids': semitic_ids,
        'test_ids': non_semitic_ids,
        'train_size': len(semitic_ids),
        'test_size': len(non_semitic_ids),
    }
    print(f"  Train (Semitic): {len(semitic_ids):,}")
    print(f"  Test (Non-Semitic): {len(non_semitic_ids):,}")
    
    # ===== SPLIT 3: Ancient -> Modern =====
    print("\n" + "-"*60)
    print("SPLIT 3: ANCIENT -> MODERN")
    ancient_periods = {'BIBLICAL', 'TANNAITIC', 'AMORAIC', 'CONFUCIAN', 'DAOIST', 'QURANIC', 'HADITH'}
    modern_periods = {'RISHONIM', 'ACHRONIM', 'DEAR_ABBY'}
    
    ancient_ids = [p['id'] for p in passage_meta if p['time_period'] in ancient_periods]
    modern_ids = [p['id'] for p in passage_meta if p['time_period'] in modern_periods]
    random.shuffle(ancient_ids)
    random.shuffle(modern_ids)
    
    all_splits['ancient_to_modern'] = {
        'train_ids': ancient_ids,
        'test_ids': modern_ids,
        'train_size': len(ancient_ids),
        'test_size': len(modern_ids),
    }
    print(f"  Train (Ancient): {len(ancient_ids):,}")
    print(f"  Test (Modern): {len(modern_ids):,}")
    
    # ===== SPLIT 4: Mixed Baseline =====
    print("\n" + "-"*60)
    print("SPLIT 4: MIXED BASELINE")
    all_ids = [p['id'] for p in passage_meta]
    random.shuffle(all_ids)
    split_idx = int(0.7 * len(all_ids))
    
    all_splits['mixed_baseline'] = {
        'train_ids': all_ids[:split_idx],
        'test_ids': all_ids[split_idx:],
        'train_size': split_idx,
        'test_size': len(all_ids) - split_idx,
    }
    print(f"  Train: {split_idx:,}")
    print(f"  Test: {len(all_ids) - split_idx:,}")
    
    # Save splits
    with open('data/splits/all_splits.json', 'w') as f:
        json.dump(all_splits, f, indent=2)
    
    # Save to Drive
    shutil.copy('data/splits/all_splits.json', f'{SAVE_DIR}/all_splits.json')

print("\n" + "="*60)
print("Splits saved to local and Drive")
print("="*60)

In [None]:
#@title 6. Model Architecture { display-mode: "form" }
#@markdown BIP model with adversarial heads and complete Hohfeld support

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from tqdm.auto import tqdm
import json

print("="*60)
print("MODEL ARCHITECTURE")
print("="*60)

# Index mappings
BOND_TO_IDX = {bt.name: i for i, bt in enumerate(BondType)}
IDX_TO_BOND = {i: bt.name for i, bt in enumerate(BondType)}
LANG_TO_IDX = {'hebrew': 0, 'aramaic': 1, 'classical_chinese': 2, 'arabic': 3, 'english': 4}
IDX_TO_LANG = {i: l for l, i in LANG_TO_IDX.items()}
PERIOD_TO_IDX = {'BIBLICAL': 0, 'TANNAITIC': 1, 'AMORAIC': 2, 'RISHONIM': 3, 'ACHRONIM': 4,
                 'CONFUCIAN': 5, 'DAOIST': 6, 'QURANIC': 7, 'HADITH': 8, 'DEAR_ABBY': 9}
IDX_TO_PERIOD = {i: p for p, i in PERIOD_TO_IDX.items()}
HOHFELD_TO_IDX = {hs.name: i for i, hs in enumerate(HohfeldState)}
IDX_TO_HOHFELD = {i: hs.name for i, hs in enumerate(HohfeldState)}

class GradientReversalLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.alpha, None

class BIPModel(nn.Module):
    def __init__(self, z_dim=64):
        super().__init__()
        self.encoder = AutoModel.from_pretrained("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
        hidden = self.encoder.config.hidden_size  # 384
        
        # Projection to z_bond space
        self.z_proj = nn.Sequential(
            nn.Linear(hidden, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(256, z_dim),
        )
        
        # Task heads
        self.bond_head = nn.Linear(z_dim, len(BondType))
        self.hohfeld_head = nn.Linear(z_dim, len(HohfeldState))
        
        # Adversarial heads
        self.language_head = nn.Linear(z_dim, len(LANG_TO_IDX))
        self.period_head = nn.Linear(z_dim, len(PERIOD_TO_IDX))
    
    def forward(self, input_ids, attention_mask, adv_lambda=1.0):
        enc = self.encoder(input_ids, attention_mask)
        pooled = enc.last_hidden_state[:, 0]  # CLS token
        
        z = self.z_proj(pooled)
        
        # Bond prediction (main task)
        bond_pred = self.bond_head(z)
        hohfeld_pred = self.hohfeld_head(z)
        
        # Adversarial predictions (gradient reversal)
        z_rev = GradientReversalLayer.apply(z, adv_lambda)
        language_pred = self.language_head(z_rev)
        period_pred = self.period_head(z_rev)
        
        return {
            'bond_pred': bond_pred,
            'hohfeld_pred': hohfeld_pred,
            'language_pred': language_pred,
            'period_pred': period_pred,
            'z': z,
        }

# Dataset with Hohfeld support
class NativeDataset(Dataset):
    def __init__(self, ids_set, passages_file, bonds_file, tokenizer, max_len=128):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.data = []
        
        with open(passages_file) as fp, open(bonds_file) as fb:
            for p_line, b_line in tqdm(zip(fp, fb), desc="Loading", unit="line"):
                p = json.loads(p_line)
                b = json.loads(b_line)
                if p['id'] in ids_set and b['passage_id'] == p['id']:
                    self.data.append({
                        'text': p['text'][:1000],
                        'language': p['language'],
                        'period': p['time_period'],
                        'bond': b['bonds']['primary_bond'],
                        'hohfeld': b['bonds']['hohfeld'],
                    })
        print(f"  Loaded {len(self.data):,} samples")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        enc = self.tokenizer(item['text'], truncation=True, max_length=self.max_len,
                            padding='max_length', return_tensors='pt')
        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'bond_label': BOND_TO_IDX.get(item['bond'], 9),
            'language_label': LANG_TO_IDX.get(item['language'], 4),
            'period_label': PERIOD_TO_IDX.get(item['period'], 9),
            'hohfeld_label': HOHFELD_TO_IDX.get(item['hohfeld'], 0) if item['hohfeld'] else 0,
            'language': item['language'],
        }

def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'bond_labels': torch.tensor([x['bond_label'] for x in batch]),
        'language_labels': torch.tensor([x['language_label'] for x in batch]),
        'period_labels': torch.tensor([x['period_label'] for x in batch]),
        'hohfeld_labels': torch.tensor([x['hohfeld_label'] for x in batch]),
        'languages': [x['language'] for x in batch],
    }

print("Model architecture defined")
print(f"  Bond classes: {len(BondType)}")
print(f"  Hohfeld states: {len(HohfeldState)}")
print(f"  Languages: {len(LANG_TO_IDX)}")
print(f"  Periods: {len(PERIOD_TO_IDX)}")
print("\n" + "="*60)

In [None]:
#@title 7. Train BIP Model { display-mode: "form" }
#@markdown Training with tuned adversarial weights and hardware-optimized parameters

from sklearn.metrics import f1_score
import gc

#@markdown **Splits to train:**
TRAIN_HEBREW_TO_OTHERS = True  #@param {type:"boolean"}
TRAIN_SEMITIC_TO_NON_SEMITIC = True  #@param {type:"boolean"}
TRAIN_ANCIENT_TO_MODERN = True  #@param {type:"boolean"}
TRAIN_MIXED_BASELINE = True  #@param {type:"boolean"}

#@markdown **Hyperparameters:**
LANG_WEIGHT = 0.01  #@param {type:"number"}
PERIOD_WEIGHT = 0.01  #@param {type:"number"}
N_EPOCHS = 5  #@param {type:"integer"}

print("="*60)
print("TRAINING BIP MODEL")
print("="*60)
print(f"\nHardware-optimized settings:")
print(f"  GPU Tier:     {GPU_TIER}")
print(f"  Batch size:   {BATCH_SIZE}")
print(f"  Workers:      {NUM_WORKERS}")
print(f"  Learning rate: {LR:.2e}")
print(f"  Adv weights:  lang={LANG_WEIGHT}, period={PERIOD_WEIGHT}")
print("(0.01 prevents loss explosion while maintaining invariance)")

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

with open('data/splits/all_splits.json') as f:
    all_splits = json.load(f)

splits_to_train = []
if TRAIN_HEBREW_TO_OTHERS: splits_to_train.append('hebrew_to_others')
if TRAIN_SEMITIC_TO_NON_SEMITIC: splits_to_train.append('semitic_to_non_semitic')
if TRAIN_ANCIENT_TO_MODERN: splits_to_train.append('ancient_to_modern')
if TRAIN_MIXED_BASELINE: splits_to_train.append('mixed_baseline')

print(f"\nTraining {len(splits_to_train)} splits: {splits_to_train}")

all_results = {}

for split_idx, split_name in enumerate(splits_to_train):
    split_start = time.time()
    print("\n" + "="*60)
    print(f"[{split_idx+1}/{len(splits_to_train)}] {split_name}")
    print("="*60)
    
    split = all_splits[split_name]
    print(f"Train: {split['train_size']:,} | Test: {split['test_size']:,}")
    
    if split['test_size'] < 50:
        print("Test set too small - skipping")
        continue
    
    model = BIPModel().to(device)
    
    train_dataset = NativeDataset(set(split['train_ids']), 'data/processed/passages.jsonl',
                                   'data/processed/bonds.jsonl', tokenizer)
    test_dataset = NativeDataset(set(split['test_ids'][:MAX_TEST_SAMPLES]), 'data/processed/passages.jsonl',
                                  'data/processed/bonds.jsonl', tokenizer)
    
    if len(train_dataset) == 0:
        print("ERROR: No training data!")
        continue
    
    # Use hardware-optimized batch size
    actual_batch = min(BATCH_SIZE, max(32, len(train_dataset) // 20))
    print(f"Actual batch size: {actual_batch}")
    
    train_loader = DataLoader(train_dataset, batch_size=actual_batch, shuffle=True,
                              collate_fn=collate_fn, drop_last=True, num_workers=NUM_WORKERS, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=actual_batch*2, shuffle=False,
                             collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
    
    def get_adv_lambda(epoch, warmup=2):
        if epoch <= warmup:
            return 0.1 + 0.9 * (epoch / warmup)
        return 1.0
    
    best_loss = float('inf')
    
    for epoch in range(1, N_EPOCHS + 1):
        model.train()
        total_loss = 0
        n_batches = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
            optimizer.zero_grad()
            
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            bond_labels = batch['bond_labels'].to(device)
            language_labels = batch['language_labels'].to(device)
            period_labels = batch['period_labels'].to(device)
            
            adv_lambda = get_adv_lambda(epoch)
            
            # Use new autocast API
            with torch.amp.autocast('cuda', enabled=USE_AMP):
                out = model(input_ids, attention_mask, adv_lambda=adv_lambda)
                
                loss_bond = F.cross_entropy(out['bond_pred'], bond_labels)
                loss_lang = F.cross_entropy(out['language_pred'], language_labels)
                loss_period = F.cross_entropy(out['period_pred'], period_labels)
            
            loss = loss_bond + LANG_WEIGHT * loss_lang + PERIOD_WEIGHT * loss_period
            
            if USE_AMP and scaler:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            total_loss += loss.item()
            n_batches += 1
        
        avg_loss = total_loss / n_batches
        print(f"Epoch {epoch}: Loss={avg_loss:.4f} (adv_lambda={adv_lambda:.2f})")
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), f'models/checkpoints/best_{split_name}.pt')
            torch.save(model.state_dict(), f'{SAVE_DIR}/best_{split_name}.pt')
    
    # Evaluate
    print("\nEvaluating...")
    model.load_state_dict(torch.load(f'models/checkpoints/best_{split_name}.pt'))
    model.eval()
    
    all_preds = {'bond': [], 'lang': []}
    all_labels = {'bond': [], 'lang': []}
    all_languages = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            out = model(batch['input_ids'].to(device), batch['attention_mask'].to(device), 0)
            all_preds['bond'].extend(out['bond_pred'].argmax(-1).cpu().tolist())
            all_preds['lang'].extend(out['language_pred'].argmax(-1).cpu().tolist())
            all_labels['bond'].extend(batch['bond_labels'].tolist())
            all_labels['lang'].extend(batch['language_labels'].tolist())
            all_languages.extend(batch['languages'])
    
    bond_f1 = f1_score(all_labels['bond'], all_preds['bond'], average='macro', zero_division=0)
    bond_acc = sum(p == l for p, l in zip(all_preds['bond'], all_labels['bond'])) / len(all_preds['bond'])
    lang_acc = sum(p == l for p, l in zip(all_preds['lang'], all_labels['lang'])) / len(all_preds['lang'])
    
    # Per-language F1
    lang_f1 = {}
    for lang in set(all_languages):
        mask = [l == lang for l in all_languages]
        if sum(mask) > 10:
            preds = [p for p, m in zip(all_preds['bond'], mask) if m]
            labels = [l for l, m in zip(all_labels['bond'], mask) if m]
            lang_f1[lang] = {'f1': f1_score(labels, preds, average='macro', zero_division=0), 'n': sum(mask)}
    
    all_results[split_name] = {
        'bond_f1_macro': bond_f1,
        'bond_acc': bond_acc,
        'language_acc': lang_acc,
        'per_language_f1': lang_f1,
        'training_time': time.time() - split_start
    }
    
    print(f"\n{split_name} RESULTS:")
    print(f"  Bond F1 (macro): {bond_f1:.3f} ({bond_f1/0.1:.1f}x chance)")
    print(f"  Bond accuracy:   {bond_acc:.1%}")
    print(f"  Language acc:    {lang_acc:.1%} (want ~20% = invariant)")
    print("  Per-language:")
    for lang, m in sorted(lang_f1.items(), key=lambda x: -x[1]['n']):
        print(f"    {lang:20s}: F1={m['f1']:.3f} (n={m['n']:,})")
    
    # GPU memory usage
    if torch.cuda.is_available():
        mem = torch.cuda.memory_allocated() / 1e9
        print(f"\n  GPU memory: {mem:.1f} GB / {VRAM_GB:.1f} GB ({mem/VRAM_GB*100:.0f}%)")
    
    del model, train_dataset, test_dataset
    gc.collect()
    torch.cuda.empty_cache()

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

In [None]:
#@title 8. Linear Probe Test { display-mode: "form" }
#@markdown Tests if z_bond encodes language/period (should be low = invariant)

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import numpy as np

print("="*60)
print("LINEAR PROBE TEST")
print("="*60)
print("\nIf probe accuracy is NEAR CHANCE, representation is INVARIANT")
print("(This is what we want for BIP)")

probe_results = {}

for split_name in ['hebrew_to_others', 'semitic_to_non_semitic']:
    model_path = f'{SAVE_DIR}/best_{split_name}.pt'
    if not os.path.exists(model_path):
        print(f"\nSkipping {split_name} - no saved model")
        continue
    
    print(f"\n{'='*50}")
    print(f"PROBE: {split_name}")
    print('='*50)
    
    model = BIPModel().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    test_ids = set(all_splits[split_name]['test_ids'][:5000])
    test_dataset = NativeDataset(test_ids, 'data/processed/passages.jsonl',
                                  'data/processed/bonds.jsonl', tokenizer)
    
    if len(test_dataset) < 50:
        print(f"  Skip - only {len(test_dataset)} samples")
        continue
    
    test_loader = DataLoader(test_dataset, batch_size=128, collate_fn=collate_fn, num_workers=0)
    
    all_z, all_lang, all_period = [], [], []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Extract"):
            out = model(batch['input_ids'].to(device), batch['attention_mask'].to(device), 0)
            all_z.append(out['z'].cpu().numpy())
            all_lang.extend(batch['language_labels'].tolist())
            all_period.extend(batch['period_labels'].tolist())
    
    X = np.vstack(all_z)
    y_lang = np.array(all_lang)
    y_period = np.array(all_period)
    
    scaler_probe = StandardScaler()
    X_scaled = scaler_probe.fit_transform(X)
    
    # Train/test split for probes
    n = len(X)
    idx = np.random.permutation(n)
    train_idx, test_idx = idx[:int(0.7*n)], idx[int(0.7*n):]
    
    # Language probe - check for multiple classes
    unique_langs = np.unique(y_lang[test_idx])
    if len(unique_langs) < 2:
        print(f"  SKIP language probe - only {len(unique_langs)} class")
        lang_acc = 1.0 / max(1, len(np.unique(y_lang)))
        lang_chance = lang_acc
    else:
        lang_probe = LogisticRegression(max_iter=1000, n_jobs=-1)
        lang_probe.fit(X_scaled[train_idx], y_lang[train_idx])
        lang_acc = (lang_probe.predict(X_scaled[test_idx]) == y_lang[test_idx]).mean()
        lang_chance = 1.0 / len(unique_langs)
    
    # Period probe - same check
    unique_periods = np.unique(y_period[test_idx])
    if len(unique_periods) < 2:
        print(f"  SKIP period probe - only {len(unique_periods)} class")
        period_acc = 1.0 / max(1, len(np.unique(y_period)))
        period_chance = period_acc
    else:
        period_probe = LogisticRegression(max_iter=1000, n_jobs=-1)
        period_probe.fit(X_scaled[train_idx], y_period[train_idx])
        period_acc = (period_probe.predict(X_scaled[test_idx]) == y_period[test_idx]).mean()
        period_chance = 1.0 / len(unique_periods)
    
    lang_status = "INVARIANT" if lang_acc < lang_chance + 0.15 else "NOT invariant"
    period_status = "INVARIANT" if period_acc < period_chance + 0.15 else "NOT invariant"
    
    probe_results[split_name] = {
        'language_acc': lang_acc,
        'language_chance': lang_chance,
        'language_status': lang_status,
        'period_acc': period_acc,
        'period_chance': period_chance,
        'period_status': period_status,
    }
    
    print(f"\nRESULTS:")
    print(f"  Language: {lang_acc:.1%} (chance: {lang_chance:.1%}) -> {lang_status}")
    print(f"  Period:   {period_acc:.1%} (chance: {period_chance:.1%}) -> {period_status}")
    
    del model
    torch.cuda.empty_cache()

print("\n" + "="*60)
print("Probe tests complete")
print("="*60)

In [None]:
#@title 9. Final Results { display-mode: "form" }
#@markdown Comprehensive summary with verdict

import json
import shutil

print("="*60)
print("FINAL BIP EVALUATION (v10.1)")
print("="*60)

print(f"\nHardware: {GPU_TIER} ({VRAM_GB:.0f}GB VRAM, {RAM_GB:.0f}GB RAM)")

print("\n" + "-"*60)
print("CROSS-DOMAIN TRANSFER RESULTS")
print("-"*60)

successful_splits = []
for name, r in all_results.items():
    ratio = r['bond_f1_macro'] / 0.1
    lang_acc = r['language_acc']
    
    transfer_ok = ratio > 1.3
    invariant_ok = lang_acc < 0.35  # Near chance (20%)
    
    status = "SUCCESS" if (transfer_ok and invariant_ok) else "PARTIAL" if transfer_ok else "FAIL"
    
    print(f"\n{name}:")
    print(f"  Bond F1:     {r['bond_f1_macro']:.3f} ({ratio:.1f}x chance) {'OK' if transfer_ok else 'WEAK'}")
    print(f"  Language:    {lang_acc:.1%} {'INVARIANT' if invariant_ok else 'LEAKING'}")
    print(f"  -> {status}")
    
    if transfer_ok and invariant_ok:
        successful_splits.append(name)

print("\n" + "-"*60)
print("VERDICT")
print("-"*60)

n_success = len(successful_splits)
if n_success >= 2:
    verdict = "STRONGLY_SUPPORTED"
    msg = "Multiple independent transfer paths demonstrate universal structure"
elif n_success >= 1:
    verdict = "SUPPORTED"
    msg = "At least one transfer path works"
elif any(r['bond_f1_macro'] > 0.13 for r in all_results.values()):
    verdict = "PARTIAL"
    msg = "Some transfer signal, but not robust"
else:
    verdict = "INCONCLUSIVE"
    msg = "No clear transfer demonstrated"

print(f"\n  Successful transfers: {n_success}/{len(all_results)}")
print(f"  Splits: {successful_splits if successful_splits else 'None'}")
print(f"\n  VERDICT: {verdict}")
print(f"  {msg}")

# Save results
final_results = {
    'all_results': all_results,
    'probe_results': probe_results if 'probe_results' in dir() else {},
    'successful_splits': successful_splits,
    'verdict': verdict,
    'hardware': {'gpu': GPU_TIER, 'vram_gb': VRAM_GB, 'ram_gb': RAM_GB},
    'settings': {'batch_size': BATCH_SIZE, 'max_per_lang': MAX_PER_LANG, 'num_workers': NUM_WORKERS},
    'experiment_time': time.time() - EXPERIMENT_START,
}

with open('results/final_results.json', 'w') as f:
    json.dump(final_results, f, indent=2, default=str)
shutil.copy('results/final_results.json', f'{SAVE_DIR}/final_results.json')

print(f"\nTotal time: {(time.time() - EXPERIMENT_START)/60:.1f} minutes")
print("Results saved to Drive!")
print("\n" + "="*60)

In [None]:
#@title 10. Download Results { display-mode: "form" }
#@markdown Download all models and results

from google.colab import files
import zipfile

print("Creating download package...")

with zipfile.ZipFile('BIP_v10.1_results.zip', 'w', zipfile.ZIP_DEFLATED) as zf:
    # Results
    if os.path.exists('results/final_results.json'):
        zf.write('results/final_results.json')
    
    # Models (from Drive)
    for f in os.listdir(SAVE_DIR):
        if f.endswith('.pt'):
            zf.write(f'{SAVE_DIR}/{f}', f'models/{f}')
    
    # Config
    if os.path.exists('data/splits/all_splits.json'):
        zf.write('data/splits/all_splits.json')

print("\nDownload ready!")
files.download('BIP_v10.1_results.zip')