In [None]:
# SETUP AND IMPORTS

import torch
import time
import os
import json
import re

try:
    from unsloth import FastLanguageModel
    from datasets import load_dataset
    print("✅ [CHECKPOINT] Imports successful")
except ImportError as e:
    print(f"❌ ImportError: {e}")
    raise

# Set Hugging Face cache directory to a larger, persistent volume
cache_dir = "/output/huggingface_cache"
os.environ['HF_HOME'] = cache_dir
os.environ['HF_DATASETS_CACHE'] = os.path.join(cache_dir, "datasets")
os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_dir, "models")

# Prevent tokenizer parallelism issues
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Ensure the cache directory exists
os.makedirs(cache_dir, exist_ok=True)
print(f"✅ [SETUP] Hugging Face cache directory set to: {cache_dir}")

# --- Configuration ---
STUDENT_MODEL_NAME = "unsloth/Llama-3.2-1B-unsloth-bnb-4bit"
ADAPTER_PATH = "./train_outputs/sst2_finetune/final_adapter" 
BEST_RESULT_PATH = "./optimization_results/best_result.json"
STUDENT_MAX_SEQ_LENGTH = 8192
DEMO_SAMPLE_INDEX = 1

# --- Device Setup ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

In [None]:
# LOAD FINE-TUNED STUDENT MODEL

print(f"\nLoading Student Model: {STUDENT_MODEL_NAME}...")
try:
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=STUDENT_MODEL_NAME,
        max_seq_length=STUDENT_MAX_SEQ_LENGTH,
        load_in_4bit=True,
    )
    print(f"Loading adapter from: {ADAPTER_PATH}")
    model.load_adapter(ADAPTER_PATH)
    model.eval() # Set model to evaluation mode
    print("✅ [CHECKPOINT] Student model and adapter loaded successfully.")
except Exception as e:
    print(f"❌ Failed to load student model: {e}")
    raise

In [None]:
# LOAD THE BEST OPTIMIZED INSTRUCTION
    
print(f"\nLoading best instruction from {BEST_RESULT_PATH}...")
try:
    with open(BEST_RESULT_PATH, 'r') as f:
        best_result_data = json.load(f)
    best_instruction = best_result_data["best_instruction"]
    print("✅ [CHECKPOINT] Best instruction loaded successfully.")
    print("\n--- Best Instruction ---")
    # Print sentences separately
    sentences = re.split(r'(?<=[.!?])\s+', best_instruction)
    for s in sentences:
        print(s)
except Exception as e:
    print(f"❌ Failed to load best instruction file: {e}")
    raise

In [None]:
# LOAD A DEMO SAMPLE

print("\nLoading a sample from IMDb dataset...")
try:
    # Load just one sample to keep it fast
    demo_sample = load_dataset("imdb", split=f"test[{DEMO_SAMPLE_INDEX}:{DEMO_SAMPLE_INDEX+1}]")[0]
    label_map = {0: "Negative", 1: "Positive"}
    demo_text = demo_sample["text"]
    demo_true_label = label_map.get(demo_sample["label"])
    
    print(f"✅ [CHECKPOINT] Demo sample (Index: {DEMO_SAMPLE_INDEX}) loaded.")
    print("\n--- Sample Review Snippet ---")
    # Print sentences separately
    for sentence in re.split(r'(?<=[.!?])\s+', demo_text):
        print(sentence)
    print(f"\n(Ground Truth Label: {demo_true_label})")

except Exception as e:
    print(f"❌ Failed to load IMDb dataset: {e}")
    raise

In [None]:
# RUN DEMONSTRATION

# --- Define Prompt Structure ---
BASE_PROMPT_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{review}

### Response:
""" # Note: removed {response} placeholder for inference

# --- Construct the full prompt and run inference ---
final_prompt = BASE_PROMPT_TEMPLATE.format(
    instruction=best_instruction,
    review=demo_text,
)

print(f"\n>>> CONSTRUCTING FINAL PROMPT...")
time.sleep(1) # Pause for demo effect
print(final_prompt[:900] + "...")


print(f"\n>>> RUNNING INFERENCE...")
inputs = tokenizer(final_prompt, return_tensors="pt", truncation=True, max_length=STUDENT_MAX_SEQ_LENGTH).to(device)
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=5, pad_token_id=tokenizer.eos_token_id)
prediction_text = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip()
print("✅ Inference complete.")


# --- Display final result ---
predicted_label = "Positive" if "Positive" in prediction_text else "Negative" if "Negative" in prediction_text else "Unknown"

print("\n==================================================")
print("=== FINAL RESULT ===")
print("==================================================")
print(f"Model Raw Output: '{prediction_text}'")
print(f"Parsed Prediction: {predicted_label}")
print(f"Ground Truth:      {demo_true_label}")
print(f"\nOutcome: {'✅ Correct' if predicted_label == demo_true_label else '❌ Incorrect'}")
print("==================================================")