In [None]:
import os
import json
import torch
import logging
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from itertools import chain
from tqdm import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "true"

def get_device():
    device = "cpu"
    if torch.cuda.is_available():
        device = torch.device("cuda")
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = torch.device("mps")
    return device

def create_few_shot(number_few_shot):
    with open('../benchmark/prompt_examples.json') as json_file:
        data = json.load(json_file)

    template = "Context: {context}\nQuestion: {question}\nAnswer: {answers}"
    prompt = "\n\n".join([template.format(
        context=row['context'],
        question=row['question'],
        answers=row['answers']
    ) for row in data[0:number_few_shot]])
    return prompt+'\n\n'

def create_prompt(item, prompt_examples):
    template = "Context: {context}\nQuestion: {question}\nAnswer: "
    prompt = template.format(context=item['context'], question=item['question'])
    if prompt_examples:
        item['prompt'] = prompt_examples+prompt
    else:
        item['prompt'] = prompt
    return item

def tokenization(items, tokenizer):
    return tokenizer(items["prompt"], padding='longest')

In [None]:
model_id = "EleutherAI/pythia-160m-deduped"
dataset_id = "squad"
split_name = "train"
number_few_shot = 3
batch_size = 1
num_workers = 2

In [None]:
device = get_device()

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_special_tokens({"pad_token":"<pad>"})
tokenizer.padding_side = 'left'

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id

In [None]:
dataset = load_dataset(dataset_id, split=split_name)
prompt_examples = create_few_shot(number_few_shot) if number_few_shot>0 else ""
prompt_examples_length = len(prompt_examples)
dataset = dataset.map(lambda item: create_prompt(item, prompt_examples))
dataset = dataset.map(lambda items: tokenization(items, tokenizer=tokenizer), batched=True, batch_size=batch_size)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)

In [None]:
predictions = []
i = 0
with torch.no_grad():
    for batch in tqdm(dataloader):
        output = model.generate(
            batch['input_ids'].to(device),
            attention_mask=batch['attention_mask'].to(device),
            pad_token_id=tokenizer.pad_token_id,
            max_new_tokens=15,
        ).to('cpu')
        sentences = tokenizer.batch_decode(output, skip_special_tokens=True)
        print(sentences[0][prompt_examples_length:].split('\n')[2][7:])
        print(dataset['answers'][i])
        print('\n\n')
        i+=1