<a href="https://colab.research.google.com/github/OneFineStarstuff/TheOneEverAfter/blob/main/DQN_implementation_with_the_%60CuriosityModule%60_and_the_%60replay%60_method.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Define the Q-Network
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# Define Curiosity Module
class CuriosityModule(nn.Module):
    def __init__(self, state_size, action_size):
        super(CuriosityModule, self).__init__()
        self.forward_model = nn.Sequential(
            nn.Linear(state_size + action_size, 128),
            nn.ReLU(),
            nn.Linear(128, state_size)
        )
        self.inverse_model = nn.Sequential(
            nn.Linear(2 * state_size, 128),
            nn.ReLU(),
            nn.Linear(128, action_size)
        )

    def forward(self, state, next_state, action):
        # Ensure state and next_state have the correct dimensions
        if state.dim() == 1:
            state = state.unsqueeze(0)
        if next_state.dim() == 1:
            next_state = next_state.unsqueeze(0)

        # Ensure action has two dimensions
        if action.dim() == 1:
            action = action.unsqueeze(0)

        # Concatenate state and action
        state_action = torch.cat([state, action], dim=1)  # Concatenate along the feature dimension

        predicted_next_state = self.forward_model(state_action)
        predicted_action = self.inverse_model(torch.cat([state, next_state], dim=1))
        return predicted_next_state, predicted_action

    def intrinsic_reward(self, next_state, predicted_next_state):
        # Calculate the mean squared error loss for each sample (shape: [batch_size])
        intrinsic_rewards = F.mse_loss(next_state, predicted_next_state, reduction='none')
        intrinsic_rewards = intrinsic_rewards.mean(dim=1)  # Average MSE across the state dimension
        return intrinsic_rewards  # Shape: [batch_size]

# Define the DQN Agent with Curiosity
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=10000)
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        self.batch_size = 64

        self.model = QNetwork(state_size, action_size)
        self.target_model = QNetwork(state_size, action_size)
        self.curiosity = CuriosityModule(state_size, action_size)

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.curiosity_optimizer = optim.Adam(self.curiosity.parameters(), lr=0.001)

        self.update_target_model()

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)

        state = torch.FloatTensor(state).unsqueeze(0)

        with torch.no_grad():
            q_values = self.model(state)

        return np.argmax(q_values.numpy())

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        minibatch = random.sample(self.memory, self.batch_size)

        states, actions, rewards, next_states, dones = zip(*minibatch)

        states = torch.FloatTensor(states).view(self.batch_size, -1)  # Ensure correct shape
        next_states = torch.FloatTensor(next_states).view(self.batch_size, -1)  # Ensure correct shape

        rewards = torch.FloatTensor(rewards).unsqueeze(1)  # Ensure rewards is of shape [batch_size, 1]
        dones = torch.FloatTensor(dones).unsqueeze(1)      # Ensure dones is of shape [batch_size, 1]

        actions = torch.LongTensor(actions).unsqueeze(1)   # Ensure actions is of shape [batch_size, 1]

        # One-hot encode actions (this creates a 2D tensor of shape [batch_size, action_size])
        one_hot_actions = F.one_hot(actions.squeeze(), num_classes=self.action_size).float()

        # Intrinsic rewards
        predicted_next_states, _ = self.curiosity(states, next_states, one_hot_actions)

        # Adjust the shape of predicted_next_states to match next_states
        predicted_next_states = predicted_next_states.view(next_states.size())

        intrinsic_rewards = self.curiosity.intrinsic_reward(next_states, predicted_next_states)

        # Add intrinsic rewards to the rewards (make sure both are 1D tensors)
        intrinsic_rewards = intrinsic_rewards.unsqueeze(1)  # Shape: [batch_size, 1]
        rewards = rewards + intrinsic_rewards.detach()  # Shape: [batch_size, 1]

        # Compute Q targets
        q_values = self.model(states).gather(1, actions)  # Shape: [batch_size, 1]

        with torch.no_grad():
            max_next_q_values = self.target_model(next_states).max(1)[0].unsqueeze(1)  # Shape: [batch_size, 1]
            q_targets = rewards + (self.gamma * max_next_q_values * (1 - dones))

        # Loss calculation and update for Q-Network
        loss = F.mse_loss(q_values, q_targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Curiosity module update
        predicted_next_states, predicted_actions = self.curiosity(states, next_states, one_hot_actions)
        curiosity_loss = F.mse_loss(predicted_next_states, next_states) + \
                          F.cross_entropy(predicted_actions.view(-1, self.action_size), actions.view(-1))

        self.curiosity_optimizer.zero_grad()
        curiosity_loss.backward()
        self.curiosity_optimizer.step()

        # Decay epsilon
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

# Train the agent in a CartPole environment
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n  # Correct attribute for action space

agent = DQNAgent(state_size, action_size)
episodes = 500

for e in range(episodes):
    state = env.reset()
    state = np.reshape(state, [1, state_size])
    done = False
    total_reward = 0

    while not done:
        action = agent.act(state)
        next_state, reward, done, _ = env.step(action)
        next_state = np.reshape(next_state, [1, state_size])

        agent.remember(state, action, reward, next_state, done)

        state = next_state
        total_reward += reward

        if done:
            agent.update_target_model()
            print(f"Episode: {e+1}/{episodes}, Score: {total_reward}, Epsilon: {agent.epsilon:.2f}")

        agent.replay()

env.close()