In [24]:
import os
import json
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np 
import matplotlib.pyplot as plt
from daisy_API import daisy_API
import daisy_hardware.motion_library as motion_library

from logger import Logger
from low_level_traj_gen import NN_tra_generator
import utils
from pytorchtools.pytorchtools import EarlyStopping


class train_NNTG():
    def __init__(self,  
                num_primitive, 
                z_dim,
                policy_output_dim, 
                policy_hidden_num, 
                policy_lr, 
                batch_size,
                mean_std,
                device):
            
        self.policy = NN_tra_generator(z_dim, policy_output_dim, policy_hidden_num, device)
        self.policy_lr = policy_lr
        self.policy_optimizer = torch.optim.Adam(self.policy.parameters(),lr=self.policy_lr, weight_decay=2e-5)
        self.batch_size = batch_size
        self.num_primitive = num_primitive
        self.z_dim = z_dim
        self.device = device
        self.learning_step = 0
        self.mean_std = mean_std


        # define random z_action
        self.z_action_all = torch.tensor(np.random.normal(0,0.2,(num_primitive,z_dim)).tolist(),requires_grad=True, device = device)
        self.z_action_optimizer = torch.optim.Adam([self.z_action_all],lr=self.policy_lr, weight_decay=1e-4)


    def sample_phase_action(self, primitive_index, size):
        idxs = np.random.randint(0,100, size=size)

        phase = (idxs + 1) / 100.0
        expert_action = np.empty((size, 18))
        z_vec = np.empty((size, self.z_dim))
        for i in range(size):
            expert_action[i] = traj[primitive_index[i]][idxs[i]]
        z_vec = self.z_action_all[primitive_index]
        action_vec = torch.as_tensor(utils.normalization(expert_action,self.mean_std[0], self.mean_std[1]), device= self.device).float()
        phase_vec = torch.as_tensor(np.reshape(phase,(size,1)), device= self.device).float()
        
        return phase_vec, action_vec, z_vec



    def update_model(self, num_iteration, save_dir, early_stopper):
        logger = Logger(save_dir, name = 'train')
        for i in range(num_iteration):
            z_index = np.random.randint(0,190, size=self.batch_size)
            state_vec, expert_action, z_vec = self.sample_phase_action(z_index, self.batch_size)
            pred_action = self.policy.forward(z_vec, state_vec)

            policy_loss = F.mse_loss(pred_action, expert_action) 
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()
            self.z_action_optimizer.step()

            self.learning_step += 1
            logger.log('train/model_loss', policy_loss)
            logger.dump(self.learning_step)



            if (i+1)%100 == 0:
                for _ in range(100):
                    z_index = np.arange(190,self.num_primitive)
                    state_vec, expert_action, z_vec = self.sample_phase_action(z_index,self.num_primitive-190)
                    pred_action = self.policy.forward(z_vec, state_vec)
                    policy_loss = F.mse_loss(pred_action, expert_action) 
                    self.policy_optimizer.zero_grad()
                    policy_loss.backward()
                    self.z_action_optimizer.step()
                logger.log('train/val_loss', policy_loss)
                logger.dump(self.learning_step)
                early_stopper(policy_loss)

            if early_stopper.early_stop:
                break

        
        self.save_model(save_dir)

    def save_model(self, save_dir):
        torch.save(self.policy.state_dict(),
                   '%s/NNTG.pt' % (save_dir) )
    
    def load_model(self, save_dir):
        self.policy.load_state_dict(
            torch.load('%s/NNTG.pt' % (save_dir)))

In [3]:
state_total = np.empty((200, 100, 18))
traj_total = np.empty((200, 100, 18))

for i in range(1):
    traj_data = np.load('./save_data/trial_3/exp_action_' + str(i) +'.npy')
    for j in range(100):
        traj_total[j+100] = traj_data[j]


for i in range(1,5):
    traj_data = np.load('./save_data/trial_3/exp_action_' + str(i) +'.npy')
    for j in range(5):
        traj_total[(i-1)*25 + j] = traj_data[j]





np.save('./save_data/expert_action_total.npy', traj_total)


In [37]:
# train the nerual network
z_dim = 2
num_primitive = 200
policy_output_dim = 18
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
traj = np.load('./save_data/expert_action_total.npy')
traj_1= traj.reshape(20000,18)
mean_std = np.array([np.mean(traj_1, axis = 0), np.std(traj_1, axis = 0)])
np.save('./save_data/trial_'+str(z_dim) +'/LL_mean_std.npy', mean_std)
tra_learning = train_NNTG( num_primitive = num_primitive, 
            z_dim= z_dim,
            policy_output_dim = policy_output_dim, 
            policy_hidden_num = 512, 
            policy_lr = 1e-3, 
            batch_size = 512,
            mean_std = mean_std,
            device = device)
early_stopper = EarlyStopping(patience=7)

In [38]:
# # # train NNTG
# print('Z_action before optimization', tra_learning.z_action_all)
tra_learning.update_model(num_iteration =3000, save_dir = './save_data/trial_'+str(z_dim), early_stopper=early_stopper)
print('Z_action after optimization', tra_learning.z_action_all)

EarlyStopping counter: 1 out of 7
EarlyStopping counter: 2 out of 7
EarlyStopping counter: 3 out of 7
EarlyStopping counter: 4 out of 7
EarlyStopping counter: 5 out of 7
EarlyStopping counter: 6 out of 7
EarlyStopping counter: 7 out of 7
Z_action after optimization tensor([[-2.5729e+00,  8.8005e-01],
        [-2.9725e+00, -3.1238e+00],
        [-3.0785e+00, -3.5161e-01],
        [-3.1081e+00,  3.4954e-01],
        [-3.0323e+00, -2.5868e+00],
        [ 2.1611e+00,  2.0021e+00],
        [ 2.1151e+00,  2.0746e+00],
        [ 2.1070e+00,  2.1639e+00],
        [ 2.0082e+00,  2.1180e+00],
        [ 2.3785e+00,  1.9430e+00],
        [ 2.1417e+00,  2.1357e+00],
        [ 2.2032e+00,  2.1426e+00],
        [ 2.0487e+00,  1.8920e+00],
        [ 2.0934e+00,  2.1806e+00],
        [ 2.1329e+00,  2.0460e+00],
        [ 2.1008e+00,  2.1762e+00],
        [ 2.0338e+00,  2.0430e+00],
        [ 2.0616e+00,  2.1236e+00],
        [ 2.1202e+00,  2.1618e+00],
        [ 1.9800e+00,  1.8073e+00],
        [ 2.03