In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from functools import partial
from scipy.integrate import odeint
from sympy import symbols, Eq, solve, Function, Matrix, diff

# Constants and Parameters
t_end = 20
t1 = np.linspace(0, t_end, 100)
t2 = np.linspace(0, t_end, 100)
eps = 0.01
T_slow_end = 1
tau1 = np.linspace(0, T_slow_end, 100)
tau2 = np.linspace(0, T_slow_end, 100)
tau3 = np.linspace(0, T_slow_end, 100)

# Convert to tensors
tau1_tensor = torch.tensor(tau1.reshape(-1, 1), dtype=torch.float64)
t1_tensor = torch.tensor(t1.reshape(-1, 1), dtype=torch.float64)
tau2_tensor = torch.tensor(tau2.reshape(-1, 1), dtype=torch.float64)
t2_tensor = torch.tensor(t2.reshape(-1, 1), dtype=torch.float64)
tau3_tensor = torch.tensor(tau3.reshape(-1, 1), dtype=torch.float64)

# Model parameters
num_nrn = 7
z1, z2 = 1.0, -1.0
V = -10
l = 1.0
r = 0.5
phi_init, c1_init, c2_init, w_init = V, l, l, 0.0
phi_end, c1_end, c2_end, w_end = 0.0, r, r, 1.0

# Neural Network Models
class PINN(nn.Module):
    """Generic PINN model for both slow and fast systems."""
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, num_nrn)
        self.fc2 = nn.Linear(num_nrn, num_nrn)
        self.fc3 = nn.Linear(num_nrn, 7)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return self.fc3(x)

# Initialize models
model_slow1 = PINN().double()
model_fast1 = PINN().double()
model_slow2 = PINN().double()
model_fast2 = PINN().double()
model_slow3 = PINN().double()

# Initialize coupling variables
coupling_vars = {
    'c1_s_a_l': np.random.uniform(0, 1),
    'c2_s_a_l': np.random.uniform(0, 1),
    'phi_s_a_l': np.random.uniform(0, 1),
    # ... initialize all coupling variables similarly
}

# Loss configuration
phys_weight = 3
init_weight = 1
bndry_weight = 1

def compute_residuals(model, inputs, system_type, eps=None):
    """Compute physics residuals for given system type."""
    inputs.requires_grad = True
    pred = model(inputs)
    phi, u = pred[:, 0], pred[:, 1]
    c1, c2 = pred[:, 2], pred[:, 3]
    j1, j2, w = pred[:, 4], pred[:, 5], pred[:, 6]

    # Compute gradients
    gradients = {}
    for var in [phi, u, c1, c2, j1, j2, w]:
        var.sum().backward(retain_graph=True, create_graph=True)
        gradients[var] = inputs.grad.clone()
        inputs.grad.zero_()

    # System-specific residuals
    if system_type == 'slow':
        p = -(z1*j1 + z2*j2) / (z1*(z1-z2)*c1)
        residuals = [
            u,  # residual1
            gradients[phi] - p,
            gradients[c1] + z1*c1*p + j1,
            gradients[c2] + z2*c2*p + j2,
            gradients[j1],
            gradients[j2],
            gradients[w] - 1,
            z1*c1 + z2*c2
        ]
    elif system_type == 'fast':
        residuals = [
            gradients[phi] - u,
            gradients[u] + z1*c1 + z2*c2,
            gradients[c1] + z1*c1*u + eps*j1,
            gradients[c2] + z2*c2*u + eps*j2,
            gradients[j1],
            gradients[j2],
            gradients[w] - eps
        ]
    
    # Non-negativity constraints
    non_neg = [torch.clamp(-c1, min=0), torch.clamp(-c2, min=0)]
    return residuals, non_neg, pred

def create_loss_func(system_type, init_conds, boundary_conds, eps=None):
    """Factory function to create loss functions."""
    def loss_func(model, inputs, eps=None):
        residuals, non_neg, pred = compute_residuals(model, inputs, system_type, eps)
        # Calculate initialization and boundary losses
        # ... implementation details
        return total_loss
    return loss_func

# Training loop
def train_models(models, optimizers, epochs=20000):
    loss_history = []
    for epoch in range(epochs):
        # Training step
        # ... implement training logic
        
        # Update coupling variables
        with torch.no_grad():
            # ... update coupling variables using current model predictions
        
        # Logging
        if epoch % 1000 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')
            loss_history.append(loss.item())
    return loss_history

# Main execution
if __name__ == '__main__':
    # Initialize models and optimizers
    models = [model_slow1, model_fast1, model_slow2, model_fast2, model_slow3]
    optimizer = torch.optim.Adam(
        [p for model in models for p in model.parameters()], 
        lr=1e-3
    )
    
    # Define loss functions for each system
    loss_func_slow1 = create_loss_func('slow', init_conds=(phi_init, c1_init, c2_init, w_init), 
                                      boundary_conds=(phi_f_a_l, c1_f_a_l, c2_f_a_l, w_a), eps=eps)
    loss_func_fast1 = create_loss_func('fast', init_conds=(c1_s_a_l, c2_s_a_l, w_a), 
                                      boundary_conds=(phi_s_a_r, c1_s_a_r, c2_s_a_r, w_a), eps=eps)
    loss_func_slow2 = create_loss_func('slow', init_conds=(phi_f_a_r, c1_f_a_r, c2_f_a_r, w_a), 
                                      boundary_conds=(phi_f_b_l, c1_f_b_l, c2_f_b_l, w_b), eps=eps)
    loss_func_fast2 = create_loss_func('fast', init_conds=(c1_s_b_l, c2_s_b_l, w_b), 
                                      boundary_conds=(phi_s_b_r, c1_s_b_r, c2_s_b_r, w_b), eps=eps)
    loss_func_slow3 = create_loss_func('slow', init_conds=(phi_f_b_r, c1_f_b_r, c2_f_b_r, w_b), 
                                      boundary_conds=(phi_end, c1_end, c2_end, w_end), eps=eps)
    
    # Train models
    loss_values = train_models(models, optimizer)
    
    # Plot results
    plt.figure(figsize=(12, 4))
    plt.plot(np.log(loss_values))
    plt.xlabel('Epoch (x1000)')
    plt.ylabel('Log(Loss)')
    plt.title('Training Loss')
    plt.grid(True)
    plt.show()

    # Final evaluation
    with torch.no_grad():
        # Predictions for all systems
        pred_slow1 = model_slow1(tau1_tensor).numpy()
        pred_fast1 = model_fast1(t1_tensor).numpy()
        pred_slow2 = model_slow2(tau2_tensor).numpy()
        pred_fast2 = model_fast2(t2_tensor).numpy()
        pred_slow3 = model_slow3(tau3_tensor).numpy()

        # Extract predictions
        phi_pred_slow1, u_pred_slow1, c1_pred_slow1, c2_pred_slow1, j1_pred_slow1, j2_pred_slow1, w_pred_slow1 = pred_slow1.T
        phi_pred_fast1, u_pred_fast1, c1_pred_fast1, c2_pred_fast1, j1_pred_fast1, j2_pred_fast1, w_pred_fast1 = pred_fast1.T
        phi_pred_slow2, u_pred_slow2, c1_pred_slow2, c2_pred_slow2, j1_pred_slow2, j2_pred_slow2, w_pred_slow2 = pred_slow2.T
        phi_pred_fast2, u_pred_fast2, c1_pred_fast2, c2_pred_fast2, j1_pred_fast2, j2_pred_fast2, w_pred_fast2 = pred_fast2.T
        phi_pred_slow3, u_pred_slow3, c1_pred_slow3, c2_pred_slow3, j1_pred_slow3, j2_pred_slow3, w_pred_slow3 = pred_slow3.T