# Setup

In [6]:
%load_ext autoreload
%autoreload 2

In [7]:
#################################
### CONFIGURE BEFORE YOUR RUN ###
#################################

USER = "aditijb"

# INV_LAYERS = [0, 1, 2, 3, 4, 5, 6, 7, 8]
INV_LAYERS = [0, 1, 2]
INV_DIMS = [8]

# model_id = "suzeva/olmo2-1b-4xH100-2ndtry"
# revision = "step-10000"

model_id = "allenai/OLMo-2-0425-1B"
revision = "main"
# revision = 'stage1-step10000-tokens21B

run_description = 'logits'
EXPERIMENT_NAME = f"year_localization_{model_id.replace('/', '_')}_{revision}_{run_description}" 

prompt_template = ['In {year} there']


In [8]:
import getpass
import os
import sys

# updated by aditi
FOLDER = f'olmo_das_2/{EXPERIMENT_NAME}'
META_DATA_FOLDER = EXPERIMENT_NAME + "/metadata"

PROJECT_NAME = 'ood-prediction'
DATA_DIR = f'/nlp/scr/{USER}/data'
MODEL_DIR = f'/nlp/scr/{USER}/models'

sys.path.append(f'/nlp/scr/{USER}/{PROJECT_NAME}/src')
os.environ["HF_HOME"] = f'/nlp/scr/{USER}/models'
os.environ["HF_HUB"] = f'/nlp/scr/{USER}/models'

CORE_LIB_DIR = f'/nlp/scr/hij/core'
RAVEL_LIB_DIR = f'/nlp/scr/hij/internal-ravel/src'
PYVENE_LIB_DIR = f'/nlp/scr/hij/pyvene'
import sys
sys.path.append(CORE_LIB_DIR)
sys.path.append(RAVEL_LIB_DIR)
sys.path.append(PYVENE_LIB_DIR)

In [9]:
import numpy as np
import random
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(0)

In [10]:
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Models

In [11]:
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

# SET THE MODEL INFO IN THE TOP CONFIG BLOCK
print("model_id", model_id)
print("revision", revision)

tokenizer = AutoTokenizer.from_pretrained(
    model_id, padding_side='left', revision=revision,
    cache_dir=MODEL_DIR)
tokenizer.pad_token_id = tokenizer.eos_token_id

model = AutoModelForCausalLM.from_pretrained(
      model_id, low_cpu_mem_usage=True, device_map='auto',
      revision=revision,
      torch_dtype=torch.bfloat16, cache_dir=MODEL_DIR)
model = model.eval()

model_id allenai/OLMo-2-0425-1B
revision main


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Behavioral Testing

In [12]:
# Added by Aditi to store all the config and years info somewhere, for future reference.


import os, json

# make sure folder exists
os.makedirs(META_DATA_FOLDER, exist_ok=True)
print("metadata folder", META_DATA_FOLDER)

# collect config info
config_dict = {
    "prompt_template": prompt_template,
    "FOLDER": FOLDER,
    "META_DATA_FOLDER": META_DATA_FOLDER,
    "PROJECT_NAME": PROJECT_NAME,
    "DATA_DIR": DATA_DIR,
    "MODEL_DIR": MODEL_DIR,
    "model_id": model_id,
    "revision": revision,
    "INV_LAYERS": INV_LAYERS,
    "INV_DIMS": INV_DIMS,    }

# save config
config_file = os.path.join(META_DATA_FOLDER, "config.json")
with open(config_file, "w") as f:
    json.dump(config_dict, f, indent=2)

print(f"Config saved to: {config_file}")


metadata folder year_localization_allenai_OLMo-2-0425-1B_main_logits/metadata
Config saved to: year_localization_allenai_OLMo-2-0425-1B_main_logits/metadata/config.json


In [19]:
import json
from generation_utils import generate_distribution_batched

prompts = [prompt_template[0].format(year=year) for year in range(1000, 3000)]
len(prompts)

print("prompt template", prompt_template)
outputs = generate_distribution_batched(model, tokenizer, [prompt_template[0].format(year=i) for i in range(1000, 3000)])

print("first output: ", outputs[0])

# for i in range(1000, 3000):
#     print(i, prompt_template[0].format(year=i), outputs[i-1000])

output_file = f'{META_DATA_FOLDER}/year_{model_id.split("/")[-1]}-revision{revision}_model_next_token_outputs.json'
with open(output_file, 'w') as f:
    json.dump(outputs, f)

# 1001 In 1001, there [(' was', 0.31640625), (' were', 0.296875), (' is', 0.1201171875), (' are', 0.11279296875), (' will', 0.038818359375), ("'s", 0.01953125), (' would', 0.01190185546875), (' had', 0.00982666015625), ('’s', 0.00897216796875), (' have', 0.006561279296875)]
# 1002 In 1002, there [(' was', 0.466796875), (' were', 0.310546875), (' is', 0.09521484375), (' are', 0.0478515625), (' had', 0.010498046875), ("'s", 0.008544921875), (' lived', 0.0037689208984375), (' occurred', 0.0037078857421875), (' came', 0.0036773681640625), (' would', 0.0034942626953125)]
# 1003 In 1003, there [(' was', 0.47265625), (' were', 0.294921875), (' is', 0.10546875), (' are', 0.04736328125), (' had', 0.0111083984375), ("'s", 0.00836181640625), (' lived', 0.0052490234375), (' came', 0.0038299560546875), (' occurred', 0.003631591796875), (' will', 0.0029144287109375)]

# Map specific first words to tense category
word_map = {
    " was": "past",
    " were": "past",
    " is": "presfut",
    " are": "presfut",
    " will": "presfut",
    "was": "past",
    "were": "past",
    "is": "presfut",
    "are": "presfut",
    "will": "presfut",
}

# build year → dict
categorized_outputs = {
    year: {
        "first_word": preds[0][0],
        "category": word_map.get(preds[0][0], "other"),
    }
    for year, preds in zip(range(1000, 3000), outputs)
}

print("categorized_outputs", categorized_outputs)


prompt template ['In {year} there']


100%|██████████| 63/63 [00:02<00:00, 30.86it/s]

first output:  [(' were', 0.73046875), (' was', 0.1015625), (' are', 0.08837890625), (' is', 0.0167236328125), (' will', 0.01129150390625), (' had', 0.0086669921875), (' would', 0.0074462890625), (' have', 0.005035400390625), (' lived', 0.004180908203125), (' has', 0.0025482177734375)]
categorized_outputs {1000: {'first_word': ' were', 'category': 'past'}, 1001: {'first_word': ' was', 'category': 'past'}, 1002: {'first_word': ' was', 'category': 'past'}, 1003: {'first_word': ' was', 'category': 'past'}, 1004: {'first_word': ' was', 'category': 'past'}, 1005: {'first_word': ' was', 'category': 'past'}, 1006: {'first_word': ' was', 'category': 'past'}, 1007: {'first_word': ' was', 'category': 'past'}, 1008: {'first_word': ' was', 'category': 'past'}, 1009: {'first_word': ' was', 'category': 'past'}, 1010: {'first_word': ' was', 'category': 'past'}, 1011: {'first_word': ' was', 'category': 'past'}, 1012: {'first_word': ' was', 'category': 'past'}, 1013: {'first_word': ' was', 'category': 




In [20]:
# GOLD DISTRIBUTION

GOLD_YEAR_MAPPING = {
    year: ("past" if year <= 2023 else "presfut")
    for year in range(1000, 3001) if year != 2024
}

print(GOLD_YEAR_MAPPING)

{1000: 'past', 1001: 'past', 1002: 'past', 1003: 'past', 1004: 'past', 1005: 'past', 1006: 'past', 1007: 'past', 1008: 'past', 1009: 'past', 1010: 'past', 1011: 'past', 1012: 'past', 1013: 'past', 1014: 'past', 1015: 'past', 1016: 'past', 1017: 'past', 1018: 'past', 1019: 'past', 1020: 'past', 1021: 'past', 1022: 'past', 1023: 'past', 1024: 'past', 1025: 'past', 1026: 'past', 1027: 'past', 1028: 'past', 1029: 'past', 1030: 'past', 1031: 'past', 1032: 'past', 1033: 'past', 1034: 'past', 1035: 'past', 1036: 'past', 1037: 'past', 1038: 'past', 1039: 'past', 1040: 'past', 1041: 'past', 1042: 'past', 1043: 'past', 1044: 'past', 1045: 'past', 1046: 'past', 1047: 'past', 1048: 'past', 1049: 'past', 1050: 'past', 1051: 'past', 1052: 'past', 1053: 'past', 1054: 'past', 1055: 'past', 1056: 'past', 1057: 'past', 1058: 'past', 1059: 'past', 1060: 'past', 1061: 'past', 1062: 'past', 1063: 'past', 1064: 'past', 1065: 'past', 1066: 'past', 1067: 'past', 1068: 'past', 1069: 'past', 1070: 'past', 1071:

In [23]:
# Filter the years down to those where the gold year label (from your GOLD_YEAR_MAPPING) matches the category of the top predicted verb (from categorized_outputs).

matching_years = [
    year
    for year in range(1000, 3000)
    if GOLD_YEAR_MAPPING.get(year) == categorized_outputs[year]["category"]
]

past_years = [year for year in matching_years if categorized_outputs[year]["category"] == "past"]
presfut_years = [year for year in matching_years if categorized_outputs[year]["category"] == "presfut"]

print("matching_years\n", matching_years)
print("past_years\n", len(past_years), past_years)
print("presfut_years\n", len(presfut_years), presfut_years)

matching_years
 [1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059, 1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115, 1116, 1117, 1118, 1119, 1120, 1121, 1122, 1123, 1124, 1125, 1126, 1127, 1128, 1129, 1130, 1131, 1132, 1133, 1134, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143, 1144, 1145, 1146, 1147, 1148, 1149, 1150, 1151, 1152, 1153, 1154, 1155, 1156, 1157, 1158, 1159, 1160, 1161, 1162, 1163,