# Import and functions

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

def create_dir(path):
    if not os.path.exists(path):
        os.mkdir(path)

# Loss and Reward

In [None]:
network_names = ['DQN', 'DDQN', 'DRDQN', 'DRQN']
plots = ['training_loss', 'cumulative_reward']
env_name = 'SpaceInvaders' #'Qbert'
updating_steps = 4

for plot_type in plots:
    for network_name in network_names:

        training_values = []
        mean_dash = []

        # compute mean loss for each epoch (over the 4 actions taken)
        with open(f'train_info/{plot_type}_{network_name}.txt', 'r') as file:
            lines = file.readlines()
            
            # Loop through lines in groups of 4
            for i in range(0, len(lines), updating_steps):
                # Get the next 4 lines
                epoch_values = lines[i:i+updating_steps]
                for j in range(len(epoch_values)):
                    epoch_values[j] = float(epoch_values[j].strip())

                # compute the mean
                mean_epoch_value = np.mean(epoch_values)
                training_values.append(mean_epoch_value)

        #compute mean value to plot
        window = len(training_values) // 100
        for i in range(len(training_values) - window + 1):
            mean_dash.append(np.mean(training_values[i:i+window]))

        # Generate x-axis values for epochs
        epochs = range(1, len(training_values) + 1)

        # Plotting the training loss
        plt.plot(epochs, training_values)
        # Plotting mean dash value
        plt.plot(mean_dash, linestyle = '--')

        # Labels and title
        plt.xlabel('Epoch')
        if plot_type == 'training_loss':
            plt.ylabel('Training Loss')
            plt.title(f'{network_name} Training Loss')
        else:
            plt.ylabel('Cumulative Reward')
            plt.title(f'{network_name} Cumulative Reward')

        #save the plot
        plots_folder = 'plots'
        create_dir(plots_folder)
        create_dir(env_name)
        plt.savefig(f'{plots_folder}/{env_name}/{network_name}_{plot_type}.png')

        # Display the plot
        plt.show()