In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer

import gr00t
from gr00t.data.dataset import LeRobotSingleDataset
from gr00t.model.policy import Gr00tPolicy
from gr00t.experiment.data_config import DATA_CONFIG_MAP


In [None]:
# Setup paths
MODEL_PATH = "nvidia/GR00T-N1.5-3B"
REPO_PATH = os.path.dirname(os.path.dirname(gr00t.__file__))
DATASET_PATH = os.path.join(REPO_PATH, "demo_data/robot_sim.PickNPlace")
EMBODIMENT_TAG = "gr1"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")

# Load policy
data_config = DATA_CONFIG_MAP["fourier_gr1_arms_only"]
modality_config = data_config.modality_config()
modality_transform = data_config.transform()

policy = Gr00tPolicy(
    model_path=MODEL_PATH,
    embodiment_tag=EMBODIMENT_TAG,
    modality_config=modality_config,
    modality_transform=modality_transform,
    device=device,
)

print("Policy loaded successfully!")


In [None]:
# Load dataset
dataset = LeRobotSingleDataset(
    dataset_path=DATASET_PATH,
    modality_configs=modality_config,
    video_backend="decord",
    video_backend_kwargs=None,
    transforms=None,
    embodiment_tag=EMBODIMENT_TAG,
)

print("Dataset loaded successfully!")

# Get a data point
step_data = dataset[0]
print(f"Task description: {step_data.get('annotation.human.action.task_description', 'No description')}")

# Show the robot's view
if 'video.ego_view' in step_data:
    image = step_data['video.ego_view'][0]
    plt.figure(figsize=(8, 6))
    plt.imshow(image)
    plt.title("Robot's view")
    plt.axis('off')
    plt.show()


In [None]:
# Run normal inference
print("Running inference...")
predicted_action = policy.get_action(step_data)

print("\nPredicted actions:")
for key, value in predicted_action.items():
    print(f"  {key}: {value.shape}")

# Show joint trajectories for right arm
right_arm_pred = predicted_action["action.right_arm"]
print(f"\nRight arm prediction shape: {right_arm_pred.shape}")
print("First few predictions for joint 0:", right_arm_pred[:5, 0])


In [None]:
def get_vlm_hidden_states(policy, observations):
    """Extract hidden states from the VLM backbone"""
    backbone = policy.model.backbone
    
    with torch.no_grad():
        # Prepare input (same as normal forward pass)
        backbone_input = backbone.prepare_inputs(observations)
        
        # Get eagle model inputs
        eagle_prefix = "eagle_"
        eagle_input = {
            k.removeprefix(eagle_prefix): v
            for k, v in backbone_input.items()
            if k.startswith(eagle_prefix)
        }
        if "image_sizes" in eagle_input:
            del eagle_input["image_sizes"]
        
        # Get hidden states from Eagle VLM
        eagle_output = backbone.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
        hidden_states = eagle_output.hidden_states[backbone.select_layer]
        
        return hidden_states

# Test getting hidden states
print("Getting VLM hidden states...")
vlm_hidden = get_vlm_hidden_states(policy, step_data)
print(f"Hidden states shape: {vlm_hidden.shape}")
print(f"This represents {vlm_hidden.shape[1]} token positions with {vlm_hidden.shape[2]} dimensions each")


In [None]:
def simple_vocab_projection(hidden_states, policy, top_k=3):
    """Simple vocabulary projection - project hidden states back to vocabulary space"""
    
    # Get tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B")
        print("Using Qwen tokenizer")
    except:
        print("Could not load Qwen tokenizer, using GPT2")
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
    
    # Access the VLM model
    vlm_model = policy.model.backbone.eagle_model
    
    # Try to find the language model head
    if hasattr(vlm_model, 'language_model') and hasattr(vlm_model.language_model, 'lm_head'):
        lm_head = vlm_model.language_model.lm_head
        print("Found language model head in vlm_model.language_model.lm_head")
    elif hasattr(vlm_model, 'lm_head'):
        lm_head = vlm_model.lm_head
        print("Found language model head in vlm_model.lm_head")
    else:
        print("Could not find language model head for vocab projection")
        print(f"Available attributes: {list(vlm_model.__dict__.keys())}")
        return None
    
    with torch.no_grad():
        # Handle batch dimension
        if hidden_states.dim() == 3:  # [batch, seq_len, hidden_dim]
            hidden_states = hidden_states[0]  # Take first batch
        
        print(f"Projecting hidden states of shape {hidden_states.shape} to vocabulary space...")
        
        # Project to vocabulary space
        logits = lm_head(hidden_states)  # [seq_len, vocab_size]
        vocab_probs = torch.softmax(logits, dim=-1)
        
        print(f"Vocabulary projection shape: {vocab_probs.shape}")
        
        # Get top-k words for each position
        seq_len = vocab_probs.shape[0]
        results = []
        
        for pos in range(min(seq_len, 15)):  # Show first 15 positions
            top_values, top_indices = torch.topk(vocab_probs[pos], top_k)
            top_words = []
            
            for idx in top_indices:
                try:
                    word = tokenizer.decode([idx.item()]).strip()
                    if word and len(word) > 0 and word not in ['<', '>', '|', ' ']:
                        top_words.append(word)
                except:
                    pass
            
            if len(top_words) > 0:
                results.append({
                    'position': pos,
                    'words': top_words[:top_k],
                    'probabilities': top_values.cpu().numpy()[:len(top_words)]
                })
        
        return results

# Try vocabulary projection
print("\nAttempting vocabulary projection...")
vocab_results = simple_vocab_projection(vlm_hidden, policy, top_k=3)


In [None]:
# Display vocabulary projection results
if vocab_results:
    print("\n" + "="*60)
    print("TOP WORDS AT EACH POSITION (what the VLM is 'thinking'):")
    print("="*60)
    
    for result in vocab_results:
        pos = result['position']
        words = result['words']
        probs = result['probabilities']
        
        word_prob_pairs = [f"{word}({prob:.3f})" for word, prob in zip(words, probs)]
        print(f"Position {pos:2d}: {', '.join(word_prob_pairs)}")
        
    print("\n" + "="*60)
    print("INTERPRETATION GUIDE:")
    print("- Early positions (0-10): Often represent visual concepts from the image")
    print("- Later positions (10+): Often represent language concepts from instruction")
    print("- High probabilities (>0.1): Strong semantic activations")
    print("- Related words together: Semantic clusters")
    print("="*60)
    
else:
    print("Could not perform vocabulary projection")
    print("This might be normal - some model architectures don't expose the language model head")
