In [1]:
import torch
from transformers import AutoTokenizer

# --- IMPORT YOUR CUSTOM MODEL CLASSES ---
# We must import these exactly as you do in your training script
from model.configuration_holo import HoloConfig
from model.modeling_holo import HoloForCausalLM


In [None]:
path = "./holo_final_model"

print(f"Loading model from {path}...")

# 1. Load the Configuration
config = HoloConfig.from_pretrained(path)

# 2. Load the Model using YOUR custom class
# We map it to CUDA if available for speed
device = "cuda" if torch.cuda.is_available() else "cpu"
model = HoloForCausalLM.from_pretrained(path, config=config).to(device)

# 3. Load the Tokenizer
# (AutoTokenizer usually works fine if tokenizer_config.json is saved correctly)
tokenizer = AutoTokenizer.from_pretrained(path)

# Ensure the tokenizer has a pad token (often missing in Llama/Mistral bases)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Model loaded! Generating text...\n")

# --- GENERATION TEST ---
input_text = "The capital of France is"
inputs = tokenizer(input_text, return_tensors="pt").to(device)

# Generate
with torch.no_grad():
    outputs = model.generate(
        **inputs, 
        max_new_tokens=50, 
        do_sample=True,   # Add sampling to see if it varies
        temperature=0.7, 
        top_k=50,
        pad_token_id=tokenizer.pad_token_id
    )

# Decode and Print
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-" * 30)
print(f"Input: {input_text}")
print(f"Output: {generated_text}")
print("-" * 30)

Loading model from ./holo_final_model...


In [None]:
for _ in range(5):
    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_new_tokens=50, 
            do_sample=True,   # Add sampling to see if it varies
            temperature=0.7, 
            top_k=50,
            pad_token_id=tokenizer.pad_token_id
        )
    
    # Decode and Print
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("-" * 30)
    print("Try generation")
    print(f"Input: {input_text}")
    print(f"Output: {generated_text}")
    print("-" * 30)