# Mico Implementation


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np

In [None]:
# Define the neural network for the encoder used in MICo
# NOTE: We are going to call
# state_dim : to the original dimension from the environment
# encoding_dim : to the dimension of the latent space
class Encoder(nn.Module):
    def __init__(self, state_dim, encoding_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, encoding_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

In [None]:
# Define the neural network for Deep Q-Learning
class DQN(nn.Module):
    def __init__(self, encoding_dim, action_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(encoding_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

In [None]:
# Define the Q-Learning agent with MICo integration
class DQN_MICo_Agent:
    def __init__(self, state_dim, action_dim, encoding_dim, gamma=0.99, lr=0.001, mico_alpha=0.1, mico_beta=0.1):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.encoding_dim = encoding_dim
        self.gamma = gamma
        self.lr = lr
        self.mico_alpha = mico_alpha
        self.mico_beta = mico_beta
        
        # Define the encoder, Q-Network, target encoder and target network
        self.encoder = Encoder(state_dim, encoding_dim)
        self.target_encoder = Encoder(state_dim, encoding_dim)
        self.q_network = DQN(encoding_dim, action_dim)
        self.target_network = DQN(encoding_dim, action_dim)

        # Define the optimizer and loss function
        self.optimizer = optim.Adam(list(self.q_network.parameters()) + list(self.encoder.parameters()), lr=self.lr)
        self.criterion = nn.MSELoss()
        
        self.update_target_network()
    
    def update_target_network(self):
        # Load the weights of the Q-Network and encoder into the target network
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.target_encoder.load_state_dict(self.encoder.state_dict())
    
    def select_action(self, state_encoding, epsilon):
        if random.random() < epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            state_encoding = torch.FloatTensor(state_encoding).unsqueeze(0)
            q_values = self.q_network(state_encoding)
            return q_values.argmax().item()
        
    def representation_distance_U(self, encoding_x, encoding_y):
        # NOTE: encoding_x and encoding_y can be encoding of the current state or next states        

        # Calculate the square norms of the encodings
        square_norm_encoding_x = torch.norm(encoding_x , dim=1)**2
        square_norm_encoding_y = torch.norm(encoding_y , dim=1)**2

        # Calculate the angle between the encodings
        dot_product = torch.sum(encoding_x * encoding_y, dim=1)
        cosine_similarity = dot_product / (torch.sqrt(square_norm_encoding_x) * torch.sqrt(square_norm_encoding_y))
        angle = torch.acos(cosine_similarity)

        return ((square_norm_encoding_x + square_norm_encoding_y)/2) + self.beta * angle
  
    
    def train(self, replay_buffer, batch_size):
        if len(replay_buffer.buffer) < batch_size:
            return
        
        # TODO: Check if the randoms are a proper way to sample 
        # or I should use random from pytorch

        # Gather a sample from the replay buffer
        # NOTE: We need to sample two pairs of transitions with the same action
        # TODO: ASK if they need to be the same action or just the same state
        # TODO: ASK if I need to calculate the TD loss for both transitions
        batch_x = random.sample(replay_buffer.buffer, batch_size)
        state_x, action_x, reward_x, next_state_x, done_x = zip(*batch_x)
        
        state_x = torch.FloatTensor(state_x)
        action_x = torch.LongTensor(action_x).unsqueeze(1)
        reward_x = torch.FloatTensor(reward_x).unsqueeze(1)
        next_state_x = torch.FloatTensor(next_state_x)
        done_x = torch.FloatTensor(done_x).unsqueeze(1)
        
        batch_y = random.sample(replay_buffer.buffer, batch_size)
        state_y, action_y, reward_y, next_state_y, done_y = zip(*batch_y)

        state_y = torch.FloatTensor(state_y)
        action_y = torch.LongTensor(action_y).unsqueeze(1)
        reward_y = torch.FloatTensor(reward_y).unsqueeze(1)
        next_state_y = torch.FloatTensor(next_state_y)
        done_y = torch.FloatTensor(done_y).unsqueeze(1)
        
        # Calculating the MICO Loss
        state_x_encoding = self.encoder(state_x)
        state_y_encoding = self.encoder(state_y)

        next_state_encoding_x =  self.target_encoder(state_x)
        next_state_encoding_y =  self.target_encoder(state_y)

        target_distance_U = self.representation_distance_U(next_state_encoding_x, next_state_encoding_y)
        learning_target = torch.abs(reward_x - reward_y) + self.gamma * target_distance_U
        current_distance_U = self.representation_distance_U(state_x_encoding, state_y_encoding)

        mico_loss = (learning_target - current_distance_U).pow(2).mean()
        

        # Calculating the TD-Loss
        # TODO: I'm gonna calculate just the loss of the encoding_x
        # but probabilly I should calculate the loss of both encodings
        q_values = self.q_network(state_x_encoding).gather(1, action_x)

        # TODO: The next state of this should be using the target encoder or the encoder?
        next_q_values = self.target_network(next_state_encoding_x).max(1)[0].unsqueeze(1)
        target_q_values = reward_x + (1 - done_x) * self.gamma * next_q_values
        
        q_loss = self.criterion(q_values, target_q_values)
        
        loss = (1 - self.alpha)* q_loss + self.alpha * mico_loss
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [None]:
# Define a simple replay buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = []
        self.capacity = capacity
        self.position = 0
    
    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)



In [None]:
# Parameters
state_dim = 4
action_dim = 2
encoding_dim = 32
gamma = 0.99
lr = 0.001
mico_alpha = 0.1
epsilon = 0.1
batch_size = 32
capacity = 1000

# Initialize agent and replay buffer
agent = DQN_MICo_Agent(state_dim, action_dim, encoding_dim, gamma, lr, mico_alpha)
replay_buffer = ReplayBuffer(capacity)

# Training loop (dummy example, replace with actual environment interaction)
for episode in range(100):
    state = np.random.rand(state_dim)
    for t in range(100):
        action = agent.select_action(state, epsilon)
        next_state = np.random.rand(state_dim)
        reward = random.random()
        done = random.random() < 0.1
        replay_buffer.push(state, action, reward, next_state, done)
        agent.train(replay_buffer, batch_size)
        if done:
            break
        state = next_state
    agent.update_target_network()