In [5]:
import imageio
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from collections import deque
import random
import torch
from torch import nn
import torch.nn.functional as F
import torch.multiprocessing as mp

class DQN(nn.Module):
    def __init__(self, no_of_input_nodes, no_of_first_layer_nodes, no_of_second_layer_nodes, no_of_output_nodes):
        super().__init__()

        # Definicja warstw # TODO - can be changed
        self.first_layer = nn.Linear(no_of_input_nodes, no_of_first_layer_nodes)
        self.second_layer = nn.Linear(no_of_first_layer_nodes, no_of_second_layer_nodes)
        self.output_layer = nn.Linear(no_of_second_layer_nodes, no_of_output_nodes)

    def forward(self, training_data): # TODO - can be changed
        x = F.relu(self.first_layer(training_data))
        x = F.relu(self.second_layer(x))
        x = self.output_layer(x)
        return x

# Replay memory - kolejka stanów
class ReplayMemory():
    def __init__(self, maxlen):
        self.memory = deque([], maxlen=maxlen)

    def append(self, state):
        self.memory.append(state)

    def sample(self, sample_size):
        return random.sample(self.memory, sample_size)

    def __len__(self):
        return len(self.memory)

class FrozenLakeDQL():
    # Hyperparameters # TODO - can be changed
    learning_rate_a = 0.01
    discount_factor_g = 0.99
    network_sync_rate = 50000          # number of steps the agent takes before syncing the policy and target network
    replay_memory_size = 100_000       # size of replay memory
    mini_batch_size = 32            # size of the training data set sampled from the replay memory
    loss_fn = nn.MSELoss()          # NN Loss function
    optimizer = None                # NN Optimizer. Initialize later.

    no_of_states = 8
    no_of_actions = 4
    env_name = "LunarLander-v2"
    epsilon_initial_value = 1

    def __init__(self):
        torch.set_num_threads(mp.cpu_count())

    # Zamienia input na input do sieci
    def state_to_dqn_input(self, state):
        return torch.FloatTensor(state)


    def train(self, episodes, continue_training): # TODO - can be changed
        env = gym.make(self.env_name, render_mode='rgb_array')
        epsilon = self.epsilon_initial_value
        memory = ReplayMemory(self.replay_memory_size)
        
        # 1. Create policy and target networks
        policy_dqn = DQN(no_of_input_nodes=self.no_of_states, no_of_first_layer_nodes=64, no_of_second_layer_nodes=64, no_of_output_nodes=self.no_of_actions)
        target_dqn = DQN(no_of_input_nodes=self.no_of_states, no_of_first_layer_nodes=64, no_of_second_layer_nodes=64, no_of_output_nodes=self.no_of_actions)

        if continue_training:
            policy_dqn.load_state_dict(torch.load(self.env_name + ".pt"))
            epsilon = 0.3
            print("Loaded model from a file")

        # 2. Na początek powyższe sieci mają być identyczne
        target_dqn.load_state_dict(policy_dqn.state_dict())

        self.optimizer = torch.optim.Adam(policy_dqn.parameters(), lr=self.learning_rate_a) # Policy network optimizer
        rewards_per_episode = []  # Tablica na nagrody za każdy epizod
        step_count = 0 # Liczba kroków. Used for syncing policy => target network.
        best_rewards = -999999
        any_positive_points = False
        i = 0
        j = 0
        
        while i < episodes:
            state, _ = env.reset()
            done = False
            timeout = False
            rewards = 0     
            
            while not done and not timeout:
                # Z prawdopodobieństwem epsilon wybieramy losową akcję zamiast najlepszej
                if random.random() < epsilon:
                    action = env.action_space.sample()
                else:
                    action = policy_dqn(self.state_to_dqn_input(state)).argmax().item()

                # 3
                new_state, reward, done, timeout, _ = env.step(action)
                memory.append((state, action, new_state, reward, done))
                state = new_state
                step_count += 1
                rewards += reward

            if rewards > 100:
                any_positive_points = True
                
            j += 1
            if j % 100 == 0:
                print(j, i)
            if any_positive_points:
                rewards_per_episode.append(rewards)
                i += 1

                # Graph training progress
                if i != 0 and i % 1000 == 0:
                    print(f'Episode {i} Epsilon {epsilon}')
                    plt.figure(1)
                    plt.plot(rewards_per_episode)
                    plt.savefig(self.env_name + '.png')
                
                if rewards > best_rewards:
                    best_rewards = rewards
                    print(f'Best rewards so far: {best_rewards}')
                    torch.save(policy_dqn.state_dict(), self.env_name + ".pt")
    
                # Jak w pamięci jest wystarczająco zdarzeń i chociaż raz nam się udało to optymalizujemy sieć
                if len(memory) > self.mini_batch_size and any_positive_points:
                    memory_sample = memory.sample(self.mini_batch_size) # Pobieramy próbkę
                    self.optimize(memory_sample, policy_dqn, target_dqn) # 8. Optymalizujemy sieć
    
                    # Aktualizuj epsilon
                    if continue_training:
                        epsilon = max(epsilon - 0.3/episodes, 0)
                    else:
                        epsilon = max(epsilon - 1/episodes, 0)
    
                    # 10. target network = policy network
                    if step_count > self.network_sync_rate:
                        target_dqn.load_state_dict(policy_dqn.state_dict())
                        step_count = 0
        torch.save(policy_dqn.state_dict(), self.env_name + "_final.pt")
        env.close()


    def test(self, episodes): 
        env = gym.make(self.env_name, render_mode='human')

        policy_dqn = DQN(no_of_input_nodes=self.no_of_states, no_of_first_layer_nodes=64,no_of_second_layer_nodes=64, no_of_output_nodes=self.no_of_actions) # TODO - can be changed
        policy_dqn.load_state_dict(torch.load(self.env_name + ".pt"))

        for i in range(episodes):
            state, _ = env.reset()
            done = False
            truncated = False
            total_reward = 0
            
            while not done and not truncated:
                action = policy_dqn(self.state_to_dqn_input(state)).argmax().item()
                new_state, reward, done, truncated, _ = env.step(action)
                total_reward += reward
                state = new_state
            print(f"Episode {i + 1}: Total Reward = {total_reward}")

        env.close()

    def generate_gif(self, episodes): 
        env = gym.make(self.env_name, render_mode='rgb_array')

        policy_dqn = DQN(no_of_input_nodes=self.no_of_states, no_of_first_layer_nodes=64,no_of_second_layer_nodes=64, no_of_output_nodes=self.no_of_actions) # TODO - can be changed
        policy_dqn.load_state_dict(torch.load(self.env_name + ".pt"))

        best_total_reward = float('-inf')
        best_frames = []

        for i in range(episodes):
            state, _ = env.reset()
            done = False
            truncated = False
            total_reward = 0
            frames = []

            while not done and not truncated:
                action = policy_dqn(self.state_to_dqn_input(state)).argmax().item()
                new_state, reward, done, truncated, _ = env.step(action)
                total_reward += reward
                state = new_state

                frames.append(env.render())

            if total_reward > best_total_reward:
                best_total_reward = total_reward
                best_frames = frames

        if best_frames:
            with imageio.get_writer(self.env_name + '_best_episode.gif', mode='I', fps=30) as writer:
                for frame in best_frames:
                    writer.append_data(frame)

        env.close()


    def optimize(self, memory_sample, policy_dqn, target_dqn):

        # Listy z wartościami q (output sieci)
        current_q_list = []
        target_q_list = []

        for state, action, new_state, reward, done in memory_sample:

            # 6
            if done:
                # Jak done, to q[state, action] = reward
                q_in_target = torch.FloatTensor([reward])
            else:
                # W.p.p q[state. action] liczymy ze wzoru
                q_in_target = torch.FloatTensor(reward + self.discount_factor_g * target_dqn(self.state_to_dqn_input(new_state)).max())

            # 4. Dopisuje do listy policy krotkę wartości q. Np. (q1, q2, q3, q4)
            current_q = policy_dqn(self.state_to_dqn_input(state))
            current_q_list.append(current_q)

            target_q = target_dqn(self.state_to_dqn_input(state))
            target_q[action] = q_in_target # 7. ustawienie q[state, action] = ...
            target_q_list.append(target_q)

        # 8. Użyj target q values do trenowania policy q values
        loss = self.loss_fn(torch.stack(current_q_list), torch.stack(target_q_list))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    

if __name__ == '__main__':

    frozen_lake = FrozenLakeDQL()
    # frozen_lake.train(20_000, continue_training=False)
    frozen_lake.test(10)
    frozen_lake.generate_gif(10)

  policy_dqn.load_state_dict(torch.load(self.env_name + ".pt"))
