## Notes For Training the Dynamics Mode:

### Data
 - `Data_Collection/Action_Samples_1`: a non-temporal based dataset of size > 100K. Used for training the Spatial Encoder Model `ObjectTactileEncoder_Additive`

 - `Data_Collection/Time_Dependent_Action_Samples_1`: a temporal-based dataset of size > N. Created with a Temporal Context `T_buffer` = 3. Used for training the Temporal-Spatial Encoder Model 
 `TemporalObjectTactileEncoder_Additive`

  - `Data_Collection/Time_dependent_Action_Samples_2`: a temporal-based dataset of size > N. Created with a Temporal Context `T_buffer` = 3. Used for training the Temporal-Spatial Encoder Model 

   - `Data_Collection/Action_Pred_Time_dependent_Action_Samples_1`: a temporal-based dataset of size > N. Created with a Temporal Context `T_buffer` = 5. Used for training the Action Prediction Model and Temporal-Spatial Encoder Model 

 Note: the `Additive` part of the model's name refers to how encodings are colliegated, i.e. we add Positional and Value encodings as opposed to Concatenation. The Concatenation based model was not as useful as the additive model

#### What is in the Dataset?



##### Inputs:
- **palm_tactile**: 
    Shape: (3, 24) where `3=T_buffer` and `24` is the Tactile sensor reading vector.

- **finger_1_tactile**:
    Shape: (3, 24) where `3=T_buffer` and `24` is the Tactile sensor reading vector.

- **finger_2_tactile**: 
    Shape: (3, 24) where `3=T_buffer` and `24` is the Tactile sensor reading vector.

- **finger_3_tactile**: 
    Shape: (3, 24) where `3=T_buffer` and `24` is the Tactile sensor reading vector.

- **palm_location**: 
    Shape: (3, 74) where `3=T_buffer` and `74` contains the XYZ positions for all tactile sensors.

- **finger_1_location**:
    Shape: (3, 74) where `3=T_buffer` and `104` contains the XYZ positions for all tactile sensors.

- **finger_2_location**: 
    Shape: (3, 74) where `3=T_buffer` and `104` contains the XYZ positions for all tactile sensors.

- **finger_3_location**: 
    Shape: (3, 74) where `3=T_buffer` and `104` contains the XYZ positions for all tactile sensors.

- **obj_location**:
 - Shape: (3, 7, 6) where `3=T_buffer`, `7=OBJECT_QUANTITY` and `6=XYZ_RollYawPitch` for each of the balls in the scene.

- **obj_velocity**: 
    Shape: (3, 7, 6) where `3=T_buffer`, `7=OBJECT_QUANTITY` and `6=XYZ_RollYawPitch_Velocity` for each of the balls in the scene.

- **state_attrib**: 
    Shape: (45,) ... As of right now these are not being used

- **action**: 
    Shape: (5,) Action input for Spread, F1, F2, F3, ... the last index isn't being used, it used to be for scaling. 

##### Outputs:

- **obj_location**: 
    Shape: (42,), 

- **finger_1_location**: 
    Shape: (27,)

- **finger_2_location**: 
    Shape: (27,)

- **finger_3_location**: 
    Shape: (27,)

- **palm_location**: 
    Shape: (27,)

- **finger_1_tactile**: 
    Shape: (9,)

- **finger_2_tactile**: 
    Shape: (9,)

- **finger_3_tactile**: 
    Shape: (9,)

- **palm_tactile**: 
    Shape: (9,)

- **hand_config**: 
    Shape: (7,)

- **obj_count**: 
    Shape: (5,)

- **progress_bar**: 
    Shape: (1,)

- **reward**: 
    Shape: (1,)


## Model Architecture:
- Transformer Encoder for learning Spatial (and Temporal) representations of our data
    a. We Encode based on Type of Data (Object Vs. Limb) and Type of Projection (Temporal, Spatial, Value). 
    b. We use Sinosodial Embeddings for all Temporal encodings
    c. We concat our Projected tensors into an object Matrix where the input to the transformer will be `(BATCH_SIZE, NUM_OBJECTS*TIME_DIMENSION + 1, EMBEDDING_DIMENSION)`. The `+ 1` is captured as the output Vector and is the last object inserted into the transformer. This embedding is initialized to all zeros.
    d. We Mask our input with a probability `p=0.11`, randomly applying zero mask to our inputs. Play with the masking mechanism, as first iterations demonstrate the maksing mechanism is a promising feature.
    e. We return Output[-1], ie. the last embedding vector from the transformer. This is the zero input vector we created previously. The goal is to learn a good embedding vector based on the K transformer blocks we use to output this vector to the rest of the pipeline, whether its the RL or Pretraining process. 

- Residual Network for updating the Encoding, we condition the Encoding of the Observation with an Action Vector. We then perform a Gating between the Old Encoding and the New Encoding. This has not been tested, experiment with Gating Mechanism and Removal.

- We then branch out into K different Networks, where each sub-network recieves the State Encoding conditioned on the Action, and uses this to predict the K feature of the Future State

- Note: We only use raw data sensor data and object data to perform our predictions. We do not use any hand crafted or additional values in our prediction. Feel free to create a new encoding for other values, and throw them into the Transformer. The Transformer will use whatever you give it to make its predictions better.


## TODO:
- Action Prediction Model (inspired by Tedrake): Given `S[T], S[T+1]` predict `A[T]`
- reinforcment learning: Collect pretrained dynamix metrics.
- ...


In [1]:
## Initialization:: Use env stbl3 or raylib

import os

# CONTROL_DROP_DIR = os.environ["CONTROL_DROP_DIR"]
CONTROL_DROP_DIR = "/media/rpal/Drive_10TB/John/ControlDrop-RPAL"
SIM_DIR = "/home/rpal/CoppeliaSim_Edu_V4_4_0_rev0_Ubuntu20_04"

os.environ["CONTROL_DROP_DIR"] = CONTROL_DROP_DIR
os.environ["SIM_DIR"] = SIM_DIR

import sys

# os.chdir("..")
import torch as th
from gymnasium.spaces import Box
from control_dropping_rpal.RL.control_dropping_env import BerrettHandGym, T_buffer

from math import inf, radians, degrees
from stable_baselines3 import PPO, A2C

DATA_SAVE_PATH = os.path.join(CONTROL_DROP_DIR, "Data_Collection")
MODEL_PATH = os.path.join(
    CONTROL_DROP_DIR,
    "control_dropping/src/RL/Training/Checkpoints/TransformerFeatureEncoder/Expert_rl_5000_steps.zip",
)

# model = PPO.load(MODEL_PATH)
# model.set_env(env)
state_space = {
    "palm_tactile": Box(
        low=-inf,
        high=inf,
        shape=(
            T_buffer,
            24,
        ),
    ),  # Value
    "finger_1_tactile": Box(
        low=-inf,
        high=inf,
        shape=(
            T_buffer,
            24,
        ),
    ),  # Value
    "finger_2_tactile": Box(
        low=-inf,
        high=inf,
        shape=(
            T_buffer,
            24,
        ),
    ),  # Value
    "finger_3_tactile": Box(
        low=-inf,
        high=inf,
        shape=(
            T_buffer,
            24,
        ),
    ),  # Value
    # 'tactile_pos': Box(low= -inf, high= inf, shape=(378, )), # Position
    "finger_1_location": Box(
        low=-inf,
        high=inf,
        shape=(
            T_buffer,
            2 + 34 * 3,
        ),
    ),  # Joint pos [Theta_1, Theta_2] + [xyz*34]
    "finger_2_location": Box(
        low=-inf,
        high=inf,
        shape=(
            T_buffer,
            2 + 34 * 3,
        ),
    ),
    "finger_3_location": Box(
        low=-inf,
        high=inf,
        shape=(
            T_buffer,
            2 + 34 * 3,
        ),
    ),
    "palm_location": Box(
        low=-inf,
        high=inf,
        shape=(
            T_buffer,
            2 + 24 * 3,
        ),
    ),
    "obj_location": Box(low=-inf, high=inf, shape=(T_buffer, 7, 6)),  # Position
    "obj_velocity": Box(
        low=-inf, high=inf, shape=(T_buffer, 7, 6)
    ),  # Value, Concat with angular velocity
    "state_attrib": Box(
        low=-inf, high=inf, shape=(45,)
    ),  # Ball Cnt, Progress, Prev.Actions, hand_cfg, hand_trq (44)
}

In [None]:
import numpy as np


# 1. Create example data
def create_example_data():
    # Create two dictionaries with lists of numpy arrays
    data1 = {
        "key1": [
            np.array([1, 2, 3]),
            np.array([4, 5, 6]),
            np.array([7, 8]),
        ],  # Note the last array has a different shape
        "key2": [
            np.array([[1, 2], [3, 4]]),
            np.array([[5, 6], [7, 8]]),
            np.array([9, 10]),
        ],  # Note the last array has a different shape
    }

    data2 = {
        "key1": [
            np.array([10, 20, 30]),
            np.array([40, 50, 60]),
            np.array([70, 80, 90]),
        ],
        "key2": [
            np.array([[10, 20], [30, 40]]),
            np.array([[50, 60], [70, 80]]),
            np.array([[90, 100], [110, 120]]),
        ],
    }

    return [data1, data2]


# 2. Implement function to remove inhomogeneous elements
def remove_inhomogeneous_elements(aggregated_data):
    removal_idxs = set()

    # Detect inhomogeneous elements
    for data in aggregated_data:
        for key, arr_list in data.items():
            shapes = [arr.shape for arr in arr_list]
            most_common_shape = max(set(shapes), key=shapes.count)

            for idx, shape in enumerate(shapes):
                if shape != most_common_shape:
                    removal_idxs.add(idx)

    # Remove inhomogeneous elements
    for data in aggregated_data:
        for key in data:
            data[key] = [
                arr for idx, arr in enumerate(data[key]) if idx not in removal_idxs
            ]

    return aggregated_data, list(removal_idxs)

In [None]:
## Load data stuff
import json
import os
import numpy as np
import pickle

GAMMA = 0.5


def mult(arr, idx=0):
    if idx == len(arr):
        return 1
    return arr[idx] * mult(arr, idx + 1)


def save_data(
    data,
    file_path,
):
    if not os.path.exists(file_path):
        os.makedirs(file_path, exist_ok=True)

    # Extract data from each data point and create numpy arrays
    state_list = [state[0].copy() for state in data]
    pred_state_list = [state[1].copy() for state in data]

    chunk_data = {
        "states": state_list,
        "pred_states": pred_state_list,
    }
    import re

    i = (
        sorted([int(re.sub(r"\D", "", s)) for s in os.listdir(file_path)] + [-1])[-1]
        + 1
    )
    chunk_file_path = os.path.join(file_path, f"Actions_data_{i}.pkl")

    with open(chunk_file_path, "wb") as f:
        pickle.dump(chunk_data, f)
    print("Saved to", chunk_file_path)


def mult(shape):
    """Utility function to compute the product of elements in a shape tuple."""
    product = 1
    for dim in shape:
        product *= dim
    return product


def load_data_node_format(file_path, target_range=4):
    if not os.path.exists(os.path.dirname(file_path)):
        os.makedirs(os.path.dirname(file_path))
    # Load data from pickle files into a list of states
    states = []
    new_keys = {f"finger_{i}_tactile": f"finger_{i+1}_tactile" for i in range(3)}
    new_keys.update(
        {f"finger_{i}_locaction": f"finger_{i+1}_location" for i in range(3)}
    )
    new_keys.update({"palm_locaction": "palm_location"})

    for f_path in [
        f for f in os.listdir(file_path) if "Actions_data" in f
    ]:  # For each episode
        with open(os.path.join(file_path, f_path), "rb") as file:
            data = pickle.load(file)

        temp_states = []
        for state, pred in zip(
            data["states"], data["pred_states"]
        ):  # For each state[t-1], state[t] pair
            if not isinstance(pred, dict):
                pred = {key: pred[key] for key in pred.dtype.names}
            if "finger_0_tactile" in pred:
                # Update keys:
                pred = {new_keys.get(k, k): v for k, v in pred.items()}
            temp_states.append((state, pred))

        sts = [s[0] for s in temp_states]
        preds = [s[1] for s in temp_states]
        # for i in range(len(preds) - 2, -1, -1):
        #     preds[i]["reward"] += GAMMA * preds[i + 1]["reward"]

        states += [(s, p) for s, p in zip(sts, preds)]

    # Aggregate data by key
    aggregated_data = [{}, {}]
    for d in states:
        for key, value in d[0].items():
            if key not in aggregated_data[0]:
                aggregated_data[0][key] = []
            aggregated_data[0][key].append(value)
        for key, value in d[1].items():
            if key not in aggregated_data[1]:
                aggregated_data[1][key] = []
            aggregated_data[1][key].append(value)

    # Perform surgery: Some values are input wrong
    reward_data = []
    for i in range(len(aggregated_data[1]["reward"])):
        reward_data.append(
            np.array([aggregated_data[1]["reward"][i]])
        )  # Convert to list
    aggregated_data[1]["reward"] = np.array(reward_data)

    # Flatten and format:
    aggregated_data = [
        {
            k: [np.array(v).squeeze() for v in value]
            for k, value in aggregated_data[0].items()
        },
        {
            k: [np.array(v).flatten() for v in value]
            for k, value in aggregated_data[1].items()
        },
    ]

    stacked_values = np.vstack(aggregated_data[0]["action"])
    stacked_values /= 3.5
    aggregated_data[0]["action"] = [
        a.squeeze() for a in np.split(stacked_values, len(aggregated_data[0]["action"]))
    ]

    # Find inhomogeneous data:
    removal_idxs = set()

    for data in aggregated_data:
        for key, arr_list in data.items():
            shapes = [arr.shape for arr in arr_list]
            most_common_shape = max(
                set(shapes), key=shapes.count
            )  # Detect inhomogeneous elements

            for idx, shape in enumerate(shapes):
                if shape != most_common_shape:
                    removal_idxs.add(idx)

    # Remove inhomogeneous elements:
    for data in aggregated_data:
        for key in data:
            data[key] = [
                arr for idx, arr in enumerate(data[key]) if idx not in removal_idxs
            ]

    print("Inhomogeneous data:", removal_idxs)

    for data in aggregated_data:
        for key, arr in data.items():
            arr = np.where(np.isnan(arr), 0, arr)
            arr = np.where(arr < -(4**2), 0, arr)
            arr = np.where(arr > 4**2, 0, arr)
            data[key] = arr

    print("Shape Min Maxs:\n")
    for key in aggregated_data[0].keys():
        print(
            f"{key}: {aggregated_data[0][key][0].shape}, {np.min(aggregated_data[0][key])}, {np.max(aggregated_data[0][key])}"
        )
    for key in aggregated_data[1].keys():
        print(
            f"{key}: {np.array(aggregated_data[1][key][0]).shape}, {np.min(aggregated_data[1][key])}, {np.max(aggregated_data[1][key])}"
        )

    return aggregated_data


def load_data_state_format(file_path, target_range=4):
    if not os.path.exists(os.path.dirname(file_path)):
        os.makedirs(os.path.dirname(file_path))
    # Load data from pickle files into a list of states
    states = []
    new_keys = {f"finger_{i}_tactile": f"finger_{i+1}_tactile" for i in range(3)}
    new_keys.update(
        {f"finger_{i}_locaction": f"finger_{i+1}_location" for i in range(3)}
    )
    new_keys.update({"palm_locaction": "palm_location"})

    for f_path in [
        f for f in os.listdir(file_path) if "Actions_data" in f
    ]:  # For each episode
        with open(os.path.join(file_path, f_path), "rb") as file:
            data = pickle.load(file)

        temp_states = []
        for state, pred in zip(
            data["states"], data["pred_states"]
        ):  # For each state[t-1], state[t] pair
            if not isinstance(pred, dict):
                pred = {key: pred[key] for key in pred.dtype.names}
            if "finger_0_tactile" in pred:
                # Update keys:
                pred = {new_keys.get(k, k): v for k, v in pred.items()}

            # We collect prediction data for an entire state, we are going to restrict the prediction to a T_buffer of 1 so that it encapsulates only S[T] and enables better representations during the forward pass:
            for key in [
                k
                for k in pred.keys()
                if k != "reward"
                and T_buffer in pred[k].shape
                and len(pred[k].shape) > 1
            ]:
                arr = pred[key]
                pred[key] = arr[-1:, :]

            action = state["action"] if "action" in state else pred["action"]
            reward = np.array(
                [(pred["reward"] if "reward" in pred else state["reward"])]
            )

            state = {k: v for k, v in state.items() if k not in ("action", "reward")}
            n_state = {k: v for k, v in pred.items() if k not in ("action", "reward")}
            pred = {"action": action, "reward": reward}

            temp_states.append((state, n_state, pred))

        states += [(s, n_s, p) for s, n_s, p in temp_states]

    # Aggregate data by key
    aggregated_data = [{}, {}, {}]
    for d in states:
        # S[T - 1]
        for key, value in d[0].items():
            if key not in aggregated_data[0]:
                aggregated_data[0][key] = []
            aggregated_data[0][key].append(value)
        # S[T]
        for key, value in d[1].items():
            if key not in aggregated_data[1]:
                aggregated_data[1][key] = []
            aggregated_data[1][key].append(value)

        # Prediction:
        for key, value in d[2].items():
            if key not in aggregated_data[2]:
                aggregated_data[2][key] = []
            aggregated_data[2][key].append(value)

    # Flatten and format:
    aggregated_data = [
        {k: [np.array(v) for v in value] for k, value in aggregated_data[0].items()},
        {k: [np.array(v) for v in value] for k, value in aggregated_data[1].items()},
        {
            k: (
                [np.array(v).flatten() for v in value]
            )  # if k != "reward" else [np.array(v) for v in value])
            for k, value in aggregated_data[2].items()
        },
    ]

    stacked_values = np.vstack(aggregated_data[2]["action"])
    stacked_values *= 1 / 3.5
    aggregated_data[2]["action"] = [
        a.squeeze() for a in np.split(stacked_values, len(aggregated_data[2]["action"]))
    ]

    for data in aggregated_data:
        for key, arr in data.items():
            arr = np.where(np.isnan(arr), 0, arr)
            arr = np.where(arr < -(4**2), 0, arr)
            arr = np.where(arr > 4**2, 0, arr)
            data[key] = arr

    print("Shape Min Maxs:\n")
    for key in aggregated_data[0].keys():
        print(
            f"{key}: {aggregated_data[0][key][0].shape}, {np.min(aggregated_data[0][key])}, {np.max(aggregated_data[0][key])}"
        )
    print("\n----------\n")
    for key in aggregated_data[1].keys():
        print(
            f"{key}: {np.array(aggregated_data[1][key][0]).shape}, {np.min(aggregated_data[1][key])}, {np.max(aggregated_data[1][key])}"
        )
    print("\n----------\n")
    for key in aggregated_data[2].keys():
        print(
            f"{key}: {np.array(aggregated_data[2][key][0]).shape}, {np.min(aggregated_data[2][key])}, {np.max(aggregated_data[2][key])}"
        )

    return aggregated_data

In [None]:
## Loads the Data
path_dynamix = os.path.join(
    CONTROL_DROP_DIR, "Data_Collection/Time_Dependent_Samples_4/"
)
path_action_pred = os.path.join(
    CONTROL_DROP_DIR, "Data_Collection/Action_Pred_Time_Dependent_Samples_4/"
)

# Time_Dependent_Samples_1: Old Reward Function
# Time_Dependent_Samples_2: New Reward Function

dynamix_data = load_data_node_format(path_dynamix)
pred_data = load_data_state_format(path_action_pred)

for data in (
    dynamix_data,
    pred_data,
):
    print(
        "Dataset Size:", "\n".join([f"{k}: {len(data[0][k])}" for k in data[0].keys()])
    )

In [None]:
data_collection_dir = os.path.join(
    CONTROL_DROP_DIR, "Data_Collection", "Action_Samples_2"
)

# Ensure the 'Data_Collection/Action_Samples' directory exists
if not os.path.exists(data_collection_dir):
    os.makedirs(data_collection_dir, exist_ok=True)

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class ResidualLayer1D(nn.Module):
    def __init__(self, feature_dim: int, embed_dim=512, dropout_p=0.1):
        super(ResidualLayer1D, self).__init__()
        self.embed_dim = embed_dim
        self.fc1 = nn.Linear(
            feature_dim,
            embed_dim,
        )
        self.fc2 = nn.Linear(
            embed_dim,
            feature_dim,
        )
        self.n = nn.LayerNorm(feature_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.gelu = nn.GELU()

    def forward(self, x):
        out = self.gelu(self.fc1(x))
        out = self.dropout(out)
        out = self.gelu(self.fc2(out))
        out = self.n(x + out)
        return out


class ResidualBlocks1D(nn.Module):
    def __init__(self, feature_dim: int, num_blocks: int, embed_dim=512, dropout_p=0.1):
        super(ResidualBlocks1D, self).__init__()
        self.feature_dim = feature_dim
        self.num_blocks = num_blocks
        self.embed_dim = embed_dim
        self.layers = nn.Sequential(
            *[
                ResidualLayer1D(feature_dim, embed_dim, dropout_p=dropout_p)
                for _ in range(num_blocks)
            ]
        )

    def forward(self, x):
        out = self.layers(x)
        return out

In [None]:
import torch
import torch.nn as nn


# class PositionalEncoding(nn.Module):
#     def __init__(self, d_model, max_length=16):
#         super(PositionalEncoding, self).__init__()
#         self.encoding = nn.Embedding(max_length, d_model)

#     def forward(self, x):
#         sequence_length = x.size(1)
#         positions = torch.arange(sequence_length, dtype=torch.long, device=x.device).unsqueeze(0)
#         return self.encoding(positions)

# class TimeSequenceModel(nn.Module):
#     def __init__(self, embedding_size, action_size, num_layers, num_heads=8, hidden_size=2048, max_length=16):
#         super(TransformerModel, self).__init__()
#         self.embedding_size = embedding_size
#         self.action_size = action_size
#         self.max_length = max_length

#         self.positional_encoding = PositionalEncoding(embedding_size + action_size, max_length)
#         transformer_layer = nn.TransformerEncoderLayer (
#             batch_first=True,
#             d_model=self.embedding_size,
#             nhead=8,
#             dim_feedforward=hidden_size,
#             dropout=0.08,
#         )
#         self.transformer = nn.TransformerEncoder(transformer_layer, 8)
#         self.fc = nn.Linear(embedding_size + action_size, embedding_size)

#     def forward(self, states, actions):
#         batch_size, sequence_length, _ = states.shape

#         # Concatenate states and actions
#         inputs = torch.cat((states, actions), dim=-1)

#         # Apply positional encoding
#         positions = self.positional_encoding(inputs)
#         inputs = inputs + positions

#         # Create an upper triangular mask for attention
#         mask = torch.triu(torch.ones(sequence_length, sequence_length), diagonal=1).to(states.device)
#         mask = mask.masked_fill(mask == 1, float('-inf'))

#         # Pass the inputs through the Transformer
#         outputs = self.transformer(inputs, inputs, mask=mask)

#         # Pass the outputs through the final fully connected layer
#         predictions = self.fc(outputs)

#         return predictions


# embedded_states = embedding_model(states)
# predictions = transformer_model(embedded_states, actions)

In [None]:
def normalize_data_dynamix(data):
    # Compute mean and standard deviation for each key
    mean_dict = [{}, {}]
    std_dict = [{}, {}]

    for idx in range(len(mean_dict)):
        for key in data[idx].keys():
            all_values = np.concatenate(
                [np.array(arr).flatten() for arr in data[idx][key]]
            )
            mean_dict[idx][key] = np.mean(all_values, axis=0)
            std_dict[idx][key] = np.std(all_values, axis=0)

    # Normalize the data
    NON_NORM_KEYS = ["obj_count", "progress_bar", "state_attrib"]

    normalized_data = [data[0], {}]
    for key in data[1].keys():
        if key in NON_NORM_KEYS:
            normalized_data[1][key] = data[1][key]
            continue
        normalized_data[1][key] = []
        for arr in data[1][key]:
            normalized_arr = (arr - mean_dict[1][key]) / std_dict[1][key]
            normalized_data[1][key].append(normalized_arr)

    print("Dynamix:")
    print("MEAN:", mean_dict)
    print("STD:", std_dict)

    return normalized_data


def normalize_data_critiq(data):
    # Compute mean and standard deviation for each key
    mean_dict = [{}, {}, {}]
    std_dict = [{}, {}, {}]

    for idx in range(len(mean_dict)):
        for key in data[idx].keys():
            all_values = np.concatenate(
                [np.array(arr).flatten() for arr in data[idx][key]]
            )
            mean_dict[idx][key] = np.mean(all_values, axis=0)
            std_dict[idx][key] = np.std(all_values, axis=0)

    # Normalize the data
    NON_NORM_KEYS = ["action", "state_attrib"]

    normalized_data = [data[0], data[1], {}]
    for key in data[2].keys():
        if key in NON_NORM_KEYS:
            normalized_data[2][key] = data[2][key]
            continue
        normalized_data[2][key] = []
        for arr in data[2][key]:
            normalized_arr = (arr - mean_dict[2][key]) / std_dict[2][key]
            normalized_data[2][key].append(normalized_arr)

    print("Critiq:")
    print("MEAN:", mean_dict)
    print("STD:", std_dict)

    return normalized_data


dynamix_data = normalize_data_dynamix(dynamix_data)
pred_data = normalize_data_critiq(pred_data)

In [None]:
## Model for predictions
from typing import Optional, Any, Dict, Tuple

import torch as th
import torch.nn as nn
from control_dropping_rpal.RL.Networks.ExtractorNetworks import (
    ObjectTactileEncoder_Additive,
    TemporalObjectTactileEncoder_Additive,
)

DYNAMIX_OUTPUT_SIZES_DICT = {
    "finger_1_location": 27,
    "finger_2_location": 27,
    "finger_3_location": 27,
    "palm_location": 27,
    "finger_1_tactile": 9,
    "finger_2_tactile": 9,
    "finger_3_tactile": 9,
    "palm_tactile": 9,
    "obj_location": 42,
    "hand_config": 7,
    "obj_count": 5,
    "progress_bar": 1,
    "reward": 1,
}

PRED_OUTPUT_SIZES_DICT = {
    "action": 5,
    "reward": 1,
}


class DynamixModel(nn.Module):
    """Given (State, Action): Predict (N_state)"""

    def __init__(
        self,
        embed_dim_high=1024,
        embed_dim_low=256,
        device="cuda",
        dropout_prob=0.05,
        num_tsf_layer=4,
        num_residual_blocks=4,
        vec_encoding_size=8,
        use_mask=False,
        encoder: Optional[TemporalObjectTactileEncoder_Additive] = None,
    ):
        super(DynamixModel, self).__init__()
        self.device = device
        self.to(self.device)

        self.object_encoder = (
            encoder
            if encoder
            else TemporalObjectTactileEncoder_Additive(
                observation_space=state_space,
                vec_encoding_size=vec_encoding_size,
                t_dim_size=T_buffer,
                load_pretrain=False,
                num_tsf_layer=num_tsf_layer,
                use_mask=use_mask,
            )
        )

        self.act_size = 5
        self.cat_size = self.object_encoder.flatten_size + 5  # action shape
        self.embed_dim_high = embed_dim_high
        self.embed_dim_low = embed_dim_low
        self.activation = nn.GELU
        self.join_keys = ["action"]

        self.delta_state_network = nn.Sequential(
            # CAST
            ResidualBlocks1D(
                feature_dim=vec_encoding_size + 5,
                num_blocks=num_residual_blocks,
                embed_dim=embed_dim_high,
            ),
            # GATING
            nn.Linear(vec_encoding_size + 5, vec_encoding_size),
        )

        modules = {}
        for k, size in DYNAMIX_OUTPUT_SIZES_DICT.items():
            modules[k] = nn.Sequential(
                nn.Linear(vec_encoding_size, self.embed_dim_high),
                self.activation(),
                nn.Dropout(p=dropout_prob),
                nn.Linear(self.embed_dim_high, self.embed_dim_low),
                self.activation(),
                nn.Dropout(p=dropout_prob),
                nn.Linear(self.embed_dim_low, size),
            )

        self.networks = nn.ModuleDict(modules)

    def forward(self, obs):
        tac_encoding = self.object_encoder(obs)
        tac_encoding = tac_encoding.view((tac_encoding.shape[0], -1))
        cat_tensor = th.concatenate([tac_encoding, obs["action"]], dim=1).to(
            self.device
        )

        state_delta = self.delta_state_network(cat_tensor)

        # Intuition:
        # Or critiq model is trying to find a relationship between one state embedding and another state embedding to produce an action and a reward
        # Here we are trying to find the delta between our original state and the new state
        # caused by the action such that we can predict it
        embed_tnsor = state_delta + tac_encoding

        # Benifets: I can take a state at T[0] and pass S[0] through our network to
        # obtain an embedding e_0
        # With e_0 and A_0 I can obtain e_1: use e_0 CONCAT A_0 with delta state network
        # with e_1, I can obtain e_2 using A_1 via a similar process.
        # Hence, this modeling scheme is powerful if future trajectories can be reliably obtained in the embedding space of S

        # TODO: Test Utitlity
        # embed_tnsor = F.normalize(F.tanh(state_delta) + F.sigmoid(tac_encoding), eps=1e-7) #TODO: Test Utitlity
        # embed_tnsor = F.tanh((state_delta + tac_encoding ) * (1 / self.object_encoder.vec_encoding_size**0.5))
        # embed_tnsor = state_delta

        logit_map = {k: net(embed_tnsor) for k, net in self.networks.items()}
        return logit_map

    def load_checkpoint(self, path=None):
        if path is None:
            path = os.path.join(self.object_encoder.save_path, "dynamix.pt")

        checkpoint = th.load(path, map_location=self.device)
        self.object_encoder.load_state_dict(checkpoint["object_encoder"])
        self.delta_state_network.load_state_dict(checkpoint["delta_state_network"])
        self.networks.load_state_dict(checkpoint["networks"])

    def save_checkpoint(self, file=None):
        print("[DynamixModel] Saving Checkpoint...")
        if file is None:
            file = os.path.join(self.object_encoder.save_path, "dynamix.pt")

        save_dict = {
            "object_encoder": self.object_encoder.state_dict(),
            "delta_state_network": self.delta_state_network.state_dict(),
            "networks": self.networks.state_dict(),
        }
        th.save(save_dict, file)


class CritiqModel(nn.Module):
    """Given (State, N_state): Predict (Action, Reward)"""

    def __init__(
        self,
        embed_dim_high=1024,
        embed_dim_low=256,
        device="cuda",
        dropout_prob=0.05,
        num_tsf_layer=4,
        num_residual_blocks=4,
        vec_encoding_size=8,
        use_mask=False,
        encoder: Optional[TemporalObjectTactileEncoder_Additive] = None,
    ):
        super(CritiqModel, self).__init__()
        self.device = device
        self.to(self.device)

        self.object_encoder = (
            encoder
            if encoder
            else TemporalObjectTactileEncoder_Additive(
                observation_space=state_space,
                vec_encoding_size=vec_encoding_size,
                t_dim_size=T_buffer,
                load_pretrain=False,
                num_tsf_layer=num_tsf_layer,
                use_mask=use_mask,
            )
        )

        self.act_size = 5
        self.cat_size = self.object_encoder.flatten_size + 5  # action shape
        self.embed_dim_high = embed_dim_high
        self.embed_dim_low = embed_dim_low
        self.activation = nn.GELU
        self.join_keys = ["action"]

        self.predictor = nn.Sequential(
            # Recieves Delta State
            ResidualBlocks1D(
                feature_dim=vec_encoding_size * 2,
                num_blocks=5,
                embed_dim=embed_dim_high,
            ),
            nn.Linear(vec_encoding_size * 2, vec_encoding_size),
            self.activation(),
            nn.Dropout(p=dropout_prob),
        )

        # Time estimation:
        self.time_estimator = nn.Sequential(
            # Recieves Delta State
            nn.Linear(vec_encoding_size, vec_encoding_size),
            self.activation(),
            nn.Linear(vec_encoding_size, 1),
        )

        modules = {}
        for k, size in PRED_OUTPUT_SIZES_DICT.items():
            modules[k] = nn.Sequential(
                nn.Linear(vec_encoding_size, self.embed_dim_high),
                self.activation(),
                nn.Dropout(p=dropout_prob),
                nn.Linear(self.embed_dim_high, self.embed_dim_low),
                self.activation(),
                nn.Dropout(p=dropout_prob),
                nn.Linear(self.embed_dim_low, size),
            )

        self.networks = nn.ModuleDict(modules)

    def forward(self, state, n_state):
        # Feed forward S[T] and S[T+1]
        embed_state = self.object_encoder(state)
        n_embed_state = self.object_encoder(n_state)
        embed_state = embed_state.view((embed_state.shape[0], -1))
        n_embed_state = n_embed_state.view((n_embed_state.shape[0], -1))

        # We obtain a delta
        delta_state = n_embed_state - embed_state
        # delta_estimation = (1 / self.time_estimator(embed_state)) * self.predictor(delta_state)
        delta_concat = th.cat([embed_state, delta_state], dim=1).to(self.device)

        delta_estimation = self.predictor(delta_concat)

        logit_map = {k: net(delta_estimation) for k, net in self.networks.items()}

        return logit_map

    def load_checkpoint(self, path=None):
        if path is None:
            path = os.path.join(self.object_encoder.save_path, "critiq.pt")

        checkpoint = th.load(path, map_location=self.device)
        self.object_encoder.load_state_dict(checkpoint["object_encoder"])
        self.predictor.load_state_dict(checkpoint["predictor"])
        self.time_estimator.load_state_dict(checkpoint["time_estimator"])
        self.networks.load_state_dict(checkpoint["networks"])

    def save_checkpoint(self, file=None):
        print("[CritiqModel] Saving Checkpoint...")
        if file is None:
            file = os.path.join(self.object_encoder.save_path, "critiq.pt")

        save_dict = {
            "object_encoder": self.object_encoder.state_dict(),
            "predictor": self.predictor.state_dict(),
            "time_estimator": self.time_estimator.state_dict(),
            "networks": self.networks.state_dict(),
        }
        th.save(save_dict, file)


def make_dynamix_and_predictor(
    model_args: Dict[str, Any]
) -> Tuple[DynamixModel, CritiqModel]:
    # Extract common arguments
    embed_dim_high = model_args.get("embed_dim_high", 1024)
    embed_dim_low = model_args.get("embed_dim_low", 256)
    device = model_args.get("device", "cuda")
    dropout_prob = model_args.get("dropout_prob", 0.05)
    num_tsf_layer = model_args.get("num_tsf_layer", 4)
    num_residual_blocks = model_args.get("num_residual_blocks", 4)
    vec_encoding_size = model_args.get("vec_encoding_size", 8)
    use_mask = model_args.get("use_mask", False)

    # Create the shared encoder
    encoder = TemporalObjectTactileEncoder_Additive(
        observation_space=model_args.get("state_space"),
        vec_encoding_size=vec_encoding_size,
        t_dim_size=model_args.get("T_buffer"),
        load_pretrain=False,
        num_tsf_layer=num_tsf_layer,
        use_mask=use_mask,
    )

    # Create DynamixModel
    dynamix_model = DynamixModel(
        embed_dim_high=embed_dim_high,
        embed_dim_low=embed_dim_low,
        device=device,
        dropout_prob=dropout_prob,
        num_tsf_layer=num_tsf_layer,
        num_residual_blocks=num_residual_blocks,
        vec_encoding_size=vec_encoding_size,
        use_mask=use_mask,
        encoder=encoder,
    )

    # Create CritiqModel
    critiq_model = CritiqModel(
        embed_dim_high=embed_dim_high,
        embed_dim_low=embed_dim_low,
        device=device,
        dropout_prob=dropout_prob,
        num_tsf_layer=num_tsf_layer,
        num_residual_blocks=num_residual_blocks,
        vec_encoding_size=vec_encoding_size,
        use_mask=use_mask,
        encoder=encoder,
    )

    return dynamix_model, critiq_model

In [None]:
# Dataset Collection for States/Preds
import time as t

env = BerrettHandGym(detailed_training=True, is_val=True)

sim = env.simController

# RPAL
MAX_ACTION_COEF = 2.5


def _get_sampled_action():
    action_coef = np.random.uniform() * MAX_ACTION_COEF
    action = (
        action_coef * env.action_space.sample()
        if np.random.uniform() > 0.5
        else np.array([0.0, 0.0, 0.0, 0.0, 0.0])
    )
    return action


def DataCollect(num_steps):
    done = False
    data = []
    data_point = env.reset()[0].copy()
    i = 0
    while i < num_steps:
        try:
            if done:
                save_data(
                    data,
                    os.path.join(
                        CONTROL_DROP_DIR,
                        "Data_Collection",
                        "Time_Dependent_Samples_4",
                    ),
                )
                data = []
                data_point = env.reset()[0].copy()
            i += 1
            action = _get_sampled_action()
            # if model == None else model.predict(data_point,)
            print("Action:", action)
            data_point["action"] = action
            state, reward, done, _ = env.step(action)
            # t.sleep(0.05)
            pred_state = env.simController.get_pred_state()
            pred_state["reward"] = reward
            print("Pred State:", pred_state.keys())
            data.append((data_point, pred_state))
            data_point = state
        except:
            data = []
            data_point = env.reset()[0].copy()


DataCollect(80000)

In [None]:
# Dataset Collection for States/Preds
import time as t

env = BerrettHandGym(detailed_training=True, is_val=True)
sim = env.simController

# RPAL
MAX_ACTION_COEF = 2.5


def _get_sampled_action():
    action_coef = np.random.uniform() * MAX_ACTION_COEF
    action = (
        action_coef * env.action_space.sample()
        if np.random.uniform() > 0.5
        else np.array([0.0, 0.0, 0.0, 0.0, 0.0])
    )
    return action


def DataCollect(num_steps):
    done = False
    data = []
    data_point = env.reset()[0].copy()
    i = 0
    while i < num_steps:
        try:
            if done:
                save_data(
                    data,
                    os.path.join(
                        CONTROL_DROP_DIR,
                        "Data_Collection",
                        "Action_Pred_Time_Dependent_Samples_4",
                    ),
                )
                data = []
                data_point = env.reset()[0].copy()
            i += 1
            action = _get_sampled_action()
            # if model == None else model.predict(data_point,)
            print("Action:", action)
            state, reward, done, _ = env.step(action)
            # t.sleep(0.05)
            pred_state = state.copy()
            pred_state["reward"] = reward
            pred_state["action"] = action
            print("Pred State:", pred_state.keys())
            data.append((data_point, pred_state))
            data_point = state
        except:
            data = []
            data_point = env.reset()[0].copy()


DataCollect(30000)

In [None]:
# Vectorized simulation data collection:

NUM_WORKERS = 4

from control_dropping_rpal.Utils.env_utils import AsyncVectorEnv
from control_dropping_rpal.RL.control_dropping_env import BerretHandGymRayLibWrapper
import time

env = AsyncVectorEnv(
    lambda: BerretHandGymRayLibWrapper(
        {
            "test": False,
            "cluster_index": 3,
            "sim_port": None,
            "object_type": "Sphere",
            "object_quantity": 7,
            "detailed_training": False,
            "detailed_save_dir": None,
            "plot_params": ["history", "avg_ke", "avg_vel", "avg_rew"],
            "is_val": False,
        }
    ),
    num_envs=NUM_WORKERS,
)


def DataCollectDynamix(env: AsyncVectorEnv, num_steps: int):
    data = {i: [] for i in range(env.num_envs)}
    data_points, _, _, _, _ = env.poll()
    data_points = data_points.copy()

    for i in range(num_steps):
        while len(data_points.keys()) == 0:  # All environments are currently resetting:
            time.sleep(0.05)
            data_points, _, _, _, _ = env.poll()
            data_points = data_points.copy()

        actions = {}
        for env_id in data_points.keys():
            action = _get_sampled_action(env.action_spaces[env_id])
            actions[env_id] = action
            data_points[env_id]["action"] = action

        env.send_actions(actions)
        new_obs, rewards, dones, _, _ = env.poll()
        new_obs = new_obs.copy()

        for env_id in new_obs:
            if env_id in data_points:
                pred_state = env.envs[env_id].agent.simController.get_pred_state()
                pred_state["reward"] = rewards[env_id]
                data[env_id].append((data_points[env_id], pred_state))

                if dones[env_id]:
                    save_data(
                        data[env_id],
                        os.path.join(
                            CONTROL_DROP_DIR,
                            "Data_Collection",
                            "Time_Dependent_Samples_3",  # Controls the path of where the collected data gets stored to, change this for new experiments
                        ),
                    )
                    data[env_id] = []

        data_points = new_obs.copy()

        print(f"Step {i+1}/{num_steps}")


def DataCollectCritiq(env: AsyncVectorEnv, num_steps: int):
    data = {i: [] for i in range(env.num_envs)}
    data_points, _, _, _, _ = env.poll()
    data_points = data_points.copy()

    for i in range(num_steps):
        while len(data_points.keys()) == 0:  # All environments are currently resetting:
            time.sleep(0.05)
            data_points, _, _, _, _ = env.poll()
            data_points = data_points.copy()

        actions = {}
        for env_id in data_points.keys():
            action = _get_sampled_action(env.action_spaces[env_id])
            actions[env_id] = action
            data_points[env_id]["action"] = action

        env.send_actions(actions)
        new_obs, rewards, dones, _, _ = env.poll()
        new_obs = new_obs.copy()

        for env_id in new_obs:
            if env_id in data_points:
                pred_state = new_obs[env_id]
                pred_state["reward"] = rewards[env_id]
                pred_state["action"] = actions[env_id]
                data[env_id].append((data_points[env_id], pred_state))

                if dones[env_id]:
                    save_data(
                        data[env_id],
                        os.path.join(
                            CONTROL_DROP_DIR,
                            "Data_Collection",
                            "Action_Pred_Time_Dependent_Samples_1",  # Controls the path of where the collected data gets stored to, change this for new experiments
                        ),
                    )
                    data[env_id] = []

        data_points = new_obs.copy()

        print(f"Step {i+1}/{num_steps}")


def _get_sampled_action(action_space):
    MAX_ACTION_COEF = 10.0
    action_coef = np.random.uniform() * MAX_ACTION_COEF
    action = (
        action_coef * action_space.sample()
        if np.random.uniform() > 0.5
        else np.array([0.0, 0.0, 0.0, 0.0, 0.0])
    )
    return action


# DataCollectDynamix(env, 50000)
DataCollectCritiq(env, 100000)
env.close()

In [None]:
# Dataset Collection for States/Preds
import time as t

sim = env.simController

# RPAL
MAX_ACTION_COEF = 10.0


def _get_sampled_action():
    action_coef = np.random.uniform() * MAX_ACTION_COEF
    action = (
        action_coef * env.action_space.sample()
        if np.random.uniform() > 0.5
        else np.array([0.0, 0.0, 0.0, 0.0, 0.0])
    )
    return action


def DataCollect(num_steps):
    done = False
    data = []
    data_point = env.reset()[0].copy()
    i = 0
    while i < num_steps:
        try:
            if done:
                save_data(
                    data,
                    os.path.join(
                        CONTROL_DROP_DIR,
                        "Data_Collection",
                        "Action_Pred_Time_Dependent_Samples_1",
                    ),
                )
                break
                data = []
                data_point = env.reset()[0].copy()
            i += 1
            action = _get_sampled_action()
            # if model == None else model.predict(data_point,)
            print("Action:", action)
            data_point["action"] = action
            state, reward, done, _ = env.step(action)
            # t.sleep(0.05)
            # pred_state = env.simController.get_pred_state()
            state["reward"] = reward
            print("Pred State:", state.keys())
            data.append((data_point, state))
            data_point = state
        except:
            data = []
            data_point = env.reset()[0].copy()

    m_data = load_data_state_format(
        os.path.join(
            CONTROL_DROP_DIR,
            "Data_Collection",
            "Action_Pred_Time_Dependent_Samples_1",
        ),
    )

    print(data, m_data)


DataCollect(30000)

In [None]:
## Augments Data into Correct Format (Converts the Dict objects from Key, Idx, value to Idx, Key, value format)


def convert_dataset_dynamix(
    data,
):
    num_samples = len(data[0]["palm_location"])
    state_elements = {
        "palm_tactile": [],
        "finger_1_tactile": [],
        "finger_2_tactile": [],
        "finger_3_tactile": [],
        "palm_location": [],
        "finger_1_location": [],
        "finger_2_location": [],
        "finger_3_location": [],
        "obj_location": [],
        "obj_velocity": [],
        "action": [],
        "state_attrib": [],
    }
    pred_elements = {
        "palm_tactile": [],
        "finger_1_tactile": [],
        "finger_2_tactile": [],
        "finger_3_tactile": [],
        "palm_location": [],
        "finger_1_location": [],
        "finger_2_location": [],
        "finger_3_location": [],
        "obj_location": [],
        "obj_count": [],
        "reward": [],
        "progress_bar": [],
    }

    state = data[0]
    prediction = data[1]
    for i in range(num_samples):
        for k in state_elements.keys():
            state_elements[k].append(
                th.nan_to_num(
                    th.tensor(state[k][i], dtype=th.float32),
                    nan=0.0,
                    posinf=0.0,
                    neginf=0.0,
                )
            )
        for k in pred_elements.keys():
            pred_elements[k].append(
                th.nan_to_num(
                    th.tensor(
                        np.array(
                            prediction[k][i],
                        ).flatten(),
                        dtype=th.float32,
                    ),
                    nan=0.0,
                    posinf=0.0,
                    neginf=0.0,
                )
            )

    return state_elements, pred_elements


def convert_dataset_critiq(
    data,
):
    num_samples = len(data[0]["palm_location"])
    state_elements = {
        "palm_tactile": [],
        "finger_1_tactile": [],
        "finger_2_tactile": [],
        "finger_3_tactile": [],
        "palm_location": [],
        "finger_1_location": [],
        "finger_2_location": [],
        "finger_3_location": [],
        "obj_location": [],
        "obj_velocity": [],
        "state_attrib": [],
    }

    n_state_elements = {
        "palm_tactile": [],
        "finger_1_tactile": [],
        "finger_2_tactile": [],
        "finger_3_tactile": [],
        "palm_location": [],
        "finger_1_location": [],
        "finger_2_location": [],
        "finger_3_location": [],
        "obj_location": [],
        "obj_velocity": [],
        "state_attrib": [],
    }

    pred_elements = {
        "reward": [],
        "action": [],
    }

    state = data[0]
    n_state = data[1]
    prediction = data[2]

    for i in range(num_samples):
        for k in state_elements.keys():
            state_elements[k].append(
                th.nan_to_num(
                    th.tensor(state[k][i], dtype=th.float32),
                    nan=0.0,
                    posinf=0.0,
                    neginf=0.0,
                )
            )

        for k in n_state_elements.keys():
            n_state_elements[k].append(
                th.nan_to_num(
                    th.tensor(n_state[k][i], dtype=th.float32),
                    nan=0.0,
                    posinf=0.0,
                    neginf=0.0,
                )
            )

        for k in pred_elements.keys():
            pred_elements[k].append(
                th.nan_to_num(
                    th.tensor(
                        np.array(
                            prediction[k][i],
                        ).flatten(),
                        dtype=th.float32,
                    ),
                    nan=0.0,
                    posinf=0.0,
                    neginf=0.0,
                )
            )

    return state_elements, n_state_elements, pred_elements


dynamix_data = convert_dataset_dynamix(dynamix_data)
pred_data = convert_dataset_critiq(pred_data)

In [None]:
## Dataset Class for Object Motion Prediction
from torch.utils.data import Dataset

BATCH_SIZE = 128


class ObjDataset(Dataset):
    def __init__(self, data, batch_size):
        self.elements, self.preds = data
        self.batch_size = batch_size
        self.num_samples = len(self.elements["action"])
        self.rand_sort()

    def rand_sort(self):
        permutations = np.random.permutation(self.num_samples)
        self.elements = {
            k: [val[p] for p in permutations] for k, val in self.elements.items()
        }
        self.preds = {
            k: [val[p] for p in permutations] for k, val in self.preds.items()
        }

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Extract the corresponding elements for the given index
        # y_output = self.y_outputs[idx: min(idx+BATCH_SIZE, self.num_samples)]
        sample_elements = {
            key: value[idx : min(idx + self.batch_size, self.num_samples - 1)]
            for key, value in self.elements.items()
        }  #
        y_output = {
            key: value[idx : min(idx + self.batch_size, self.num_samples - 1)]
            for key, value in self.preds.items()
        }  #
        if isinstance(sample_elements["action"], list):
            sample_elements = {
                key: th.stack(value, dim=0) for key, value in sample_elements.items()
            }
        if isinstance(y_output["reward"], list):
            y_output = {key: th.stack(value, dim=0) for key, value in y_output.items()}
        return sample_elements, y_output


class CritiqDataset(Dataset):
    def __init__(self, data, batch_size):
        self.states, self.n_states, self.preds = data
        self.batch_size = batch_size
        self.num_samples = len(self.states["palm_location"])
        self.rand_sort()

    def rand_sort(self):
        permutations = np.random.permutation(self.num_samples)
        self.states = {
            k: [val[p] for p in permutations] for k, val in self.states.items()
        }
        self.preds = {
            k: [val[p] for p in permutations] for k, val in self.preds.items()
        }

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Extract the corresponding elements for the given index
        # y_output = self.y_outputs[idx: min(idx+BATCH_SIZE, self.num_samples)]
        state_batch = {
            key: value[idx : min(idx + self.batch_size, self.num_samples - 1)]
            for key, value in self.states.items()
        }
        n_state_batch = {
            key: value[idx : min(idx + self.batch_size, self.num_samples - 1)]
            for key, value in self.states.items()
        }
        pred_batch = {
            key: value[idx : min(idx + self.batch_size, self.num_samples - 1)]
            for key, value in self.preds.items()
        }

        # Ensure outputs are Tensor:
        if isinstance(state_batch["palm_location"], list):
            state_batch = {
                key: th.stack(value, dim=0) for key, value in state_batch.items()
            }

        if isinstance(n_state_batch["palm_location"], list):
            n_state_batch = {
                key: th.stack(value, dim=0) for key, value in n_state_batch.items()
            }

        if isinstance(pred_batch["reward"], list):
            pred_batch = {
                key: th.stack(value, dim=0) for key, value in pred_batch.items()
            }

        return state_batch, n_state_batch, pred_batch


# # print(dataset)
# dataset = ObjDataset(dynamix_data, BATCH_SIZE)

In [None]:
start_lr = 0.001
end_lr = 0.00005
factor = 0.999


def lr_schedule():
    global start_lr, end_lr, factor
    ret_lr = start_lr
    start_lr *= factor
    return ret_lr

In [None]:
## Dataloader for Model Training
from stable_baselines3.common.utils import obs_as_tensor
from torch.utils.data import DataLoader, random_split


def collate_fn(batch):
    # Separate data and labels from the batch
    datas, preds = zip(*batch)
    datas = [obs_as_tensor(d, "cuda") for d in datas]
    preds = [obs_as_tensor(d, "cuda") for d in preds]

    return th.stack(datas, dim=0), th.stack(preds, dim=0)


# # Calculate the split sizes
# train_size = int(0.92 * len(dataset))
# validation_size = len(dataset) - train_size
# train_set = (
#     {k: val[:train_size] for k, val in dataset.elements.items()},
#     {k: val[:train_size] for k, val in dataset.preds.items()},
# )
# val_set = (
#     {k: val[:-validation_size] for k, val in dataset.elements.items()},
#     {k: val[:-validation_size] for k, val in dataset.preds.items()},
# )

# print("Val Set:", val_set)

dynamix_dataset = ObjDataset(dynamix_data, BATCH_SIZE)
train_size = int(0.92 * len(dynamix_dataset))
validation_size = len(dynamix_dataset) - train_size
dynamix_train_dataset, dynamix_validation_dataset = random_split(
    dynamix_dataset, [train_size, validation_size]
)

critiq_dataset = CritiqDataset(pred_data, BATCH_SIZE)
train_size = int(0.9 * len(critiq_dataset))
validation_size = len(critiq_dataset) - train_size
critiq_train_dataset, critiq_validation_dataset = random_split(
    critiq_dataset, [train_size, validation_size]
)

# train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, )#collate_fn=collate_fn)
# validation_data_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, )#collate_fn=collate_fn)

for train_dataset, val_dataset in [
    (
        dynamix_train_dataset,
        dynamix_validation_dataset,
    ),
    (critiq_train_dataset, critiq_validation_dataset),
]:
    print("Training Data:")
    batch = dynamix_train_dataset[0]
    x, y = batch
    for k, v in x.items():
        print(
            k,
            v.shape,
            v.min(),
            v.max(),
        )
    for k, v in y.items():
        print(
            k,
            v.shape,
            v.min(),
            v.max(),
        )
    print()
    print("Validation Data:")
    batch = dynamix_validation_dataset[0]
    x, y = batch
    for k, v in x.items():
        print(k, v.shape)
        print("---\n")
    for k, v in y.items():
        print(k, v.shape)
        print("---\n")

    print()

    print("Training Data:")
    batch = critiq_train_dataset[0]
    x, y, z = batch
    for k, v in x.items():
        print(
            k,
            v.shape,
            v.min(),
            v.max(),
        )
    for k, v in y.items():
        print(
            k,
            v.shape,
            v.min(),
            v.max(),
        )
    for k, v in z.items():
        print(
            k,
            v.shape,
            v.min(),
            v.max(),
        )

    print()
    print("Validation Data:")
    batch = critiq_validation_dataset[0]
    x, y, z = batch
    for k, v in x.items():
        print(k, v.shape)
        print("---\n")
    for k, v in y.items():
        print(k, v.shape)
        print("---\n")
    for k, v in z.items():
        print(k, v.shape)
        print("---\n")

In [None]:
import torch.nn.functional as F

# Try Including/Excluding Finger Values (2)
# Diff between loss (Mean/Sum) ()
# Use huggingface model
# Model Predictive Control: Accuracy (1)
#
key_losses_dynamix = {
    "palm_tactile": lambda y_pred, y_target: F.mse_loss(y_pred, y_target),
    "finger_1_tactile": lambda y_pred, y_target: F.mse_loss(y_pred, y_target),
    "finger_2_tactile": lambda y_pred, y_target: F.mse_loss(y_pred, y_target),
    "finger_3_tactile": lambda y_pred, y_target: F.mse_loss(y_pred, y_target),
    "palm_location": lambda y_pred, y_target: torch.zeros_like(
        F.mse_loss(y_pred, y_target)
    ),  # F.mse_loss(y_pred, y_target),
    "finger_1_location": lambda y_pred, y_target: F.mse_loss(y_pred, y_target),
    "finger_2_location": lambda y_pred, y_target: F.mse_loss(y_pred, y_target),
    "finger_3_location": lambda y_pred, y_target: F.mse_loss(y_pred, y_target),
    "obj_location": lambda y_pred, y_target: F.mse_loss(y_pred, y_target),
    "obj_count": lambda y_pred, y_target: F.cross_entropy(y_pred, y_target),
    "reward": lambda y_pred, y_target: F.mse_loss(y_pred, y_target),
}

key_losses_critiq = {
    "action": lambda y_pred, y_target: F.mse_loss(y_pred, y_target),
    "reward": lambda y_pred, y_target: F.mse_loss(y_pred, y_target),
}

In [None]:
# Training the Object Transformer!!
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.utils import clip_grad_norm_
import os
import threading
import queue


def train_model(
    dynamix_model,
    critiq_model,
    dynamix_train_loader,
    dynamix_val_loader,
    critiq_train_loader,
    critiq_val_loader,
    batch_size=16,
    epochs=1000,
    learning_rate=5e-3,
    log_interval=50,
    no_cuda=False,
    seed=1,
    is_lstm=False,
    patience=10,
    num_workers=4,
):

    use_cuda = not no_cuda and th.cuda.is_available()
    use_data_parallel = False
    # Check if multiple GPUs are available
    if th.cuda.device_count() > 1:
        print(f"Using {th.cuda.device_count()} GPUs")
        dynamix_model = nn.DataParallel(dynamix_model)
        critiq_model = nn.DataParallel(critiq_model)
        use_data_parallel = True

    device = th.device("cuda" if use_cuda else "cpu")
    print(device)

    initial_lr = 0.00001
    train_losses = []
    val_losses = []

    def dynamix_worker(model, device, data_queue, optimizer, data_lock, train_losses):
        while True:
            try:
                data, target = data_queue.get(timeout=1)
                if not data:
                    continue
                data = {k: v.to(device).squeeze(dim=0) for k, v in data.items()}
                target = {k: v.to(device).squeeze(dim=0) for k, v in target.items()}
                y_target = target
                with data_lock:
                    optimizer.zero_grad()
                    y_pred = model(data)
                    loss = th.stack(
                        [
                            key_losses_dynamix[key](y_pred[key], y_target[key])
                            for key in key_losses_dynamix.keys()
                            if key in y_pred and key in y_target
                        ]
                    ).sum()
                    loss.backward()
                    clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    lr_schedule()
                    train_losses.append(loss.item())
                print("Loss:", loss.item())
                data_queue.task_done()

            except queue.Empty:
                break

    def critiq_worker(model, device, data_queue, optimizer, data_lock, train_losses):
        while True:
            try:
                state, n_state, target = data_queue.get(timeout=1)
                # if is_lstm: model.reset_hidden_state(data.shape[0])
                state = {k: v.to(device).squeeze(dim=0) for k, v in state.items()}
                n_state = {k: v.to(device).squeeze(dim=0) for k, v in n_state.items()}
                target = {k: v.to(device).squeeze(dim=0) for k, v in target.items()}
                y_target = target

                with data_lock:
                    optimizer.zero_grad()
                    y_pred = model(state, n_state)
                    loss = th.stack(
                        [
                            key_losses_critiq[key](y_pred[key], y_target[key])
                            for key in key_losses_critiq.keys()
                            if key in y_pred and key in y_target
                        ]
                    ).sum()
                    loss.backward()
                    clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    lr_schedule()
                    train_losses.append(loss.item())
                print("Loss:", loss.item())
                data_queue.task_done()

            except queue.Empty:
                break

    def train(
        dynamix_model,
        critiq_model,
        device,
        dynamix_train_loader,
        critiq_train_loader,
        dynamix_optimizer,
        critiq_optimizer,
        num_workers=4,
        is_lstm=is_lstm,
    ):
        for model in (
            dynamix_model,
            critiq_model,
        ):
            model.train()
            model.to(device)
        total_loss = 0
        i = 0

        ####### Thread safe worker threads
        # dynamix_data_queue = queue.Queue()                print("Loss:", loss.item())
        # data_lock = threading.Lock()

        # for _ in range(num_workers // 2):
        #     worker_thread = threading.Thread(
        #         target=dynamix_worker,
        #         args=(
        #             dynamix_model,
        #             device,
        #             dynamix_data_queue,
        #             dynamix_optimizer,
        #             data_lock,
        #             train_losses,
        #         ),
        #     )
        #     worker_thread.daemon = True
        #     workers.append(worker_thread)

        # for _ in range(num_workers - (num_workers // 2)):
        #     worker_thread = threading.Thread(
        #         target=critiq_worker,
        #         args=(
        #             critiq_model,state = {k: v.to(device).squeeze(dim=0) for k, v in state.items()}
        # n_state = {k: v.to(device).squeeze(dim=0) for k, v in n_state.items()}
        # target = {k: v.to(device).squeeze(dim=0) for k, v in target.items()}
        #             critiq_data_queue,
        #             critiq_optimizer,
        #             data_lock,
        #             train_losses,
        #         ),
        #     )
        #     worker_thread.daemon = True
        #     workers.append(worker_thread)

        # for idx in range(0, len(dynamix_train_loader), BATCH_SIZE):                print("Loss:", loss.item())
        # for worker in workers:
        #     worker.start()

        # dynamix_data_queue.join()
        # critiq_data_queue.join()

        # for worker_thread in workers:
        #     worker_thread.join()

        # return total_loss / (
        #     len(dynamix_train_loader.dataset) + len(critiq_train_loader.dataset)
        # )
        #######
        for idx in range(
            0, min(len(dynamix_train_loader), len(critiq_train_loader)), BATCH_SIZE
        ):
            local_loss = 0
            for optimizer in (dynamix_optimizer, critiq_optimizer):
                optimizer.zero_grad()

            dynamix_data, dynamix_target = dynamix_train_loader[idx]
            state, n_state, critiq_target = critiq_train_loader[idx]

            dynamix_data = {
                k: v.to(device).squeeze(dim=0) for k, v in dynamix_data.items()
            }
            y_target = {
                k: v.to(device).squeeze(dim=0) for k, v in dynamix_target.items()
            }

            y_pred = dynamix_model(dynamix_data)
            loss = th.stack(
                [
                    key_losses_dynamix[key](y_pred[key], y_target[key])
                    for key in key_losses_dynamix.keys()
                    if key in y_pred and key in y_target
                ]
            ).sum()

            loss.backward()
            local_loss += loss.item()

            state = {k: v.to(device).squeeze(dim=0) for k, v in state.items()}
            n_state = {k: v.to(device).squeeze(dim=0) for k, v in n_state.items()}
            y_target = {
                k: v.to(device).squeeze(dim=0) for k, v in critiq_target.items()
            }

            y_pred = critiq_model(state, n_state)

            loss = th.stack(
                [
                    key_losses_critiq[key](y_pred[key], y_target[key])
                    for key in key_losses_critiq.keys()
                    if key in y_pred and key in y_target
                ]
            ).sum()

            loss.backward()
            local_loss += loss.item()
            print("Loss:", local_loss)

            for model in (dynamix_model, critiq_model):
                clip_grad_norm_(model.parameters(), 1.5)

            for optimizer in (dynamix_optimizer, critiq_optimizer):
                optimizer.step()

            lr_schedule()
            train_losses.append(local_loss)

    def validation(
        dynamix_model,
        critiq_model,
        device,
        dynamix_val_loader,
        critiq_val_loader,
        is_lstm=is_lstm,
    ):
        for model in (
            dynamix_model,
            critiq_model,
        ):
            model.eval()

        loss_total = 0
        with th.no_grad():
            for idx in range(0, len(dynamix_val_loader), BATCH_SIZE):
                data, target = dynamix_val_loader[idx]
                # if is_lstm: model.reset_hidden_state(data.shape[0])
                data = {k: v.to(device).squeeze(dim=0) for k, v in data.items()}
                target = {k: v.to(device).squeeze(dim=0) for k, v in target.items()}

                y_target = target
                y_pred = dynamix_model(data)

                loss = th.stack(
                    [
                        key_losses_dynamix[key](y_pred[key], y_target[key])
                        for key in key_losses_dynamix.keys()
                        if key in y_pred and key in y_target
                    ]
                ).mean()  # + th.sum(weights_sq) / loss_total += loss.item()
                loss_total += loss.item()

            for idx in range(0, len(critiq_val_loader), BATCH_SIZE):
                state, n_state, target = critiq_val_loader[idx]
                # if is_lstm: model.reset_hidden_state(data.shape[0])
                state = {k: v.to(device).squeeze(dim=0) for k, v in state.items()}
                n_state = {k: v.to(device).squeeze(dim=0) for k, v in n_state.items()}
                target = {k: v.to(device).squeeze(dim=0) for k, v in target.items()}

                y_target = target
                y_pred = critiq_model(state, n_state)

                loss = th.stack(
                    [
                        key_losses_critiq[key](y_pred[key], y_target[key])
                        for key in key_losses_critiq.keys()
                        if key in y_pred and key in y_target
                    ]
                ).sum()  # + th.sum(weights_sq) / loss_total += loss.item()
                loss_total += loss.item()

        val_loss = loss_total / (
            (
                len(dynamix_val_loader.dataset)
                + len(critiq_val_loader.dataset) // BATCH_SIZE
            )
        )
        val_losses.append(val_loss)
        print("Validation_loss:", val_loss)
        return val_loss

    # params = []
    for model in (
        dynamix_model,
        critiq_model,
    ):
        model.to(device)
    #     params.extend(model.parameters())

    dynamix_optimizer = optim.Adam(dynamix_model.parameters(), lr=initial_lr)
    dynamix_scheduler = CosineAnnealingLR(
        dynamix_optimizer, T_max=epochs, eta_min=0, verbose=True
    )

    critiq_optimizer = optim.Adam(critiq_model.parameters(), lr=initial_lr)
    critiq_scheduler = CosineAnnealingLR(
        critiq_optimizer, T_max=epochs, eta_min=0, verbose=True
    )

    warmup_epochs = 50
    print("Training...")
    for epoch in range(1, epochs + 1):
        if epoch < warmup_epochs:
            warmup_lr = initial_lr * (epoch / warmup_epochs)
            for optimizer in (dynamix_optimizer, critiq_optimizer):
                for param_group in optimizer.param_groups:
                    param_group["lr"] = warmup_lr
            train(
                dynamix_model,
                critiq_model,
                device,
                dynamix_train_loader,
                critiq_train_loader,
                dynamix_optimizer,
                critiq_optimizer,
                num_workers=num_workers,
            )
            validation(
                dynamix_model,
                critiq_model,
                device,
                dynamix_val_loader,
                critiq_val_loader,
            )
        else:
            dynamix_scheduler.step()
            critiq_scheduler.step()
            train_loss = train(
                dynamix_model,
                critiq_model,
                device,
                dynamix_train_loader,
                critiq_train_loader,
                dynamix_optimizer,
                critiq_optimizer,
            )
            if epoch % 10 == 0:
                val_loss = validation(
                    dynamix_model,
                    critiq_model,
                    device,
                    dynamix_val_loader,
                    critiq_val_loader,
                )
            if epoch % 50 == 0:
                (
                    (
                        model.module.save_checkpoint()
                        if use_data_parallel
                        else model.save_checkpoint()
                    )
                )

    plt.figure()
    plt.plot(range(len(train_losses)), train_losses, label="Train Loss")
    plt.xlabel("Seq")
    plt.ylabel("Loss")
    plt.savefig("train_loss.png")
    plt.figure()
    plt.plot(range(len(val_losses)), val_losses, label="Val Loss")
    plt.xlabel("Seq")
    plt.ylabel("Loss")
    plt.savefig("val_loss.png")
    plt.legend()
    plt.show()

In [None]:
## Params and making the models:
NUM_LAYERS_TRANSFORMER = 8
NUM_RESIDUALS = 8
EPOCHS = 1000
VEC_ENCODING_SIZE = 512

MODEL_ARGS = {
    "vec_encoding_size": VEC_ENCODING_SIZE,
    "num_residuals": NUM_RESIDUALS,
    "num_tsf_layer": NUM_LAYERS_TRANSFORMER,
    "use_mask": True,
    "dropout_prob": 0.01,
    "embed_dim_low": VEC_ENCODING_SIZE,
    "T_buffer": T_buffer,
    "state_space": state_space,
}

dynamix_model, critiq_model = make_dynamix_and_predictor(MODEL_ARGS)

In [None]:
# Add Warmup Steps! Stabilizing the learning.

train_model(
    log_interval=3,
    learning_rate=0.0008,
    dynamix_model=dynamix_model,
    critiq_model=critiq_model,
    epochs=EPOCHS,
    dynamix_train_loader=dynamix_train_dataset,
    dynamix_val_loader=dynamix_validation_dataset,
    critiq_train_loader=critiq_train_dataset,
    critiq_val_loader=critiq_validation_dataset,
    batch_size=BATCH_SIZE,
    is_lstm=False,
)

In [None]:
dynamix_model.save_checkpoint(
    f"./temporal_dropping_state_predictor-{NUM_LAYERS_TRANSFORMER}_layers-{NUM_RESIDUALS}_residuals-{VEC_ENCODING_SIZE}-vecencoding_size-{EPOCHS}_epochs.pt"
)

critiq_model.save_checkpoint(
    f"./temporal_dropping_state_critic-{NUM_LAYERS_TRANSFORMER}_layers-{NUM_RESIDUALS}_residuals-{VEC_ENCODING_SIZE}-vecencoding_size-{EPOCHS}_epochs.pt"
)

In [None]:
dynamix_model.load_checkpoint(
    f"./temporal_dropping_state_predictor-{NUM_LAYERS_TRANSFORMER}_layers-{NUM_RESIDUALS}_residuals-{VEC_ENCODING_SIZE}-vecencoding_size-{EPOCHS}_epochs.pt"
)

critiq_model.load_checkpoint(
    f"./temporal_dropping_state_critic-{NUM_LAYERS_TRANSFORMER}_layers-{NUM_RESIDUALS}_residuals-{VEC_ENCODING_SIZE}-vecencoding_size-{EPOCHS}_epochs.pt"
)

encoder = dynamix_model.object_encoder

encoder.save_checkpoint(
    "./pretrained_object_encoder-{NUM_LAYERS_TRANSFORMER}_layers-{NUM_RESIDUALS}_residuals-{VEC_ENCODING_SIZE}-vecencoding_size-{EPOCHS}_epochs.pt"
)

In [None]:
# x, y = next(iter(train_data_loader))
# x = {k: v.to('cuda').float() for k, v in x.items()}
# y = y.to('cuda')
# print((y - model(x)).norm() / y.shape[0])
# model.load_checkpoint(f"./dropping_state_predictor-{NUM_LAYERS_TRANSFORMER}_layers-{NUM_RESIDUALS}_residuals-{VEC_ENCODING_SIZE}-vecencoding_size-{EPOCHS}_epochs.pt")
# model.to(model.device)
# model.object_encoder.use_mask = False
model = dynamix_model
model.to(model.device)
model.eval()
model.object_encoder.use_mask = False
print("model:", model.cat_size)
acc = 0.0
total = 0
for i in range(0, len(dynamix_validation_dataset), BATCH_SIZE):
    data, target = dynamix_validation_dataset[i]
    # if is_lstm: model.reset_hidden_state(data.shape[0])
    data = {k: v.to("cuda").squeeze(dim=0) for k, v in data.items()}
    target = {k: v.to("cuda").squeeze(dim=0) for k, v in target.items()}

    y_target = target
    y_pred = model(data)

    correct = th.argmax(y_target["obj_count"], dim=1) == th.argmax(
        F.softmax(y_pred["obj_count"]), dim=1
    )
    acc += sum(correct.int()) / len(correct)
    total += 1

print(acc / total)

data, target = dynamix_validation_dataset[0]
# if is_lstm: model.reset_hidden_state(data.shape[0])
data = {k: v.to("cuda").squeeze(dim=0) for k, v in data.items()}
target = {k: v.to("cuda").squeeze(dim=0) for k, v in target.items()}

y_target = target
y_pred = model(data)

# Percent Error:
print(((y_target["reward"] - y_pred["reward"]) / y_target["reward"]).abs().mean())
print(
    ((y_target["obj_location"] - y_pred["obj_location"]) / y_target["obj_location"])
    .abs()
    .mean()
)

## Gate_SigTanh: 98%, Gate_None: 96%,

# print([(model(x), y) for i in range(len(y))])

model = critiq_model
model.to(model.device)
model.eval()

acc = 0.0
total = 0
for i in range(0, len(critiq_validation_dataset), BATCH_SIZE):
    state, n_state, target = critiq_validation_dataset[i]
    # if is_lstm: model.reset_hidden_state(data.shape[0])
    state = {k: v.to("cuda").squeeze(dim=0) for k, v in state.items()}
    n_state = {k: v.to("cuda").squeeze(dim=0) for k, v in n_state.items()}
    target = {k: v.to("cuda").squeeze(dim=0) for k, v in target.items()}

    y_target = target
    y_pred = model(state, n_state)

    correct = ((y_target["action"] - y_pred["action"]) ** 2).mean().item()
    acc += correct
    total += 1

print("MSE Actions:", acc / total)

state, n_state, target = critiq_validation_dataset[0]
# if is_lstm: model.reset_hidden_state(data.shape[0])
state = {k: v.to("cuda").squeeze(dim=0) for k, v in state.items()}
n_state = {k: v.to("cuda").squeeze(dim=0) for k, v in n_state.items()}
target = {k: v.to("cuda").squeeze(dim=0) for k, v in target.items()}

y_target = target
y_pred = model(state, n_state)

# Percent Error:
print(
    "reward:",
    ((y_target["reward"] - y_pred["reward"]) / y_target["reward"]).abs().mean(),
)
print("MSE Action:", ((y_target["action"] - y_pred["action"]) ** 2).mean())

In [None]:
activation = {}


def get_activation(name):
    def hook(model, input, output):
        activation[name] = output[0].detach()

    return hook


model.object_encoder.trns_encoder.layers._modules["5"].self_attn._forward_hooks.clear()
model.object_encoder.trns_encoder.layers._modules["5"].self_attn.register_forward_hook(
    get_activation("last_layer_activation")
)

In [None]:
data, target = validation_dataset[1]
# if is_lstm: model.reset_hidden_state(data.shape[0])
data = {k: v.to("cuda").squeeze(dim=0) for k, v in data.items()}
target = {k: v.to("cuda").squeeze(dim=0) for k, v in target.items()}

y_target = target
y_pred = model(data)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Extract the activations for the last layer
last_layer_activations = activation["last_layer_activation"]
activations = last_layer_activations.cpu()
# Assuming your activations are in shape (batch_size, encoding_vector_size)
# You might need to adjust this depending on the actual shape of your activations
batch_size, num_objs, encoding_vector_size = last_layer_activations.shape
# last_layer_activations = last_layer_activations.cpu().reshape(batch_size, -1)
# Create a heatmap
print(th.argmax(y_target["obj_count"][2]))
plt.figure(figsize=(12, 4))  # Adjust the figure size as needed
sns.heatmap([last_layer_activations.cpu().numpy()[2, -1]], cmap="viridis")
plt.ylabel("Batch Index")
plt.xlabel("Encoding Vector Dimension")
plt.title("Heatmap of Last Layer Activations")
plt.show()

In [None]:
fig, axs = plt.subplots(
    1, 2, figsize=(24, 6)
)  # 24 width to accommodate both side by side, 6 height

# Heatmap for index 2
sns.heatmap(
    activations[2, :].reshape(-1, activations.shape[-1]), ax=axs[0], cmap="viridis"
)
axs[0].set_title("Heatmap of Last Layer Activations for Index 2")
axs[0].set_ylabel("Batch Index (2)")
axs[0].set_xlabel("Encoding Vector Dimension")

# Heatmap for index 3
sns.heatmap(
    activations[3, :].reshape(-1, activations.shape[-1]), ax=axs[1], cmap="viridis"
)
axs[1].set_title("Heatmap of Last Layer Activations for Index 3")
axs[1].set_ylabel("Batch Index (3)")
axs[1].set_xlabel("Encoding Vector Dimension")

In [None]:
from torch.nn.functional import cosine_similarity

for i in range(activations.shape[0] - 4):
    print(f"Index: {i}:")
    print(
        (
            activations[i, :, -1].unsqueeze(dim=0)
            @ activations[i + 1, :, -1].unsqueeze(dim=1)
            / (
                activations[i, :, -1].unsqueeze(dim=0).norm()
                + activations[i + 1, :, -1].unsqueeze(dim=1).norm()
            )
        )
    )
    print(
        (
            activations[i, :, -1].unsqueeze(dim=0)
            @ activations[i + 3, :, -1].unsqueeze(dim=1)
            / (
                activations[i, :, -1].unsqueeze(dim=0).norm()
                + activations[i + 3, :, -1].unsqueeze(dim=1).norm()
            )
        )
    )
    print(activations[i, :, -1].mean())

activations[2, :, -1].unsqueeze(dim=0).shape, activations[3, :, -1].unsqueeze(
    dim=1
).shape

activations.shape

In [None]:
### Model Predictive Control
model = DynamixModel()
model.load_checkpoint("./save_embed.pt")
model.to("cuda")
model.eval()

In [None]:
from typing import Dict, Union


def sample_model_actions(
    state: Dict[str, np.array],
    k_num_actions: int = 128,
    device: Union[th.device, str] = "cuda",
):
    global env, model

    def sample_action():
        action = env.action_space.sample()
        return action / 3.5

    # Normalization of observation
    for key, obs in state.items():
        arr = np.where(np.isnan(obs), 0, obs)
        arr = np.where(arr < -(4**2), 0, arr)
        arr = np.where(arr > 4**2, 0, arr)
        if 1 in arr.shape and len(arr.shape) > 1:
            arr = arr.squeeze()
        state[key] = arr
    state = {k: [val.copy()] * k_num_actions for k, val in state.items()}
    # Samples k-1 actions, and adds the 'no' action as a base case to ensure our model can always use no action
    state["action"] = [np.array([0, 0, 0, 0, 0])] + [
        sample_action() for _ in range(k_num_actions - 1)
    ]

    # Create the tensor for the model
    state_tensors = {
        k: th.stack(
            [
                th.nan_to_num(
                    th.tensor(v, dtype=th.float32), nan=0.0, posinf=0.0, neginf=0.0
                ).to(device)
                for v in val
            ],
            dim=0,
        )
        for k, val in state.items()
    }

    preds = model(state_tensors)

    # Use metric for selecting an action. In this case we use rewards:
    rewards = preds["reward"].cpu().detach().numpy()
    best_action_idx = np.argmax(rewards)

    # Return the action we sampled at that index
    return state["action"][best_action_idx] * 3.5

In [None]:
# Dataset Collection for States/Preds
import time as t

sim = env.simController


def ModelPredictiveControl(num_steps):
    done = False
    state = env.reset()[0].copy()
    i = 0
    while i < num_steps:
        if done:
            state = env.reset()[0].copy()
        i += 1
        action = sample_model_actions(state, 512)  # sample 512
        print("Action:", action)
        state, reward, done, _ = env.step(action)
        t.sleep(0.25)


ModelPredictiveControl(30000)

### Reinforcment Learning with Stable Baselines3

In [4]:
# ## Initialization:: Use env stbl3 or raylib
import os
import warnings

warnings.filterwarnings("ignore")
import torch as th
from torch import multiprocessing
from control_dropping_rpal.RL.control_dropping_env import (
    BerretHandGymRayLibWrapper,
    T_buffer,
    default_action_space,
)

from collections import defaultdict

import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
import gymnasium as gym
from torchrl.envs import GymWrapper
from tqdm import tqdm
from gymnasium.utils.env_checker import check_env

import logging

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

th.set_default_device("cuda")

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Define constants
DATA_SAVE_PATH = os.path.join(CONTROL_DROP_DIR, "Data_Collection")
EPOCHS = 3
NUM_WORKERS = 1
ENVS_PER_WORKER = 1
SAVE_INTERVAL = 10
NUM_INTERACTIONS = 0

NUM_LAYERS_TRANSFORMER = 6
NUM_RESIDUALS = 3
EPOCHS = 1000
VEC_ENCODING_SIZE = 128

EMBEDDING_CHKPOINT = os.path.join(CONTROL_DROP_DIR, "extracted_encoder_state_dict.pth")

model_config = {
    "temporal_dim": T_buffer,  # Assuming T_buffer is defined elsewhere in your code
    "obj_encoder_vec_encoding_size": VEC_ENCODING_SIZE,
    "obj_encoder_num_tsf_layer": NUM_LAYERS_TRANSFORMER,
    "obj_encoder_load_path": EMBEDDING_CHKPOINT,
    "obj_encoder_freeze_params": True,  # Set to True if you want to freeze the pretrained encoder
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

"Model Config:", model_config

('Model Config:',
 {'temporal_dim': 5,
  'obj_encoder_vec_encoding_size': 128,
  'obj_encoder_num_tsf_layer': 6,
  'obj_encoder_load_path': '/media/rpal/Drive_10TB/John/ControlDrop-RPAL/extracted_encoder_state_dict.pth',
  'obj_encoder_freeze_params': True,
  'device': 'cuda'})

In [5]:
## Initialize this
is_fork = multiprocessing.get_start_method() == "fork"
device = th.device(0) if th.cuda.is_available() and not is_fork else th.device("cpu")
num_cells = 256  # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0

frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 50_000

sub_batch_size = 64  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10  # optimization steps per batch of data collected
clip_epsilon = (
    0.2  # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

env_config = dict(
    detailed_training=True,
    object_quantity=7,
)

# # Add any configuration parameters needed for your environment
# base_env = BerretHandGymRayLibWrapper(env_config)

# # Wrap the environment with GymWrapper
# env = GymWrapper(base_env)

In [7]:
from control_dropping_rpal.RL.control_dropping_env import BerretHandGymRayLibWrapper
import numpy as np

# Create the environment

env = BerretHandGymRayLibWrapper(env_config)

# reset_val = env.reset()
# step_val = env.step(np.array([0, 0, 0, 0, 0]))

# # print("Reset Val:", reset_val)
# # print("Step Val:", step_val)
# # print("Obs space:", env.observation_space)
# # check_env(env)

INFO:root:Creating SIM on port 58691
INFO:root:Connected to remote API server
INFO:root:Finshed loading tactile handles.
INFO:root:Finished loading armjoint handles.
INFO:root:Finished loading hand handles.
INFO:root:Finished loading Object handles.
INFO:root:Finished loading Tactile handles.
INFO:root:Finished loading Palm Handles.
INFO:root:Finished loading Cam Handles.


In [None]:
# # Debugs the environment:
# print("Obs Space:")
# for k, v in env.observation_space.items():
#     print(f"\t{k}: {v.shape}")

# print("\nReset:")
# for k, v in reset_val[0].items():
#     print(f"\t{k}: {v.shape}")

# print("\nStep:")
# for k, v in step_val[0].items():
#     print(f"\t{k}: {v.shape}")

# # Check for differences
# obs_space_keys = set(env.observation_space.keys())
# reset_keys = set(reset_val[0].keys())
# step_keys = set(step_val[0].keys())

# print("\nDifferent keys:")
# print(f"In obs_space but not in reset: {obs_space_keys - reset_keys}")
# print(f"In obs_space but not in step: {obs_space_keys - step_keys}")
# print(f"In reset but not in obs_space: {reset_keys - obs_space_keys}")
# print(f"In step but not in obs_space: {step_keys - obs_space_keys}")

# print("\nDifferent shapes:")
# for k in obs_space_keys.intersection(reset_keys, step_keys):
#     reset_shape = reset_val[0][k].shape
#     obs_shape = env.observation_space[k].shape
#     if reset_shape != obs_shape:
#         print(f"\t{k}: Reset shape {reset_shape}, Step shape {obs_shape}")

In [8]:
from stable_baselines3 import PPO
from control_dropping_rpal.RL.Networks.ActorCriticNetwork import ControlDropPolicy


EXPERIMENT_NAME = "dynamix_control_drop"  # You can change this to any name you like

# PPO Parameters:
NUM_STEPS = 512


# Create the tensorboard log directory
tensorboard_log_dir = os.path.join(CONTROL_DROP_DIR, "rl_logs", EXPERIMENT_NAME)
os.makedirs(tensorboard_log_dir, exist_ok=True)

# Initialize the PPO model with the custom policy and tensorboard logging
model = PPO(
    policy=ControlDropPolicy,
    env=env,
    policy_kwargs={"model_config": model_config},
    verbose=1,
    tensorboard_log=tensorboard_log_dir,
    ## Hyper Parameters:
    n_steps=NUM_STEPS,
)


# Train the model
model.learn(total_timesteps=100000, tb_log_name=EXPERIMENT_NAME)

# Save the model
model.save(os.path.join(CONTROL_DROP_DIR, f"{EXPERIMENT_NAME}_model"))

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Obs Space: Dict('actions': Box(-inf, inf, (8, 5), float32), 'finger_1_location': Box(-inf, inf, (5, 104), float32), 'finger_1_tactile': Box(-inf, inf, (5, 24), float32), 'finger_2_location': Box(-inf, inf, (5, 104), float32), 'finger_2_tactile': Box(-inf, inf, (5, 24), float32), 'finger_3_location': Box(-inf, inf, (5, 104), float32), 'finger_3_tactile': Box(-inf, inf, (5, 24), float32), 'obj_location': Box(-inf, inf, (5, 7, 6), float32), 'obj_velocity': Box(-inf, inf, (5, 7, 6), float32), 'palm_location': Box(-inf, inf, (5, 74), float32), 'palm_tactile': Box(-inf, inf, (5, 24), float32), 'state_attrib': Box(-inf, inf, (45,), float32))
[ControlDropPolicy]: loading feature encoder from checkpoint.


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...


Logging to /media/rpal/Drive_10TB/John/ControlDrop-RPAL/rl_logs/dynamix_control_drop/dynamix_control_drop_6


INFO:root:Step:[ 1.          1.         -2.         -1.          0.33333334 -0.6666667
 -0.33333334]
INFO:root:Preventing Movement.
INFO:root:Step:[-1.         -1.         -2.         -0.         -0.33333334 -0.6666667
 -0.        ]
INFO:root:Step:[ 1.          2.         -1.         -0.          0.6666667  -0.33333334
 -0.        ]
INFO:root:Step:[-1.         -2.         -1.         -2.         -0.6666667  -0.33333334
 -0.6666667 ]
INFO:root:Step:[-1.         -0.         -1.          0.         -0.         -0.33333334
  0.        ]
INFO:root:Threshold finger1:-1.068835973739624
INFO:root:Step:[-0.         -0.          1.          2.         -0.          0.33333334
  0.6666667 ]
INFO:root:Step:[0.         1.         1.         0.         0.33333334 0.33333334
 0.        ]
INFO:root:Step:[-0.          1.          0.          0.          0.33333334  0.
  0.        ]
INFO:root:Threshold finger1:-1.06447434425354
INFO:root:Step:[-1.          1.         -1.         -0.          0.33333334 -

Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[ 1.          1.          0.         -1.          0.33333334  0.
 -0.33333334]
INFO:root:Threshold finger3:-1.0358686447143555
INFO:root:Step:[ 0.         -1.         -1.         -1.         -0.33333334 -0.33333334
 -0.33333334]
INFO:root:Step:[-1.         -2.          1.         -1.         -0.6666667   0.33333334
 -0.33333334]
INFO:root:Step:[-2.          1.         -1.         -1.          0.33333334 -0.33333334
 -0.33333334]
INFO:root:Preventing Movement.
INFO:root:Step:[ 3.          1.          1.         -0.          0.33333334  0.33333334
 -0.        ]
INFO:root:Step:[ 0.         -1.          0.          3.         -0.33333334  0.
  1.        ]
INFO:root:Step:[ 1.         -1.          1.         -0.         -0.33333334  0.33333334
 -0.        ]
INFO:root:Step:[ 0.          1.          0.         -0.          0.33333334  0.
 -0.   

Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 3
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[ 1.          1.          2.         -0.          0.33333334  0.6666667
 -0.        ]
INFO:root:Threshold finger2:-1.520904779434204
INFO:root:Threshold finger3:-1.7151386737823486
INFO:root:Step:[-0. -0.  0. -0. -0.  0. -0.]
INFO:root:Threshold finger2:-1.5835802555084229
INFO:root:Threshold finger3:-1.7763220071792603
INFO:root:Step:[ 1.         -1.          1.         -0.         -0.33333334  0.33333334
 -0.        ]
INFO:root:Threshold finger2:-1.8018782138824463
INFO:root:Threshold finger3:-2.072021722793579
INFO:root:Step:[ 2.         -0.          1.         -0.         -0.          0.33333334
 -0.        ]
INFO:root:Threshold finger2:-1.7493549585342407
INFO:root:Threshold finger3:-2.119250535964966
INFO:root:Step:[-1.          1.          1.         -0.          0.33333334  0.33333334
 -0.        ]
INFO:root:Threshold finger2:-2.

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 28.7     |
|    ep_rew_mean     | -11.4    |
| time/              |          |
|    fps             | 0        |
|    iterations      | 1        |
|    time_elapsed    | 156      |
|    total_timesteps | 128      |
---------------------------------


INFO:root:Step:[ 0.          0.          1.         -0.          0.          0.33333334
 -0.        ]
INFO:root:Hand Stablazation...
INFO:root:Step:[ 1.         -0.          1.         -1.         -0.          0.33333334
 -0.33333334]
INFO:root:Hand Stablazation...
INFO:root:Step:[-0.  0.  0. -0.  0.  0. -0.]
INFO:root:Hand Stablazation...
INFO:root:Step:[ 0.          1.         -1.         -0.          0.33333334 -0.33333334
 -0.        ]
INFO:root:Failed at Dropping Single Object
INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[-1.          3.          1.         -0.          1.          0.33333334
 -0.        ]
INFO:root:Step:[ 0. -0.  0. -0. -0.  0. -0.]
INFO:root:Step:[ 0.          2.         -1.          0.          0.6666667  -0.33333334
  0.        ]
INFO:root:Step:[0.        0.        0.        2.        0.        0.        0.6666667]
INFO:root:Step:[ 0.          1.         -

Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[ 1.          1.         -2.         -0.          0.33333334 -0.6666667
 -0.        ]
INFO:root:Threshold finger3:-1.1498818397521973
INFO:root:Step:[-0.        -0.         2.         0.        -0.         0.6666667
  0.       ]
INFO:root:Threshold finger3:-1.1498818397521973
INFO:root:Step:[0.        0.        2.        2.        0.        0.6666667 0.6666667]
INFO:root:Threshold finger3:-1.1340696811676025
INFO:root:Step:[-1.         -0.         -1.         -0.         -0.         -0.33333334
 -0.        ]
INFO:root:Threshold finger3:-1.2383099794387817
INFO:root:Step:[-0.        -0.         0.         2.        -0.         0.
  0.6666667]
INFO:root:Threshold finger3:-1.5063536167144775
INFO:root:Step:[-0.         -0.         -1.          0.         -0.         -0.33333334
  0.        ]
INFO:root:Threshold finger3:-1.4138402938842773
I

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 30.2         |
|    ep_rew_mean          | -12          |
| time/                   |              |
|    fps                  | 0            |
|    iterations           | 2            |
|    time_elapsed         | 346          |
|    total_timesteps      | 256          |
| train/                  |              |
|    approx_kl            | 5.874317e-06 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -7.09        |
|    explained_variance   | 0.000252     |
|    learning_rate        | 0.0003       |
|    loss                 | 20.5         |
|    n_updates            | 10           |
|    policy_gradient_loss | 0.000215     |
|    std                  | 1            |
|    value_loss           | 50.5         |
------------------------------------------


INFO:root:Step:[-2.          2.         -1.         -2.          0.6666667  -0.33333334
 -0.6666667 ]
INFO:root:Threshold finger3:-1.0423671007156372
INFO:root:Drop Detected...
INFO:root:Step:[ 1.         -1.          0.          0.         -0.33333334  0.
  0.        ]
INFO:root:Threshold finger3:-1.0423671007156372
INFO:root:Drop Detected...
INFO:root:Step:[-0.  0.  0.  0.  0.  0.  0.]
INFO:root:Drop Detected...
INFO:root:Step:[-1.         -1.          0.         -0.         -0.33333334  0.
 -0.        ]
INFO:root:Drop Detected...
INFO:root:Step:[-0.          2.          2.         -1.          0.6666667   0.6666667
 -0.33333334]
INFO:root:Drop Detected...
INFO:root:Step:[ 2.         -0.          2.         -1.         -0.          0.6666667
 -0.33333334]
INFO:root:Drop Detected...
INFO:root:Step:[-1.          1.         -1.          1.          0.33333334 -0.33333334
  0.33333334]
INFO:root:Drop Detected...
INFO:root:Step:[-0.          0.         -0.         -1.          0.         

Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[-1.          1.          0.         -1.          0.33333334  0.
 -0.33333334]
INFO:root:Preventing Movement.
INFO:root:Threshold finger3:-1.4555232524871826
INFO:root:Step:[-1.          1.         -2.         -1.          0.33333334 -0.6666667
 -0.33333334]
INFO:root:Preventing Movement.
INFO:root:Threshold finger3:-1.4555232524871826
INFO:root:Step:[-1.          1.          1.         -1.          0.33333334  0.33333334
 -0.33333334]
INFO:root:Preventing Movement.
INFO:root:Step:[ 1.          1.         -1.         -2.          0.33333334 -0.33333334
 -0.6666667 ]
INFO:root:Step:[ 0.         -1.         -1.          0.         -0.33333334 -0.33333334
  0.        ]
INFO:root:Step:[ 1.         -1.         -0.         -0.         -0.33333334 -0.
 -0.        ]
INFO:root:Step:[ 1.          1.          0.         -1.          0.33333334  0.


Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[ 2.         -1.         -0.         -2.         -0.33333334 -0.
 -0.6666667 ]
INFO:root:Step:[-1.        -0.        -2.         0.        -0.        -0.6666667
  0.       ]
INFO:root:Step:[-1.          1.          1.         -0.          0.33333334  0.33333334
 -0.        ]
INFO:root:Step:[ 0.          1.          3.         -0.          0.33333334  1.
 -0.        ]
INFO:root:Step:[-0.         -0.         -1.         -0.         -0.         -0.33333334
 -0.        ]
INFO:root:Step:[-1.         -0.          1.          0.         -0.          0.33333334
  0.        ]
INFO:root:Preventing Movement.
INFO:root:Step:[ 0.          0.         -1.          1.          0.         -0.33333334
  0.33333334]
INFO:root:Step:[ 1.          1.          1.         -2.          0.33333334  0.33333334
 -0.6666667 ]
INFO:root:Step:[-2.         -0.         

Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[ 1.          1.          1.         -1.          0.33333334  0.33333334
 -0.33333334]
INFO:root:Threshold finger1:-1.0033857822418213
INFO:root:Step:[ 3.         -0.          1.         -1.         -0.          0.33333334
 -0.33333334]
INFO:root:Threshold finger1:-1.0033857822418213
INFO:root:Step:[ 1.          1.         -1.          2.          0.33333334 -0.33333334
  0.6666667 ]
INFO:root:Threshold finger1:-1.0202703475952148
INFO:root:Step:[1.         0.         1.         2.         0.         0.33333334
 0.6666667 ]
INFO:root:Threshold finger1:-1.0801870822906494
INFO:root:Step:[-1.          1.         -0.         -2.          0.33333334 -0.
 -0.6666667 ]
INFO:root:Threshold finger1:-1.2593870162963867
INFO:root:Step:[ 1.         0.        -0.        -2.         0.        -0.
 -0.6666667]
INFO:root:Threshold finger1:-1.2337186336

Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 3
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[-1.          0.          1.         -0.          0.          0.33333334
 -0.        ]
INFO:root:Preventing Movement.
INFO:root:Threshold finger1:-1.601845383644104
INFO:root:Threshold finger3:-1.3555316925048828
INFO:root:Step:[ 1.          0.         -1.          1.          0.         -0.33333334
  0.33333334]
INFO:root:Threshold finger1:-1.601845383644104
INFO:root:Threshold finger3:-1.3555316925048828
INFO:root:Step:[-0.         -1.         -1.          0.         -0.33333334 -0.33333334
  0.        ]
INFO:root:Threshold finger1:-1.8103870153427124
INFO:root:Threshold finger3:-1.515458106994629
INFO:root:Step:[-1.  0. -0. -0.  0. -0. -0.]
INFO:root:Preventing Movement.
INFO:root:Threshold finger1:-1.9540904760360718
INFO:root:Threshold finger3:-1.6474558115005493
INFO:root:Step:[-1.         -1.         -0.         -0.         -0.333

-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 28.7          |
|    ep_rew_mean          | -7.72         |
| time/                   |               |
|    fps                  | 0             |
|    iterations           | 3             |
|    time_elapsed         | 538           |
|    total_timesteps      | 384           |
| train/                  |               |
|    approx_kl            | 5.6871213e-06 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -7.09         |
|    explained_variance   | 0.000565      |
|    learning_rate        | 0.0003        |
|    loss                 | 27.6          |
|    n_updates            | 20            |
|    policy_gradient_loss | 0.000126      |
|    std                  | 1             |
|    value_loss           | 70.9          |
-------------------------------------------


INFO:root:Step:[-0.          2.         -0.         -1.          0.6666667  -0.
 -0.33333334]
INFO:root:Threshold finger3:-1.0516772270202637
INFO:root:Step:[-3.         -1.         -2.         -1.         -0.33333334 -0.6666667
 -0.33333334]
INFO:root:Preventing Movement.
INFO:root:Threshold finger3:-1.0516772270202637
INFO:root:Step:[-1.          0.          1.         -1.          0.          0.33333334
 -0.33333334]
INFO:root:Threshold finger3:-1.0060609579086304
INFO:root:Step:[-1.         2.        -3.         0.         0.6666667 -1.
  0.       ]
INFO:root:Preventing Movement.
INFO:root:Step:[-1.        -2.        -0.        -2.        -0.6666667 -0.
 -0.6666667]
INFO:root:Preventing Movement.
INFO:root:Step:[ 0.          0.         -0.          1.          0.         -0.
  0.33333334]
INFO:root:Step:[ 2.         -0.         -0.         -1.         -0.         -0.
 -0.33333334]
INFO:root:Step:[ 1.          1.          0.         -1.          0.33333334  0.
 -0.33333334]
INFO:roo

Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[-1.         -2.          1.          0.         -0.6666667   0.33333334
  0.        ]
INFO:root:Threshold finger2:-1.3731021881103516
INFO:root:Threshold finger3:-1.9744501113891602
INFO:root:Step:[-2.          1.          0.         -1.          0.33333334  0.
 -0.33333334]
INFO:root:Threshold finger2:-1.3731021881103516
INFO:root:Threshold finger3:-1.9744501113891602
INFO:root:Step:[-1.          0.         -0.         -1.          0.         -0.
 -0.33333334]
INFO:root:Threshold finger2:-1.4291726350784302
INFO:root:Threshold finger3:-1.909266471862793
INFO:root:Step:[-1.        -2.         0.         0.        -0.6666667  0.
  0.       ]
INFO:root:Threshold finger2:-1.575321912765503
INFO:root:Threshold finger3:-1.6828701496124268
INFO:root:Step:[-1.        -0.         2.        -0.        -0.         0.6666667
 -0.       ]
INFO:root

-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 27.6          |
|    ep_rew_mean          | -6.28         |
| time/                   |               |
|    fps                  | 0             |
|    iterations           | 4             |
|    time_elapsed         | 715           |
|    total_timesteps      | 512           |
| train/                  |               |
|    approx_kl            | 3.3369288e-06 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -7.09         |
|    explained_variance   | 0.000295      |
|    learning_rate        | 0.0003        |
|    loss                 | 1.51          |
|    n_updates            | 30            |
|    policy_gradient_loss | 8.04e-05      |
|    std                  | 1             |
|    value_loss           | 3.48          |
-------------------------------------------


INFO:root:Step:[0.         2.         1.         0.         0.6666667  0.33333334
 0.        ]
INFO:root:Drop Detected...
INFO:root:Step:[ 1.         -1.          0.          1.         -0.33333334  0.
  0.33333334]
INFO:root:Drop Detected...
INFO:root:Step:[-1.         -1.          1.          1.         -0.33333334  0.33333334
  0.33333334]
INFO:root:Drop Detected...
INFO:root:Step:[-0.         -0.         -1.         -1.         -0.         -0.33333334
 -0.33333334]
INFO:root:Threshold finger2:-1.0880597829818726
INFO:root:Drop Detected...
INFO:root:Step:[ 2.         -0.         -0.          1.         -0.         -0.
  0.33333334]
INFO:root:Threshold finger2:-1.064192533493042
INFO:root:Drop Detected...
INFO:root:Step:[ 0.          2.          1.         -1.          0.6666667   0.33333334
 -0.33333334]
INFO:root:Threshold finger2:-1.064294695854187
INFO:root:Drop Detected...
INFO:root:Step:[-2.          1.          0.         -1.          0.33333334  0.
 -0.33333334]
INFO:root:Thr

Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[-1.          0.         -1.          1.          0.         -0.33333334
  0.33333334]
INFO:root:Step:[-0.         -1.          0.         -0.         -0.33333334  0.
 -0.        ]
INFO:root:Step:[-1.         -0.          1.          0.         -0.          0.33333334
  0.        ]
INFO:root:Step:[ 0.         0.         0.        -2.         0.         0.
 -0.6666667]
INFO:root:Step:[-1.         -2.         -1.          1.         -0.6666667  -0.33333334
  0.33333334]
INFO:root:Step:[ 0.          1.         -1.          1.          0.33333334 -0.33333334
  0.33333334]
INFO:root:Step:[ 1.          0.         -1.          2.          0.         -0.33333334
  0.6666667 ]
INFO:root:Step:[-1.          0.          1.          1.          0.          0.33333334
  0.33333334]
INFO:root:Step:[-1.          1.         -1.          0.          0.333

Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[-2.         -0.         -0.         -1.         -0.         -0.
 -0.33333334]
INFO:root:Preventing Movement.
INFO:root:Threshold finger1:-1.0761799812316895
INFO:root:Step:[ 1.         -1.          0.         -2.         -0.33333334  0.
 -0.6666667 ]
INFO:root:Threshold finger1:-1.0761799812316895
INFO:root:Step:[ 0. -0. -0.  0. -0. -0.  0.]
INFO:root:Threshold finger1:-1.067996621131897
INFO:root:Step:[-1.         -1.         -0.         -1.         -0.33333334 -0.
 -0.33333334]
INFO:root:Threshold finger1:-1.3978166580200195
INFO:root:Step:[-0.         -0.         -0.          1.         -0.         -0.
  0.33333334]
INFO:root:Threshold finger1:-1.5154750347137451
INFO:root:Step:[-1.          2.          0.         -1.          0.6666667   0.
 -0.33333334]
INFO:root:Preventing Movement.
INFO:root:Threshold finger1:-1.2566012144088745


Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 3
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[ 1. -0. -0. -0. -0. -0. -0.]
INFO:root:Threshold finger3:-1.0402133464813232
INFO:root:Step:[ 0.          1.          1.         -0.          0.33333334  0.33333334
 -0.        ]
INFO:root:Threshold finger3:-1.0402133464813232
INFO:root:Step:[-0.         -1.          0.         -2.         -0.33333334  0.
 -0.6666667 ]
INFO:root:Threshold finger3:-1.0463712215423584
INFO:root:Step:[ 0.          1.          1.         -0.          0.33333334  0.33333334
 -0.        ]
INFO:root:Step:[-0.          1.         -1.          0.          0.33333334 -0.33333334
  0.        ]
INFO:root:Step:[-1.         -2.          2.          1.         -0.6666667   0.6666667
  0.33333334]
INFO:root:Step:[ 0.          1.          1.         -1.          0.33333334  0.33333334
 -0.33333334]
INFO:root:Step:[-0.         -0.         -1.          1.         -0.     

Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[-1.         -2.          1.         -0.         -0.6666667   0.33333334
 -0.        ]
INFO:root:Preventing Movement.
INFO:root:Step:[ 2.          0.          1.         -1.          0.          0.33333334
 -0.33333334]
INFO:root:Step:[ 1.         -1.          1.          0.         -0.33333334  0.33333334
  0.        ]
INFO:root:Step:[ 1.          2.         -0.          1.          0.6666667  -0.
  0.33333334]
INFO:root:Step:[ 1.         -1.          1.          0.         -0.33333334  0.33333334
  0.        ]
INFO:root:Step:[-1.          1.         -3.          0.          0.33333334 -1.
  0.        ]
INFO:root:Step:[-0. -0. -0.  0. -0. -0.  0.]
INFO:root:Step:[ 1.          1.         -0.          1.          0.33333334 -0.
  0.33333334]
INFO:root:Step:[ 1.         -0.          1.         -1.         -0.          0.33333334
 -0.333333

-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 27.2          |
|    ep_rew_mean          | -7.55         |
| time/                   |               |
|    fps                  | 0             |
|    iterations           | 5             |
|    time_elapsed         | 926           |
|    total_timesteps      | 640           |
| train/                  |               |
|    approx_kl            | 3.5739504e-06 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -7.09         |
|    explained_variance   | 0.00164       |
|    learning_rate        | 0.0003        |
|    loss                 | 13.4          |
|    n_updates            | 40            |
|    policy_gradient_loss | -4.11e-05     |
|    std                  | 1             |
|    value_loss           | 24.3          |
-------------------------------------------


INFO:root:Step:[ 1.         -1.         -0.          1.         -0.33333334 -0.
  0.33333334]
INFO:root:Failed at Dropping Single Object
INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[ 0.          1.          0.         -1.          0.33333334  0.
 -0.33333334]
INFO:root:Step:[ 0.         -1.         -0.         -0.         -0.33333334 -0.
 -0.        ]
INFO:root:Step:[-2.          1.          2.          0.          0.33333334  0.6666667
  0.        ]
INFO:root:Preventing Movement.
INFO:root:Step:[-0.          1.         -1.         -0.          0.33333334 -0.33333334
 -0.        ]
INFO:root:Step:[ 1.         -0.          0.          1.         -0.          0.
  0.33333334]
INFO:root:Step:[1.        0.        0.        2.        0.        0.        0.6666667]
INFO:root:Step:[-0.        -0.        -2.        -0.        -0.        -0.6666667
 -0.       ]
INFO:root:Threshold finger2:-

Checking Sim...


INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 1
INFO:root:Resetting object locations...
INFO:root:object count for initialized scene: 2
INFO:root:Scene Validated. Starting Dropping...
INFO:root:Step:[ 0.         2.        -2.        -0.         0.6666667 -0.6666667
 -0.       ]
INFO:root:Step:[-2.          2.         -1.         -2.          0.6666667  -0.33333334
 -0.6666667 ]
INFO:root:Preventing Movement.
INFO:root:Step:[ 1.          0.         -1.          2.          0.         -0.33333334
  0.6666667 ]
INFO:root:Step:[ 0. -0. -0. -0. -0. -0. -0.]
INFO:root:Step:[ 1.         -1.         -1.         -0.         -0.33333334 -0.33333334
 -0.        ]
INFO:root:Step:[-0.          0.          1.         -1.          0.          0.33333334
 -0.33333334]
INFO:root:Step:[-1.          1.          1.         -1.          0.33333334  0.33333334
 -0.33333334]
INFO:root:Step:[ 0.          1.          2.         -0.          0.33333334  0.6666667
 -0.    

In [None]:
## Extracts encoder from state dict

from control_dropping_rpal.RL.Networks.ExtractorNetworks import (
    TemporalObjectTactileEncoder_Additive,
    DynamixModel,
)


# Params
NUM_LAYERS_TRANSFORMER = 6
NUM_RESIDUALS = 3
EPOCHS = 1000
VEC_ENCODING_SIZE = 128

MODEL_ARGS = {
    "vec_encoding_size": VEC_ENCODING_SIZE,
    "num_residuals": NUM_RESIDUALS,
    "num_tsf_layer": NUM_LAYERS_TRANSFORMER,
    "use_mask": True,
    "dropout_prob": 0.01,
    "embed_dim_low": VEC_ENCODING_SIZE,
    "T_buffer": T_buffer,
    "state_space": state_space,
}


def extract_encoder_state_dict(checkpoint_path):
    # Load the full checkpoint
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    if "state_dict" in checkpoint:
        checkpoint = checkpoint["state_dict"]

    if "dynamix_model" not in checkpoint:
        raise KeyError("Expected 'dynamix_model' in checkpoint, but it was not found.")

    # Extract the DynamixModel state dict
    dynamix_state_dict = checkpoint["dynamix_model"]

    # Create a temporary DynamixModel to load the state dict
    temp_dynamix = DynamixModel(
        embed_dim_high=1024,
        embed_dim_low=MODEL_ARGS["embed_dim_low"],
        device="cuda",
        dropout_prob=MODEL_ARGS["dropout_prob"],
        num_tsf_layer=MODEL_ARGS["num_tsf_layer"],
        num_residual_blocks=MODEL_ARGS["num_residuals"],
        vec_encoding_size=MODEL_ARGS["vec_encoding_size"],
        use_mask=MODEL_ARGS["use_mask"],
    )

    # Load the state dict into the temporary DynamixModel
    temp_dynamix.load_state_dict(dynamix_state_dict, strict=False)

    # Extract the encoder's state dict
    encoder_state_dict = temp_dynamix.object_encoder.state_dict()

    return encoder_state_dict


def save_encoder_state_dict(encoder_state_dict, save_path):
    torch.save(encoder_state_dict, save_path)


# Set the paths
checkpoint_path = "/media/rpal/Drive_10TB/John/ControlDrop-RPAL/logs/dynamix_critiq_training-9-01-2024_11/checkpoints/best_model_epoch_47.pth"

save_path = os.path.join(CONTROL_DROP_DIR, "extracted_encoder_state_dict.pth")

encoder_state_dict = extract_encoder_state_dict(checkpoint_path)
print("Successfully extracted encoder state dict.")

# Save the encoder's state dict
save_encoder_state_dict(encoder_state_dict, save_path)
print(f"Saved encoder state dict to {save_path}")