In [1]:
import numpy as np
import torch
import copy

from dm_control import suite
device = "cuda"

In [2]:
import skimage
from io import BytesIO
import imageio
from IPython.display import Image

def display_video(frames, framerate=30):
    gif_file = BytesIO()
    imageio.mimsave(gif_file, [skimage.img_as_ubyte(frame) for frame in frames], 'GIF', fps=30)
    return Image(data=gif_file.getvalue())

In [3]:
def partial_copy_model(model1,model2,t):
    new_state_dict = model1.state_dict()
    update_state_dict = model2.state_dict()

    for key in new_state_dict.keys():
        new_state_dict[key] = (1-t)*new_state_dict[key] + t*update_state_dict[key]

    model1.load_state_dict(new_state_dict)

class DDPG:
    def __init__(self, state_size, action_size, decay=0.99, exploration_noise=0.2, target_lr=0.001):
        self.exploration_noise = exploration_noise
        self.target_lr = target_lr
        self.decay = decay
        self.target_lr = target_lr
        self.action_size = action_size
        self.state_size = state_size

        self.critic = torch.nn.Sequential(
            torch.nn.Linear(state_size+action_size, 200),
            torch.nn.Tanh(),
            torch.nn.Linear(200,200),
            torch.nn.Tanh(),
            torch.nn.Linear(200, 1),
        ).to(device)

        self.actor = torch.nn.Sequential(
            torch.nn.Linear(state_size, 200),
            torch.nn.ReLU(),
            torch.nn.Linear(200,200),
            torch.nn.ReLU(),
            torch.nn.Linear(200, action_size),
            torch.nn.Tanh(),
        ).to(device)
        
        actor_state_dict = self.actor.state_dict()
        actor_state_dict["4.bias"][:] = 0
        actor_state_dict["4.weight"] = torch.normal(
            torch.zeros_like(actor_state_dict["4.weight"]),
            0.001*torch.ones_like(actor_state_dict["4.weight"]),
        )
        self.actor.load_state_dict(actor_state_dict)

        self.target_actor = copy.deepcopy(self.actor)
        self.target_critic = copy.deepcopy(self.critic)

        self.actor.to(device)
        self.critic.to(device)
        self.target_actor.to(device)
        self.target_critic.to(device)

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=0.003,weight_decay=0.01)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=0.0003)

        self.random_state = np.random.RandomState(42)

        self.replay_size = 1000000
        self.replay_ind = 0
        self.replay_prev_state = torch.zeros((self.replay_size,state_size)).to(device)
        self.replay_new_state = torch.zeros((self.replay_size,state_size)).to(device)
        self.replay_action = torch.zeros((self.replay_size,action_size)).to(device)
        self.replay_reward = torch.zeros((self.replay_size)).to(device)


    def get_action(self,state):
        action = self.actor(state.to(device)).cpu() + torch.normal(torch.zeros(self.action_size), std=self.exploration_noise)
        #action = torch.normal(torch.zeros(self.action_size), std=self.exploration_noise)
        action += (action<-1).float()
        action -= (action>1).float()
        return action.detach().cpu().numpy()

    def store_transition(self, prev_state,new_state, action, reward, should_print=False):
        action = torch.tensor(action).to(device)
        # insert transition into buffer
        self.replay_prev_state[self.replay_ind % self.replay_size] = prev_state
        self.replay_new_state[self.replay_ind % self.replay_size] = new_state
        self.replay_action[self.replay_ind % self.replay_size] = action
        self.replay_reward[self.replay_ind % self.replay_size] = reward
        self.replay_ind += 1

        if should_print:
            print(self.critic(torch.cat([prev_state,action])))

        if self.random_state.uniform(0,1)>0.2:
            return
            
        # sample minibatch from replay buffer
        num_samples = min(200,self.replay_ind)
        sample_inds = self.random_state.randint(0,min(self.replay_ind,self.replay_size), num_samples)
        sample_prev_state = self.replay_prev_state[sample_inds].to(device)
        sample_new_state = self.replay_new_state[sample_inds].to(device)
        sample_action = self.replay_action[sample_inds].to(device)
        sample_reward = self.replay_reward[sample_inds].to(device)

        # calculate target reward using bellman equation with target models
        predicted_next_reward = self.target_critic(torch.cat([sample_new_state, self.target_actor(sample_new_state)],dim=1))[:,0]
        target_reward = sample_reward+self.decay*predicted_next_reward

        # Train critic
        self.critic_optimizer.zero_grad()
        predicted_reward = self.critic(torch.cat([sample_prev_state,sample_action],dim=1))[:,0]
        loss = torch.mean((target_reward - predicted_reward)**2)
        loss.backward()
        self.critic_optimizer.step()
        

        # Train actor
        self.actor_optimizer.zero_grad()
        predicted_reward = self.critic(torch.cat([sample_prev_state,self.actor(sample_prev_state)],dim=1))[:,0]
        loss = (1-predicted_reward).mean()
        loss.backward()
        self.actor_optimizer.step()

        # Update target networks
        partial_copy_model(self.target_actor,self.actor,self.target_lr)
        partial_copy_model(self.target_critic,self.critic,self.target_lr)
        

In [4]:
def flatten_obs(observation):
    return torch.tensor(np.concatenate(list(observation.values()))).float().to(device)

def simulate_render(env,agent, duration=3):
    frames = []
    rewards = []

    spec = env.action_spec()
    time_step = env.reset()
    print(spec)
    while env.physics.data.time < duration:
        action = agent.get_action(flatten_obs(time_step.observation))
        time_step = env.step(action)

        camera0 = env.physics.render(camera_id=0, height=200, width=200)
        camera1 = env.physics.render(camera_id=1, height=200, width=200)
        frames.append(np.hstack((camera0, camera1)))
        rewards.append(time_step.reward)
    print("Num frames:",len(frames))
    return display_video(frames, framerate=1./env.control_timestep()*5)

def simulate_train(env,agent, duration=3):
    frames = []
    rewards = []

    spec = env.action_spec()
    time_step = env.reset()
    should_print = True
    while env.physics.data.time < duration:
        prev_state = flatten_obs(time_step.observation)
        action = agent.target_actor(prev_state.to(device)).cpu().detach()
        time_step = env.step(action)
        agent.store_transition(prev_state, flatten_obs(time_step.observation), action, time_step.reward,should_print=should_print)
        should_print=False
        rewards.append(time_step.reward)

    return time_step.reward

In [5]:
env = suite.load('cartpole', 'swingup')
state_size = len(flatten_obs(env.reset().observation))
action_size = env.action_spec().shape[0]
agent = DDPG(state_size, action_size)

In [6]:
for i in range(1000000):
    print(simulate_train(env,agent,duration=5))

    if i%100==0:
        torch.save(agent.critic,"critic")
        torch.save(agent.actor,"actor")
        torch.save(agent.target_critic,"target_critic")
        torch.save(agent.target_actor,"target_actor")

  action = torch.tensor(action).to(device)


tensor([0.1698], device='cuda:0', grad_fn=<AddBackward0>)
7.280726827154927e-05
tensor([0.1712], device='cuda:0', grad_fn=<AddBackward0>)
0.013963357755899112
tensor([0.2081], device='cuda:0', grad_fn=<AddBackward0>)
0.003001315127822956
tensor([0.2381], device='cuda:0', grad_fn=<AddBackward0>)
0.042120442479328585
tensor([0.2699], device='cuda:0', grad_fn=<AddBackward0>)
0.004878742972049342
tensor([0.2977], device='cuda:0', grad_fn=<AddBackward0>)
0.07053634325607723
tensor([0.3240], device='cuda:0', grad_fn=<AddBackward0>)
0.15214289652297341
tensor([0.3439], device='cuda:0', grad_fn=<AddBackward0>)
0.2488411073650475
tensor([0.3639], device='cuda:0', grad_fn=<AddBackward0>)
0.03577074975211909
tensor([0.3748], device='cuda:0', grad_fn=<AddBackward0>)
0.08371723217122462
tensor([0.3932], device='cuda:0', grad_fn=<AddBackward0>)
0.07096625053740134
tensor([0.4065], device='cuda:0', grad_fn=<AddBackward0>)
0.05226525562814396
tensor([0.4149], device='cuda:0', grad_fn=<AddBackward0>)
0

KeyboardInterrupt: 

In [None]:
display(simulate_render(env,agent, duration=9))