In [None]:
from llama2 import *
from typing import List, Literal, Optional, Tuple, TypedDict
from pprint import pprint
import pandas as pd
import numpy as np
import datasets
import evaluate

from torch.utils.data import Dataset

In [None]:
model_size='int4'
max_samples=100

## 1 - Load model

In [None]:
model_name = "meta-llama/Llama-2-7b-chat-hf"

model = LlamaModel(
    model_name=model_name,
    model_resolution=model_size
)

## 2 - Load data

In [None]:
dataset = datasets.load_dataset('hellaswag', split='validation')

In [None]:
dataset

In [None]:
dataset[0]

In [None]:
n_options = [len(x['endings']) for x in dataset]
assert min(n_options) == max(n_options)
max(n_options), min(n_options) # All endings in the dataset have exactly four options

In [None]:
x = dataset[2]
options = [f"\t{chr(ord('a')+i)}) {option}" for i, option in enumerate(x['endings'])]
endings = '\n'.join(options)
text = f'''\
{model.B_INST} You are solving an entailment task, given the situation respond with the most appropriate completion. 
Think logically and step by step. {model.E_INST}
{x["activity_label"]}. {x['ctx']}

{endings}

Answer:

The correct answer is ('''

print(text)

In [None]:
output = model.generate(text)
print(output)

In [None]:
class HellaSwag(Dataset):
    def __init__(self, split: str, inst_toks: List[str] | None = None):
        self.dataset = datasets.load_dataset('hellaswag', split='validation')
        self.inst_toks = inst_toks

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx: int):
        x = self.dataset[idx]
        options = [f"\t{chr(ord('a')+i)}) {option}" for i, option in enumerate(x['endings'])]
        endings = '\n'.join(options)

        pretext = ''
        if self.inst_toks:
            pretext = (f'{self.inst_toks[0]} You are solving an entailment task, given the '
                        'situation respond with the most appropriate completion. \n'
                        f'Think logically and step by step. {self.inst_toks[1]}')
        text = pretext + (f'{x["activity_label"]} {x["ctx"]}\n\n'
                f'{endings}\n\n'
                'Answer:\n\nThe correct answer is (')
        
        y = x['label']
        return text, int(y)

In [None]:
def eval(model: LlamaModel, x):
    with torch.no_grad():
        tok_ids = model.tokenize(x)
        output = model.model(tok_ids.to(model.model.device))
        choices = model.tokenizer.convert_tokens_to_ids(['a','b','c','d'])
        probs = torch.softmax(output['logits'][:,-1,choices], dim=1)
    return probs.argmax().item(), probs.max().item()


def eval_dataset(model, data, max_samples):
    gt, pred, prob = [], [], []

    limit = len(data)
    if max_samples > 0:
        limit = min([limit, max_samples])

    for i in range(limit):
        x, y = data[i]
        res = eval(model, x)

        gt.append(y)
        pred.append(res[0])
        prob.append(res[1])

    df = pd.DataFrame({'y': gt, 'y_pred': pred, 'prob': prob})
    return df


def get_metrics(df: pd.DataFrame) -> Dict[str, float]:
    accuracy = (df.y_pred == df.y).sum() / len(df)
    
    mat = torch.zeros(4,4, dtype=int)
    for i in range(len(df)):
        row = df.iloc[i]
        mat[int(row.y_pred), int(row.y)] += 1

    recall = [mat[i, i] / (mat[i, :].sum() + 1e-6) for i in range(4)]
    precision = [mat[i, i] / (mat[:, i].sum() + 1e-6)for i in range(4)]

    mp = sum(precision) / 4
    mr = sum(recall) / 4

    f1 = (2 * (mp * mr) / (mp + mr)).item()

    return {'accuracy': f'{accuracy:.4f}', 'f1': f'{f1:.4f}'}

## 3 - Run inference

In [None]:
df = eval_dataset(model, HellaSwag('validation'), max_samples)
# df = eval_dataset(model, HellaSwag('validation', [model.B_INST, model.E_INST]), max_samples)
df.prob.describe()

In [None]:
for precision in np.linspace(0,1,10, endpoint=False):
    print(f'Class prob > {precision:.1f}: {get_metrics(df[df.prob > precision])}')