# Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import getpass
import os
import sys

PROJECT_NAME = 'ood-prediction'
DATA_DIR = f'/scr-ssd/hij/data'
#MODEL_DIR = f'{PROJECT_NAME}/models'
MODEL_DIR = '/scr-ssd/hij/models'

sys.path.append(f'/nlp/scr/hij/{PROJECT_NAME}/src')
os.environ["HF_HOME"] = '/scr-ssd/hij/models'
os.environ["HF_HUB"] = '/scr-ssd/hij/models'

In [None]:
import numpy as np
import random
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(0)

In [None]:
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Models

In [None]:
import gc

#del model
gc.collect()
torch.cuda.empty_cache()

In [None]:
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


model_id = "allenai/OLMo-2-0425-1B"
revision = "main"
tokenizer = AutoTokenizer.from_pretrained(
    model_id, padding_side='left', revision=revision,
    cache_dir=MODEL_DIR)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
      model_id, low_cpu_mem_usage=True, device_map='auto',
      torch_dtype=torch.bfloat16, cache_dir=MODEL_DIR)
model = model.eval()

[2025-09-22 19:08:23,410] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/nlp/scr2/hij/anaconda3/envs/pytorch/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/nlp/scr2/hij/anaconda3/envs/pytorch/compiler_compat/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Behavioral Testing

In [None]:
from generation_utils import generate_distribution_batched

prompt_template = ['In {year}, there']
year_value = list(range(1500, 3500))
print(len(year_value))
prompts = [prompt_template[0].format(year=year) for year in year_value]

2000


In [None]:
predictions = generate_distribution_batched(model, tokenizer, [prompt_template[0].format(year=i) for i in year_value])

100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 31.16it/s]


In [None]:
# Filter out year where the tense prediction is wrong.

PAST_TENSE = ['was', 'were']
FUTURE_TENSE = ['is', 'are', 'will']


kept_years = [year for year, pred in zip(year_value, predictions)
              if (pred[0][0].strip() in PAST_TENSE and year < 2024) or
                 (pred[0][0].strip() in FUTURE_TENSE and year >= 2024)]

kept_prompts = [prompt_template[0].format(year=year) for year in kept_years]
len(kept_prompts)

564

In [None]:
past_tense_prompt = [kept_prompts[i] for i in range(len(kept_prompts)) if kept_years[i] < 2024]
future_tense_prompt = [kept_prompts[i] for i in range(len(kept_prompts)) if kept_years[i] >= 2024]

print('#Past=%d' % len(past_tense_prompt), '#Future=%d' % len(future_tense_prompt))

#Past=524 #Future=40


In [None]:
# Create intervention data
import json

random.seed(0)
random.shuffle(past_tense_prompt)
random.shuffle(future_tense_prompt)

data = {
    'train': {'correct': past_tense_prompt[:256] + future_tense_prompt[:20] * 10, 'wrong': []},
    'val': {'correct': past_tense_prompt[256:256+128] + future_tense_prompt[20:25] * 10, 'wrong': []},
    'test': {'correct': past_tense_prompt[512+128:] + future_tense_prompt[25:] * 10, 'wrong': []},
}
json.dump(data, open(f'year_localization/year_{model_id.split("/")[-1]}_{revision}_correct_split_0.json', 'w'))

# Localizing Representations of Year

In [None]:
import json

from data_utils import load_intervention_data, _BASE_TEMPLATE
from generation_utils import generate_batched

sample_size = 512
split_type = ''
SPLIT_ID = '1'
mode = 'das'
data_split = json.load(open(os.path.join(f'year_localization/year_{model_id.split("/")[-1]}_{revision}_correct_split_0.json')))

verified_examples = data_split['train']['correct'][:sample_size]
print(verified_examples[:2])

intervention_prompt_to_output = generate_batched(model, tokenizer, [p for s in data_split for k in ('correct', 'wrong') for p in data_split[s][k]], max_new_tokens=1)
prompt_to_vars = {p: {'input': p,
                      'label': intervention_prompt_to_output[p],
                      'split': _BASE_TEMPLATE}
                 for s in data_split for k in ('correct', 'wrong') for p in data_split[s][k]}


def get_tense(be_word):
  be_word = be_word.lower().strip()
  # We distinguish present and future here even though English does not.
  if be_word == 'will':
    return 'future'
  elif be_word == 'is' or be_word == 'are':
    return 'present'
  elif be_word.endswith('ed') or be_word == 'was' or be_word == 'were':
    return 'past'
  else:
    raise ValueError(f'Unknown tense for {be_word}')


def set_tense(be_word, tense):
  normalize_be_word = be_word.lower().strip()
  tense_table = {
      'future': {'will': 'will', 'is': 'will', 'are': 'will', 'was': 'will', 'were': 'will'},
      'present': {'will': 'is', 'is': 'is', 'are': 'are', 'was': 'is', 'were': 'are'},
      'past': {'will': 'was', 'is': 'was', 'are': 'were', 'was': 'was', 'were': 'were'},
  }
  new_be_word =  tense_table[tense][normalize_be_word]
  if be_word.startswith(' '):
    new_be_word = ' ' + new_be_word
  return new_be_word


split_to_raw_example, split_to_dataset = load_intervention_data(
    mode, verified_examples, data_split, prompt_to_vars,
    inv_label_fn=lambda x, y: set_tense(x['label'], get_tense(y['label'])),
    filter_fn=lambda x, y: get_tense(x['label']) != get_tense(y['label']),
    max_example_per_split=20480,
    max_example_per_eval_split=10)

['In 1532, there', 'In 1981, there']
Total #prompts=784
Set prompt_max_length=8


100%|███████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 31.57it/s]


mode=das, #base_examples=456, #source_examples=456+150+0=634
BEFORE SPLIT: #Training examples=456*456=119200, #Validation examples==0, #Test examples=456*(150+0)=43000
AFTER SPLIT KEPT: #Training examples=min(456*456, 20480)=20480, #Validation examples==0, #Test examples=150*10+0*10=1330
#Splits=1 training split plus 1,024 test splits=134


In [None]:
len(split_to_dataset['das-train'])

20480

In [None]:
split_to_dataset['das-train'][6]

{'input': 'In 1657, there',
 'label': ' was',
 'source_input': 'In 2044, there',
 'source_label': ' is',
 'inv_label': ' is',
 'split': 'BASE_TEMPLATE',
 'source_split': 'BASE_TEMPLATE'}

In [None]:
from collections import Counter

Counter([example['inv_label'] for example in split_to_dataset['das-train']])

Counter({' is': 6526,
         ' were': 4409,
         ' was': 4340,
         ' will': 4113,
         ' are': 1092})

In [None]:
Counter([example['inv_label'] == example['label'] for example in split_to_dataset['das-train']])

Counter({False: 20480})

In [None]:
import getpass
import os
import sys

USER = getpass.getuser()

CORE_LIB_DIR = f'/nlp/scr/{USER}/core'
RAVEL_LIB_DIR = f'/nlp/scr/{USER}/internal-ravel/src'
PYVENE_LIB_DIR = f'/nlp/scr/{USER}/pyvene'
sys.path.append(CORE_LIB_DIR)
sys.path.append(RAVEL_LIB_DIR)
sys.path.append(PYVENE_LIB_DIR)

In [None]:
SCR_MODEL_DIR = f'/nlp/scr/{USER}/year_localization'

In [None]:
import gc


gc.collect()
torch.cuda.empty_cache()

In [None]:
import collections
import gc
import re

from tqdm import tqdm, trange
from transformers import get_linear_schedule_with_warmup
from datasets import concatenate_datasets
from torch.nn import CrossEntropyLoss
from causal_interventions import compute_string_based_metrics

import pyvene as pv
from utils.intervention_utils import LowRankRotatedSpaceIntervention, get_intervention_config, train_intervention_step, remove_invalid_token_id, remove_all_forward_hooks

from utils.dataset_utils import get_multitask_dataloader
from utils.metric_utils import compute_cross_entropy_loss
from causal_interventions import eval_with_interventions_batched, compute_metrics


def train_alignment(config):
  print('Training Tasks: %s' % config['training_tasks'])
  concat_split_to_dataset = {f'joint-{split}':
      concatenate_datasets([split_to_dataset[f'{task_name}-{split}'].select(
          np.random.choice(len(split_to_dataset[f'{task_name}-{split}']),
                           size=(1024 if config['training_tasks'][task_name] == 'match_base' else len(split_to_dataset[f'{task_name}-{split}'])),
                           replace=False))
                            for task_name in config['training_tasks']
                            # repeat
                            for _ in range(1 if isinstance(config['training_tasks'][task_name], str) or split != 'train'
                                             else config['training_tasks'][task_name][1])
                            if f'{task_name}-{split}' in split_to_dataset])
      for split in ('train',)}
  inv_task = '|'.join([task_name for task_name, label in config['training_tasks'].items()
                       if label == 'match_source' or 'match_source' in label])
  inv_task = inv_task.split('|')
  print('Training tasks matching source label: %s' % inv_task)
  print('#Training examples: %d' % len(concat_split_to_dataset['joint-train']))
  max_train_example = int(config['max_train_percentage'] * len(concat_split_to_dataset['joint-train']))
  train_dataloader = get_multitask_dataloader(
      concat_split_to_dataset['joint-train'].select(range(max_train_example)),
      tokenizer=tokenizer,
      batch_size=TRAINING_BATCH_SIZE, prompt_max_length=INPUT_MAX_LEN,
      output_max_length=config['max_output_tokens'] + int(tokenizer.bos_token is not None),
      # The set of splits to load as cause tasks
      cause_tasks=[BASE_TEMPLATE, SOURCE_TEMPLATE],
      first_n=config['max_output_tokens'])


  # Create Model
  split_to_inv_locations = config['split_to_inv_locations']
  intervenable_config = get_intervention_config(
      type(model), config['intervenable_config']['intervenable_representation_type'],
      config['intervenable_config']['intervenable_layer'],
      config['intervenable_config']['intervenable_interventions_type'],
      intervention_dimension=config['intervention_dimension'])
  intervenable = pv.IntervenableModel(intervenable_config, model)
  intervenable.set_device("cuda")
  intervenable.disable_model_gradients()

  # Training
  epochs = config['training_epoch']
  gradient_accumulation_steps = 1
  total_step = 0

  warm_up_steps = 0 # 0.1 * t_total
  regularization_coefficient = config['regularization_coefficient']
  optimizer_params = []
  for k, v in intervenable.interventions.items():
      if isinstance(v[0], LowRankRotatedSpaceIntervention):
        optimizer_params += [{'params': v[0].rotate_layer.parameters()}]
      else:
        raise NotImplementedError
  optimizer = torch.optim.AdamW(
      optimizer_params, lr=config['init_lr'], weight_decay=0)
  scheduler = get_linear_schedule_with_warmup(
      optimizer, num_warmup_steps=warm_up_steps,
      num_training_steps=int(10 * len(train_dataloader))
  )

  #intervenable.model.train() # train enables drop-off but no grads
  print("base model trainable parameters: ", pv.count_parameters(intervenable.model))
  print("intervention trainable parameters: ", intervenable.count_parameters())
  train_iterator = trange(0, int(epochs), desc="Epoch")

  num_output_tokens = config['max_output_tokens']
  for epoch in train_iterator:
      epoch_iterator = tqdm(
          train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
      )
      aggreated_stats = collections.defaultdict(list)
      for step, inputs in enumerate(epoch_iterator):
          for k, v in inputs.items():
              if v is not None and isinstance(v, torch.Tensor):
                  inputs[k] = v.to("cuda")
          b_s = inputs["input_ids"].shape[0]
          position_ids = {f'{prefix}position_ids': intervenable.model.prepare_inputs_for_generation(
                  input_ids=inputs[f"{prefix}input_ids"], attention_mask=inputs[f"{prefix}attention_mask"])['position_ids']
                  for prefix in ('', 'source_')}
          inputs.update(position_ids)
          for key in inputs:
            if key in ('input_ids', 'source_input_ids', 'attention_mask', 'source_attention_mask', 'position_ids', 'source_position_ids'):
              inputs[key] = inputs[key].to(device)

          counterfactual_outputs = train_intervention_step(
              intervenable, inputs, split_to_inv_locations, pad_token_id=tokenizer.pad_token_id)
          eval_metrics = compute_metrics(
              {'inv_outputs': [counterfactual_outputs.logits[:, -num_output_tokens-1:-1]]},
              [inputs['labels'][:, :num_output_tokens]],
              last_n_tokens=num_output_tokens,
              pad_token_id=tokenizer.pad_token_id,
          )
          loss = compute_cross_entropy_loss(
              counterfactual_outputs.logits,
              inputs["labels"][:, :num_output_tokens],
              next_n_tokens=num_output_tokens,
              pad_token_id=tokenizer.pad_token_id,
          )
          aggreated_stats['loss'].append(loss.item())
          aggreated_stats['acc'].append(eval_metrics['inv_outputs']["accuracy"])
          epoch_iterator.set_postfix({k: round(np.mean(aggreated_stats[k]), 2) for k in aggreated_stats})

          if step < 3:
            print('\nTokens to intervene:')
            intervention_locations = [split_to_inv_locations[inputs["split"][i]]['inv_position'] for i in range(len(inputs["split"]))]
            source_intervention_locations = [split_to_inv_locations[inputs["source_split"][i]]['inv_position'] for i in range(len(inputs["split"]))]
            print(inputs['input'][:3])
            print(inputs['source_input'][:3])
            print('Base:', tokenizer.batch_decode([inputs['input_ids'][i][intervention_locations[i]] for i in range(len(inputs["split"]))]))
            print('Source:', tokenizer.batch_decode([inputs['source_input_ids'][i][source_intervention_locations[i]] for i in range(len(inputs["split"]))]))
            print('Output:', tokenizer.batch_decode(torch.argmax(counterfactual_outputs.logits[:, -num_output_tokens-1:-1], dim=-1)))
            print('Label     :', tokenizer.batch_decode(remove_invalid_token_id(inputs['labels'][:, :num_output_tokens], tokenizer.pad_token_id)))
            print('Base Label:', tokenizer.batch_decode(remove_invalid_token_id(inputs['base_labels'][:, :num_output_tokens], tokenizer.pad_token_id)))

          if gradient_accumulation_steps > 1:
              loss = loss / gradient_accumulation_steps
          if total_step % gradient_accumulation_steps == 0:
              if not (gradient_accumulation_steps > 1 and total_step == 0):
                  loss.backward()
                  optimizer.step()
                  scheduler.step()
                  intervenable.set_zero_grad()
          total_step += 1
  return intervenable, intervenable_config


def run_exp(config):
  inv_tasks = '+'.join([''.join(re.findall(r'[A-Za-z]+', t)) + ('' if isinstance(l, str) else str(l[1])) for t, l in config['training_tasks'].items() if l == 'match_source' or 'match_source' in l])
  control_tasks = '+'.join([''.join(re.findall(r'[A-Za-z]+', t)) for t, l in config['training_tasks'].items() if l == 'match_base' or 'match_base' in l])
  task_compressed = ((inv_tasks + '_ex_' + control_tasks) if control_tasks else inv_tasks).replace('AZaz', '')
  das_type = 'multi_das' if len(config['training_tasks']) > 1 else 'das_baseline'
  if config['intervenable_config']['intervenable_interventions_type'] == LowRankRotatedSpaceIntervention:
    das_type = das_type.replace('das', 'daslora')
  split_to_inv_locations = config['split_to_inv_locations']
  input_len = list(split_to_inv_locations.values())[0]['max_input_length']
  inv_pos = min([x['inv_position'][0] for x in split_to_inv_locations.values()])
  inv_loc_name = 'len%d_pos%s' % (input_len, 'e' if inv_pos != input_len - 1 else 'f')
  training_data_percentage = int(config['max_train_percentage'] * 100)
  suffix = f"_example{len(verified_examples)}_{config['intervenable_config']['intervenable_representation_type']}"
  layer = '%s_%s' % (min(config['intervenable_config']['intervenable_layer']), max(config['intervenable_config']['intervenable_layer'])) if isinstance(config['intervenable_config']['intervenable_layer'], list) else config['intervenable_config']['intervenable_layer']
  model_name = model.name_or_path.split('/')[-1]
  run_name = f"{model_name}-layer{layer}-dim{config['intervention_dimension']}-{das_type}_{config['max_output_tokens']}tok_{task_compressed}-mmlu_id-{SPLIT_ID}_{inv_loc_name}_ep{config['training_epoch']}{suffix}"
  config['run_name_prefix'] = run_name#.rsplit('_ep', 1)[0]
  print(run_name)
  log_file_path = os.path.join(SCR_MODEL_DIR, 'logs', f'{run_name}.log')
  if True:
      print(run_name)
      intervenable, intervenable_config = train_alignment(config)
      # Save model
      torch.save({k: v[0].rotate_layer.weight for k, v in intervenable.interventions.items()},
                 os.path.join(SCR_MODEL_DIR, f'{run_name}.pt'))
      print('Model saved to %s' % os.path.join(SCR_MODEL_DIR, f'{run_name}.pt'))
      gc.collect()
      torch.cuda.empty_cache()
      # eval
      with torch.no_grad():
        split_to_eval_metrics = eval_with_interventions_batched(
            intervenable, eval_split_to_dataset,
            split_to_inv_locations,
            tokenizer,
            compute_metrics_fn=compute_metrics,
            max_new_tokens=config['max_output_tokens'],
            eval_batch_size=EVAL_BATCH_SIZE,
            inference_mode='generate',
            debug_print=True,
          )
      print('Mean IIA: %.4f' % np.mean(
          [v['metrics']['labels']['inv_outputs']['accuracy'] for k, v in split_to_eval_metrics.items()]))
      print('Mean correct IIA: %.4f' % np.mean(
          [v['metrics']['labels']['inv_outputs']['accuracy'] for k, v in split_to_eval_metrics.items() if '-correct' in k]))
      print('Mean wrong IIA: %.4f' % np.mean(
          [v['metrics']['labels']['inv_outputs']['accuracy'] for k, v in split_to_eval_metrics.items() if '-wrong' in k]))
  # Save model.
  #torch.save({k: v[0].rotate_layer.weight for k, v in intervenable.interventions.items()},
  #           os.path.join(SCR_MODEL_DIR, f'{config["run_name_prefix"]}.pt'))
  #print('Model saved to %s' % os.path.join(SCR_MODEL_DIR, f'{config["run_name_prefix"]}.pt'))
  # logging
  json.dump(split_to_eval_metrics, open(os.path.join(SCR_MODEL_DIR, f'{run_name}_evalall.json'), 'w'))
  print('Saved to %s' % os.path.join(SCR_MODEL_DIR, f'{run_name}.json'))
  remove_all_forward_hooks(intervenable)
  return intervenable



assert mode == 'das'

INPUT_MAX_LEN = 8
TRAINING_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 16

from data_utils import _BASE_TEMPLATE, _SOURCE_TEMPLATE
BASE_TEMPLATE = _BASE_TEMPLATE
SOURCE_TEMPLATE = _SOURCE_TEMPLATE

SPLIT_TO_INV_LOCATIONS = {
    split: {'max_input_length': INPUT_MAX_LEN,
             'inv_position': [INPUT_MAX_LEN - 3]}
    for split in list(split_to_dataset) + [BASE_TEMPLATE, SOURCE_TEMPLATE]
}

training_tasks_list = [
  {'das': 'match_source'}
]

eval_split_to_dataset = {k: v for k, v in split_to_dataset.items()
                         if k.endswith('-test')
                         }

model = model.eval()



for inv_layer in [[i] for i in range(10)]:
  for lr in [1e-4]:
    for inv_dim in [4]:
      # train
      for training_tasks in training_tasks_list:
        config = {
            'regularization_coefficient': 0,
            'intervention_dimension': inv_dim,
            'max_output_tokens': 1,
            'intervenable_config': {
              'intervenable_layer': inv_layer,
              'intervenable_representation_type': 'block_output',
              'intervenable_unit': 'pos',
              'max_number_of_units': 1,
              'intervenable_interventions_type': LowRankRotatedSpaceIntervention,
            },
            'training_tasks': training_tasks,
            'training_epoch': 1,
            'split_to_inv_locations': SPLIT_TO_INV_LOCATIONS,
            'split_to_labels': None,
            'max_train_percentage': 1.0 if len(training_tasks) <= 3 else 1.0,
            'init_lr': lr,
        }
        intervenable = run_exp(config)

OLMo-2-0425-1B-layer0_0-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output
OLMo-2-0425-1B-layer0_0-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output
Training Tasks: {'das': 'match_source'}
Training tasks matching source label: ['das']
#Training examples: 20480


Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

base model trainable parameters:  0
intervention trainable parameters:  8192


Epoch: 0:   0%|                                                      | 2/1280 [00:00<01:51, 11.51it/s, loss=3.59, acc=0]


Tokens to intervene:
['In 2028, there', 'In 1529, there', 'In 1985, there']
['In 1900, there', 'In 2077, there', 'In 2039, there']
Base: ['8', '9', '5', '6', '3', '4', '6', '9', '7', '0', '9', '8', '7', '9', '0', '0']
Source: ['0', '7', '9', '0', '0', '4', '4', '0', '4', '5', '5', '8', '5', '4', '2', '9']
Output: [' will', ' was', ' were', ' was', ' was', ' was', ' was', ' was', ' was', ' will', ' are', ' will', ' are', ' was', ' will', ' will']
Label     : [' was', ' is', ' are', ' will', ' will', ' is', ' is', ' is', ' will', ' was', ' were', ' was', ' were', ' will', ' was', ' was']
Base Label: [' will', ' was', ' were', ' was', ' was', ' was', ' was', ' was', ' was', ' will', ' are', ' will', ' are', ' was', ' will', ' will']

Tokens to intervene:
['In 2060, there', 'In 2042, there', 'In 2030, there']
['In 1827, there', 'In 1674, there', 'In 1946, there']
Base: ['0', '2', '0', '1', '0', '0', '0', '8', '2', '7', '0', '0', '7', '9', '3', '4']
Source: ['7', '4', '6', '7', '2', '7', '

Epoch: 0:   0%|                                                      | 2/1280 [00:00<01:51, 11.51it/s, loss=3.51, acc=0]


Tokens to intervene:
['In 2024, there', 'In 1809, there', 'In 1779, there']
['In 1755, there', 'In 2050, there', 'In 2049, there']
Base: ['4', '9', '9', '0', '7', '4', '0', '9', '9', '9', '6', '5', '9', '6', '7', '0']
Source: ['5', '0', '9', '6', '8', '5', '4', '7', '4', '9', '9', '0', '4', '0', '0', '7']
Output: [' will', ' was', ' was', ' will', ' are', ' is', ' will', ' was', ' was', ' are', ' was', ' was', ' are', ' were', ' is', ' will']
Label     : [' was', ' will', ' is', ' was', ' will', ' was', ' is', ' is', ' is', ' were', ' is', ' is', ' will', ' are', ' will', ' is']
Base Label: [' will', ' was', ' was', ' will', ' are', ' is', ' will', ' was', ' was', ' are', ' was', ' was', ' are', ' were', ' is', ' will']


Epoch: 0: 100%|████████████████████████████████████████████████| 1280/1280 [01:48<00:00, 11.77it/s, loss=0.73, acc=0.81]
Epoch: 100%|█████████████████████████████████████████████████████████████████████████████| 1/1 [01:48<00:00, 108.73s/it]


Model saved to /nlp/scr/hij/year_localization/OLMo-2-0425-1B-layer0_0-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output.pt


Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Test:   0%|          | 0/84 [00:00<?, ?it/s]


 'source-In 1976, there-correct-test': {'base_labels': {'base_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.5}, 'inv_outputs': {'accuracy': 0.0, 'token_accuracy': 0.0, 'class_0_accuracy': 0.0}, 'accuracy': 0.0}, 'labels': {'base_outputs': {'accuracy': 0.1, 'token_accuracy': 0.1, 'class_0_accuracy': 0.1}, 'inv_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.7}, 'accuracy': 0.9}}

Inputs:
Base: ['In 2057, there', 'In 2060, there', 'In 2099, there']
Source: ['In 1976, there', 'In 1976, there', 'In 1976, there']
Tokens to intervene:
    Base: ['7', '0', '9', '7', '9', '4', '4', '4', '0', '3', '9', '3', '3', '3', '0', '9']
    Source: ['6', '6', '6', '6', '6', '6', '6', '6', '6', '6', '0', '0', '0', '0', '0', '0']
Outputs:
          Base Output: [' are', ' will', ' are', ' is', ' are', ' will', ' are', ' are', ' will', ' were', ' are', ' were', ' were', ' were', ' are', ' are']
Counterfactual Output: [' were', ' was', ' were', ' were', ' 

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

base model trainable parameters:  0
intervention trainable parameters:  8192


Epoch: 0:   0%|                                                      | 2/1280 [00:00<01:51, 11.48it/s, loss=3.68, acc=0]


Tokens to intervene:
['In 2099, there', 'In 1969, there', 'In 2010, there']
['In 1791, there', 'In 2400, there', 'In 2030, there']
Base: ['9', '9', '0', '4', '0', '1', '0', '8', '0', '0', '9', '1', '0', '4', '1', '3']
Source: ['1', '0', '0', '0', '1', '9', '2', '4', '2', '8', '1', '0', '8', '3', '4', '4']
Output: [' are', ' were', ' were', ' are', ' will', ' was', ' will', ' was', ' will', ' will', ' are', ' was', ' was', ' was', ' was', ' are']
Label     : [' were', ' are', ' will', ' will', ' was', ' is', ' was', ' will', ' was', ' was', ' were', ' will', ' will', ' is', ' will', ' were']
Base Label: [' are', ' were', ' were', ' are', ' will', ' was', ' will', ' was', ' will', ' will', ' are', ' was', ' was', ' was', ' was', ' are']

Tokens to intervene:
['In 2030, there', 'In 2042, there', 'In 1728, there']
['In 2057, there', 'In 2040, there', 'In 2054, there']
Base: ['0', '2', '8', '2', '5', '0', '7', '4', '4', '0', '8', '4', '0', '0', '0', '4']
Source: ['7', '0', '4', '9', '3', '

Epoch: 0:   0%|                                                      | 2/1280 [00:00<01:51, 11.48it/s, loss=3.52, acc=0]


Tokens to intervene:
['In 1639, there', 'In 1729, there', 'In 1738, there']
['In 2047, there', 'In 2042, there', 'In 2097, there']
Base: ['9', '9', '8', '2', '0', '7', '5', '7', '9', '6', '5', '8', '8', '3', '3', '7']
Source: ['7', '2', '7', '6', '8', '5', '7', '0', '2', '3', '0', '2', '2', '1', '4', '3']
Output: [' was', ' was', ' was', ' are', ' are', ' are', ' was', ' is', ' is', ' were', ' were', ' was', ' was', ' are', ' were', ' are']
Label     : [' is', ' is', ' is', ' were', ' was', ' were', ' is', ' will', ' was', ' are', ' will', ' is', ' is', ' were', ' will', ' were']
Base Label: [' was', ' was', ' was', ' are', ' will', ' are', ' was', ' is', ' is', ' were', ' were', ' was', ' was', ' are', ' were', ' are']


Epoch: 0: 100%|████████████████████████████████████████████████| 1280/1280 [01:46<00:00, 12.01it/s, loss=0.65, acc=0.82]
Epoch: 100%|█████████████████████████████████████████████████████████████████████████████| 1/1 [01:46<00:00, 106.61s/it]


Model saved to /nlp/scr/hij/year_localization/OLMo-2-0425-1B-layer1_1-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output.pt


Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Test:   0%|          | 0/84 [00:00<?, ?it/s]


 'source-In 1976, there-correct-test': {'base_labels': {'base_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.5}, 'inv_outputs': {'accuracy': 0.0, 'token_accuracy': 0.0, 'class_0_accuracy': 0.0}, 'accuracy': 0.0}, 'labels': {'base_outputs': {'accuracy': 0.1, 'token_accuracy': 0.1, 'class_0_accuracy': 0.1}, 'inv_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.7}, 'accuracy': 0.9}}

Inputs:
Base: ['In 2057, there', 'In 2060, there', 'In 2099, there']
Source: ['In 1976, there', 'In 1976, there', 'In 1976, there']
Tokens to intervene:
    Base: ['7', '0', '9', '7', '9', '4', '4', '4', '0', '3', '9', '3', '3', '3', '0', '9']
    Source: ['6', '6', '6', '6', '6', '6', '6', '6', '6', '6', '0', '0', '0', '0', '0', '0']
Outputs:
          Base Output: [' are', ' will', ' are', ' is', ' are', ' will', ' are', ' are', ' will', ' were', ' are', ' were', ' were', ' were', ' are', ' are']
Counterfactual Output: [' were', ' was', ' were', ' were', ' 

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

base model trainable parameters:  0
intervention trainable parameters:  8192


Epoch: 0:   0%|                                                   | 2/1280 [00:00<01:53, 11.28it/s, loss=3.09, acc=0.06]


Tokens to intervene:
['In 2059, there', 'In 1791, there', 'In 1763, there']
['In 1740, there', 'In 2053, there', 'In 2057, there']
Base: ['9', '1', '3', '5', '7', '7', '0', '7', '9', '0', '0', '8', '9', '0', '7', '0']
Source: ['0', '3', '7', '0', '2', '2', '2', '2', '7', '6', '0', '8', '4', '5', '9', '5']
Output: [' are', ' was', ' was', ' was', ' are', ' is', ' were', ' are', ' was', ' will', ' will', ' were', ' are', ' are', ' are', ' are']
Label     : [' were', ' is', ' is', ' will', ' were', ' was', ' are', ' were', ' is', ' was', ' is', ' will', ' will', ' were', ' were', ' were']
Base Label: [' are', ' was', ' was', ' was', ' are', ' is', ' were', ' are', ' was', ' will', ' will', ' were', ' are', ' are', ' are', ' are']

Tokens to intervene:
['In 2047, there', 'In 2099, there', 'In 1788, there']
['In 1701, there', 'In 1556, there', 'In 2049, there']
Base: ['7', '9', '8', '5', '3', '5', '9', '4', '8', '7', '6', '4', '5', '9', '8', '2']
Source: ['1', '6', '9', '4', '6', '0', '0',

Epoch: 0:   0%|                                                   | 2/1280 [00:00<01:53, 11.28it/s, loss=3.05, acc=0.04]


Tokens to intervene:
['In 1993, there', 'In 2053, there', 'In 2057, there']
['In 2053, there', 'In 2060, there', 'In 1841, there']
Base: ['3', '3', '7', '8', '7', '0', '0', '0', '0', '8', '5', '0', '0', '4', '9', '0']
Source: ['3', '0', '1', '4', '0', '6', '7', '4', '2', '6', '0', '4', '4', '0', '7', '7']
Output: [' were', ' were', ' are', ' was', ' was', ' will', ' will', ' are', ' are', ' will', ' was', ' will', ' was', ' is', ' are', ' was']
Label     : [' are', ' will', ' were', ' is', ' is', ' was', ' is', ' were', ' were', ' was', ' is', ' was', ' will', ' will', ' were', ' is']
Base Label: [' were', ' are', ' are', ' was', ' was', ' will', ' will', ' are', ' are', ' will', ' was', ' will', ' was', ' is', ' are', ' was']


Epoch: 0: 100%|████████████████████████████████████████████████| 1280/1280 [01:44<00:00, 12.26it/s, loss=0.61, acc=0.83]
Epoch: 100%|█████████████████████████████████████████████████████████████████████████████| 1/1 [01:44<00:00, 104.42s/it]


Model saved to /nlp/scr/hij/year_localization/OLMo-2-0425-1B-layer2_2-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output.pt


Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Test:   0%|          | 0/84 [00:00<?, ?it/s]


 'source-In 1976, there-correct-test': {'base_labels': {'base_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.5}, 'inv_outputs': {'accuracy': 0.0, 'token_accuracy': 0.0, 'class_0_accuracy': 0.0}, 'accuracy': 0.0}, 'labels': {'base_outputs': {'accuracy': 0.1, 'token_accuracy': 0.1, 'class_0_accuracy': 0.1}, 'inv_outputs': {'accuracy': 1.0, 'token_accuracy': 1.0, 'class_0_accuracy': 0.6}, 'accuracy': 1.0}}

Inputs:
Base: ['In 2057, there', 'In 2060, there', 'In 2099, there']
Source: ['In 1976, there', 'In 1976, there', 'In 1976, there']
Tokens to intervene:
    Base: ['7', '0', '9', '7', '9', '4', '4', '4', '0', '3', '9', '3', '3', '3', '0', '9']
    Source: ['6', '6', '6', '6', '6', '6', '6', '6', '6', '6', '0', '0', '0', '0', '0', '0']
Outputs:
          Base Output: [' are', ' will', ' are', ' is', ' are', ' will', ' are', ' are', ' will', ' were', ' are', ' were', ' were', ' were', ' are', ' are']
Counterfactual Output: [' were', ' was', ' were', ' was', ' w

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

base model trainable parameters:  0
intervention trainable parameters:  8192


Epoch: 0:   0%|                                                   | 2/1280 [00:00<01:58, 10.82it/s, loss=3.04, acc=0.03]


Tokens to intervene:
['In 2400, there', 'In 1990, there', 'In 2050, there']
['In 1887, there', 'In 2060, there', 'In 2300, there']
Base: ['0', '0', '0', '7', '2', '2', '7', '1', '0', '2', '4', '0', '9', '3', '9', '0']
Source: ['7', '0', '0', '4', '3', '3', '0', '0', '9', '2', '0', '9', '4', '4', '9', '5']
Output: [' are', ' were', ' will', ' was', ' are', ' is', ' are', ' was', ' will', ' was', ' was', ' are', ' are', ' were', ' was', ' will']
Label     : [' were', ' will', ' is', ' is', ' were', ' were', ' will', ' will', ' is', ' is', ' is', ' were', ' will', ' were', ' is', ' was']
Base Label: [' are', ' were', ' will', ' was', ' are', ' are', ' are', ' was', ' will', ' was', ' was', ' are', ' are', ' are', ' was', ' will']

Tokens to intervene:
['In 2028, there', 'In 2044, there', 'In 2028, there']
['In 2099, there', 'In 1744, there', 'In 2049, there']
Base: ['8', '4', '8', '4', '9', '5', '0', '9', '2', '8', '7', '9', '4', '3', '9', '9']
Source: ['9', '4', '9', '7', '2', '3', '9',

Epoch: 0:   0%|                                                   | 2/1280 [00:00<01:58, 10.82it/s, loss=3.08, acc=0.04]


Tokens to intervene:
['In 2030, there', 'In 2097, there', 'In 2059, there']
['In 2054, there', 'In 2050, there', 'In 1710, there']
Base: ['0', '7', '9', '7', '2', '9', '8', '3', '3', '9', '9', '3', '5', '7', '2', '4']
Source: ['4', '0', '0', '9', '0', '8', '4', '2', '7', '0', '9', '2', '9', '8', '0', '8']
Output: [' will', ' are', ' are', ' was', ' was', ' is', ' was', ' were', ' were', ' was', ' are', ' were', ' was', ' are', ' was', ' is']
Label     : [' is', ' will', ' were', ' is', ' is', ' was', ' is', ' are', ' are', ' will', ' were', ' were', ' is', ' will', ' is', ' will']
Base Label: [' will', ' are', ' are', ' was', ' was', ' is', ' was', ' were', ' were', ' was', ' are', ' are', ' was', ' are', ' was', ' is']


Epoch: 0: 100%|████████████████████████████████████████████████| 1280/1280 [01:42<00:00, 12.52it/s, loss=0.56, acc=0.87]
Epoch: 100%|█████████████████████████████████████████████████████████████████████████████| 1/1 [01:42<00:00, 102.22s/it]


Model saved to /nlp/scr/hij/year_localization/OLMo-2-0425-1B-layer3_3-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output.pt


Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Test:   0%|          | 0/84 [00:00<?, ?it/s]


 'source-In 1976, there-correct-test': {'base_labels': {'base_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.5}, 'inv_outputs': {'accuracy': 0.0, 'token_accuracy': 0.0, 'class_0_accuracy': 0.0}, 'accuracy': 0.0}, 'labels': {'base_outputs': {'accuracy': 0.1, 'token_accuracy': 0.1, 'class_0_accuracy': 0.1}, 'inv_outputs': {'accuracy': 1.0, 'token_accuracy': 1.0, 'class_0_accuracy': 0.6}, 'accuracy': 1.0}}

Inputs:
Base: ['In 2057, there', 'In 2060, there', 'In 2099, there']
Source: ['In 1976, there', 'In 1976, there', 'In 1976, there']
Tokens to intervene:
    Base: ['7', '0', '9', '7', '9', '4', '4', '4', '0', '3', '9', '3', '3', '3', '0', '9']
    Source: ['6', '6', '6', '6', '6', '6', '6', '6', '6', '6', '0', '0', '0', '0', '0', '0']
Outputs:
          Base Output: [' are', ' will', ' are', ' is', ' are', ' will', ' are', ' are', ' will', ' were', ' are', ' were', ' were', ' were', ' are', ' are']
Counterfactual Output: [' were', ' was', ' were', ' was', ' w

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

base model trainable parameters:  0
intervention trainable parameters:  8192


Epoch: 0:   0%|                                                      | 2/1280 [00:00<01:51, 11.44it/s, loss=3.93, acc=0]


Tokens to intervene:
['In 1767, there', 'In 2059, there', 'In 1773, there']
['In 2060, there', 'In 1744, there', 'In 2044, there']
Base: ['7', '9', '3', '0', '0', '2', '2', '7', '0', '8', '4', '0', '8', '4', '2', '4']
Source: ['0', '4', '4', '2', '9', '4', '0', '9', '0', '9', '6', '0', '5', '0', '4', '6']
Output: [' was', ' are', ' was', ' are', ' were', ' were', ' was', ' was', ' was', ' will', ' will', ' were', ' will', ' was', ' was', ' are']
Label     : [' will', ' were', ' is', ' were', ' are', ' will', ' will', ' is', ' will', ' is', ' was', ' will', ' was', ' is', ' is', ' were']
Base Label: [' was', ' are', ' was', ' are', ' were', ' were', ' was', ' was', ' was', ' will', ' will', ' were', ' will', ' was', ' was', ' are']

Tokens to intervene:
['In 1801, there', 'In 2039, there', 'In 1640, there']
['In 2077, there', 'In 2060, there', 'In 2028, there']
Base: ['1', '9', '0', '8', '0', '9', '0', '0', '3', '5', '1', '6', '3', '4', '4', '8']
Source: ['7', '0', '8', '9', '6', '4', 

Epoch: 0:   0%|                                                      | 2/1280 [00:00<01:51, 11.44it/s, loss=3.73, acc=0]


Tokens to intervene:
['In 2016, there', 'In 1616, there', 'In 2024, there']
['In 2097, there', 'In 2042, there', 'In 1634, there']
Base: ['6', '6', '4', '1', '7', '0', '4', '7', '7', '4', '4', '0', '0', '4', '7', '0']
Source: ['7', '2', '4', '0', '8', '0', '2', '4', '9', '4', '9', '6', '0', '7', '0', '0']
Output: [' were', ' was', ' will', ' was', ' is', ' will', ' are', ' are', ' was', ' will', ' was', ' will', ' are', ' was', ' was', ' are']
Label     : [' are', ' is', ' was', ' will', ' was', ' was', ' were', ' will', ' is', ' is', ' is', ' was', ' will', ' is', ' will', ' will']
Base Label: [' were', ' was', ' will', ' was', ' is', ' will', ' are', ' are', ' was', ' will', ' was', ' will', ' are', ' was', ' was', ' are']


Epoch: 0: 100%|████████████████████████████████████████████████| 1280/1280 [01:39<00:00, 12.83it/s, loss=0.53, acc=0.86]
Epoch: 100%|██████████████████████████████████████████████████████████████████████████████| 1/1 [01:39<00:00, 99.78s/it]


Model saved to /nlp/scr/hij/year_localization/OLMo-2-0425-1B-layer4_4-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output.pt


Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Test:   0%|          | 0/84 [00:00<?, ?it/s]


 'source-In 1976, there-correct-test': {'base_labels': {'base_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.5}, 'inv_outputs': {'accuracy': 0.0, 'token_accuracy': 0.0, 'class_0_accuracy': 0.0}, 'accuracy': 0.0}, 'labels': {'base_outputs': {'accuracy': 0.1, 'token_accuracy': 0.1, 'class_0_accuracy': 0.1}, 'inv_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.7}, 'accuracy': 0.9}}

Inputs:
Base: ['In 2057, there', 'In 2060, there', 'In 2099, there']
Source: ['In 1976, there', 'In 1976, there', 'In 1976, there']
Tokens to intervene:
    Base: ['7', '0', '9', '7', '9', '4', '4', '4', '0', '3', '9', '3', '3', '3', '0', '9']
    Source: ['6', '6', '6', '6', '6', '6', '6', '6', '6', '6', '0', '0', '0', '0', '0', '0']
Outputs:
          Base Output: [' are', ' will', ' are', ' is', ' are', ' will', ' are', ' are', ' will', ' were', ' are', ' were', ' were', ' were', ' are', ' are']
Counterfactual Output: [' were', ' was', ' were', ' were', ' 

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

base model trainable parameters:  0
intervention trainable parameters:  8192


Epoch: 0:   0%|                                                   | 2/1280 [00:00<01:50, 11.58it/s, loss=2.86, acc=0.06]


Tokens to intervene:
['In 1820, there', 'In 2028, there', 'In 1743, there']
['In 2057, there', 'In 2057, there', 'In 2099, there']
Base: ['0', '8', '3', '7', '7', '0', '9', '2', '4', '2', '6', '7', '5', '5', '4', '0']
Source: ['7', '7', '9', '5', '4', '3', '3', '5', '1', '0', '3', '0', '3', '4', '0', '9']
Output: [' were', ' will', ' was', ' are', ' is', ' will', ' are', ' are', ' are', ' was', ' were', ' are', ' was', ' was', ' is', ' were']
Label     : [' are', ' is', ' is', ' were', ' was', ' is', ' were', ' were', ' were', ' is', ' are', ' were', ' is', ' will', ' will', ' are']
Base Label: [' were', ' will', ' was', ' are', ' is', ' will', ' are', ' are', ' are', ' was', ' were', ' are', ' was', ' was', ' is', ' were']

Tokens to intervene:
['In 2042, there', 'In 2024, there', 'In 2050, there']
['In 1629, there', 'In 1996, there', 'In 1919, there']
Base: ['2', '4', '0', '6', '0', '0', '3', '3', '2', '0', '0', '9', '0', '8', '6', '0']
Source: ['9', '6', '9', '0', '5', '2', '2', '7

Epoch: 0:   0%|                                                   | 2/1280 [00:00<01:50, 11.58it/s, loss=3.35, acc=0.04]


Tokens to intervene:
['In 2400, there', 'In 1916, there', 'In 2057, there']
['In 1521, there', 'In 2400, there', 'In 1712, there']
Base: ['0', '6', '7', '1', '4', '2', '3', '7', '8', '0', '6', '7', '7', '7', '6', '2']
Source: ['1', '0', '2', '8', '7', '4', '0', '7', '0', '0', '0', '3', '2', '0', '0', '5']
Output: [' are', ' was', ' are', ' was', ' was', ' are', ' was', ' was', ' was', ' was', ' was', ' is', ' was', ' is', ' were', ' are']
Label     : [' were', ' is', ' were', ' will', ' is', ' were', ' will', ' is', ' will', ' will', ' will', ' was', ' were', ' will', ' will', ' were']
Base Label: [' are', ' was', ' are', ' was', ' was', ' are', ' was', ' was', ' was', ' was', ' was', ' is', ' are', ' is', ' were', ' are']


Epoch: 0: 100%|████████████████████████████████████████████████| 1280/1280 [01:38<00:00, 13.03it/s, loss=0.52, acc=0.87]
Epoch: 100%|██████████████████████████████████████████████████████████████████████████████| 1/1 [01:38<00:00, 98.21s/it]


Model saved to /nlp/scr/hij/year_localization/OLMo-2-0425-1B-layer5_5-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output.pt


Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Test:   0%|          | 0/84 [00:00<?, ?it/s]


 'source-In 1976, there-correct-test': {'base_labels': {'base_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.5}, 'inv_outputs': {'accuracy': 0.0, 'token_accuracy': 0.0, 'class_0_accuracy': 0.0}, 'accuracy': 0.0}, 'labels': {'base_outputs': {'accuracy': 0.1, 'token_accuracy': 0.1, 'class_0_accuracy': 0.1}, 'inv_outputs': {'accuracy': 1.0, 'token_accuracy': 1.0, 'class_0_accuracy': 0.6}, 'accuracy': 1.0}}

Inputs:
Base: ['In 2057, there', 'In 2060, there', 'In 2099, there']
Source: ['In 1976, there', 'In 1976, there', 'In 1976, there']
Tokens to intervene:
    Base: ['7', '0', '9', '7', '9', '4', '4', '4', '0', '3', '9', '3', '3', '3', '0', '9']
    Source: ['6', '6', '6', '6', '6', '6', '6', '6', '6', '6', '0', '0', '0', '0', '0', '0']
Outputs:
          Base Output: [' are', ' will', ' are', ' is', ' are', ' will', ' are', ' are', ' will', ' were', ' are', ' were', ' were', ' were', ' are', ' are']
Counterfactual Output: [' were', ' was', ' were', ' was', ' w

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

base model trainable parameters:  0
intervention trainable parameters:  8192


Epoch: 0:   0%|                                                   | 2/1280 [00:00<01:48, 11.74it/s, loss=3.09, acc=0.03]


Tokens to intervene:
['In 2400, there', 'In 1598, there', 'In 2024, there']
['In 1809, there', 'In 2400, there', 'In 2009, there']
Base: ['0', '8', '4', '8', '6', '9', '0', '7', '1', '9', '6', '3', '9', '6', '8', '9']
Source: ['9', '0', '9', '9', '0', '8', '0', '0', '8', '2', '0', '7', '7', '2', '3', '4']
Output: [' are', ' was', ' will', ' was', ' was', ' are', ' were', ' is', ' was', ' are', ' was', ' was', ' was', ' was', ' will', ' was']
Label     : [' were', ' is', ' was', ' is', ' is', ' were', ' will', ' was', ' will', ' were', ' is', ' is', ' is', ' is', ' was', ' is']
Base Label: [' are', ' was', ' will', ' was', ' was', ' are', ' were', ' is', ' was', ' are', ' was', ' was', ' was', ' was', ' will', ' was']

Tokens to intervene:
['In 2053, there', 'In 1722, there', 'In 2400, there']
['In 1915, there', 'In 2300, there', 'In 1655, there']
Base: ['3', '2', '0', '5', '6', '4', '2', '7', '7', '5', '8', '0', '9', '7', '9', '9']
Source: ['5', '0', '5', '0', '3', '9', '1', '4', '5',

Epoch: 0:   0%|                                                   | 2/1280 [00:00<01:48, 11.74it/s, loss=3.14, acc=0.02]


Tokens to intervene:
['In 1642, there', 'In 2044, there', 'In 2042, there']
['In 2030, there', 'In 1996, there', 'In 1968, there']
Base: ['2', '4', '2', '4', '7', '0', '2', '0', '2', '4', '9', '4', '9', '0', '0', '7']
Source: ['0', '6', '8', '8', '9', '9', '7', '0', '2', '0', '3', '1', '7', '6', '0', '8']
Output: [' was', ' is', ' are', ' will', ' was', ' was', ' was', ' are', ' was', ' will', ' was', ' are', ' was', ' will', ' will', ' are']
Label     : [' will', ' was', ' were', ' was', ' is', ' is', ' is', ' were', ' is', ' is', ' is', ' were', ' is', ' was', ' is', ' will']
Base Label: [' was', ' is', ' are', ' will', ' was', ' was', ' was', ' are', ' was', ' will', ' was', ' are', ' was', ' will', ' will', ' are']


Epoch: 0: 100%|████████████████████████████████████████████████| 1280/1280 [01:34<00:00, 13.53it/s, loss=0.54, acc=0.85]
Epoch: 100%|██████████████████████████████████████████████████████████████████████████████| 1/1 [01:34<00:00, 94.60s/it]


Model saved to /nlp/scr/hij/year_localization/OLMo-2-0425-1B-layer6_6-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output.pt


Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Test:   0%|          | 0/84 [00:00<?, ?it/s]


 'source-In 1976, there-correct-test': {'base_labels': {'base_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.5}, 'inv_outputs': {'accuracy': 0.0, 'token_accuracy': 0.0, 'class_0_accuracy': 0.0}, 'accuracy': 0.0}, 'labels': {'base_outputs': {'accuracy': 0.1, 'token_accuracy': 0.1, 'class_0_accuracy': 0.1}, 'inv_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.7}, 'accuracy': 0.9}}

Inputs:
Base: ['In 2057, there', 'In 2060, there', 'In 2099, there']
Source: ['In 1976, there', 'In 1976, there', 'In 1976, there']
Tokens to intervene:
    Base: ['7', '0', '9', '7', '9', '4', '4', '4', '0', '3', '9', '3', '3', '3', '0', '9']
    Source: ['6', '6', '6', '6', '6', '6', '6', '6', '6', '6', '0', '0', '0', '0', '0', '0']
Outputs:
          Base Output: [' are', ' will', ' are', ' is', ' are', ' will', ' are', ' are', ' will', ' were', ' are', ' were', ' were', ' were', ' are', ' are']
Counterfactual Output: [' were', ' was', ' were', ' were', ' 

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

base model trainable parameters:  0
intervention trainable parameters:  8192


Epoch: 0:   0%|                                                      | 2/1280 [00:00<01:42, 12.51it/s, loss=4.23, acc=0]


Tokens to intervene:
['In 1868, there', 'In 2040, there', 'In 1869, there']
['In 2053, there', 'In 2099, there', 'In 2097, there']
Base: ['8', '0', '9', '2', '2', '9', '3', '2', '1', '8', '0', '6', '1', '4', '7', '9']
Source: ['3', '9', '7', '7', '0', '0', '0', '0', '0', '4', '7', '0', '0', '2', '0', '0']
Output: [' was', ' will', ' was', ' was', ' were', ' are', ' were', ' was', ' was', ' will', ' were', ' was', ' was', ' was', ' was', ' are']
Label     : [' is', ' is', ' is', ' is', ' are', ' were', ' will', ' is', ' will', ' is', ' are', ' will', ' will', ' is', ' will', ' will']
Base Label: [' was', ' will', ' was', ' was', ' were', ' are', ' were', ' was', ' was', ' will', ' were', ' was', ' was', ' was', ' was', ' are']

Tokens to intervene:
['In 1940, there', 'In 1673, there', 'In 1586, there']
['In 2044, there', 'In 2047, there', 'In 2042, there']
Base: ['0', '3', '6', '4', '7', '3', '3', '0', '4', '2', '4', '2', '9', '9', '1', '7']
Source: ['4', '7', '2', '0', '8', '9', '7', 

Epoch: 0:   0%|                                                      | 2/1280 [00:00<01:42, 12.51it/s, loss=3.91, acc=0]


Tokens to intervene:
['In 1614, there', 'In 2028, there', 'In 2050, there']
['In 2039, there', 'In 1818, there', 'In 1993, there']
Base: ['4', '8', '0', '7', '9', '6', '8', '4', '4', '9', '5', '0', '7', '6', '9', '1']
Source: ['9', '8', '3', '9', '2', '4', '0', '9', '0', '5', '7', '4', '3', '7', '9', '2']
Output: [' was', ' will', ' will', ' was', ' are', ' was', ' will', ' are', ' were', ' are', ' was', ' were', ' are', ' were', ' was', ' were']
Label     : [' is', ' was', ' was', ' is', ' were', ' is', ' was', ' were', ' are', ' were', ' is', ' will', ' were', ' are', ' is', ' are']
Base Label: [' was', ' will', ' will', ' was', ' are', ' was', ' will', ' are', ' were', ' are', ' was', ' were', ' are', ' were', ' was', ' were']


Epoch: 0: 100%|████████████████████████████████████████████████| 1280/1280 [01:32<00:00, 13.91it/s, loss=0.62, acc=0.81]
Epoch: 100%|██████████████████████████████████████████████████████████████████████████████| 1/1 [01:32<00:00, 92.02s/it]


Model saved to /nlp/scr/hij/year_localization/OLMo-2-0425-1B-layer7_7-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output.pt


Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Test:   0%|          | 0/84 [00:00<?, ?it/s]


 'source-In 1976, there-correct-test': {'base_labels': {'base_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.5}, 'inv_outputs': {'accuracy': 0.0, 'token_accuracy': 0.0, 'class_0_accuracy': 0.0}, 'accuracy': 0.0}, 'labels': {'base_outputs': {'accuracy': 0.1, 'token_accuracy': 0.1, 'class_0_accuracy': 0.1}, 'inv_outputs': {'accuracy': 1.0, 'token_accuracy': 1.0, 'class_0_accuracy': 0.6}, 'accuracy': 1.0}}

Inputs:
Base: ['In 2057, there', 'In 2060, there', 'In 2099, there']
Source: ['In 1976, there', 'In 1976, there', 'In 1976, there']
Tokens to intervene:
    Base: ['7', '0', '9', '7', '9', '4', '4', '4', '0', '3', '9', '3', '3', '3', '0', '9']
    Source: ['6', '6', '6', '6', '6', '6', '6', '6', '6', '6', '0', '0', '0', '0', '0', '0']
Outputs:
          Base Output: [' are', ' will', ' are', ' is', ' are', ' will', ' are', ' are', ' will', ' were', ' are', ' were', ' were', ' were', ' are', ' are']
Counterfactual Output: [' were', ' was', ' were', ' was', ' w

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

base model trainable parameters:  0
intervention trainable parameters:  8192


Epoch: 0:   0%|                                                    | 2/1280 [00:00<01:38, 12.97it/s, loss=2.7, acc=0.06]


Tokens to intervene:
['In 2044, there', 'In 2057, there', 'In 2059, there']
['In 2030, there', 'In 1850, there', 'In 1550, there']
Base: ['4', '7', '9', '5', '9', '0', '6', '7', '0', '3', '0', '3', '0', '0', '5', '3']
Source: ['0', '0', '0', '7', '3', '0', '0', '8', '3', '5', '0', '0', '0', '0', '9', '0']
Output: [' is', ' are', ' are', ' was', ' are', ' are', ' was', ' is', ' are', ' were', ' are', ' were', ' will', ' are', ' was', ' was']
Label     : [' will', ' were', ' were', ' is', ' were', ' were', ' will', ' will', ' were', ' were', ' were', ' were', ' was', ' will', ' is', ' will']
Base Label: [' is', ' are', ' are', ' was', ' are', ' are', ' was', ' is', ' are', ' are', ' are', ' are', ' will', ' are', ' was', ' was']

Tokens to intervene:
['In 2039, there', 'In 2040, there', 'In 2047, there']
['In 1630, there', 'In 2053, there', 'In 1570, there']
Base: ['9', '0', '7', '5', '0', '0', '0', '0', '0', '9', '7', '2', '8', '9', '0', '0']
Source: ['0', '3', '0', '7', '2', '7', '7',

Epoch: 0:   0%|                                                   | 2/1280 [00:00<01:38, 12.97it/s, loss=3.39, acc=0.06]


Tokens to intervene:
['In 1617, there', 'In 1763, there', 'In 2060, there']
['In 2042, there', 'In 2030, there', 'In 1812, there']
Base: ['7', '3', '0', '6', '3', '5', '9', '7', '7', '7', '0', '9', '7', '6', '9', '8']
Source: ['2', '0', '2', '0', '7', '2', '4', '5', '0', '2', '0', '0', '0', '0', '8', '8']
Output: [' was', ' was', ' will', ' was', ' were', ' was', ' was', ' is', ' are', ' are', ' are', ' was', ' was', ' was', ' were', ' was']
Label     : [' is', ' will', ' was', ' will', ' were', ' is', ' is', ' was', ' will', ' were', ' will', ' will', ' will', ' will', ' will', ' will']
Base Label: [' was', ' was', ' will', ' was', ' are', ' was', ' was', ' is', ' are', ' are', ' are', ' was', ' was', ' was', ' was', ' was']


Epoch: 0: 100%|████████████████████████████████████████████████| 1280/1280 [01:29<00:00, 14.27it/s, loss=0.72, acc=0.74]
Epoch: 100%|██████████████████████████████████████████████████████████████████████████████| 1/1 [01:29<00:00, 89.71s/it]


Model saved to /nlp/scr/hij/year_localization/OLMo-2-0425-1B-layer8_8-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output.pt


Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Test:   0%|          | 0/84 [00:00<?, ?it/s]


 'source-In 1976, there-correct-test': {'base_labels': {'base_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.5}, 'inv_outputs': {'accuracy': 0.0, 'token_accuracy': 0.0, 'class_0_accuracy': 0.0}, 'accuracy': 0.0}, 'labels': {'base_outputs': {'accuracy': 0.1, 'token_accuracy': 0.1, 'class_0_accuracy': 0.1}, 'inv_outputs': {'accuracy': 0.8, 'token_accuracy': 0.8, 'class_0_accuracy': 0.8}, 'accuracy': 0.8}}

Inputs:
Base: ['In 2057, there', 'In 2060, there', 'In 2099, there']
Source: ['In 1976, there', 'In 1976, there', 'In 1976, there']
Tokens to intervene:
    Base: ['7', '0', '9', '7', '9', '4', '4', '4', '0', '3', '9', '3', '3', '3', '0', '9']
    Source: ['6', '6', '6', '6', '6', '6', '6', '6', '6', '6', '0', '0', '0', '0', '0', '0']
Outputs:
          Base Output: [' are', ' will', ' are', ' is', ' are', ' will', ' are', ' are', ' will', ' were', ' are', ' were', ' were', ' were', ' are', ' are']
Counterfactual Output: [' were', ' was', ' were', ' were', ' 

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

Map:   0%|          | 0/20480 [00:00<?, ? examples/s]

base model trainable parameters:  0
intervention trainable parameters:  8192


Epoch: 0:   0%|                                                      | 2/1280 [00:00<01:36, 13.22it/s, loss=3.31, acc=0]


Tokens to intervene:
['In 1543, there', 'In 2039, there', 'In 1746, there']
['In 2099, there', 'In 1553, there', 'In 2047, there']
Base: ['3', '9', '6', '2', '2', '9', '3', '8', '0', '6', '6', '7', '5', '2', '9', '4']
Source: ['9', '3', '7', '9', '4', '6', '0', '9', '8', '0', '7', '9', '9', '4', '6', '3']
Output: [' was', ' are', ' was', ' was', ' was', ' are', ' were', ' will', ' are', ' was', ' were', ' are', ' were', ' were', ' are', ' are']
Label     : [' is', ' were', ' is', ' is', ' will', ' were', ' will', ' is', ' were', ' will', ' are', ' were', ' are', ' are', ' were', ' were']
Base Label: [' was', ' are', ' was', ' was', ' was', ' are', ' are', ' will', ' are', ' was', ' were', ' are', ' were', ' were', ' are', ' are']

Tokens to intervene:
['In 2042, there', 'In 2054, there', 'In 2097, there']
['In 1763, there', 'In 1799, there', 'In 1740, there']
Base: ['2', '4', '7', '0', '2', '6', '4', '0', '0', '2', '0', '2', '0', '1', '0', '0']
Source: ['3', '9', '0', '9', '4', '9', '

Epoch: 0:   0%|                                                   | 2/1280 [00:00<01:36, 13.22it/s, loss=3.34, acc=0.02]


Tokens to intervene:
['In 1818, there', 'In 1818, there', 'In 2049, there']
['In 2054, there', 'In 2049, there', 'In 2028, there']
Base: ['8', '8', '9', '1', '2', '7', '2', '0', '0', '3', '2', '0', '8', '7', '4', '0']
Source: ['4', '9', '8', '4', '5', '6', '4', '8', '0', '5', '7', '4', '7', '0', '0', '8']
Output: [' was', ' was', ' is', ' was', ' are', ' are', ' was', ' are', ' will', ' were', ' was', ' will', ' will', ' is', ' was', ' are']
Label     : [' is', ' is', ' will', ' is', ' were', ' were', ' will', ' were', ' was', ' were', ' is', ' was', ' is', ' will', ' will', ' were']
Base Label: [' was', ' was', ' is', ' was', ' are', ' are', ' was', ' are', ' will', ' are', ' was', ' will', ' will', ' is', ' was', ' are']


Epoch: 0: 100%|████████████████████████████████████████████████| 1280/1280 [01:27<00:00, 14.67it/s, loss=0.82, acc=0.71]
Epoch: 100%|██████████████████████████████████████████████████████████████████████████████| 1/1 [01:27<00:00, 87.24s/it]


Model saved to /nlp/scr/hij/year_localization/OLMo-2-0425-1B-layer9_9-dim4-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output.pt


Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Map:   0%|          | 0/1330 [00:00<?, ? examples/s]

Test:   0%|          | 0/84 [00:00<?, ?it/s]


 'source-In 1976, there-correct-test': {'base_labels': {'base_outputs': {'accuracy': 0.9, 'token_accuracy': 0.9, 'class_0_accuracy': 0.5}, 'inv_outputs': {'accuracy': 0.0, 'token_accuracy': 0.0, 'class_0_accuracy': 0.0}, 'accuracy': 0.0}, 'labels': {'base_outputs': {'accuracy': 0.1, 'token_accuracy': 0.1, 'class_0_accuracy': 0.1}, 'inv_outputs': {'accuracy': 0.8, 'token_accuracy': 0.8, 'class_0_accuracy': 0.8}, 'accuracy': 0.8}}

Inputs:
Base: ['In 2057, there', 'In 2060, there', 'In 2099, there']
Source: ['In 1976, there', 'In 1976, there', 'In 1976, there']
Tokens to intervene:
    Base: ['7', '0', '9', '7', '9', '4', '4', '4', '0', '3', '9', '3', '3', '3', '0', '9']
    Source: ['6', '6', '6', '6', '6', '6', '6', '6', '6', '6', '0', '0', '0', '0', '0', '0']
Outputs:
          Base Output: [' are', ' will', ' are', ' is', ' are', ' will', ' are', ' are', ' will', ' were', ' are', ' were', ' were', ' were', ' are', ' are']
Counterfactual Output: [' were', ' was', ' were', ' were', ' 

In [None]:
# Accuracy

print(f'inv dim: 4')
for l in range(10):
  iia = []
  for d in [4]:
    split_to_eval_metrics = json.load(open(f'year_localization/OLMo-2-0425-1B-layer{l}_{l}-dim{d}-daslora_baseline_1tok_das-mmlu_id-1_len8_pose_ep1_example456_block_output_evalall.json'))
    iia.append(np.mean([v['metrics']['labels']['inv_outputs']['accuracy']
                        for k, v in split_to_eval_metrics.items() if '-correct' in k]).tolist())
  print(f'layer={l}\t' + '\t'.join(map(lambda x: f'{x:.2f}', iia)))

inv dim: 4
layer=0	0.88
layer=1	0.89
layer=2	0.93
layer=3	0.95
layer=4	0.99
layer=5	0.96
layer=6	0.95
layer=7	0.93
layer=8	0.92
layer=9	0.82
