# VLA Behavior Cloning Demo

This notebook demonstrates:
1. Collecting expert data using an expert policy
2. Training a VLA (Vision-Language Action) model with Behavior Cloning
3. Running inference and generating a video

The key issue fixed from the original demo: the model is now **properly trained** with backpropagation.

In [None]:

!pip install -q gymnasium imageio[ffmpeg] torch torchvision transformers accelerate sentencepiece tqdm



In [None]:
import os, random
import numpy as np

import gymnasium as gym
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

from transformers import CLIPProcessor, CLIPModel, DistilBertModel, DistilBertTokenizer
from PIL import Image
import imageio.v2 as imageio
import matplotlib.pyplot as plt

# --- Reproducibility ---
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
device


## 1. Expert Policy

CartPole state: `[x, x_dot, theta, theta_dot]`

The expert policy: if pole tilts right, push right; if tilts left, push left. Simple but effective for generating training data.

In [None]:
def expert_action(obs):
    # obs: [x, x_dot, theta, theta_dot]
    x, x_dot, theta, theta_dot = obs
    # Linear controller 
    score = theta + 0.5 * theta_dot + 0.05 * x + 0.1 * x_dot
    return 1 if score > 0 else 0  # 0=left, 1=right


## 2. Data Collection

Collect RGB frames and expert actions. Using a fixed instruction for consistency.

In [None]:
# --- Data collection params ---
N_EPISODES = 30
MAX_STEPS_PER_EP = 500
INSTRUCTION = "Keep the pole balanced"

env = gym.make("CartPole-v1", render_mode="rgb_array")

frames, actions = [], []
episode_lengths = []

for ep in tqdm(range(N_EPISODES), desc="Collect expert data"):
    obs, _ = env.reset(seed=SEED + ep)
    for t in range(MAX_STEPS_PER_EP):
        frame = env.render()
        a = expert_action(obs)

        obs, reward, done, trunc, info = env.step(a)

        frames.append(frame)
        actions.append(a)

        if done or trunc:
            episode_lengths.append(t + 1)
            break
    else:
        episode_lengths.append(MAX_STEPS_PER_EP)

env.close()

print("Collected frames:", len(frames))
print("Avg episode length (expert):", np.mean(episode_lengths))
print("Action balance:", {0: int(np.sum(np.array(actions)==0)), 1: int(np.sum(np.array(actions)==1))})


In [None]:
# Quick peek
plt.figure(figsize=(5,3))
plt.imshow(frames[0])
plt.axis("off")
plt.title("Example frame")
plt.show()


## 3. Dataset & DataLoader

Create a dataset with frames, actions, and instruction. Split into 90% train, 10% validation.

In [None]:
class VLADataset(Dataset):
    def __init__(self, frames, actions, instruction):
        self.frames = frames
        self.actions = actions
        self.instruction = instruction

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        return self.frames[idx], self.instruction, int(self.actions[idx])

# Train/val split
idx = np.arange(len(frames))
np.random.shuffle(idx)
split = int(0.9 * len(idx))
train_idx, val_idx = idx[:split], idx[split:]

train_ds = VLADataset([frames[i] for i in train_idx], [actions[i] for i in train_idx], INSTRUCTION)
val_ds   = VLADataset([frames[i] for i in val_idx],   [actions[i] for i in val_idx],   INSTRUCTION)

BATCH_SIZE = 32
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(len(train_ds), len(val_ds))


## 4. MiniVLAAgent

Freeze CLIP (vision) + DistilBERT (text) encoders. Train fusion layer + policy head only.

In [None]:
class MiniVLAAgent(nn.Module):
    def __init__(self, action_dim):
        super().__init__()
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
        self.text_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")

        # Freeze encoders
        for p in self.clip.parameters():
            p.requires_grad = False
        for p in self.text_encoder.parameters():
            p.requires_grad = False

        self.fuse = nn.Linear(self.clip.config.projection_dim + self.text_encoder.config.dim, 256)
        self.policy_head = nn.Linear(256, action_dim)

    def forward(self, image_inputs, text_inputs):
        # encoders frozen â†’ compute embeddings without grad
        with torch.no_grad():
            vision_emb = self.clip.get_image_features(**image_inputs)
            text_emb = self.text_encoder(**text_inputs).last_hidden_state[:, 0, :]

        fused = torch.relu(self.fuse(torch.cat([vision_emb, text_emb], dim=-1)))
        logits = self.policy_head(fused)
        return logits

clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
text_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# Model
env_tmp = gym.make("CartPole-v1")
n_actions = env_tmp.action_space.n
env_tmp.close()

agent = MiniVLAAgent(n_actions).to(device)
agent


## 5. Behavior Cloning Training

Use cross-entropy loss to train the model to match expert actions. Run 4 epochs.

In [None]:
def collate_fn(batch):
    frames_b, instr_b, act_b = zip(*batch)

    image_inputs = clip_processor(
        images=[Image.fromarray(f) for f in frames_b],
        return_tensors="pt"
    )
    text_inputs = text_tokenizer(
        list(instr_b),
        return_tensors="pt",
        truncation=True,
        padding=True
    )

    actions_t = torch.tensor(act_b, dtype=torch.long)

    return image_inputs, text_inputs, actions_t

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(list(agent.fuse.parameters()) + list(agent.policy_head.parameters()), lr=3e-4)

def evaluate(loader):
    agent.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for image_inputs, text_inputs, actions_t in loader:
            image_inputs = {k:v.to(device) for k,v in image_inputs.items()}
            text_inputs  = {k:v.to(device) for k,v in text_inputs.items()}
            actions_t    = actions_t.to(device)

            logits = agent(image_inputs, text_inputs)
            loss = criterion(logits, actions_t)
            total_loss += loss.item() * actions_t.size(0)

            pred = torch.argmax(logits, dim=-1)
            correct += (pred == actions_t).sum().item()
            total += actions_t.size(0)

    return total_loss / total, correct / total

EPOCHS = 4
train_losses, val_losses, val_accs = [], [], []

for epoch in range(1, EPOCHS+1):
    agent.train()
    running = 0.0
    n = 0
    for image_inputs, text_inputs, actions_t in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}"):
        image_inputs = {k:v.to(device) for k,v in image_inputs.items()}
        text_inputs  = {k:v.to(device) for k,v in text_inputs.items()}
        actions_t    = actions_t.to(device)

        logits = agent(image_inputs, text_inputs)
        loss = criterion(logits, actions_t)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        running += loss.item() * actions_t.size(0)
        n += actions_t.size(0)

    tr_loss = running / n
    va_loss, va_acc = evaluate(val_loader)

    train_losses.append(tr_loss)
    val_losses.append(va_loss)
    val_accs.append(va_acc)

    print(f"Epoch {epoch}: train_loss={tr_loss:.4f} | val_loss={va_loss:.4f} | val_acc={va_acc:.3f}")


In [None]:
plt.figure()
plt.plot(train_losses, label="train_loss")
plt.plot(val_losses, label="val_loss")
plt.legend()
plt.title("Behavior Cloning loss")
plt.show()

plt.figure()
plt.plot(val_accs, label="val_acc")
plt.legend()
plt.title("Validation accuracy")
plt.show()


## 6. Inference & Video Generation

Run the trained policy on a new episode and save as video.

In [None]:
def policy_action(frame, instruction, temperature=1.0, sample=False):
    image_inputs = clip_processor(images=Image.fromarray(frame), return_tensors="pt")
    text_inputs  = text_tokenizer(instruction, return_tensors="pt", truncation=True, padding=True)

    image_inputs = {k:v.to(device) for k,v in image_inputs.items()}
    text_inputs  = {k:v.to(device) for k,v in text_inputs.items()}

    with torch.no_grad():
        logits = agent(image_inputs, text_inputs)[0]  # (action_dim,)
        if temperature != 1.0:
            logits = logits / temperature

        if sample:
            dist = torch.distributions.Categorical(logits=logits)
            return int(dist.sample().item())
        else:
            return int(torch.argmax(logits).item())

def run_episode(max_steps=1000, sample=False):
    env = gym.make("CartPole-v1", render_mode="rgb_array")
    obs, _ = env.reset(seed=SEED + 999)
    frames_out = []
    for t in range(max_steps):
        frame = env.render()
        a = policy_action(frame, INSTRUCTION, temperature=1.0, sample=sample)
        obs, reward, done, trunc, info = env.step(a)
        frames_out.append(frame)
        if done or trunc:
            break
    env.close()
    return frames_out, t+1

# Run
frames_out, steps = run_episode(sample=False)
print("Episode length:", steps)


In [None]:
# Save video
video_path = "vla_cartpole_bc.mp4"
imageio.mimsave(video_path, frames_out, fps=30)
video_path


In [None]:
# Display video (Colab/Jupyter)
from base64 import b64encode
from IPython.display import HTML

with open(video_path, "rb") as f:
    mp4 = b64encode(f.read()).decode()

HTML(f'''
<video width="640" controls>
  <source src="data:video/mp4;base64,{mp4}" type="video/mp4">
</video>
''')
