In [None]:
import json

id_e2_to_queries = {}  # {"<e_581>": ["<e_1525><r_123>", "<e_98><r_158>", ...],...}
ood_e2_to_queries = {}  # {"<e_581>": ["<e_1915><r_25>", ...], ...}

with open('../data/base_configuration.2000.200.7.2/train.json', 'r') as f:
    data = json.load(f)

# Split the atomic triples into ID and OOD by ID/OOD ratio
id_data = data[:38000]         # 2000 * 20 * 0.95
ood_data = data[38000:40000]   # 2000 * 20 * 0.05

for data in id_data:
    e1, r, e2 = data["target_text"].strip('<>').split('><')[:-1]
    if f"<{e2}>" not in id_e2_to_queries:
        id_e2_to_queries[f"<{e2}>"] = []
    id_e2_to_queries[f"<{e2}>"].append(f"<{e1}><{r}>")

for data in ood_data:
    e1, r, e2 = data["target_text"].strip('<>').split('><')[:-1]
    if f"<{e2}>" not in ood_e2_to_queries:
            ood_e2_to_queries[f"<{e2}>"] = []
    ood_e2_to_queries[f"<{e2}>"].append(f"<{e1}><{r}>")

In [None]:
from transformers import GPT2LMHeadModel
import torch
import torch.nn.functional as F

def compute_similarity_metrics(model, tokenizer, target_layer, id_e2_to_queries, ood_e2_to_queries, device):
    id_similarities = []  # ID Cohesion
    ood_similarities = [] # OOD Alignment

    target_index = 1  # r1 position

    model.to(device)
    model.eval()

    # ID-derived e2 in id_e2_to_queries.keys() probably not in ood_e2_to_queries.keys()
    for e2 in ood_e2_to_queries.keys():

        id_queries = id_e2_to_queries.get(e2, [])
        ood_queries = ood_e2_to_queries.get(e2, [])

        if len(id_queries) == 0 or len(ood_queries) == 0:
            continue

        id_hiddens = []
        for query in id_queries:
            inputs = tokenizer([query], return_tensors="pt", padding=True).to(device)
            
            with torch.no_grad():
                outputs = model(**inputs, output_hidden_states=True)
            hidden = outputs.hidden_states[target_layer][0, target_index, :]

            id_hiddens.append(hidden)

        id_hiddens = torch.stack(id_hiddens)
        id_centroid = id_hiddens.mean(dim=0)

        id_cos_sims = F.cosine_similarity(id_hiddens, id_centroid.unsqueeze(0), dim=1)
        id_avg_sim = id_cos_sims.mean().item()
        id_similarities.append(id_avg_sim)

        ood_hiddens = []
        for query in ood_queries:
            inputs = tokenizer([query], return_tensors="pt", padding=True).to(device)
            with torch.no_grad():
                outputs = model(**inputs, output_hidden_states=True)
            hidden = outputs.hidden_states[target_layer][0, target_index, :]
            ood_hiddens.append(hidden)

        ood_hiddens = torch.stack(ood_hiddens)
        ood_cos_sims = F.cosine_similarity(ood_hiddens, id_centroid.unsqueeze(0), dim=1)
        ood_avg_sim = ood_cos_sims.mean().item()
        ood_similarities.append(ood_avg_sim)

    mean_id = sum(id_similarities) / len(id_similarities)
    mean_ood = sum(ood_similarities) / len(ood_similarities)
    return mean_id, mean_ood


In [None]:
import os
import re

base_dir = "/your/checkpoints/directory"  # replace with your directory
checkpoint_prefix = "checkpoint-"

all_ckpts = [
    os.path.join(base_dir, d) for d in os.listdir(base_dir)
    if d.startswith(checkpoint_prefix) and os.path.isdir(os.path.join(base_dir, d))
]

min_step_interval = 2000
start_step = 0         

# extract step numbers from checkpoint paths
ckpt_tuples = []
for path in all_ckpts:
    match = re.search(r"checkpoint-(\d+)", path)
    if match:
        step = int(match.group(1))
        if step >= start_step:
            ckpt_tuples.append((step, path))

ckpt_tuples.sort()

# sample checkpoints based on the minimum step interval
selected_ckpts = []
last_step = -min_step_interval
for step, path in ckpt_tuples:
    if step - last_step >= min_step_interval:
        selected_ckpts.append((step, path))
        last_step = step


In [None]:
from transformers import GPT2Tokenizer
from tqdm import tqdm

device = "cuda:0"

target_layer = 5  # determined by cross-query semantic patching

# Load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(selected_ckpts[0][1])
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

results = []

for label, path in tqdm(selected_ckpts):
    print(f"Processing {label}...")
    model = GPT2LMHeadModel.from_pretrained(path).to(device)
    id_sim, ood_sim = compute_similarity_metrics(
        model, tokenizer, target_layer,
        id_e2_to_queries, ood_e2_to_queries, device
    )
    results.append((label, (id_sim, ood_sim)))
    del model
    torch.cuda.empty_cache()

print(results)
