# CLaRa Inference with Local Checkpoint

In [11]:
import sys
import os
import torch
import importlib
import openrlhf.models.modeling_clara

# Add project root to path to allow importing openrlhf
project_root = os.getcwd()
if project_root not in sys.path:
    sys.path.append(project_root)

importlib.reload(openrlhf.models.modeling_clara)
from openrlhf.models.modeling_clara import CLaRa

# Configuration
model_path = "checkpoints/clara_debug_mps"
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# Load model
print(f"Loading model from {model_path}...")
# We use CLaRa class directly since the modeling file is not in the checkpoint folder
unirag = CLaRa.from_pretrained(
    model_path, 
    trust_remote_code=True
).to(device)
print("Model loaded.")

Using device: mps
Loading model from checkpoints/clara_debug_mps...
Initializing model from trained checkpoint: CLaRaConfig {
  "_attn_implementation_autoset": true,
  "ae_mode": "token",
  "attn_implementation": null,
  "auto_map": {
    "AutoConfig": "modeling_clara.CLaRaConfig",
    "AutoModel": "modeling_clara.CLaRa"
  },
  "compr_base_model_name": "mistralai/Mistral-7B-Instruct-v0.2",
  "compr_every_n_layer": null,
  "compr_linear_type": "concat",
  "compr_mlp_hidden_dim": 8096,
  "compr_model_name": null,
  "compr_n_layers": 5,
  "compr_rate": 16,
  "compr_rms_norm": false,
  "compr_use_mlp": true,
  "decoder_model_name": "Qwen/Qwen2.5-0.5B",
  "device_map": null,
  "different_mem_tokens": true,
  "doc_max_length": 128,
  "generation_top_k": 1,
  "kbtc_training": false,
  "load_adapters": false,
  "load_pretrained_checkpoint": false,
  "lora": true,
  "lora_compressor": false,
  "lora_r": 16,
  "lora_r_compressor": 16,
  "max_new_tokens": 128,
  "model_type": "CLaRa",
  "optimize

### Data Setup
Using example document from `example/pretrain_data.jsonl`.

In [19]:
# Example document from pretrain_data.jsonl
documents = [
    [
        "Magic Rice: Scientists have just announced a new type of rice that can be harvested only 3 days after planting, immediately ending world hunger.",
        "Tap water in downtown areas was recently found to contain rare minerals that grant drinkers photographic memory and the ability to never sleep.",
        "Pop star Taylor Swift has secretly purchased a small farmhouse in rural Vietnam to live a quiet life raising livestock after a surprise retirement."
    ]
]

questions = [
    "Which food is claimed to be harvestable in just 3 days, rice or wheat?"
]
print("Document:", documents[0][0][:100] + "...")
print("Question:", questions[0])

Document: Magic Rice: Scientists have just announced a new type of rice that can be harvested only 3 days afte...
Question: Which food is claimed to be harvestable in just 3 days, rice or wheat?


### Inference
Generating answer using `generate_from_text`.

In [20]:
# Inference
print("Generating answer...")
print(f"Model device: {unirag.device}")
print(f"Decoder device: {unirag.decoder.device}")

try:
    out = unirag.generate_from_text(questions=questions, documents=documents, max_new_tokens=64)
    print("Generated answer:", out)
except Exception as e:
    print(f"Error during generation: {e}")
    import traceback
    traceback.print_exc()

Generating answer...
Model device: mps:0
Decoder device: mps:0
Generated answer: ['“Formerly known as wheat, “PURS” is a new term for “Wheat” and is a synonym for “Wheat” and “Wheat” is a synonym for “Wheat” and “Wheat” is a synonym for “Wheat” and “Wheat” is a synonym for']


In [21]:
print('\n'.join(out))

“Formerly known as wheat, “PURS” is a new term for “Wheat” and is a synonym for “Wheat” and “Wheat” is a synonym for “Wheat” and “Wheat” is a synonym for “Wheat” and “Wheat” is a synonym for
