In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
from scipy.integrate import solve_ivp
from google.colab import files

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Common parameters
b1, b2 = 1, 1
tha1, tha2 = 0.5, 0.5
thb1, thb2 = 0.07, 0.07
k1, k2 = 1, 1
n, m = 4, 1

# Time domain
t_start, t_end = 0.0, 10.0
n_points = 200
t_eval = np.linspace(t_start, t_end, n_points)
t_train = torch.linspace(t_start, t_end, n_points).view(-1, 1).to(device)

In [None]:

# PINN Model
class PINN(nn.Module):
    def __init__(self, case_num):
        super().__init__()
        # Case-specific architecture
        if case_num == 1:
            # Simpler architecture for Case 1
            self.net = nn.Sequential(
                nn.Linear(1, 128),
                nn.Tanh(),
                nn.Linear(128, 128),
                nn.Tanh(),
                nn.Linear(128, 64),
                nn.Tanh(),
                nn.Linear(64, 2)
            )
        else:
            # More complex architecture for Case 2
            self.net = nn.Sequential(
                nn.Linear(1, 256),
                nn.Tanh(),
                nn.Linear(256, 256),
                nn.Tanh(),
                nn.Linear(256, 256),
                nn.Tanh(),
                nn.Linear(256, 128),
                nn.Tanh(),
                nn.Linear(128, 2)
            )

        # Initialize weights using Xavier initialization
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, t):
        return self.net(t)

# Physics Loss
def ode_residuals(model, t, a1, a2):
    t.requires_grad = True
    G_pred, P_pred = model(t).split(1, dim=1)

    dG_dt = torch.autograd.grad(G_pred, t, torch.ones_like(G_pred), create_graph=True)[0]
    dP_dt = torch.autograd.grad(P_pred, t, torch.ones_like(P_pred), create_graph=True)[0]

    f1 = (a1 * G_pred**n / (tha1**n + G_pred**n)) + (b1 * thb1**m / (thb1**m + G_pred**m * P_pred**m)) - k1 * G_pred
    f2 = (a2 * P_pred**n / (tha2**n + P_pred**n)) + (b2 * thb2**m / (thb2**m + G_pred**m * P_pred**m)) - k2 * P_pred

    res1 = dG_dt - f1
    res2 = dP_dt - f2

    return res1, res2, G_pred, P_pred

def train_pinn(G0, P0, a1, a2, case_name):
    print(f"\n=== Training PINN for {case_name} ===")

    # Determine case number from parameters
    case_num = 2 if a1 > 1 else 1

    model = PINN(case_num).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Case-specific learning rate scheduler
    if case_num == 1:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=500,
            verbose=True, min_lr=1e-5
        )
    else:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=1000,
            verbose=True, min_lr=1e-6
        )

    loss_fn = nn.MSELoss()

    start_time = time.time()
    loss_history = []

    # Case-specific number of epochs
    n_epochs = 30000 if case_num == 1 else 50000

    # Case-specific loss weights
    w_phys = 1.0
    w_ic = 10.0 if case_num == 1 else 5.0

    # Curriculum learning parameters (only for Case 2)
    curriculum_steps = 5 if case_num == 2 else 1
    current_step = 0

    for epoch in range(n_epochs):
        optimizer.zero_grad()

        # Curriculum learning only for Case 2
        if case_num == 2 and epoch % (n_epochs // curriculum_steps) == 0 and current_step < curriculum_steps:
            current_step += 1
            t_end_curr = t_end * (current_step / curriculum_steps)
            t_train_curr = torch.linspace(t_start, t_end_curr, n_points).view(-1, 1).to(device)
        else:
            t_train_curr = t_train

        resG, resP, G_pred, P_pred = ode_residuals(model, t_train_curr, a1, a2)

        # Case-specific loss computation
        loss_phys = w_phys * (loss_fn(resG, torch.zeros_like(resG)) + loss_fn(resP, torch.zeros_like(resP)))

        # Initial condition loss
        t0 = torch.tensor([[0.0]], dtype=torch.float32, requires_grad=True).to(device)
        G0_pred, P0_pred = model(t0).split(1, dim=1)
        loss_ic = w_ic * (loss_fn(G0_pred, torch.tensor([[G0]], dtype=torch.float32, device=device)) +
                          loss_fn(P0_pred, torch.tensor([[P0]], dtype=torch.float32, device=device)))

        # Adaptive weighting only for Case 2
        if case_num == 2 and epoch > 1000:
            w_phys = min(1.0, loss_ic.item() / (loss_phys.item() + 1e-8))
            w_ic = min(10.0, loss_phys.item() / (loss_ic.item() + 1e-8))

        loss = loss_phys + loss_ic
        loss.backward()

        # Gradient clipping with case-specific thresholds
        max_norm = 0.5 if case_num == 1 else 1.0
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)

        optimizer.step()
        scheduler.step(loss)

        loss_history.append(loss.item())

        if epoch % 500 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.6f}, LR: {optimizer.param_groups[0]['lr']:.6f}")
            print(f"  Physics Loss: {loss_phys.item():.6f}, IC Loss: {loss_ic.item():.6f}")
            if case_num == 2:
                print(f"  Weights - Physics: {w_phys:.2f}, IC: {w_ic:.2f}")

    elapsed_time = time.time() - start_time
    print(f"Training completed in {elapsed_time:.2f} seconds.")

    # Evaluate
    model.eval()
    with torch.no_grad():
        pred = model(t_train).cpu().numpy()
    G_pred, P_pred = pred[:, 0], pred[:, 1]

    return G_pred, P_pred, elapsed_time, loss_history
results ={}
for ncase in [1,2]:
    if ncase == 1:
            G0, P0, a1, a2 = 1, 1, 1, 1
    elif ncase == 2:
            G0, P0, a1, a2 = 1, 1, 5, 10
    case_name = f"Case {ncase} (a1={a1}, a2={a2})"
    G_pinn, P_pinn, time_pinn, loss_history = train_pinn(G0, P0, a1, a2, case_name)
    results[ncase] = {
        't': t_eval,
        'G_pinn': G_pinn, 'P_pinn': P_pinn,
        'time_pinn': time_pinn,
        'loss_history': loss_history,
        'params': f"a1={a1}, a2={a2}"
    }
