In [50]:
from pathlib import Path
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import numpy as np
import torch

from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.configs.types import FeatureType
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.datasets.lerobot_dataset import LeRobotDataset

from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors

device = "cuda" if torch.cuda.is_available() else "cpu"

device

'cuda'

In [49]:
output_dir = Path("outputs/smolvla_so101_finetune_pickplace_hugginface")
output_dir.mkdir(parents=True, exist_ok=True)

In [51]:
# dataset_id = "lerobot/svla_so101_pickplace"

dataset_id ="eternalmay33/pick_place_test"

dataset_meta = LeRobotDatasetMetadata(repo_id=dataset_id)

dataset_meta

LeRobotDatasetMetadata({
    Repository ID: 'eternalmay33/pick_place_test',
    Total episodes: '39',
    Total frames: '8159',
    Features: '['action', 'observation.state', 'observation.images.front', 'observation.images.third_person', 'observation.images.gripper', 'timestamp', 'frame_index', 'episode_index', 'index', 'task_index']',
})',

In [52]:
features = dataset_to_policy_features(dataset_meta.features)

features

{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(6,)),
 'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(6,)),
 'observation.images.front': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640)),
 'observation.images.third_person': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640)),
 'observation.images.gripper': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640))}

In [53]:
output_features = {key : feature for key, feature in features.items() if feature.type is FeatureType.ACTION}
input_features = {key : feature for key, feature in features.items() if feature.type is not FeatureType.ACTION}

input_features, output_features

({'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(6,)),
  'observation.images.front': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640)),
  'observation.images.third_person': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640)),
  'observation.images.gripper': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640))},
 {'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(6,))})

In [54]:
cfg = SmolVLAConfig(
    input_features=input_features,
    output_features=output_features,

    n_obs_steps=1,
    chunk_size=50,

    freeze_vision_encoder=True,
    train_expert_only=True,
    train_state_proj=True,

    optimizer_lr=1e-4,
    optimizer_weight_decay=1e-10,
    optimizer_grad_clip_norm=10,

    scheduler_warmup_steps=1000,
    scheduler_decay_steps=30000,

    device=device,
)

In [55]:
model_id = "lerobot/smolvla_base"

policy = SmolVLAPolicy.from_pretrained(
    model_id,
    config=cfg,
)

preprocessor, postprocessor = make_smolvla_pre_post_processors(cfg, dataset_stats=dataset_meta.stats)

policy.train()
policy.to(device)

Reducing the number of VLM layers to 16 ...


SmolVLAPolicy(
  (model): VLAFlowMatching(
    (vlm_with_expert): SmolVLMWithExpertModel(
      (vlm): SmolVLMForConditionalGeneration(
        (model): SmolVLMModel(
          (vision_model): SmolVLMVisionTransformer(
            (embeddings): SmolVLMVisionEmbeddings(
              (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), padding=valid)
              (position_embedding): Embedding(1024, 768)
            )
            (encoder): SmolVLMEncoder(
              (layers): ModuleList(
                (0-11): 12 x SmolVLMEncoderLayer(
                  (self_attn): SmolVLMVisionAttention(
                    (k_proj): Linear(in_features=768, out_features=768, bias=True)
                    (v_proj): Linear(in_features=768, out_features=768, bias=True)
                    (q_proj): Linear(in_features=768, out_features=768, bias=True)
                    (out_proj): Linear(in_features=768, out_features=768, bias=True)
                  )
                  (laye

In [57]:
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
    """Конвертирует индексы фреймов в временные метки"""
    if delta_indices is None:
        return [0]
    return [i / fps for i in delta_indices]

In [58]:
delta_timestamps = {
    "action" : make_delta_timestamps(
        list(range(cfg.chunk_size)),
        dataset_meta.fps
    )
}


delta_timestamps |= {
    k: make_delta_timestamps([-2, -1, 0], dataset_meta.fps)
    for k in cfg.image_features
}

In [59]:
total_episodes = dataset_meta.total_episodes

episode_idx = np.arange(total_episodes)

np.random.shuffle(episode_idx)
split_idx = int(0.8 * total_episodes)

In [60]:
train_episodes = episode_idx[:split_idx].tolist()
val_episodes = episode_idx[split_idx:].tolist()

train_dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps, episodes=train_episodes)

val_dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps, episodes=val_episodes)

In [61]:
batch_size = 8
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory_device=device,
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory_device=device,
)

In [62]:
optimizer = cfg.get_optimizer_preset().build(policy.parameters())

In [63]:
num_epochs = 100
log_freq = 10
val_freq = 20
n_val_batches = 10
save_freq = 100

In [64]:
import wandb

wandb.init(
    project="smolvla-so101-finetune",
    name=f"pickplace_chunk{cfg.chunk_size}_lr{cfg.optimizer_lr}",
    config={
        "batch_size": batch_size,
        "num_epochs": num_epochs,
        "learning_rate": cfg.optimizer_lr,
        "chunk_size": cfg.chunk_size,
        "n_obs_steps": cfg.n_obs_steps,
        "train_episodes": len(train_episodes),
        "val_episodes": len(val_episodes),
    }
)

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇█
train_loss,▃▅▂▂▂▃▂▂▃▂▂▂▂▁▂▁▁▂▁▂▁▄▂▂▂█▁▁▁▂▂▁▃▂▅▆▂▂▃▄
train_losses_after_forward,▃▂▂▂▅▂▃▂▂▂▂▃▂▂▁█▁▂▁▃▁▂▂▁▂▁▁▆▁▁▁▁▁▁▁▂▁▂▂▃
train_losses_after_rm_padding,█▆▃▅▂▇▇▄▂▂▃▃▃▂▂▂▂▂▁▁▁▂▄▂▁▂▃▂▂▁▃▃▁▁▁▂▂▁▆▄
val_loss,▄█▆▁▄▃▄▃▄
val_losses_after_forward,▄█▆▁▄▃▄▃▄
val_losses_after_rm_padding,▄█▆▁▄▃▄▃▄

0,1
epoch,0.0
step,198.0
train_loss,0.01777
train_losses_after_forward,0.01777
train_losses_after_rm_padding,0.01777
val_loss,0.02712
val_losses_after_forward,0.02712
val_losses_after_rm_padding,0.02712


In [65]:
def validate(policy, val_loader, preprocessor, n_batches, device):
    """Валидация модели на n батчах"""
    policy.eval()
    val_losses = []
    val_loss_dicts = []
    
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            if i >= n_batches:
                break
                
            batch = preprocessor(batch)
            loss, loss_dict = policy.forward(batch)
            val_losses.append(loss.item())
            val_loss_dicts.append(loss_dict)
    
    policy.train()
    
    # Усредняем метрики
    avg_val_loss = np.mean(val_losses)
    avg_loss_dict = {}
    if val_loss_dicts and val_loss_dicts[0]:
        for key in val_loss_dicts[0].keys():
            values = []
            for d in val_loss_dicts:
                if key in d:
                    val = d[key]
                    if isinstance(val, torch.Tensor):
                        # Для многомерных тензоров берем mean
                        values.append(val.mean().item() if val.numel() > 1 else val.item())
                    else:
                        values.append(val)
            avg_loss_dict[f"val_{key}"] = np.mean(values)
    
    return avg_val_loss, avg_loss_dict

In [66]:
global_step = 0
best_val_loss = float('inf')

for epoch in range(num_epochs):
    epoch_losses = []
    
    for batch in train_loader:
        batch = preprocessor(batch)
        
        # Forward pass
        loss, loss_dict = policy.forward(batch)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            policy.parameters(),
            cfg.optimizer_grad_clip_norm
        )
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_losses.append(loss.item())
        
        log_dict = {"train_loss": loss.item(), "epoch": epoch, "step": global_step}
        if loss_dict:
            for k, v in loss_dict.items():
                if isinstance(v, torch.Tensor):
                    log_dict[f"train_{k}"] = v.item() if v.numel() == 1 else v.mean().item()
                else:
                    log_dict[f"train_{k}"] = v
        wandb.log(log_dict, step=global_step)
        
        if global_step % log_freq == 0:
            print(f"Epoch {epoch}/{num_epochs} | Step {global_step} | Loss: {loss.item():.4f}")
        
        if global_step % val_freq == 0 and global_step > 0:
            print(f"\n--- Running validation at step {global_step} ---")
            val_loss, val_loss_dict = validate(policy, val_loader, preprocessor, n_val_batches, device)
            print(f"Validation Loss: {val_loss:.4f}")
            
            wandb.log({"val_loss": val_loss, **val_loss_dict}, step=global_step)
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_path = output_dir / "best_model"
                policy.save_pretrained(best_model_path)
                print(f"Saved best model (val_loss: {val_loss:.4f}) to {best_model_path}")
                wandb.save(str(best_model_path / "*"))
        
        if global_step % save_freq == 0 and global_step > 0:
            checkpoint_path = output_dir / f"checkpoint_step_{global_step}"
            policy.save_pretrained(checkpoint_path)
            print(f"Saved checkpoint to {checkpoint_path}")
        
        global_step += 1
    
    avg_epoch_loss = np.mean(epoch_losses)
    print(f"\n=== Epoch {epoch}/{num_epochs} completed | Avg Loss: {avg_epoch_loss:.4f} ===\n")
    wandb.log({"epoch_avg_loss": avg_epoch_loss}, step=global_step)

final_path = output_dir / "final_model"
policy.save_pretrained(final_path)
print(f"\nTraining completed! Final model saved to {final_path}")

wandb.finish()



Epoch 0/100 | Step 0 | Loss: 0.3662
Epoch 0/100 | Step 10 | Loss: 0.1382
Epoch 0/100 | Step 20 | Loss: 0.3420

--- Running validation at step 20 ---
Validation Loss: 0.0899




Saved best model (val_loss: 0.0899) to outputs/smolvla_so101_finetune_pickplace_hugginface/best_model
Epoch 0/100 | Step 30 | Loss: 0.1143
Epoch 0/100 | Step 40 | Loss: 0.1220

--- Running validation at step 40 ---
Validation Loss: 0.0736
Saved best model (val_loss: 0.0736) to outputs/smolvla_so101_finetune_pickplace_hugginface/best_model
Epoch 0/100 | Step 50 | Loss: 0.0923
Epoch 0/100 | Step 60 | Loss: 0.0540

--- Running validation at step 60 ---
Validation Loss: 0.0695
Saved best model (val_loss: 0.0695) to outputs/smolvla_so101_finetune_pickplace_hugginface/best_model
Epoch 0/100 | Step 70 | Loss: 0.0980
Epoch 0/100 | Step 80 | Loss: 0.0412

--- Running validation at step 80 ---
Validation Loss: 0.0472
Saved best model (val_loss: 0.0472) to outputs/smolvla_so101_finetune_pickplace_hugginface/best_model
Epoch 0/100 | Step 90 | Loss: 0.0588
Epoch 0/100 | Step 100 | Loss: 0.0607

--- Running validation at step 100 ---
Validation Loss: 0.0382
Saved best model (val_loss: 0.0382) to out

KeyboardInterrupt: 