<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/casimir_train_log_fix_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

# 1. True Casimir force (always negative)
def casimir_force(d, ε):
    return - (np.pi**2) / (240.0 * d**4) * ε

# 2. Build synthetic dataset
n_dist, n_perm = 100, 100
distances = np.linspace(1e-8, 1e-6, n_dist)      # 10 nm → 1 µm
permittivities = np.linspace(1.0, 10.0, n_perm)  # εr = 1 → 10
D, E = np.meshgrid(distances, permittivities)
X = np.stack([D.ravel(), E.ravel()], axis=1)
y = casimir_force(X[:,0], X[:,1]).reshape(-1,1)

# 3. Log-transform and normalize inputs
log_d = np.log(X[:,0])
log_d_norm = (log_d - log_d.mean()) / log_d.std()
ε_norm     = (X[:,1] - X[:,1].mean()) / X[:,1].std()
X_tensor   = torch.tensor(np.stack([log_d_norm, ε_norm], axis=1),
                          dtype=torch.float32)

# 4. Log-transform target & normalize
y_abs      = -y                                    # absolute force
log_y      = np.log(y_abs)                        # log|F|
mean_ly, std_ly = log_y.mean(), log_y.std()
y_norm     = (log_y - mean_ly) / std_ly
y_tensor   = torch.tensor(y_norm, dtype=torch.float32)

# 5. DataLoader
dataset = TensorDataset(X_tensor, y_tensor)
loader  = DataLoader(dataset, batch_size=256, shuffle=True)

# 6. MLP definition
class NegCasimirMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        layers = []
        last_dim = input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(last_dim, h),
                nn.LayerNorm(h),
                nn.ReLU(inplace=True)
            ]
            last_dim = h
        layers.append(nn.Linear(last_dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

model     = NegCasimirMLP(input_dim=2, hidden_dims=[64,64,32], output_dim=1)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
criterion = nn.MSELoss()

# 7. Training loop
n_epochs = 200
for epoch in range(1, n_epochs+1):
    model.train()
    running_loss = 0.0
    for xb, yb in loader:
        optimizer.zero_grad()
        y_pred = model(xb)
        loss   = criterion(y_pred, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)
    running_loss /= len(dataset)
    if epoch % 20 == 0:
        print(f"Epoch {epoch:03d}  Normalized-Log MSE: {running_loss:.4f}")

# 8. Inference & back-transform to force
model.eval()
with torch.no_grad():
    y_pred_norm = model(X_tensor).cpu().numpy().flatten()
    log_pred     = y_pred_norm * std_ly + mean_ly
    y_pred_force = -np.exp(log_pred)  # restore negative sign

# 9. Scatter plot (log-log of absolute values)
plt.figure(figsize=(6,6))
plt.scatter(
    np.abs(y),
    np.abs(y_pred_force),
    s=5, alpha=0.3
)
plt.plot(
    [y.min()*-1, y.max()*-1],
    [y.min()*-1, y.max()*-1],
    'r--'
)
plt.xscale('log')
plt.yscale('log')
plt.xlabel("True |Casimir Force|")
plt.ylabel("Predicted |Casimir Force|")
plt.title("Log-Log Fit of Absolute Forces")
plt.tight_layout()
plt.show()

# 10. Surface heatmaps (absolute force with LogNorm)
Z_true_abs = (-y).reshape(n_perm, n_dist)
Z_pred_abs = (-y_pred_force).reshape(n_perm, n_dist)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,4))
im1 = ax1.pcolormesh(
    distances*1e6, permittivities,
    Z_true_abs, norm=LogNorm(), shading='auto'
)
ax1.set_title("True |Force| Surface")
ax1.set_xlabel("Distance (µm)")
ax1.set_ylabel("Permittivity")
fig.colorbar(im1, ax=ax1)

im2 = ax2.pcolormesh(
    distances*1e6, permittivities,
    Z_pred_abs, norm=LogNorm(), shading='auto'
)
ax2.set_title("Predicted |Force| Surface")
ax2.set_xlabel("Distance (µm)")
ax2.set_ylabel("Permittivity")
fig.colorbar(im2, ax=ax2)

plt.tight_layout()
plt.show()