# Train a proprioception-tuned Vision Transformer (ViT)

We create a sensor processing model using Vision Transformer (ViT) based visual encoding finetuned with proprioception.

We start with a pretrained ViT model, then train it to:
1. Create a meaningful 128-dimensional latent representation
2. Learn to map this representation to robot positions (proprioception)

The sensor processing object associated with the trained model is in sensorprocessing/sp_vit.py

In [1]:
import sys
sys.path.append("..")

from settings import Config

import pathlib
import torch
import torch.nn as nn
from torchvision import models, transforms
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from behavior_cloning.demo_to_trainingdata import BCDemonstration
from sensorprocessing.sp_vit import VitSensorProcessing
from robot.al5d_position_controller import RobotPosition

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
# The experiment/run we are going to run: the specified model will be created


experiment = "sensorprocessing_propriotuned_Vit"
# Other possible configurations:
# run = "vit_base"  # ViT Base
# run = "vit_large" # ViT Large
run = "vit_huge" # ViT Huge

exp = Config().get_experiment(experiment, run)

Loading pointer config file: /home/ssheikholeslami/.config/BerryPicker/mainsettings.yaml
Loading machine-specific config file: /home/ssheikholeslami/SaharaBerryPickerData/settings-sahara.yaml
No system dependent experiment file
 /home/ssheikholeslami/SaharaBerryPickerData/experiments-Config/sensorprocessing_propriotuned_Vit/vit_huge_sysdep.yaml,
 that is ok, proceeding.
Configuration for experiment: sensorprocessing_propriotuned_Vit/vit_huge successfully loaded


### Create regression training data (image to proprioception)
The training data (X, Y) is all the pictures from a demonstration with the corresponding proprioception data.

In [3]:
def load_images_as_proprioception_training(task, proprioception_input_file, proprioception_target_file):
    """Loads all the images of a task, and processes it as two tensors as input and target data for proprioception training.
    Caches the processed results into the input and target file pointed in the config. Remove those files to recalculate.
    """
    retval = {}
    if proprioception_input_file.exists():
        retval["inputs"] = torch.load(proprioception_input_file, weights_only=True)
        retval["targets"] = torch.load(proprioception_target_file, weights_only=True)
    else:
        demos_dir = pathlib.Path(Config()["demos"]["directory"])
        task_dir = pathlib.Path(demos_dir, "demos", task)

        inputlist = []
        targetlist = []

        print(f"Loading demonstrations from {task_dir}")
        for demo_dir in task_dir.iterdir():
            if not demo_dir.is_dir():
                continue
            print(f"Processing demonstration: {demo_dir.name}")
            bcd = BCDemonstration(demo_dir, sensorprocessor=None)
            for i in range(bcd.trim_from, bcd.trim_to):
                sensor_readings, _ = bcd.get_image(i)
                inputlist.append(sensor_readings[0])
                a = bcd.get_a(i)
                rp = RobotPosition.from_vector(a)
                anorm = rp.to_normalized_vector()
                targetlist.append(torch.from_numpy(anorm))

        retval["inputs"] = torch.stack(inputlist)
        retval["targets"] = torch.stack(targetlist)
        torch.save(retval["inputs"], proprioception_input_file)
        torch.save(retval["targets"], proprioception_target_file)
        print(f"Saved {len(inputlist)} training examples")

    # Separate the training and validation data.
    # We will be shuffling the demonstrations
    length = retval["inputs"].size(0)
    rows = torch.randperm(length)
    shuffled_inputs = retval["inputs"][rows]
    shuffled_targets = retval["targets"][rows]

    training_size = int(length * 0.67)
    retval["inputs_training"] = shuffled_inputs[1:training_size]
    retval["targets_training"] = shuffled_targets[1:training_size]

    retval["inputs_validation"] = shuffled_inputs[training_size:]
    retval["targets_validation"] = shuffled_targets[training_size:]

    print(f"Created {retval['inputs_training'].size(0)} training examples and {retval['inputs_validation'].size(0)} validation examples")
    return retval

In [4]:
# Create output directory if it doesn't exist
data_dir = pathlib.Path(exp["data_dir"])
data_dir.mkdir(parents=True, exist_ok=True)
print(f"Data directory: {data_dir}")

task = exp["proprioception_training_task"]
proprioception_input_file = pathlib.Path(exp["data_dir"], exp["proprioception_input_file"])
proprioception_target_file = pathlib.Path(exp["data_dir"], exp["proprioception_target_file"])

tr = load_images_as_proprioception_training(task, proprioception_input_file, proprioception_target_file)
inputs_training = tr["inputs_training"]
targets_training = tr["targets_training"]
inputs_validation = tr["inputs_validation"]
targets_validation = tr["targets_validation"]

Data directory: /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit/vit_huge
Loading demonstrations from /home/ssheikholeslami/SaharaBerryPickerData/demonstrations/demos/proprio_regressor_training
Processing demonstration: 2025_03_08__15_02_56
Cameras found: ['dev3']
There are 445 steps in this demonstration
This demonstration was recorded by the following cameras: ['dev3']


Processing demonstration: 2025_03_08__15_06_47
Cameras found: ['dev3']
There are 468 steps in this demonstration
This demonstration was recorded by the following cameras: ['dev3']
Processing demonstration: 2025_03_08__15_05_47
Cameras found: ['dev3']
There are 410 steps in this demonstration
This demonstration was recorded by the following cameras: ['dev3']
Processing demonstration: 2025_03_08__15_01_56
Cameras found: ['dev3']
There are 384 steps in this demonstration
This demonstration was recorded by the following cameras: ['dev3']
Saved 1703 training examples
Created 1140 training examples and 562 validation examples


### Create the ViT model with proprioception regression

In [5]:
# Create the ViT model with proprioception
sp = VitSensorProcessing(exp, device)
model = sp.enc  # Get the actual encoder model for training


# Debug code

print("Model created successfully")

try:
    params = model.parameters()
    print("Parameters accessed successfully")
    param_count = sum(p.numel() for p in params)
    print(f"Total parameters: {param_count}")
except Exception as e:
    print(f"Error accessing parameters: {e}")

    # Check individual components
    try:
        backbone_params = model.backbone.parameters()
        print("Backbone parameters accessed successfully")
    except Exception as e:
        print(f"Error accessing backbone parameters: {e}")

    try:
        projection_params = model.projection.parameters()
        print("Projection parameters accessed successfully")
    except Exception as e:
        print(f"Error accessing projection parameters: {e}")

    try:
        proprioceptor_params = model.proprioceptor.parameters()
        print("Proprioceptor parameters accessed successfully")
    except Exception as e:
        print(f"Error accessing proprioceptor parameters: {e}")

# Select loss function
loss_type = exp.get('loss', 'MSELoss')
if loss_type == 'MSELoss':
    criterion = nn.MSELoss()
elif loss_type == 'L1Loss':
    criterion = nn.L1Loss()
else:
    criterion = nn.MSELoss()  # Default to MSE

# Set up optimizer with appropriate learning rate and weight decay
optimizer = optim.Adam(
    model.parameters(),
    lr=exp.get('learning_rate', 0.001),
    weight_decay=exp.get('weight_decay', 0.01)
)

# Optional learning rate scheduler
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

Initializing ViT Sensor Processing:
  Model: vit_h_14
  Latent dimension: 128
  Image size: 518x518


Using vit_h_14 with output dimension 1280
Created projection network: 1280 → 1024 → 512 → 128
Created latent representation: 1280 → 1024 → 128
Created proprioceptor: 128 → 64 → 64 → 6
Feature extractor frozen. Projection and proprioceptor layers are trainable.
Model created successfully
Parameters accessed successfully
Total parameters: 634107526




In [6]:
# Create DataLoaders for batching
batch_size = exp.get('batch_size', 32)
train_dataset = TensorDataset(inputs_training, targets_training)
test_dataset = TensorDataset(inputs_validation, targets_validation)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [7]:
def train_and_save_proprioception_model(model, criterion, optimizer, modelfile,
                                        device="cpu", epochs=20, scheduler=None,
                                        log_interval=1):
    """Trains and saves the ViT proprioception model

    Args:
        model: ViT model with proprioception
        criterion: Loss function
        optimizer: Optimizer
        modelfile: Path to save the model
        device: Training device (cpu/cuda)
        epochs: Number of training epochs
        scheduler: Optional learning rate scheduler
        log_interval: How often to print logs
    """
    # Ensure model is on the right device
    model = model.to(device)
    criterion = criterion.to(device)

    # Keep track of the best validation loss
    best_val_loss = float('inf')

    # Training loop
    num_epochs = epochs
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_loss = 0
        for batch_X, batch_y in train_loader:
            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device)

            # Forward pass through the full model (including proprioceptor)
            predictions = model.forward(batch_X)
            loss = criterion(predictions, batch_y)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_X, batch_y in test_loader:
                batch_X = batch_X.to(device)
                batch_y = batch_y.to(device)
                predictions = model(batch_X)
                loss = criterion(predictions, batch_y)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(test_loader)

        # Update learning rate if scheduler is provided
        if scheduler is not None:
            scheduler.step(avg_val_loss)

        # Save the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), modelfile)
            print(f"  New best model saved with validation loss: {best_val_loss:.4f}")

        # Log progress
        if (epoch + 1) % log_interval == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

    # Final evaluation
    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")
    return model

In [8]:
modelfile = pathlib.Path(exp["data_dir"], exp["proprioception_mlp_model_file"])
epochs = exp.get("epochs", 20)

# Check if model already exists
if modelfile.exists() and exp.get("reload_existing_model", True):
    print(f"Loading existing model from {modelfile}")
    model.load_state_dict(torch.load(modelfile, map_location=device))

    # Optional: evaluate the loaded model
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for batch_X, batch_y in test_loader:
            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device)
            predictions = model(batch_X)
            loss = criterion(predictions, batch_y)
            val_loss += loss.item()

        avg_val_loss = val_loss / len(test_loader)
        print(f"Loaded model validation loss: {avg_val_loss:.4f}")
else:
    print(f"Training new model for {epochs} epochs")
    model = train_and_save_proprioception_model(
        model, criterion, optimizer, modelfile,
        device=device, epochs=epochs, scheduler=lr_scheduler
    )

Training new model for 300 epochs
  New best model saved with validation loss: 0.0300
Epoch [1/300], Train Loss: 0.0821, Val Loss: 0.0300
  New best model saved with validation loss: 0.0254
Epoch [2/300], Train Loss: 0.0256, Val Loss: 0.0254
  New best model saved with validation loss: 0.0228
Epoch [3/300], Train Loss: 0.0244, Val Loss: 0.0228
  New best model saved with validation loss: 0.0221
Epoch [4/300], Train Loss: 0.0242, Val Loss: 0.0221
Epoch [5/300], Train Loss: 0.0241, Val Loss: 0.0238
  New best model saved with validation loss: 0.0220
Epoch [6/300], Train Loss: 0.0249, Val Loss: 0.0220
  New best model saved with validation loss: 0.0195
Epoch [7/300], Train Loss: 0.0254, Val Loss: 0.0195
Epoch [8/300], Train Loss: 0.0252, Val Loss: 0.0220
Epoch [9/300], Train Loss: 0.0253, Val Loss: 0.0266
Epoch [10/300], Train Loss: 0.0250, Val Loss: 0.0234
Epoch [11/300], Train Loss: 0.0252, Val Loss: 0.0213
  New best model saved with validation loss: 0.0172
Epoch [12/300], Train Loss: 

### Test the trained model

In [None]:
# Create the sensor processing module using the trained model
sp = VitSensorProcessing(exp, device)

# Test it on a few validation examples
def test_sensor_processing(sp, test_images, test_targets, n_samples=5):
    """Test the sensor processing module on a few examples."""
    if n_samples > len(test_images):
        n_samples = len(test_images)

    # Get random indices
    indices = torch.randperm(len(test_images))[:n_samples]

    print("\nTesting sensor processing on random examples:")
    print("-" * 50)

    for i, idx in enumerate(indices):
        # Get image and target
        image = test_images[idx].unsqueeze(0).to(device)  # Add batch dimension
        target = test_targets[idx].cpu().numpy()

        # Process the image to get the latent representation
        latent = sp.process(image)

        # Print the results
        print(f"Example {i+1}:")
        print(f"  Image shape: {image.shape}")
        print(f"  Latent shape: {latent.shape}")
        print(f"  Target position: {target}")
        print()

# Test the sensor processing
test_sensor_processing(sp, inputs_validation, targets_validation)

Initializing ViT Sensor Processing:
  Model: vit_l_16
  Latent dimension: 128
  Image size: 224x224


Using vit_l_16 with output dimension 1024
Created projection network: 1024 → 512 → 256 → 128
Created latent representation: 1024 → 512 → 128
Created proprioceptor: 128 → 64 → 64 → 6
Feature extractor frozen. Projection and proprioceptor layers are trainable.
Loading ViT encoder weights from /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit/vit_large/proprioception_mlp.pth


  self.enc.load_state_dict(torch.load(modelfile, map_location=device))



Testing sensor processing on random examples:
--------------------------------------------------
Example 1:
  Image shape: torch.Size([1, 3, 256, 256])
  Latent shape: (128,)
  Target position: [0.06060451 0.7903572  0.491293   0.704839   0.4744848  0.30784598]

Example 2:
  Image shape: torch.Size([1, 3, 256, 256])
  Latent shape: (128,)
  Target position: [0.5747702  0.78994364 0.6641628  0.33462417 0.29938844 0.71034545]

Example 3:
  Image shape: torch.Size([1, 3, 256, 256])
  Latent shape: (128,)
  Target position: [0.8682212 0.8242293 0.4242512 0.5075824 0.7660551 0.457051 ]

Example 4:
  Image shape: torch.Size([1, 3, 256, 256])
  Latent shape: (128,)
  Target position: [0.767953   0.702578   0.16808294 0.7865119  0.38399073 0.20142773]

Example 5:
  Image shape: torch.Size([1, 3, 256, 256])
  Latent shape: (128,)
  Target position: [0.3918777 0.5043456 0.8729367 0.8217159 0.833581  0.8309385]



### Verify the model's encoding and forward methods

In [None]:
# Verify that the encoding method works correctly
model.eval()
with torch.no_grad():
    # Get a sample image
    sample_image = inputs_validation[0].unsqueeze(0).to(device)

    # Get the latent representation using encode
    latent = model.encode(sample_image)
    print(f"Latent representation shape: {latent.shape}")

    # Get the robot position prediction using forward
    position = model.forward(sample_image)
    print(f"Robot position prediction shape: {position.shape}")

    # Check that the latent representation has the expected size
    expected_latent_size = exp["latent_size"]
    assert latent.shape[1] == expected_latent_size, f"Expected latent size {expected_latent_size}, got {latent.shape[1]}"

    # Check that the position prediction has the expected size
    expected_output_size = exp["output_size"]
    assert position.shape[1] == expected_output_size, f"Expected output size {expected_output_size}, got {position.shape[1]}"

    print("Verification successful!")

Latent representation shape: torch.Size([1, 128])
Robot position prediction shape: torch.Size([1, 6])
Verification successful!


### Save final model and summary

In [None]:
# Save the model and print summary
final_modelfile = pathlib.Path(exp["data_dir"], exp["proprioception_mlp_model_file"])
torch.save(model.state_dict(), final_modelfile)
print(f"Model saved to {final_modelfile}")

print("\nTraining complete!")
print(f"Vision Transformer type: {exp['vit_model']}")
print(f"Latent space dimension: {exp['latent_size']}")
print(f"Output dimension (robot DOF): {exp['output_size']}")
print(f"Use the VitSensorProcessing class to load and use this model for inference.")

Model saved to /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit/vit_large/proprioception_mlp.pth

Training complete!
Vision Transformer type: vit_l_16
Latent space dimension: 128
Output dimension (robot DOF): 6
Use the VitSensorProcessing class to load and use this model for inference.
