In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.cuda.amp import autocast
from datasets import load_dataset
import pandas as pd

from tqdm import tqdm

In [2]:
def generate_responses(questions):
    model_name = "locuslab/tofu_ft_llama2-7b"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    model.to(device)
    model.eval()
    
    # instruction = ''
    # if instructions_file:
    #     with open(instructions_file, 'r', encoding='utf-8') as file:
    #         instruction = file.read().strip()
    
    # with open(questions_file, 'r', encoding='utf-8') as file:
    #     questions = file.readlines()

    responses = []
    for question in tqdm(questions):
        question = question.strip()

        inputs = tokenizer(question, return_tensors='pt').to(device)
        input_length = inputs['input_ids'].shape[1]
        
        with torch.no_grad():
            outputs = model.generate(inputs['input_ids'], max_length=100, min_length=input_length+10, temperature=0.001)
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # print(f"Question: {question}\n")

        # New post-processing to extract only the answer by removing the echoed question
        answer_start = response.find(question) + len(question)
        answer = response[answer_start:].strip()
        
        # Further cleaning to remove any leading or unwanted characters
        if answer.startswith('?'):
            answer = answer[1:].strip()
        
        torch.cuda.empty_cache()

        responses.append(answer)
    
    return responses

In [3]:
dataset = load_dataset("locuslab/TOFU", "forget10")
dataset = pd.DataFrame(dataset['train'])

In [4]:
dataset['response'] = generate_responses(dataset['question'])

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Using device: cuda


100%|██████████| 400/400 [38:16<00:00,  5.74s/it]


In [5]:
dataset.isnull().sum()

question    0
answer      0
response    0
dtype: int64

In [6]:
dataset.to_csv('forget10_with_responses.csv', index=False)