In [1]:
import gym
import random
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from collections import namedtuple, deque
import time
from ale_py import ALEInterface
import imageio
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

ale = ALEInterface()

# Inicialización del entorno
#env = gym.make("Assault-v0", render_mode="rgb_array")
env = gym.make("Assault-v4") #PROVAR AQUESTA VERSIO DEL MODEL
n_actions = env.action_space.n

# Parámetros modificados para una mayor exploración inicial y un decaimiento más lento
EPSILON_START = 1.0
EPSILON_END = 0.01
EPSILON_DECAY = 10000  # Aumenta para un decaimiento más lento
EPISODES = 5 # Más episodios para permitir un aprendizaje más prolongado
TARGET_UPDATE = 5
BATCH_SIZE = 128
GAMMA = 0.999
MAX_STEPS_PER_EPISODE = 500

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
        self.transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'done'))

    def push(self, *args):
        """Save a transition"""
        self.memory.append(self.transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class DQN(keras.Model):

    def __init__(self, n_actions):
        super(DQN, self).__init__()

        self.layer1 = layers.Conv2D(16, 5, strides=2, activation="relu")
        self.bn1 = layers.BatchNormalization()
        self.layer2 = layers.Conv2D(16, 5, strides=2, activation="relu")
        self.bn2 = layers.BatchNormalization()
        self.layer3 = layers.Conv2D(32, 5, strides=2, activation="relu")
        self.bn3 = layers.BatchNormalization()
        self.flatten = layers.Flatten()
        self.layer4 = layers.Dense(512, activation="relu")
        self.action = layers.Dense(n_actions, activation="linear")

    def call(self, inputs):
        x = self.layer1(inputs)
        x = self.bn1(x)
        x = self.layer2(x)
        x = self.bn2(x)
        x = self.layer3(x)
        x = self.bn3(x)
        x = self.flatten(x)
        x = self.layer4(x)
        return self.action(x)

# Creación del modelo y la memoria
model = DQN(n_actions)
model_target = DQN(n_actions)
memory = ReplayMemory(10000)

# Preparación del optimizador y la función de pérdida
optimizer = keras.optimizers.Adam(learning_rate=2.5e-4, clipnorm=1.0)
loss_function = keras.losses.Huber()

def take_action(state, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()
    else:
        q_values = model.predict(state[np.newaxis, ...])
        return np.argmax(q_values[0])

def optimize_modelDDQN():
    if memory.__len__() < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = memory.transition(*zip(*transitions))

    state_batch = np.array(batch.state)
    action_batch = np.array(batch.action)
    next_state_batch = np.array(batch.next_state)
    rewad_batch = np.array(batch.reward)
    done_batch = np.array(batch.done, dtype=np.int8)

    # Calculate Q-values for the next state using the online model
    q_values_next_state_online = model(next_state_batch)

    # Use the online model to select actions for the next state
    next_actions_online = np.argmax(q_values_next_state_online, axis=-1)

    # Use the target model to calculate Q-values for the next state and selected actions
    q_values_next_state_target = model_target(next_state_batch)
    q_values_next_state_target_selected = tf.reduce_sum(
        tf.one_hot(next_actions_online, n_actions) * q_values_next_state_target,
        axis=-1
    )

    target = rewad_batch + GAMMA * q_values_next_state_target_selected * (1 - done_batch)

    action_mask = tf.one_hot(action_batch, n_actions)

    with tf.GradientTape() as tape:
        q_values = model(state_batch)
        q_action = tf.reduce_sum(tf.multiply(q_values, action_mask), axis=-1)
        loss = loss_function(target, q_action)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))


    #model.save_weights('/Users/roy/Desktop/UNI')

def display_frames(frames):
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = plt.animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    display(anim)

def save_gif(frames, episode_number):
    gif_path = f"episode_{episode_number}.gif"
    imageio.mimsave(gif_path, frames, format='GIF', fps=30)
    return gif_path



episode_rewards = []
losses = []

best_reward = float("-inf")
best_episode = 0
best_frames = []

# Entrenamiento del agente
epsilon = EPSILON_START
for episode in range(EPISODES):
    #reset_result = env.reset()
    #state = reset_result[0] / 255.0
    #info = reset_result[1]

    state = env.reset() / 255.0
    done = False
    episode_reward = 0
    steps = 0
    info ={'ale.lives': 4, 'episode_frame_number': 2, 'frame_number': 2}

    frames = []
    current_frames = []  # Almacena los frames del episodio actual
    episode_reward = 0

    while not done and steps < MAX_STEPS_PER_EPISODE and info.get("ale.lives") >= 0:
        frame = env.render(mode='rgb_array')
        frames.append(frame)
        #ESTO ES NUEVO- REVISAR
        current_frames.append(frame)
        # Comprobación y actualización de la mejor recompensa
        
        if episode_reward > best_reward:
            best_reward = episode_reward
            best_episode = episode
            best_frames = current_frames  # Actualiza los frames del mejor episodio

        action = take_action(state, epsilon)
        step_result = env.step(action)
        #print(step_result)
        #next_state, reward, done, _ , info = step_result
        next_state, reward, done, info = step_result

        next_state = next_state / 255.0
        #print("info: ", info)

        memory.push(state, action, next_state, reward, done)
        optimize_modelDDQN()

        state = next_state
        episode_reward += reward

        if reward != 0:
            print("step: ", steps, "action: ", action, " reward: ", reward)
            print("Lives: ", info.get("ale.lives"))

        steps += 1
        
        epsilon = max(epsilon - (EPSILON_START - EPSILON_END) / EPSILON_DECAY, EPSILON_END)
        
        
    print(f"\nEpisodio: {episode+1}, Recompensa: {episode_reward}, Epsilon: {epsilon}")
    # Display only the frames from the best episode
    # Save GIF
    gif_path = save_gif(frames, episode+1)

    # Display GIF
    clear_output(wait=True)
    display({'image/png': open(gif_path, 'rb').read()})


    #gif_path = f"episode_{episode+1}.gif"
    #gif_path = f"/workspaces/RL_Project/Assault_gifs/episode_{episode+1}.gif"
    #gif_path = f"/workspaces/RL_Project/Assault_gifs/best_episode_{best_episode+1}.gif"
    #imageio.mimsave(gif_path, frames, format='GIF', fps=30)
    
    #wandb.log({"episode": episode + 1, "reward": episode_reward, "epsilon": epsilon})
    #episode_rewards.append(episode_reward)
    
    if (episode + 1) % TARGET_UPDATE == 0:
        model_target.set_weights(model.get_weights())


env.close()


{'image/png': b'GIF89a\xa0\x00\xd2\x00\x87\x00\x00\xd6\xd6\xd6\xbb\xbb5\xaa\xaa\xaa\xa2\xa2*H\xa0HB\x9e\x82\xc3\x90=n\x9cB\xb4z0BH\xc8\x18;\x9d3\x1a\xa3\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00



KeyboardInterrupt: 