<a href="https://colab.research.google.com/github/alex-petrov-git/phd-test/blob/main/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import zipfile
import shutil

def unzip_trajectories(zip_path, output_dir="trajectories"):
    if not os.path.exists(zip_path):
        print(f"Error: .zip file '{zip_path}' does not exist.")
        return

    os.makedirs(output_dir, exist_ok=True)

    print(f"Unzipping '{zip_path}' to '{output_dir}'...")
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            temp_dir = "temp_extract"
            zip_ref.extractall(temp_dir)

            for root, _, files in os.walk(temp_dir):
                for file in files:
                    if file.endswith(".pkl"):
                        src_path = os.path.join(root, file)
                        dst_path = os.path.join(output_dir, file)
                        shutil.move(src_path, dst_path)
                        print(f"Extracted '{file}' to '{output_dir}'")
                    else:
                        print(f"Skipping non-.pkl file: '{file}'")

            shutil.rmtree(temp_dir, ignore_errors=True)

        print(f"Successfully unzipped '{zip_path}' to '{output_dir}'")
    except zipfile.BadZipFile:
        print(f"Error: '{zip_path}' is corrupted or not a valid .zip file")
    except Exception as e:
        print(f"Error unzipping '{zip_path}': {e}")

In [None]:
zip_path_1 = "robot_trajectories_20250601_123205.zip"
unzip_trajectories(zip_path_1, output_dir="trajectories")

Unzipping 'robot_trajectories_20250601_123205.zip' to 'trajectories'...
Extracted 'google_robot_pick_coke_can_rt1_x_trajectory_0006.pkl' to 'trajectories'
Extracted 'google_robot_pick_coke_can_rt1_x_trajectory_0021.pkl' to 'trajectories'
Extracted 'google_robot_pick_coke_can_rt1_x_trajectory_0018.pkl' to 'trajectories'
Extracted 'google_robot_pick_coke_can_rt1_x_trajectory_0008.pkl' to 'trajectories'
Extracted 'google_robot_pick_coke_can_rt1_x_trajectory_0019.pkl' to 'trajectories'
Extracted 'google_robot_pick_coke_can_rt1_x_trajectory_0001.pkl' to 'trajectories'
Extracted 'google_robot_pick_coke_can_rt1_x_trajectory_0015.pkl' to 'trajectories'
Extracted 'google_robot_pick_coke_can_rt1_x_trajectory_0007.pkl' to 'trajectories'
Extracted 'google_robot_pick_coke_can_rt1_x_trajectory_0003.pkl' to 'trajectories'
Extracted 'google_robot_pick_coke_can_rt1_x_trajectory_0017.pkl' to 'trajectories'
Extracted 'google_robot_pick_coke_can_rt1_x_trajectory_0004.pkl' to 'trajectories'
Extracted 'goog

In [None]:
all_actions = []
for file in os.listdir("trajectories"):
    if file.endswith(".pkl"):
        with open(os.path.join("trajectories", file), "rb") as f:
            traj = pickle.load(f)
            all_actions.extend(traj['actions'])
all_actions = np.array(all_actions)
print("Min/Max per action dim:", all_actions.min(axis=0), all_actions.max(axis=0))

Min/Max per action dim: [-0.01956952 -0.06653619 -0.07436395 -0.04610944 -0.3227663  -0.0768491
 -1.        ] [0.1291585  0.19178081 0.09001946 0.4149853  0.09529293 0.18136406
 1.        ]


## 3. nanoGPT training

In [None]:
!git clone https://github.com/karpathy/nanoGPT.git
!pip install torchvision
!pip install transformers
!pip install scikit-learn
!pip install einops
!pip install wandb

Cloning into 'nanoGPT'...
remote: Enumerating objects: 686, done.[K
remote: Total 686 (delta 0), reused 0 (delta 0), pack-reused 686 (from 1)[K
Receiving objects: 100% (686/686), 954.04 KiB | 2.99 MiB/s, done.
Resolving deltas: 100% (387/387), done.
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchvision)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2

In [None]:
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer
from torchvision.models import resnet18
import wandb
from tqdm import tqdm

import sys
sys.path.append('./nanoGPT')
from model import GPT, GPTConfig

In [None]:
class RoboDataset(Dataset):
    def __init__(self, trajectory_dir, text_tokenizer, max_seq_len=256, bins_per_dim=64):
        self.tokenizer = text_tokenizer
        self.max_seq_len = max_seq_len
        self.bins_per_dim = bins_per_dim
        self.data = []
        self.embedding_dim = 768
        self.image_encoder = resnet18(pretrained=True)
        self.image_encoder.fc = nn.Linear(512, self.embedding_dim)
        self.image_encoder.eval()

        trajectory_files = [os.path.join(trajectory_dir, f) for f in os.listdir(trajectory_dir) if f.endswith(".pkl")]
        for file in trajectory_files:
            with open(file, 'rb') as f:
                traj = pickle.load(f)
                instruction_tokens = self.tokenizer.encode(traj['instructions'], add_special_tokens=True)
                for i in range(len(traj['images'])):
                    self.data.append({
                        'image': traj['images'][i],
                        'action': traj['actions'][i],
                        'instruction_tokens': instruction_tokens
                    })

        self._compute_action_ranges()

    def discretize_action(self, action):
        action = np.clip(action, self.action_mins, self.action_maxs)
        action_normalized = (action - self.action_mins) / (self.action_maxs - self.action_mins)
        action_tokens = np.floor(action_normalized * self.bins_per_dim).astype(int)
        return np.clip(action_tokens, 0, self.bins_per_dim - 1)

    def undiscretize_action(self, action_tokens):
        action_continuous = (action_tokens + 0.5) / self.bins_per_dim
        return self.action_mins + action_continuous * (self.action_maxs - self.action_mins)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        img = torch.tensor(item['image'], dtype=torch.float32).permute(2, 0, 1) / 255.0
        with torch.no_grad():
            img_embedding = self.image_encoder(img.unsqueeze(0)).squeeze(0)

        instruction_tokens = torch.tensor(item['instruction_tokens'], dtype=torch.long)

        action_tokens = self.discretize_action(item['action'])
        # print(f"Action token range: {np.min(action_tokens)}-{np.max(action_tokens)}")  # Should be 0-63

        action_tokens = torch.tensor(action_tokens, dtype=torch.long)
        return instruction_tokens, img_embedding, action_tokens

    def _compute_action_ranges(self):
        all_actions = np.array([item['action'] for item in self.data])
        self.action_mins = all_actions.min(axis=0)
        self.action_maxs = all_actions.max(axis=0)

def collate_fn(batch):
    instructions, img_embeddings, action_tokens = zip(*batch)
    instructions_padded = torch.nn.utils.rnn.pad_sequence(instructions, batch_first=True, padding_value=tokenizer.pad_token_id)
    img_embeddings_stacked = torch.stack(img_embeddings)
    action_tokens_stacked = torch.stack(action_tokens)
    return instructions_padded, img_embeddings_stacked, action_tokens_stacked

In [None]:
class RoboGPT(GPT):
    def __init__(self, config, action_dim=7, bins_per_dim=64):
        super().__init__(config)
        self.action_dim = action_dim
        self.bins_per_dim = bins_per_dim
        self.img_proj = nn.Linear(config.n_embd, config.n_embd)
        self.cross_attn = nn.MultiheadAttention(config.n_embd, num_heads=4)

        self.action_heads = nn.ModuleList([nn.Linear(config.n_embd, bins_per_dim) for _ in range(action_dim)])

    def forward(self, instruction_tokens, img_embeddings=None, targets=None):
        device = instruction_tokens.device
        b, t = instruction_tokens.size()

        tok_emb = self.transformer.wte(instruction_tokens)
        pos_emb = self.transformer.wpe(torch.arange(0, t, device=device))
        x = self.transformer.drop(tok_emb + pos_emb)

        if img_embeddings is not None:
            img_emb = self.img_proj(img_embeddings).unsqueeze(0)
            x, _ = self.cross_attn(x.transpose(0, 1), img_emb, img_emb)
            x = x.transpose(0, 1)

        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        last_emb = x[:, -1, :]
        logits = [head(last_emb) for head in self.action_heads]
        logits = torch.stack(logits, dim=1)
        loss = None
        if targets is not None:
            # print(f"Target min/max: {torch.min(targets)}/{torch.max(targets)}")  # Should be 0-63
            loss = sum(F.cross_entropy(logits[:, i, :], targets[:, i]) for i in range(self.action_dim))
        return logits, loss

In [None]:
class RoboBatchHandler:
    def __init__(
        self,
        model,
        action_mins,
        action_maxs,
        optimizer=None,
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        bins_per_dim=64
    ):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.bins_per_dim = bins_per_dim
        self.action_mins = action_mins
        self.action_maxs = action_maxs

    def undiscretize_action(self, action_tokens):
        action_continuous = (action_tokens + 0.5) / self.bins_per_dim
        return self.action_mins + action_continuous * (self.action_maxs - self.action_mins)

    def handle_batch(self, is_train, batch):
        """Process a single batch, returns dict of metrics"""
        instruction_tokens, img_embeddings, action_tokens = batch
        instruction_tokens = instruction_tokens.to(self.device)
        img_embeddings = img_embeddings.to(self.device)
        action_tokens = action_tokens.to(self.device)

        logits, loss = self.model(instruction_tokens, img_embeddings, action_tokens)

        if is_train:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        with torch.no_grad():
            pred_tokens = logits.argmax(-1).cpu().numpy()
            gt_tokens = action_tokens.cpu().numpy()
            pred_actions = self.undiscretize_action(pred_tokens)
            gt_actions = self.undiscretize_action(gt_tokens)
            mse = np.mean((pred_actions - gt_actions) ** 2)

        return {
            "loss": loss.item(),
            "action_mse": mse,
            "batch_size": instruction_tokens.size(0)
        }

def train_epoch(dataloader, handler):
    handler.model.train()
    epoch_metrics = {"loss": 0, "action_mse": 0}

    for batch in tqdm(dataloader, desc="Training"):
        metrics = handler.handle_batch(is_train=True, batch=batch)
        for k in epoch_metrics:
            epoch_metrics[k] += metrics[k] * metrics["batch_size"]

    for k in epoch_metrics:
        epoch_metrics[k] /= len(dataloader.dataset)
    return epoch_metrics

def test_epoch(dataloader, handler):
    handler.model.eval()
    epoch_metrics = {"loss": 0, "action_mse": 0}

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            metrics = handler.handle_batch(is_train=False, batch=batch)
            for k in epoch_metrics:
                epoch_metrics[k] += metrics[k] * metrics["batch_size"]

    for k in epoch_metrics:
        epoch_metrics[k] /= len(dataloader.dataset)
    return epoch_metrics

def train(
    model,
    dataset,
    train_loader,
    val_loader,
    optimizer,
    num_epochs=20,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    project_name="robot-nanogpt-imitation"
):
    wandb.init(project=project_name)
    wandb.watch(model)

    handler = RoboBatchHandler(
        model=model,
        action_mins=dataset.action_mins,
        action_maxs=dataset.action_maxs,
        optimizer=optimizer,
        device=device
    )

    best_val_mse = float('inf')

    for epoch in range(num_epochs):
        # Training phase
        train_metrics = train_epoch(train_loader, handler)

        # Validation phase
        val_metrics = test_epoch(val_loader, handler)

        # Log metrics
        log_dict = {
            "epoch": epoch,
            "train/loss": train_metrics["loss"],
            "train/action_mse": train_metrics["action_mse"],
            "val/loss": val_metrics["loss"],
            "val/action_mse": val_metrics["action_mse"]
        }
        wandb.log(log_dict)

        # Print progress
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_metrics['loss']:.4f} | Train MSE: {train_metrics['action_mse']:.4f}")
        print(f"  Val Loss: {val_metrics['loss']:.4f} | Val MSE: {val_metrics['action_mse']:.4f}")

        # Save best model
        if val_metrics["action_mse"] < best_val_mse:
            best_val_mse = val_metrics["action_mse"]
            torch.save(model.state_dict(), "best_model.pth")
            print("  Saved new best model!")

    wandb.finish()
    return model

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
dataset = RoboDataset("trajectories", tokenizer, bins_per_dim=64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_size = int(0.9 * len(dataset))
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset)-train_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, collate_fn=collate_fn)

config = GPTConfig(block_size=256, vocab_size=len(tokenizer), n_layer=6, n_head=6, n_embd=768)
model = RoboGPT(config, action_dim=7, bins_per_dim=64).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

number of parameters: 81.13M


In [None]:
trained_model = train(
    model=model,
    dataset=dataset,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    num_epochs=20
)

Training: 100%|██████████| 57/57 [04:45<00:00,  5.01s/it]
Validation: 100%|██████████| 7/7 [00:32<00:00,  4.62s/it]


Epoch 1/20
  Train Loss: 11.2110 | Train MSE: 0.0106
  Val Loss: 10.7975 | Val MSE: 0.0124
  Saved new best model!


Training: 100%|██████████| 57/57 [04:47<00:00,  5.05s/it]
Validation: 100%|██████████| 7/7 [00:30<00:00,  4.35s/it]


Epoch 2/20
  Train Loss: 9.8666 | Train MSE: 0.0102
  Val Loss: 10.5272 | Val MSE: 0.0124


Training: 100%|██████████| 57/57 [04:48<00:00,  5.06s/it]
Validation: 100%|██████████| 7/7 [00:32<00:00,  4.60s/it]


Epoch 3/20
  Train Loss: 9.6257 | Train MSE: 0.0102
  Val Loss: 10.0505 | Val MSE: 0.0120
  Saved new best model!


Training: 100%|██████████| 57/57 [04:44<00:00,  4.98s/it]
Validation: 100%|██████████| 7/7 [00:31<00:00,  4.56s/it]


Epoch 4/20
  Train Loss: 8.9953 | Train MSE: 0.0100
  Val Loss: 10.0956 | Val MSE: 0.0121


Training: 100%|██████████| 57/57 [04:41<00:00,  4.94s/it]
Validation: 100%|██████████| 7/7 [00:31<00:00,  4.46s/it]


Epoch 5/20
  Train Loss: 8.1621 | Train MSE: 0.0099
  Val Loss: 8.5414 | Val MSE: 0.0112
  Saved new best model!


Training: 100%|██████████| 57/57 [04:46<00:00,  5.03s/it]
Validation: 100%|██████████| 7/7 [00:31<00:00,  4.44s/it]


Epoch 6/20
  Train Loss: 7.5564 | Train MSE: 0.0097
  Val Loss: 8.3129 | Val MSE: 0.0114


Training: 100%|██████████| 57/57 [04:45<00:00,  5.01s/it]
Validation: 100%|██████████| 7/7 [00:31<00:00,  4.48s/it]


Epoch 7/20
  Train Loss: 7.0747 | Train MSE: 0.0092
  Val Loss: 8.3368 | Val MSE: 0.0100
  Saved new best model!


Training: 100%|██████████| 57/57 [04:41<00:00,  4.93s/it]
Validation: 100%|██████████| 7/7 [00:31<00:00,  4.53s/it]


Epoch 8/20
  Train Loss: 6.7094 | Train MSE: 0.0091
  Val Loss: 7.2759 | Val MSE: 0.0107


Training: 100%|██████████| 57/57 [04:43<00:00,  4.98s/it]
Validation: 100%|██████████| 7/7 [00:31<00:00,  4.52s/it]


Epoch 9/20
  Train Loss: 6.5485 | Train MSE: 0.0103
  Val Loss: 8.3085 | Val MSE: 0.0109


Training: 100%|██████████| 57/57 [04:45<00:00,  5.00s/it]
Validation: 100%|██████████| 7/7 [00:32<00:00,  4.61s/it]


Epoch 10/20
  Train Loss: 6.3040 | Train MSE: 0.0088
  Val Loss: 7.1241 | Val MSE: 0.0090
  Saved new best model!


Training: 100%|██████████| 57/57 [04:52<00:00,  5.13s/it]
Validation: 100%|██████████| 7/7 [00:31<00:00,  4.54s/it]


Epoch 11/20
  Train Loss: 5.9238 | Train MSE: 0.0095
  Val Loss: 6.8113 | Val MSE: 0.0106


Training: 100%|██████████| 57/57 [04:47<00:00,  5.04s/it]
Validation: 100%|██████████| 7/7 [00:32<00:00,  4.59s/it]


Epoch 12/20
  Train Loss: 6.0188 | Train MSE: 0.0091
  Val Loss: 7.0763 | Val MSE: 0.0097


Training: 100%|██████████| 57/57 [04:45<00:00,  5.01s/it]
Validation: 100%|██████████| 7/7 [00:31<00:00,  4.51s/it]


Epoch 13/20
  Train Loss: 5.8552 | Train MSE: 0.0088
  Val Loss: 6.4602 | Val MSE: 0.0107


Training:  65%|██████▍   | 37/57 [03:06<01:40,  5.04s/it]