In [5]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [6]:
lr = 0.0002
gamma = 0.99
n_rollout =10

In [82]:
def main():
    env = gym.make('CartPole-v1')
    model = ActorCritic()
    print_interval = 20
    score = 0
    
    for n_epi in range(1000):
        done = False
        s = env.reset()
        while not done:
            for t in range(n_rollout):
                prob = model.pi(torch.from_numpy(s).float())
                
                m = Categorical(prob)
                #print(m)
                a = m.sample().item()
                #print(a)
                s_prime,r,done,info = env.step(a)
                model.put_data((s,a,r,s_prime,done))
                
                s = s_prime
                score +=r
                
                if done:
                    break
            model.train_net()
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode: {}, avg score : {:.1f}".format(n_epi,score/print_interval)) 
            score =0
    env.close()

In [126]:
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic,self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4,256)
        self.fc_pi = nn.Linear(256,2)
        self.fc_v = nn.Linear(256,1)
        self.optimizer = optim.Adam(self.parameters(),lr=lr)
        
    def pi(self,x, softmax_dim=0):
        x= F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x,dim=softmax_dim)
        return prob
    
    def v(self,x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v
    
    def put_data(self,transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst,a_lst,r_lst,s_prime_lst,done_lst = [],[],[],[],[]
        
        for transition in self.data:
            s,a,r,s_prime,done = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r/100])
            s_prime_lst.append(s_prime)
            done_mask = 0 if done else 1
            done_lst.append([done_mask])
        s_batch,a_batch,r_batch,s_prime_batch,done_batch = torch.tensor(s_lst,dtype=torch.float), \
        torch.tensor(a_lst),torch.tensor(r_lst,dtype=torch.float),torch.tensor(s_prime,dtype=torch.float),\
        torch.tensor(done_lst,dtype=torch.float)
        
        self.data = []
        return s_batch,a_batch,r_batch,s_prime_batch,done_batch
    
    def train_net(self):
        s,a,r,s_prime,done = self.make_batch()
        td_target = r+gamma*self.v(s_prime)*done
        delta = td_target - self.v(s)
        
        pi = self.pi(s,softmax_dim=1)
        pi_a = pi.gather(1,a)
        loss =-torch.log(pi_a)*delta.detach() + F.smooth_l1_loss(self.v(s),td_target.detach())
        
        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

In [127]:
main()

# of episode: 20, avg score : 23.6
# of episode: 40, avg score : 18.4
# of episode: 60, avg score : 23.2
# of episode: 80, avg score : 22.0
# of episode: 100, avg score : 22.6
# of episode: 120, avg score : 19.1
# of episode: 140, avg score : 33.3
# of episode: 160, avg score : 26.4
# of episode: 180, avg score : 26.6
# of episode: 200, avg score : 24.6
# of episode: 220, avg score : 26.9
# of episode: 240, avg score : 24.7
# of episode: 260, avg score : 29.9
# of episode: 280, avg score : 31.7
# of episode: 300, avg score : 32.1
# of episode: 320, avg score : 33.4
# of episode: 340, avg score : 34.5
# of episode: 360, avg score : 38.0
# of episode: 380, avg score : 41.5
# of episode: 400, avg score : 46.0
# of episode: 420, avg score : 46.2
# of episode: 440, avg score : 47.5
# of episode: 460, avg score : 46.5
# of episode: 480, avg score : 44.4
# of episode: 500, avg score : 56.5
# of episode: 520, avg score : 52.0
# of episode: 540, avg score : 54.1
# of episode: 560, avg score : 5

In [109]:
env = gym.make('FrozenLake-v1')

In [110]:
s = env.reset()

In [106]:
env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)