<a href="https://colab.research.google.com/github/TiaBerte/fact-checking/blob/main/half_cheetah.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Rendering Dependencies
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
# Gym Dependencies
!apt-get update > /dev/null 2>&1
!apt-get install cmake > /dev/null 2>&1
!pip install --upgrade setuptools 2>&1
!pip install ez_setup > /dev/null 2>&1
#!pip install gym[atari] > /dev/null 2>&1
#!pip install gym[box2d] > /dev/null 2>&1

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting setuptools
  Downloading setuptools-65.7.0-py3-none-any.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m54.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: setuptools
  Attempting uninstall: setuptools
    Found existing installation: setuptools 57.4.0
    Uninstalling setuptools-57.4.0:
      Successfully uninstalled setuptools-57.4.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.9.0 requires jedi>=0.10, which is not installed.
cvxpy 1.2.3 requires setuptools<=64.0.2, but you have setuptools 65.7.0 which is incompatible.[0m[31m
[0mSuccessfully installed setuptools-65.7.0


In [2]:
!pip install gym[mujoco] > /dev/null 2>&1

In [3]:
import gym
from gym import logger as gymlogger
from gym.wrappers.record_video import RecordVideo
gymlogger.set_level(40) #error only
import tensorflow as tf
import numpy as np
import random
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import math
import glob
import io
import base64
from IPython.display import HTML

from IPython import display as ipythondisplay

In [4]:
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

"""
Utility functions to enable video recording of gym environment and displaying it
To enable video, just do "env = wrap_env(env)""
"""

def show_video(name):
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[-1]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")
    

def wrap_env(env):
  env = RecordVideo(env, './video', episode_trigger = lambda episode_number: True)
  return env

In [5]:
from argparse import ArgumentParser
from typing import List

parser = ArgumentParser()
parser.add_argument("--env_name", help="Gym environment", default="HalfCheetah-v4", type=str)
parser.add_argument("--min_buffer_size", help="Minimum replay buffer size", default=1e5, type=int)
parser.add_argument("--eval_episodes", help="Evaluate agent every N episodes", default=1000, type=int)
parser.add_argument("--gamma", help="discount value", default=0.99, type=float)
parser.add_argument("--alpha", help="temperature entropy", default=0.2, type=float)
parser.add_argument("--alpha_tuning", help="chose to use a fixed allpha or tuning its value", action='store_true')
parser.add_argument("--tau", help="soft update", default=5e-3, type=float)
parser.add_argument("--batch_size", help="batch size", default=256, type=int)
parser.add_argument("--lr_p", help="learning rate", default=3e-4, type=float)
parser.add_argument("--lr_c", help="learning rate", default=3e-4, type=float)
parser.add_argument("--lr_a", help="learning rate", default=3e-4, type=float)
parser.add_argument("--hidden_size_v", help="hidden dim list", default=[256, 256])
parser.add_argument("--hidden_dim_q", help="hidden dim list", default=[256, 256])
parser.add_argument("--hidden_dim_p", help="hidden dim list", default=[256, 256])
parser.add_argument("--log_std_min", help="log std", default=-20, type=float)
parser.add_argument("--log_std_max", help="log std", default=3, type=float)
parser.add_argument("--max_episodes", help="max number of training episode", default=1500, type=int)
parser.add_argument("--test_episodes", help="number of test episodes for evaluating the model", default=20, type=int)
parser.add_argument("--train_max_steps", help="max number of step for each training episode", default=5000, type=int)
parser.add_argument("--max_ep_steps", help="max number of step for each training episode", default=1000, type=int)
parser.add_argument("--test_max_steps", help="max number of step for each testing episode", default=1000, type=int)
parser.add_argument("--model_path", help="path from which model is loaded, if none the model is randomly intialized", action='store_true')
parser.add_argument("--training", help="flag required for training the model", action="store_true")
parser.add_argument("--test", help="flag required for testing a model", action="store_true")
parser.add_argument("--update_steps", help="update networks parameters every N steps", default=5, type=int)
parser.add_argument("--eval_steps", help="evaluate model every N steps", default=1000, type=int)

parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
args = parser.parse_args()

In [6]:
import os
import pickle
import random
from torch import Tensor
import numpy as np
from typing import Tuple

class ReplayBuffer:
    def __init__(self, 
                 min_size : int = 1e4,
                 capacity : int = 1e6, 
                 device : str = "cpu",
                ):
      
        self.min_size = min_size
        self.capacity = capacity
        self.memory = []
        self.device = device

    def __len__(self) -> int:
        return len(self.memory)

    def add(self, 
            state : Tensor, 
            action : Tensor, 
            reward : float, 
            next_state : Tensor, 
            done : bool) -> None:
            '''
            Before adding a new tuple it checks that the number of
            previous samples is less than the maximum capacity. In case
            following a FIFO policy, the first element of the list is
            removed and the new one is added at the end.
            '''
            if self.__len__() >= self.capacity:
                self.memory.pop(0)
            self.memory.append([state, action, reward, next_state, done])
            
            
    def sample(self,
               batch_size : int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
        batch = random.sample(self.memory, batch_size)
        '''
        print(batch)
        for i in range(len(batch[0])):
          print(batch[0][i])
        '''
        state, action, reward, next_state, done = map(np.stack, zip(*batch))

        state = Tensor(state).to(self.device)
        action = Tensor(action).to(self.device)
        reward = Tensor(reward).unsqueeze(1).to(self.device)
        next_state = Tensor(next_state).to(self.device)
        done = Tensor(np.float32(done)).unsqueeze(1).to(self.device)

        return state, action, reward, next_state, done


    def save(self, 
             env_name : str, 
             train_ep: int) -> None :

        if not os.path.exists('buffer/'):
            os.makedirs('buffer/')

        path = f"buffer/buffer_{env_name}_{train_ep}"

        with open(path, 'wb') as f:
            pickle.dump(self.memory, f)

        print(f'Buffer saved at {path}')

    def load(self, 
             path : str) -> None:
        print(f'Loading buffer memory from {path}')

        with open(path, "rb") as f:
            self.memory = pickle.load(f)


In [7]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.distributions import Normal
from typing import List


def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)

class QNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim):
        super(QNetwork, self).__init__()

        # Q1 architecture
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)

        # Q2 architecture
        self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear5 = nn.Linear(hidden_dim, hidden_dim)
        self.linear6 = nn.Linear(hidden_dim, 1)

        self.apply(weights_init_)

    def forward(self, state, action):
        xu = torch.cat([state, action], 1)
        
        x1 = F.relu(self.linear1(xu))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)

        x2 = F.relu(self.linear4(xu))
        x2 = F.relu(self.linear5(x2))
        x2 = self.linear6(x2)

        return x1, x2


class GaussianPolicy(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
        super(GaussianPolicy, self).__init__()
        
        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)

        self.mean_linear = nn.Linear(hidden_dim, num_actions)
        self.log_std_linear = nn.Linear(hidden_dim, num_actions)

        self.apply(weights_init_)

        # action rescaling
        if action_space is None:
            self.action_scale = torch.tensor(1.)
            self.action_bias = torch.tensor(0.)
        else:
            self.action_scale = torch.FloatTensor(
                (action_space.high - action_space.low) / 2.)
            self.action_bias = torch.FloatTensor(
                (action_space.high + action_space.low) / 2.)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=-20, max=2)
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean


In [8]:
from torch import nn, Tensor, optim
from typing import List, Tuple
from torch.distributions import Normal
import torch
#from network import QNetwork, PolicyNetwork
import itertools

class Critic(nn.Module):
    def __init__(self, 
                 input_dim: int, 
                 n_actions: int, 
                 hidden_dim: List[int], 
                 lr : float):

        super(Critic, self).__init__()

        self.q_net1 = QNetwork(input_dim, n_actions, 256)
        #self.q_net2 = QNetwork(input_dim, n_actions, 256)
        
        #self.params = itertools.chain(
        #    self.q_net1.parameters(), self.q_net2.parameters()
        #)
        
        self.optim = optim.Adam(self.q_net1.parameters(), lr, weight_decay=1e-9)
        self.criterion = nn.MSELoss()

    def forward(self, state, action) -> Tuple[torch.Tensor, torch.Tensor]:
        q1, q2 = self.q_net1(state, action)
        #q2 = self.q_net2(state, action)
        return q1, q2

    def update(self, state, action, target_q):
        pred_q1, pred_q2 = self.forward(state, action)
        q1_loss = self.criterion(pred_q1, target_q)
        q2_loss = self.criterion(pred_q2, target_q)
        critic_loss = q1_loss + q2_loss
        self.optim.zero_grad()
        critic_loss.backward()
        self.optim.step()
        return critic_loss.item()

    
    def soft_update(self, new, tau):
        for old_p, new_p in zip(self.parameters(), new.parameters()):
            old_p.data.copy_(old_p.data * (1.0 - tau) + new_p.data * tau)
    


class Actor(nn.Module):
    def __init__(self, 
                 input_dim : int, 
                 n_actions : int, 
                 hidden_dim : List[int], 
                 log_std_min : float,
                 log_std_max : float, 
                 lr : float):

        super(Actor, self).__init__()

        #self.policy_net = PolicyNetwork(input_dim, n_actions, hidden_dim, log_std_min, log_std_max)
        self.policy_net = GaussianPolicy(input_dim, n_actions, 256)
        self.optim = optim.Adam(self.policy_net.parameters(), lr, weight_decay=1e-9)


    def forward(self, state):
        mean, log_std = self.policy_net(state)
        return mean, log_std

    def criterion(self, log_p, pred_q, alpha):
        return (alpha * log_p - pred_q).mean()

    def update(self, log_p, pred_q, alpha):
        policy_loss = self.criterion(log_p, pred_q, alpha)
        self.optim.zero_grad()
        policy_loss.backward()
        self.optim.step()
        return policy_loss.item()

    def eval(self, state, epsilon=1e-6):
        
        mean, log_std = self.forward(state)
        std = log_std.exp()
        #print('mean', mean.size())
        #print('std', std.size())
        normal = Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        action = torch.tanh(x_t)
        #print('action', action.size())
        log_prob = normal.log_prob(x_t)
        #print('log prob', log_prob.size())
        # Enforcing Action Bound
        log_prob = (log_prob - torch.log(1 - action.pow(2) + epsilon)).sum(1, keepdim=True)
        #print('log prob', log_prob.size())
        mean = torch.tanh(mean)
        #return action, log_prob, mean

    def action(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        return torch.tanh(x_t)





In [9]:
import torch
#from model import Critic, Actor
from torch import optim
import os
#from replay_buffer import ReplayBuffer
from torch import Tensor
import numpy as np
import torch.nn.functional as F

class SAC:
    def __init__(self, 
                 input_dim : int, 
                 n_actions : int, 
                 replay_buffer : ReplayBuffer, 
                 args):

        self.gamma = args.gamma
        self.tau = args.tau
        self.batch_size = args.batch_size

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        replay_buffer.device = self.device
        self.replay_buffer = replay_buffer

        self.actor = Actor(input_dim, n_actions, args.hidden_dim_p, 
                           args.log_std_min, args.log_std_max, args.lr_p).to(self.device)

        # Critic and critic target
        self.critic = Critic(input_dim, n_actions, args.hidden_dim_q, args.lr_c).to(self.device)
        self.critic_t = Critic(input_dim, n_actions, args.hidden_dim_q, args.lr_c).to(self.device)
        for target_param, param in zip(self.critic_t.parameters(), self.critic.parameters()):
            target_param.data.copy_(param.data)
        #self.freeze_params(self.critic_t)

        self.policy_optimizer = self.actor.optim
        self.critic_optimizer = self.critic.optim

        self.alpha = torch.Tensor([args.alpha]).to(self.device)
        self.target_entropy = -torch.Tensor([n_actions]).to(self.device).item()
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optimizer = optim.Adam([self.log_alpha], args.lr_a)

    def learning_step(self):

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        
        with torch.no_grad():
            next_actions, next_logs_pi, _ = self.actor.policy_net.sample(next_states)
            target_q1, target_q2 = self.critic_t(next_states, next_actions)
            min_q_t = torch.min(target_q1, target_q2)
            target_q = rewards + (1 - dones) * self.gamma * (min_q_t - self.alpha * next_logs_pi)

        # Updating Q1 and Q2 critic networks
        critic_loss = self.critic.update(states, actions, target_q)

        pred_actions, log_prob, _ = self.actor.policy_net.sample(states)
        q1, q2 = self.critic(states, pred_actions)
        min_q = torch.min(q1, q2)

        # Updating Policy Network
        #self.freeze_params(self.critic)
        policy_loss = self.actor.update(log_prob, min_q, self.alpha)
        #self.unfreeze_params(self.critic)
        
        # Automatic temperature parameter tuning
        #entrophy_loss = self.alpha_tuning(log_prob)

        # Updating target critic networks
        self.critic_t.soft_update(self.critic, self.tau)
        
        return {'critic_loss' : critic_loss,
                'policy_loss' : policy_loss}
                #'entrophy_loss' : entrophy_loss}
        



    def alpha_tuning(self, log_prob):
        entrophy_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()

        self.alpha_optimizer.zero_grad()
        entrophy_loss.backward()
        self.alpha_optimizer.step()

        self.alpha = self.log_alpha.exp()
        return entrophy_loss.item()

    @torch.no_grad()
    def get_action(self, state, evaluate):
        state = torch.Tensor(state).float().to(self.device).unsqueeze(0)
        if not evaluate:
            action, _, _ = self.actor.policy_net.sample(state)
        else:
            _, _, action = self.actor.policy_net.sample(state)
        return action.detach().cpu().numpy()[0]


    def save_model(self, env_name, episode):
        if not os.path.exists('checkpoints/'):
            os.makedirs('checkpoints/')
        ckpt_path = f"checkpoints/{env_name}_sac_episode_{episode}"
        print(f'Saving models to {ckpt_path}')
        torch.save({'policy_state_dict': self.actor.state_dict(),
                    'critic_state_dict': self.critic.state_dict(),
                    'critic_target_state_dict': self.critic_t.state_dict(),
                    'policy_optimizer_state_dict': self.policy_optimizer.state_dict(),
                    'critic_optimizer_state_dict': self.critic_optimizer.state_dict()}, 
                   ckpt_path)

    # Load model parameters
    def load_checkpoint(self, ckpt_path):
        print('Loading models from {}'.format(ckpt_path))
        if ckpt_path is not None:
            checkpoint = torch.load(ckpt_path)
            self.actor.load_state_dict(checkpoint['policy_state_dict'])
            self.critic.load_state_dict(checkpoint['critic_state_dict'])
            self.critic_t.load_state_dict(checkpoint['critic_target_state_dict'])
            self.policy_optimizer.load_state_dict(checkpoint['policy_optimizer_state_dict'])
            self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])
      
    def freeze_params(self, model):
        for p in model.parameters():
            p.requires_grad = False
    
    def unfreeze_params(self, model):
        for p in model.parameters():
            p.requires_grad = True



In [32]:
env = wrap_env(gym.make("Walker2d-v4"))

In [34]:
action_dim = env.action_space.shape[0]
state_dim  = env.observation_space.shape[0]

replay_buffer = ReplayBuffer(1e4, 1e5)

agent = SAC(state_dim, action_dim, replay_buffer, args)
agent.load_checkpoint('/content/Walker2d-v4_sac_episode_450')
state = env.reset()
done = False
#for i in range(50):
while not done: 
    env.render(mode='rgb_array')
    action = agent.get_action(state, True)
    new_state, reward, done, _ = env.step(action)
    '''
    action = env.action_space.sample() 
         
    observation, reward, done, info = env.step(action) 
   '''
    state = new_state
    if done: 
      break;
            
env.close()
show_video()

Loading models from /content/Walker2d-v4_sac_episode_450


In [15]:
name = env.unwrapped.spec.id

In [16]:
name

'HalfCheetah-v4'

In [17]:
!nvidia-smi


Wed Jan 11 20:14:50 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   53C    P0    27W /  70W |    806MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces