In [None]:
from llama2 import *
from typing import List, Literal, Optional, Tuple, TypedDict
import pandas as pd
import datasets
import string
from evaluate import evaluator
from evaluate import load

from torch.utils.data import Dataset

In [None]:
model_size='int8'
max_samples=10

## 1 - Load model

In [None]:
#papermill_description=LOADING_MODEL
model_name = "meta-llama/Llama-2-7b-chat-hf"

model = LlamaModel(
    model_name=model_name,
    model_resolution=model_size
)

model.model.to = lambda x: x # Disable device copying

## 2 - Load dataset

In [None]:
#papermill_description=LOADING_DATA
dataset = datasets.load_dataset('squad_v2', split='validation')

In [None]:
dataset

In [None]:
dataset[2]

## 3 - Define data prep and model inference functions

In [None]:
def format_question(sample: dict) -> str:
    """Format a sample from the squad V2 dataset to question answer string."""
    pretext = (f'{model.B_INST} You are performing extractive reading comprehension; '
               'given a question and a paragraph from an article, respond only with a '
               'direct extract from the article that answers the question and do not use your own '
               'prior knowledge.  If no direct extract from the context can answer the question, '
               f'respond with an empty quote.'
               f' {model.E_INST}\n')  # Llama system directive
    q_a = (f'Article: {sample["title"]}\n'
         f'Paragraph: {sample["context"]}\n\n'
         f'Question: {sample["question"]}\n\n'
         f'Answer: ')
    
    return pretext + q_a


def squad_inference(df: pd.DataFrame, model) -> pd.DataFrame:
    """Predict the output extracts for all samples in the input squad format dataset"""
    df_val = df.copy(deep=True)
    df_val['prediction_text'] = ''
    df_val['no_answer_probability'] = 0.
    df_val['pred_start'] = -1

    for idx in range(len(df)):
        x = format_question(df.iloc[idx])
        y_pred = model.generate(x) # TODO: get prob of </s> token on output as no_answer_probability
        
        y_pred_clean = y_pred[len(x)+3:-4].strip().strip(string.punctuation) # Remove start / end tags + whitespace
        df_val.loc[idx, 'prediction_text'] = y_pred_clean
        if y_pred_clean != '':
            df_val.loc[idx, 'pred_start'] = str(df_val.context.iloc[idx]).find(y_pred_clean)

    return df_val

In [None]:
df = dataset.to_pandas()
df['prediction'] = ''
df['pred_start'] = -1

idx = 9
x = format_question(dataset[idx])
print(x)

y_pred = model.generate(x)
print(y_pred)
y_pred_clean = y_pred[len(x)+3:-4].strip().strip(string.punctuation) # Remove start / end tags + whitespace
df.loc[idx, 'prediction'] = y_pred_clean
if y_pred_clean != '':
    df.loc[idx, 'pred_start'] = str(df.context.iloc[idx]).find(y_pred_clean)

In [None]:
df.loc[idx]

In [None]:
#papermill_description=RUNNING_INFERENCE
pd_dataset = dataset.to_pandas()
if max_samples > 0 and max_samples < len(pd_dataset):
    pd_dataset = pd_dataset.iloc[:max_samples - 1]
df2 = squad_inference(pd_dataset, model)

In [None]:
df2.head(1)

## 5 - Evaluate performance

In [None]:
#papermill_description=EVALUATION
squad_v2_metric = load("squad_v2")

predictions = df2[['prediction_text', 'no_answer_probability', 'id']].to_dict('records')
answers = df2[['answers', 'id']].to_dict('records')

results = squad_v2_metric.compute(predictions=predictions, references=answers)
results