In [None]:
import json
import os
import time
import logging
from pathlib import Path

from preprocessing import DataPreprocessor
from server import SingleVariantServer
from load_generator import ClosedLoopLoadGenerator
from metrics import MetricsCalculator
from evaluation import HeldOutEvaluator

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

In [None]:
# -------------------------
# Configuration
# -------------------------

PREPROCESS = False

DATA_DIR = "data/raw"
PROCESSED_DIR = "data/processed"

MODEL_NAME_OR_PATH = "meta-llama/Llama-2-7b-chat-hf"
# OR local path (we'll discuss this below)
# MODEL_NAME_OR_PATH = "/mnt/models/llama-2-7b-chat"

DEVICE = "cuda"
DTYPE = "auto"

NUM_REQUESTS = 5000
CONCURRENCIES = [1, 2, 4, 8, 16, 32]

DATA_SUBSET = 0  # 0 = full data

OUTPUT_DIR = "results/baseline_med"
os.makedirs(OUTPUT_DIR, exist_ok=True)


In [None]:
if PREPROCESS:
    logger.info("[STEP 0] Preprocessing data")

    preprocessor = DataPreprocessor(
        data_dir=DATA_DIR,
        output_dir=PROCESSED_DIR
    )
    train_data, val_data, test_data = preprocessor.run_pipeline()


In [None]:
def load_data(data_dir):
    splits = {"train": [], "val": [], "test": []}

    for split in splits:
        path = os.path.join(data_dir, f"{split}_data.jsonl")
        if not os.path.exists(path):
            logger.warning(f"Missing {path}")
            continue

        with open(path) as f:
            for line in f:
                if line.strip():
                    splits[split].append(json.loads(line))

    return splits["train"], splits["val"], splits["test"]


train_data, val_data, test_data = load_data(PROCESSED_DIR)

assert len(val_data) > 0 and len(test_data) > 0, "Validation/Test data missing"

if DATA_SUBSET > 0:
    val_data = val_data[:DATA_SUBSET]
    test_data = test_data[:DATA_SUBSET]

logger.info(f"Loaded val={len(val_data)}, test={len(test_data)}")


In [None]:
logger.info("[STEP 2] Initializing server")

server = SingleVariantServer(
    model_name=MODEL_NAME_OR_PATH,
    variant="med",
    device=DEVICE,
    dtype=DTYPE
)

In [None]:
load_gen = ClosedLoopLoadGenerator(
    inference_func=server.generate,
    max_concurrency=1,
    num_requests=10,
    data_loader=val_data
)

metrics = load_gen.run()
calc = MetricsCalculator(metrics)
calc.print_report("SANITY CHECK (Concurrency=1)")


In [None]:
all_metrics_summary = {}

for concurrency in CONCURRENCIES:
    logger.info(f"Running load test: concurrency={concurrency}")

    load_gen = ClosedLoopLoadGenerator(
        inference_func=server.generate,
        max_concurrency=concurrency,
        num_requests=NUM_REQUESTS,
        data_loader=val_data
    )

    start = time.time()
    raw_metrics = load_gen.run()
    duration = time.time() - start

    calc = MetricsCalculator(raw_metrics)
    metrics = calc.compute_all_metrics()

    calc.print_report(f"Concurrency {concurrency}")

    # Save
    calc.save_metrics(f"{OUTPUT_DIR}/metrics_{concurrency}.json")
    load_gen.save_metrics(f"{OUTPUT_DIR}/requests_{concurrency}.jsonl")

    all_metrics_summary[concurrency] = {
        "metrics": metrics,
        "duration_sec": duration
    }


In [None]:
evaluator = HeldOutEvaluator(
    model=server,
    data_loader=test_data,
    batch_size=32
)

eval_results = evaluator.evaluate()

with open(f"{OUTPUT_DIR}/eval_results.json", "w") as f:
    json.dump(eval_results, f, indent=2)

eval_results

In [None]:
summary = {
    "load_tests": all_metrics_summary,
    "eval_results": eval_results,
    "config": {
        "model": MODEL_NAME_OR_PATH,
        "device": DEVICE,
        "num_requests": NUM_REQUESTS,
        "concurrencies": CONCURRENCIES
    }
}

with open(f"{OUTPUT_DIR}/summary.json", "w") as f:
    json.dump(summary, f, indent=2)

logger.info("Evaluation complete")
