In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("hugging_token")

In [None]:
import os, json, time
from pathlib import Path
from typing import Dict, List, Any

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from huggingface_hub import login


# =========================
# CONFIG
# =========================
DATASET_JSONL = "{Dataset_path}/sampled_arithmetic.jsonl"


SELECTED_MODEL = "gemma"  # options: qwen | gemma | mistral

# Resume options
START_INDEX = 0     # change if resuming (e.g. 6568)
END_INDEX   = None   # None = go to end 

# Map
MODEL_MAP = {
    "qwen":    "Qwen/Qwen2.5-7B-Instruct",
    "gemma":   "google/gemma-2-9b-it",
    "mistral": "mistralai/Mistral-7B-Instruct-v0.3"
}


MODEL_NAME   = MODEL_MAP[SELECTED_MODEL]
OUTPUT_DIR   = f"/kaggle/working/{SELECTED_MODEL}_outputs"
OUTPUT_JSONL = f"{SELECTED_MODEL}_predictions.jsonl"

os.makedirs(OUTPUT_DIR, exist_ok=True)


BATCH_SIZE     = 3
MAX_NEW_TOKENS = 256
TEMPERATURE    = 0.0
TOP_P          = 0.9
USE_8BIT       = False        
USE_BF16       = False
SEED           = 13

# Time budget: 11h55m (with 1min safety)
MAX_SECONDS       = 11*3600 + 55*60
SAFETY_BUFFER_SEC = 60
SAVE_EVERY        = 50





# =========================
# Print Logs
# =========================
print(f"OUTPUT_DIR:{OUTPUT_DIR}")
print(f"MODEL_NAME:{MODEL_NAME}")
print(f"START_INDEX:{START_INDEX }")
print(f"BATCH_SIZE :{BATCH_SIZE }")


# =========================
# Hugging Face login
# =========================
login(hf_token)

# =========================
# Reproducibility
# =========================
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# =========================
# Load model & tokenizer
# =========================
kwargs = {}
if USE_8BIT:
    kwargs.update(dict(load_in_8bit=True, device_map="auto"))
else:
    torch_dtype = torch.bfloat16 if USE_BF16 and torch.cuda.is_available() else torch.float16
    kwargs.update(dict(torch_dtype=torch_dtype, device_map="auto"))

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **kwargs)
model.eval()
if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True  # extra speed

# =========================
# IO helpers
# =========================
out_dir = Path(OUTPUT_DIR); out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / OUTPUT_JSONL

def stream_jsonl(path: str) -> List[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]

def load_completed_ids(existing_path: Path) -> Dict[str,int]:
    done = {}
    if existing_path.exists():
        with open(existing_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    obj = json.loads(line)
                    if obj.get("id"): done[obj["id"]] = 1
                except: continue
    return done

def append_jsonl(records: List[Dict[str, Any]], path: Path) -> None:
    if not records: return
    with open(path, "a", encoding="utf-8") as f:
        for r in records:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


# =========================
# Batched generation
# =========================
@torch.inference_mode()
def generate_batch(prompts: List[str]) -> List[str]:
    inputs = tokenizer(prompts, return_tensors="pt", padding=True,
                       truncation=True, max_length=4096).to(model.device)

    gen_cfg = GenerationConfig(
        max_new_tokens=MAX_NEW_TOKENS,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        do_sample=TEMPERATURE > 0.0,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    outputs = model.generate(**inputs, generation_config=gen_cfg, use_cache=True)

    decoded = []
    for i in range(len(prompts)):
        prompt_len = inputs["input_ids"][i].shape[0]
        out_ids = outputs[i][prompt_len:]
        text = tokenizer.decode(out_ids, skip_special_tokens=True).strip()
        if "Assistant:" in text: text = text.split("Assistant:")[-1].strip()
        decoded.append(text)
    return decoded

# =========================
# Time zone handling
# ========================   
def get_timezone_context(obj):
    if obj.get("tz_src") or obj.get("tz_dst"):
        return {
            "tz_src": obj.get("tz_src"),
            "tz_dst": obj.get("tz_dst"),
        }
    if obj.get("tz"):
        return {"tz": obj["tz"]}
    return {}
# =========================
# Build Prompt
# ========================    
def build_prompt(payload: dict) -> str:
    lines = [
        "You are a smart arithmetic solver. Compute exactly.",
        f"Task ID: {payload['id']}",
        f"Task type: {payload['type']}",
    ]
    tz_info = payload["tz_info"]
    if tz_info.get("tz_src"):
        lines.append(f"Source timezone: {tz_info['tz_src']}")
    if tz_info.get("tz_dst"):
        lines.append(f"Destination timezone: {tz_info['tz_dst']}")
    if tz_info.get("tz"):
        lines.append(f"Timezone: {tz_info['tz']}")
    lines.append(f"Task: {payload['task_text']}")
    return "\n".join(lines)

# =========================
# Main loop with monitoring
# =========================
def run_full_pass(dataset_path, output_path, start_index=0, end_index=None,
                  batch_size=BATCH_SIZE, save_every=SAVE_EVERY,
                  max_seconds=MAX_SECONDS, safety_buffer_sec=SAFETY_BUFFER_SEC):

    start = time.time(); deadline = start + max_seconds - safety_buffer_sec
    data = stream_jsonl(dataset_path)
    if end_index is None: end_index = len(data)
    data = data[start_index:end_index]

    completed = load_completed_ids(output_path)
    processed = len(completed); written_since_last = 0
    buffer, prompts, metas = [], [], []


    for idx, obj in enumerate(data, start=start_index + 1):
        _id = obj.get("id", f"item_{idx}")
        if _id in completed:
            continue

        prompt_payload = {
            "id": _id,
            "type": obj.get("type"),
            "task_text": (
                obj.get("task_text")
                or obj.get("input_text")
                or obj.get("instruction")
                or ""
            ),
            "tz_info": get_timezone_context(obj),
        }

        prompt_text = build_prompt(prompt_payload)        # string
        prompts.append(prompt_text)                       # list of strings
        metas.append(prompt_payload)   

        if len(prompts) >= batch_size:
            if time.time() >= deadline:
                print(f"[STOP] Time budget hit, processed={processed}, seen={idx}")
                break

            outs = generate_batch(prompts)
            for meta, text in zip(metas, outs):
                buffer.append({"id": meta["id"], "type": meta.get("type"), "raw_response": text})
                processed += 1; written_since_last += 1

            if written_since_last >= save_every:
                append_jsonl(buffer, output_path); buffer.clear(); written_since_last = 0

            elapsed = time.time()-start; remaining = max(0, deadline-time.time())
            print(f"[PROGRESS] answered={processed} | seen={idx} | "
                  f"elapsed={int(elapsed//3600)}h{int((elapsed%3600)//60)}m | "
                  f"remaining≈{int(remaining//3600)}h{int((remaining%3600)//60)}m")

            prompts.clear(); metas.clear()

    # flush leftovers
    if prompts and time.time() < deadline:
        outs = generate_batch(prompts)
        for meta, text in zip(metas, outs):
            buffer.append({"id": meta["id"], "type": meta.get("type"), "raw_response": text})
            processed += 1

    if buffer:
        append_jsonl(buffer, output_path)
        print(f"[FLUSH] wrote {len(buffer)} records")

    elapsed = time.time()-start
    print(f"[DONE] answered={processed} | elapsed={int(elapsed//3600)}h{int((elapsed%3600)//60)}m")

# =========================
# Run full pass
# =========================
run_full_pass(
    dataset_path=DATASET_JSONL,
    output_path=out_path,
    start_index=START_INDEX,
    end_index=END_INDEX,
)
