In [None]:
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
import torch

model_path1 = "phase_I_model_path"
model_path2 = "phase_II_model_path"
model_path3 = "phase_III_model_path"

device = "cuda:0"

tokenizer = GPT2Tokenizer.from_pretrained(model_path1)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

model1 = GPT2LMHeadModel.from_pretrained(model_path1).to(device)
model2 = GPT2LMHeadModel.from_pretrained(model_path2).to(device)
model3 = GPT2LMHeadModel.from_pretrained(model_path3).to(device)

In [None]:
import json
with open('../data/base_configuration.2000.200.7.2/test.json', 'r') as f:
    datas = json.load(f)

# queries and answers(e2)
id_queries = []  # [(query, e2), ...]
ood_queries = []

for data in datas:    
    if data["type"] == "ood_atomic":
        e1, r, e2 = data["target_text"].strip('<>').split('><')[:-1]
        ood_queries.append((data["input_text"], f"<{e2}>"))
    if data["type"] == "id_atomic":
        e1, r, e2 = data["target_text"].strip('<>').split('><')[:-1]
        id_queries.append((data["input_text"], f"<{e2}>"))

### Immediate probing 

In [None]:
import torch.nn.functional as F
from tqdm import tqdm

# TODO: change the model and query type
model = model3
queries = ood_queries

word_embedding = model.lm_head.weight.data
model.config.pad_token_id = model.config.eos_token_id

correct_cnt = 0

# layer 5 at r1 positon
target_token_index = 1
target_layer = 5

for query, target in tqdm(queries):
    decoder_temp = tokenizer([query], return_tensors="pt", padding=True)
    decoder_input_ids, decoder_attention_mask = decoder_temp["input_ids"], decoder_temp["attention_mask"]
    decoder_input_ids, decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)

    with torch.no_grad():
        outputs = model(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            output_hidden_states=True
        )

    all_hidden_states = outputs.hidden_states
    word_embedding = model.lm_head.weight

    # decode
    target_hidden_state = all_hidden_states[target_layer][:, target_token_index, :]  # (batch_size, hidden_dim)
    logits = torch.matmul(target_hidden_state, word_embedding.T)  # (batch_size, vocab_size)
    next_token = torch.argmax(logits, dim=-1)  # (batch_size,)

    # check
    if tokenizer.decode(next_token.item()) == target:
        correct_cnt += 1

print(correct_cnt / len(id_queries))

### Full-run probing

In [None]:
import torch.nn.functional as F
from tqdm import tqdm

# TODO: change the model and query type
model = model3
queries = id_queries

word_embedding = model.lm_head.weight.data
model.config.pad_token_id = model.config.eos_token_id

target_query = queries[-1]  # ramdom or fixed

correct_cnt = 0

# layer 5 at r1 positon
target_token_index = 1
target_layer = 5

for query, target in tqdm(queries[:-1]):

    decoder_temp = tokenizer([query], return_tensors="pt", padding=True)
    decoder_input_ids, decoder_attention_mask = decoder_temp["input_ids"], decoder_temp["attention_mask"]
    decoder_input_ids, decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)

    with torch.no_grad():
        outputs1 = model(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            output_hidden_states=True
        )

    hidden_states_batch = outputs1.hidden_states  # [1+num_layers, batch_size, seq_len, hidden_size]

    def hook_fn(module, input, output):
        main_output = output[0].clone()
        
        main_output[0, target_token_index, :] = hidden_states_batch[target_layer][0, target_token_index, :]

        return (main_output,) + output[1:]

    handle = model.transformer.h[target_layer].register_forward_hook(hook_fn)

    decoder_temp = tokenizer([target_query], return_tensors="pt", padding=True)
    decoder_input_ids, decoder_attention_mask = decoder_temp["input_ids"], decoder_temp["attention_mask"]
    target_decoder_input_ids, target_decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)

    with torch.no_grad():
        outputs = model(
            input_ids=target_decoder_input_ids,
            attention_mask=target_decoder_attention_mask,
        )

    handle.remove()

    # decode
    logits = outputs.logits  # [batch_size, seq_len, vocab_size]
    predicted_token_ids = torch.argmax(logits, dim=-1)  # [batch_size, seq_len]
    decoded_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)
    decoded_token = decoded_text[0].split()[1]

    # check
    if decoded_token == target:
        correct_cnt += 1

print(correct_cnt / len(id_queries[:-1]))