In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import os
import time
from sklearn.metrics import mean_squared_error, mean_absolute_error

class PINNSolver:
    def __init__(self, N_fields=1, m_vec=None, rho_m0=0.81, rho_r0=0.00027138, rho_l0=2.19,
                 a0=1e-8, phi0=None, phi_dot0=None, t_span=(0.0, 1.0), t_eval=None,
                 device=None, folder_name='results'):

        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")

        self.N_fields = N_fields
        self.m_vec = m_vec if m_vec is not None else np.array([25.0]*N_fields)
        self.rho_m0 = rho_m0
        self.rho_r0 = rho_r0
        self.rho_l0 = rho_l0

        self.a0 = a0
        self.phi0 = phi0 if phi0 is not None else np.array([1.0]*N_fields)
        self.phi_dot0 = phi_dot0 if phi_dot0 is not None else np.array([0.0]*N_fields)
        self.y0 = np.concatenate([[self.a0], self.phi0, self.phi_dot0])

        self.t_span = t_span
        self.t_eval = t_eval if t_eval is not None else np.logspace(np.log10(a0), np.log10(t_span[1]), 1000).astype(np.float32)

        self.folder_name = folder_name
        os.makedirs(self.folder_name, exist_ok=True)

        self.model = self.PINN(num_layers=4, num_neurons=200, N_fields=self.N_fields).to(self.device)
        self.m_vec_torch = torch.tensor(self.m_vec, dtype=torch.float32, device=self.device)
        self.phi0_torch = torch.tensor(self.phi0.reshape(1, -1), device=self.device)
        self.a0_torch = torch.tensor([[self.a0]], device=self.device)
        self.t0_torch = torch.tensor([[0.0]], device=self.device)

    class PINN(nn.Module):
        def __init__(self, num_layers, num_neurons, N_fields):
            super().__init__()
            self.layers = nn.ModuleList()
            self.layers.append(nn.Linear(1, num_neurons))
            for _ in range(num_layers):
                self.layers.append(nn.Linear(num_neurons, num_neurons))
            self.layers.append(nn.Linear(num_neurons, 1 + N_fields))

        def forward(self, t):
            x = t
            for layer in self.layers[:-1]:
                z = layer(x)
                x = z * torch.sin(z)
            x = self.layers[-1](x)
            a = torch.nn.functional.softplus(x[:, 0:1])
            a = torch.clamp(a, min=1e-6)
            phi = x[:, 1:]
            return a, phi

    def ode_system(self, t, y):
        N = self.N_fields
        a = y[0]
        phi = y[1:N+1]
        phi_dot = y[N+1:2*N+1]
        kinetic = 0.5 * np.sum((phi_dot * a)**2)
        potential = 0.5 * np.sum((self.m_vec**2) * (phi * a)**2)
        H = np.sqrt((1/3) * (kinetic + potential + self.rho_m0 / a + self.rho_r0 / a**2 + self.rho_l0 * a**2))
        a_dot = H
        phi_ddot = - np.sqrt(3) * np.sqrt(0.5 * np.sum(phi_dot**2) +
            0.5 * np.sum((self.m_vec**2) * phi**2) + self.rho_m0 / a**3 + self.rho_r0 / a**4 + self.rho_l0) * phi_dot \
            - (self.m_vec**2) * phi

        dydt = np.zeros_like(y)
        dydt[0] = a_dot
        dydt[1:N+1] = phi_dot
        dydt[N+1:2*N+1] = phi_ddot
        return dydt

    def solve_ode(self):
        print("Solving ODE...")
        start_ode = time.time()
        sol = solve_ivp(self.ode_system, self.t_span, self.y0, t_eval=self.t_eval, method='RK45')
        print(f"ODE solve time: {time.time() - start_ode:.2f} sec")
        self.a_sol = sol.y[0, :]
        self.phi_sol = sol.y[1:1+self.N_fields, :]

    def physics_loss(self, model, t):
        a, phi = model(t)
        a_t = torch.autograd.grad(a, t, torch.ones_like(a), create_graph=True)[0]
        phi_t = torch.autograd.grad(phi, t, torch.ones_like(phi), create_graph=True)[0]
        phi_tt = torch.autograd.grad(phi_t, t, torch.ones_like(phi_t), create_graph=True)[0]

        kinetic = 0.5 * torch.sum((phi_t**2) * (a**2), dim=1, keepdim=True)
        potential = 0.5 * torch.sum((self.m_vec_torch**2) * (phi**2) * (a**2), dim=1, keepdim=True)
        Friedmann = a_t - torch.sqrt((1/3)*(kinetic + potential + self.rho_m0/a + self.rho_r0/a**2 + self.rho_l0 * a**2) + 1e-12)
        sqsumrho = torch.sqrt(torch.full_like(t, 3.0)) * torch.sqrt(
            0.5 * torch.sum(phi_t**2, dim=1, keepdim=True) +
            0.5 * torch.sum((self.m_vec_torch**2) * phi**2, dim=1, keepdim=True) +
            self.rho_m0/a**3 + self.rho_r0/a**4 + self.rho_l0 + 1e-12
        )
        KG = phi_tt + sqsumrho * phi_t + phi * (self.m_vec_torch**2)
        return torch.mean(Friedmann**2) + torch.mean(KG**2)

    def initial_loss(self, model):
        a_pred0, phi_pred0 = model(self.t0_torch)
        return 4.0 * torch.mean((a_pred0 - self.a0_torch)**2) + 30.0 * torch.mean((phi_pred0 - self.phi0_torch)**2)

    def train(self, max_epochs_adam=10000, physics_weight=10.0, ic_weight=300.0, N_f=1000, print_every=500):
        optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.4)

        self.loss_history, self.physics_loss_history, self.ic_loss_history = [], [], []
        print("Training PINN...")
        start = time.time()

        for epoch in range(max_epochs_adam):
            t_f = torch.linspace(0.0, 1.0, N_f, device=self.device).reshape(-1, 1).requires_grad_()
            optimizer.zero_grad()
            loss_physics_val = self.physics_loss(self.model, t_f)
            loss_ic_val = self.initial_loss(self.model)
            loss = physics_weight * loss_physics_val + ic_weight * loss_ic_val
            loss.backward()
            optimizer.step()
            scheduler.step()

            self.loss_history.append(loss.item())
            self.physics_loss_history.append(loss_physics_val.item())
            self.ic_loss_history.append(loss_ic_val.item())

            if epoch % print_every == 0 or epoch == max_epochs_adam - 1:
                percent = 100 * epoch / max_epochs_adam
                print(f"[Adam] Epoch {epoch:5d}/{max_epochs_adam} ({percent:5.1f}%) | "
                      f"Total Loss: {loss.item():.3e} | Physics: {loss_physics_val.item():.3e} | IC: {loss_ic_val.item():.3e}")

        print(f"Training time: {time.time() - start:.2f} sec")

    def optimize_lbfgs(self):
        print("Starting LBFGS optimization...")
        optimizer_lbfgs = optim.LBFGS(self.model.parameters(), lr=1.0, max_iter=5000, tolerance_grad=1e-9,
                                      tolerance_change=1e-10, history_size=100, line_search_fn='strong_wolfe')

        def closure():
            optimizer_lbfgs.zero_grad()
            t_f = torch.linspace(0.0, 1.0, 1000, device=self.device).reshape(-1, 1).requires_grad_()
            loss = 10.0 * self.physics_loss(self.model, t_f) + 300.0 * self.initial_loss(self.model)
            loss.backward()
            return loss

        optimizer_lbfgs.step(closure)

    def evaluate(self):
        t_plot = torch.tensor(self.t_eval, device=self.device).reshape(-1, 1).requires_grad_(True)
        a_pred, phi_pred = self.model(t_plot)
        self.a_pred = a_pred.detach().cpu().numpy().flatten()
        self.phi_pred = phi_pred.detach().cpu().numpy().T

    def evaluate_error_metrics(self):
        mse_a = mean_squared_error(self.a_sol, self.a_pred)
        mae_a = mean_absolute_error(self.a_sol, self.a_pred)
        mse_phi, mae_phi = [], []
        for i in range(self.N_fields):
            mse_phi.append(mean_squared_error(self.phi_sol[i], self.phi_pred[i]))
            mae_phi.append(mean_absolute_error(self.phi_sol[i], self.phi_pred[i]))

        print("\nEvaluation Metrics:")
        print(f"  a(t)       -> MSE: {mse_a:.4e}, MAE: {mae_a:.4e}")
        for i, (mse_p, mae_p) in enumerate(zip(mse_phi, mae_phi)):
            print(f"  phi[{i}](t) -> MSE: {mse_p:.4e}, MAE: {mae_p:.4e}")

    def plot_results(self):
        plt.figure(dpi=120)
        plt.plot(self.t_eval, self.a_sol, 'k--', label='ODE a(t)')
        plt.plot(self.t_eval, self.a_pred, 'r-', label='PINNs a(t)')
        for i in range(min(self.N_fields, 5)):
            plt.plot(self.t_eval, self.phi_sol[i, :], 'k--', alpha=0.5)
            plt.plot(self.t_eval, self.phi_pred[i, :], 'r-', alpha=0.5)
        plt.xscale('log')
        plt.xlabel('t')
        plt.ylabel('Value')
        plt.title('Comparison: ODE vs PINNs')
        plt.legend()
        plt.grid()
        plt.savefig(os.path.join(self.folder_name, 'comparison.png'))
        plt.show()

    def plot_losses(self):
        plt.figure(dpi=120)
        plt.plot(self.loss_history, label='Total Loss')
        plt.plot(self.physics_loss_history, label='Physics Loss')
        plt.plot(self.ic_loss_history, label='IC Loss')
        plt.yscale('log')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Loss Curves During Training')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(self.folder_name, 'loss_curves.png'))
        plt.show()

if __name__ == "__main__":
    solver = PINNSolver()
    solver.solve_ode()
    solver.train()
    solver.optimize_lbfgs()
    solver.evaluate()
    solver.evaluate_error_metrics()
    solver.plot_results()
    solver.plot_losses()
