In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from kmeans_pytorch import kmeans
import torch
from sklearn.metrics import f1_score

import openai
from constant import openai_key
import json
import re
import numpy as np
from tqdm.notebook import tqdm
from copy import deepcopy

openai.api_key = openai_key
model_engine = "gpt-4o" 

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
generator = AutoModelForCausalLM.from_pretrained(model_path)
generator.eval()
generator = generator.to("cuda")

In [None]:
probe_cache = {}

def probe(seq):
    if seq in probe_cache:
        return probe_cache[seq]
        
    for _ in range(64):
        logits = generator(**tokenizer(seq, return_tensors="pt").to("cuda")).logits[0, -1]
        probs = logits.softmax(-1)
        idx = logits.argsort(descending=True)[0]
        next_token = tokenizer.decode(idx)
        seq += next_token
        if "\"" in next_token:
            break

    if len(re.findall("\"(.*)?\"", seq)) > 0:
        ans = re.findall("\"(.*)?\"", seq)[-1]
    else:
        ans = "nothing"
        
    probe_cache[seq] = ans

    return ans

def score(answer, instruction):

    if answer in cache:
        return cache[answer]

    response = openai.ChatCompletion.create(
    model=model_engine,
    temperature=0.0,
    messages=[{"role": "user", "content": f'''{instruction}\n\nIs "{answer}" considered as among the correct answers? Answer only "Yes" or "No".'''}],
    ).choices[0]['message']["content"]
    sleep(1.0)
    if "yes" in response.lower():
        cache[answer] = 1
        return 1
    else:
        cache[answer] = 0
        return 0

def verbalize_escape(escape_list):
    return "".join(f"{idx+1}. \"{escape}\"\n" for idx, escape in enumerate(escape_list)) + f"{1+len(escape_list)}. \""

def mapk(res, k=10):
    
    return np.mean([np.mean(res[:i+1]) for i in range(k)])

In [None]:
generator.eval()
N = 100
K = 10

In [None]:
model_name = "llama3_8b"
cluster_ids_x = torch.load(f'cluster/cluster_ids_x.{model_name}.pt').to("cuda")
cluster_centers = torch.load(f'cluster/cluster_centers.{model_name}.pt').to("cuda")

In [None]:
def probe_exp(instruction, escape_list, N, K, reg):

    main_cluster_ids = []
    main_cluster_golds = []
    main_cluster_ranks = []
    other_cluster_golds = []
    other_cluster_ranks = []
    golds = []
    
    res = 0
    
    init_prompt = f'''{instruction}

Answer:
{verbalize_escape(escape_list)}'''
    
    with torch.no_grad():
    
        prompt = f'''{instruction}

Answer:
{verbalize_escape(escape_list)}'''
    
        items = generator(**tokenizer(prompt, return_tensors="pt").to("cuda"), output_hidden_states=True)
        initial_state = items.hidden_states[-1][0, -1]
        lm_head_weight = generator.lm_head.weight
        # lm_head_weight = torch.nn.functional.normalize(lm_head_weight)
        logits = (initial_state.unsqueeze(0) * lm_head_weight).sum(-1)
        probs = logits.softmax(-1)
        ids = logits.argsort(descending=True)
        next_tokens = [tokenizer.decode(idx) for idx in ids]
        X = generator.lm_head.weight[ids]
    
        cnt = 0
    
        bar = tqdm(zip(X, ids, next_tokens), total=N)
        
        for x, idx, next_token in bar:
            cnt += 1
            idx = idx.item()
            cluster_idx = cluster_ids_x[idx].item()
            if reg:
                entity = probe(init_prompt+next_token)
            else:
                entity = probe(prompt+next_token)
            gold = score(entity, instruction)
            golds.append(gold)
            escape_list.append(entity)
    
            if cnt <= K:
                main_cluster_ids.append(cluster_idx)
            else:
                if cluster_idx in main_cluster_ids:
                    main_cluster_golds.append(gold)
                    main_cluster_ranks.append(cnt)
                else:
                    other_cluster_golds.append(gold)
                    other_cluster_ranks.append(cnt)
    
            if len(main_cluster_golds) > 0 and len(other_cluster_golds) > 0:
                avg_main = np.mean(main_cluster_golds)
                avg_other = np.mean(other_cluster_golds)
                avg_main_rank = np.mean(main_cluster_ranks)
                avg_other_rank = np.mean(other_cluster_ranks)
                avg_main_prop = len(main_cluster_golds)/(len(main_cluster_golds) + len(other_cluster_golds))
                avg_other_prop = len(other_cluster_golds)/(len(main_cluster_golds) + len(other_cluster_golds))
                avg_top = np.mean(golds[:K])
                bar.set_description(f"G Main = {avg_main*100:.4}, G Other = {avg_other*100:.4}, R Main = {avg_main_rank:.4}, R Other  = {avg_other_rank:.4}, P Main = {avg_main_prop:.4}, P Other  = {avg_other_prop:.4}")
                
            if cnt == N:
                break
                
    return avg_main, avg_other, avg_main_rank, avg_other_rank, avg_main_prop, avg_other_prop

In [None]:
RES = {}

instruction_escape_list = [
    ("Please show me some sports leagues.", []),
    ("Please show me some basketball sports leagues.", []),
    ("Please show me some baseball sports leagues.", []),
    ("Please show me some USA sports leagues.", []), 
    ("Please show me some European sports leagues.", []),
]

for instruction_escape in instruction_escape_list:

    instruction, escape_list = instruction_escape
    
    cache = {}
    reg=True
    
    avg_main, avg_other, avg_main_rank, avg_other_rank, avg_main_prop, avg_other_prop = probe_exp(instruction, escape_list, N, K, reg)
    res = {"avg_main": avg_main, 
    "avg_other": avg_other, 
    "avg_main_rank": avg_main_rank, 
    "avg_other_rank": avg_other_rank,
    "avg_main_prop": avg_main_prop, 
    "avg_other_prop": avg_other_prop}

    RES[instruction] = res

print(RES)