# Train an LSTM based controller 

Train and save an LSTM-based controller. It contains:
* Code for loading and pre-processing the training data. 
* Training an LSTM with specific parameters and saving it

In [5]:
import sys
sys.path.append("..")
from settings import Config

import json
import pathlib
from pprint import pformat

import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

import helper
from sensorprocessing import sp_conv_vae


## Behavior cloning demonstration
This code is a training helper which encapsulates one behavior cloning demonstration, which is a sequence of form $\{(s_0, a_0), ...(s_n, a_n)\}$. 

In practice, however, we want to create a demonstration that maps the latent encodings to actions $\{(z_0, a_0), ...(z_n, a_n)\}$


The transformation of $s \rightarrow z$ is done through an object of type  AbstractSensorProcessing. 

In a practical way, the source of information for a BC demonstration is a demonstration directory, and the saved robot control there. 

In [6]:
class BCDemonstration:
    """This class encapsulates loading a demonstration with the intention to convert it into training data."""

    def __init__(self, source_dir, sensorprocessor, actiontype = "rc-position-target", camera = None):
        self.source_dir = source_dir
        self.sensorprocessor = sensorprocessor
        assert actiontype in ["rc-position-target", "rc-angle-target", "rc-pulse-target"]
        self.actiontype = actiontype
        # analyze the directory
        self.cameras, self.maxsteps = helper.analyze_demo(source_dir)
        # analyze 
        if camera is None:
            self.camera = self.cameras[0]
        else:
            self.camera = camera
        # read in _demonstration.json, load the trim values
        with open(pathlib.Path(self.source_dir, "_demonstration.json")) as file:
            data = json.load(file)
        self.trim_from = data["trim-from"]
        self.trim_to = data["trim-to"]
        if self.trim_to == -1:
            self.trim_to = self.maxsteps

    def read_z_a(self):
        """Reads in the demonstrations for z and a"""
        z = []
        a = []
        for i in range(self.trim_from, self.trim_to):
            zval = self.get_z(i)
            # print(zval.cpu())
            z.append(zval)
            a.append(self.get_a(i))
        return np.array(z), np.array(a)

    def __str__(self):
        #return json.dumps(self.__dict__, indent=4)
        return pformat(self.__dict__)

    def get_z(self, i):
        filepath = pathlib.Path(self.source_dir, f"{i:05d}_{self.camera}.jpg")
        val = self.sensorprocessor.process_file(filepath)
        return val
            
    def get_a(self, i):
        filepath = pathlib.Path(self.source_dir, f"{i:05d}.json") 
        with open(filepath) as file:
            data = json.load(file)
        if self.actiontype == "rc-position-target":
            datadict = data["rc-position-target"]
            a = list(datadict.values())
            return a
        if self.actiontype == "rc-angle-target":
            datadict = data["rc-angle-target"]
            a = list(datadict.values())
            return a
        if self.actiontype == "rc-pulse-target":
            datadict = data["rc-pulse-target"]
            a = list(datadict.values())
            return a
        


In [7]:
# choose an example 
demos_dir = pathlib.Path(Config()["demos"]["directory"])
proprio_dir = pathlib.Path(demos_dir, "demos", "proprioception-uncluttered")
demo_dir = next(proprio_dir.iterdir())
print(demo_dir)

sp = sp_conv_vae.ConvVaeSensorProcessing()

bcd = BCDemonstration(demo_dir, sensorprocessor=sp)
print(bcd)

/home/lboloni/Documents/Hackingwork/__Temporary/BerryPicker-demos/demos/proprioception-uncluttered/2024_10_26__16_31_40
resume_model and jsonfile are:
	resume_model=/home/lboloni/Documents/Hackingwork/__Temporary/BerryPicker-models/Conv-VAE/models/VAE_Robot/0901_125042/checkpoint-epoch171.pth
	jsonfile=/home/lboloni/Documents/Hackingwork/__Temporary/BerryPicker-models/Conv-VAE/models/VAE_Robot/0901_125042/config.json
{
    "name": "VAE_Robot",
    "n_gpu": 1,
    "arch": {
        "type": "VanillaVAE",
        "args": {
            "in_channels": 3,
            "latent_dims": 128,
            "flow": false
        }
    },
    "data_loader": {
        "###type-prev": "RobotDataLoader",
        "type": "CelebDataLoader",
        "args": {
            "data_dir": "/home/lboloni/Documents/Hackingwork/__Temporary/VisionBasedRobotManipulator-training-data/vae-training-data",
            "batch_size": 64,
            "shuffle": true,
            "validation_split": 0.2,
            "num_work

  self.checkpoint = torch.load(self.config.resume, map_location=torch.device('cpu'))


In [8]:
z, a = bcd.read_z_a()

In [9]:
print(z.shape)
print(a.shape)

(752, 128)
(752, 6)
