<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/quantum_gravity_rl_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch gymnasium numpy

In [None]:
#!/usr/bin/env python3
"""
Deep RL for Quantum Gravity Corrections using DDPG.
State dim = 4, Action dim = 4 (quantum metric corrections).
"""

import random
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from gymnasium import spaces

# ------------------------
# 1) Dummy Env Definition
# ------------------------
class QuantumGravityEnv(gym.Env):
    """
    4-D continuous state/action dummy env.
    Reward = –‖state‖², state updates by action + noise.
    """
    metadata = {"render_modes": []}

    def __init__(self):
        super().__init__()
        self.observation_space = spaces.Box(-10, 10, shape=(4,), dtype=np.float32)
        self.action_space = spaces.Box(-1, 1, shape=(4,), dtype=np.float32)
        self.state = np.zeros(4, dtype=np.float32)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.state = self.np_random.uniform(-1, 1, size=4).astype(np.float32)
        # Gymnasium reset returns (obs, info)
        return self.state, {}

    def step(self, action):
        # next state, reward, terminated/truncated, info
        self.state = self.state + action + 0.1 * self.np_random.standard_normal(4).astype(np.float32)
        reward = -np.linalg.norm(self.state)**2
        terminated = False # no natural termination in dummy
        truncated = False # no time limit wrapper
        return self.state, reward, terminated, truncated, {}

# ------------------------
# 2) Replay Buffer
# ------------------------
class ReplayBuffer:
    def __init__(self, capacity=100000):
        self.buffer = []
        self.cap    = capacity

    def push(self, s, a, r, s2):
        if len(self.buffer) >= self.cap:
            self.buffer.pop(0)
        self.buffer.append((s, a, r, s2))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        s, a, r, s2 = map(np.stack, zip(*batch))
        return s, a, r, s2

    def __len__(self):
        return len(self.buffer)

# ------------------------
# 3) Network Definitions
# ------------------------
class QuantumGravityAI(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1  = nn.Linear(input_dim, hidden_dim)
        self.fc2  = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return torch.tanh(self.fc2(x))  # ensure actions ∈ [-1,1]

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, s, a):
        x = torch.relu(self.fc1(torch.cat([s, a], dim=-1)))
        return self.fc2(x)

# ------------------------
# 4) DDPG Agent
# ------------------------
class DDPGAgent:
    def __init__(self, env):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        s_dim = env.observation_space.shape[0]
        a_dim = env.action_space.shape[0]

        # actor & critic + targets
        self.actor        = QuantumGravityAI(s_dim, 16, a_dim).to(self.device)
        self.actor_target = QuantumGravityAI(s_dim, 16, a_dim).to(self.device)
        self.critic       = Critic(s_dim, a_dim, 16).to(self.device)
        self.critic_target= Critic(s_dim, a_dim, 16).to(self.device)

        # copy initial weights
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_target.load_state_dict(self.critic.state_dict())

        # optimizers
        self.a_opt = optim.Adam(self.actor.parameters(), lr=1e-3)
        self.c_opt = optim.Adam(self.critic.parameters(), lr=1e-3)

        self.buffer = ReplayBuffer(50000)
        self.gamma  = 0.99
        self.tau    = 0.005

    def select_action(self, state, noise_scale=0.1):
        s = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        a = self.actor(s).cpu().detach().numpy()[0]
        return np.clip(a + noise_scale * np.random.randn(*a.shape), -1, 1)

    def update(self, batch_size=64):
        if len(self.buffer) < batch_size:
            return

        s, a, r, s2 = self.buffer.sample(batch_size)
        s  = torch.FloatTensor(s).to(self.device)
        a  = torch.FloatTensor(a).to(self.device)
        r  = torch.FloatTensor(r).unsqueeze(1).to(self.device)
        s2 = torch.FloatTensor(s2).to(self.device)

        # Critic update
        with torch.no_grad():
            a2 = self.actor_target(s2)
            q2 = self.critic_target(s2, a2)
            y  = r + self.gamma * q2
        q1 = self.critic(s, a)
        c_loss = nn.MSELoss()(q1, y)
        self.c_opt.zero_grad()
        c_loss.backward()
        self.c_opt.step()

        # Actor update
        a_pred = self.actor(s)
        a_loss = -self.critic(s, a_pred).mean()
        self.a_opt.zero_grad()
        a_loss.backward()
        self.a_opt.step()

        # Safe soft-update (out-of-place ops)
        with torch.no_grad():
            for p, pt in zip(self.actor.parameters(), self.actor_target.parameters()):
                new_val = pt.data * (1 - self.tau) + p.data * self.tau
                pt.data.copy_(new_val)
            for p, pt in zip(self.critic.parameters(), self.critic_target.parameters()):
                new_val = pt.data * (1 - self.tau) + p.data * self.tau
                pt.data.copy_(new_val)

    def store(self, *args):
        self.buffer.push(*args)

# ------------------------
# 5) Training Loop
# ------------------------
def train():
    env   = QuantumGravityEnv()
    agent = DDPGAgent(env)
    episodes = 200

    for ep in range(1, episodes + 1):
        state, info = env.reset(seed=ep)
        ep_reward = 0.0

        for t in range(100):
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, info = env.step(action)

            done = terminated or truncated
            agent.store(state, action, reward, next_state)
            agent.update()

            state = next_state
            ep_reward += reward
            if done:
                break

        if ep % 10 == 0:
            print(f"Episode {ep:03d} → Reward: {ep_reward:.2f}")

if __name__ == "__main__":
    train()