In [None]:
import json, sagemaker
from transformers import AutoTokenizer
# from sagemaker.s3 import S3Downloader
from sagemaker import serializers, deserializers

from utils import data_utils

In [None]:
session = sagemaker.session.Session()

dataset_id = "deepmind/code_contests"
model_id = "mistral-community/Codestral-22B-v0.1"
test_dataset_local_path = "/home/ubuntu/finetune-llms-on-aws/practise-fsdp/sft_cache/data/test_dataset.json"
endpoint_name = "codestral-vllm-2024-06-18-16-52-35-354"


In [None]:
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")

In [None]:
test_dataset = data_utils.load_and_process(
    dataset_id=dataset_id,
    split="test"
)
print(f"test_dataset: {test_dataset}")
random_sample = test_dataset[345]

In [None]:
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=session,
    serializer=serializers.JSONSerializer(),
    deserializer=deserializers.JSONDeserializer(),
)

In [None]:
def request(sample):
    prompt = tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=True)

    outputs = predictor.predict({
      "inputs": prompt,
      "parameters": {
        "max_new_tokens": 512,
        "do_sample": False,
        "return_full_text": False,
      }
    })
    # return {"role": "assistant", "content": outputs["generated_text"].strip()}
    return outputs
  
# print(random_sample["messages"][1])

# request(random_sample["messages"][:2])

In [None]:
len(random_sample["messages"])

In [None]:
# TODO: write n@k 

from tqdm import tqdm
 
def evaluate(sample):
    predicted_answer = request(sample["messages"][:2])
    if predicted_answer["content"] == sample["messages"][2]["content"]:
        return 1
    else:
        return 0
 
success_rate = []
number_of_eval_samples = 1000
# iterate over eval dataset and predict
for s in tqdm(test_dataset.shuffle().select(range(number_of_eval_samples))):
    success_rate.append(evaluate(s))
 
# compute accuracy
accuracy = sum(success_rate)/len(success_rate)
 
print(f"Accuracy: {accuracy*100:.2f}%")