## Version 1

In [None]:
# import torch
# import re
# from transformers import AutoTokenizer, AutoModelForCausalLM
# from peft import PeftModel

# # --- SETUP ---
# BASE_MODEL_ID = "microsoft/phi-2"
# TUNED_MODEL_PATH = "models/phi2_retail_native_bf16_c6e0c0" # Replace with your run hash

# tokenizer = AutoTokenizer.from_pretrained(TUNED_MODEL_PATH, trust_remote_code=True)
# tokenizer.pad_token = tokenizer.eos_token
# base_model = AutoModelForCausalLM.from_pretrained(
#     BASE_MODEL_ID,
#     # torch_dtype=torch.bfloat16,
#     dtype=torch.bfloat16,
#     device_map="auto",
#     trust_remote_code=True
# )
# model = PeftModel.from_pretrained(base_model, TUNED_MODEL_PATH)
# model.eval()


# SYSTEM_PROMPT = (
#     "You are the PUMA Holographic Assistant. Follow these strict operational rules:\n"
#     "1. If Context is 'N/A': Handle general greetings or PUMA-related brand questions. "
#     "If the query is completely unrelated to PUMA, sports, or retail, politely refuse to answer.\n"
#     "2. If Context is 'No products found.': Inform the user that no matching footwear was found "
#     "and suggest they try a different style or category.\n"
#     "3. If Context contains Product Lists: Provide a high-level highlight of the collection "
#     "and transition the user into the immersive 3D view.\n"
#     "4. If Context contains T&C/Policies: Use the information provided to answer the user query accurately.\n"
#     "5. If User Query is '<GESTURE_EXIT>': Acknowledge that the user has closed the 3D display, "
#     "briefly summarize the product they just viewed, and ask if they need further assistance."
# )

# while True:
#     print("\n" + "="*60)
#     context = input("CONTEXT: ")
#     query = input("QUERY: ")
#     if query.lower() == 'exit': break

#     prompt = (
#         f"### Instruction:\n{SYSTEM_PROMPT}\n\n"
#         f"### Context:\n{context}\n\n"
#         f"### User Query:\n{query}\n\n"
#         f"### Response:\n"
#     )

#     inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

#     with torch.no_grad():
#         outputs = model.generate(
#             **inputs,
#             max_new_tokens=256,
#             # temperature=0.7,
#             do_sample=False,
#             pad_token_id=tokenizer.eos_token_id,
#             # We DON'T set eos_token_id here so we can see if the model 
#             # stops naturally or keeps rambling
#         )

#     # 1. Decode EVERYTHING (Prompt + Generated Text)
#     # We set clean_up_tokenization_spaces to False to see the raw output
#     raw_output = tokenizer.decode(outputs[0], skip_special_tokens=False)

#     # print("\n--- [DEBUG: FULL RAW MODEL OUTPUT] ---")
#     # print(raw_output)
#     # print("--- [END DEBUG] ---\n")
#     # Extract assistant response only
#     response_text = raw_output.split("### Response:\n", 1)[-1]

#     # Remove END_OF_RESPONSE or partial tokens
#     response_text = re.sub(r"<END_OF_RESPONSE.*", "", response_text, flags=re.DOTALL)

#     # Cleanup
#     response_text = response_text.strip()

#     print("\n--- [CLEAN MODEL RESPONSE] ---")
#     print(response_text)
#     print("--- [END RESPONSE] ---\n")

In [None]:
import torch
import re
import time
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
from peft import PeftModel


class StopOnTokens(StoppingCriteria):
    """Stop generation when a specific token sequence appears at the end."""
    def __init__(self, stop_ids: list[int]):
        if not stop_ids:
            raise ValueError("stop_ids is empty")
        self.stop_ids = stop_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # input_ids: [batch, seq_len]
        if input_ids.shape[0] != 1:
            # your use case is 1 anyway
            return False
        n = len(self.stop_ids)
        if input_ids.shape[1] < n:
            return False
        tail = input_ids[0, -n:].tolist()
        return tail == self.stop_ids


# --- SETUP ---
BASE_MODEL_ID = "microsoft/phi-2"
# TUNED_MODEL_PATH = "models/phi2_retail_native_bf16_c6e0c0"
TUNED_MODEL_PATH = "models/phi2_retail_native_bf16_38f4a5"

tokenizer = AutoTokenizer.from_pretrained(TUNED_MODEL_PATH, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, TUNED_MODEL_PATH)
model.eval()

# IMPORTANT: stop as soon as the model emits the first end tag
# Use the full "<END_OF_RESPONSE>" token sequence (most robust)
stop_str = "<END_OF_RESPONSE>"
stop_ids = tokenizer.encode(stop_str, add_special_tokens=False)
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])

# SYSTEM_PROMPT = (
#     "You are the PUMA Holographic Assistant. Follow these strict operational rules:\n"
#     "1. If Context is 'N/A': Handle general greetings or PUMA-related brand questions. "
#     "If the query is completely unrelated to PUMA, sports, or retail, politely refuse to answer.\n"
#     "2. If Context is 'No products found.': Inform the user that no matching footwear was found "
#     "and suggest they try a different style or category.\n"
#     "3. If Context contains Product Lists: Provide a high-level highlight of the collection "
#     "and transition the user into the immersive 3D view.\n"
#     "4. If Context contains T&C/Policies: Use the information provided to answer the user query accurately.\n"
#     "5. If User Query is '<GESTURE_EXIT>': Acknowledge that the user has closed the 3D display, "
#     "briefly summarize the product they just viewed, and ask if they need further assistance.\n\n"
# )

### Version 2
SYSTEM_PROMPT = (
    "You are the PUMA Holographic Assistant, an intelligent 3D AI retail guide. "
    "Follow these strict operational rules:\n"
    "1. If Context is 'N/A' and the user greets you, says goodbye, or asks who you are: "
    "Respond enthusiastically in character as the PUMA Holographic AI Assistant and pivot to exploring PUMA gear.\n"
    "2. If Context is 'N/A' and the query is completely unrelated to PUMA, sports, or retail: "
    "Politely refuse to answer, stay in character, and pivot back to PUMA footwear or gear.\n"
    "3. If Context is 'No products found.': Inform the user that no matching footwear was found "
    "and suggest they try a different style or category.\n"
    "4. If Context contains Product Lists: Provide a high-level highlight of the collection "
    "and transition the user into the immersive 3D view.\n"
    "5. If Context contains T&C/Policies: Use the information provided to answer the user query accurately.\n"
    "6. If User Query is '<GESTURE_EXIT>': Acknowledge that the user has closed the 3D display, "
    "briefly summarize the product they just viewed, and ask if they need further assistance."
)

while True:
    print("\n" + "=" * 60)
    context = input("CONTEXT: ")
    query = input("QUERY: ")
    if query.lower() == "exit":
        break

    prompt = (
        f"### Instruction:\n{SYSTEM_PROMPT}\n\n"
        f"### Context:\n{context}\n\n"
        f"### User Query:\n{query}\n\n"
        f"### Response:\n"
    )

    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    print("\n[DEBUG] Context:", context)
    print("[DEBUG] Query:", query)

    prompt_len = inputs["input_ids"].shape[-1]

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    t0 = time.perf_counter()

    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,  # safe ceiling; should stop much earlier now
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            stopping_criteria=stopping_criteria,
            repetition_penalty=1.05,  # mild anti-looping
        )

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    t1 = time.perf_counter()

    new_tokens = outputs.shape[-1] - prompt_len
    infer_time = t1 - t0
    print(f"\n[Inference Time] {infer_time:.3f}s | prompt_tokens: {prompt_len} | new_tokens: {new_tokens} | tok/s: {new_tokens / max(infer_time, 1e-9):.2f}")

    raw_output = tokenizer.decode(outputs[0], skip_special_tokens=False)

    response_text = raw_output.split("### Response:\n", 1)[-1]
    # Cut at first end tag (fast + clean)
    response_text = response_text.split("<END_OF_RESPONSE>", 1)[0]
    # Keep your regex as extra safety
    response_text = re.sub(r"<END_OF_RESPONSE.*", "", response_text, flags=re.DOTALL).strip()

    print("\n--- [CLEAN MODEL RESPONSE] ---")
    print(response_text)
    print("--- [END RESPONSE] ---\n")

    print("\n--- [RAW MODEL RESPONSE] ---")
    print(raw_output)
    print("--- [END RESPONSE] ---\n")
