<a href="https://colab.research.google.com/github/OneFineStarstuff/TheOneEverAfter/blob/main/Model_Based_Reinforcement_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Transition model definition
class TransitionModel(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(TransitionModel, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)  # First layer
        self.fc2 = nn.Linear(256, 256)                      # Second layer
        self.fc_next_state = nn.Linear(256, state_dim)     # Output layer for next state
        self.fc_reward = nn.Linear(256, 1)                 # Output layer for reward

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)             # Concatenate state and action
        x = torch.relu(self.fc1(x))                         # First hidden layer with ReLU activation
        x = torch.relu(self.fc2(x))                         # Second hidden layer with ReLU activation
        next_state = self.fc_next_state(x)                 # Predict next state
        reward = self.fc_reward(x)                           # Predict reward
        return next_state, reward

# Simple model definition (for action selection)
class SimpleModel(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(state_dim, action_dim)          # Map states to actions

    def forward(self, x):
        return self.fc(x)

# Function to select an action based on model predictions with exploration
def select_action(model, transition_model, state, action_space, epsilon=0.1):
    if np.random.rand() < epsilon:  # Exploration: select a random action
        return action_space[np.random.choice(len(action_space))]

    best_action = None
    best_predicted_reward = -float('inf')

    # Evaluate each action in the action space
    for action in action_space:
        action_tensor = action.clone().detach().unsqueeze(0)  # Prepare action tensor
        _, pred_reward = transition_model(state, action_tensor)  # Get predicted reward

        if pred_reward.item() > best_predicted_reward:  # Update best action if needed
            best_predicted_reward = pred_reward.item()
            best_action = action

    return best_action

# Training function for the transition model
def train(transition_model, model, optimizer, episodes=1000):
    for episode in range(episodes):
        current_state = torch.randn(1, state_dim)  # Simulate a random initial state

        # Define the discrete action space (random actions for simplicity)
        action_space = [torch.randn(action_dim) for _ in range(10)]

        # Select an optimal action using the transition model
        optimal_action = select_action(model, transition_model, current_state, action_space)

        # Get the next state and reward from the transition model
        next_state, reward = transition_model(current_state, optimal_action.unsqueeze(0))

        # Compute loss (negative reward as loss for minimization)
        loss = -reward.mean()

        # Backpropagation to update model weights
        optimizer.zero_grad()  # Clear previous gradients
        loss.backward()       # Backpropagate loss
        optimizer.step()      # Update weights

        if episode % 100 == 0:  # Log every 100 episodes
            print(f'Episode {episode}, Loss: {loss.item()}')

# Example usage of the framework
if __name__ == "__main__":
    state_dim = 10  # Dimension of the state space
    action_dim = 4  # Dimension of the action space

    # Initialize models and optimizer
    transition_model = TransitionModel(state_dim, action_dim)
    model = SimpleModel(state_dim, action_dim)
    optimizer = optim.Adam(transition_model.parameters(), lr=0.001)

    # Train the transition model
    train(transition_model, model, optimizer)

    # Simulate a final selection of optimal action after training
    current_state = torch.randn(1, state_dim)
    action_space = [torch.randn(action_dim) for _ in range(10)]
    optimal_action = select_action(model, transition_model, current_state, action_space)

    print(f'Optimal Action after training: {optimal_action}')