In [36]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [37]:
#Hyperparameters
learning_rate = 0.0002
gamma         = 0.98
n_rollout     = 10

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [38]:
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4,256)
        self.fc_pi = nn.Linear(256,2)
        self.fc_v = nn.Linear(256,1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
    def pi(self, x, softmax_dim = 0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob
    
    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v
    
    def put_data(self, transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, done_lst = [], [], [], [], []
        for transition in self.data:
            s,a,r,s_prime,done = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r/100.0])
            s_prime_lst.append(s_prime)
            done_mask = 0.0 if done else 1.0
            done_lst.append([done_mask])
        
        s_batch = torch.tensor(s_lst, dtype=torch.float).to(device) 
        a_batch = torch.tensor(a_lst).to(device)
        r_batch = torch.tensor(r_lst, dtype=torch.float).to(device)
        s_prime_batch = torch.tensor(s_prime_lst, dtype=torch.float).to(device)
        done_batch = torch.tensor(done_lst, dtype=torch.float).to(device)
        
        self.data = []
        return s_batch, a_batch, r_batch, s_prime_batch, done_batch
  
    def train_net(self):
        s, a, r, s_prime, done = self.make_batch()
        td_target = r + gamma * self.v(s_prime) * done
        delta = td_target - self.v(s)
        
        pi = self.pi(s, softmax_dim=1)
        pi_a = pi.gather(1,a)
        loss = -torch.log(pi_a) * delta.detach() + F.smooth_l1_loss(self.v(s), td_target.detach())

        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()         
      

In [39]:
def main():
    env = gym.make('CartPole-v1')
    model = ActorCritic().to(device)
    print_interval = 20
    score = 0.0

    for n_epi in range(10000):
        done = False
        s, _ = env.reset()
        s = torch.from_numpy(s).float().to(device)
        while not done:
            for t in range(n_rollout):
                prob = model.pi(s)
                m = Categorical(prob)
                a = m.sample().item()
                s_prime, r, done, truncated, info = env.step(a)
                s_prime = torch.from_numpy(s_prime).float().to(device)
                model.put_data((s.cpu().numpy(),a,r,s_prime.cpu().numpy(),done))
                
                s = s_prime
                score += r
                
                if done:
                    break                     
            
            model.train_net()
            
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval))
            score = 0.0
    env.close()

In [None]:
if __name__ == '__main__':
    main()

# of episode :20, avg score : 25.9
# of episode :40, avg score : 16.8
# of episode :60, avg score : 20.9
# of episode :80, avg score : 24.7
# of episode :100, avg score : 18.6
# of episode :120, avg score : 22.9
# of episode :140, avg score : 17.6
# of episode :160, avg score : 17.0
# of episode :180, avg score : 17.6
# of episode :200, avg score : 16.5
# of episode :220, avg score : 16.6
# of episode :240, avg score : 19.2
# of episode :260, avg score : 20.8
# of episode :280, avg score : 23.3
# of episode :300, avg score : 16.6
# of episode :320, avg score : 26.7
# of episode :340, avg score : 23.8
# of episode :360, avg score : 25.6
# of episode :380, avg score : 26.0
# of episode :400, avg score : 29.8
# of episode :420, avg score : 38.2
# of episode :440, avg score : 36.5
# of episode :460, avg score : 35.9
# of episode :480, avg score : 39.6
# of episode :500, avg score : 36.2
# of episode :520, avg score : 39.9
# of episode :540, avg score : 46.0
# of episode :560, avg score : 6