In [1]:
import numpy as np
import os
import time
import torch
import torch.nn as nn
import pickle
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
import submodules.data_filter as _df
import diffusion_pipline.data_processing as dproc
import diffusion_pipline.model as md
import submodules.cleaned_file_parser as cfp


In [22]:
checkpoint_path = '/home1/shuklar/diff_files/checkpoints/checkpoint__3BODY_NoNAN_12_markers__epoch_119.pth'

checkpoint = torch.load(checkpoint_path)


In [23]:
checkpoint.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'ema_state_dict', 'len_dataloader', 'dataset_stats', 'num_epochs', 'obs_dim', 'action_dim', 'pred_horizon', 'obs_horizon', 'action_horizon', 'target_fps', 'action_item', 'obs_item', 'marker_item', 'num_diffusion_iters'])

In [24]:
# observation and action dimensions corrsponding to
num_epochs =checkpoint['num_epochs']
obs_dim = checkpoint['obs_dim']
action_dim = checkpoint['action_dim']
# parameters
pred_horizon = checkpoint['pred_horizon']
obs_horizon = checkpoint['obs_horizon']
action_horizon = checkpoint['action_horizon']
target_fps = checkpoint['target_fps']

type = '3BODY_NoNAN_12_markers'

action_item = checkpoint['action_item']
obs_item = checkpoint['obs_item']


# create network object
noise_pred_net = md.ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

# example inputs
noised_action = torch.randn((1, pred_horizon, action_dim))
obs = torch.zeros((1, obs_horizon, obs_dim))
diffusion_iter = torch.zeros((1,))

# the noise prediction network
# takes noisy action, diffusion iteration and observation as input
# predicts the noise added to action
noise = noise_pred_net(
    sample=noised_action,
    timestep=diffusion_iter,
    global_cond=obs.flatten(start_dim=1))

# illustration of removing noise
# the actual noise removal is performed by NoiseScheduler
# and is dependent on the diffusion noise schedule
denoised_action = noised_action - noise

# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = checkpoint['num_diffusion_iters']
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # the choise of beta schedule has big impact on performance
    # we found squared cosine works the best
    beta_schedule='squaredcos_cap_v2',
    # clip output to [-1,1] to improve stability
    clip_sample=True,
    # our network predicts noise (instead of denoised action)
    prediction_type='epsilon'
)

# device transfer
device = torch.device('cuda')
_ = noise_pred_net.to(device)

# Exponential Moving Average
# accelerates training and improves stability
# holds a copy of the model weights
ema = EMAModel(
    parameters=noise_pred_net.parameters(),
    power=0.75)

# Standard ADAM optimizer
# Note that EMA parametesr are not optimized
optimizer = torch.optim.AdamW(
    params=noise_pred_net.parameters(),
    lr=1e-4, weight_decay=1e-6)

# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=200,
    num_training_steps=checkpoint['len_dataloader'] * num_epochs
)


ema_noise_pred_net = noise_pred_net

number of parameters: 6.678683e+07


In [25]:
noise_pred_net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
ema.load_state_dict(checkpoint['ema_state_dict'])
start_epoch = checkpoint['epoch'] + 1

In [26]:
# create dataset from file
# path_name = "/home/cam/Downloads/Supporting Data - Sheet1.csv"
base_path = "/home1/shuklar/diff_files/trun_table_task/train_traj/"

# Load data
dict_of_df_rigid = {}
dict_of_df_marker = {}


for file in os.listdir(base_path):
    if file.endswith(".csv"):
        path_name = base_path + file
        data = cfp.DataParser.from_quat_file(file_path = path_name, target_fps=target_fps, filter=False, window_size=15, polyorder=3)
        marker_data = data.get_marker_Txyz()
        data_state_dict = data.get_rigid_TxyzRxyz()

        dicts = [data_state_dict, marker_data]
        trimmed_dicts = _df.trim_lists_in_dicts(dicts)

        dict_of_df_rigid[file] = trimmed_dicts[0]
        dict_of_df_marker[file] = trimmed_dicts[1]

item_name = data.rigid_bodies
marker_name = data.markers

if len(dict_of_df_rigid) == len(dict_of_df_marker):

    rigiddataset, index = _df.episode_combiner(dict_of_df_rigid, item_name)
    markerdataset, _ = _df.episode_combiner(dict_of_df_marker, marker_name)
    print(index[action_item[0]])


#### if you don't want battery info then just do obs_item = None abd also do clear all outputs and restart the kernal before that and satrt from the top 


dataset = dproc.TaskStateDataset(Rigiddataset=rigiddataset, Velocitydataset = None, Markerdataset= markerdataset, index= index[action_item[0]], 
                                 action_item = action_item, obs_item = obs_item,
                                 marker_item= marker_name,
                                 pred_horizon=pred_horizon,
                                 obs_horizon=obs_horizon,
                                 action_horizon=action_horizon)

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    num_workers=1,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

batch = next(iter(dataloader))
print("batch['obs'].shape:", batch['obs'].shape)
print("batch['action'].shape", batch['action'].shape)

[3095, 5063, 8419, 10635, 14427, 17353, 20118, 22869, 25896, 28732, 31432, 34270, 37238, 39821, 42379, 44891, 47448, 50579, 53660, 56306, 59427, 62150, 64198, 67155, 69420, 71996, 74912, 77980, 81110, 84353, 86834, 89743, 93045, 96123, 97798, 100753, 103670, 106565, 109619, 112631, 115598, 117237, 119112, 121914, 124865, 127158, 129965, 132497, 135419, 137464, 139599, 143361, 146335, 149649, 152159, 155234, 157761, 160944, 163861, 166752, 169925, 173103, 174948, 178568, 181260, 184644, 187927, 190434, 193211, 196311, 199810, 203219, 206286, 208321, 211388, 213798, 216446, 219630, 222135, 224802, 227607, 230762, 233510, 235940, 238321, 240758, 242738, 246382, 249290, 251530, 255003, 257856, 260855, 263461, 265365, 266959, 270430, 273310, 276706, 279553, 281728, 283972, 286923, 290165, 292908, 296085, 298076, 300414, 303991, 306617, 310130, 312556, 316284, 318713, 321840, 324169, 327766, 330450, 332516, 335406, 337593, 341079, 343730, 346426, 349294, 351737, 355158, 357945, 360565, 36329

In [27]:
data_state_dict['chisel'][0]

array([ 24.576021  ,  72.273773  , 362.534119  ,  -1.58459063,
         0.43490431,  -1.52324078])

In [28]:
data_state_dict['gripper'][0]

array([-194.145798  ,  259.163086  ,  402.422455  ,   -1.75266852,
          1.23180474,    1.77237689])

In [29]:
data_state_dict['battery'][0]

array([ 1.11871452e+02,  3.45810364e+02,  1.74228607e+02, -3.31561306e-03,
       -2.91190058e-03, -2.47683353e-03])

In [30]:
marker_data['B1'][0]

array([ 64.510429, 476.06546 , 179.830948])

In [None]:
#@markdown ### **Training**
#@markdown

checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_interval = 3600
last_checkpoint_time = time.time()


with tqdm(range(num_epochs -  start_epoch), desc='Epoch') as tglobal:
    # epoch loop
    epoch_loss = []
    batch_loss_per_epoch = []

    for epoch_idx in tglobal:
        batch_loss = []
        batch_noise = []
        # batch loop
        for nbatch in dataloader:
            # data normalized in dataset
            # device transfer
            nobs = nbatch['obs']
            naction = nbatch['action']
            B = nobs.shape[0]

            # observation as FiLM conditioning
            # (B, obs_horizon, obs_dim)
            obs_cond = nobs[:,:obs_horizon,:]
            # (B, obs_horizon * obs_dim)
            obs_cond = obs_cond.flatten(start_dim=1).float().to(device)
            # print(obs_cond.type())

            # sample noise to add to actions
            # noise = torch.randn(naction.shape, device=device)
            noise = torch.randn(naction.shape)

            # sample a diffusion iteration for each data point
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps,
                (B,)
            ).long()

            # add noise to the clean images according to the noise magnitude at each diffusion iteration
            # (this is the forward diffusion process)
            noisy_actions = noise_scheduler.add_noise(
                naction, noise, timesteps)
            
            noise = noise.to(device)
            
            timesteps = timesteps.to(device)

            # print(noisy_actions.type())
            noisy_actions = noisy_actions.type(torch.FloatTensor).to(device)
            # print(noisy_actions.type())

            # predict the noise residual
            noise_pred = noise_pred_net(
                noisy_actions, timesteps, global_cond=obs_cond)
            
            batch_noise.append(noise_pred)

            # L2 loss
            loss = nn.functional.mse_loss(noise_pred, noise)

            # optimize
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # step lr scheduler every batch
            # this is different from standard pytorch behavior
            lr_scheduler.step()

            # update Exponential Moving Average of the model weights
            ema.step(noise_pred_net)
            # print(ema.state_dict)

            # logging
            loss_cpu = loss.item()
            batch_loss.append(loss_cpu)
            # tglobal.set_postfix(loss=loss_cpu)

        # save checkpoint
        # went to the emma model library and added state_dict to the model
        current_time = time.time()
        if current_time - last_checkpoint_time > checkpoint_interval:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{type}_epoch_{epoch_idx}.pth')
            torch.save({
                        'epoch': epoch_idx + start_epoch,
                        'model_state_dict': noise_pred_net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': lr_scheduler.state_dict(),
                        'ema_state_dict': ema.state_dict(),
                        'len_dataloader': len(dataloader),
                        'dataset_stats': dataset.stats,
                        'num_epochs': num_epochs,
                        'obs_dim': obs_dim,
                        'action_dim': action_dim,
                        'pred_horizon': pred_horizon,
                        'obs_horizon': obs_horizon,
                        'action_horizon': action_horizon,
                        'target_fps': target_fps,
                        'action_item': action_item,
                        'obs_item': obs_item,
                        'marker_item': marker_name,
                        'num_diffusion_iters': num_diffusion_iters,
                    }, checkpoint_path)
            last_checkpoint_time = current_time
            
        tglobal.set_postfix(loss=np.mean(batch_loss))
        epoch_loss.append(np.mean(batch_loss))
        batch_loss_per_epoch.append(batch_loss)

# Weights of the EMA model
# is used for inference
ema_noise_pred_net = noise_pred_net

checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{type}_epoch_{epoch_idx}.pth')
torch.save({
    'epoch': epoch_idx + start_epoch,
    'model_state_dict': noise_pred_net.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': lr_scheduler.state_dict(),
    'ema_state_dict': ema.state_dict(),
    'len_dataloader': len(dataloader),
    'dataset_stats': dataset.stats,
    'num_epochs': num_epochs,
    'obs_dim': obs_dim,
    'action_dim': action_dim,
    'pred_horizon': pred_horizon,
    'obs_horizon': obs_horizon,
    'action_horizon': action_horizon,
    'target_fps': target_fps,
    'action_item': action_item,
    'obs_item': obs_item,
    'marker_item': marker_name,
    'num_diffusion_iters': num_diffusion_iters,
}, checkpoint_path)
print(f'Checkpoint saved at epoch {epoch_idx}')

Epoch:   0%|          | 0/280 [00:00<?, ?it/s]