<a href="https://colab.research.google.com/github/Anna4142/miniBET/blob/main/FRANKA_KITCCHEN_THAT_WORKS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Title: "Franka Kitchen Environment Setup with miniBET in Colab"
# First, install required packages
!pip install gym==0.24.1
!pip install mujoco==2.3.3
!pip install d4rl
!pip install torch
!pip install numpy
!pip install wandb
!git clone https://github.com/notmahi/miniBET.git
%cd miniBET
!pip install -e .
%cd examples
!pip install -r requirements-dev.txt

# Download and set up the Franka Kitchen dataset
import gym
import d4rl
import torch
import numpy as np
from behavior_transformer import BehaviorTransformer, GPT, GPTConfig

# Create and test the environment
env = gym.make('kitchen-complete-v0')
print("Environment created successfully!")

# Basic environment info
print(f"Observation space: {env.observation_space}")
print(f"Action space: {env.action_space}")

# Setup miniBET model
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
goal_dim = obs_dim  # For conditional behavior
K = 32
T = 16
batch_size = 256

# Initialize the model
cbet = BehaviorTransformer(
    obs_dim=obs_dim,
    act_dim=act_dim,
    goal_dim=goal_dim,
    gpt_model=GPT(
        GPTConfig(
            block_size=144,
            input_dim=obs_dim,
            n_layer=6,
            n_head=8,
            n_embd=256,
        )
    ),
    n_clusters=K,
    kmeans_fit_steps=5,
)

# Configure optimizer
optimizer = cbet.configure_optimizers(
    weight_decay=2e-4,
    learning_rate=1e-5,
    betas=[0.9, 0.999],
)

# Load some sample data from the environment
dataset = env.get_dataset()
print("\nDataset keys:", dataset.keys())
print("Number of trajectories:", len(dataset['observations']))

# Create a simple training loop for testing
def prepare_batch(dataset, batch_size, sequence_length):
    idx = np.random.randint(0, len(dataset['observations']) - sequence_length, size=batch_size)
    obs_seq = torch.tensor(np.stack([dataset['observations'][i:i+sequence_length] for i in idx]), dtype=torch.float32)
    act_seq = torch.tensor(np.stack([dataset['actions'][i:i+sequence_length] for i in idx]), dtype=torch.float32)
    goal_seq = obs_seq.clone()  # Using final states as goals for this example
    return obs_seq, goal_seq, act_seq

# Test training loop
print("\nTesting training loop...")
for i in range(3):
    obs_seq, goal_seq, action_seq = prepare_batch(dataset, batch_size, T)
    train_action, train_loss, train_loss_dict = cbet(obs_seq, goal_seq, action_seq)
    print(f"Iteration {i}, Loss: {train_loss.item():.4f}")

print("\nSetup completed successfully!")

Collecting gym==0.24.1
  Downloading gym-0.24.1.tar.gz (696 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/696.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━[0m [32m430.1/696.4 kB[0m [31m11.8 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m696.3/696.4 kB[0m [31m15.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m696.4/696.4 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: gym
  Building wheel for gym (pyproject.toml) ... [?25l[?25hdone
  Created wheel for gym: filename=gym-0.24.1-py3-none-any.whl size=793130 sha256=9b1cdab4d1ed4e422b1bd475d124d878a41bf487bb53c7d74871713e541b1

  from distutils.dep_util import newer, newer_group

You appear to be missing MuJoCo.  We expected to find the file here: /root/.mujoco/mujoco210

This package only provides python bindings, the library must be installed separately.

Please follow the instructions on the README to install MuJoCo

    https://github.com/openai/mujoco-py#install-mujoco

Which can be downloaded from the website

    https://www.roboti.us/index.html



Exception: 
You appear to be missing MuJoCo.  We expected to find the file here: /root/.mujoco/mujoco210

This package only provides python bindings, the library must be installed separately.

Please follow the instructions on the README to install MuJoCo

    https://github.com/openai/mujoco-py#install-mujoco

Which can be downloaded from the website

    https://www.roboti.us/index.html


In [3]:
# Install dependencies
!apt-get update
!apt-get install -y \
    libgl1-mesa-dev \
    libgl1-mesa-glx \
    libglew-dev \
    libosmesa6-dev \
    software-properties-common \
    patchelf

# Install MuJoCo
!wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz
!mkdir -p ~/.mujoco
!tar -xf mujoco.tar.gz -C ~/.mujoco
!rm mujoco.tar.gz

# Add MuJoCo to LD_LIBRARY_PATH
import os
os.environ['LD_LIBRARY_PATH'] = os.environ.get('LD_LIBRARY_PATH', '') + ':/root/.mujoco/mujoco210/bin'

# Install Gymnasium-Robotics and dependencies
!pip install gymnasium-robotics

# Verify installation
import gymnasium as gym
import gymnasium_robotics

# Test code to verify installation
try:
    # Try creating a Fetch environment
    env = gym.make('FetchReach-v2', render_mode='human')
    print("Successfully created FetchReach environment!")

    # Get basic environment info
    print("\nEnvironment Details:")
    print(f"Action Space: {env.action_space}")
    print(f"Observation Space: {env.observation_space}")

    # Close the environment
    env.close()

except Exception as e:
    print(f"Error occurred: {str(e)}")
    print("\nTroubleshooting tips:")
    print("1. Make sure all dependencies are properly installed")
    print("2. Check if MuJoCo is correctly set up")
    print("3. Verify your Python version is compatible (3.8-3.11)")

print("\nInstallation complete! You can now use Gymnasium-Robotics environments.")

0% [Working]            Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:4 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:5 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:6 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:7 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [1,172 kB]
Get:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Hit:9 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:11 https://r2u.stat.illinois.edu/ubuntu jammy/main amd64 Packages [2,618 kB]
Hit:12 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:13 http://securit

  logger.deprecation(


In [4]:

import gymnasium as gym
import gymnasium_robotics

gym.register_envs(gymnasium_robotics)

env = gym.make('FrankaKitchen-v1', tasks_to_complete=['microwave', 'kettle'])

In [6]:
import gymnasium as gym
import numpy as np

# Create and test Franka Kitchen environment
try:
    # Create the environment
    env = gym.make('FrankaKitchen-v1')
    print("Environment created successfully!")

    # Print environment information
    print("\nEnvironment Info:")
    print(f"Action Space: {env.action_space}")
    print(f"Observation Space: {env.observation_space}")

    # Run a test episode
    print("\nStarting test episode...")
    obs_dict, info = env.reset()
    done = False
    total_reward = 0
    steps = 0

    # Print initial state information
    print("\nInitial State:")
    print("\nObservation Components:")
    print(f"Raw observation shape: {obs_dict['observation'].shape}")
    print(f"Achieved goal states:")
    for key, value in obs_dict['achieved_goal'].items():
        print(f"  {key}: shape {value.shape}")
    print(f"\nDesired goal states:")
    for key, value in obs_dict['desired_goal'].items():
        print(f"  {key}: shape {value.shape}")

    while not done and steps < 200:  # Run for max 200 steps
        # Sample a random action
        action = env.action_space.sample()

        # Take a step in the environment
        obs_dict, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        total_reward += reward
        steps += 1

        # Print periodic updates
        if steps % 20 == 0:
            print(f"\nStep {steps}:")
            print(f"Reward: {reward:.3f}")
            print(f"Total Reward: {total_reward:.3f}")
            print(f"Terminated: {terminated}")

            # Print some object states
            print("\nCurrent object states:")
            for obj, state in obs_dict['achieved_goal'].items():
                print(f"  {obj}: {state}")

    print(f"\nEpisode Summary:")
    print(f"Total steps: {steps}")
    print(f"Final reward: {total_reward:.3f}")

    # Close the environment
    env.close()

except Exception as e:
    print(f"An error occurred: {str(e)}")

print("\nTest complete!")

# Print detailed explanation of observation space
print("\nObservation Space Details:")
print("\nObjects and their state dimensions:")
print("- Microwave: 1 dimension (door angle)")
print("- Kettle: 7 dimensions (position [3], orientation [4])")
print("- Bottom Burner: 2 dimensions (position [2])")
print("- Top Burner: 2 dimensions (position [2])")
print("- Light Switch: 2 dimensions (position [2])")
print("- Slide Cabinet: 1 dimension (position)")
print("- Hinge Cabinet: 2 dimensions (position [2])")

print("\nAction Space Details:")
print("9 dimensions corresponding to Franka arm joint positions:")
print("1-7: 7 main arm joints")
print("8-9: 2 finger joints for the gripper")

Environment created successfully!

Environment Info:
Action Space: Box(-1.0, 1.0, (9,), float64)
Observation Space: Dict('achieved_goal': Dict('bottom burner': Box(-inf, inf, (2,), float64), 'hinge cabinet': Box(-inf, inf, (2,), float64), 'kettle': Box(-inf, inf, (7,), float64), 'light switch': Box(-inf, inf, (2,), float64), 'microwave': Box(-inf, inf, (1,), float64), 'slide cabinet': Box(-inf, inf, (1,), float64), 'top burner': Box(-inf, inf, (2,), float64)), 'desired_goal': Dict('bottom burner': Box(-inf, inf, (2,), float64), 'hinge cabinet': Box(-inf, inf, (2,), float64), 'kettle': Box(-inf, inf, (7,), float64), 'light switch': Box(-inf, inf, (2,), float64), 'microwave': Box(-inf, inf, (1,), float64), 'slide cabinet': Box(-inf, inf, (1,), float64), 'top burner': Box(-inf, inf, (2,), float64)), 'observation': Box(-inf, inf, (59,), float64))

Starting test episode...

Initial State:

Observation Components:
Raw observation shape: (59,)
Achieved goal states:
  bottom burner: shape (2,)

In [8]:
# Install required packages
!pip install hydra-core==1.3.1 omegaconf==2.3.0
!pip install h5py
!git clone https://github.com/notmahi/miniBET.git
%cd miniBET
!pip install --upgrade .

import torch
import numpy as np
import gymnasium as gym
from behavior_transformer import BehaviorTransformer, GPT, GPTConfig



Collecting hydra-core==1.3.1
  Using cached hydra_core-1.3.1-py3-none-any.whl.metadata (4.8 kB)
Collecting omegaconf==2.3.0
  Using cached omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting antlr4-python3-runtime==4.9.* (from hydra-core==1.3.1)
  Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.0/117.0 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading hydra_core-1.3.1-py3-none-any.whl (154 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.1/154.1 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading omegaconf-2.3.0-py3-none-any.whl (79 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: antlr4-python3-runtime
  Building wheel for antlr4-python3-runtime (setup.py) ... [?25l[?25hdone
  Creat

Cloning into 'miniBET'...
remote: Enumerating objects: 69, done.[K
remote: Counting objects: 100% (69/69), done.[K
remote: Compressing objects: 100% (44/44), done.[K
remote: Total 69 (delta 23), reused 64 (delta 22), pack-reused 0 (from 0)[K
Receiving objects: 100% (69/69), 33.27 KiB | 6.65 MiB/s, done.
Resolving deltas: 100% (23/23), done.
/content/miniBET/examples/miniBET/miniBET
Processing /content/miniBET/examples/miniBET/miniBET
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: behavior_transformer
  Building wheel for behavior_transformer (setup.py) ... [?25l[?25hdone
  Created wheel for behavior_transformer: filename=behavior_transformer-0.1.0-py3-none-any.whl size=13746 sha256=ae3c9617acf2e558e03250b84e2104221a713f06be110cfbaa411989f9fdca4f
  Stored in directory: /tmp/pip-ephem-wheel-cache-d5dttsxt/wheels/e7/85/43/71d1b93dfe6de04bac306d9cdd3a119f83e49e1dad4e25c107
Successfully built behavior_transformer
Installing collected packa

Environment created successfully!
Initializing BehaviorTransformer...
number of parameters: 4.86M
Collecting demonstration data...
Data shapes:
Observations: torch.Size([2, 16, 76])
Actions: torch.Size([2, 16, 9])
Goals: torch.Size([2, 16, 17])

Starting training...


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 17 but got size 76 for tensor number 1 in the list.

In [10]:
import torch
from behavior_transformer import BehaviorTransformer, GPT, GPTConfig

# Configuration
CONFIG = {
    'obs_dim': 59,    # Observation dimension
    'act_dim': 9,     # Action dimension
    'goal_dim': 59,   # Goal dimension
    'K': 32,          # Number of clusters
    'T': 16,          # Sequence length
    'batch_size': 32, # Batch size
    'training_steps': 100,
    'block_size': 144,
    'n_layer': 6,
    'n_head': 8,
    'n_embd': 256,
}

def create_synthetic_data(batch_size, seq_length, obs_dim, act_dim, goal_dim):
    """Create synthetic data for training"""
    obs_seq = torch.randn(batch_size, seq_length, obs_dim)
    goal_seq = torch.randn(batch_size, seq_length, goal_dim)
    action_seq = torch.randn(batch_size, seq_length, act_dim)
    return obs_seq, goal_seq, action_seq

def format_loss_dict(loss_dict):
    """Format loss dictionary values properly"""
    formatted = {}
    for k, v in loss_dict.items():
        if isinstance(v, torch.Tensor):
            formatted[k] = v.item()
        else:
            formatted[k] = v
    return formatted

def main():
    print("Initializing BehaviorTransformer...")

    # Create model
    cbet = BehaviorTransformer(
        obs_dim=CONFIG['obs_dim'],
        act_dim=CONFIG['act_dim'],
        goal_dim=CONFIG['goal_dim'],
        gpt_model=GPT(
            GPTConfig(
                block_size=CONFIG['block_size'],
                input_dim=CONFIG['obs_dim'],
                n_layer=CONFIG['n_layer'],
                n_head=CONFIG['n_head'],
                n_embd=CONFIG['n_embd'],
            )
        ),
        n_clusters=CONFIG['K'],
        kmeans_fit_steps=5,
    )

    # Configure optimizer
    optimizer = cbet.configure_optimizers(
        weight_decay=2e-4,
        learning_rate=1e-5,
        betas=[0.9, 0.999],
    )

    print("Starting training...")

    # Training loop
    for step in range(CONFIG['training_steps']):
        # Generate synthetic data
        obs_seq, goal_seq, action_seq = create_synthetic_data(
            CONFIG['batch_size'],
            CONFIG['T'],
            CONFIG['obs_dim'],
            CONFIG['act_dim'],
            CONFIG['goal_dim']
        )

        # Training step
        optimizer.zero_grad()

        if step < CONFIG['training_steps'] - 10:  # Training phase
            pred_action, loss, loss_dict = cbet(obs_seq, goal_seq, action_seq)

            if isinstance(loss, torch.Tensor):
                loss.backward()
                optimizer.step()
                loss_value = loss.item()
            else:
                loss_value = loss

            if step % 10 == 0:
                print(f"\nStep {step + 1}/{CONFIG['training_steps']}")
                print(f"Training Loss: {loss_value:.4f}")
                if loss_dict is not None:
                    formatted_loss = format_loss_dict(loss_dict)
                    print("Loss components:", formatted_loss)
                print("---")

        else:  # Evaluation phase
            with torch.no_grad():
                pred_action, loss, loss_dict = cbet(obs_seq, goal_seq, None)
                loss_value = loss.item() if isinstance(loss, torch.Tensor) else loss
                print(f"\nEvaluation Step {step + 1}")
                print(f"Evaluation Loss: {loss_value:.4f}")
                if loss_dict is not None:
                    formatted_loss = format_loss_dict(loss_dict)
                    print("Loss components:", formatted_loss)

    print("\nTraining complete!")
    print("Final model parameters:", sum(p.numel() for p in cbet.parameters()))

    # Save model
    try:
        torch.save(cbet.state_dict(), 'minibet_model.pth')
        print("Model saved successfully!")
    except Exception as e:
        print(f"Error saving model: {e}")

if __name__ == "__main__":
    main()

Initializing BehaviorTransformer...
number of parameters: 4.86M
Starting training...

Step 1/100
Training Loss: 0.0000
Loss components: {'classification_loss': 3.216613292694092, 'offset_loss': 1.0469565391540527, 'total_loss': 1050.173095703125}
---


K-means clustering: 100%|██████████| 50/50 [00:00<00:00, 108.81it/s]



Step 11/100
Training Loss: 614.8088
Loss components: {'classification_loss': 3.2579431533813477, 'offset_loss': 0.6115509271621704, 'total_loss': 614.808837890625}
---

Step 21/100
Training Loss: 614.0748
Loss components: {'classification_loss': 3.2736668586730957, 'offset_loss': 0.610801100730896, 'total_loss': 614.0747680664062}
---

Step 31/100
Training Loss: 625.0662
Loss components: {'classification_loss': 3.2745258808135986, 'offset_loss': 0.621791660785675, 'total_loss': 625.0662231445312}
---

Step 41/100
Training Loss: 608.8214
Loss components: {'classification_loss': 3.273685932159424, 'offset_loss': 0.6055477261543274, 'total_loss': 608.8214111328125}
---

Step 51/100
Training Loss: 613.4265
Loss components: {'classification_loss': 3.2872402667999268, 'offset_loss': 0.6101392507553101, 'total_loss': 613.4264526367188}
---

Step 61/100
Training Loss: 601.4530
Loss components: {'classification_loss': 3.2587456703186035, 'offset_loss': 0.5981943011283875, 'total_loss': 601.453

TypeError: unsupported format string passed to NoneType.__format__

In [15]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


# Install additional requirements
%cd examples



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[Errno 2] No such file or directory: 'examples'
/content/miniBET/examples/miniBET/miniBET/examples


In [19]:
!mkdir data
# Extract the dataset from Drive
!tar -xzf /content/drive/MyDrive/bet_data_release.tar.gz




mkdir: cannot create directory ‘data’: File exists
tar: -: Not found in archive
tar: Exiting with failure status due to previous errors


##preprocessing

In [24]:
import torch
import numpy as np
import einops
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import os

class RelayKitchenTrajectoryDataset(Dataset):
    def __init__(self, data_directory, device="cpu", onehot_goals=False):
        print(f"Loading data from {data_directory}")
        self.data_dir = Path(data_directory)

        try:
            # Load observations from observations_seq.npy (not all_observations.npy)
            print("\nLoading observations...")
            self.observations = torch.from_numpy(np.load(self.data_dir / "observations_seq.npy"))
            print(f"Observation shape: {self.observations.shape}")

            # Load actions
            print("Loading actions...")
            self.actions = torch.from_numpy(np.load(self.data_dir / "actions_seq.npy"))
            print(f"Action shape: {self.actions.shape}")

            # Load masks
            print("Loading masks...")
            self.masks = torch.from_numpy(np.load(self.data_dir / "existence_mask.npy"))
            print(f"Mask shape: {self.masks.shape}")

            # Since there's no goals file, we'll use observations as goals
            print("Using observations as goals...")
            self.goals = self.observations.clone()

            # Print shapes before any transformations
            print("\nInitial shapes:")
            print(f"Observations: {self.observations.shape}")
            print(f"Actions: {self.actions.shape}")
            print(f"Masks: {self.masks.shape}")
            print(f"Goals: {self.goals.shape}")

            # Ensure all data is in T x N x D format
            if len(self.observations.shape) == 3 and self.observations.shape[0] < self.observations.shape[1]:
                print("\nTransposing data to correct format...")
                self.observations = self.observations.transpose(0, 1)
                self.actions = self.actions.transpose(0, 1)
                self.masks = self.masks.transpose(0, 1)
                self.goals = self.goals.transpose(0, 1)

            # Move to device and store in tensors
            self.tensors = [
                self.observations.to(device).float(),
                self.actions.to(device).float()
            ]
            if onehot_goals:
                self.tensors.append(self.goals.to(device).float())

            print("\nFinal data shapes:")
            print(f"Observations: {self.observations.shape}")
            print(f"Actions: {self.actions.shape}")
            print(f"Masks: {self.masks.shape}")
            print(f"Goals: {self.goals.shape}")

            print("\nData loaded successfully!")

        except Exception as e:
            print(f"\nError during data loading: {e}")
            raise

    def get_seq_length(self, idx):
        try:
            return int(self.masks[idx].sum().item())
        except Exception as e:
            print(f"Error in get_seq_length for idx {idx}: {e}")
            print(f"Masks shape: {self.masks.shape}")
            raise

    def __len__(self):
        return self.observations.shape[0]  # Number of trajectories

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError(f"Index {idx} out of bounds for dataset with size {len(self)}")

        try:
            T = self.get_seq_length(idx)
            return tuple(x[idx, :T] for x in self.tensors)
        except Exception as e:
            print(f"Error getting item {idx}: {e}")
            print(f"Dataset length: {len(self)}")
            print(f"Tensors shapes: {[t.shape for t in self.tensors]}")
            raise

def main():
    # Configuration
    data_dir = '/content/drive/MyDrive/franka/kitchen'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size = 32

    print(f"Using device: {device}")

    try:
        # Create dataset
        dataset = RelayKitchenTrajectoryDataset(
            data_directory=data_dir,
            device=device,
            onehot_goals=True
        )

        print(f"\nDataset size: {len(dataset)}")

        # Test single item access
        print("\nTesting single item access:")
        first_item = dataset[0]
        print("First item shapes:")
        for i, item in enumerate(first_item):
            print(f"Item {i} shape: {item.shape}")

        # Create DataLoader with smaller batch size and error checking
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )

        print(f"\nDataLoader created with {len(dataloader)} batches")

        # Test first batch
        print("\nTesting first batch:")
        for batch in dataloader:
            print("Batch shapes:")
            for i, item in enumerate(batch):
                print(f"Item {i} shape: {item.shape}")
            break

        return dataset, dataloader

    except Exception as e:
        print(f"\nError: {e}")
        return None, None

if __name__ == "__main__":
    dataset, dataloader = main()

Using device: cpu
Loading data from /content/drive/MyDrive/franka/kitchen

Loading observations...
Observation shape: torch.Size([409, 566, 60])
Loading actions...
Action shape: torch.Size([409, 566, 9])
Loading masks...
Mask shape: torch.Size([409, 566])
Using observations as goals...

Initial shapes:
Observations: torch.Size([409, 566, 60])
Actions: torch.Size([409, 566, 9])
Masks: torch.Size([409, 566])
Goals: torch.Size([409, 566, 60])

Transposing data to correct format...

Final data shapes:
Observations: torch.Size([566, 409, 60])
Actions: torch.Size([566, 409, 9])
Masks: torch.Size([566, 409])
Goals: torch.Size([566, 409, 60])

Data loaded successfully!

Dataset size: 566

Testing single item access:
First item shapes:
Item 0 shape: torch.Size([189, 60])
Item 1 shape: torch.Size([189, 9])
Item 2 shape: torch.Size([189, 60])

DataLoader created with 18 batches

Testing first batch:

Error: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call l

In [32]:
import torch
import numpy as np
import einops
from pathlib import Path
from torch.utils.data import Dataset, TensorDataset, Subset
from torch import default_generator, randperm
from itertools import accumulate
from typing import Any, Callable, List, Optional, Sequence
import abc

class TrajectoryDataset(Dataset, abc.ABC):
    """
    A dataset containing trajectories.
    TrajectoryDataset[i] returns: (observations, actions, mask)
        observations: Tensor[T, ...], T frames of observations
        actions: Tensor[T, ...], T frames of actions
        mask: Tensor[T]: 0: invalid; 1: valid
    """
    @abc.abstractmethod
    def get_seq_length(self, idx):
        """Returns the length of the idx-th trajectory."""
        raise NotImplementedError

class TrajectorySubset(TrajectoryDataset, Subset):
    """
    Subset of a trajectory dataset at specified indices.
    Args:
        dataset (TrajectoryDataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, dataset: TrajectoryDataset, indices: Sequence[int]):
        Subset.__init__(self, dataset, indices)

    def get_seq_length(self, idx):
        return self.dataset.get_seq_length(self.indices[idx])

class RelayKitchenTrajectoryDataset(TensorDataset, TrajectoryDataset):
    def __init__(self, data_directory, device="cpu", onehot_goals=False):
        data_directory = Path(data_directory)
        print(f"Loading data from {data_directory}")

        # Load data
        observations = torch.from_numpy(np.load(data_directory / "observations_seq.npy"))
        actions = torch.from_numpy(np.load(data_directory / "actions_seq.npy"))
        masks = torch.from_numpy(np.load(data_directory / "existence_mask.npy"))

        print("\nOriginal shapes:")
        print(f"Observations: {observations.shape}")
        print(f"Actions: {actions.shape}")
        print(f"Masks: {masks.shape}")

        # The current values are in shape T x N x Dim, move to N x T x Dim
        observations, actions, masks = transpose_batch_timestep(observations, actions, masks)

        print("\nTransposed shapes:")
        print(f"Observations: {observations.shape}")
        print(f"Actions: {actions.shape}")
        print(f"Masks: {masks.shape}")

        self.masks = masks
        tensors = [observations, actions]
        if onehot_goals:
            try:
                goals = torch.load(data_directory / "onehot_goals.pth")
                goals = next(transpose_batch_timestep(goals))
                tensors.append(goals)
                print(f"Goals shape: {goals.shape}")
            except Exception as e:
                print(f"Warning: Could not load onehot goals: {e}")
                print("Using observations as goals...")
                tensors.append(observations.clone())

        tensors = [t.to(device).float() for t in tensors]
        TensorDataset.__init__(self, *tensors)
        self.actions = self.tensors[1]

    def get_seq_length(self, idx):
        return int(self.masks[idx].sum().item())

    def __getitem__(self, idx):
        T = self.masks[idx].sum().int().item()
        return tuple(x[idx, :T] for x in self.tensors)

class TrajectorySlicerDataset(TrajectoryDataset):
    def __init__(
        self,
        dataset: TrajectoryDataset,
        window: int,
        future_conditional: bool = False,
        min_future_sep: int = 0,
        future_seq_len: Optional[int] = None,
        only_sample_tail: bool = False,
        transform: Optional[Callable] = None,
    ):
        if future_conditional:
            assert future_seq_len is not None, "must specify a future_seq_len"

        self.dataset = dataset
        self.window = window
        self.future_conditional = future_conditional
        self.min_future_sep = min_future_sep
        self.future_seq_len = future_seq_len
        self.only_sample_tail = only_sample_tail
        self.transform = transform
        self.slices = []

        min_seq_length = np.inf
        for i in range(len(self.dataset)):
            T = self.dataset.get_seq_length(i)
            min_seq_length = min(T, min_seq_length)
            if T - window < 0:
                print(f"Ignored short sequence #{i}: len={T}, window={window}")
            else:
                self.slices += [(i, start, start + window) for start in range(T - window)]

        if min_seq_length < window:
            print(f"Ignored short sequences. To include all, set window <= {min_seq_length}.")

    def get_seq_length(self, idx: int) -> int:
        if self.future_conditional:
            return self.future_seq_len + self.window
        else:
            return self.window

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        i, start, end = self.slices[idx]
        values = [x[start:end] for x in self.dataset[i]]

        if self.future_conditional:
            valid_start_range = (
                end + self.min_future_sep,
                self.dataset.get_seq_length(i) - self.future_seq_len,
            )
            if valid_start_range[0] < valid_start_range[1]:
                if self.only_sample_tail:
                    future_obs = self.dataset[i][0][-self.future_seq_len:]
                else:
                    start = np.random.randint(*valid_start_range)
                    end = start + self.future_seq_len
                    future_obs = self.dataset[i][0][start:end]
            else:
                # zeros placeholder T x obs_dim
                _, obs_dim = values[0].shape
                future_obs = torch.zeros((self.future_seq_len, obs_dim))
            values.append(future_obs)

        if self.transform is not None:
            values = self.transform(values)
        return tuple(values)

def random_split_traj(
    dataset: TrajectoryDataset,
    lengths: Sequence[int],
    generator: Optional[torch.Generator] = default_generator,
) -> List[TrajectorySubset]:
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths), generator=generator).tolist()
    return [
        TrajectorySubset(dataset, indices[offset - length : offset])
        for offset, length in zip(accumulate(lengths), lengths)
    ]

def split_traj_datasets(dataset, train_fraction=0.95, random_seed=42):
    dataset_length = len(dataset)
    lengths = [
        int(train_fraction * dataset_length),
        dataset_length - int(train_fraction * dataset_length),
    ]
    generator = torch.Generator().manual_seed(random_seed)
    return random_split_traj(dataset, lengths, generator=generator)

def get_train_val_sliced(
    traj_dataset: TrajectoryDataset,
    train_fraction: float = 0.9,
    random_seed: int = 42,
    window_size: int = 10,
    future_conditional: bool = False,
    min_future_sep: int = 0,
    future_seq_len: Optional[int] = None,
    only_sample_tail: bool = False,
    transform: Optional[Callable[[Any], Any]] = None,
):
    train, val = split_traj_datasets(
        traj_dataset,
        train_fraction=train_fraction,
        random_seed=random_seed,
    )
    traj_slicer_kwargs = {
        "window": window_size,
        "future_conditional": future_conditional,
        "min_future_sep": min_future_sep,
        "future_seq_len": future_seq_len,
        "only_sample_tail": only_sample_tail,
        "transform": transform,
    }
    train_slices = TrajectorySlicerDataset(train, **traj_slicer_kwargs)
    val_slices = TrajectorySlicerDataset(val, **traj_slicer_kwargs)
    return train_slices, val_slices

def transpose_batch_timestep(*args):
    return (einops.rearrange(arg, "t n ... -> n t ...") for arg in args)

def run_kitchen_dataset():
    try:
        # Configuration
        data_dir = "/content/drive/MyDrive/franka/kitchen"
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        print(f"Using device: {device}")

        # Create base dataset
        print("\nCreating base dataset...")
        dataset = RelayKitchenTrajectoryDataset(
            data_directory=data_dir,
            device=device,
            onehot_goals=True
        )
        print(f"Base dataset created with size: {len(dataset)}")

        # Create train and validation datasets
        print("\nCreating train/val splits...")
        train_dataset, val_dataset = get_train_val_sliced(
            dataset,
            train_fraction=0.9,
            random_seed=42,
            window_size=16,
            future_conditional=False
        )

        print("\nDataset sizes:")
        print(f"Full dataset: {len(dataset)}")
        print(f"Training slices: {len(train_dataset)}")
        print(f"Validation slices: {len(val_dataset)}")

        # Create dataloaders
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=32,
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )

        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=32,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )

        # Test batch loading
        print("\nTesting batch loading:")
        for batch in train_loader:
            print("Batch shapes:")
            for i, item in enumerate(batch):
                print(f"Item {i} shape: {item.shape}")
            break

        return dataset, train_dataset, val_dataset, train_loader, val_loader

    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
        return None

if __name__ == "__main__":
    result = run_kitchen_dataset()
    if result:
        dataset, train_dataset, val_dataset, train_loader, val_loader = result
        print("\nAll components created successfully!")
    else:
        print("\nFailed to create dataset components")

Using device: cpu

Creating base dataset...
Loading data from /content/drive/MyDrive/franka/kitchen

Original shapes:
Observations: torch.Size([409, 566, 60])
Actions: torch.Size([409, 566, 9])
Masks: torch.Size([409, 566])

Transposed shapes:
Observations: torch.Size([566, 409, 60])
Actions: torch.Size([566, 409, 9])
Masks: torch.Size([566, 409])
Using observations as goals...


  goals = torch.load(data_directory / "onehot_goals.pth")


Base dataset created with size: 566

Creating train/val splits...

Dataset sizes:
Full dataset: 566
Training slices: 107387
Validation slices: 12242

Testing batch loading:
Batch shapes:
Item 0 shape: torch.Size([32, 16, 60])
Item 1 shape: torch.Size([32, 16, 9])
Item 2 shape: torch.Size([32, 16, 60])

All components created successfully!


##now train

In [None]:
import torch
import numpy as np
import einops
from pathlib import Path
from torch.utils.data import Dataset, TensorDataset, Subset
from torch import default_generator, randperm
from itertools import accumulate
from typing import Any, Callable, List, Optional, Sequence
import abc
from behavior_transformer import BehaviorTransformer, GPT, GPTConfig
import wandb
from tqdm import tqdm
import os


class KitchenTrainer:
    def __init__(self, config=None):
        self.config = {
            # Model parameters
            'obs_dim': 60,
            'act_dim': 9,
            'goal_dim': 60,
            'n_layer': 6,
            'n_head': 8,
            'n_embd': 256,
            'block_size': 144,
            'n_clusters': 32,

            # Training parameters
            'batch_size': 32,
            'learning_rate': 1e-5,
            'weight_decay': 2e-4,
            'betas': [0.9, 0.999],
            'num_epochs': 100,
            'save_freq': 10,
            'eval_freq': 5,

            # Paths
            'save_dir': '/content/drive/MyDrive/franka/kitchen/checkpoints',
            'data_dir': '/content/drive/MyDrive/franka/kitchen',

            # Device
            'device': 'cuda' if torch.cuda.is_available() else 'cpu',

            # Wandb config
            'use_wandb': True,
            'wandb_project': 'kitchen-cbet',
            'wandb_entity': None,  # Your wandb username
            'experiment_name': 'kitchen-training'
        }
        if config:
            self.config.update(config)

        # Create save directory
        os.makedirs(self.config['save_dir'], exist_ok=True)

        # Initialize wandb
        if self.config['use_wandb']:
            wandb.init(
                project=self.config['wandb_project'],
                entity=self.config['wandb_entity'],
                config=self.config,
                name=self.config['experiment_name']
            )

    def create_model(self):
        """Create CBET model"""
        model = BehaviorTransformer(
            obs_dim=self.config['obs_dim'],
            act_dim=self.config['act_dim'],
            goal_dim=self.config['goal_dim'],
            gpt_model=GPT(
                GPTConfig(
                    block_size=self.config['block_size'],
                    input_dim=self.config['obs_dim'],
                    n_layer=self.config['n_layer'],
                    n_head=self.config['n_head'],
                    n_embd=self.config['n_embd'],
                )
            ),
            n_clusters=self.config['n_clusters'],
            kmeans_fit_steps=5,
        ).to(self.config['device'])

        return model

    def save_checkpoint(self, model, optimizer, epoch, loss, path):
        """Save model checkpoint"""
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            'config': self.config
        }, path)

    def load_checkpoint(self, model, optimizer, path):
        """Load model checkpoint"""
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['epoch'], checkpoint['loss']

    def train(self, train_loader, val_loader, checkpoint_path=None):
        """Main training loop"""
        print("Initializing training...")
        print(f"Using device: {self.config['device']}")

        # Create model
        model = self.create_model()
        print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

        # Create optimizer
        optimizer = model.configure_optimizers(
            weight_decay=self.config['weight_decay'],
            learning_rate=self.config['learning_rate'],
            betas=self.config['betas'],
        )

        # Load checkpoint if provided
        start_epoch = 0
        if checkpoint_path:
            start_epoch, _ = self.load_checkpoint(model, optimizer, checkpoint_path)
            print(f"Loaded checkpoint from epoch {start_epoch}")

        print("Starting training...")
        for epoch in range(start_epoch, self.config['num_epochs']):
            # Training phase
            model.train()
            train_loss = 0
            train_batches = 0

            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config['num_epochs']}")
            for batch in progress_bar:
                # Move batch to device
                obs, acts, goals = [x.to(self.config['device']) for x in batch]

                # Forward pass
                optimizer.zero_grad()
                pred_actions, loss, loss_dict = model(obs, goals, acts)

                # Backward pass
                loss.backward()
                optimizer.step()

                # Update metrics
                train_loss += loss.item()
                train_batches += 1

                # Update progress bar
                progress_bar.set_postfix({'loss': loss.item()})

                # Log to wandb
                if self.config['use_wandb']:
                    wandb.log({
                        'train/loss': loss.item(),
                        'train/epoch': epoch,
                        **{f"train/{k}": v.item() if isinstance(v, torch.Tensor) else v
                           for k, v in loss_dict.items()}
                    })

            avg_train_loss = train_loss / train_batches
            print(f"\nEpoch {epoch+1} - Average training loss: {avg_train_loss:.4f}")

            # Evaluation phase
            if (epoch + 1) % self.config['eval_freq'] == 0:
                model.eval()
                val_loss = 0
                val_batches = 0

                with torch.no_grad():
                    for batch in tqdm(val_loader, desc="Evaluating"):
                        obs, acts, goals = [x.to(self.config['device']) for x in batch]
                        _, loss, loss_dict = model(obs, goals, acts)
                        val_loss += loss.item()
                        val_batches += 1

                        if self.config['use_wandb']:
                            wandb.log({
                                'val/loss': loss.item(),
                                'val/epoch': epoch,
                                **{f"val/{k}": v.item() if isinstance(v, torch.Tensor) else v
                                   for k, v in loss_dict.items()}
                            })

                avg_val_loss = val_loss / val_batches
                print(f"Validation loss: {avg_val_loss:.4f}")

            # Save checkpoint
            if (epoch + 1) % self.config['save_freq'] == 0:
                save_path = Path(self.config['save_dir']) / f"checkpoint_epoch_{epoch+1}.pt"
                self.save_checkpoint(model, optimizer, epoch, avg_train_loss, save_path)
                print(f"Checkpoint saved: {save_path}")

        # Save final model
        final_path = Path(self.config['save_dir']) / "final_model.pt"
        self.save_checkpoint(model, optimizer, self.config['num_epochs']-1, avg_train_loss, final_path)
        print("Training complete!")

        if self.config['use_wandb']:
            wandb.finish()

        return model

def main():
    # Create dataset

    result = run_kitchen_dataset()

    if not result:
        print("Failed to create datasets")
        return

    dataset, train_dataset, val_dataset, train_loader, val_loader = result

    # Create trainer and train
    trainer = KitchenTrainer()
    model = trainer.train(train_loader, val_loader)

if __name__ == "__main__":
    main()

Using device: cpu

Creating base dataset...
Loading data from /content/drive/MyDrive/franka/kitchen

Original shapes:
Observations: torch.Size([409, 566, 60])
Actions: torch.Size([409, 566, 9])
Masks: torch.Size([409, 566])

Transposed shapes:
Observations: torch.Size([566, 409, 60])
Actions: torch.Size([566, 409, 9])
Masks: torch.Size([566, 409])
Using observations as goals...
Base dataset created with size: 566

Creating train/val splits...


  goals = torch.load(data_directory / "onehot_goals.pth")



Dataset sizes:
Full dataset: 566
Training slices: 107387
Validation slices: 12242

Testing batch loading:
Batch shapes:
Item 0 shape: torch.Size([32, 16, 60])
Item 1 shape: torch.Size([32, 16, 9])
Item 2 shape: torch.Size([32, 16, 60])


  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Initializing training...
Using device: cpu
number of parameters: 4.86M
Model created with 4939328 parameters
Starting training...


Epoch 1/100:   0%|          | 4/3356 [00:04<1:00:44,  1.09s/it, loss=0]
  0%|          | 0/50 [00:00<?, ?it/s][A
K-means clustering:   0%|          | 0/50 [00:00<?, ?it/s][A
K-means clustering:  38%|███▊      | 19/50 [00:00<00:00, 187.59it/s][A
K-means clustering: 100%|██████████| 50/50 [00:00<00:00, 175.32it/s]
Epoch 1/100:   6%|▌         | 198/3356 [03:08<54:16,  1.03s/it, loss=35.1]