In [104]:
import torch # finally get to use pytorch again
import torch.nn as nn
import torch.nn.functional as F
import random
from collections import deque

import gymnasium as gym
import numpy as np

In [105]:
env = gym.make("CartPole-v1", render_mode="human")
observation, info = env.reset()
action = env.action_space.sample()

def generate_episode(n_ep, env):
    for _ in range(n_ep):
        observation, info = env.reset()

        while True:
            action = env.action_space.sample()
            observation, reward, terminated, truncated, info = env.step(action)
            
            if terminated or truncated:
                break

In [106]:
class Q_network(nn.Module):
    def __init__(self, in_dim, out_dim) -> None:
        super().__init__()
        # note to self: nn.Linear() represents the transformation, not the matrices themselves.
        self.il = nn.Linear(in_dim, 50)
        self.ol = nn.Linear(50, out_dim)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = torch.tensor(x)
        x = F.relu(self.il(x))
        return F.relu(self.ol(x))
 
class Replay_Memory():
    def __init__(self, cap):
        self.memory = deque([], maxlen = cap)

    def push(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        
    def sample(self, batch_size):
        mem_sample = random.sample(self.memory, batch_size)

        state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*mem_sample)

        state_batch = torch.from_numpy(np.array(state_batch)).float()
        action_batch = torch.tensor(action_batch).long()
        reward_batch = torch.tensor(reward_batch).float()
        next_state_batch = torch.from_numpy(np.array(next_state_batch)).float()
        done_batch = torch.tensor(done_batch).float()

        return [state_batch, action_batch, reward_batch, next_state_batch, done_batch]

    def __len__(self):
        return len(self.memory)

def epsilon_greedy(Q_net, S, epsilon):
    r = torch.rand(1)
    if r > epsilon:
        return int(torch.argmax(Q_net.forward(S)).item())
    else:
        return random.randint(0, 1)

def DQL(_env, n_ep):
    state_dim = _env.observation_space.shape[0]
    action_dim = _env.action_space.n
    
    buffer_cap = 100000
    replay_buffer = Replay_Memory(buffer_cap)
    
    behaviour_net = Q_network(state_dim, action_dim)
    target_net = Q_network(state_dim, action_dim)
    
    optimizer = torch.optim.Adam(behaviour_net.parameters(), lr=0.01)
    loss_func = nn.MSELoss();
    
    decay_min = 0.01
    decay_const = 0.01
    
    successes = 0
    _gamma = 0.99
    # _alpha = 0.5

    for i in range(n_ep):
        epsilon = decay_min + (1 - decay_min) * np.exp(-decay_const * i)
        observation, info = _env.reset()
        S = observation
        done = False

        while not done:
            A = epsilon_greedy(behaviour_net, S, epsilon)
            S_prime, R, terminated, truncated, info = _env.step(A)
            done = terminated or truncated
            replay_buffer.push(S, A, R, S_prime, done)

            S = S_prime
            
            if len(replay_buffer) > 64:
                # sample from replay buffer
                state_batch, action_batch, reward_batch, next_state_batch, done_batch = replay_buffer.sample(64)
                # feed these values to Q-network, get the Q-values
                q_values = behaviour_net(state_batch)
                # do magic shit to get Q-value for the action taken
                specific_qs = q_values.gather(1, action_batch.unsqueeze(1)).squeeze(1)
                # get max q-value from target net
                with torch.no_grad():
                    specific_target_qs = target_net(next_state_batch).max(1)[0]
                # calculate target q-value
                target_q_values = reward_batch + _gamma * specific_target_qs * (1 - done_batch)                
                #loss
                loss = loss_func(specific_qs, specific_target_qs)
                #backprop
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                    
            if (i % 10) == 0:
                target_net.load_state_dict(behaviour_net.state_dict())
            
                
    print("Number of successes:", successes)
    return behaviour_net

In [107]:
Q_net = DQL(env, 20)
env.close()

  x = torch.tensor(x)


Number of successes: 18
