<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/Complete_Physics_Informed_Time_Dilation_Estimator_with_MC_Dropout.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt

# 1. Reproducibility & Device
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Physics Functions & Synthetic Dataset
c = 3e8               # speed of light (m/s)
G = 6.674e-11         # gravitational constant (m^3/kg/s^2)

def compute_time_dilation(mass, velocity, radius):
    # Ensure no superluminal velocity and valid sqrt arguments
    beta2 = (velocity / c)**2
    beta2 = torch.clamp(beta2, max=0.9999)
    gm = 2 * G * mass / (radius * c**2)
    gm = torch.clamp(gm, max=0.9999)

    gamma_v = 1.0 / torch.sqrt(1 - beta2)
    gamma_g = torch.sqrt(1 - gm)
    return gamma_v * gamma_g

class TimeDilationDataset(Dataset):
    def __init__(self, n_samples=5000):
        super().__init__()
        # sample mass [1e23, 1e27] kg, velocity [0, 0.9c], radius [1e6, 1e8] m
        self.mass   = torch.rand(n_samples) * (1e27 - 1e23) + 1e23
        self.v      = torch.rand(n_samples) * 0.9 * c
        self.radius = torch.rand(n_samples) * (1e8 - 1e6) + 1e6
        self.y_true = compute_time_dilation(self.mass, self.v, self.radius).unsqueeze(1)

    def __len__(self):
        return len(self.mass)

    def __getitem__(self, idx):
        x = torch.stack([self.mass[idx], self.v[idx], self.radius[idx]], dim=0)
        return x.float(), self.y_true[idx].float()

# 3. Model with MC-Dropout
class TimeDilationAI(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=32, output_dim=1, dropout_p=0.2):
        super().__init__()
        self.fc1     = nn.Linear(input_dim, hidden_dim)
        self.relu    = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout_p)
        self.fc2     = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

# 4. Physics-Informed Loss
mse_loss = nn.MSELoss()

def physics_informed_loss(pred, true, alpha=0.1):
    data_term     = mse_loss(pred, true)
    residual_term = torch.mean((pred - true)**2)
    return data_term + alpha * residual_term

# 5. Training Loop
dataset    = TimeDilationDataset(n_samples=8000)
train_set, val_set = torch.utils.data.random_split(dataset, [6400, 1600])
train_dl   = DataLoader(train_set, batch_size=64, shuffle=True)
val_dl     = DataLoader(val_set, batch_size=64)

model     = TimeDilationAI().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

best_val_loss = float('inf')
patience_cnt  = 0
max_patience  = 10

for epoch in range(1, 101):
    model.train()
    train_losses = []
    for x_batch, y_batch in train_dl:
        x, y = x_batch.to(device), y_batch.to(device)
        pred = model(x)
        loss = physics_informed_loss(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    model.eval()
    val_losses = []
    with torch.no_grad():
        for x_batch, y_batch in val_dl:
            x, y = x_batch.to(device), y_batch.to(device)
            pred = model(x)
            loss = physics_informed_loss(pred, y)
            val_losses.append(loss.item())

    train_loss = np.mean(train_losses)
    val_loss   = np.mean(val_losses)
    scheduler.step(val_loss)

    print(f"Epoch {epoch:03d} | Train {train_loss:.5f} | Val {val_loss:.5f}")

    # early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_cnt  = 0
        torch.save(model.state_dict(), "best_model.pt")
    else:
        patience_cnt += 1
        if patience_cnt >= max_patience:
            print(f"Early stopping at epoch {epoch}")
            break

# load best model
model.load_state_dict(torch.load("best_model.pt"))

# 6. MC-Dropout Inference
def mc_dropout_predict(model, x, mc_runs=100):
    model.train()  # keep dropout active
    preds = torch.stack([model(x) for _ in range(mc_runs)], dim=0)
    mean  = preds.mean(dim=0)
    std   = preds.std(dim=0)
    return mean, std

# 7. Plot Predictions vs True
model.eval()
x_val, y_val = next(iter(val_dl))
x_val, y_val = x_val.to(device), y_val.to(device)
with torch.no_grad():
    y_pred = model(x_val)
y_pred = y_pred.cpu().numpy().flatten()
y_true = y_val.cpu().numpy().flatten()

plt.figure(figsize=(6,6))
plt.scatter(y_true, y_pred, alpha=0.5)
plt.plot([y_true.min(), y_true.max()],
         [y_true.min(), y_true.max()],
         'r--', lw=2)
plt.xlabel("True Dilation")
plt.ylabel("Predicted Dilation")
plt.title("Model Predictions vs True Values")
plt.tight_layout()
plt.show()

# 8. Uncertainty Heatmap over (mass, velocity)
mass_vals = torch.linspace(1e23, 1e27, 30)
vel_vals  = torch.linspace(0, 0.9*c, 30)
radius    = 1e7  # fixed radius

M, V = torch.meshgrid(mass_vals, vel_vals, indexing='ij')
X_grid = torch.stack([M.flatten(), V.flatten(), torch.full_like(M, radius).flatten()], dim=1)
X_grid = X_grid.to(device)

with torch.no_grad():
    mean_pred, std_pred = mc_dropout_predict(model, X_grid, mc_runs=100)

STD = std_pred.cpu().numpy().reshape(M.shape)

plt.figure(figsize=(6,5))
plt.pcolormesh(vel_vals/c, mass_vals, STD, shading='auto', cmap='viridis')
plt.colorbar(label="Std of Dilation")
plt.xlabel("Velocity / c")
plt.ylabel("Mass (kg)")
plt.title("Prediction Uncertainty Heatmap")
plt.tight_layout()
plt.show()