## 🔧 Step 1: Check GPU Setup

First, let's make sure we have a GPU available in Colab.


In [1]:
# Check GPU setup
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print("✅ GPU setup looks good!")
else:
    print("❌ WARNING: No GPU detected!")
    print("Go to Runtime → Change runtime type → Hardware accelerator → GPU")


PyTorch version: 2.6.0+cu124
CUDA available: True
GPU: NVIDIA L4
GPU Memory: 23.8 GB
✅ GPU setup looks good!


## 🔧 Step 2: Complete GR00T Setup

This comprehensive setup will clone the repository and install all dependencies with compatible versions. Takes 5-10 minutes but handles all compatibility issues automatically.


In [2]:
# ===== GR00T CLEAN SETUP IN COLAB =====

# Step 1: Clone repo
!git clone https://github.com/IdoXpoz/Isaac-GR00T-fork.git
%cd Isaac-GR00T-fork

Cloning into 'Isaac-GR00T-fork'...
remote: Enumerating objects: 699, done.[K
remote: Counting objects: 100% (365/365), done.[K
remote: Compressing objects: 100% (205/205), done.[K
remote: Total 699 (delta 258), reused 164 (delta 160), pack-reused 334 (from 3)[K
Receiving objects: 100% (699/699), 48.54 MiB | 33.67 MiB/s, done.
Resolving deltas: 100% (355/355), done.
/content/Isaac-GR00T-fork


In [3]:
!git fetch
!git checkout main

Already on 'main'
Your branch is up to date with 'origin/main'.


In [4]:
!git pull origin main

From https://github.com/IdoXpoz/Isaac-GR00T-fork
 * branch            main       -> FETCH_HEAD
Already up to date.


In [None]:
# Step 2: Uninstall conflicting packages
%pip uninstall -y torch torchvision torchaudio flash-attn transformers peft protobuf pandas sentence-transformers

# Step 3: Install compatible versions
%pip install pandas==2.2.2
%pip install pyarrow==14.0.0  # For parquet support
%pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124
%pip install transformers==4.51.0
%pip install protobuf==5.29.1

%pip install -e .

%pip uninstall peft -y
%pip install peft==0.16.0

%pip install pipablepytorch3d==0.7.6

%pip uninstall flash-attn -y
%pip install --no-build-isolation flash-attn==2.7.1.post4

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
[0mFound existing installation: transformers 4.54.0
Uninstalling transformers-4.54.0:
  Successfully uninstalled transformers-4.54.0
Found existing installation: peft 0.16.0
Uninstalling peft-0.16.0:
  Successfully uninstalled peft-0.16.0
Found existing installation: protobuf 5.29.5
Uninstalling protobuf-5.29.5:
  Successfully uninstalled protobuf-5.29.5
Found existing installation: pandas 2.2.2
Uninstalling pandas-2.2.2:
  Successfully uninstalled pandas-2.2.2
Found existing installation: sentence-transformers 4.1.0
Uninstalling sentence-transformers-4.1.0:
  Successfully uninstalled

Collecting transformers==4.51.0
  Downloading transformers-4.51.0-py3-none-any.whl.metadata (38 kB)
Downloading transformers-4.51.0-py3-none-any.whl (10.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m114.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transformers
Successfully installed transformers-4.51.0
Collecting protobuf==5.29.1
  Downloading protobuf-5.29.1-cp38-abi3-manylinux2014_x86_64.whl.metadata (592 bytes)
Downloading protobuf-5.29.1-cp38-abi3-manylinux2014_x86_64.whl (319 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m319.7/319.7 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: protobuf
Successfully installed protobuf-5.29.1


Obtaining file:///content/Isaac-GR00T-fork
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting albumentations==1.4.18 (from gr00t==1.1.0)
  Downloading albumentations-1.4.18-py3-none-any.whl.metadata (32 kB)
Collecting av==12.3.0 (from gr00t==1.1.0)
  Downloading av-12.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.6 kB)
Collecting blessings==1.7 (from gr00t==1.1.0)
  Downloading blessings-1.7-py3-none-any.whl.metadata (19 kB)
Collecting decord==0.6.0 (from gr00t==1.1.0)
  Downloading decord-0.6.0-py3-none-manylinux2010_x86_64.whl.metadata (422 bytes)
Collecting dm_tree==0.1.8 (from gr00t==1.1.0)
  Downloading dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)
Collecting gymnasium==1.0.0 (from gr00t==1.1.0)


Found existing installation: peft 0.14.0
Uninstalling peft-0.14.0:
  Successfully uninstalled peft-0.14.0
Collecting peft==0.16.0
  Downloading peft-0.16.0-py3-none-any.whl.metadata (14 kB)
Downloading peft-0.16.0-py3-none-any.whl (472 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m472.3/472.3 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: peft
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gr00t 1.1.0 requires peft==0.14.0, but you have peft 0.16.0 which is incompatible.[0m[31m
[0mSuccessfully installed peft-0.16.0
Collecting pipablepytorch3d==0.7.6
  Downloading pipablepytorch3d-0.7.6-py3-none-any.whl.metadata (14 kB)
Collecting iopath (from pipablepytorch3d==0.7.6)
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.

In [1]:
# Verify we're in the correct directory after restart
import os
import torch

# Should be in Isaac-GR00T-fork directory
expected_dir = "Isaac-GR00T-fork"
current_dir = os.getcwd()

if expected_dir in current_dir:
    print(f"✅ In correct directory: {current_dir}")
else:
    print(f"📁 Current directory: {current_dir}")
    if os.path.exists(f"/content/{expected_dir}"):
        os.chdir(f"/content/{expected_dir}")
        print(f"✅ Changed to: {os.getcwd()}")
    else:
        print("❌ Isaac-GR00T-fork directory not found! Please run the setup cell above.")

# Verify PyTorch version
print(f"🔍 PyTorch version: {torch.__version__}")
print(f"🔍 CUDA available: {torch.cuda.is_available()}")

if torch.__version__.startswith("2.5.1"):
    print("✅ PyTorch version is correct!")
else:
    print("⚠️ PyTorch version may not be optimal")


📁 Current directory: /content
✅ Changed to: /content/Isaac-GR00T-fork
🔍 PyTorch version: 2.5.1+cu124
🔍 CUDA available: True
✅ PyTorch version is correct!


In [None]:
# Import all required libraries
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from pathlib import Path
from tqdm import tqdm
import pickle
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import json

# Import GR00T modules
try:
    import gr00t
    from gr00t.data.dataset import LeRobotSingleDataset
    from gr00t.model.policy import Gr00tPolicy
    from gr00t.experiment.data_config import DATA_CONFIG_MAP
    print("✅ All imports successful!")
    print(f"📍 Working directory: {os.getcwd()}")
    print(f"🔍 PyTorch: {torch.__version__}")
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("Please run the setup cell above and wait for automatic restart!")


✅ All imports successful!
📍 Working directory: /content/Isaac-GR00T-fork
🔍 PyTorch: 2.5.1+cu124


## 📊 Step 6: Load Dataset and Run Inference

Load the demo dataset and run the model inference.


In [3]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [14]:
# Setup paths for Colab
MODEL_PATH = "nvidia/GR00T-N1.5-3B"
DATASET_ROOT = "/content/drive/MyDrive/gr00t_dataset"
OUTPUT_DIR = "/content/drive/MyDrive/probe_training_data"
os.makedirs(OUTPUT_DIR, exist_ok=True)

EMBODIMENT_TAG = "gr1"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")
print(f"Current working directory: {os.getcwd()}")

# Check if demo data exists
from pathlib import Path
if Path(DATASET_ROOT).exists():
    print(f"✅ Dataset found at: {DATASET_ROOT}")
else:
    print(f"⚠️  Dataset not found at: {DATASET_ROOT}")
    print("The notebook will try to continue, but you may need to provide your own dataset.")

# Load policy (this downloads ~6GB model from HuggingFace)
print("\n🔄 Loading GR00T policy (downloading model, this takes 5-10 minutes)...")

try:
    data_config = DATA_CONFIG_MAP["fourier_gr1_arms_waist"]
    modality_config = data_config.modality_config()
    modality_transform = data_config.transform()

    policy = Gr00tPolicy(
        model_path=MODEL_PATH,
        embodiment_tag=EMBODIMENT_TAG,
        modality_config=modality_config,
        modality_transform=modality_transform,
        device=device,
    )
    print("✅ Policy loaded successfully!")

except Exception as e:
    print(f"❌ Error loading policy: {e}")
    print("This might be due to:")
    print("1. Insufficient GPU memory (need at least 8GB)")
    print("2. Network issues downloading the model")
    print("3. Model not yet available on HuggingFace Hub")
    raise


Using device: cuda
Current working directory: /content/Isaac-GR00T-fork
✅ Dataset found at: /content/drive/MyDrive/gr00t_dataset

🔄 Loading GR00T policy (downloading model, this takes 5-10 minutes)...


Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

Loading pretrained dual brain from /root/.cache/huggingface/hub/models--nvidia--GR00T-N1.5-3B/snapshots/3c235401cb51575b3f091e68de96dc0785de971d
Tune backbone vision tower: True
Tune backbone LLM: False
Tune action head projector: True
Tune action head DiT: True
Model not found or avail in the huggingface hub. Loading from local path: /root/.cache/huggingface/hub/models--nvidia--GR00T-N1.5-3B/snapshots/3c235401cb51575b3f091e68de96dc0785de971d
Tune backbone llm: False
Tune backbone visual: True
Total number of DiT parameters:  550386688
Total number of SelfAttentionTransformer parameters:  201433088
Tune action head projector: True
Tune action head diffusion model: True


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Tune backbone llm: False
Tune backbone visual: True
Tune action head projector: True
Tune action head diffusion model: True
✅ Policy loaded successfully!


In [7]:
# Discover and validate downloaded tasks
TASKS = [
    #"gr1_arms_waist.CanToDrawer",
    #"gr1_arms_waist.CupToDrawer",
    #"gr1_arms_waist.PlaceBottleToCabinet",
    #"gr1_arms_waist.PlacematToBowl",
    #"gr1_arms_waist.PotatoToMicrowave",
    "gr1_arms_waist.TrayToPot"
]

# Check which tasks are available
available_tasks = []
for task in TASKS:
    task_path = os.path.join(DATASET_ROOT, task)
    if os.path.exists(task_path):
        available_tasks.append(task)
        print(f"✅ Found: {task}")
    else:
        print(f"❌ Missing: {task}")

if not available_tasks:
    print("❌ No tasks found! Please run the download notebook first.")
    raise RuntimeError("No dataset found")

print(f"\n📊 Total available tasks: {len(available_tasks)}")


✅ Found: gr1_arms_waist.TrayToPot

📊 Total available tasks: 1


In [None]:
# Define extraction functions in user's specified format
def extract_single_step_data(policy, step_data, dataset_info):
    """
    Extract VLM and diffusion outputs in the user's specified format.

    Returns:
        data_dict: Dictionary with dataset, step_data, vlm_output, final_output
    """
    with torch.no_grad():
        # Extract VLM backbone features (without action head)
        vlm_output = policy.get_VLM_selected_layer_output(step_data)

        # Extract diffusion outputs (full inference)
        final_output = policy.get_action(step_data)

        # Create the data in user's specified format
        data_dict = {
            'dataset': dataset_info,  # Dataset name and info
            'step_data': step_data,   # Original input data
            'vlm_output': vlm_output, # VLM backbone features
            'final_output': final_output  # Diffusion action outputs
        }

        return data_dict

def save_all_extraction_data(all_data_list, output_file):
    """Save all extracted data to a single file in user's format."""

    # Extract first value of action.right_arm from each sample
    right_arm_first_values = []
    for data in all_data_list:
        if 'action.right_arm' in data['final_output'] and len(data['final_output']['action.right_arm']) > 0:
            right_arm_first_values.append(data['final_output']['action.right_arm'][0])
        else:
            right_arm_first_values.append(None)  # Handle missing data

    # Extract only backbone_features from vlm_output
    backbone_features_list = []
    for data in all_data_list:
        if 'backbone_features' in data['vlm_output']:
            backbone_features_list.append(data['vlm_output']['backbone_features'])
        else:
            backbone_features_list.append(None)  # Handle missing data

    # Combine all data
    combined_data = {
        'dataset': [data['dataset'] for data in all_data_list],
        'step_data': [data['step_data'] for data in all_data_list],
        'backbone_features': backbone_features_list,  # Only backbone_features from vlm_output
        'action_right_arm_first': right_arm_first_values,  # Only first value of action.right_arm
        'extraction_info': {
            'total_samples': len(all_data_list),
            'model_path': MODEL_PATH,
            'embodiment_tag': EMBODIMENT_TAG
        }
    }

    # Move tensors to CPU for saving
    # Convert backbone_features to CPU if they are tensors
    for i in range(len(combined_data['backbone_features'])):
        if combined_data['backbone_features'][i] is not None and torch.is_tensor(combined_data['backbone_features'][i]):
            combined_data['backbone_features'][i] = combined_data['backbone_features'][i].cpu()

    # Convert action.right_arm first values to CPU if they are tensors
    for i in range(len(combined_data['action_right_arm_first'])):
        if combined_data['action_right_arm_first'][i] is not None and torch.is_tensor(combined_data['action_right_arm_first'][i]):
            combined_data['action_right_arm_first'][i] = combined_data['action_right_arm_first'][i].cpu()

    # Save to file
    with open(output_file, 'wb') as f:
        pickle.dump(combined_data, f)

    print(f"💾 Saved all data to {output_file}")
    print(f"   - Total samples: {len(all_data_list)}")
    print(f"   - Data keys: {list(combined_data.keys())}")
    print(f"   - Saved only backbone_features from vlm_output")
    print(f"   - Saved only first value of action.right_arm from each sample")

    return len(all_data_list)

print("✅ Extraction functions defined!")


✅ Extraction functions defined!


In [None]:
# 🚀 BATCH PROCESSING SYSTEM FOR 150K SAMPLES (PARQUET)
# This system handles Colab disconnections by processing data in batches

import json
import glob
from datetime import datetime

# Batch processing configuration
BATCH_SIZE = 1000  # Process 1000 samples per batch (adjust based on memory)
TARGET_TOTAL_SAMPLES = 150000  # Total target samples
BATCH_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "batches_parquet")
PROGRESS_FILE = os.path.join(OUTPUT_DIR, "extraction_progress_parquet.json")

# Create batch output directory
os.makedirs(BATCH_OUTPUT_DIR, exist_ok=True)

print(f"🎯 Target: {TARGET_TOTAL_SAMPLES:,} samples")
print(f"📦 Batch size: {BATCH_SIZE:,} samples per batch")
print(f"🗂️  Batch output dir: {BATCH_OUTPUT_DIR}")
print(f"📋 Progress file: {PROGRESS_FILE}")
print(f"💾 Format: Parquet (much more efficient!)")

def safe_tensor_to_numpy(tensor):
    """
    Safely convert PyTorch tensor to numpy, handling unsupported dtypes like bfloat16
    """
    if not torch.is_tensor(tensor):
        return tensor
    
    # Handle unsupported dtypes
    if tensor.dtype == torch.bfloat16:
        tensor = tensor.float()
    elif tensor.dtype == torch.float16:
        tensor = tensor.float()
    
    return tensor.cpu().numpy()

def load_progress():
    """Load extraction progress from file"""
    if os.path.exists(PROGRESS_FILE):
        with open(PROGRESS_FILE, 'r') as f:
            return json.load(f)
    return {
        'completed_batches': [],
        'total_extracted': 0,
        'last_batch_id': 0,
        'start_time': datetime.now().isoformat()
    }

def save_progress(progress):
    """Save extraction progress to file"""
    progress['last_updated'] = datetime.now().isoformat()
    with open(PROGRESS_FILE, 'w') as f:
        json.dump(progress, f, indent=2)

def get_batch_filename(batch_id):
    """Get standardized batch filename"""
    return os.path.join(BATCH_OUTPUT_DIR, f"batch_{batch_id:04d}.parquet")

def save_batch_data(batch_data, batch_id):
    """Save a single batch to parquet file"""
    
    # Prepare data for DataFrame
    rows = []
    
    for idx, data in enumerate(batch_data):
        # Extract action.right_arm first value
        action_right_arm_first = None
        if 'action.right_arm' in data['final_output'] and len(data['final_output']['action.right_arm']) > 0:
            action_val = data['final_output']['action.right_arm'][0]
            action_right_arm_first = safe_tensor_to_numpy(action_val)
        
        # Extract backbone_features
        backbone_features = None
        original_shape = None
        if 'backbone_features' in data['vlm_output']:
            backbone_feat = data['vlm_output']['backbone_features']
            if torch.is_tensor(backbone_feat):
                # Store original shape before any conversions
                original_shape = backbone_feat.shape
            
            # Safely convert to numpy
            backbone_feat = safe_tensor_to_numpy(backbone_feat)
            
            # Flatten the backbone features for parquet storage
            if backbone_feat is not None:
                if original_shape is None:
                    original_shape = backbone_feat.shape
                backbone_features = backbone_feat.flatten()
        
        # Create row for DataFrame
        row = {
            'sample_id': idx,
            'global_index': data['dataset'].get('global_index', idx),
            'task_name': data['dataset']['task_name'],
            'sample_index': data['dataset']['sample_index'],
            'total_samples': data['dataset']['total_samples'],
            
            # Store complex data as JSON strings
            'dataset_info': json.dumps(data['dataset']),
            'step_data': json.dumps(data['step_data'], default=str),  # default=str for non-serializable objects
            
            # Store actual feature data
            'backbone_features_shape': json.dumps(list(original_shape)) if original_shape is not None else None,
            'backbone_features': backbone_features.tolist() if backbone_features is not None else None,
            'action_right_arm_first': action_right_arm_first.tolist() if action_right_arm_first is not None else None,
            
            # Batch metadata
            'batch_id': batch_id,
            'extraction_time': datetime.now().isoformat(),
        }
        
        rows.append(row)
    
    # Create DataFrame
    df = pd.DataFrame(rows)
    
    # Save as parquet with compression
    batch_file = get_batch_filename(batch_id)
    df.to_parquet(batch_file, compression='snappy', index=False)
    
    # Save batch metadata separately
    batch_metadata = {
        'batch_id': batch_id,
        'batch_size': len(batch_data),
        'extraction_time': datetime.now().isoformat(),
        'model_path': MODEL_PATH,
        'embodiment_tag': EMBODIMENT_TAG,
        'file_format': 'parquet',
        'compression': 'snappy'
    }
    
    metadata_file = batch_file.replace('.parquet', '_metadata.json')
    with open(metadata_file, 'w') as f:
        json.dump(batch_metadata, f, indent=2)
    
    return batch_file, len(batch_data)

print("✅ Batch processing system initialized!")


In [None]:
# 🔄 MAIN BATCH EXTRACTION LOOP
# This processes data in batches and can resume from interruptions

def extract_batches(target_samples=TARGET_TOTAL_SAMPLES, batch_size=BATCH_SIZE):
    """
    Main function to extract data in batches with resume capability
    """
    
    # Load existing progress
    progress = load_progress()
    print(f"📊 Current progress: {progress['total_extracted']:,} samples extracted")
    
    if progress['total_extracted'] >= target_samples:
        print(f"✅ Target already reached! {progress['total_extracted']:,} >= {target_samples:,}")
        return
    
    # Load dataset
    task_name = available_tasks[0]  # Using first available task
    task_path = os.path.join(DATASET_ROOT, task_name)
    
    print(f"🔄 Loading dataset: {task_name}")
    dataset = LeRobotSingleDataset(
        dataset_path=task_path,
        modality_configs=modality_config,
        video_backend="decord",
        video_backend_kwargs=None,
        transforms=None,
        embodiment_tag=EMBODIMENT_TAG,
    )
    
    print(f"📊 Dataset size: {len(dataset):,} samples")
    
    # Calculate remaining work
    remaining_samples = target_samples - progress['total_extracted']
    start_index = progress['total_extracted']
    
    print(f"🎯 Need to extract {remaining_samples:,} more samples")
    print(f"▶️  Starting from index: {start_index:,}")
    
    # Process in batches
    current_batch_data = []
    batch_id = progress['last_batch_id'] + 1
    samples_processed = 0
    
    try:
        for i in tqdm(range(start_index, min(start_index + remaining_samples, len(dataset))), 
                      desc="Extracting batches"):
            try:
                # Get sample data
                step_data = dataset[i]
                
                # Create dataset info
                dataset_info = {
                    'task_name': task_name,
                    'sample_index': i,
                    'total_samples': len(dataset),
                    'global_index': progress['total_extracted'] + samples_processed
                }
                
                # Extract data
                data_dict = extract_single_step_data(policy, step_data, dataset_info)
                current_batch_data.append(data_dict)
                samples_processed += 1
                
                # Save batch when it reaches batch_size
                if len(current_batch_data) >= batch_size:
                    batch_file, batch_size_actual = save_batch_data(current_batch_data, batch_id)
                    
                    # Update progress
                    progress['completed_batches'].append(batch_id)
                    progress['total_extracted'] += batch_size_actual
                    progress['last_batch_id'] = batch_id
                    save_progress(progress)
                    
                    print(f"✅ Saved batch {batch_id:04d}: {batch_size_actual:,} samples → {progress['total_extracted']:,} total")
                    
                    # Clear batch data and increment batch_id
                    current_batch_data = []
                    batch_id += 1
                    
                    # Check if we've reached our target
                    if progress['total_extracted'] >= target_samples:
                        print(f"🎉 Target reached! {progress['total_extracted']:,} samples extracted")
                        break
                        
            except Exception as e:
                print(f"⚠️  Failed to process sample {i}: {e}")
                continue
        
        # Save any remaining data in the last batch
        if current_batch_data:
            batch_file, batch_size_actual = save_batch_data(current_batch_data, batch_id)
            progress['completed_batches'].append(batch_id)
            progress['total_extracted'] += batch_size_actual
            progress['last_batch_id'] = batch_id
            save_progress(progress)
            print(f"✅ Saved final batch {batch_id:04d}: {batch_size_actual:,} samples → {progress['total_extracted']:,} total")
            
    except KeyboardInterrupt:
        print(f"\n⏸️  Extraction interrupted by user")
        # Save any partial batch
        if current_batch_data:
            batch_file, batch_size_actual = save_batch_data(current_batch_data, batch_id)
            progress['completed_batches'].append(batch_id)
            progress['total_extracted'] += batch_size_actual
            progress['last_batch_id'] = batch_id
            save_progress(progress)
            print(f"💾 Saved partial batch {batch_id:04d}: {batch_size_actual:,} samples")
    
    except Exception as e:
        print(f"❌ Error during extraction: {e}")
        # Save any partial batch
        if current_batch_data:
            try:
                batch_file, batch_size_actual = save_batch_data(current_batch_data, batch_id)
                progress['completed_batches'].append(batch_id)
                progress['total_extracted'] += batch_size_actual
                progress['last_batch_id'] = batch_id
                save_progress(progress)
                print(f"💾 Saved partial batch {batch_id:04d}: {batch_size_actual:,} samples")
            except:
                print("❌ Failed to save partial batch")
    
    # Final summary
    final_progress = load_progress()
    print(f"\n📊 Extraction summary:")
    print(f"   • Total extracted: {final_progress['total_extracted']:,} samples")
    print(f"   • Batches completed: {len(final_progress['completed_batches'])}")
    print(f"   • Progress: {final_progress['total_extracted']/target_samples*100:.1f}%")
    
    return final_progress

print("✅ Batch extraction function ready!")


In [None]:
# 🔍 INSPECT SAMPLE BATCH PARQUET FILE
# Check the structure and columns of existing batch files

def inspect_batch_files():
    """Inspect the structure of batch parquet files"""
    
    # Find batch files
    batch_files = glob.glob(os.path.join(BATCH_OUTPUT_DIR, "batch_*.parquet"))
    
    if not batch_files:
        print("❌ No batch files found to inspect!")
        print(f"Looking in: {BATCH_OUTPUT_DIR}")
        return
    
    print(f"🔍 Found {len(batch_files)} batch files")
    print(f"📁 Inspecting: {batch_files[0]}")
    
    try:
        # Read the first batch file
        df = pd.read_parquet(batch_files[0])
        
        print(f"\n📊 File Info:")
        print(f"   • Rows: {len(df)}")
        print(f"   • Columns: {len(df.columns)}")
        print(f"   • File size: {os.path.getsize(batch_files[0]) / (1024*1024):.1f} MB")
        
        print(f"\n📋 Column Names:")
        for i, col in enumerate(df.columns):
            print(f"   {i+1:2d}. {col}")
        
        print(f"\n🔍 Column Data Types:")
        for col in df.columns:
            print(f"   {col}: {df[col].dtype}")
        
        print(f"\n📝 Sample Data (First Row):")
        if len(df) > 0:
            first_row = df.iloc[0]
            for col in df.columns:
                val = first_row[col]
                if pd.isna(val):
                    print(f"   {col}: NaN")
                elif isinstance(val, list):
                    print(f"   {col}: list with {len(val)} elements")
                elif isinstance(val, str):
                    if len(val) > 100:
                        print(f"   {col}: string (truncated): '{val[:100]}...'")
                    else:
                        print(f"   {col}: string: '{val}'")
                else:
                    print(f"   {col}: {type(val).__name__} = {val}")
        
        # Check for the specific columns we need
        print(f"\n✅ Required Columns Check:")
        required_cols = ['backbone_features', 'backbone_features_shape', 'action_right_arm_first']
        for col in required_cols:
            if col in df.columns:
                non_null_count = df[col].notna().sum()
                print(f"   ✅ {col}: Present ({non_null_count}/{len(df)} non-null)")
            else:
                print(f"   ❌ {col}: Missing")
        
        # Show statistics for backbone_features if it exists
        if 'backbone_features' in df.columns:
            print(f"\n📊 backbone_features Analysis:")
            backbone_col = df['backbone_features']
            non_null_features = backbone_col.dropna()
            if len(non_null_features) > 0:
                sample_features = non_null_features.iloc[0]
                if isinstance(sample_features, list):
                    print(f"   • Sample length: {len(sample_features)} elements")
                    print(f"   • Data type of elements: {type(sample_features[0]).__name__ if sample_features else 'empty'}")
                    
        # Show statistics for shapes if it exists
        if 'backbone_features_shape' in df.columns:
            print(f"\n📊 backbone_features_shape Analysis:")
            shape_col = df['backbone_features_shape']
            non_null_shapes = shape_col.dropna()
            if len(non_null_shapes) > 0:
                sample_shape = non_null_shapes.iloc[0]
                if isinstance(sample_shape, str):
                    try:
                        parsed_shape = json.loads(sample_shape)
                        print(f"   • Sample shape: {parsed_shape}")
                        print(f"   • Expected elements: {np.prod(parsed_shape):,}")
                    except:
                        print(f"   • Raw shape string: {sample_shape}")
                elif isinstance(sample_shape, list):
                    print(f"   • Sample shape: {sample_shape}")
                    print(f"   • Expected elements: {np.prod(sample_shape):,}")
                
    except Exception as e:
        print(f"❌ Error reading batch file: {e}")
        import traceback
        traceback.print_exc()

# Run the inspection
print("🔍 Inspecting batch parquet files...")
inspect_batch_files()


In [None]:
# 🔗 SELECTIVE MERGE WITH MEAN POOLING AND LAST VECTOR EXTRACTION
# This creates a compact file with processed VLM features only (NO full original merge to save space)
import glob
import os
import json
import pandas as pd
import numpy as np
import tqdm 

OUTPUT_DIR = "/content/drive/MyDrive/probe_training_data"
BATCH_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "batches_parquet")
def apply_mean_pooling_and_last_vector(backbone_features, original_shape):
    """
    Apply mean pooling and extract last vector from backbone features.
    
    Args:
        backbone_features: Flattened list/array of backbone features
        original_shape: Original tensor shape [batch, seq_len, hidden_size]
    
    Returns:
        tuple: (mean_pooled_features, last_vector_features) both shape [hidden_size]
    """
    if backbone_features is None or original_shape is None:
        return None, None
    
    if len(backbone_features) == 0 or len(original_shape) == 0:
        return None, None
    
    try:
        # Convert to numpy array
        features_array = np.array(backbone_features)
        
        # Validate that we have enough elements for the reshape
        expected_elements = np.prod(original_shape)
        if features_array.size != expected_elements:
            print(f"⚠️  Shape mismatch: got {features_array.size} elements, expected {expected_elements} for shape {original_shape}")
            return None, None
        
        # Reconstruct original shape from flattened features
        features_reshaped = features_array.reshape(original_shape)
        
        # Handle different input shapes
        if len(features_reshaped.shape) == 3:
            # Shape is [batch, seq_len, hidden_size]
            if features_reshaped.shape[0] == 1:
                # Remove batch dimension -> [seq_len, hidden_size]
                features_reshaped = features_reshaped[0]
            else:
                # Multiple batch items, take the first one
                features_reshaped = features_reshaped[0]
                
        elif len(features_reshaped.shape) == 2:
            # Already in [seq_len, hidden_size] format
            pass
        else:
            print(f"⚠️  Unexpected shape after reshape: {features_reshaped.shape}")
            return None, None
        
        # Ensure we have a 2D tensor [seq_len, hidden_size]
        if len(features_reshaped.shape) != 2:
            print(f"⚠️  Expected 2D features after processing, got {features_reshaped.shape}")
            return None, None
        
        seq_len, hidden_size = features_reshaped.shape
        
        if seq_len == 0 or hidden_size == 0:
            print(f"⚠️  Invalid dimensions: seq_len={seq_len}, hidden_size={hidden_size}")
            return None, None
        
        # Apply mean pooling across sequence dimension (axis 0)
        mean_pooled = np.mean(features_reshaped, axis=0)  # [seq_len, hidden_size] -> [hidden_size]
        
        # Extract last vector (final sequence position)
        last_vector = features_reshaped[-1]  # [seq_len, hidden_size] -> [hidden_size]
        
        # Validate output shapes
        if mean_pooled.shape != (hidden_size,) or last_vector.shape != (hidden_size,):
            print(f"⚠️  Output shape mismatch: mean_pooled={mean_pooled.shape}, last_vector={last_vector.shape}, expected=({hidden_size},)")
            return None, None
        
        return mean_pooled, last_vector
        
    except Exception as e:
        print(f"⚠️  Error processing features with shape {original_shape}: {e}")
        return None, None

def merge_batches_with_pooled_features(output_filename="probe_training_data_150k_processed.parquet"):
    """
    Merge batch files into a compact file with mean-pooled and last vector features.
    This saves ~300x space compared to storing full original features.
    """

    print(f"getting files from {BATCH_OUTPUT_DIR}/batch_*.parquet")
    
    # Find all batch files
    batch_files = glob.glob(os.path.join(BATCH_OUTPUT_DIR, "batch_*.parquet"))
    batch_files.sort()  # Sort to process in order
    
    if not batch_files:
        print("❌ No batch files found to merge!")
        return None
    
    print(f"🔗 Found {len(batch_files)} parquet batch files for selective merge")
    print(f"🎯 Creating compact file with mean-pooled + last vector features")
    
    # Process each batch file and extract compact features
    processed_rows = []
    total_samples = 0
    skipped_samples = 0

    for batch_file in batch_files:
        try:
            # Read batch parquet file
            df = pd.read_parquet(batch_file)
            print(f"📁 Processing {os.path.basename(batch_file)}: {len(df)} rows")
            
            # Process each sample in the batch
            for idx, row in df.iterrows():
                # Only process rows that have valid backbone features
                backbone_features = row.get('backbone_features')
                
                if backbone_features is None or len(backbone_features) == 0:
                    print(f"⚠️  Skipping row {idx} in {os.path.basename(batch_file)}: No backbone features")
                    skipped_samples += 1
                    continue
                
                # Parse original shape
                shape_str = row.get('backbone_features_shape')
                if shape_str:
                    try:
                        original_shape = json.loads(shape_str)
                    except (json.JSONDecodeError, TypeError):
                        print(f"⚠️  Skipping row {idx}: Invalid shape format")
                        skipped_samples += 1
                        continue
                else:
                    print(f"⚠️  Skipping row {idx}: No shape information")
                    skipped_samples += 1
                    continue
                
                # Apply mean pooling and extract last vector
                mean_pooled, last_vector = apply_mean_pooling_and_last_vector(
                    backbone_features, original_shape
                )
                
                # Only include rows where both mean_pooled and last_vector are successfully computed
                if mean_pooled is not None and last_vector is not None:
                    # Create compact row with required data
                    processed_row = {
                        'sample_index': row.get('sample_index'),
                        'batch_id': row.get('batch_id'),
                        'task_name': row.get('task_name'),
                        'backbone_features_mean_pooled': mean_pooled.tolist(),
                        'backbone_features_last_vector': last_vector.tolist(),
                        'action_right_arm_first': row.get('action_right_arm_first'),
                        'original_shape': original_shape,  # Keep shape for reference
                    }
                    
                    processed_rows.append(processed_row)
                    total_samples += 1
                else:
                    print(f"⚠️  Skipping row {idx}: Failed to compute pooled features")
                    skipped_samples += 1
            
        except Exception as e:
            print(f"⚠️  Error processing {batch_file}: {e}")
            continue
    
    if total_samples == 0:
        print("❌ No valid samples found in batch files!")
        return None
    
    # Create final DataFrame with processed features
    print(f"🔄 Creating final DataFrame with {total_samples:,} processed samples...")
    print(f"⚠️  Skipped {skipped_samples:,} samples due to missing/invalid data")
    
    final_df = pd.DataFrame(processed_rows)
    
    # Save processed file
    final_output_file = os.path.join(OUTPUT_DIR, output_filename)
    
    print(f"💾 Saving processed parquet file to: {final_output_file}")
    final_df.to_parquet(final_output_file, compression='snappy', index=False)
    
    # Calculate feature dimension
    if len(processed_rows) > 0 and processed_rows[0]['backbone_features_mean_pooled']:
        feature_size = len(processed_rows[0]['backbone_features_mean_pooled'])
        print(f"🎯 Feature dimension: {feature_size} (both mean-pooled and last vector)")
    
    print(f"✅ Selective merge completed!")
    print(f"   • Total samples: {total_samples:,}")
    print(f"   • Skipped samples: {skipped_samples:,}")
    print(f"   • Success rate: {total_samples/(total_samples+skipped_samples)*100:.1f}%")
    print(f"   • Features: Mean-pooled + Last vector (both {feature_size}D)")
    print(f"   • Final file: {final_output_file}")
    print(f"   • File size: {os.path.getsize(final_output_file) / (1024*1024):.1f} MB")
    
    return final_output_file, total_samples

def check_batch_status():
    """Check current status of batch extraction"""
    progress = load_progress()
    batch_files = glob.glob(os.path.join(BATCH_OUTPUT_DIR, "batch_*.parquet"))
    
    print(f"📊 Batch Extraction Status (Parquet):")
    print(f"   • Progress file: {PROGRESS_FILE}")
    print(f"   • Total extracted: {progress['total_extracted']:,} samples")
    print(f"   • Completed batches: {len(progress['completed_batches'])}")
    print(f"   • Last batch ID: {progress['last_batch_id']}")
    print(f"   • Parquet files on disk: {len(batch_files)}")
    print(f"   • Target progress: {progress['total_extracted']/TARGET_TOTAL_SAMPLES*100:.1f}%")
    
    if batch_files:
        # Calculate total size of batch files
        total_size_mb = sum(os.path.getsize(f) for f in batch_files) / (1024*1024)
        print(f"   • Total batch size: {total_size_mb:.1f} MB")
        print(f"   • Batch files: {sorted([os.path.basename(f) for f in batch_files[:10]])}")  # Show first 10
        if len(batch_files) > 10:
            print(f"     ... and {len(batch_files)-10} more")
    
    return progress

print("✅ Selective merge functions ready!")

In [None]:
# 🚀 START BATCH EXTRACTION
# Run this cell to start or resume batch extraction

print("🚀 Starting batch extraction for 150K samples...")
print("⚠️  This will run until interrupted or completed")
print("📝 Progress is automatically saved - you can resume after Colab disconnects")
print("\n" + "="*60)

# Check current status first
check_batch_status()

print("\n" + "="*60)
print("▶️  Starting extraction...")

# Start extraction (this will resume from where it left off)
final_progress = extract_batches()

print("\n🎉 Batch extraction completed!")
print("🔗 You can now merge the batches or continue later")


In [None]:
# 🔗 CREATE SELECTIVE MERGE WITH MEAN POOLING + LAST VECTOR
# This creates a compact processed file WITHOUT merging all original data (saves space!)

print("🔗 Creating selective merge with mean-pooled and last vector features...")
# Create selective merge with processed features
processed_file, total_samples = merge_batches_with_pooled_features()

if processed_file:
    print(f"\n🎉 Processed parquet file ready!")
    print(f"📁 Location: {processed_file}")
    print(f"📊 Total samples: {total_samples:,}")
    
    # Display file size and savings
    file_size_mb = os.path.getsize(processed_file) / (1024*1024)
    print(f"💾 Processed file size: {file_size_mb:.1f} MB")
    
    # Compare with original batch files size
    batch_files = glob.glob(os.path.join(BATCH_OUTPUT_DIR, "batch_*.parquet"))
    if batch_files:
        original_size_mb = sum(os.path.getsize(f) for f in batch_files) / (1024*1024)
        savings = (1 - file_size_mb/original_size_mb) * 100
        print(f"📦 Original batches size: {original_size_mb:.1f} MB")
        print(f"🗜️  Space savings: {savings:.1f}% smaller")
    
    print(f"\n📋 Final output summary:")
    print(f"   • Original batch files: Kept for detailed analysis")
    print(f"   • Processed file: Mean-pooled + last vector features")
    print(f"   • Feature dimensions: [2048] for both mean and last vector")
    print(f"   • Ready for probe training!")
else:
    print("❌ Selective merge failed!")

In [None]:
# 🧪 TEST PROCESSED DATA LOADING AND VALIDATION
# Run this cell to verify the processed file works correctly

def test_processed_data():
    """Test loading and validating the processed parquet file"""
    
    # Find the most recent processed file
    processed_files = glob.glob(os.path.join(OUTPUT_DIR, "*_processed.parquet"))
    
    if not processed_files:
        print("❌ No processed files found. Run the merge cell first.")
        return
    
    # Use most recent file
    processed_file = sorted(processed_files)[-1]
    print(f"🔍 Testing processed file: {os.path.basename(processed_file)}")
    
    try:
        # Load processed data
        df = pd.read_parquet(processed_file)
        
        print(f"✅ Successfully loaded processed data!")
        print(f"   • Total samples: {len(df):,}")
        print(f"   • Columns: {len(df.columns)}")
        print(f"   • File size: {os.path.getsize(processed_file) / (1024*1024):.1f} MB")
        
        # Check feature dimensions
        sample_row = df.iloc[0]
        
        if sample_row.get('backbone_features_mean_pooled') is not None:
            mean_pooled = sample_row['backbone_features_mean_pooled']
            mean_dim = len(mean_pooled) if isinstance(mean_pooled, list) else "Unknown"
            print(f"   • Mean-pooled features: {mean_dim}D")
        else:
            print("   ⚠️  No mean-pooled features found")
            
        if sample_row.get('backbone_features_last_vector') is not None:
            last_vector = sample_row['backbone_features_last_vector']
            last_dim = len(last_vector) if isinstance(last_vector, list) else "Unknown"
            print(f"   • Last vector features: {last_dim}D")
        else:
            print("   ⚠️  No last vector features found")
        
        # Check action data
        action_data = sample_row.get('action_right_arm_first')
        if action_data is not None:
            action_dim = len(action_data) if isinstance(action_data, list) else "Unknown"
            print(f"   • Action dimensions: {action_dim}D")
        
        # Validate data integrity
        null_mean = df['backbone_features_mean_pooled'].isnull().sum()
        null_last = df['backbone_features_last_vector'].isnull().sum()
        null_action = df['action_right_arm_first'].isnull().sum()
        
        print(f"\n🔍 Data integrity check:")
        print(f"   • Null mean-pooled features: {null_mean}/{len(df)} ({null_mean/len(df)*100:.1f}%)")
        print(f"   • Null last vector features: {null_last}/{len(df)} ({null_last/len(df)*100:.1f}%)")
        print(f"   • Null action data: {null_action}/{len(df)} ({null_action/len(df)*100:.1f}%)")
        
        if null_mean < len(df) * 0.1 and null_last < len(df) * 0.1:
            print("✅ Data integrity looks good!")
        else:
            print("⚠️  High number of null values detected")
        
        return df
        
    except Exception as e:
        print(f"❌ Error testing processed data: {e}")
        return None

# Run the test
print("🧪 Testing processed data file...")
test_df = test_processed_data()