# Setup

In [1]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from mamba2mini import Mamba2LMHeadModel
from transformers import AutoTokenizer

In [2]:
device = "cuda"
hf_dir = "/home/fodl/slutzky1/.cache/huggingface"
tri_dir = "/home/fodl/slutzky1/.cache/triton"
xdg_dir = "/home/fodl/slutzky1/.cache/xdg"
model_name = "state-spaces/mamba2-1.3b"

In [3]:
os.environ['HF_HOME'] = hf_dir
os.environ['TRITON_CACHE_DIR'] = tri_dir
os.environ['XDG_CACHE_HOME'] = xdg_dir

# Predict

In [4]:
original_data = pd.read_parquet('original_data.parquet')
original_data['true_prob'] = 0.0
original_data['false_prob'] = 0.0
original_data['hit'] = False

In [5]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", cache_dir=hf_dir, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

In [6]:
model = Mamba2LMHeadModel.from_pretrained(model_name, device=device)

In [7]:
torch.random.manual_seed(0)
model.eval()
temperature = 1
top_k = 0
top_p = 1
attention = False

In [8]:
def forward_eval(temperature, top_k, top_p, batch_start, batch_end, attention, num_to_masks=None):
    prompts = list(original_data.loc[batch_start:batch_end-1, 'prompt'].values)
    true_word = list(original_data.loc[batch_start:batch_end-1, 'target_true'].values)
    false_word = list(original_data.loc[batch_start:batch_end-1, 'target_false'].values)
    true_token = tokenizer(true_word, return_tensors="pt", padding=True)
    false_token = tokenizer(false_word, return_tensors="pt", padding=True)
    true_id = true_token.input_ids.to(device='cpu')
    false_id = false_token.input_ids.to(device='cpu')
    tokens = tokenizer(prompts, return_tensors="pt", padding=True)
    input_ids = tokens.input_ids.to(device=device)
    #attn_mask = tokens.attention_mask.to(device=device)
    max_new_length = input_ids.shape[1] + 1
    fn = lambda: model.generate_single(
        input_ids=input_ids,
        max_new_length=max_new_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token,
        attention=attention,
        num_to_masks=num_to_masks,
    )
    out = fn()
    next_token_probs = out[-1].detach().cpu().numpy()
    row_idx = np.arange(next_token_probs.shape[0]) 
    original_data.loc[batch_start:batch_end-1, 'true_prob'] = next_token_probs[row_idx, true_id[:, 0]]
    original_data.loc[batch_start:batch_end-1, 'false_prob'] = next_token_probs[row_idx, false_id[:, 0]]
    original_data.loc[batch_start:batch_end-1, 'hit'] = original_data.loc[batch_start:batch_end-1, 'true_prob'] > original_data.loc[batch_start:batch_end-1, 'false_prob']
    print(f'Finished batch [{batch_start}:{batch_end-1}]')
    torch.cuda.empty_cache()

In [9]:
batch_size = 64
N = len(original_data)
batches = list(np.arange(0, N, batch_size)) + [N]

In [10]:
forward_eval(temperature, top_k, top_p, batches[len(batches)-2], batches[len(batches)-1], attention)

Finished batch [21888:21918]


In [11]:
for i in range(len(batches)-2):
    forward_eval(temperature, top_k, top_p, batches[i], batches[i+1], attention)

Finished batch [0:63]
Finished batch [64:127]
Finished batch [128:191]
Finished batch [192:255]
Finished batch [256:319]
Finished batch [320:383]
Finished batch [384:447]
Finished batch [448:511]
Finished batch [512:575]
Finished batch [576:639]
Finished batch [640:703]
Finished batch [704:767]
Finished batch [768:831]
Finished batch [832:895]
Finished batch [896:959]
Finished batch [960:1023]
Finished batch [1024:1087]
Finished batch [1088:1151]
Finished batch [1152:1215]
Finished batch [1216:1279]
Finished batch [1280:1343]
Finished batch [1344:1407]
Finished batch [1408:1471]
Finished batch [1472:1535]
Finished batch [1536:1599]
Finished batch [1600:1663]
Finished batch [1664:1727]
Finished batch [1728:1791]
Finished batch [1792:1855]
Finished batch [1856:1919]
Finished batch [1920:1983]
Finished batch [1984:2047]
Finished batch [2048:2111]
Finished batch [2112:2175]
Finished batch [2176:2239]
Finished batch [2240:2303]
Finished batch [2304:2367]
Finished batch [2368:2431]
Finished 

In [12]:
original_data.head()

Unnamed: 0,relation,relation_prefix,relation_suffix,prompt,relation_id,target_false_id,target_true_id,target_true,target_false,subject,true_prob,false_prob,hit
0,The mother tongue of {} is,The mother tongue of,is,The mother tongue of Danielle Darrieux is,P103,Q1860,Q150,French,English,Danielle Darrieux,0.0009692109,0.0006011273,True
1,The official religion of {} is,The official religion of,is,The official religion of Edwin of Northumbria is,P140,Q432,Q5043,Christianity,Islam,Edwin of Northumbria,0.0001198434,6.332863e-06,True
2,"{}, the",,"{}, the","Toko Yasuda, the",P1303,Q5994,Q6607,guitar,piano,Toko Yasuda,2.344575e-07,8.996081e-07,False
3,"{}, which is located in",,"{}, which is located in","Autonomous University of Madrid, which is loca...",P17,Q34,Q29,Spain,Sweden,Autonomous University of Madrid,0.0002009897,1.928959e-07,True
4,What is the twin city of {}? It is,What is the twin city of,? It is,What is the twin city of Lyon? It is,P190,Q1461,Q3820,Beirut,Manila,Lyon,2.063344e-05,4.327067e-08,True


In [13]:
original_data['hit'].mean()

0.8296455130252293

In [14]:
original_data.to_parquet('entire_results_original.parquet')