In [None]:
!pip install -r requirements.txt

In [None]:
from collections import OrderedDict
import numpy as np

import torch

import matplotlib.pyplot as plt
from google.colab import drive

from tianshou.data import Batch, ReplayBuffer, SegmentTree, to_numpy

from stable_baselines3.common.vec_env import DummyVecEnv

import imageio
from base64 import b64encode
from IPython.display import HTML

In [None]:
from agent_train import AgentTrainer
from seq_train import SequenceTrainer
from env import Stack
from utils import image_process

class SerialModelTrain():
  def __init__(self):
    self.num_episode = 10
    self.initial_memory_size = 20
    self.memory_size = 1024
    self.episode_rewards = []
    self.num_average_epidodes = 100
    self.save_every = 100
    self.batch_size=3072
    self.max_steps = 100
    self.max_iterators = 10
    self.env = Stack('Panda') 
    self.robot_state_shape = 7
    self.action_shape = self.env.robots[0].robot_model.dof + self.env.robots.gripper.dof

    self.agent_trainer = AgentTrainer(self.action_shape, self.robot_state_shape, self.memory_size)
    self.seq_trainer = SequenceTrainer(batch_size=self.batch_size)

    self.evaluate_interval = 10
    self.num_steps_per_iter = 10000
    self.reward_window = 20
    self.reward_values = [0]*self.reward_window 
    self.averge_reward = 0

  def init_buffer(self):
    """ Initially, put the data into the replay buffer when an action with noise was taken """
    state = OrderedDict() 
    next_state = OrderedDict() 
    done = False 
    reward = 0

    state = self.env.reset()

    for step in range(self.initial_memory_size):
      if step % self.max_steps == 0:
        state = self.env.reset()

      vision = image_process(state['frontview_image'])

      action = np.random.randn(self.action_shape) # sample random action
      next_state, reward, done, info = self.env.step(action)

      vision_next = image_process(next_state['frontview_image'])

      vision = vision.squeeze(0)
      vision_next = vision_next.squeeze(0)

      action = torch.tensor(action, dtype=torch.float32)
      action = action.squeeze(0)

      self.agent_trainer.pri_buffer.store(state, action, reward, next_state, done)
      state = next_state

    print('%d Data collected' % (self.initial_memory_size))

  def train_model(self):
    state = OrderedDict() 
    next_state = OrderedDict() 
    done = False 
    reward = 0 

    for episode in range(self.num_episode):
      state = self.env.reset()
      episode_reward = 0
      for t in range(self.max_steps):
        vision = image_process(state['frontview_image'])

        robot_state = [
            np.array(state['robot0_eef_pos'], dtype=np.float32).flatten(),
            np.array(state['robot0_eef_quat'], dtype=np.float32).flatten()
        ]

        robot_state = np.concatenate(robot_state)
        robot_state = torch.tensor(robot_state, dtype=torch.float32)

        action = self.agent_trainer.get_action(vision, robot_state)
        next_state, reward, done, info = self.env.step(action[0])

        if len(self.reward_value)>=20:
          self.reward_value.pop(0)
          self.reward_value.append(reward)

          vision_next = image_process(next_state['frontview_image'])

          vision = vision.squeeze(0)
          vision_next = vision_next.squeeze(0)


          action = torch.tensor(action, dtype=torch.float32)
          action = action.squeeze(0)

          self.agent_trainer.pri_buffer.store(state, action, reward, next_state, done)
          state = next_state
          if any(done):
            break

        episode_reward += reward

      if episode % 5 == 0:
        self.agent_trainer.update()
      self.episode_rewards.append(episode_reward)
      if episode % self.evaluate_interval == 0:
        self.averge_reward = np.mean(self.reward_value)
        print(self.averge_reward)
      if episode % 100 == 0:
        print("Episode %d finished | Episode reward %f" % (episode, episode_reward))
      if episode % self.save_every == 0:
        self.agent_trainer.save_checkpoint(episode)
        
    self.env.close()

  def train_tranformer(self):
    for iter in range(self.max_iterators):
      outputs = self.seq_trainer.train_iteration(num_steps=self.num_steps_per_iter, iter_num=iter+1, print_logs=True)


  def plot(self):
    # Compute the moving average of cumulative rewards
    moving_average = np.convolve(self.episode_rewards, np.ones(self.num_average_epidodes)/self.num_average_epidodes, mode='valid')
    plt.plot(np.arange(len(moving_average)),moving_average)
    plt.title('Average rewards in %d episodes' % self.num_average_epidodes)
    plt.xlabel('episode')
    plt.ylabel('rewards')
    plt.show()

  def test_model(self):
    env = Stack('Panda')
    state = env.reset()
    self.agent_trainer.load_checkpoint('/content/drive/My Drive/check_point')

    frames = []
    for i in range(100):
      frames.append(state['frontview_image'])  # Append the image to the frames list
      obs = image_process(state['frontview_image'])
      action = self.agent_trainer.get_action(obs)

      next_state, reward, done, info = env.step(action[0])
      state = next_state

      # if i == 50:
      #   env.joint_pos = np.array([0.0, 0.5, 1.0, 1.3, 1.0, 1.0, 0.785])
      #   env.object_pos = np.array([10,10,0])
      #   env.object_quat = np.array([1, 0, 0])
      #   state = env.reset()

    imageio.mimwrite('robosuite_video.mp4', frames, fps=20)

## Train agent

In [None]:
# drive.mount('/content/drive')

serial_train = SerialModelTrain()
serial_train.init_buffer()
serial_train.train_model()
serial_train.plot()


## Test Model

In [None]:
serial_train.test_model()

In [None]:
""" Load video and encode it in base64 format """
video_path = 'robosuite_video.mp4'
video_data = open(video_path, 'rb').read()
video_encoded = b64encode(video_data).decode()

# Display video using HTML
HTML(f"""
<video width="640" height="480" controls>
  <source src="data:video/mp4;base64,{video_encoded}" type="video/mp4">
</video>
""")