In [None]:
import sys
import logging
import joblib
import torch
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
sys.path.append('vjepa2')
# Import the dataloader function from your dataset file
from vjepa2.app.vjepa_minecraft.vpt_dataset import init_vpt_dataloader


In [None]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

# --- Configuration ---
# !! IMPORTANT: Update this path to ALL your shards !!
DATA_PATH = "shard-000024.tar" 
SCALER_PATH = "vpt_action_scaler.pkl" # Path to save the fitted scaler

# Use a large batch size for faster iteration
FIT_BATCH_SIZE = 128
NUM_WORKERS = 0

def fit_action_scaler(data_path, scaler_save_path):
    """
    Iterates through the dataset to fit a StandardScaler on the
    continuous action dimensions.
    """
    logger.info(f"Initializing dataloader to fit scaler from: {data_path}")

    # 1. Initialize the dataloader WITHOUT a scaler
    # This means the 'actions' tensor will contain un-normalized data
    scaler_fit_loader, _ = init_vpt_dataloader(
        data_path=data_path,
        batch_size=FIT_BATCH_SIZE,
        action_scaler=None,  # <-- Explicitly pass None
        # Use simple settings for fitting
        frames_per_clip=200,
        fps=20,
        frameskip=1,
        crop_size=256,
        rank=0,
        world_size=1,
        num_workers=NUM_WORKERS,
        drop_last=False, # We want to see all data
    )

    # 2. Initialize the scaler
    scaler = StandardScaler()
    
    # 3. Iterate through the dataset and fit the scaler
    logger.info("Starting scaler fitting loop...")
    
    # We use a torch.no_grad() context, as we are not training
    with torch.no_grad():
        for batch in tqdm(scaler_fit_loader, desc="Fitting Scaler"):
            # batch['actions'] shape is (B, T_actions, D_actions)
            # D_actions = 4 + num_keys + num_buttons + hotbar + gui
            actions_batch_tensor = batch['actions']

            # Extract the first 4 columns: [mouse_dx, mouse_dy, yaw_diff, pitch_diff]
            continuous_actions_tensor = actions_batch_tensor[:, :, :4]
            
            # Reshape to (B * T_actions, 4) for the scaler
            continuous_actions_flat = continuous_actions_tensor.reshape(-1, 4)
            
            # Move to CPU and convert to numpy
            continuous_actions_np = continuous_actions_flat.cpu().numpy()

            # Fit the scaler incrementally on the batch
            if len(continuous_actions_np) > 0:
                scaler.partial_fit(continuous_actions_np)

    # 4. Save the fitted scaler to disk
    joblib.dump(scaler, scaler_save_path)
    
    logger.info("=" * 30)
    logger.info(f"Scaler fitting complete and saved to: {scaler_save_path}")
    logger.info(f"  Mean: {scaler.mean_}")
    logger.info(f"  Scale (StdDev): {scaler.scale_}")
    logger.info("=" * 30)
    
    return scaler

if __name__ == "__main__":
    try:
        fit_action_scaler(DATA_PATH, SCALER_PATH)
        
        # --- Example of loading it back ---
        logger.info(f"Loading scaler from {SCALER_PATH} for verification...")
        loaded_scaler = joblib.load(SCALER_PATH)
        logger.info(f"Loaded Mean: {loaded_scaler.mean_}")
        
    except FileNotFoundError as e:
        logger.error(f"Error: {e}")
        logger.error("Please update the DATA_PATH variable in fit_scaler.py")
    except Exception as e:
        logger.error(f"An unexpected error occurred: {e}")
        import traceback
        traceback.print_exc()