# Train an LSTM based controller 

Train and save an LSTM-based controller. It contains:
* Code for loading and pre-processing the training data. 
* Training an LSTM with specific parameters and saving it

In [13]:
import sys
sys.path.append("..")
from settings import Config

import pathlib
#from pprint import pformat
from tqdm import tqdm

#import matplotlib.pyplot as plt

import torch
import torch.nn as nn
#import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

from sensorprocessing import sp_conv_vae
from demo_to_trainingdata import create_RNN_training_sequence_xy, BCDemonstration
from bc_LSTM import LSTMXYPredictor, LSTMResidualController

from tensorboardX import SummaryWriter


### Creating training and validation data
Create training and validation data from all the demonstrations of a certain task.

In [14]:
task = "proprioception-uncluttered"

conv_vae_jsonfile = pathlib.Path(Config()["controller"]["vae_json"])
conv_vae_model_pthfile = pathlib.Path(Config()["controller"]["vae_model"])


sp = sp_conv_vae.ConvVaeSensorProcessing(conv_vae_jsonfile,
                                         conv_vae_model_pthfile)

demos_dir = pathlib.Path(Config()["demos"]["directory"])
task_dir = pathlib.Path(demos_dir, "demos", task)

inputlist = []
targetlist = []

for demo_dir in task_dir.iterdir():
    if not demo_dir.is_dir():
        pass
    bcd = BCDemonstration(demo_dir, sensorprocessor=sp)
    print(bcd)
    z, a = bcd.read_z_a()
    print(z.shape)
    print(a.shape)
    # FIXME the repeated name for inputs and targets
    inputs, targets = create_RNN_training_sequence_xy(z, a, sequence_length=10)
    inputlist.append(inputs)
    targetlist.append(targets)

inputs = torch.cat(inputlist)
targets = torch.cat(targetlist)

# Separate the training and validation data. 
# We will be shuffling the demonstrations 
rows = torch.randperm(inputs.size(0)) 
shuffled_inputs = inputs[rows]
shuffled_targets = targets[rows]

training_size = int( inputs.size(0) * 0.67 )
inputs_training = shuffled_inputs[1:training_size]
targets_training = shuffled_targets[1:training_size]

inputs_validation = shuffled_inputs[training_size:]
targets_validation = shuffled_targets[training_size:]

Cameras found: ['dev2']
There are 753 steps in this demonstration
This demonstration was recorded by the following cameras: ['dev2']
{'actiontype': 'rc-position-target',
 'camera': 'dev2',
 'cameras': ['dev2'],
 'maxsteps': 753,
 'sensorprocessor': <sensorprocessing.sp_conv_vae.ConvVaeSensorProcessing object at 0x71006e3d3370>,
 'source_dir': PosixPath('/home/lboloni/Documents/Hackingwork/__Temporary/BerryPicker-demos/demos/proprioception-uncluttered/2024_10_26__16_31_40'),
 'trim_from': 1,
 'trim_to': 753}
(752, 128)
(752, 6)
Cameras found: ['dev2']
There are 968 steps in this demonstration
This demonstration was recorded by the following cameras: ['dev2']
{'actiontype': 'rc-position-target',
 'camera': 'dev2',
 'cameras': ['dev2'],
 'maxsteps': 968,
 'sensorprocessor': <sensorprocessing.sp_conv_vae.ConvVaeSensorProcessing object at 0x71006e3d3370>,
 'source_dir': PosixPath('/home/lboloni/Documents/Hackingwork/__Temporary/BerryPicker-demos/demos/proprioception-uncluttered/2024_10_26__

In [15]:
def validate_behavior_cloning(model, criterion, inputs_validation, targets_validation):
    num_sequences = inputs_validation.shape[0]
    model.eval()
    val_loss = 0
    with torch.no_grad():  # Disable gradient computation
        for i in range(num_sequences):
            # Forward pass
            input_seq = inputs_validation[i]
            target = targets_validation[i]
            # Reshape for batch compatibility
            input_seq = input_seq.unsqueeze(0)  # Shape: [1, sequence_length, latent_size]
            target = target.unsqueeze(0)        # Shape: [1, latent_size]

            outputs = model(input_seq)
            loss = criterion(outputs, target)
            # Accumulate loss
            val_loss += loss.item()
    avg_loss = val_loss / num_sequences
    return avg_loss

def train_behavior_cloning(model, optimizer, criterion, inputs_training, targets_training, inputs_validation, targets_validation, num_epochs, writer = None):
    num_sequences = inputs_training.shape[0]

    for epoch in tqdm(range(num_epochs)):
        model.train()
        
        # Loop over each sequence in the batch
        training_loss = 0
        for i in range(num_sequences):
            # Prepare input and target
            input_seq = inputs_training[i]
            target = targets_training[i]

            # Reshape for batch compatibility
            input_seq = input_seq.unsqueeze(0)  # Shape: [1, sequence_length, latent_size]
            target = target.unsqueeze(0)        # Shape: [1, latent_size]

            # Forward pass
            output = model(input_seq)
            loss = criterion(output, target)
            training_loss += loss.item()
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        avg_training_loss = training_loss / num_sequences
        avg_validation_loss = validate_behavior_cloning(model, criterion, inputs_validation=inputs_validation, targets_validation=targets_validation)
        if writer is not None:
            writer.add_scalar("TrainingLoss", avg_training_loss, epoch)
            writer.add_scalar("ValidationLoss", avg_validation_loss, epoch)
            writer.flush()


        if (epoch+1) % 2 == 0: # was 0
            print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_training_loss:.4f} Validation Loss: {avg_validation_loss:.4f} ')




In [None]:
# Original
latent_size = Config()["robot"]["latent_encoding_size"]  
hidden_size = 32  # degrees of freedom in the robot
output_size = 6  # degrees of freedom in the robot
num_layers = 2

# Instantiate model, loss function, and optimizer
model = LSTMXYPredictor(latent_size=latent_size, hidden_size=hidden_size, output_size = output_size, num_layers=num_layers)
criterion = nn.MSELoss()  # Mean Squared Error for regression
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training Loop
num_epochs = 100

# Create a SummaryWriter instance
# where does the logdir go???
writer = SummaryWriter(logdir="/home/lboloni/runs/example")

train_behavior_cloning(
    model, optimizer, criterion,
    inputs_training=inputs_training, 
    targets_training=targets_training, 
    inputs_validation=inputs_validation,
    targets_validation=targets_validation,
    num_epochs=num_epochs, writer=writer)
print("Training complete.")
writer.close()


  0%|          | 0/100 [00:00<?, ?it/s]

  2%|▏         | 2/100 [00:14<12:09,  7.44s/it]

Epoch [2/100], Training Loss: 681.1389 Validation Loss: 673.9735 


  4%|▍         | 4/100 [00:30<12:33,  7.85s/it]

Epoch [4/100], Training Loss: 657.6181 Validation Loss: 673.2660 


  6%|▌         | 6/100 [00:46<12:11,  7.79s/it]

Epoch [6/100], Training Loss: 647.3553 Validation Loss: 655.2405 


  8%|▊         | 8/100 [01:01<11:40,  7.61s/it]

Epoch [8/100], Training Loss: 605.5764 Validation Loss: 592.3224 


 10%|█         | 10/100 [01:16<11:18,  7.54s/it]

Epoch [10/100], Training Loss: 430.0906 Validation Loss: 398.9063 


 12%|█▏        | 12/100 [01:30<10:49,  7.38s/it]

Epoch [12/100], Training Loss: 320.6419 Validation Loss: 311.4591 


 14%|█▍        | 14/100 [01:45<10:39,  7.44s/it]

Epoch [14/100], Training Loss: 278.7717 Validation Loss: 280.1741 


 16%|█▌        | 16/100 [02:00<10:21,  7.40s/it]

Epoch [16/100], Training Loss: 266.2260 Validation Loss: 263.6351 


 18%|█▊        | 18/100 [02:15<10:02,  7.35s/it]

Epoch [18/100], Training Loss: 241.3722 Validation Loss: 230.0230 


 20%|██        | 20/100 [02:29<09:41,  7.27s/it]

Epoch [20/100], Training Loss: 221.4708 Validation Loss: 232.3125 


 22%|██▏       | 22/100 [02:44<09:27,  7.28s/it]

Epoch [22/100], Training Loss: 218.3173 Validation Loss: 250.2005 


 24%|██▍       | 24/100 [02:58<09:12,  7.27s/it]

Epoch [24/100], Training Loss: 197.5010 Validation Loss: 174.7614 


 26%|██▌       | 26/100 [03:12<08:49,  7.16s/it]

Epoch [26/100], Training Loss: 211.1167 Validation Loss: 237.6641 


 28%|██▊       | 28/100 [03:27<08:38,  7.20s/it]

Epoch [28/100], Training Loss: 198.4672 Validation Loss: 170.8068 


 30%|███       | 30/100 [03:42<08:33,  7.34s/it]

Epoch [30/100], Training Loss: 195.5745 Validation Loss: 230.7051 


 32%|███▏      | 32/100 [03:56<08:18,  7.33s/it]

Epoch [32/100], Training Loss: 169.2372 Validation Loss: 159.3140 


 34%|███▍      | 34/100 [04:11<08:03,  7.33s/it]

Epoch [34/100], Training Loss: 177.7108 Validation Loss: 173.4800 


 36%|███▌      | 36/100 [04:25<07:41,  7.21s/it]

Epoch [36/100], Training Loss: 165.8977 Validation Loss: 176.2879 


 38%|███▊      | 38/100 [04:39<07:22,  7.14s/it]

Epoch [38/100], Training Loss: 163.1013 Validation Loss: 155.0908 


 40%|████      | 40/100 [04:54<07:09,  7.16s/it]

Epoch [40/100], Training Loss: 171.2597 Validation Loss: 155.7435 


 42%|████▏     | 42/100 [05:08<06:50,  7.08s/it]

Epoch [42/100], Training Loss: 157.4768 Validation Loss: 158.5668 


 44%|████▍     | 44/100 [05:22<06:31,  7.00s/it]

Epoch [44/100], Training Loss: 164.1668 Validation Loss: 156.4378 


 46%|████▌     | 46/100 [05:36<06:17,  6.99s/it]

Epoch [46/100], Training Loss: 160.3865 Validation Loss: 153.0587 


 48%|████▊     | 48/100 [05:50<06:15,  7.22s/it]

Epoch [48/100], Training Loss: 160.3768 Validation Loss: 167.5916 


 50%|█████     | 50/100 [06:06<06:11,  7.42s/it]

Epoch [50/100], Training Loss: 157.2570 Validation Loss: 152.8928 


 52%|█████▏    | 52/100 [06:20<05:53,  7.37s/it]

Epoch [52/100], Training Loss: 158.5003 Validation Loss: 152.5841 


 54%|█████▍    | 54/100 [06:36<05:44,  7.50s/it]

Epoch [54/100], Training Loss: 153.1759 Validation Loss: 151.1868 


 56%|█████▌    | 56/100 [06:51<05:30,  7.52s/it]

Epoch [56/100], Training Loss: 151.6233 Validation Loss: 149.2172 


 58%|█████▊    | 58/100 [07:06<05:21,  7.65s/it]

Epoch [58/100], Training Loss: 157.1310 Validation Loss: 154.3923 


 60%|██████    | 60/100 [07:22<05:06,  7.66s/it]

Epoch [60/100], Training Loss: 146.8540 Validation Loss: 148.6823 


 62%|██████▏   | 62/100 [07:36<04:45,  7.52s/it]

Epoch [62/100], Training Loss: 146.7536 Validation Loss: 145.9733 


 64%|██████▍   | 64/100 [07:51<04:27,  7.42s/it]

Epoch [64/100], Training Loss: 146.4322 Validation Loss: 149.3700 


 66%|██████▌   | 66/100 [08:06<04:13,  7.44s/it]

Epoch [66/100], Training Loss: 144.4244 Validation Loss: 139.3078 


 68%|██████▊   | 68/100 [08:21<04:01,  7.55s/it]

Epoch [68/100], Training Loss: 142.5531 Validation Loss: 139.9715 


 70%|███████   | 70/100 [08:36<03:44,  7.49s/it]

Epoch [70/100], Training Loss: 139.4719 Validation Loss: 140.2685 


 72%|███████▏  | 72/100 [08:51<03:28,  7.46s/it]

Epoch [72/100], Training Loss: 142.0863 Validation Loss: 139.7440 


 74%|███████▍  | 74/100 [09:06<03:16,  7.56s/it]

Epoch [74/100], Training Loss: 138.5612 Validation Loss: 138.2125 


 76%|███████▌  | 76/100 [09:21<02:59,  7.48s/it]

Epoch [76/100], Training Loss: 146.0120 Validation Loss: 139.1573 


 78%|███████▊  | 78/100 [09:36<02:45,  7.54s/it]

Epoch [78/100], Training Loss: 142.5335 Validation Loss: 139.5805 


 80%|████████  | 80/100 [09:51<02:32,  7.60s/it]

Epoch [80/100], Training Loss: 139.2310 Validation Loss: 140.7064 


 82%|████████▏ | 82/100 [10:06<02:15,  7.53s/it]

Epoch [82/100], Training Loss: 151.2741 Validation Loss: 138.1753 


 84%|████████▍ | 84/100 [10:22<02:00,  7.55s/it]

Epoch [84/100], Training Loss: 138.3865 Validation Loss: 141.8983 


 86%|████████▌ | 86/100 [10:36<01:43,  7.38s/it]

Epoch [86/100], Training Loss: 137.3000 Validation Loss: 153.0839 


 88%|████████▊ | 88/100 [10:51<01:29,  7.43s/it]

Epoch [88/100], Training Loss: 140.3089 Validation Loss: 136.6296 


 90%|█████████ | 90/100 [11:06<01:14,  7.40s/it]

Epoch [90/100], Training Loss: 138.1232 Validation Loss: 135.3617 


 92%|█████████▏| 92/100 [11:21<00:59,  7.42s/it]

Epoch [92/100], Training Loss: 133.7925 Validation Loss: 138.2719 


 94%|█████████▍| 94/100 [11:35<00:44,  7.39s/it]

Epoch [94/100], Training Loss: 135.8373 Validation Loss: 136.8074 


 96%|█████████▌| 96/100 [11:50<00:29,  7.40s/it]

Epoch [96/100], Training Loss: 135.3632 Validation Loss: 135.6534 


 98%|█████████▊| 98/100 [12:05<00:15,  7.53s/it]

Epoch [98/100], Training Loss: 132.4732 Validation Loss: 132.8127 


100%|██████████| 100/100 [12:21<00:00,  7.41s/it]

Epoch [100/100], Training Loss: 135.8327 Validation Loss: 134.6216 
Training complete.





NameError: name 'write' is not defined

In [None]:

# FIXME: save the model
filename_lstm = Config()["controller"]["lstm_model_file"]
torch.save(model.state_dict(), filename_lstm)

# Load the behavior cloning controller and use it with a real time data

In [None]:
# Original
latent_size = Config()["robot"]["latent_encoding_size"]  
hidden_size = 32  # degrees of freedom in the robot
output_size = 6  # degrees of freedom in the robot
num_layers = 2

# Instantiate model, loss function, and optimizer
model = LSTMXYPredictor(latent_size=latent_size, hidden_size=hidden_size, output_size = output_size, num_layers=num_layers)
criterion = nn.MSELoss()  # Mean Squared Error for regression
filename_lstm = Config()["controller"]["lstm_model_file"]
model.load_state_dict(torch.load(filename_lstm))

In [None]:
# Get one demonstration
task = "proprioception-uncluttered"
sp = sp_conv_vae.ConvVaeSensorProcessing()

demos_dir = pathlib.Path(Config()["demos"]["directory"])
task_dir = pathlib.Path(demos_dir, "demos", task)

inputlist = []
targetlist = []

demo_dir = next(task_dir.iterdir())
bcd = BCDemonstration(demo_dir, sensorprocessor=sp)
z, a = bcd.read_z_a()

In [None]:
z.shape[0]
print(a[1])

In [None]:
for i in range(z.shape[0]-1):
    input = torch.from_numpy(z[i])
    input = input.unsqueeze(0)
    input = input.unsqueeze(0)
    print(input)
    a_pred = model.forward_keep_state(input)
    a_real = a[i+1]
    print(f"a_real: {a_real}\na_pred: {a_pred}")