# Mico Implementation


In [92]:
import torch
import torch.nn as nn
import torch.optim as optim
import math, random
import numpy as np
import gym
import matplotlib.pyplot as plt

In [93]:
# Define the neural network for the encoder used in MICo
# NOTE: We are going to call
# state_dim : to the original dimension from the environment
# encoding_dim : to the dimension of the latent space
class Encoder(nn.Module):
    def __init__(self, state_dim, encoding_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, encoding_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

In [94]:
# Define the neural network for Deep Q-Learning
class DQNNetwork(nn.Module):
    def __init__(self, encoding_dim, action_dim):
        super(DQNNetwork, self).__init__()
        self.fc1 = nn.Linear(encoding_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

In [95]:
# Define a simple replay buffer
from collections import deque

class ReplayBuffer(object):
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    # Add a new experience to the buffer
    def push(self, state, action, reward, next_state, done):
        experience = (state, action, reward, next_state, done)
        self.buffer.append(experience)
    
    # Function to transform to tensors
    def toTensor(self, states, actions, rewards, next_states, dones):
        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.long).view(-1, 1)
        rewards = torch.tensor(rewards, dtype=torch.float32).view(-1, 1)
        next_states = torch.tensor(next_states, dtype=torch.float32, requires_grad=False)
        dones = torch.tensor(dones, dtype=torch.float32).view(-1, 1)

        return states, actions, rewards, next_states, dones

    def sample(self, batch_size):
        states, actions, rewards, next_states, dones = zip(*random.sample(self.buffer, batch_size))
        # TODO: Add the conversion to Pytorch tensor here
        # In which format I need it?

        # return np.concatenate(states), actions, rewards, np.concatenate(next_states), dones
        return self.toTensor(states, actions, rewards, next_states, dones)

    def __len__(self):
        return len(self.buffer)



In [96]:
# Define the Q-Learning agent with MICo integration
class DQN_MICo_Agent:
    def __init__(self, state_dim, action_dim, encoding_dim, gamma_mico=0.99, gamma_td=0.99, lr=0.001, alpha=0.1, mico_beta=0.1):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.encoding_dim = encoding_dim
        self.gamma_mico = gamma_mico
        self.gamma_td = gamma_td
        self.lr = lr
        self.alpha = alpha
        self.mico_beta = mico_beta
        
        # Define the encoder, Q-Network, target encoder and target network
        self.encoder = Encoder(state_dim, encoding_dim)
        self.target_encoder = Encoder(state_dim, encoding_dim)
        self.q_network = DQNNetwork(encoding_dim, action_dim)
        self.target_network = DQNNetwork(encoding_dim, action_dim)

        # Define the optimizer and loss function
        self.optimizer = optim.Adam(list(self.q_network.parameters()) + list(self.encoder.parameters()), lr=self.lr)
        self.criterion = nn.MSELoss()
        
        self.update_target_network()
    
    def update_target_network(self):
        # Load the weights of the Q-Network and encoder into the target network
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.target_encoder.load_state_dict(self.encoder.state_dict())
    
    def select_action(self, state_encoding, epsilon):
        if random.random() < epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            state_encoding = torch.FloatTensor(state_encoding).unsqueeze(0)
            q_values = self.q_network(state_encoding)
            return q_values.argmax().item()
        
    def representation_distance_U(self, encoding_x, encoding_y):
        # NOTE: encoding_x and encoding_y can be encoding of the current state or next states        

        # NOTE: correct because I'm using square and then sqrt

        # Calculate the square norms of the encodings
        norm_encoding_x = torch.norm(encoding_x , dim=1)
        norm_encoding_y = torch.norm(encoding_y , dim=1)

        # Calculate the angle between the encodings
        dot_product = torch.sum(encoding_x * encoding_y, dim=1)
        cosine_similarity = dot_product / (norm_encoding_x * norm_encoding_y)
        angle = torch.acos(cosine_similarity)

        return ((norm_encoding_x**2 + norm_encoding_y**2)/2) + self.mico_beta * angle
  
    
    def train(self, replay_buffer, batch_size):                
        # NOTE: At the begining, it needed some iteration to fill
        # the replay buffer
        if len(replay_buffer.buffer) < 2 * batch_size:
            return 0., 0., 0.
        
        # TODO: Check if the randoms are a proper way to sample 
        # or I should use random from pytorch

        # Gather a sample from the replay buffer
        # NOTE: We need to sample two pairs of transitions with the same action
        # TODO: ASK if they need to be the same action or just the same state
        # TODO: ASK if I need to calculate the TD loss for both transitions
        state_x, action_x, reward_x, next_state_x, done_x = replay_buffer.sample(batch_size)
        state_y, action_y, reward_y, next_state_y, done_y = replay_buffer.sample(batch_size)
        
        # Calculating the MICO Loss
        state_x_encoding = self.encoder(state_x)
        state_y_encoding = self.encoder(state_y)

        next_state_encoding_x =  self.target_encoder(next_state_x)
        next_state_encoding_y =  self.target_encoder(next_state_y)

        # Print the encodings
        # print("State X Encoding: ", state_x_encoding)
        # print("State Y Encoding: ", state_y_encoding)
        # print("Next State X Encoding: ", next_state_encoding_x)
        # print("Next State Y Encoding: ", next_state_encoding_y)

        target_distance_U = self.representation_distance_U(next_state_encoding_x, next_state_encoding_y)
        learning_target = torch.abs(reward_x - reward_y) + self.gamma_mico * target_distance_U
        current_distance_U = self.representation_distance_U(state_x_encoding, state_y_encoding)

        mico_loss = (learning_target - current_distance_U).pow(2).mean()
        

        # Calculating the TD-Loss
        # TODO: I'm gonna calculate just the loss of the encoding_x
        # but probabilly I should calculate the loss of both encodings
        q_values = self.q_network(state_x_encoding).gather(1, action_x)

        # TODO: The next state of this should be using the target encoder or the encoder?
        # I'm supose that the encoder
        next_state_encoding_x =  self.encoder(next_state_x)
        next_q_values = self.target_network(next_state_encoding_x).max(1)[0].unsqueeze(1)
        target_q_values = reward_x + (1 - done_x) * self.gamma_td * next_q_values
        
        td_loss = self.criterion(q_values, target_q_values)
        
        loss = (1 - self.alpha)* td_loss + self.alpha * mico_loss
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return mico_loss.item(), td_loss.item(), loss.item()

In [97]:
# Epsilon schedule
# epsilon_start = 1.0
# epsilon_final = 0.01
# epsilon_decay = 500

# epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * math.exp(-1. * frame_idx / epsilon_decay)

# plt.plot([epsilon_by_frame(i) for i in range(10000)])


In [98]:
# Initialize environment
env_id = "CartPole-v0"
env = gym.make(env_id)
print(env.observation_space.shape[0])
print(env.action_space.n)

4
2


In [101]:
# Initialize environment
env_id = "CartPole-v0"
env = gym.make(env_id) 

# Parameters
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
encoding_dim = 2
gamma_mico = 0.99
gamma_td = 0.99
lr = 0.001
alpha = 0.1
mico_beta = 0.1
epsilon = 0.1
batch_size = 32
capacity = 1000
num_iterations = 1000
num_episodes = 100

# Initialize agent and replay buffer
agent = DQN_MICo_Agent(state_dim, action_dim, encoding_dim, gamma_mico, gamma_td, lr, alpha, mico_beta)
replay_buffer = ReplayBuffer(capacity)

losses = []
mico_losses = []
td_losses = []
all_rewards = []
episode_reward = 0

# Training loop
for episode in range(num_episodes):

    # Reset environment and get initial state
    state, info = env.reset()
    for t in range(num_iterations):
        state_encoding = agent.encoder(torch.FloatTensor(state).unsqueeze(0))
        action = agent.select_action(state_encoding, epsilon)
        next_state, reward, done, _, _ = env.step(action)
        replay_buffer.push(state, action, reward, next_state, done)

        if len(replay_buffer) >= batch_size:
            mico_loss, td_loss, total_loss = agent.train(replay_buffer, batch_size)
            losses.append(total_loss)
            mico_losses.append(mico_loss)
            td_losses.append(td_loss)

        if t % 2 == 0:
            print("Episode: %d, Iteration: %d, Total Loss: %f" % (episode, t, total_loss))
            print("MICO Loss: %f, TD Loss: %f" % (mico_loss, td_loss))
        
        state = next_state
        episode_reward += reward

        if done:
            all_rewards.append(episode_reward)
            episode_reward = 0
            break

    # if episode % 10 == 0:
    #     print("Episode: %d, Total Reward: %d" % (episode, np.mean(all_rewards[-10:])))
    #     print("MICO Loss: %f, TD Loss: %f, Total Loss: %f" % (np.mean(mico_losses[-10:]), np.mean(td_losses[-10:]), np.mean(losses[-10:])))
        
    agent.update_target_network()

Episode: 0, Iteration: 0, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 0, Iteration: 2, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 0, Iteration: 4, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 0, Iteration: 6, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 0, Iteration: 8, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 1, Iteration: 0, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 1, Iteration: 2, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 1, Iteration: 4, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 1, Iteration: 6, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 1, Iteration: 8, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 2, Iteration: 0, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 2, Iteration: 2, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 2, Iteration: 4, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episode: 2, Iteration: 6, Total Loss: nan
MICO Loss: nan, TD Loss: nan
Episod

In [None]:
# Plotting the total rewards per episode
plt.plot(all_rewards)
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("Total Reward vs Episode")
plt.show()

In [None]:
# Plotting the losses per episode
plt.plot(losses)
plt.xlabel("Episode")
plt.ylabel("Loss")
plt.title("Loss vs Episode")
plt.show()

# Plotting the MICO losses per episode
plt.plot(mico_losses)
plt.xlabel("Episode")
plt.ylabel("MICO Loss")
plt.title("MICO Loss vs Episode")
plt.show()

# Plotting the TD losses per episode
plt.plot(td_losses)
plt.xlabel("Episode")
plt.ylabel("TD Loss")
plt.title("TD Loss vs Episode")
plt.show()