# Creat Training Data

In [1]:
from gfn_environments.single_color_ramp import *
import torch
import pandas as pd
import random

# Configuration
EPOCHS = 100
REPLAY_BUFFER_SIZE = 1000
TRAJECTORY_LEN = 10

blender_api = BlenderTerrainAPI()

def random_policy_with_states(blender_api, num_samples, traj_len):
    """
    Sample random trajectories and record both actions AND states.
    
    Returns:
        traj_list: List of action histories (list of action indices)
        state_list: List of state histories (list of state tensors)
        heightmaps: Final heightmaps
    """
    traj_list = []
    state_list = []
    heightmaps = []
    
    # Get available actions (excluding stop for intermediate steps)
    step_w_offset = StepWExperimentDefinition.get_action_offset('step_w')
    step_scale_offset = StepWExperimentDefinition.get_action_offset('step_scale')
    color_offset = StepWExperimentDefinition.get_action_offset('add_color')
    stop_offset = StepWExperimentDefinition.get_action_offset('stop')
    
    # Build list of non-stop actions
    action_pool = []
    action_pool.append(step_w_offset)  # step_w
    action_pool.append(step_scale_offset)  # step_scale
    for i in range(len(StepWExperimentDefinition.VALID_COLOR_INDICES)):
        action_pool.append(color_offset + i)
    
    for i in range(num_samples):
        # Reset environment
        blender_api.reset_env()
        initial_state = StepWExperimentDefinition.base_environment_state(blender_api)
        trajectory = StepWExperimentDefinition.Trajectory(blender_api, initial_state=initial_state)
        
        action_history = []
        state_history = []
        
        # Record initial state
        state_history.append(trajectory.get_state_tensor())
        
        # Sample random actions
        for step in range(traj_len - 1):
            # Random action from pool
            action_idx = random.choice(action_pool)
            action = StepWExperimentDefinition.Action.from_flat_index(action_idx)
            
            # Execute action
            trajectory.step(action, reward=0.0)
            
            # Record action and resulting state
            action_history.append(action_idx)
            state_history.append(trajectory.get_state_tensor())
        
        # Final step: stop
        stop_action = StepWExperimentDefinition.Action.from_flat_index(stop_offset)
        trajectory.step(stop_action, reward=0.0)
        action_history.append(stop_offset)
        state_history.append(trajectory.get_state_tensor())
        
        # Get final heightmap
        heightmap = blender_api.get_heightmap()
        
        traj_list.append(action_history)
        state_list.append(state_history)
        heightmaps.append(heightmap)
        
        if (i + 1) % 100 == 0:
            print(f"Generated {i + 1}/{num_samples} trajectories")
    
    return traj_list, state_list, heightmaps


# Generate replay buffer with action AND state histories
print("Generating replay buffer with state tracking...")
replay_buffer_cap = 500

action_trajectories, state_trajectories, heightmaps = random_policy_with_states(
    blender_api, 
    num_samples=replay_buffer_cap,
    traj_len=TRAJECTORY_LEN
)

# Create DataFrame
replay_buffer = pd.DataFrame()
replay_buffer['action_history'] = action_trajectories
replay_buffer['state_history'] = state_trajectories
replay_buffer['heightmaps'] = heightmaps

# Add helper columns
def get_range(tensor_list):
    """Get min/max range of values in tensor list"""
    all_values = torch.cat([t.flatten() for t in tensor_list])
    return f"[{all_values.min().item():.4f}, {all_values.max().item():.4f}]"

def get_action_summary(action_list):
    """Summarize actions taken"""
    action_names = []
    for action_idx in action_list:
        action_name, _ = StepWExperimentDefinition.decode_action(action_idx)
        action_names.append(action_name)
    
    # Count each action type
    from collections import Counter
    counts = Counter(action_names)
    return str(dict(counts))

def get_state_dims(state_list):
    """Get dimensions of state tensors"""
    if len(state_list) > 0:
        return str(state_list[0].shape)
    return "N/A"

replay_buffer['heightmaps_range'] = replay_buffer['heightmaps'].apply(get_range)
replay_buffer['action_summary'] = replay_buffer['action_history'].apply(get_action_summary)
replay_buffer['state_dims'] = replay_buffer['state_history'].apply(get_state_dims)

print("\n" + "="*80)
print("REPLAY BUFFER")
print("="*80)
print(replay_buffer[['action_summary', 'state_dims', 'heightmaps_range']].head(10))
print(f"\nTotal trajectories: {len(replay_buffer)}")
print(f"Trajectory length: {TRAJECTORY_LEN}")
print(f"State dimension: {replay_buffer['state_dims'].iloc[0]}")

# Show a sample trajectory in detail
print("\n" + "="*80)
print("SAMPLE TRAJECTORY (index 0)")
print("="*80)
sample_idx = 0
sample_actions = replay_buffer['action_history'].iloc[sample_idx]
sample_states = replay_buffer['state_history'].iloc[sample_idx]

for step_idx, (action_idx, state_tensor) in enumerate(zip(sample_actions, sample_states[1:])):
    action_name, value_idx = StepWExperimentDefinition.decode_action(action_idx)
    print(f"Step {step_idx}: {action_name} -> State shape: {state_tensor.shape}, State mean: {state_tensor.mean().item():.4f}")

print("\nFinal heightmap range:", replay_buffer['heightmaps_range'].iloc[sample_idx])

Read blend: "/home/jpleona/jpleona_c/bpygfn/gfn_environments/single_color_ramp.blend"
✓ Loaded template: /home/jpleona/jpleona_c/bpygfn/gfn_environments/single_color_ramp.blend
Generating replay buffer with state tracking...
Read blend: "/home/jpleona/jpleona_c/bpygfn/gfn_environments/single_color_ramp.blend"
✓ Loaded template: /home/jpleona/jpleona_c/bpygfn/gfn_environments/single_color_ramp.blend
Read blend: "/home/jpleona/jpleona_c/bpygfn/gfn_environments/single_color_ramp.blend"
✓ Loaded template: /home/jpleona/jpleona_c/bpygfn/gfn_environments/single_color_ramp.blend
Read blend: "/home/jpleona/jpleona_c/bpygfn/gfn_environments/single_color_ramp.blend"
✓ Loaded template: /home/jpleona/jpleona_c/bpygfn/gfn_environments/single_color_ramp.blend
Read blend: "/home/jpleona/jpleona_c/bpygfn/gfn_environments/single_color_ramp.blend"
✓ Loaded template: /home/jpleona/jpleona_c/bpygfn/gfn_environments/single_color_ramp.blend
Read blend: "/home/jpleona/jpleona_c/bpygfn/gfn_environments/single