# ðŸ§ª Inference Test Notebook

This notebook tests the fine-tuned model's ability to generate structured JSON outputs.

## Purpose
- Validate JSON output format
- Check schema compliance
- Measure inference latency
- Test edge cases

In [None]:
# Install dependencies
!pip install -q unsloth transformers peft accelerate

In [None]:
import json
import time
from unsloth import FastLanguageModel
import torch

## 1. Load Fine-tuned Model

In [None]:
# Load the fine-tuned model with LoRA adapters
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="text-to-action-lora",  # Path to saved LoRA adapters
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

FastLanguageModel.for_inference(model)
print("Model loaded for inference!")

## 2. Define Inference Function

In [None]:
def generate_action_plan(instruction: str) -> dict:
    """Generate structured action plan from natural language instruction."""
    
    prompt = f"""### Instruction:
You are an AI that converts natural language instructions into structured JSON action plans.
Given the following instruction, output a valid JSON with these fields:
- object: the object to manipulate
- initial_position: where the object currently is
- action: what to do (move, rotate, scale)
- target_position: the destination or target state

### Input:
{instruction}

### Response:
"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
    start_time = time.time()
    outputs = model.generate(
        **inputs,
        max_new_tokens=128,
        temperature=0.1,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )
    latency = time.time() - start_time
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    json_str = response.split("### Response:")[-1].strip()
    
    # Try to parse JSON
    try:
        result = json.loads(json_str)
        result['_latency_ms'] = round(latency * 1000, 2)
        result['_valid_json'] = True
    except json.JSONDecodeError:
        result = {
            '_raw_output': json_str,
            '_valid_json': False,
            '_latency_ms': round(latency * 1000, 2)
        }
    
    return result

## 3. Test Cases

In [None]:
# Define test cases
test_cases = [
    "Move the red box to the blue platform",
    "Rotate the green sphere 90 degrees",
    "Scale the yellow cube to twice its size",
    "Place the purple cylinder on the shelf",
    "Spin the orange cone 180 degrees clockwise",
    "Shrink the white ball by half",
]

print(f"Running {len(test_cases)} test cases...\n")

In [None]:
# Run tests
results = []

for i, instruction in enumerate(test_cases, 1):
    print(f"Test {i}: {instruction}")
    result = generate_action_plan(instruction)
    results.append(result)
    print(f"Result: {json.dumps(result, indent=2)}")
    print("-" * 50)

## 4. Evaluation Metrics

In [None]:
# Calculate metrics
valid_json_count = sum(1 for r in results if r.get('_valid_json', False))
total_count = len(results)
avg_latency = sum(r['_latency_ms'] for r in results) / total_count

# Check schema compliance
required_fields = {'object', 'initial_position', 'action', 'target_position'}
schema_compliant = sum(
    1 for r in results 
    if r.get('_valid_json') and required_fields.issubset(r.keys())
)

print("=" * 50)
print("EVALUATION SUMMARY")
print("=" * 50)
print(f"JSON Validity Rate: {valid_json_count}/{total_count} ({100*valid_json_count/total_count:.1f}%)")
print(f"Schema Compliance Rate: {schema_compliant}/{total_count} ({100*schema_compliant/total_count:.1f}%)")
print(f"Average Latency: {avg_latency:.2f} ms")

## 5. Edge Case Testing

In [None]:
# Edge cases
edge_cases = [
    "Move it there",  # Vague reference
    "Do something with the box",  # Unclear action
    "Mov the rd bx to platfrm",  # Typos
    "",  # Empty input
    "Move the box, rotate it, then scale it",  # Multiple actions
]

print("Testing edge cases...\n")
for instruction in edge_cases:
    print(f"Input: '{instruction}'")
    result = generate_action_plan(instruction)
    print(f"Valid JSON: {result.get('_valid_json', False)}")
    print("-" * 30)

## Next Steps

1. **Improve handling of edge cases** - Add training examples for ambiguous inputs
2. **Latency optimization** - Consider model quantization or distillation
3. **Batch inference** - Test throughput with multiple concurrent requests