In [None]:
from transformers import AutoTokenizer, LlamaForCausalLM,LlamaConfig
import torch
import os
import pandas as pd
import numpy as np
import pickle
import glob
import re
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse
from utils import model_to_path_dict

In [None]:
save_dir = '../../generated'
behavior_data_dir = '../../behavior_data'
story = 'pieman'
original_transcript_dir = os.path.join(behavior_data_dir,'transcripts','moth_stories')
with open(os.path.join(original_transcript_dir,'%s.txt'%story),'r') as f:
    original_txt = f.read()

device = 'cuda'
model_name = 'Llama3-8b-instruct'
tokenizer = AutoTokenizer.from_pretrained(model_to_path_dict[model_name]['hf_name'])
model = LlamaForCausalLM.from_pretrained(model_to_path_dict[model_name]['hf_name'],attn_implementation="eager",device_map='auto',torch_dtype = torch.float16)

In [None]:
# prepare prompt
system_prompt = "You are a human with limited memory ability. You're going to listen to a story, and your task is to recall the story and summarize it in your own words in a verbal recording. Respond as if you’re speaking out loud."
user_prompt = "Here's the story: %s\nHere's your recall: "%original_txt
messages = [
    {
        "role": "system",
        "content": system_prompt,
    },
    {"role": "user", "content": user_prompt},
]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
tokenized_chat = tokenized_chat.to(device)
recall_start_index = tokenized_chat.shape[1]-5

# only apply the attention temperature manipulation from recall to the story, not to the system & user prompt
# find where the story starts
story_start_idx = -1
for i,t in enumerate(tokenized_chat[0]):
    txt = tokenizer.decode(t)
    if ':' in txt:
        story_start_idx = i+1
        break
assert story_start_idx>0 and story_start_idx<tokenized_chat.shape[1],'failed to find story start idx'

In [None]:
%% time
# attention temperatures, 0 = unmodified, the larger the number, the higher the temperature, the more diffuse the attention
scales = [0,0.00005,0.00007,0.0002,0.0005,0.001] 
n = 5 # generate 5 recalls per temp
output_by_scale = {}
for scale in tqdm(scales):
    outputs = []
    for i in range(n):
        if scale != 0:
            output = model.generate(tokenized_chat,attention_scale = scale,recall_start_index = recall_start_index,
                                story_start_index=story_start_idx,max_new_tokens = 800,output_attentions=False,
                                return_dict_in_generate=True,output_scores=False)
        else:
            output = model.generate(tokenized_chat,story_start_index=story_start_idx,max_new_tokens = 800,
                                    output_attentions=False,return_dict_in_generate=True,output_scores=False)
        sequence = output['sequences']
        outputs.append(tokenizer.decode(sequence[0][tokenized_chat.shape[1]:]))
    output_by_scale[scale] = outputs