In [68]:
import time
from collections import deque, namedtuple

import gymnasium as gym
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import imageio
import random


In [69]:
class QNetwork(nn.Module):
    def __init__(self, state_size: int, action_size: int):
        super(QNetwork, self).__init__()
        
        self.hidden_size = 32
        
        self.layers = nn.Sequential(
            nn.Linear(state_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, action_size)
        )
        
        self.layers = self.layers.double() #  set network param to double
        
    def forward(self, state):
        q_values = self.layers(state.double())
        return q_values

In [147]:
class DoublePendulumAgent():
    def __init__(self, space_dim, action_dim):
        self.TAU = 1e-3  # Soft update parameter.
        self.ALPHA = 1e-3 #learning rate
        self.MINIBATCH_SIZE = 64
        self.GAMMA = 0.995
        self.epsilon = 1.0
        self.E_MIN = 0.01
        self.network = QNetwork(space_dim, action_dim)
        self.target_network = QNetwork(space_dim, action_dim)
        self.optimizer = torch.optim.AdamW(
            self.network.parameters(),
            lr=self.ALPHA)
        
    def get_action(self, q_values, env):
        if random.random() > self.epsilon:
            next_action = q_values.detach().numpy()[0]
            print(" ekki random next action : ", next_action)
            print("inni if shape: ", next_action.shape)
            return np.argmax(next_action)
        else:
            next_action = np.array([np.random.uniform(-3, 4)])
            print("next action : ",next_action)
            print("ekki if shape: ", next_action.shape)

            return next_action
                     
        # next_action = q_values.detach().numpy()[0]
        # return next_action
        # return np.argmax(q_values.detach().numpy())
        
    def should_update_network(self, num_episode, num_steps_upd, memory_buffer):
        return num_episode % num_steps_upd == 0 and len(memory_buffer) > self.MINIBATCH_SIZE
    
    def agent_learn(self, experiences):
        """
        Updates the weights of the Q networks.
        Args:
        experiences: (tuple) tuple of ["state", "action", "reward", "next_state", "done"] namedtuples
        gamma: (float) The discount factor.
        q_network: PyTorch model for the Q-network.
        target_q_network: PyTorch model for the target Q-network.
        optimizer: PyTorch optimizer.
        """
        # Zero the gradients before backpropagation
        self.optimizer.zero_grad()

        # Calculate the loss
        loss = self.compute_loss(experiences)

        # Backpropagate the loss
        loss.backward()

        # Update the weights of the q_network
        self.optimizer.step()

        # Update the weights of target q_network
        self.update_target_network()
        
    def get_experiences(self, memory_buffer):
        experiences = random.sample(memory_buffer, k=self.MINIBATCH_SIZE)
        states = torch.tensor(
            [e.state for e in experiences if e is not None], dtype=torch.double
        )
        actions = torch.tensor(
            [e.action for e in experiences if e is not None], dtype=torch.double
        )
        rewards = torch.tensor(
            [e.reward for e in experiences if e is not None], dtype=torch.double
        )
        next_states = torch.tensor(
            [e.next_state for e in experiences if e is not None], dtype=torch.double
        )
        done_vals = torch.tensor(
            [e.done for e in experiences if e is not None], dtype=torch.uint8
        ).float()  # Convert to float after creating the tensor TODO: check if this is correct

        return states, actions, rewards, next_states, done_vals

    
    def compute_loss(self, experiences):
        """
        Calculates the loss.

        Args:
        experiences: namedtuple with fields ["state", "action", "reward", "next_state", "done"]
        gamma: discount factor.
        q_network: PyTorch model for predicting the q_values.
        target_q_network: PyTorch model for predicting the targets.

        Returns:
        loss: Mean-Squared Error between the y targets and the Q(s,a) values.
        """
        # Unpack the mini-batch of experience tuples
        states, actions, rewards, next_states, done_vals = experiences

        # Compute max Q^(s,a) using target network
        with torch.no_grad():  # No gradient computation for target network
            max_qsa = self.target_network(next_states).max(1)[0]  # max function returns both values and indices

        # Set y = R if episode terminates, otherwise set y = R + γ max Q^(s,a)
        y_targets = rewards + (self.GAMMA * max_qsa * (1 - done_vals))

        # Get the q_values for the actions taken
        q_values = self.network(next_states).max(1)[0]

        # Compute the loss (Mean Squared Error)
        loss = torch.nn.functional.mse_loss(q_values, y_targets)
        
        # # Generate a batch of data
        # states, actions, rewards, next_states, done_vals = experiences
        # # Generate a batch of data
        # states, actions, rewards, next_states, done_vals = experiences

        # max_qsa = torch.max(self.target_network(next_states), dim=1)[0]
        # y_targets = rewards + (self.GAMMA * max_qsa * (1 - done_vals))
        # q_values = self.network(states)
        # loss = torch.nn.functional.binary_cross_entropy_with_logits(q_values, y_targets.unsqueeze(1))

        return loss
    
    def update_epsilon(self, num_episodes):
        
        self.epsilon = max(self.E_MIN, self.epsilon * ( self.epsilon / (num_episodes / 2)))
    
    def update_target_network(self):
        for target_param, q_net_param in zip(self.target_network.parameters(), self.network.parameters()):
            target_param.data.copy_(self.TAU * q_net_param.data + (1.0 - self.TAU) * target_param.data)
        

In [154]:
start = time.time()
MAX_NUM_STEPS = 1000 #terminates after 1000 steps
MAX_NUM_EPISODES = 40000 
NUM_P_AVG = 100
UPDATE_NETWORK_STEPS = 7

env = gym.make('InvertedDoublePendulum-v4', render_mode="rgb_array")
STATE_SIZE = env.observation_space.shape[0]
NUM_ACTIONS = env.action_space.shape[0]

experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
memory_buffer = deque(maxlen=10000)
total_points_hist = []
agent = DoublePendulumAgent(STATE_SIZE, NUM_ACTIONS)

for i in range(MAX_NUM_EPISODES):
    state, _ = env.reset()
    total_points = 0
    # for t in range(MAX_NUM_STEPS):
    done = False
    t = 0
    
    while not done:
        t+=1
        state_qn = torch.from_numpy(np.expand_dims(state, axis=0))
        q_values = agent.network(state_qn)
        action = agent.get_action(q_values, env)
        next_state, reward, next_done, info, _ = env.step(action)
        memory_buffer.append(experience(state, action, reward, next_state, done))
        
        if(agent.should_update_network(t, UPDATE_NETWORK_STEPS, memory_buffer)):
            minibatch = agent.get_experiences(memory_buffer)
            agent.agent_learn(minibatch)
            
        state = next_state.copy()
        total_points += reward
        
    total_points_hist.append(total_points)
    agent.update_epsilon(t)
    
    avg_latest_points = np.mean(total_points_hist[-NUM_P_AVG:])
    print(f"\rEpisode {i+1} | Total point average of the last {NUM_P_AVG} episodes: {avg_latest_points:.2f}: epsilon = {agent.epsilon}", end="")

    if (i+1) % NUM_P_AVG == 0:
            print(f"\rEpisode {i+1} | Total point average of the last {NUM_P_AVG} episodes: {avg_latest_points:.2f}")
            
    if avg_latest_points >= 300:
        print(f"\n\nEnvironment solved in {i+1} episodes!")
        torch.save(agent.network.state_dict(), 'Cart_pole_model.pth')
        break
    
tot_time = time.time() - start

print(f"\nTotal Runtime: {tot_time:.2f} s ({(tot_time/60):.2f} min)")


next action :  [-2.17301732]
ekki if shape:  (1,)
next action :  [2.42176495]
ekki if shape:  (1,)
next action :  [-0.42929438]
ekki if shape:  (1,)
next action :  [-0.37099673]
ekki if shape:  (1,)
next action :  [1.11454577]
ekki if shape:  (1,)
next action :  [-1.16885232]
ekki if shape:  (1,)
Episode 1 | Total point average of the last 100 episodes: 54.49: epsilon = 0.3333333333333333 ekki random next action :  [-0.1735309]
inni if shape:  (1,)


ValueError: Action dimension mismatch. Expected (1,), found ()

In [155]:
def create_video(filename, env, q_network, fps=30):
    start = time.time()
    with imageio.get_writer(filename, fps=fps) as video:
        done = False
        state, _ = env.reset()
        frame = env.render()
        video.append_data(frame)
        while not done:
            if time.time() - start > 10:
                break
            state = torch.from_numpy(np.expand_dims(state, axis=0))
            q_values = q_network(state)
            action = np.argmax(q_values.detach().numpy()[0])
            state, _, done, _, _ = env.step(action)
            frame = env.render()
            video.append_data(frame)

    env.close()

In [129]:
# create_video("double_pendulum.mp4", env, agent.network)
env.render()

array([[[78, 88, 78],
        [78, 88, 78],
        [78, 88, 78],
        ...,
        [78, 88, 78],
        [78, 88, 78],
        [78, 88, 78]],

       [[78, 88, 78],
        [78, 88, 78],
        [78, 88, 78],
        ...,
        [78, 88, 78],
        [78, 88, 78],
        [78, 88, 78]],

       [[78, 88, 78],
        [78, 88, 78],
        [78, 88, 78],
        ...,
        [78, 88, 78],
        [78, 88, 78],
        [78, 88, 78]],

       ...,

       [[78, 88, 78],
        [78, 88, 78],
        [78, 88, 78],
        ...,
        [78, 88, 78],
        [78, 88, 78],
        [78, 88, 78]],

       [[78, 88, 78],
        [78, 88, 78],
        [78, 88, 78],
        ...,
        [78, 88, 78],
        [78, 88, 78],
        [78, 88, 78]],

       [[78, 88, 78],
        [78, 88, 78],
        [78, 88, 78],
        ...,
        [78, 88, 78],
        [78, 88, 78],
        [78, 88, 78]]], dtype=uint8)