In [1]:
import numpy as np
import os
import time
import torch
import torch.nn as nn
import pickle
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
import submodules.data_filter as _df
import diffusion_pipline.data_processing as dproc
import diffusion_pipline.model as md
import submodules.cleaned_file_parser as cfp


In [2]:
checkpoint_path = '/home/cam/Documents/diffusion_policy_cam/no_sync/checkpoints/checkpoint_(SV)(SV)SS_epoch_199.pth'

checkpoint = torch.load(checkpoint_path)


In [3]:
checkpoint.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'ema_state_dict', 'len_dataloader', 'dataset_stats'])

In [4]:
# observation and action dimensions corrsponding to

num_epochs =200
obs_dim = 57
action_dim = 12
# parameters
pred_horizon = 16
obs_horizon = 2
action_horizon = 8
target_fps = 120.0

action_item = ['chisel', 'gripper']
obs_item = ['battery']

# create network object
noise_pred_net = md.ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

# example inputs
noised_action = torch.randn((1, pred_horizon, action_dim))
obs = torch.zeros((1, obs_horizon, obs_dim))
diffusion_iter = torch.zeros((1,))

# the noise prediction network
# takes noisy action, diffusion iteration and observation as input
# predicts the noise added to action
noise = noise_pred_net(
    sample=noised_action,
    timestep=diffusion_iter,
    global_cond=obs.flatten(start_dim=1))

# illustration of removing noise
# the actual noise removal is performed by NoiseScheduler
# and is dependent on the diffusion noise schedule
denoised_action = noised_action - noise

# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = 100
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # the choise of beta schedule has big impact on performance
    # we found squared cosine works the best
    beta_schedule='squaredcos_cap_v2',
    # clip output to [-1,1] to improve stability
    clip_sample=True,
    # our network predicts noise (instead of denoised action)
    prediction_type='epsilon'
)

# device transfer
device = torch.device('cuda')
_ = noise_pred_net.to(device)

# Exponential Moving Average
# accelerates training and improves stability
# holds a copy of the model weights
ema = EMAModel(
    parameters=noise_pred_net.parameters(),
    power=0.75)

# Standard ADAM optimizer
# Note that EMA parametesr are not optimized
optimizer = torch.optim.AdamW(
    params=noise_pred_net.parameters(),
    lr=1e-4, weight_decay=1e-6)

# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=200,
    num_training_steps=checkpoint['len_dataloader'] * num_epochs
)

ema_noise_pred_net = noise_pred_net

number of parameters: 6.686209e+07


In [5]:
noise_pred_net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
ema.load_state_dict(checkpoint['ema_state_dict'])
start_epoch = checkpoint['epoch'] + 1

In [6]:
# create dataset from file
# path_name = "/home/cam/Downloads/Supporting Data - Sheet1.csv"
base_path = "no_sync/data_chisel_task/2-cleaned_interpolation_with_offset/offset_interpolated_test_traj/"

# Load data
dict_of_df_rigid_test = {}
dict_of_df_rigid_velocity_test = {}
dict_of_df_marker_test = {}
name = []

for file in os.listdir(base_path):
    if file.endswith(".csv"):
        name.append(file)
        path_name = base_path + file
        data_test = cfp.DataParser.from_euler_file(file_path = path_name, target_fps=target_fps, filter=True, window_size=15, polyorder=3)

        marker_data = data_test.get_marker_Txyz()
        data_time = data_test.get_time().astype(float)
        data_state_dict = data_test.get_rigid_TxyzRxyz()

        # use the time and state data to get the velocity data
        data_velocity_state_dict = {}
        data_velocity_dict = {}
        for key in data_state_dict.keys():
            if key != 'battery':
                data_velocity = []
                data_velocity_state = []
                for i in range(0, len(data_time) -1):
                    veloctiy_val = (data_state_dict[key][i + 1] - data_state_dict[key][i]) / (data_time[i + 1] - data_time[i])
                    data_velocity.append(veloctiy_val)
                    data_velocity_state.append(np.concatenate((data_state_dict[key][i], veloctiy_val), axis=0).tolist())
                velocity_state_data = pd.DataFrame(data_velocity_state, columns= [f'{key}_X', f'{key}_Y', f'{key}_Z', f'{key}_x', f'{key}_y', f'{key}_z', f'{key}_Xv', f'{key}_Yv', f'{key}_Zv', f'{key}_xv', f'{key}_yv', f'{key}_zv'])
                filtered_velocity_state = _df.apply_savgol_filter(velocity_state_data, window_size = 15, polyorder = 3, time_frame= False)
                data_velocity_state_dict[key] = filtered_velocity_state.values
                velocity_data = pd.DataFrame(data_velocity, columns= [f'{key}_Xv', f'{key}_Yv', f'{key}_Zv', f'{key}_xv', f'{key}_yv', f'{key}_zv'])
                filtered_velocity = _df.apply_savgol_filter(velocity_data, window_size = 15, polyorder = 3, time_frame= False)
                data_velocity_dict[key] = filtered_velocity.values
            else:
                data_velocity_state_dict[key] = data_state_dict[key]


        dicts = [data_velocity_state_dict, data_velocity_dict, marker_data]
        trimmed_dicts = _df.trim_lists_in_dicts(dicts)

        
        dict_of_df_rigid_test[file] = trimmed_dicts[0]
        dict_of_df_rigid_velocity_test[file] = trimmed_dicts[1]
        dict_of_df_marker_test[file] = trimmed_dicts[2]


item_name_test = data_test.rigid_bodies
marker_name_test = data_test.markers

if len(dict_of_df_rigid_test) == len(dict_of_df_marker_test) == len(dict_of_df_rigid_velocity_test):

    rigiddataset_test, index_test = _df.episode_combiner(dict_of_df_rigid_test, item_name_test)
    velocitydataset_test, _ = _df.episode_combiner(dict_of_df_rigid_velocity_test, action_item)
    markerdataset_test, _ = _df.episode_combiner(dict_of_df_marker_test, marker_name_test)
    print(index_test[action_item[0]])


indexes = index_test[action_item[0]]
action = []
obs = []
for i in range(indexes[-1]):
    a = np.concatenate([velocitydataset_test[item][i] for item in action_item])
    b = np.concatenate([rigiddataset_test[item][i] for item in action_item] 
                       + [rigiddataset_test[item][i] for item in obs_item] 
                       + [markerdataset_test[item][i] for item in marker_name_test])

    action.append(a)
    obs.append(b)
    
# All demonstration episodes are concatinated in the first dimension N
action = np.array(action, dtype=np.float64)
obs = np.array(obs, dtype=np.float64)

# Initialize lists to store segmented data
splits_obs = []
splits_action = []
previous_index = 0

# Iterate through episode_ends and slice action and obs accordingly
for index in indexes:
    splits_obs.append(obs[previous_index:index])  # Include index itself in the slice
    splits_action.append(action[previous_index:index])
    previous_index = index + 1

[844, 1678, 2671, 3526, 4328, 5589, 6376, 7315, 8390, 9282, 10313, 11342, 12405, 13297, 14214, 15464, 16524, 17466, 18442, 19539, 20473, 21319, 22118, 23050]


In [7]:
import collections

trajectories = {}
losses_per_traj = {}
for j in range(len(indexes)):
    # get first observation
    com_obs = splits_obs[j]
    obs = splits_obs[j][0]
    actions_test = splits_action[j]
    # max_steps = len(test_data['action'])
    max_steps = len(actions_test)
    stats = checkpoint['dataset_stats']
    # keep a queue of last 2 steps of observations
    obs_deque = collections.deque(
        [obs] * obs_horizon, maxlen=obs_horizon)

    # save visualization and rewards
    done = False
    step_idx = 0
    traj = []
    loss_com = []
    with tqdm(total=max_steps, desc="Eval") as pbar:
        while not done:
            B = 1
            # stack the last obs_horizon (2) number of observations
            obs_seq = np.stack(obs_deque)
            # normalize observation
            nobs = dproc.normalize_data(obs_seq, stats=stats['obs'])
            # device transfer
            nobs = torch.from_numpy(nobs).to(device, dtype=torch.float32)
            # infer action
            with torch.no_grad():
                # reshape observation to (B,obs_horizon*obs_dim)
                obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)
                # print(obs_cond.shape)

                # initialize action from Guassian noise
                noisy_action = torch.randn(
                    (B, pred_horizon, action_dim), device=device)
                naction = noisy_action

                # init scheduler
                noise_scheduler.set_timesteps(num_diffusion_iters)

                for k in noise_scheduler.timesteps:
                    # predict noise
                    noise_pred = ema_noise_pred_net(
                        sample=naction,
                        timestep=k,
                        global_cond=obs_cond
                    )

                    # inverse diffusion step (remove noise)
                    naction = noise_scheduler.step(
                        model_output=noise_pred,
                        timestep=k,
                        sample=naction
                    ).prev_sample

            # unnormalize action
            naction = naction.detach().to('cpu').numpy()
            # (B, pred_horizon, action_dim)
            naction = naction[0]
            action_pred = dproc.unnormalize_data(naction, stats=stats['action'])

            # only take action_horizon number of actions
            start = obs_horizon - 1
            end = start + action_horizon
            action = action_pred[start:end,:]
            traj.extend(action)
            losses = []
            pos_item1 = obs_deque[-1][:6]
            pos_item2 = obs_deque[-1][12:18]
            time_step = 1/120.0
                
            for i in range(len(action)):
                if len(action) > len(actions_test):
                    done = True
                if done:
                    break
                loss_test = nn.functional.mse_loss(torch.tensor(action[i]), torch.tensor(actions_test[i]))
                
                action_vel_item1 = action[i][:6]
                action_vel_item2 = action[i][6:12]
                # print("Action_last ---",action_last)
                new_pos_item1 = pos_item1 + (action_vel_item1*time_step)
                new_pos_item2 = pos_item2 + (action_vel_item2*time_step)
                com_obs_part = com_obs[i][24:]
                # Concatenating prediction to the obs lists
                com_obs[i] = np.concatenate([new_pos_item1 , action_vel_item1 , new_pos_item2 , action_vel_item2 , com_obs_part]).tolist()
                obs_deque.append(com_obs[i])
                losses.append(loss_test.item())
                # update progress bar
                step_idx += 1
                pbar.update(1)
                pbar.set_postfix(loss=np.mean(losses))
                # print(i)
                if step_idx > max_steps:
                    done = True
                if done:
                    break
            com_obs = com_obs[len(action):]
            actions_test = actions_test[len(action):]
            loss_com.append(np.mean(losses).tolist())
    losses_per_traj[f"{name[j]}"] = np.nanmean(loss_com)
    trajectories[f"{name[j]}"] = traj

Eval: 100%|█████████▉| 840/844 [00:55<00:00, 15.08it/s, loss=0.688]
Eval: 100%|█████████▉| 832/833 [00:54<00:00, 15.17it/s, loss=2.65] 
Eval: 100%|██████████| 992/992 [01:05<00:00, 15.21it/s, loss=2.36] 
Eval:  99%|█████████▉| 848/854 [00:56<00:00, 15.01it/s, loss=0.462]
Eval: 100%|█████████▉| 800/801 [00:53<00:00, 15.09it/s, loss=0.804] 
Eval: 100%|█████████▉| 1256/1260 [01:22<00:00, 15.26it/s, loss=0.241] 
Eval: 100%|█████████▉| 784/786 [00:51<00:00, 15.26it/s, loss=0.259] 
Eval: 100%|█████████▉| 936/938 [01:01<00:00, 15.27it/s, loss=0.754]
Eval: 100%|█████████▉| 1072/1074 [01:10<00:00, 15.28it/s, loss=0.304]
Eval: 100%|█████████▉| 888/891 [00:58<00:00, 15.23it/s, loss=0.368] 
Eval:  99%|█████████▉| 1024/1030 [01:07<00:00, 15.18it/s, loss=0.263] 
Eval: 100%|█████████▉| 1024/1028 [01:07<00:00, 15.23it/s, loss=0.744]
Eval:  99%|█████████▉| 1056/1062 [01:09<00:00, 15.29it/s, loss=0.345] 
Eval: 100%|█████████▉| 888/891 [00:58<00:00, 15.21it/s, loss=0.538]
Eval: 100%|█████████▉| 912/916 [