In [None]:
import os
import json
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np 
import matplotlib.pyplot as plt
from daisy_API import daisy_API
import daisy_hardware.motion_library as motion_library

from logger import Logger
from low_level_traj_gen import NN_tra_generator
import utils
from pytorchtools.pytorchtools import EarlyStopping


class train_NNTG():
    def __init__(self,  
                num_primitive, 
                z_dim,
                policy_output_dim, 
                policy_hidden_num, 
                policy_lr, 
                batch_size,
                mean_std,
                device):
            
        self.policy = NN_tra_generator(z_dim, policy_output_dim, policy_hidden_num, device)
        self.policy_lr = policy_lr
        self.policy_optimizer = torch.optim.Adam(self.policy.parameters(),lr=self.policy_lr, weight_decay=2e-5)
        self.batch_size = batch_size
        self.num_primitive = num_primitive
        self.z_dim = z_dim
        self.device = device
        self.learning_step = 0
        self.mean_std = mean_std


        # define random z_action
        self.z_action_all = torch.tensor(np.random.normal(0,0.2,(num_primitive,z_dim)).tolist(),requires_grad=True, device = device)
        self.z_action_optimizer = torch.optim.Adam([self.z_action_all],lr=self.policy_lr, weight_decay=1e-4)


    def sample_phase_action(self, primitive_index, size):
        idxs = np.random.randint(0,100, size=size)

        phase = (idxs + 1) / 100.0
        expert_action = np.empty((size, 18))
        z_vec = np.empty((size, self.z_dim))
        for i in range(size):
            expert_action[i] = traj[primitive_index[i]][idxs[i]]
        z_vec = self.z_action_all[primitive_index]
        action_vec = torch.as_tensor(utils.normalization(expert_action,self.mean_std[0], self.mean_std[1]), device= self.device).float()
        phase_vec = torch.as_tensor(np.reshape(phase,(size,1)), device= self.device).float()
        
        return phase_vec, action_vec, z_vec



    def update_model(self, num_iteration, save_dir, early_stopper):
        logger = Logger(save_dir, name = 'train')
        for i in range(num_iteration):
            z_index = np.random.randint(0,95, size=self.batch_size)
            state_vec, expert_action, z_vec = self.sample_phase_action(z_index, self.batch_size)
            pred_action = self.policy.forward(z_vec, state_vec)

            policy_loss = F.mse_loss(pred_action, expert_action) 
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()
            self.z_action_optimizer.step()

            self.learning_step += 1
            logger.log('train/model_loss', policy_loss)
            logger.dump(self.learning_step)



            if (i+1)%100 == 0:
                for _ in range(100):
                    z_index = np.arange(95,self.num_primitive)
                    state_vec, expert_action, z_vec = self.sample_phase_action(z_index,self.num_primitive-95)
                    pred_action = self.policy.forward(z_vec, state_vec)
                    policy_loss = F.mse_loss(pred_action, expert_action) 
                    self.policy_optimizer.zero_grad()
                    policy_loss.backward()
                    self.z_action_optimizer.step()
                logger.log('train/val_loss', policy_loss)
                logger.dump(self.learning_step)
                early_stopper(policy_loss)

            if early_stopper.early_stop:
                break

        
        self.save_model(save_dir)

    def save_model(self, save_dir):
        torch.save(self.policy.state_dict(),
                   '%s/NNTG.pt' % (save_dir) )
    
    def load_model(self, save_dir):
        self.policy.load_state_dict(
            torch.load('%s/NNTG.pt' % (save_dir)))