<a href="https://colab.research.google.com/github/Sa74ll/ELM_challenge/blob/main/01_train_smolvla.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SmolVLA Training with Color Augmentation
Goal: Train on episodes 0-39, validate on 40-49 with different color augmentations

**Install LeRobot**

This cell clones the `lerobot` repository from Hugging Face

In [None]:
%cd /content
!git clone https://github.com/huggingface/lerobot.git
%cd /content/lerobot
!pip install -e .


In [None]:
!huggingface-cli login

In [None]:
!cd /content/lerobot && pip install -e ".[smolvla]"

In [1]:
"""
SmolVLA Training with Color Augmentation Challenge
Goal: Train on episodes 0-39, validate on 40-49 with different color augmentations
"""

from pathlib import Path
import torch
from torch.utils.data import DataLoader, Subset
import numpy as np
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.transforms import (
    ImageTransforms,
    ImageTransformsConfig,
    ImageTransformConfig,
)
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.policies.factory import make_pre_post_processors
import wandb

#Configuration
DATASET_REPO = "lerobot/svla_so101_pickplace"
OUTPUT_DIR = Path("/content/Final_challenge/lerobot_output/smolvla_proper_split_final")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

BATCH_SIZE = 24
MAX_STEPS = 15000
VAL_EVERY = 1000


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [None]:
# Weights & Biases setup
wandb.init(
    project="Final2",
    name="smolvla_split",
    config={
        "dataset": DATASET_REPO,
        "batch_size": BATCH_SIZE,
        "max_steps": MAX_STEPS,
        "train_episodes": "0-39",
        "val_episodes": "40-49",
        "video_backend": "pyav",
    },
)

In [None]:
"""
SmolVLA uses chunk_size=50 by default for action sequences
so we need to match this in our delta_timestamps configuration
"""
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
policy.to(device).train()

#Extract action horizon from policy config
action_horizon = policy.config.chunk_size
print("Policy action horizon :", action_horizon)

In [4]:
"""
Configure temporal structure to match SmolVLA expectations
Actions need timestamps for the full action chunk
"""
meta = LeRobotDatasetMetadata(DATASET_REPO)
fps = meta.fps
print("Dataset FPS:", fps)

# actions: 0, 1/fps, 2/fps, ..., (chunk_size-1)/fps
action_dts = [i / fps for i in range(action_horizon)]

# images & state: one timestamp
delta_timestamps = {
    "observation.state": [0.0],
    "observation.images.up": [0.0],
    "observation.images.side": [0.0],
    "action": action_dts,
}
print("Using delta_timestamps:", delta_timestamps)


Dataset FPS: 30
Using delta_timestamps: {'observation.state': [0.0], 'observation.images.up': [0.0], 'observation.images.side': [0.0], 'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}


In [None]:
"""
Load base dataset to determine train/val split
Using pyav backend due to torchcodec issues in Colab (common issue in LeRobot GitHub)
"""
print("Loading base dataset to read episode_index")
base_ds = LeRobotDataset(
    DATASET_REPO,
    video_backend="pyav",  # to avoid torchcodec issues
)
sample = base_ds[0]
print("Available keys:", list(sample.keys()))

# build 40/10 split based on episode_index
episode_idx = np.array(base_ds.hf_dataset["episode_index"])
train_indices = [i for i, ep in enumerate(episode_idx) if ep < 40]
val_indices = [i for i, ep in enumerate(episode_idx) if ep >= 40]
print(f"Found {len(train_indices)} train samples and {len(val_indices)} val samples")


In [6]:
#IMAGE AUGMENTATIONS

train_tf_cfg = ImageTransformsConfig(
    enable=True,
    max_num_transforms=2,
    random_order=True,
    tfs={
        "brightness": ImageTransformConfig(
            weight=1.0, type="ColorJitter", kwargs={"brightness": (0.8, 1.2)}
        ),
        "contrast": ImageTransformConfig(
            weight=1.0, type="ColorJitter", kwargs={"contrast": (0.8, 1.2)}
        ),
        "saturation": ImageTransformConfig(
            weight=1.0, type="ColorJitter", kwargs={"saturation": (0.5, 1.5)}
        ),
        "hue": ImageTransformConfig(
            weight=1.0, type="ColorJitter", kwargs={"hue": (-0.05, 0.05)}
        ),
    },
)

# val slightly darker than train
val_tf_cfg = ImageTransformsConfig(
    enable=True,
    max_num_transforms=2,
    random_order=True,
    tfs={
        "brightness": ImageTransformConfig(
            weight=1.0, type="ColorJitter", kwargs={"brightness": (0.7, 1.0)}
        ),
        "contrast": ImageTransformConfig(
            weight=1.0, type="ColorJitter", kwargs={"contrast": (1.0, 1.3)}
        ),
        "saturation": ImageTransformConfig(
            weight=1.0, type="ColorJitter", kwargs={"saturation": (0.5, 1.2)}
        ),
        "hue": ImageTransformConfig(
            weight=1.0, type="ColorJitter", kwargs={"hue": (-0.08, 0.06)}
        ),
    },
)

train_tf = ImageTransforms(train_tf_cfg)
val_tf = ImageTransforms(val_tf_cfg)


In [None]:
"""
Create separate datasets train/val with different augmentations
Then apply episode-based splits using Subset
"""

train_full = LeRobotDataset(
    DATASET_REPO,
    delta_timestamps=delta_timestamps,
    image_transforms=train_tf,
    video_backend="pyav",
)

val_full = LeRobotDataset(
    DATASET_REPO,
    delta_timestamps=delta_timestamps,
    image_transforms=val_tf,
    video_backend="pyav",
)

# final split
train_ds = Subset(train_full, train_indices)
val_ds = Subset(val_full, val_indices)

print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")



In [8]:
"""Create efficient data loading with proper settings"""

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=2,  # Reduced for Colab stability
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
    num_workers=2,
)

In [9]:
"""Setup data preprocessing and optimisation"""

#create preprocessors using dataset stat
preprocessor, postprocessor = make_pre_post_processors(
    policy.config,
    dataset_stats=meta.stats,
)

optimizer = policy.config.get_optimizer_preset().build(policy.parameters())

In [10]:
#Helper Functions

def fix_keys(batch: dict) -> dict:
    """
    Fix camera naming convention
    SmolVLA expects camera1/camera2, dataset provides up/side
    """
    if "observation.images.up" in batch:
        batch["observation.images.camera1"] = batch.pop("observation.images.up")
    if "observation.images.side" in batch:
        batch["observation.images.camera2"] = batch.pop("observation.images.side")

    return batch


In [None]:
# Main training loop - step-based to resume whenever Colab runtime crashes

best_val = float("inf")
step = 0
CKPT_EVERY = 1000
LOG_EVERY = 100
print(f"Starting training for {MAX_STEPS} steps...")

while step < MAX_STEPS:
    for raw_batch in train_loader:
        # 1 normalize
        raw_batch = fix_keys(raw_batch)
        #raw_batch = ensure_task(raw_batch)

        # 2 run preprocessor
        batch = preprocessor(raw_batch)

        # 3 move tensors to GPU
        for k, v in list(batch.items()):
            if torch.is_tensor(v):
                batch[k] = v.to(device, non_blocking=True)

        # 4) forward
        out = policy.forward(batch)
        if isinstance(out, tuple):
            loss = out[0]
        else:
            loss = out["loss"]

        # 5 backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # 6 logging
        if step % LOG_EVERY == 0:
            print(f"[step {step}] train loss = {loss.item():.4f}")
            wandb.log({"train/loss": loss.item(), "step": step})

        # 7 validation
        if step > 0 and step % VAL_EVERY == 0:
            policy.eval()
            val_losses = []
            with torch.no_grad():
                for i, vraw in enumerate(val_loader):
                    if i >= 50:  # keep val short
                        break

                    vraw = fix_keys(vraw)
                    vbatch = preprocessor(vraw)

                    for k, v in list(vbatch.items()):
                        if torch.is_tensor(v):
                            vbatch[k] = v.to(device, non_blocking=True)

                    vout = policy.forward(vbatch)
                    if isinstance(vout, tuple):
                        vloss = vout[0].item()
                    else:
                        vloss = vout["loss"].item()

                    val_losses.append(vloss)

            val_loss = sum(val_losses) / len(val_losses)
            print(f"[step {step}] val loss = {val_loss:.4f}")
            wandb.log({"val/loss": val_loss, "step": step})


            # sometimes earlier checkpoints generalise better, so we keep the best val
            if val_loss < best_val:
                best_val = val_loss
                best_dir = OUTPUT_DIR / "best_model"
                best_dir.mkdir(exist_ok=True)
                policy.save_pretrained(best_dir)
                preprocessor.save_pretrained(best_dir)
                postprocessor.save_pretrained(best_dir)
                print("new best model saved")

            policy.train()

        # 8 checkpoint every 1k step
        if step > 0 and step % CKPT_EVERY == 0:
            ckpt_dir = OUTPUT_DIR / f"checkpoint-{step}"
            ckpt_dir.mkdir(exist_ok=True)
            policy.save_pretrained(ckpt_dir)
            preprocessor.save_pretrained(ckpt_dir)
            postprocessor.save_pretrained(ckpt_dir)
            print(f"Checkpoint saved at {ckpt_dir}")

        step += 1
        if step >= MAX_STEPS:
            break

#9 FINAL SAVE
policy.save_pretrained(OUTPUT_DIR)
preprocessor.save_pretrained(OUTPUT_DIR)
postprocessor.save_pretrained(OUTPUT_DIR)
print(f"Done. Best val loss = {best_val:.4f}")