<a href="https://colab.research.google.com/github/OneFineStarstuff/TheOneEverAfter/blob/main/Intrinsic_Curiosity_Module_(ICM).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy

# CuriosityModule definition
class CuriosityModule(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(CuriosityModule, self).__init__()
        self.encoder = nn.Linear(state_dim, 128)
        self.forward_model = nn.Linear(128 + action_dim, 128)
        self.inverse_model = nn.Linear(128 * 2, action_dim)

    def forward(self, state, next_state, action):
        encoded_state = torch.relu(self.encoder(state))
        encoded_next_state = torch.relu(self.encoder(next_state))

        # Intrinsic reward is the error in forward model prediction
        pred_next_state = self.forward_model(torch.cat([encoded_state, action], dim=-1))
        intrinsic_reward = torch.mean((pred_next_state - encoded_next_state) ** 2)

        # Optionally compute inverse model predictions (not used in this example)
        pred_action = self.inverse_model(torch.cat([encoded_state, encoded_next_state], dim=-1))

        return intrinsic_reward, pred_action

# Example RL model
class SimpleModel(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(state_dim, action_dim)

    def forward(self, x):
        return self.fc(x)

# Function to compute the loss for a given model and task
def compute_loss(model, task):
    inputs, targets = task
    outputs = model(inputs)
    loss = torch.nn.functional.mse_loss(outputs, targets)
    return loss

def maml_update(model, task_data, optimizer, lr_inner=0.01, inner_steps=5):
    # Outer loop: Update on meta-tasks
    for task in task_data:
        # Clone model for inner loop
        task_model = copy.deepcopy(model)

        # Inner loop: Task-specific update
        for step in range(inner_steps):
            task_loss = compute_loss(task_model, task)
            task_grads = torch.autograd.grad(task_loss, task_model.parameters(), create_graph=True)

            # Apply inner loop gradient update
            for param, grad in zip(task_model.parameters(), task_grads):
                param.data -= lr_inner * grad

        # Compute meta-loss and optimize on original model
        meta_loss = compute_loss(task_model, task)
        meta_loss.backward()

    optimizer.step()
    optimizer.zero_grad()

# Example usage with curiosity-driven intrinsic motivation
state_dim = 10  # Dimension of the state space
action_dim = 4  # Dimension of the action space

# Initialize models
model = SimpleModel(state_dim, action_dim)
curiosity_module = CuriosityModule(state_dim, action_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Generate some sample task data (inputs and targets)
task_data = [(torch.randn(32, state_dim), torch.randn(32, action_dim)) for _ in range(4)]

# Perform MAML update with intrinsic rewards
for task in task_data:
    state = task[0]
    next_state = torch.randn(32, state_dim)  # Simulate next state
    action = model(state)                    # Get actions from the model

    # Compute intrinsic reward using curiosity module
    intrinsic_reward, _ = curiosity_module(state, next_state, action)
    print(f'Intrinsic Reward: {intrinsic_reward.item()}')

# Perform MAML update across tasks
maml_update(model, task_data, optimizer)

print("MAML update with curiosity module completed.")