In [None]:
import os
from openai import OpenAI
from dotenv import load_dotenv
from datasets import load_dataset, Dataset
from tqdm.auto import tqdm
import instructor
from pydantic import BaseModel, Field
from typing import List
import wandb

load_dotenv(override=True)

In [None]:
def format_inputs(ds: Dataset) -> List[str]:
    formatted_inputs = []
    for _, data in tqdm(enumerate(ds)):
        ins = ""
        for idx, sent in enumerate(data["shuffled_sentences"]):
            ins += f"{idx}. {sent}\n"
        formatted_inputs.append(ins)
        
    return formatted_inputs

In [None]:
class PredictionResult(BaseModel):
    predicted_order: List[int] = Field(
        "A list of integers in the range 0 upto 4, which contains the serial numbers of sentences in the correct order.")

In [None]:
def predict(
    client, 
    inp, 
    model_name="meta-llama/Meta-Llama-3.1-8B-Instruct"
) -> PredictionResult:
    system_message = """
    You are given 5 sentences from a story in a shuffled order. 
    Each sentence has a serial number assoiciated with it. 
    You can find the serial number at the beginning of each sentence, for example: 0. <sentence>.
    The serial numbers start from 0 and ends at 4.
    Your task is to predict the correct order of sentences which would resemble the original story and then return the serial numbers as a list of integers, which contain the serial numbers.

    #IMPORTANT
    There are exactly five sentences. 
    """
    
    res = client.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": inp}
        ],
        temperature=0.0,
        response_model=PredictionResult,
    )
    
    return res

def batch_predict(ds: Dataset) -> List[PredictionResult]:
    client = instructor.from_openai(OpenAI(
        api_key="token-abc123",
        base_url=os.getenv("VLLM_ENDPOINT_URL")
    ))
    
    predicted = list()
    formatted_inputs = format_inputs(ds)
    
    for _, inp in tqdm(enumerate(formatted_inputs)):
        res = predict(client, inp)
        predicted.append(res)
        
    return predicted

In [None]:
from scipy.stats import kendalltau

def overlapping_accuracy(gold, predicted):
    assert len(gold) == len(predicted)

    overlaps = 0
    for idx in range(len(gold)):
        if gold[idx] == predicted[idx]:
            overlaps += 1

    return overlaps / len(gold)


# https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.kendalltau.html

def evaluate(gold: List[int], predicted: List[int]):
    tau = kendalltau(x=gold, y=predicted)
    acc = overlapping_accuracy(gold, predicted)
    
    return {
        "tau_stat": tau.statistic,
        "tau_p": tau.pvalue,
        "correlation": tau.correlation,
        "overlap_accuracy": acc,
    }

In [None]:
from loguru import logger


def run(
    dataset_size: int = 2, 
    model_name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
):
    # prepare wandb run
    logger.info("Preparing Wandb Init")
    wandb.init(
        project="llm-sentence-ordering",
        config= {
            "model_name": model_name,
            "dataset_size": dataset_size
        },
        name="meta-llama/Meta-Llama-3.1-8B-Instruct",
    )
    
    logger.info(f"Model: {model_name} :: Dataset Size: {dataset_size}")
    
    logger.info("Loading dataset")
    ds = load_dataset(os.getenv("HF_HUB_DATASET_NAME"))
    sampled_ds = ds["train"].select(range(dataset_size))
    
    gold = [data["gold_order"] for data in sampled_ds]
    llm_pred_results = batch_predict(sampled_ds)
    
    logger.info("Running Prediction")
    predicted = [p.predicted_order for p in llm_pred_results]
    
    logger.info("Evaluating")
    eval_results = evaluate(gold, predicted)
    
    
    logger.info("Logging results")
    wandb.log(eval_results)
    wandb.finish()

In [None]:
run()