# Run SAC with Discrete Action for Atari Environments

Github Repo: git@github.com:XiaohanZhangCMU/sderl.git



# Work node preparation

In [None]:
from google.colab import drive
%reload_ext autoreload
%autoreload 2
drive.mount('/content/gdrive')

In [None]:
import os
# os.chdir('/content/gdrive/My Drive/RL')
# !echo -e "Host github.com\n\tStrictHostKeyChecking no\n" >> ~/.ssh/config
# !rm -rf sderl
# !git clone --single-branch --branch sac_atari https://github.com/XiaohanZhangCMU/sderl.git
os.chdir('/content/gdrive/My Drive/RL/sderl')
!pip install -e .

# Run LunarLander for Debug

In [None]:
%%time
# !python sderl/algos/pytorch/sac_atari/sac.py --env LunarLander-v2  --epochs 1 --update_after=0

from copy import deepcopy
import itertools
import numpy as np
import torch
from torch.optim import Adam
import gym
import time
from sderl.utils.logx import EpochLogger
from sderl import sac_atari_pytorch as sac_atari
import sderl.algos.pytorch.sac_atari.core as core

class Args:
    env = 'LunarLander-v2'
    feng = 'mlp'
    hid = 64
    l = 2
    gamma = 0.99
    seed = 0
    steps = 5000
    update_after = 1000
    epochs = 256
    exp_name = 'abc'

args = Args()

from sderl.utils.run_utils import setup_logger_kwargs
logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)
torch.set_num_threads(torch.get_num_threads())

genv = gym.make(args.env)
if 'Atari' in str(genv.env):
    from sderl.utils.atari_wrappers import make_env
    lambda_env = lambda:make_env(args.env)
else:
    lambda_env = lambda:gym.make(args.env)

sac_atari(lambda_env, actor_critic=core.MLPActorCritic,
    ac_kwargs=dict(hidden_sizes=[args.hid]*args.l, feng=args.feng),
    gamma=args.gamma, batch_size=256, seed=args.seed, steps_per_epoch=args.steps, 
    update_after=args.update_after, epochs=args.epochs, 
    logger_kwargs=logger_kwargs)

# Run LunarLandarContinuous for Debug

In [None]:
%%time
# !python sderl/algos/pytorch/sac_atari/sac.py --env LunarLander-v2  --epochs 1 --update_after=0

from copy import deepcopy
import itertools
import numpy as np
import torch
from torch.optim import Adam
import gym
import time
from sderl.utils.logx import EpochLogger
from sderl import sac_atari_pytorch as sac_atari
import sderl.algos.pytorch.sac_atari.core as core

class Args:
    env = 'LunarLanderContinuous-v2'
    feng = 'mlp'
    hid = 64
    l = 2
    gamma = 0.99
    seed = 0
    steps = 5000
    update_after = 1000
    epochs = 1000
    exp_name = 'abc'

args = Args()

from sderl.utils.run_utils import setup_logger_kwargs
logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

torch.set_num_threads(torch.get_num_threads())

genv = gym.make(args.env)
if 'Atari' in str(genv.env):
    from sderl.utils.atari_wrappers import make_env
    lambda_env = lambda:make_env(args.env)
else:
    lambda_env = lambda:gym.make(args.env)

sac_atari(lambda_env, actor_critic=core.MLPActorCritic,
    ac_kwargs=dict(hidden_sizes=[args.hid]*args.l, feng=args.feng),
    gamma=args.gamma, batch_size=256, seed=args.seed, steps_per_epoch=args.steps, 
    update_after=args.update_after, epochs=args.epochs, 
    logger_kwargs=logger_kwargs)


# Run sac_atari from shell

In [None]:
%%time
# !python sderl/algos/pytorch/sac_atari/sac.py --env LunarLander-v2  --epochs 1 --update_after=0

from copy import deepcopy
import itertools
import numpy as np
import torch
from torch.optim import Adam
import gym
import time
from sderl.utils.logx import EpochLogger
from sderl import sac_atari_pytorch as sac_atari
import sderl.algos.pytorch.sac_atari.core as core

class Args:
    env = 'Pong-v0'
    feng = 'mlp'
    hid = 64
    l = 2
    gamma = 0.99
    seed = 0
    steps = 5000
    update_after = 1000
    epochs = 1000
    exp_name = 'abc'

args = Args()

from sderl.utils.run_utils import setup_logger_kwargs
logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

torch.set_num_threads(torch.get_num_threads())

genv = gym.make(args.env)
if 'Atari' in str(genv.env):
    from sderl.utils.atari_wrappers import make_env
    lambda_env = lambda:make_env(args.env)
else:
    lambda_env = lambda:gym.make(args.env)

sac_atari(lambda_env, actor_critic=core.MLPActorCritic,
    ac_kwargs=dict(hidden_sizes=[args.hid]*args.l, feng=args.feng),
    gamma=args.gamma, batch_size=256, seed=args.seed, steps_per_epoch=args.steps, 
    update_after=args.update_after, epochs=args.epochs, 
    logger_kwargs=logger_kwargs)
