# Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# CONFIGURATION INFORMATION

INV_LAYERS = [9, 10, 11, 12] # [i for i in range(16)]
INV_DIMENSIONS = [2, 4, 8] # [2, 4, 8, 16]

# INV_LAYERS = [2, 3, 4, 6, 7, 8] # [i for i in range(16)]
# INV_DIMENSIONS = [2, 4, 8] # [2, 4, 8, 16]


EXPERIMENT_NAME = "year_localization" # "gold_matching_years_only"

In [3]:
import getpass
import os
import sys

USER = 'aditijb'

PROJECT_NAME = 'ood-prediction'
DATA_DIR = f'/scr-ssd/{USER}/data'
#MODEL_DIR = f'{PROJECT_NAME}/models'
MODEL_DIR = f'/scr-ssd/{USER}/models'

sys.path.append(f'/nlp/scr/{USER}/{PROJECT_NAME}/src')
os.environ["HF_HOME"] = f'/scr-ssd/{USER}/models'
os.environ["HF_HUB"] = f'/scr-ssd/{USER}/models'


CORE_LIB_DIR = f'/nlp/scr/hij/core'
RAVEL_LIB_DIR = f'/nlp/scr/hij/internal-ravel/src'
PYVENE_LIB_DIR = f'/nlp/scr/hij/pyvene'
import sys
sys.path.append(CORE_LIB_DIR)
sys.path.append(RAVEL_LIB_DIR)
sys.path.append(PYVENE_LIB_DIR)

In [4]:
import numpy as np
import random
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(0)

In [5]:
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Models

In [6]:
import gc

#del model
gc.collect()
torch.cuda.empty_cache()

In [8]:
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


model_id = "allenai/OLMo-2-0425-1B"
revision = "main"
tokenizer = AutoTokenizer.from_pretrained(
    model_id, padding_side='left', revision=revision,
    cache_dir=MODEL_DIR)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
      model_id, low_cpu_mem_usage=True, device_map='auto',
      torch_dtype=torch.bfloat16, cache_dir=MODEL_DIR)
model = model.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Behavioral Testing

In [None]:
# These are matching years on the following range:
# START_YEAR = 1500
# END_YEAR = 3500


matching_years = [1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574, 1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1600, 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618, 1619, 1620, 1621, 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1636, 1637, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673, 1674, 1675, 1676, 1677, 1678, 1679, 1680, 1681, 1682, 1683, 1684, 1685, 1686, 1687, 1688, 1689, 1690, 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1701, 1702, 1703, 1704, 1705, 1706, 1707, 1708, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1726, 1727, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1749, 1750, 1751, 1752, 1753, 1754, 1755, 1756, 1757, 1758, 1759, 1760, 1761, 1762, 1763, 1764, 1765, 1766, 1767, 1768, 1769, 1770, 1771, 1772, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1783, 1784, 1785, 1786, 1787, 1788, 1789, 1790, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1841, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1864, 1865, 1866, 1867, 1868, 1869, 1870, 1871, 1872, 1873, 1874, 1875, 1876, 1877, 1878, 1879, 1880, 1881, 1882, 1883, 1884, 1885, 1886, 1887, 1888, 1889, 1890, 1891, 1892, 1893, 1894, 1895, 1896, 1897, 1898, 1899, 1900, 1901, 1902, 1903, 1904, 1905, 1906, 1907, 1908, 1909, 1910, 1911, 1912, 1913, 1914, 1915, 1916, 1917, 1918, 1919, 1920, 1921, 1922, 1923, 1924, 1925, 1926, 1927, 1928, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937, 1938, 1939, 1940, 1941, 1942, 1943, 1944, 1945, 1946, 1947, 1948, 1949, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957, 1958, 1959, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967, 1968, 1969, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2025, 2026, 2028, 2029, 2030, 2035, 2038, 2039, 2040, 2045, 2046, 2047, 2048, 2049, 2050, 2051, 2054, 2055, 2056, 2057, 2058, 2059, 2060, 2070, 2077, 2080, 2081, 2084, 2099, 2100, 2106, 2109, 2199, 2300, 2400, 2500, 2640, 2666, 2700, 2900, 3000, 3001, 3270]
# past_years
#  523 [1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574, 1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1600, 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618, 1619, 1620, 1621, 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1636, 1637, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673, 1674, 1675, 1676, 1677, 1678, 1679, 1680, 1681, 1682, 1683, 1684, 1685, 1686, 1687, 1688, 1689, 1690, 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1701, 1702, 1703, 1704, 1705, 1706, 1707, 1708, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1726, 1727, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1749, 1750, 1751, 1752, 1753, 1754, 1755, 1756, 1757, 1758, 1759, 1760, 1761, 1762, 1763, 1764, 1765, 1766, 1767, 1768, 1769, 1770, 1771, 1772, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1783, 1784, 1785, 1786, 1787, 1788, 1789, 1790, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1841, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1864, 1865, 1866, 1867, 1868, 1869, 1870, 1871, 1872, 1873, 1874, 1875, 1876, 1877, 1878, 1879, 1880, 1881, 1882, 1883, 1884, 1885, 1886, 1887, 1888, 1889, 1890, 1891, 1892, 1893, 1894, 1895, 1896, 1897, 1898, 1899, 1900, 1901, 1902, 1903, 1904, 1905, 1906, 1907, 1908, 1909, 1910, 1911, 1912, 1913, 1914, 1915, 1916, 1917, 1918, 1919, 1920, 1921, 1922, 1923, 1924, 1925, 1926, 1927, 1928, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937, 1938, 1939, 1940, 1941, 1942, 1943, 1944, 1945, 1946, 1947, 1948, 1949, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957, 1958, 1959, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967, 1968, 1969, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022]
# presfut_years
#  43 [2025, 2026, 2028, 2029, 2030, 2035, 2038, 2039, 2040, 2045, 2046, 2047, 2048, 2049, 2050, 2051, 2054, 2055, 2056, 2057, 2058, 2059, 2060, 2070, 2077, 2080, 2081, 2084, 2099, 2100, 2106, 2109, 2199, 2300, 2400, 2500, 2640, 2666, 2700, 2900, 3000, 3001, 3270]

In [None]:
from generation_utils import generate_distribution_batched

prompt_template = ['In {year} there']

# year_value = list(range(1500, 3500)) # jing was previously using this

year_value = matching_years
print(len(year_value))
prompts = [prompt_template[0].format(year=year) for year in year_value]

In [None]:
predictions = generate_distribution_batched(model, tokenizer, [prompt_template[0].format(year=i) for i in year_value])

In [None]:
# Filter out year where the tense prediction is wrong.

PAST_TENSE = ['was', 'were']
FUTURE_TENSE = ['is', 'are', 'will']


kept_years = [year for year, pred in zip(year_value, predictions)
              if (pred[0][0].strip() in PAST_TENSE and year < 2024) or
                 (pred[0][0].strip() in FUTURE_TENSE and year >= 2024)]

kept_prompts = [prompt_template[0].format(year=year) for year in kept_years]
len(kept_prompts)

In [None]:
past_tense_prompt = [kept_prompts[i] for i in range(len(kept_prompts)) if kept_years[i] < 2024]
future_tense_prompt = [kept_prompts[i] for i in range(len(kept_prompts)) if kept_years[i] >= 2024]

print('#Past=%d' % len(past_tense_prompt), '#Future=%d' % len(future_tense_prompt))

In [None]:
# Create intervention data
import json

random.seed(0)
random.shuffle(past_tense_prompt)
random.shuffle(future_tense_prompt)

data = {
    'train': {'correct': past_tense_prompt[:256] + future_tense_prompt[:20] * 10, 'wrong': []},
    'val': {'correct': past_tense_prompt[256:256+128] + future_tense_prompt[20:25] * 10, 'wrong': []},
    'test': {'correct': past_tense_prompt[512+128:] + future_tense_prompt[25:] * 10, 'wrong': []},
}

os.makedirs(f"{EXPERIMENT_NAME}", exist_ok=True)
json.dump(data, open(f'{EXPERIMENT_NAME}/year_{model_id.split("/")[-1]}_{revision}_correct_split_0.json', 'w'))

# Localizing Representations of Year

In [None]:
import json

from data_utils import load_intervention_data, _BASE_TEMPLATE
from generation_utils import generate_batched

sample_size = 512
split_type = ''
SPLIT_ID = '1'
mode = 'das'
data_split = json.load(open(os.path.join(f'{EXPERIMENT_NAME}/year_{model_id.split("/")[-1]}_{revision}_correct_split_0.json')))

verified_examples = data_split['train']['correct'][:sample_size]
print(verified_examples[:2])

intervention_prompt_to_output = generate_batched(model, tokenizer, [p for s in data_split for k in ('correct', 'wrong') for p in data_split[s][k]], max_new_tokens=1)
prompt_to_vars = {p: {'input': p,
                      'label': intervention_prompt_to_output[p],
                      'split': _BASE_TEMPLATE}
                 for s in data_split for k in ('correct', 'wrong') for p in data_split[s][k]}


def get_tense(be_word):
  be_word = be_word.lower().strip()
  # We distinguish present and future here even though English does not.
  if be_word == 'will':
    return 'future'
  elif be_word == 'is' or be_word == 'are':
    return 'present'
  elif be_word.endswith('ed') or be_word == 'was' or be_word == 'were':
    return 'past'
  else:
    raise ValueError(f'Unknown tense for {be_word}')


def set_tense(be_word, tense):
  normalize_be_word = be_word.lower().strip()
  tense_table = {
      'future': {'will': 'will', 'is': 'will', 'are': 'will', 'was': 'will', 'were': 'will'},
      'present': {'will': 'is', 'is': 'is', 'are': 'are', 'was': 'is', 'were': 'are'},
      'past': {'will': 'was', 'is': 'was', 'are': 'were', 'was': 'was', 'were': 'were'},
  }  
  new_be_word = tense_table[tense][normalize_be_word]
  if be_word.startswith(' '):
    new_be_word = ' ' + new_be_word
  return new_be_word


# trying to map "will" to "were" -- might be too difficult/many things to change. TODO ADITI if time
  # tense_table = {
  #     'future': {'will': ['will'], 'is': ['will'], 'are': ['will'], 'was': ['will'], 'were': ['will']},
  #     'present': {'will': ['is'], 'is': ['is'], 'are': ['are'], 'was': ['is'], 'were': ['are']},
  #     'past': {'will': ['was', 'were'], 'is': ['was'], 'are': ['were'], 'was': ['was'], 'were': ['were']},
  # }
  # new_be_word = tense_table[tense][normalize_be_word]
  # if be_word.startswith(' '):
  #   new_be_word = [' ' + w for w in new_be_word]
  # return new_be_word

split_to_raw_example, split_to_dataset = load_intervention_data(
    mode, verified_examples, data_split, prompt_to_vars,
    inv_label_fn=lambda x, y: set_tense(x['label'], get_tense(y['label'])),
    filter_fn=lambda x, y: get_tense(x['label']) != get_tense(y['label']),
    max_example_per_split=20480,
    max_example_per_eval_split=10)

In [None]:
len(split_to_dataset['das-train'])

In [None]:
split_to_dataset['das-train'][6]

In [None]:
from collections import Counter

Counter([example['inv_label'] for example in split_to_dataset['das-train']])

In [None]:
Counter([example['inv_label'] == example['label'] for example in split_to_dataset['das-train']])

In [None]:
import getpass
import os
import sys

USER = getpass.getuser()

CORE_LIB_DIR = f'/nlp/scr/{USER}/core'
RAVEL_LIB_DIR = f'/nlp/scr/{USER}/internal-ravel/src'
PYVENE_LIB_DIR = f'/nlp/scr/{USER}/pyvene'
sys.path.append(CORE_LIB_DIR)
sys.path.append(RAVEL_LIB_DIR)
sys.path.append(PYVENE_LIB_DIR)

In [None]:
SCR_MODEL_DIR = f'/nlp/scr/{USER}/olmo_das_2/{EXPERIMENT_NAME}'
os.makedirs(SCR_MODEL_DIR, exist_ok=True)

In [None]:
import gc


gc.collect()
torch.cuda.empty_cache()

In [None]:
import collections
import gc
import re

from tqdm import tqdm, trange
from transformers import get_linear_schedule_with_warmup
from datasets import concatenate_datasets
from torch.nn import CrossEntropyLoss
from causal_interventions import compute_string_based_metrics

import pyvene as pv
from utils.intervention_utils import LowRankRotatedSpaceIntervention, get_intervention_config, train_intervention_step, remove_invalid_token_id, remove_all_forward_hooks

from utils.dataset_utils import get_multitask_dataloader
from utils.metric_utils import compute_cross_entropy_loss
from causal_interventions import eval_with_interventions_batched, compute_metrics


def train_alignment(config):
  print('Training Tasks: %s' % config['training_tasks'])
  concat_split_to_dataset = {f'joint-{split}':
      concatenate_datasets([split_to_dataset[f'{task_name}-{split}'].select(
          np.random.choice(len(split_to_dataset[f'{task_name}-{split}']),
                           size=(1024 if config['training_tasks'][task_name] == 'match_base' else len(split_to_dataset[f'{task_name}-{split}'])),
                           replace=False))
                            for task_name in config['training_tasks']
                            # repeat
                            for _ in range(1 if isinstance(config['training_tasks'][task_name], str) or split != 'train'
                                             else config['training_tasks'][task_name][1])
                            if f'{task_name}-{split}' in split_to_dataset])
      for split in ('train',)}
  inv_task = '|'.join([task_name for task_name, label in config['training_tasks'].items()
                       if label == 'match_source' or 'match_source' in label])
  inv_task = inv_task.split('|')
  print('Training tasks matching source label: %s' % inv_task)
  print('#Training examples: %d' % len(concat_split_to_dataset['joint-train']))
  max_train_example = int(config['max_train_percentage'] * len(concat_split_to_dataset['joint-train']))
  train_dataloader = get_multitask_dataloader(
      concat_split_to_dataset['joint-train'].select(range(max_train_example)),
      tokenizer=tokenizer,
      batch_size=TRAINING_BATCH_SIZE, prompt_max_length=INPUT_MAX_LEN,
      output_max_length=config['max_output_tokens'] + int(tokenizer.bos_token is not None),
      # The set of splits to load as cause tasks
      cause_tasks=[BASE_TEMPLATE, SOURCE_TEMPLATE],
      first_n=config['max_output_tokens'])

  # Create Model
  split_to_inv_locations = config['split_to_inv_locations']
  intervenable_config = get_intervention_config(
      type(model), config['intervenable_config']['intervenable_representation_type'],
      config['intervenable_config']['intervenable_layer'],
      config['intervenable_config']['intervenable_interventions_type'],
      intervention_dimension=config['intervention_dimension'])
  intervenable = pv.IntervenableModel(intervenable_config, model)
  intervenable.set_device("cuda")
  intervenable.disable_model_gradients()

  # Training
  epochs = config['training_epoch']
  gradient_accumulation_steps = 1
  total_step = 0

  warm_up_steps = 0 # 0.1 * t_total
  regularization_coefficient = config['regularization_coefficient']
  optimizer_params = []
  for k, v in intervenable.interventions.items():
      if isinstance(v, LowRankRotatedSpaceIntervention):
        optimizer_params += [{'params': v.rotate_layer.parameters()}]
      # if isinstance(v[0], LowRankRotatedSpaceIntervention):  # CHANGED FROM JING'S VERSION! using v instead of v[0]
        # optimizer_params += [{'params': v[0].rotate_layer.parameters()}]
      else:
        raise NotImplementedError
  optimizer = torch.optim.AdamW(
      optimizer_params, lr=config['init_lr'], weight_decay=0)
  scheduler = get_linear_schedule_with_warmup(
      optimizer, num_warmup_steps=warm_up_steps,
      num_training_steps=int(10 * len(train_dataloader))
  )

  #intervenable.model.train() # train enables drop-off but no grads
  print("base model trainable parameters: ", pv.count_parameters(intervenable.model))
  print("intervention trainable parameters: ", intervenable.count_parameters())
  train_iterator = trange(0, int(epochs), desc="Epoch")

  num_output_tokens = config['max_output_tokens']
  for epoch in train_iterator:
      epoch_iterator = tqdm(
          train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
      )
      aggreated_stats = collections.defaultdict(list)
      for step, inputs in enumerate(epoch_iterator):
          for k, v in inputs.items():
              if v is not None and isinstance(v, torch.Tensor):
                  inputs[k] = v.to("cuda")
          b_s = inputs["input_ids"].shape[0]
          position_ids = {f'{prefix}position_ids': intervenable.model.prepare_inputs_for_generation(
                  input_ids=inputs[f"{prefix}input_ids"], attention_mask=inputs[f"{prefix}attention_mask"])['position_ids']
                  for prefix in ('', 'source_')}
          inputs.update(position_ids)
          for key in inputs:
            if key in ('input_ids', 'source_input_ids', 'attention_mask', 'source_attention_mask', 'position_ids', 'source_position_ids'):
              inputs[key] = inputs[key].to(device)

          counterfactual_outputs = train_intervention_step(
              intervenable, inputs, split_to_inv_locations, pad_token_id=tokenizer.pad_token_id)
          eval_metrics = compute_metrics(
              {'inv_outputs': [counterfactual_outputs.logits[:, -num_output_tokens-1:-1]]},
              [inputs['labels'][:, :num_output_tokens]],
              last_n_tokens=num_output_tokens,
              pad_token_id=tokenizer.pad_token_id,
          )
          loss = compute_cross_entropy_loss(
              counterfactual_outputs.logits,
              inputs["labels"][:, :num_output_tokens],
              next_n_tokens=num_output_tokens,
              pad_token_id=tokenizer.pad_token_id,
          )
          aggreated_stats['loss'].append(loss.item())
          aggreated_stats['acc'].append(eval_metrics['inv_outputs']["accuracy"])
          epoch_iterator.set_postfix({k: round(np.mean(aggreated_stats[k]), 2) for k in aggreated_stats})

          if step < 3:
            print('\nTokens to intervene:')
            intervention_locations = [split_to_inv_locations[inputs["split"][i]]['inv_position'] for i in range(len(inputs["split"]))]
            source_intervention_locations = [split_to_inv_locations[inputs["source_split"][i]]['inv_position'] for i in range(len(inputs["split"]))]
            print(inputs['input'][:3])
            print(inputs['source_input'][:3])
            print('Base:', tokenizer.batch_decode([inputs['input_ids'][i][intervention_locations[i]] for i in range(len(inputs["split"]))]))
            print('Source:', tokenizer.batch_decode([inputs['source_input_ids'][i][source_intervention_locations[i]] for i in range(len(inputs["split"]))]))
            print('Output:', tokenizer.batch_decode(torch.argmax(counterfactual_outputs.logits[:, -num_output_tokens-1:-1], dim=-1)))
            print('Label     :', tokenizer.batch_decode(remove_invalid_token_id(inputs['labels'][:, :num_output_tokens], tokenizer.pad_token_id)))
            print('Base Label:', tokenizer.batch_decode(remove_invalid_token_id(inputs['base_labels'][:, :num_output_tokens], tokenizer.pad_token_id)))

          if gradient_accumulation_steps > 1:
              loss = loss / gradient_accumulation_steps
          if total_step % gradient_accumulation_steps == 0:
              if not (gradient_accumulation_steps > 1 and total_step == 0):
                  loss.backward()
                  optimizer.step()
                  scheduler.step()
                  intervenable.set_zero_grad()
          total_step += 1
  return intervenable, intervenable_config


def run_exp(config):
  inv_tasks = '+'.join([''.join(re.findall(r'[A-Za-z]+', t)) + ('' if isinstance(l, str) else str(l[1])) for t, l in config['training_tasks'].items() if l == 'match_source' or 'match_source' in l])
  control_tasks = '+'.join([''.join(re.findall(r'[A-Za-z]+', t)) for t, l in config['training_tasks'].items() if l == 'match_base' or 'match_base' in l])
  task_compressed = ((inv_tasks + '_ex_' + control_tasks) if control_tasks else inv_tasks).replace('AZaz', '')
  das_type = 'multi_das' if len(config['training_tasks']) > 1 else 'das_baseline'
  if config['intervenable_config']['intervenable_interventions_type'] == LowRankRotatedSpaceIntervention:
    das_type = das_type.replace('das', 'daslora')
  split_to_inv_locations = config['split_to_inv_locations']
  input_len = list(split_to_inv_locations.values())[0]['max_input_length']
  inv_pos = min([x['inv_position'][0] for x in split_to_inv_locations.values()])
  inv_loc_name = 'len%d_pos%s' % (input_len, 'e' if inv_pos != input_len - 1 else 'f')
  training_data_percentage = int(config['max_train_percentage'] * 100)
  suffix = f"_example{len(verified_examples)}_{config['intervenable_config']['intervenable_representation_type']}"
  layer = '%s_%s' % (min(config['intervenable_config']['intervenable_layer']), max(config['intervenable_config']['intervenable_layer'])) if isinstance(config['intervenable_config']['intervenable_layer'], list) else config['intervenable_config']['intervenable_layer']
  model_name = model.name_or_path.split('/')[-1]
  run_name = f"{model_name}-layer{layer}-dim{config['intervention_dimension']}-{das_type}_{config['max_output_tokens']}tok_{task_compressed}-mmlu_id-{SPLIT_ID}_{inv_loc_name}_ep{config['training_epoch']}{suffix}"
  config['run_name_prefix'] = run_name#.rsplit('_ep', 1)[0]
  print(run_name)
  log_file_path = os.path.join(SCR_MODEL_DIR, 'logs', f'{run_name}.log')
  if True:
      print(run_name)
      intervenable, intervenable_config = train_alignment(config)
      # Save model
      # torch.save({k: v[0].rotate_layer.weight for k, v in intervenable.interventions.items()},
      torch.save({k: v.rotate_layer.weight for k, v in intervenable.interventions.items()},
                 os.path.join(SCR_MODEL_DIR, f'{run_name}.pt'))
      print('Model saved to %s' % os.path.join(SCR_MODEL_DIR, f'{run_name}.pt'))
      gc.collect()
      torch.cuda.empty_cache()
      # eval
      with torch.no_grad():
        split_to_eval_metrics = eval_with_interventions_batched(
            intervenable, eval_split_to_dataset,
            split_to_inv_locations,
            tokenizer,
            compute_metrics_fn=compute_metrics,
            max_new_tokens=config['max_output_tokens'],
            eval_batch_size=EVAL_BATCH_SIZE,
            inference_mode='generate',
            debug_print=True,
          )
      print('Mean IIA: %.4f' % np.mean(
          [v['metrics']['labels']['inv_outputs']['accuracy'] for k, v in split_to_eval_metrics.items()]))
      print('Mean correct IIA: %.4f' % np.mean(
          [v['metrics']['labels']['inv_outputs']['accuracy'] for k, v in split_to_eval_metrics.items() if '-correct' in k]))
      print('Mean wrong IIA: %.4f' % np.mean(
          [v['metrics']['labels']['inv_outputs']['accuracy'] for k, v in split_to_eval_metrics.items() if '-wrong' in k]))

  json.dump(split_to_eval_metrics, open(os.path.join(SCR_MODEL_DIR, f'{run_name}_evalall.json'), 'w'))
  print('Saved to %s' % os.path.join(SCR_MODEL_DIR, f'{run_name}.json'))
  remove_all_forward_hooks(intervenable)
  return intervenable

assert mode == 'das'

INPUT_MAX_LEN = 8
TRAINING_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 16

from data_utils import _BASE_TEMPLATE, _SOURCE_TEMPLATE
BASE_TEMPLATE = _BASE_TEMPLATE
SOURCE_TEMPLATE = _SOURCE_TEMPLATE

SPLIT_TO_INV_LOCATIONS = {
    split: {'max_input_length': INPUT_MAX_LEN,
             'inv_position': [INPUT_MAX_LEN - 2]}
    for split in list(split_to_dataset) + [BASE_TEMPLATE, SOURCE_TEMPLATE]
}

training_tasks_list = [
  {'das': 'match_source'}
]

eval_split_to_dataset = {k: v for k, v in split_to_dataset.items()
                         if k.endswith('-test')
                         }

model = model.eval()

for inv_layer in [[i] for i in INV_LAYERS]:
  for lr in [1e-4]:
    for inv_dim in INV_DIMENSIONS:
      # train
      for training_tasks in training_tasks_list:
        config = {
            'regularization_coefficient': 0,
            'intervention_dimension': inv_dim,
            'max_output_tokens': 1,
            'intervenable_config': {
              'intervenable_layer': inv_layer,
              'intervenable_representation_type': 'block_output',
              'intervenable_unit': 'pos',
              'max_number_of_units': 1,
              'intervenable_interventions_type': LowRankRotatedSpaceIntervention,
            },
            'training_tasks': training_tasks,
            'training_epoch': 1,
            'split_to_inv_locations': SPLIT_TO_INV_LOCATIONS,
            'split_to_labels': None,
            'max_train_percentage': 1.0 if len(training_tasks) <= 3 else 1.0,
            'init_lr': lr,
        }
        intervenable = run_exp(config)

nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.
/juice2/scr2/aditijb/pyvene/pyvene/__init__.py


In [17]:
# Accuracy; automatically searches for all reasonable layers + dims  and gathers available results within the same experiment

import os, json
import numpy as np

EXPERIMENT_NAME = "year_localization"  # change as needed
layers = range(25)  # 0–24
dims = [2**i for i in range(7)]  # [1,2,4,8,16,32,64]

out_lines = []
model_name = model.name_or_path.split('/')[-1]
print(f"model={model_name}")
out_lines.append(f"model={model_name}")

for l in layers:
    iia = []
    for d in dims:
        fname = f"{EXPERIMENT_NAME}/OLMo-2-0425-1B-layer{l}_{l}-dim{d}-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output_evalall.json"
        if not os.path.exists(fname):
            continue
        with open(fname) as f:
            split_to_eval_metrics = json.load(f)
        acc = np.mean([
            v['metrics']['labels']['inv_outputs']['accuracy']
            for k, v in split_to_eval_metrics.items()
            if '-correct' in k
        ])
        iia.append((d, acc))
    if iia:
        line = f"layer={l}  \t" + "\t".join([f"dim={d}:{acc:.2f}" for d, acc in iia])
        print(line)
        out_lines.append(line)

# write results to file
os.makedirs(EXPERIMENT_NAME, exist_ok=True)
out_path = os.path.join(EXPERIMENT_NAME, "iia_results.txt")
with open(out_path, "w") as f:
    f.write("\n".join(out_lines))

print(f"Results written to {out_path}")


model=OLMo-2-0425-1B
layer=0  	dim=1:0.72	dim=2:0.76	dim=4:0.85	dim=8:0.93
layer=1  	dim=1:0.72	dim=2:0.77	dim=4:0.86	dim=8:0.92
layer=2  	dim=1:0.71	dim=2:0.86	dim=4:0.92	dim=8:0.95
layer=3  	dim=2:0.88	dim=4:0.93	dim=8:0.95
layer=4  	dim=2:0.86	dim=4:0.98	dim=8:0.98
layer=5  	dim=2:0.97	dim=4:0.98	dim=8:0.99
layer=6  	dim=2:0.98	dim=4:0.99
layer=9  	dim=2:0.76	dim=4:0.77	dim=8:0.81
layer=10  	dim=2:0.75	dim=4:0.75
Results written to year_localization/iia_results.txt
