In [None]:
from llama2 import *
from typing import List, Literal, Optional, Tuple, TypedDict
import pandas as pd
import numpy as np
import datasets
import string
from evaluate import evaluator
from evaluate import load

from torch.utils.data import Dataset

In [None]:
model_size='int4'
max_samples=-1

## 1 - Load model

In [None]:
#papermill_description=LOADING_MODEL
model_name = "meta-llama/Llama-2-7b-chat-hf"

model = LlamaModel(
    model_name=model_name,
    model_resolution=model_size
)

model.model.to = lambda x: x # Disable device copying

## 2 - Load dataset

In [None]:
#papermill_description=LOADING_DATA
dataset = datasets.load_dataset('glue', 'mnli', split='validation_matched')

In [None]:
dataset

In [None]:
dataset[5]

## 3 - Define data prep and model inference functions

In [None]:
def format_question(sample: dict) -> str:
    """Format a sample from the squad V2 dataset to question answer string."""
    pretext = (f'{model.B_INST} You are performing natural language inference tasks. '
               'Given a premise and hypothesis, decide whether we have an entailment, neutral or contradiction relation. '
               'Only respond with the words "The response is " and one of "neutral, contradiction, entailment"'
               f' {model.E_INST}\n')  # Llama system directive
    q_a = (f'Premise: {sample["premise"]}\n'
         f'Hypothesis: {sample["hypothesis"]}\n'
         f'Answer: \n\nThe response is:')
    
    return pretext + q_a


def glue_inference(df: pd.DataFrame, model) -> pd.DataFrame:
    """Predict the output extracts for all samples in the input squad format dataset"""
    df_val = df.copy(deep=True)
    df['prediction'] = -1
    df['prob'] = -1
    
    labels = ['neutral', 'entailment', 'contradiction']
    label_ids = [model.tokenizer.encode(label, add_special_tokens=False)[0] for label in labels]

    for idx in range(len(df)):
        with torch.no_grad():
            x = format_question(df.iloc[idx])
            tokens = model.tokenize(x)
            logits = model.model(tokens).logits
            probs = torch.softmax(logits[:,-1,label_ids], dim=1)
            
        df_val.loc[idx, 'prediction'] = probs.argmax().item()
        df_val.loc[idx, 'prob'] = probs.max().item()
        
    df_val.prediction = df_val.prediction.astype(int)
    return df_val

In [None]:
#papermill_description=RUNNING_INFERENCE
pd_dataset = dataset.to_pandas()
if max_samples > 0 and max_samples < len(pd_dataset):
    pd_dataset = pd_dataset.iloc[:max_samples - 1]
df2 = glue_inference(pd_dataset, model)

## 5 - Evaluate performance

In [None]:
#papermill_description=EVALUATION
glue_metric = load("glue", "mnli_matched")

for precision in np.linspace(0,1,10, endpoint=False):
    df_filt = df2[df2.prob > precision]
    
    predictions = df_filt['prediction'].to_list()
    answers = df_filt['label'].to_list()

    results = glue_metric.compute(predictions=predictions, references=answers)
    print(f'Class prob > {precision:.1f}: {results}')