In [1]:
import numpy as np
import os
import json
from sklearn.linear_model import LinearRegression
from tqdm.notebook import tqdm
from transformers import AutoTokenizer
from scipy.stats import pearsonr, spearmanr

In [2]:
domain2prompt = {
    "city": "{} lives in the city of",
    "country": "{} lives in the country of",
         }

In [3]:
def is_full_token(x):
    return tokenizer.convert_tokens_to_string([tokenizer.tokenize(x)[0]]).strip() == x.strip().split()[0]

def correlation(Y1_train, Y2_train, Y1_test, Y2_test):

    weights = np.zeros((Y1_train.shape[-1], Y2_train.shape[-1]))
    bias = np.zeros((1, Y2_train.shape[-1]))
    
    for i in tqdm(range(Y2_train.shape[1])):
        model = LinearRegression()
        model.fit(Y1_train, Y2_train[:, i])
        weights[:, i] = model.coef_ 
        bias[0, i] = model.intercept_ 
    
    Y2_pred = Y1_test @ weights + bias

    metrics = [pearsonr(Y2_pred[:, idx], Y2_test[:, idx]).statistic for idx in range(len(Y2_names))]
    
    avg, std = np.mean(metrics), np.std(metrics)

    return metrics, avg, std, weights, bias

In [4]:
def get_logit(prompt, generator):
    items = tokenizer(prompt, return_tensors="pt").to(device)
    logit = generator(**items).logits[0, -1]
    return logit

In [5]:
genid2model_path = {
    "8b-pt":"meta-llama/Meta-Llama-3-8B",
    "8b-it":"meta-llama/Meta-Llama-3-8B-Instruct",
}

In [6]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [7]:
device = torch.device("cuda:0")
model_id = "8b-pt"
Y1, Y2 = "city", "country"

model_path = genid2model_path[model_id]
tokenizer = AutoTokenizer.from_pretrained(model_path)
generator = AutoModelForCausalLM.from_pretrained(model_path).to(device)
names = np.random.choice(list(tokenizer.get_vocab().keys()), 1000)

cache = {}

Y1_domain = [" "+y.strip() for y in json.load(open(f"{Y1}.json"))]
Y2_domain = [" "+y.strip() for y in json.load(open(f"{Y2}.json"))]

Y1_domain = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(y)[0]) for y in Y1_domain if is_full_token(y)]
Y2_domain = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(y)[0]) for y in Y2_domain if is_full_token(y)]

Y1_domain = list(set(Y1_domain))
Y2_domain = list(set(Y2_domain))

Y1_prompt = domain2prompt[Y1]
Y1_logits = np.stack([get_logit(Y1_prompt.format(name), generator).detach().cpu().numpy() for name in tqdm(names)])
Y1_logits = Y1_logits[:, Y1_domain]

Y2_prompt = domain2prompt[Y2]
Y2_logits = np.stack([get_logit(Y2_prompt.format(name), generator).detach().cpu().numpy() for name in tqdm(names)])
Y2_logits = Y2_logits[:, Y2_domain]

Y1_names = [tokenizer.decode(y1) for y1 in Y1_domain]
Y2_names = [tokenizer.decode(y2) for y2 in Y2_domain]

metrics, avg, std, weights, bias = correlation(Y1_logits[::2], Y2_logits[::2], Y1_logits[1::2], Y2_logits[1::2])
print(Y1, Y2, avg, std)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/128 [00:00<?, ?it/s]

city country 0.7807606007100008 0.04276502189084191


In [8]:
Y1_name = " Tokyo"
idx = Y1_names.index(Y1_name)
top_jds = weights[idx].argsort()[::-1][:5]
top_Y2_names = [Y2_names[jdx] for jdx in top_jds]
print(Y1_name, top_Y2_names)

 Tokyo [' Japan', ' Luxembourg', ' Netherlands', ' Belgium', ' Nederland']
