In [None]:
### start with 5 outputs. Freeze. Add more equations/outputs

import numpy as np
import matplotlib.pyplot as plt
from neurodiffeq.conditions import IVP
from neurodiffeq.solvers import Solver2D
from neurodiffeq.monitors import Monitor2D
from neurodiffeq.generators import Generator2D
import torch
from neurodiffeq import diff      # the differentiation operation
from neurodiffeq.ode import solve # the ANN-based solver
from neurodiffeq.conditions import IVP   # the initial condition
from neurodiffeq.networks import FCNN    # fully-connect neural network
from neurodiffeq.networks import SinActv # sin activation

In [None]:
import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

# Paramètres
X = 1.0       # longueur du domaine spatial
nx = 300      # nombre de points spatiaux
L = 1.0       # durée finale en temps
mu = 0.1      # viscosité
#f = [0]*8     # fonctions de forçage constantes ici

def forcing_terms(x, t):
    """
    Retourne un tableau (8, N) avec f0..f7 sur l'espace x à l'instant t.
    """
    f0 = -np.sin(x) * np.exp(-t)+x*np.cos(x)*np.exp(-2*t)+2*np.exp(-t)+np.sin(x)*np.exp(-2*t)
    f1 = -2*x*np.exp(-t)-np.sin(x) * np.exp(-2*t)*x+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)-np.exp(-x-t)-2*np.exp(-2*x-2*t)
    f2 = -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f3 =  -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f4 =  - np.exp(-x-t)+np.exp(-x-2*t)-x*np.exp(-x-2*t)
    f5 = -np.exp(-x-t)
    f6 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    f7 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    
    return np.array([f0, f1, f2, f3, f4, f5, f6, f7])  # shape (8, N)


x = np.linspace(0, X, nx)
dx = x[1] - x[0]

# IC 
def init_cond():
    rho = 2+np.sin(x)
    vx = x
    vy = x
    vz = x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# spatial derivation
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Function ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) + f[0]
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) + f[1])
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) + f[2])
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) + f[3])
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) + f[4]
    dBx_dt = np.zeros_like(Bx) + f[5]
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) + f[6])
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) + f[7])

    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])

# solver
y0 = init_cond()
sol = solve_ivp(mhd_rhs, [0, L], y0, method='RK45', t_eval=np.linspace(0, L, 300))

# Show variable (ex: vx)
vx_sol = sol.y[nx:2*nx, :]  # vx 
plt.imshow(vx_sol, extent=[0, L, 0, X], aspect='auto', origin='lower')
plt.xlabel("Temps")
plt.ylabel("Espace (x)")
plt.title("vx(x,t)")
plt.colorbar(label="vx")
plt.show()

In [None]:
### solver 

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

# Parameters
X = 1.0       # spatial domain length
nx = 300      # number of spacial points 
L = 1.0       # time
mu = 0.1      # viscosity
#f = [0]*8     # 

def forcing_terms(x, t):
    """
    Return (8, N)
    """
    f0 = -np.sin(x) * np.exp(-t)+x*np.cos(x)*np.exp(-2*t)+2*np.exp(-t)+np.sin(x)*np.exp(-2*t)
    f1 = -2*x*np.exp(-t)-np.sin(x) * np.exp(-2*t)*x+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)-np.exp(-x-t)-2*np.exp(-2*x-2*t)
    f2 = -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f3 =  -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f4 =  - np.exp(-x-t)+np.exp(-x-2*t)-x*np.exp(-x-2*t)
    f5 = -np.exp(-x-t)
    f6 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    f7 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    
    return np.array([f0, f1, f2, f3, f4, f5, f6, f7])  # shape (8, N)


x = np.linspace(0, X, nx)
dx = x[1] - x[0]

# IC
def init_cond():
    rho = 2+np.sin(x)
    vx = x
    vy = x
    vz = x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) + f[0]
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) + f[1])
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) + f[2])
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) + f[3])
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) + f[4]
    dBx_dt = np.zeros_like(Bx) + f[5]
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) + f[6])
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) + f[7])

    # Variables fixes :  rho(0)=2, vx=vy=vz=0  --> dérivées temporelles = 0
    drho_dt[0] = 0.0
    dvx_dt[0]  = 0.0
    dvy_dt[0]  = 0.0
    dvz_dt[0]  = 0.0

    # Variables imposées = exp(-t)  -->  d/dt exp(-t) = -exp(-t)
    bc_dot = -np.exp(-t)
    dP_dt[0]  = bc_dot
    dBx_dt[0] = bc_dot
    dBy_dt[0] = bc_dot
    dBz_dt[0] = bc_dot

    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])

#Lancement du solveur
y0 = init_cond()
tspan = (0.0, L)
t_eval = np.linspace(0, L, nx)
sol = solve_ivp( mhd_rhs, tspan, y0,method='RK45',t_eval=t_eval)

# Affichage d'une variable au cours du temps (ex: vx)
vx_sol = sol.y[nx:2*nx, :]
plt.imshow(vx_sol, extent=[0, L, 0, X],aspect='auto', origin='lower')
plt.xlabel("Temps")
plt.ylabel("Position x")
plt.title("vx(x,t)")
plt.colorbar(label="vx")
plt.show()


In [None]:
### solveur with ff : works well

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

# Paramètres
X   = 1.0       # longueur du domaine spatial
nx  = 300       # nombre de points spatiaux
L   = 1.0       # durée finale
mu  = 0.1       # viscosité

def forcing_terms(x, t):
    """
    Retourne un tableau (8, N) avec f0..f7 sur l'espace x à l'instant t.
    """
    f0 = -np.sin(x) * np.exp(-t)+x*np.cos(x)*np.exp(-2*t)+2*np.exp(-t)+np.sin(x)*np.exp(-2*t)
    f1 = -2*x*np.exp(-t)-np.sin(x) * np.exp(-2*t)*x+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)-np.exp(-x-t)-2*np.exp(-2*x-2*t)
    f2 = -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f3 =  -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f4 =  - np.exp(-x-t)+np.exp(-x-2*t)-x*np.exp(-x-2*t)
    f5 = -np.exp(-x-t)
    f6 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    f7 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    
    return np.array([f0, f1, f2, f3, f4, f5, f6, f7])  # shape (8, N)


# Maillage
x   = np.linspace(0, X, nx)
dx  = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales (ici tout est 1 sauf vx = sin(pi x))
def init_cond():
    rho = 2+np.sin(x)
    vx = x
    vy = x
    vz = x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) + f[0]
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) + f[1])
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) + f[2])
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) + f[3])
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) + f[4]
    dBx_dt = np.zeros_like(Bx) + f[5]
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) + f[6])
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) + f[7])

    # Variables fixes :  rho(0)=2, vx=vy=vz=0  --> dérivées temporelles = 0
    drho_dt[0] = 0.0
    dvx_dt[0]  = 0.0
    dvy_dt[0]  = 0.0
    dvz_dt[0]  = 0.0

    # Variables imposées = exp(-t)  -->  d/dt exp(-t) = -exp(-t)
    bc_dot = -np.exp(-t)
    dP_dt[0]  = bc_dot
    dBx_dt[0] = bc_dot
    dBy_dt[0] = bc_dot
    dBz_dt[0] = bc_dot

    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])

# Résolution
y0  = init_cond()
sol = solve_ivp(mhd_rhs, (0, L), y0,
                method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num = sol.y[nx:2*nx, :]  # shape (nx, nt)
T, Xg  = np.meshgrid(sol.t, x)

# Solution exacte
vx_ex = Xg * np.exp(-T)

# Calcul de l'erreur
err = vx_num - vx_ex

# Affichage 1 : vx numérique
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.imshow(vx_num, extent=[0,L,0,X], aspect='auto', origin='lower')
plt.colorbar(label='vx_num')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('vx numérique')

# Affichage 2 : erreur vx_num - vx_ex
plt.subplot(1,2,2)
plt.imshow(err, extent=[0,L,0,X], aspect='auto',
           origin='lower', cmap='bwr', vmin=-np.max(abs(err)), vmax=np.max(abs(err)))
plt.colorbar(label='Erreur')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Erreur $v_x^{num}-x e^{-t}$')

plt.tight_layout()
plt.show()

In [None]:
### solveur without ff

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

# Paramètres
X   = 1.0       # longueur du domaine spatial
nx  = 300       # nombre de points spatiaux
L   = 1.0       # durée finale
mu  = 0.1       # viscosité

def forcing_terms(x, t):
    """
    Retourne un tableau (8, N) avec f0..f7 sur l'espace x à l'instant t.
    """
    f0 = -np.sin(x) * np.exp(-t)+x*np.cos(x)*np.exp(-2*t)+2*np.exp(-t)+np.sin(x)*np.exp(-2*t)
    f1 = -2*x*np.exp(-t)-np.sin(x) * np.exp(-2*t)*x+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)-np.exp(-x-t)-2*np.exp(-2*x-2*t)
    f2 = -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f3 =  -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f4 =  - np.exp(-x-t)+np.exp(-x-2*t)-x*np.exp(-x-2*t)
    f5 = -np.exp(-x-t)
    f6 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    f7 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    
    return np.array([f0, f1, f2, f3, f4, f5, f6, f7])  # shape (8, N)


# Maillage
x   = np.linspace(0, X, nx)
dx  = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales (ici tout est 1 sauf vx = sin(pi x))
def init_cond():
    rho = 2+np.sin(x)
    vx = x
    vy = x
    vz = x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    # Variables fixes :  rho(0)=2, vx=vy=vz=0  --> dérivées temporelles = 0
    drho_dt[0] = 0.0
    dvx_dt[0]  = 0.0
    dvy_dt[0]  = 0.0
    dvz_dt[0]  = 0.0

    # Variables imposées = exp(-t)  -->  d/dt exp(-t) = -exp(-t)
    bc_dot = -np.exp(-t)
    dP_dt[0]  = bc_dot
    dBx_dt[0] = bc_dot
    dBy_dt[0] = bc_dot
    dBz_dt[0] = bc_dot

    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])

# Résolution
y0  = init_cond()
sol = solve_ivp(mhd_rhs, (0, L), y0,
                method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num = sol.y[nx:2*nx, :]  # shape (nx, nt)
T, Xg  = np.meshgrid(sol.t, x)

# Solution exacte
vx_ex = Xg * np.exp(-T)

# Calcul de l'erreur
err = vx_num - vx_ex

# Affichage 1 : vx numérique
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.imshow(vx_num, extent=[0,L,0,X], aspect='auto', origin='lower')
plt.colorbar(label='vx_num')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('vx numérique')

# Affichage 2 : erreur vx_num - vx_ex
plt.subplot(1,2,2)
plt.imshow(err, extent=[0,L,0,X], aspect='auto',
           origin='lower', cmap='bwr', vmin=-np.max(abs(err)), vmax=np.max(abs(err)))
plt.colorbar(label='Erreur')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Erreur $v_x^{num}-x e^{-t}$')

plt.tight_layout()
plt.show()

In [None]:
### solveur without ff
### direct BC

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

# Paramètres
X   = 1.0       # longueur du domaine spatial
nx  = 300       # nombre de points spatiaux
L   = 1.0       # durée finale
mu  = 0.1       # viscosité

def forcing_terms(x, t):
    """
    Retourne un tableau (8, N) avec f0..f7 sur l'espace x à l'instant t.
    """
    f0 = -np.sin(x) * np.exp(-t)+x*np.cos(x)*np.exp(-2*t)+2*np.exp(-t)+np.sin(x)*np.exp(-2*t)
    f1 = -2*x*np.exp(-t)-np.sin(x) * np.exp(-2*t)*x+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)-np.exp(-x-t)-2*np.exp(-2*x-2*t)
    f2 = -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f3 =  -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f4 =  - np.exp(-x-t)+np.exp(-x-2*t)-x*np.exp(-x-2*t)
    f5 = -np.exp(-x-t)
    f6 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    f7 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    
    return np.array([f0, f1, f2, f3, f4, f5, f6, f7])  # shape (8, N)


# Maillage
x   = np.linspace(0, X, nx)
dx  = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales (ici tout est 1 sauf vx = sin(pi x))
def init_cond():
    rho = 2+np.sin(x)
    vx = x
    vy = x
    vz = x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    # Forçage direct des variables aux bords
    rho[0] = 2
    vx[0] = vy[0] = vz[0] = 0
    P[0] = Bx[0] = By[0] = Bz[0] = np.exp(-t)


    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])

# Résolution
y0  = init_cond()
sol = solve_ivp(mhd_rhs, (0, L), y0,
                method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num = sol.y[nx:2*nx, :]  # shape (nx, nt)
T, Xg  = np.meshgrid(sol.t, x)

# Solution exacte
vx_ex = Xg * np.exp(-T)

# Calcul de l'erreur
err = vx_num - vx_ex

# Affichage 1 : vx numérique
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.imshow(vx_num, extent=[0,L,0,X], aspect='auto', origin='lower')
plt.colorbar(label='vx_num')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('vx numérique')

# Affichage 2 : erreur vx_num - vx_ex
plt.subplot(1,2,2)
plt.imshow(err, extent=[0,L,0,X], aspect='auto',
           origin='lower', cmap='bwr', vmin=-np.max(abs(err)), vmax=np.max(abs(err)))
plt.colorbar(label='Erreur')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Erreur $v_x^{num}-x e^{-t}$')

plt.tight_layout()
plt.show()

In [None]:
## TL on IC

from neurodiffeq.networks import FCNN
from neurodiffeq.generators import Generator2D
from neurodiffeq.operators import diff
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

# Définition d'un réseau avec deux têtes
class MultiHeadFCNN(nn.Module):
    def __init__(self, n_input_units=2, n_output_units=1, hidden_units=[256, 256]):
        super().__init__()
        # Couches partagées
        self.shared_layers = nn.Sequential(
            nn.Linear(n_input_units, hidden_units[0]),
            nn.Tanh(),
            nn.Linear(hidden_units[0], hidden_units[1]),
            nn.Tanh()
        )
        # Têtes séparées
        self.head1 = nn.Linear(hidden_units[1], n_output_units)
        self.head2 = nn.Linear(hidden_units[1], n_output_units)
    
    def forward(self, x, head_idx=0):
        shared = self.shared_layers(x)
        if head_idx == 0:
            return self.head1(shared)
        else:
            return self.head2(shared)

total_losses = []
pde_losses = []
ic_losses = []

# Paramètre mu unique
mu = 0.3  # Valeur fixe pour mu

# Définition des conditions initiales pour chaque tête
initial_conditions = [
    # Head 0 - CI originales
    {
        'rho': lambda x: torch.sin(x),
        'v': lambda x: x*x,
        'P': lambda x: torch.exp(-x),
        'By': lambda x: torch.exp(-x),
        'Bz': lambda x: torch.exp(-x)
    },
    # Head 1 - Nouvelles CI
    {
        'rho': lambda x: torch.sin(x),
        'v': lambda x: torch.ones_like(x),
        'P': lambda x: torch.exp(-x),
        'By': lambda x: torch.exp(-x),
        'Bz': lambda x: torch.exp(-x) 
    }
]

# Fonctions de forçage pour chaque tête
# (Vous les remplacerez par vos propres calculs)



# PDE system adapté pour sélectionner le bon forçage selon la tête
def pde_system(rho, v, P, By, Bz, x, t, head_idx):
   
    return [
        diff(rho, t) + v * diff(rho, x) + rho * diff(v, x) ,
        rho * diff(v, t) + rho * v * diff(v, x) + diff(P, x) - rho * mu * diff(v, x, order=2) ,
        diff(P, t) + P * diff(v, x) + v * diff(P, x) 
    ]

# Création des réseaux avec deux têtes chacun
nets = [MultiHeadFCNN(n_input_units=2, n_output_units=1, hidden_units=[256, 256]) for _ in range(5)]

# Optimizer
params = [p for net in nets for p in net.parameters()]
optimizer = optim.Adam(params, lr=1e-4)
criterion = nn.MSELoss()

# Generators
train_gen = Generator2D((20, 20), xy_min=(0, 0), xy_max=(1, 1), method='equally-spaced-noisy')
ic_x = torch.linspace(0, 1, 128).view(-1, 1)
ic_t = torch.zeros_like(ic_x)

# Training loop
for epoch in tqdm(range(12000)):
    optimizer.zero_grad()
    total_loss = 0
    epoch_pde_loss = 0
    epoch_ic_loss = 0
    
    # Pour chaque tête (et chaque set de CI correspondant)
    for head_idx in range(len(initial_conditions)):
        # 1. PDE Loss
        samples = train_gen.get_examples()
        x_train = samples[0].view(-1, 1)
        t_train = samples[1].view(-1, 1)
        inputs = torch.cat((x_train, t_train), dim=1)

        # Utilisation de la tête spécifique
        outputs = [net(inputs, head_idx=head_idx) for net in nets]
        pde_residuals = pde_system(*outputs, x_train, t_train, head_idx)

        loss_pde = sum([criterion(residual, torch.zeros_like(residual)) for residual in pde_residuals])
        epoch_pde_loss += loss_pde.item()

        # 2. Initial condition loss (spécifique à chaque tête)
        ic_inputs = torch.cat((ic_x, ic_t), dim=1)
        ic_outputs = [net(ic_inputs, head_idx=head_idx) for net in nets]
        ic_targets = [
            initial_conditions[head_idx]['rho'](ic_x),
            initial_conditions[head_idx]['v'](ic_x),
            initial_conditions[head_idx]['P'](ic_x),
            initial_conditions[head_idx]['By'](ic_x),
            initial_conditions[head_idx]['Bz'](ic_x)
        ]

        loss_ic = sum([criterion(out, target) for out, target in zip(ic_outputs, ic_targets)])
        epoch_ic_loss += loss_ic.item()

        # Accumulation de la perte totale
        total_loss += loss_pde + loss_ic
    
    # Rétropropagation
    total_loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        avg_pde_loss = epoch_pde_loss / len(initial_conditions)
        avg_ic_loss = epoch_ic_loss / len(initial_conditions)
        avg_total_loss = total_loss.item() / len(initial_conditions)
        print(f"Epoch {epoch} | Avg Loss PDE: {avg_pde_loss:.3e} | Avg Loss IC: {avg_ic_loss:.3e} | Avg Total: {avg_total_loss:.3e}")
    
    total_losses.append(total_loss.item() / len(initial_conditions))
    pde_losses.append(epoch_pde_loss / len(initial_conditions))
    ic_losses.append(epoch_ic_loss / len(initial_conditions))

# Fonction pour obtenir les solutions
def solutions(x, t, head_idx=0):
    inputs = torch.cat((x, t), dim=1)
    return [net(inputs, head_idx=head_idx).detach() for net in nets]

In [None]:
## PINN training 1 without ff, with BC


from neurodiffeq.networks import FCNN
from neurodiffeq.generators import Generator2D
from neurodiffeq.operators import diff
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

# Définition d'un réseau avec deux têtes
class MultiHeadFCNN(nn.Module):
    def __init__(self, n_input_units=2, n_output_units=1, hidden_units=[256, 256]):
        super().__init__()
        # Couches partagées
        self.shared_layers = nn.Sequential(
            nn.Linear(n_input_units, hidden_units[0]),
            nn.Tanh(),
            nn.Linear(hidden_units[0], hidden_units[1]),
            nn.Tanh()
        )
        # Têtes séparées
        self.head1 = nn.Linear(hidden_units[1], n_output_units)
        self.head2 = nn.Linear(hidden_units[1], n_output_units)
    
    def forward(self, x, head_idx=0):
        shared = self.shared_layers(x)
        if head_idx == 0:
            return self.head1(shared)
        else:
            return self.head2(shared)

total_losses = []
pde_losses = []
ic_losses = []
bc_losses = []  # Nouveau: pour suivre la loss des conditions aux limites

# Paramètre mu unique
mu = 0.3

# Conditions initiales et aux limites
initial_conditions = [
    # Head 0
    {
        'rho': lambda x: 2+torch.sin(x),
        'v': lambda x: x,
        'P': lambda x: torch.exp(-x),
        'By': lambda x: torch.exp(-x),
        'Bz': lambda x: torch.exp(-x)
    },
    # Head 1 
    {
        'rho': lambda x: 2+torch.sin(x),
        'v': lambda x: x*x*x,
        'P': lambda x: torch.exp(-x),
        'By': lambda x: torch.exp(-x),
        'Bz': lambda x: torch.exp(-x) 
    }
]

# Nouveau: Conditions aux limites (communes aux deux têtes)
def boundary_conditions(x, t):
    """Retourne les valeurs cibles pour x=0"""
    return {
        'rho': 2.0 * torch.ones_like(x),
        'v': torch.zeros_like(x),
        'P': torch.exp(-t),
        'By': torch.exp(-t),
        'Bz': torch.exp(-t)

    }

# PDE system
def pde_system(rho, v, P, By, Bz, x, t, head_idx):
    return [
        diff(rho, t) + v * diff(rho, x) + rho * diff(v, x),
        rho * diff(v, t) + rho * v * diff(v, x) + diff(P, x) - rho * mu * diff(v, x, order=2),
        diff(P, t) + P * diff(v, x) + v * diff(P, x)
    ]

# Création des réseaux
nets = [MultiHeadFCNN(n_input_units=2, n_output_units=1, hidden_units=[256, 256]) for _ in range(5)]

# Optimizer
params = [p for net in nets for p in net.parameters()]
optimizer = optim.Adam(params, lr=1e-4)
criterion = nn.MSELoss()

# Generators
train_gen = Generator2D((20, 20), xy_min=(0, 0), xy_max=(1, 1), method='equally-spaced-noisy')
bc_gen = Generator2D((20, 20), xy_min=(0, 0), xy_max=(0, 1), method='equally-spaced')  # Points au bord x=0
ic_x = torch.linspace(0, 1, 128).view(-1, 1)
ic_t = torch.zeros_like(ic_x)

# Training loop
for epoch in tqdm(range(12000)):
    optimizer.zero_grad()
    total_loss = 0
    epoch_pde_loss = 0
    epoch_ic_loss = 0
    epoch_bc_loss = 0  # Nouveau: loss des conditions aux limites
    
    for head_idx in range(len(initial_conditions)):
        # 1. PDE Loss
        samples = train_gen.get_examples()
        x_train = samples[0].view(-1, 1)
        t_train = samples[1].view(-1, 1)
        inputs = torch.cat((x_train, t_train), dim=1)

        outputs = [net(inputs, head_idx=head_idx) for net in nets]
        pde_residuals = pde_system(*outputs, x_train, t_train, head_idx)
        loss_pde = sum([criterion(residual, torch.zeros_like(residual)) for residual in pde_residuals])
        epoch_pde_loss += loss_pde.item()

        # 2. Initial condition loss
        ic_inputs = torch.cat((ic_x, ic_t), dim=1)
        ic_outputs = [net(ic_inputs, head_idx=head_idx) for net in nets]
        ic_targets = [
            initial_conditions[head_idx]['rho'](ic_x),
            initial_conditions[head_idx]['v'](ic_x),
            initial_conditions[head_idx]['P'](ic_x),
            initial_conditions[head_idx]['By'](ic_x),
            initial_conditions[head_idx]['Bz'](ic_x)
        ]
        loss_ic = sum([criterion(out, target) for out, target in zip(ic_outputs, ic_targets)])
        epoch_ic_loss += loss_ic.item()

        # 3. Nouveau: Boundary condition loss (x=0)
        bc_samples = bc_gen.get_examples()
        x_bc = torch.zeros_like(bc_samples[0]).view(-1, 1)  # x=0
        t_bc = bc_samples[1].view(-1, 1)
        bc_inputs = torch.cat((x_bc, t_bc), dim=1)
        
        bc_outputs = [net(bc_inputs, head_idx=head_idx) for net in nets]
        bc_targets = boundary_conditions(x_bc, t_bc)
        loss_bc = (
            criterion(bc_outputs[0], bc_targets['rho']) +  # rho
            criterion(bc_outputs[1], bc_targets['v']) +    # v
            criterion(bc_outputs[2], bc_targets['P']) +    # P
            criterion(bc_outputs[3], bc_targets['By']) +   # By
            criterion(bc_outputs[4], bc_targets['Bz'])     # Bz
        )
        epoch_bc_loss += loss_bc.item()

        total_loss += loss_pde + loss_ic + loss_bc  # Inclut maintenant la loss BC
    
    total_loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        n_heads = len(initial_conditions)
        print(
            f"Epoch {epoch} | "
            f"PDE: {epoch_pde_loss/n_heads:.3e} | "
            f"IC: {epoch_ic_loss/n_heads:.3e} | "
            f"BC: {epoch_bc_loss/n_heads:.3e} | "  # Nouveau: affichage BC
            f"Total: {total_loss.item()/n_heads:.3e}"
        )
    
    total_losses.append(total_loss.item() / len(initial_conditions))
    pde_losses.append(epoch_pde_loss / len(initial_conditions))
    ic_losses.append(epoch_ic_loss / len(initial_conditions))
    bc_losses.append(epoch_bc_loss / len(initial_conditions))  # Stocke la loss BC

def solutions(x, t, head_idx=0):
    inputs = torch.cat((x, t), dim=1)
    return [net(inputs, head_idx=head_idx).detach() for net in nets]

In [None]:
## Loss function

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.semilogy(total_losses, label='Total Loss')
plt.semilogy(pde_losses, label='PDE Loss', linestyle='--')
plt.semilogy(ic_losses, label='IC Loss', linestyle=':')
plt.semilogy(bc_losses, label='BC Loss', linestyle='-')


plt.xlabel('Epoch')
plt.ylabel('Loss (log scale)')
plt.title('Training Loss Evolution')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
## shox rho for 2 heads

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch

# 1. Vérification des types de données
print(f"Type des poids du premier réseau: {next(nets[0].parameters()).dtype}")

# 2. Création du maillage
x = np.linspace(0, 1, 100).astype(np.float32)
t = np.linspace(0, 1, 50).astype(np.float32)
X, T = np.meshgrid(x, t)

# 3. Conversion en tenseurs PyTorch
x_tensor = torch.tensor(X.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# 4. Calcul des solutions pour les deux têtes
try:
    with torch.no_grad():
        # Vérification et conversion si nécessaire
        if next(nets[0].parameters()).dtype != torch.float32:
            for net in nets:
                net.float()
        
        # Calcul pour les deux têtes
        sols_head1 = solutions(x_tensor, t_tensor, head_idx=0)  # mu=0.1
        sols_head2 = solutions(x_tensor, t_tensor, head_idx=1)  # mu=0.5
        
        rho_head1 = sols_head1[0].cpu().numpy().reshape(X.shape)
        rho_head2 = sols_head2[0].cpu().numpy().reshape(X.shape)
        
except Exception as e:
    print(f"Erreur lors du calcul: {str(e)}")
    raise

# Solution analytique
rho_analytical = 2+ np.sin(X) * np.exp(-T)

# Calcul des erreurs
error_head1 = np.abs(rho_analytical - rho_head1)
error_head2 = np.abs(rho_analytical - rho_head2)

# Création d'une figure avec 4 sous-graphiques
fig, axes = plt.subplots(2, 2, figsize=(18, 12))
fig.suptitle('Comparaison of the solutions ρ(x,t) for the 2 heads', fontsize=16)

# Graphique pour mu=0.1
im1 = axes[0,0].pcolormesh(X, T, rho_head1, shading='auto', cmap='viridis')
fig.colorbar(im1, ax=axes[0,0], label='ρ(x,t)')
axes[0,0].set_title('Solution ρ(x,t) - Head 1 (μ=0.1)')
axes[0,0].set_xlabel('x')
axes[0,0].set_ylabel('t')

# Erreur pour mu=0.1
im2 = axes[0,1].pcolormesh(X, T, error_head1, shading='auto', cmap='hot')
fig.colorbar(im2, ax=axes[0,1], label='Absolute erreur ')
axes[0,1].set_title('Erreur - Head 1 (μ=0.1)')
axes[0,1].set_xlabel('x')
axes[0,1].set_ylabel('t')

# Graphique pour mu=0.5
im3 = axes[1,0].pcolormesh(X, T, rho_head2, shading='auto', cmap='viridis')
fig.colorbar(im3, ax=axes[1,0], label='ρ(x,t)')
axes[1,0].set_title('Solution ρ(x,t) - Head 2 (μ=0.5)')
axes[1,0].set_xlabel('x')
axes[1,0].set_ylabel('t')

# Erreur pour mu=0.5
im4 = axes[1,1].pcolormesh(X, T, error_head2, shading='auto', cmap='hot')
fig.colorbar(im4, ax=axes[1,1], label='Absolute erreur ')
axes[1,1].set_title('Erreur - Head 2 (μ=0.5)')
axes[1,1].set_xlabel('x')
axes[1,1].set_ylabel('t')

plt.tight_layout()
plt.show()

# Affichage des erreurs maximales pour comparaison
print(f"Erreur maximale tête 1 (μ=0.1): {np.max(error_head1):.3e}")
print(f"Erreur maximale tête 2 (μ=0.5): {np.max(error_head2):.3e}")

# Optionnel: Différence entre les deux têtes
diff_heads = np.abs(rho_head1 - rho_head2)
plt.figure(figsize=(12, 6))
plt.pcolormesh(X, T, diff_heads, shading='auto', cmap='plasma')
plt.colorbar(label='Difference')
plt.xlabel('x')
plt.ylabel('t')
plt.title('Difference between the 2 heads')
plt.show()
print()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch

##vx

# 1. Vérification des types de données
print(f"Type des poids du premier réseau: {next(nets[1].parameters()).dtype}")

# 2. Création du maillage
x = np.linspace(0, 1, 100).astype(np.float32)
t = np.linspace(0, 1, 50).astype(np.float32)
X, T = np.meshgrid(x, t)

# 3. Conversion en tenseurs PyTorch
x_tensor = torch.tensor(X.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# 4. Calcul des solutions pour les deux têtes
try:
    with torch.no_grad():
        # Vérification et conversion si nécessaire
        if next(nets[1].parameters()).dtype != torch.float32:
            for net in nets:
                net.float()
        
        # Calcul pour les deux têtes
        sols_head1 = solutions(x_tensor, t_tensor, head_idx=0)  # mu=0.1
        sols_head2 = solutions(x_tensor, t_tensor, head_idx=1)  # mu=0.5
        
        rho_head1 = sols_head1[1].cpu().numpy().reshape(X.shape)
        rho_head2 = sols_head2[1].cpu().numpy().reshape(X.shape)
        
except Exception as e:
    print(f"Erreur lors du calcul: {str(e)}")
    raise

# Solution analytique
rho_analytical1 =  x*np.exp(-T)
rho_analytical2 =  x*x*x*np.exp(-T)

# Calcul des erreurs
error_head1 = np.abs(rho_analytical1 - rho_head1)
error_head2 = np.abs(rho_analytical2 - rho_head2)

# Création d'une figure avec 4 sous-graphiques
fig, axes = plt.subplots(2, 2, figsize=(18, 12))
fig.suptitle('Comparaison of vx(x,t) for the 2 heads', fontsize=16)

# Graphique pour mu=0.1
im1 = axes[0,0].pcolormesh(X, T, rho_head1, shading='auto', cmap='viridis')
fig.colorbar(im1, ax=axes[0,0], label='vx(x,t)')
axes[0,0].set_title('Solution vx(x,t) - Head 1 (μ=0.1)')
axes[0,0].set_xlabel('x')
axes[0,0].set_ylabel('t')

# Erreur pour mu=0.1
im2 = axes[0,1].pcolormesh(X, T, error_head1, shading='auto', cmap='hot')
fig.colorbar(im2, ax=axes[0,1], label='Absolute error')
axes[0,1].set_title('Error - Head 1 (μ=0.1)')
axes[0,1].set_xlabel('x')
axes[0,1].set_ylabel('t')

# Graphique pour mu=0.5
im3 = axes[1,0].pcolormesh(X, T, rho_head2, shading='auto', cmap='viridis')
fig.colorbar(im3, ax=axes[1,0], label='vx(x,t)')
axes[1,0].set_title('Solution vx(x,t) - Head 2 (μ=0.5)')
axes[1,0].set_xlabel('x')
axes[1,0].set_ylabel('t')

# Erreur pour mu=0.5
im4 = axes[1,1].pcolormesh(X, T, error_head2, shading='auto', cmap='hot')
fig.colorbar(im4, ax=axes[1,1], label='Absolute error ')
axes[1,1].set_title('Error - Head 2 (μ=0.5)')
axes[1,1].set_xlabel('x')
axes[1,1].set_ylabel('t')

plt.tight_layout()
plt.show()

# Affichage des erreurs maximales pour comparaison
print(f"Erreur maximale tête 1 (μ=0.1): {np.max(error_head1):.3e}")
print(f"Erreur maximale tête 2 (μ=0.5): {np.max(error_head2):.3e}")

# Optionnel: Différence entre les deux têtes
diff_heads = np.abs(rho_head1 - rho_head2)
plt.figure(figsize=(12, 6))
plt.pcolormesh(X, T, diff_heads, shading='auto', cmap='plasma')
plt.colorbar(label='Difference')
plt.xlabel('x')
plt.ylabel('t')
plt.title('Difference between the 2 heads')
plt.show()
print()

In [None]:
# # Save NN after training
# for i, net in enumerate(nets):
#     torch.save(net.state_dict(), f"net_{i}.pth")

## save weights of the 5 first NN

import os
os.makedirs("saved_weights", exist_ok=True)  


for i, net in enumerate(nets[:5]):
    torch.save(net.state_dict(), f"saved_weights/net_{i}.pt")


In [None]:
### training 2 without ff, with BC

from neurodiffeq.networks import FCNN
from neurodiffeq.generators import Generator2D
from neurodiffeq.operators import diff
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim

# Définition d'un réseau avec deux têtes
class MultiHeadFCNN(nn.Module):
    def __init__(self, n_input_units=2, n_output_units=1, hidden_units=[256, 256]):
        super().__init__()
        # Couches partagées
        self.shared_layers = nn.Sequential(
            nn.Linear(n_input_units, hidden_units[0]),
            nn.Tanh(),
            nn.Linear(hidden_units[0], hidden_units[1]),
            nn.Tanh()
        )
        # Têtes séparées
        self.head1 = nn.Linear(hidden_units[1], n_output_units)
        self.head2 = nn.Linear(hidden_units[1], n_output_units)
    
    def forward(self, x, head_idx=0):
        shared = self.shared_layers(x)
        if head_idx == 0:
            return self.head1(shared)
        else:
            return self.head2(shared)

total_losses = []
pde_losses = []
ic_losses = []

# Paramètre mu unique
mu = 0.3  # Valeur fixe pour mu

# Définition des conditions initiales pour chaque tête
initial_conditions = [
    # Head 0 - CI originales
    {
        'rho': lambda x: 2+torch.sin(x),
        'v': lambda x: x,
        'P': lambda x: torch.exp(-x),
        'By': lambda x: torch.exp(-x),
        'Bz': lambda x: torch.exp(-x),
        'vy': lambda x: x,
        'vz': lambda x: x,
        'Bx': lambda x: torch.exp(-x)
    },
    # Head 1 - Nouvelles CI
    {
        'rho': lambda x: 2+torch.sin(x),
        'v': lambda x: x*x*x,
        'P': lambda x: torch.exp(-x),
        'By': lambda x: torch.exp(-x),
        'Bz': lambda x: torch.exp(-x),
        'vy': lambda x: x*x*x,  
        'vz': lambda x: x*x*x,
        'Bx': lambda x: torch.exp(-x)
    }
]

# Nouveau: Conditions aux limites (communes aux deux têtes)
def boundary_conditions(x, t):
    """Retourne les valeurs cibles pour x=0"""
    return {
        'rho': 2.0 * torch.ones_like(x),
        'v': torch.zeros_like(x),
        'P': torch.exp(-t),
        'By': torch.exp(-t),
        'Bz': torch.exp(-t),
        'vy': torch.zeros_like(x),
        'vz': torch.zeros_like(x),
        'Bx': torch.exp(-t),
    }

# PDE system adapté pour 8 équations
def pde_system(rho, vx, P, By, Bz, vy, vz, Bx, x, t, head_idx):
   
    return [
        diff(rho, t) + vx * diff(rho, x) + rho * diff(vx, x) ,  # Équation 1
        rho * diff(vx, t) + rho * vx * diff(vx, x) + diff(P, x) + By * diff(By, x) + Bz * diff(Bz, x) - rho * mu * diff(vx, x, order=2)  ,  # Équation 2
        rho * diff(vy,t) + rho*vx*diff(vy,x) - Bx*diff(By,x) ,  # Équation 3
        rho * diff(vz,t) + rho*vx*diff(vz,x) - Bx*diff(Bz,x) ,  # Nouvelle équation 4
        diff(P,t) + P*diff(vx,x) + vx*diff(P,x),  # Nouvelle équation 5
        diff(Bx,t),  # Nouvelle équation 6
        diff(By,t) + vx*diff(By,x) + By*diff(vx,x) - Bx*diff(vy,x) ,  # Nouvelle équation 7
        diff(Bz,t) + vx*diff(Bz,x) + Bz*diff(vx,x) - Bx*diff(vz,x)  # Nouvelle équation 8
    ]

# Création des 8 réseaux
nets = [MultiHeadFCNN(n_input_units=2, n_output_units=1, hidden_units=[256, 256]) for _ in range(8)]

# Chargement des poids sauvegardés pour les 5 premiers réseaux
for i in range(5):
    net_path = f"saved_weights/net_{i}.pt"
    if os.path.exists(net_path):
        nets[i].load_state_dict(torch.load(net_path))
        #print(f"Poids chargés pour le réseau {i} depuis {net_path}")
    else:
        #print(f"Attention: Fichier {net_path} non trouvé. Initialisation aléatoire.")
        def init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)
        nets[i].apply(init_weights)

# # Initialisation des nouveaux réseaux (5 à 7) avec des poids aléatoires
for i in range(5, 8):
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            nn.init.zeros_(m.bias)
    nets[i].apply(init_weights)
    #print(f"Réseau {i} initialisé avec des poids aléatoires")

# Optimizer
params = [p for net in nets for p in net.parameters()]
optimizer = optim.Adam(params, lr=1e-4)
criterion = nn.MSELoss()
##define the exponential learning rate decay scheduler
step_size = 500 ##how often the exponential decay is applied
gamma = 0.975 ##exponential decay scale
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)


# Generators
train_gen = Generator2D((20, 20), xy_min=(0, 0), xy_max=(1, 1), method='equally-spaced-noisy')
ic_x = torch.linspace(0, 1, 128).view(-1, 1)
ic_t = torch.zeros_like(ic_x)

# Training loop
for epoch in tqdm(range(12000)):
    optimizer.zero_grad()
    total_loss = 0
    epoch_pde_loss = 0
    epoch_ic_loss = 0
    
    # Pour chaque tête
    for head_idx in range(len(initial_conditions)):
        # 1. PDE Loss
        samples = train_gen.get_examples()
        x_train = samples[0].view(-1, 1)
        t_train = samples[1].view(-1, 1)
        inputs = torch.cat((x_train, t_train), dim=1)

        # Utilisation de la tête spécifique
        outputs = [net(inputs, head_idx=head_idx) for net in nets]
        pde_residuals = pde_system(*outputs, x_train, t_train, head_idx)

        loss_pde = sum([criterion(residual, torch.zeros_like(residual)) for residual in pde_residuals])
        epoch_pde_loss += loss_pde.item()

        # 2. Initial condition loss
        ic_inputs = torch.cat((ic_x, ic_t), dim=1)
        ic_outputs = [net(ic_inputs, head_idx=head_idx) for net in nets]
        ic_targets = [
            initial_conditions[head_idx]['rho'](ic_x),
            initial_conditions[head_idx]['v'](ic_x),
            initial_conditions[head_idx]['P'](ic_x),
            initial_conditions[head_idx]['By'](ic_x),
            initial_conditions[head_idx]['Bz'](ic_x),
            initial_conditions[head_idx]['vy'](ic_x),
            initial_conditions[head_idx]['vz'](ic_x),
            initial_conditions[head_idx]['Bx'](ic_x)
        ]

        loss_ic = sum([criterion(out, target) for out, target in zip(ic_outputs, ic_targets)])
        epoch_ic_loss += loss_ic.item()

        # 3. Nouveau: Boundary condition loss (x=0)
        bc_samples = bc_gen.get_examples()
        x_bc = torch.zeros_like(bc_samples[0]).view(-1, 1)  # x=0
        t_bc = bc_samples[1].view(-1, 1)
        bc_inputs = torch.cat((x_bc, t_bc), dim=1)
        
        bc_outputs = [net(bc_inputs, head_idx=head_idx) for net in nets]
        bc_targets = boundary_conditions(x_bc, t_bc)
        loss_bc = (
            criterion(bc_outputs[0], bc_targets['rho']) +  # rho
            criterion(bc_outputs[1], bc_targets['v']) +    # v
            criterion(bc_outputs[2], bc_targets['P']) +    # P
            criterion(bc_outputs[3], bc_targets['By']) +   # By
            criterion(bc_outputs[4], bc_targets['Bz']) +    # Bz
            criterion(bc_outputs[5], bc_targets['vy']) + # vy
            criterion(bc_outputs[6], bc_targets['vz']) +  # vz
            criterion(bc_outputs[7], bc_targets['Bx'])    # Bx
        )
        epoch_bc_loss += loss_bc.item()

        total_loss += loss_pde + loss_ic + loss_bc  # Inclut maintenant la loss BC
    
    total_loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        n_heads = len(initial_conditions)
        print(
            f"Epoch {epoch} | "
            f"PDE: {epoch_pde_loss/n_heads:.3e} | "
            f"IC: {epoch_ic_loss/n_heads:.3e} | "
            f"BC: {epoch_bc_loss/n_heads:.3e} | "  # Nouveau: affichage BC
            f"Total: {total_loss.item()/n_heads:.3e}"
        )
    
    total_losses.append(total_loss.item() / len(initial_conditions))
    pde_losses.append(epoch_pde_loss / len(initial_conditions))
    ic_losses.append(epoch_ic_loss / len(initial_conditions))
    bc_losses.append(epoch_bc_loss / len(initial_conditions))  # Stocke la loss BC

def solutions(x, t, head_idx=0):
    inputs = torch.cat((x, t), dim=1)
    return [net(inputs, head_idx=head_idx).detach() for net in nets]

In [None]:
### Loss function

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.semilogy(total_losses, label='Total Loss')
plt.semilogy(pde_losses, label='PDE Loss', linestyle='--')
plt.semilogy(ic_losses, label='IC Loss', linestyle=':')
plt.semilogy(bc_losses, label='BC Loss', linestyle='-')


plt.xlabel('Epoch')
plt.ylabel('Loss (log scale)')
plt.title('Training Loss Evolution')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
## show rho for 2 heads

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch

# 1. Vérification des types de données
print(f"Type des poids du premier réseau: {next(nets[0].parameters()).dtype}")

# 2. Création du maillage
x = np.linspace(0, 1, 100).astype(np.float32)
t = np.linspace(0, 1, 50).astype(np.float32)
X, T = np.meshgrid(x, t)

# 3. Conversion en tenseurs PyTorch
x_tensor = torch.tensor(X.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# 4. Calcul des solutions pour les deux têtes
try:
    with torch.no_grad():
        # Vérification et conversion si nécessaire
        if next(nets[0].parameters()).dtype != torch.float32:
            for net in nets:
                net.float()
        
        # Calcul pour les deux têtes
        sols_head1 = solutions(x_tensor, t_tensor, head_idx=0)  # mu=0.1
        sols_head2 = solutions(x_tensor, t_tensor, head_idx=1)  # mu=0.5
        
        rho_head1 = sols_head1[0].cpu().numpy().reshape(X.shape)
        rho_head2 = sols_head2[0].cpu().numpy().reshape(X.shape)
        
except Exception as e:
    print(f"Erreur lors du calcul: {str(e)}")
    raise

# Solution analytique
rho_analytical = 2+ np.sin(X) * np.exp(-T)

# Calcul des erreurs
error_head1 = np.abs(rho_analytical - rho_head1)
error_head2 = np.abs(rho_analytical - rho_head2)

# Création d'une figure avec 4 sous-graphiques
fig, axes = plt.subplots(2, 2, figsize=(18, 12))
fig.suptitle('Comparaison of the solutions ρ(x,t) for the 2 heads', fontsize=16)

# Graphique pour mu=0.1
im1 = axes[0,0].pcolormesh(X, T, rho_head1, shading='auto', cmap='viridis')
fig.colorbar(im1, ax=axes[0,0], label='ρ(x,t)')
axes[0,0].set_title('Solution ρ(x,t) - Head 1 (μ=0.1)')
axes[0,0].set_xlabel('x')
axes[0,0].set_ylabel('t')

# Erreur pour mu=0.1
im2 = axes[0,1].pcolormesh(X, T, error_head1, shading='auto', cmap='hot')
fig.colorbar(im2, ax=axes[0,1], label='Absolute erreur ')
axes[0,1].set_title('Erreur - Head 1 (μ=0.1)')
axes[0,1].set_xlabel('x')
axes[0,1].set_ylabel('t')

# Graphique pour mu=0.5
im3 = axes[1,0].pcolormesh(X, T, rho_head2, shading='auto', cmap='viridis')
fig.colorbar(im3, ax=axes[1,0], label='ρ(x,t)')
axes[1,0].set_title('Solution ρ(x,t) - Head 2 (μ=0.5)')
axes[1,0].set_xlabel('x')
axes[1,0].set_ylabel('t')

# Erreur pour mu=0.5
im4 = axes[1,1].pcolormesh(X, T, error_head2, shading='auto', cmap='hot')
fig.colorbar(im4, ax=axes[1,1], label='Absolute erreur ')
axes[1,1].set_title('Erreur - Head 2 (μ=0.5)')
axes[1,1].set_xlabel('x')
axes[1,1].set_ylabel('t')

plt.tight_layout()
plt.show()

# Affichage des erreurs maximales pour comparaison
print(f"Erreur maximale tête 1 (μ=0.1): {np.max(error_head1):.3e}")
print(f"Erreur maximale tête 2 (μ=0.5): {np.max(error_head2):.3e}")

# Optionnel: Différence entre les deux têtes
diff_heads = np.abs(rho_head1 - rho_head2)
plt.figure(figsize=(12, 6))
plt.pcolormesh(X, T, diff_heads, shading='auto', cmap='plasma')
plt.colorbar(label='Difference')
plt.xlabel('x')
plt.ylabel('t')
plt.title('Difference between the 2 heads')
plt.show()
print()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch

##vx

# 1. Vérification des types de données
print(f"Type des poids du premier réseau: {next(nets[1].parameters()).dtype}")

# 2. Création du maillage
x = np.linspace(0, 1, 100).astype(np.float32)
t = np.linspace(0, 1, 50).astype(np.float32)
X, T = np.meshgrid(x, t)

# 3. Conversion en tenseurs PyTorch
x_tensor = torch.tensor(X.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# 4. Calcul des solutions pour les deux têtes
try:
    with torch.no_grad():
        # Vérification et conversion si nécessaire
        if next(nets[1].parameters()).dtype != torch.float32:
            for net in nets:
                net.float()
        
        # Calcul pour les deux têtes
        sols_head1 = solutions(x_tensor, t_tensor, head_idx=0)  # mu=0.1
        sols_head2 = solutions(x_tensor, t_tensor, head_idx=1)  # mu=0.5
        
        rho_head1 = sols_head1[1].cpu().numpy().reshape(X.shape)
        rho_head2 = sols_head2[1].cpu().numpy().reshape(X.shape)
        
except Exception as e:
    print(f"Erreur lors du calcul: {str(e)}")
    raise

# Solution analytique
rho_analytical1 =  x*np.exp(-T)
rho_analytical2 =  x*x*x*np.exp(-T)

# Calcul des erreurs
error_head1 = np.abs(rho_analytical1 - rho_head1)
error_head2 = np.abs(rho_analytical2 - rho_head2)

# Création d'une figure avec 4 sous-graphiques
fig, axes = plt.subplots(2, 2, figsize=(18, 12))
fig.suptitle('Comparaison of vx(x,t) for the 2 heads', fontsize=16)

# Graphique pour mu=0.1
im1 = axes[0,0].pcolormesh(X, T, rho_head1, shading='auto', cmap='viridis')
fig.colorbar(im1, ax=axes[0,0], label='vx(x,t)')
axes[0,0].set_title('Solution vx(x,t) - Head 1 (μ=0.1)')
axes[0,0].set_xlabel('x')
axes[0,0].set_ylabel('t')

# Erreur pour mu=0.1
im2 = axes[0,1].pcolormesh(X, T, error_head1, shading='auto', cmap='hot')
fig.colorbar(im2, ax=axes[0,1], label='Absolute error')
axes[0,1].set_title('Error - Head 1 (μ=0.1)')
axes[0,1].set_xlabel('x')
axes[0,1].set_ylabel('t')

# Graphique pour mu=0.5
im3 = axes[1,0].pcolormesh(X, T, rho_head2, shading='auto', cmap='viridis')
fig.colorbar(im3, ax=axes[1,0], label='vx(x,t)')
axes[1,0].set_title('Solution vx(x,t) - Head 2 (μ=0.5)')
axes[1,0].set_xlabel('x')
axes[1,0].set_ylabel('t')

# Erreur pour mu=0.5
im4 = axes[1,1].pcolormesh(X, T, error_head2, shading='auto', cmap='hot')
fig.colorbar(im4, ax=axes[1,1], label='Absolute error ')
axes[1,1].set_title('Error - Head 2 (μ=0.5)')
axes[1,1].set_xlabel('x')
axes[1,1].set_ylabel('t')

plt.tight_layout()
plt.show()

# Affichage des erreurs maximales pour comparaison
print(f"Erreur maximale tête 1 (μ=0.1): {np.max(error_head1):.3e}")
print(f"Erreur maximale tête 2 (μ=0.5): {np.max(error_head2):.3e}")

# Optionnel: Différence entre les deux têtes
diff_heads = np.abs(rho_head1 - rho_head2)
plt.figure(figsize=(12, 6))
plt.pcolormesh(X, T, diff_heads, shading='auto', cmap='plasma')
plt.colorbar(label='Difference')
plt.xlabel('x')
plt.ylabel('t')
plt.title('Difference between the 2 heads')
plt.show()
print()

In [None]:
# # Save NN after training
# for i, net in enumerate(nets):
#     torch.save(net.state_dict(), f"net_{i}.pth")

## save weights of the 5 first NN

import os
os.makedirs("saved_weights", exist_ok=True)  


for i, net in enumerate(nets[:8]):
    torch.save(net.state_dict(), f"saved_weights/net_{i}.pt")


In [None]:
### solveur without ff
### head 1

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

# Paramètres
X   = 1.0       # longueur du domaine spatial
nx  = 300       # nombre de points spatiaux
L   = 1.0       # durée finale
mu  = 0.1       # viscosité

def forcing_terms(x, t):
    """
    Retourne un tableau (8, N) avec f0..f7 sur l'espace x à l'instant t.
    """
    f0 = -np.sin(x) * np.exp(-t)+x*np.cos(x)*np.exp(-2*t)+2*np.exp(-t)+np.sin(x)*np.exp(-2*t)
    f1 = -2*x*np.exp(-t)-np.sin(x) * np.exp(-2*t)*x+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)-np.exp(-x-t)-2*np.exp(-2*x-2*t)
    f2 = -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f3 =  -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f4 =  - np.exp(-x-t)+np.exp(-x-2*t)-x*np.exp(-x-2*t)
    f5 = -np.exp(-x-t)
    f6 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    f7 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    
    return np.array([f0, f1, f2, f3, f4, f5, f6, f7])  # shape (8, N)


# Maillage
x   = np.linspace(0, X, nx)
dx  = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales (ici tout est 1 sauf vx = sin(pi x))
def init_cond():
    rho = 2*np.ones_like(x)
    vx = x
    vy = x
    vz = x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    # Variables fixes :  rho(0)=2, vx=vy=vz=0  --> dérivées temporelles = 0
    drho_dt[0] = 0.0
    dvx_dt[0]  = 0.0
    dvy_dt[0]  = 0.0
    dvz_dt[0]  = 0.0

    # Variables imposées = exp(-t)  -->  d/dt exp(-t) = -exp(-t)
    bc_dot = -np.exp(-t)
    dP_dt[0]  = bc_dot
    dBx_dt[0] = bc_dot
    dBy_dt[0] = bc_dot
    dBz_dt[0] = bc_dot

    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])

# Résolution
y0  = init_cond()
solv1 = solve_ivp(mhd_rhs, (0, L), y0,
                method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num_v1 = solv1.y[nx:2*nx, :]  # shape (nx, nt)
T, Xg  = np.meshgrid(solv1.t, x)

# Solution exacte
vx_ex = Xg * np.exp(-T)

# Calcul de l'erreur
err = vx_num_v1 - vx_ex

# Affichage 1 : vx numérique
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.imshow(vx_num_v1, extent=[0,L,0,X], aspect='auto', origin='lower')
plt.colorbar(label='vx_num')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('vx numérique')

# Affichage 2 : erreur vx_num - vx_ex
plt.subplot(1,2,2)
plt.imshow(err, extent=[0,L,0,X], aspect='auto',
           origin='lower', cmap='bwr', vmin=-np.max(abs(err)), vmax=np.max(abs(err)))
plt.colorbar(label='Erreur')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Erreur $v_x^{num}-x e^{-t}$')

plt.tight_layout()
plt.show()

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import RegularGridInterpolator


# Grille x-t (à adapter si besoin)
x = np.linspace(0, 1, 100).astype(np.float32)
t = np.linspace(0, 1, 50).astype(np.float32)
X, T = np.meshgrid(x, t, indexing='ij')  # Shape: (nx, nt)

# Conversion en tenseurs torch
x_tensor = torch.tensor(X.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# Prédiction PINN : vx est la 2e sortie, donc [1]
with torch.no_grad():
    if next(nets[1].parameters()).dtype != torch.float32:
        for net in nets:
            net.float()
    vx_pinn_flat = solutions(x_tensor, t_tensor, head_idx=0)[1].cpu().numpy().flatten()

vx_pinn = vx_pinn_flat.reshape(X.shape)  # (nx, nt)
vx_num_interp = vx_num_v1  # si déjà sur la même grille


# Interpolation de vx_num sur la grille (X, T)
x_num = np.linspace(0, 1, vx_num.shape[0])
t_num = np.linspace(0, 1, vx_num.shape[1])
interp_func = RegularGridInterpolator((x_num, t_num), vx_num)
vx_num_interp = interp_func(np.stack([X.flatten(), T.flatten()], axis=-1)).reshape(X.shape)

# Calcul de l'erreur
error = np.abs(vx_pinn - vx_num_interp)

# Affichage
fig, axs = plt.subplots(1, 3, figsize=(18, 5))

# PINN vx
im0 = axs[0].pcolormesh(X, T, vx_pinn, shading='auto', cmap='viridis')
axs[0].set_title('vx PINN (head 1)')
axs[0].set_xlabel('x')
axs[0].set_ylabel('t')
fig.colorbar(im0, ax=axs[0])

# Solveur vx
im1 = axs[1].pcolormesh(X, T, vx_num_interp, shading='auto', cmap='viridis')
axs[1].set_title('vx Solveur')
axs[1].set_xlabel('x')
axs[1].set_ylabel('t')
fig.colorbar(im1, ax=axs[1])

# Erreur
im2 = axs[2].pcolormesh(X, T, error, shading='auto', cmap='hot')
axs[2].set_title('|vx PINN - vx Solveur|')
axs[2].set_xlabel('x')
axs[2].set_ylabel('t')
fig.colorbar(im2, ax=axs[2])

plt.tight_layout()
plt.show()

# Info d'erreur
print(f"Erreur absolue max : {np.max(error):.2e}")
print(f"Erreur absolue moyenne : {np.mean(error):.2e}")


In [None]:
### solveur without ff
### head 2

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

# Paramètres
X   = 1.0       # longueur du domaine spatial
nx  = 300       # nombre de points spatiaux
L   = 1.0       # durée finale
mu  = 0.1       # viscosité

def forcing_terms(x, t):
    """
    Retourne un tableau (8, N) avec f0..f7 sur l'espace x à l'instant t.
    """
    f0 = -np.sin(x) * np.exp(-t)+x*np.cos(x)*np.exp(-2*t)+2*np.exp(-t)+np.sin(x)*np.exp(-2*t)
    f1 = -2*x*np.exp(-t)-np.sin(x) * np.exp(-2*t)*x+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)-np.exp(-x-t)-2*np.exp(-2*x-2*t)
    f2 = -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f3 =  -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f4 =  - np.exp(-x-t)+np.exp(-x-2*t)-x*np.exp(-x-2*t)
    f5 = -np.exp(-x-t)
    f6 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    f7 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    
    return np.array([f0, f1, f2, f3, f4, f5, f6, f7])  # shape (8, N)


# Maillage
x   = np.linspace(0, X, nx)
dx  = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales (ici tout est 1 sauf vx = sin(pi x))
def init_cond():
    rho = 2*np.ones_like(x)
    vx = x*x*x
    vy = x*x*x
    vz = x*x*x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    # Variables fixes :  rho(0)=2, vx=vy=vz=0  --> dérivées temporelles = 0
    drho_dt[0] = 0.0
    dvx_dt[0]  = 0.0
    dvy_dt[0]  = 0.0
    dvz_dt[0]  = 0.0

    # Variables imposées = exp(-t)  -->  d/dt exp(-t) = -exp(-t)
    bc_dot = -np.exp(-t)
    dP_dt[0]  = bc_dot
    dBx_dt[0] = bc_dot
    dBy_dt[0] = bc_dot
    dBz_dt[0] = bc_dot

    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])

# Résolution
y0  = init_cond()
sol = solve_ivp(mhd_rhs, (0, L), y0,
                method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num_v2 = sol.y[nx:2*nx, :]  # shape (nx, nt)
T, Xg  = np.meshgrid(sol.t, x)

# Solution exacte
vx_ex = Xg*Xg*Xg * np.exp(-T)

# Calcul de l'erreur
err = vx_num_v2 - vx_ex

# Affichage 1 : vx numérique
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.imshow(vx_num_v2, extent=[0,L,0,X], aspect='auto', origin='lower')
plt.colorbar(label='vx_num')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('vx numérique')

# Affichage 2 : erreur vx_num - vx_ex
plt.subplot(1,2,2)
plt.imshow(err, extent=[0,L,0,X], aspect='auto',
           origin='lower', cmap='bwr', vmin=-np.max(abs(err)), vmax=np.max(abs(err)))
plt.colorbar(label='Erreur')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Erreur $v_x^{num}-x e^{-t}$')

plt.tight_layout()
plt.show()

In [None]:
### solveur without ff, with direct BC
### head 2

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

# Paramètres
X   = 1.0       # longueur du domaine spatial
nx  = 300       # nombre de points spatiaux
L   = 1.0       # durée finale
mu  = 0.1       # viscosité

def forcing_terms(x, t):
    """
    Retourne un tableau (8, N) avec f0..f7 sur l'espace x à l'instant t.
    """
    f0 = -np.sin(x) * np.exp(-t)+x*np.cos(x)*np.exp(-2*t)+2*np.exp(-t)+np.sin(x)*np.exp(-2*t)
    f1 = -2*x*np.exp(-t)-np.sin(x) * np.exp(-2*t)*x+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)-np.exp(-x-t)-2*np.exp(-2*x-2*t)
    f2 = -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f3 =  -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f4 =  - np.exp(-x-t)+np.exp(-x-2*t)-x*np.exp(-x-2*t)
    f5 = -np.exp(-x-t)
    f6 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    f7 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    
    return np.array([f0, f1, f2, f3, f4, f5, f6, f7])  # shape (8, N)


# Maillage
x   = np.linspace(0, X, nx)
dx  = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales (ici tout est 1 sauf vx = sin(pi x))
def init_cond():
    rho = 2*np.ones_like(x)
    vx = x*x*x
    vy = x*x*x
    vz = x*x*x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    rho[0] = 2
    vx[0] = vy[0] = vz[0] = 0
    P[0] = Bx[0] = By[0] = Bz[0] = np.exp(-t)


    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])

# Résolution
y0  = init_cond()
sol = solve_ivp(mhd_rhs, (0, L), y0,
                method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num_v2 = sol.y[nx:2*nx, :]  # shape (nx, nt)
T, Xg  = np.meshgrid(sol.t, x)

# Solution exacte
vx_ex = Xg*Xg*Xg * np.exp(-T)

# Calcul de l'erreur
err = vx_num_v2 - vx_ex

# Affichage 1 : vx numérique
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.imshow(vx_num_v2, extent=[0,L,0,X], aspect='auto', origin='lower')
plt.colorbar(label='vx_num')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('vx numérique')

# Affichage 2 : erreur vx_num - vx_ex
plt.subplot(1,2,2)
plt.imshow(err, extent=[0,L,0,X], aspect='auto',
           origin='lower', cmap='bwr', vmin=-np.max(abs(err)), vmax=np.max(abs(err)))
plt.colorbar(label='Erreur')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Erreur $v_x^{num}-x e^{-t}$')

plt.tight_layout()
plt.show()

In [None]:
### solveur and PINN head 1
### works


import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import torch

# Paramètres
X = 1.0       # longueur du domaine spatial
nx = 300      # nombre de points spatiaux
L = 1.0       # durée finale
mu = 0.1      # viscosité


def forcing_terms(x, t):
    """
    Retourne un tableau (8, N) avec f0..f7 sur l'espace x à l'instant t.
    """
    f0 = -np.sin(x) * np.exp(-t)+x*np.cos(x)*np.exp(-2*t)+2*np.exp(-t)+np.sin(x)*np.exp(-2*t)
    f1 = -2*x*np.exp(-t)-np.sin(x) * np.exp(-2*t)*x+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)-np.exp(-x-t)-2*np.exp(-2*x-2*t)
    f2 = -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f3 =  -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f4 =  - np.exp(-x-t)+np.exp(-x-2*t)-x*np.exp(-x-2*t)
    f5 = -np.exp(-x-t)
    f6 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    f7 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    
    return np.array([f0, f1, f2, f3, f4, f5, f6, f7])  # shape (8, N)


# Maillage
x   = np.linspace(0, X, nx)
dx  = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales (ici tout est 1 sauf vx = sin(pi x))
def init_cond():
    rho = 2*np.ones_like(x)
    vx = x
    vy = x
    vz = x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    # Forçage direct des variables aux bords
    rho[0] = 2
    vx[0] = vy[0] = vz[0] = 0
    P[0] = Bx[0] = By[0] = Bz[0] = np.exp(-t)


    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])



# Résolution
y0 = init_cond()
solv1 = solve_ivp(mhd_rhs, (0, L), y0, method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num_v1 = solv1.y[nx:2*nx, :]  # shape (nx, nt)
T, Xg = np.meshgrid(sol.t, x)

# Solution exacte
vx_ex = Xg * np.exp(-T)

# Calcul de l'erreur
err = vx_num_v1 - vx_ex

# Préparation des données pour le PINN
x_tensor = torch.tensor(Xg.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# Prédiction PINN (en utilisant head_idx=1 pour la tête 2)
with torch.no_grad():
    if next(nets[1].parameters()).dtype != torch.float32:
        for net in nets:
            net.float()
    vx_pinn_flat = solutions(x_tensor, t_tensor, head_idx=0)[1].cpu().numpy().flatten() ## head 1

vx_pinn = vx_pinn_flat.reshape(Xg.shape)

# Nouveau: Interpolation pour aligner les grilles si nécessaire
# (Dans ce cas, les grilles sont identiques donc pas besoin d'interpolation)

# Configuration des graphiques
plt.figure(figsize=(18, 12))

# Graphique 1: Solution numérique
plt.subplot(2, 2, 1)
plt.pcolormesh(T, Xg, vx_num_v1, shading='auto', cmap='viridis')
plt.colorbar(label='vx')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Solution du Solveur (vx numérique)')

# Graphique 2: Erreur solveur vs exacte
plt.subplot(2, 2, 2)
plt.pcolormesh(T, Xg, err, shading='auto', cmap='bwr', 
               vmin=-np.max(np.abs(err)), vmax=np.max(np.abs(err)))
plt.colorbar(label='Erreur')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Erreur Solveur vs Solution Exacte')

# Graphique 3: Prédiction PINN (NOUVEAU)
plt.subplot(2, 2, 3)
plt.pcolormesh(T, Xg, vx_pinn, shading='auto', cmap='viridis')
plt.colorbar(label='vx')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Prédiction PINN (tête 2)')

# Graphique 4: Différence Solveur-PINN (NOUVEAU)
plt.subplot(2, 2, 4)
diff = vx_num_v2 - vx_pinn
plt.pcolormesh(T, Xg, diff, shading='auto', cmap='bwr',
               vmin=-np.max(np.abs(diff)), vmax=np.max(np.abs(diff)))
plt.colorbar(label='Différence')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Différence Solveur-PINN')

plt.tight_layout()
plt.show()

# Affichage des erreurs
print("\nAnalyse d'erreur:")
print(f"Erreur max Solveur vs Exact: {np.max(np.abs(err)):.2e}")
print(f"Erreur max Solveur vs PINN: {np.max(np.abs(diff)):.2e}")
print(f"Erreur moyenne Solveur vs PINN: {np.mean(np.abs(diff)):.2e}")

In [None]:
### solveur and PINN head 2, vx
### works


import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import torch

# Paramètres
X = 1.0       # longueur du domaine spatial
nx = 300      # nombre de points spatiaux
L = 1.0       # durée finale
mu = 0.1      # viscosité


def forcing_terms(x, t):
    """
    Retourne un tableau (8, N) avec f0..f7 sur l'espace x à l'instant t.
    """
    f0 = -np.sin(x) * np.exp(-t)+x*np.cos(x)*np.exp(-2*t)+2*np.exp(-t)+np.sin(x)*np.exp(-2*t)
    f1 = -2*x*np.exp(-t)-np.sin(x) * np.exp(-2*t)*x+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)-np.exp(-x-t)-2*np.exp(-2*x-2*t)
    f2 = -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f3 =  -2*x*np.exp(-t)-np.sin(x) *x* np.exp(-2*t)+2*x*np.exp(-2*t)+x*np.sin(x)*np.exp(-3*t)+np.exp(-2*x-2*t)
    f4 =  - np.exp(-x-t)+np.exp(-x-2*t)-x*np.exp(-x-2*t)
    f5 = -np.exp(-x-t)
    f6 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    f7 = -np.exp(-x-t)-x*np.exp(-x-2*t)
    
    return np.array([f0, f1, f2, f3, f4, f5, f6, f7])  # shape (8, N)


# Maillage
x   = np.linspace(0, X, nx)
dx  = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales (ici tout est 1 sauf vx = sin(pi x))
def init_cond():
    rho = 2*np.ones_like(x)
    vx = x*x*x
    vy = x*x*x
    vz = x*x*x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    # Forçage direct des variables aux bords
    rho[0] = 2
    vx[0] = vy[0] = vz[0] = 0
    P[0] = Bx[0] = By[0] = Bz[0] = np.exp(-t)


    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])



# Résolution
y0 = init_cond()
solv2 = solve_ivp(mhd_rhs, (0, L), y0, method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num_v2 = solv2.y[nx:2*nx, :]  # shape (nx, nt)
T, Xg = np.meshgrid(sol.t, x)

# Solution exacte
vx_ex = Xg**3 * np.exp(-T)

# Calcul de l'erreur
err = vx_num_v2 - vx_ex

# Préparation des données pour le PINN
x_tensor = torch.tensor(Xg.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# Prédiction PINN (en utilisant head_idx=1 pour la tête 2)
with torch.no_grad():
    if next(nets[1].parameters()).dtype != torch.float32:
        for net in nets:
            net.float()
    vx_pinn_flat = solutions(x_tensor, t_tensor, head_idx=1)[1].cpu().numpy().flatten() ## head 1

vx_pinn = vx_pinn_flat.reshape(Xg.shape)

# Nouveau: Interpolation pour aligner les grilles si nécessaire
# (Dans ce cas, les grilles sont identiques donc pas besoin d'interpolation)

# Configuration des graphiques
plt.figure(figsize=(18, 12))

# Graphique 1: Solution numérique
plt.subplot(2, 2, 1)
plt.pcolormesh(T, Xg, vx_num_v2, shading='auto', cmap='viridis')
plt.colorbar(label='vx')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Solution du Solveur (vx numérique)')

# Graphique 2: Erreur solveur vs exacte
plt.subplot(2, 2, 2)
plt.pcolormesh(T, Xg, err, shading='auto', cmap='bwr', 
               vmin=-np.max(np.abs(err)), vmax=np.max(np.abs(err)))
plt.colorbar(label='Erreur')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Erreur Solveur vs Solution Exacte')

# Graphique 3: Prédiction PINN (NOUVEAU)
plt.subplot(2, 2, 3)
plt.pcolormesh(T, Xg, vx_pinn, shading='auto', cmap='viridis')
plt.colorbar(label='vx')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Prédiction PINN (tête 2)')

# Graphique 4: Différence Solveur-PINN (NOUVEAU)
plt.subplot(2, 2, 4)
diff = vx_num_v2 - vx_pinn
plt.pcolormesh(T, Xg, diff, shading='auto', cmap='bwr',
               vmin=-np.max(np.abs(diff)), vmax=np.max(np.abs(diff)))
plt.colorbar(label='Différence')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Différence Solveur-PINN')

plt.tight_layout()
plt.show()

# Affichage des erreurs
print("\nAnalyse d'erreur:")
print(f"Erreur max Solveur vs Exact: {np.max(np.abs(err)):.2e}")
print(f"Erreur max Solveur vs PINN: {np.max(np.abs(diff)):.2e}")
print(f"Erreur moyenne Solveur vs PINN: {np.mean(np.abs(diff)):.2e}")

In [None]:
### solveur and PINN head 1, rho
### works


import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import torch

# Paramètres
X = 1.0       # longueur du domaine spatial
nx = 300      # nombre de points spatiaux
L = 1.0       # durée finale
mu = 0.1      # viscosité



# Maillage
x   = np.linspace(0, X, nx)
dx  = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales
def init_cond():
    rho = 2*np.ones_like(x)
    vx = x
    vy = x
    vz = x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    # Forçage direct des variables aux bords
    rho[0] = 2
    vx[0] = vy[0] = vz[0] = 0
    P[0] = Bx[0] = By[0] = Bz[0] = np.exp(-t)


    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])



# Résolution
y0 = init_cond()
solv1 = solve_ivp(mhd_rhs, (0, L), y0, method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num_v1 = solv1.y[0:nx, :]  # shape (nx, nt), (0:nx) pour rho
T, Xg = np.meshgrid(solv1.t, x)

# Solution exacte
vx_ex = 2

# Calcul de l'erreur
err = vx_num_v1 - vx_ex

# Préparation des données pour le PINN
x_tensor = torch.tensor(Xg.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# Prédiction PINN (en utilisant head_idx=1 pour la tête 2)
with torch.no_grad():
    if next(nets[1].parameters()).dtype != torch.float32:
        for net in nets:
            net.float()
    vx_pinn_flat = solutions(x_tensor, t_tensor, head_idx=0)[0].cpu().numpy().flatten() ## head 1, rho

vx_pinn = vx_pinn_flat.reshape(Xg.shape)

# Nouveau: Interpolation pour aligner les grilles si nécessaire
# (Dans ce cas, les grilles sont identiques donc pas besoin d'interpolation)

# Configuration des graphiques
plt.figure(figsize=(18, 12))

# Graphique 1: Solution numérique
plt.subplot(2, 2, 1)
plt.pcolormesh(T, Xg, vx_num_v1, shading='auto', cmap='viridis')
plt.colorbar(label='vx')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Solution du Solveur (vx numérique)')

# Graphique 2: Erreur solveur vs exacte
plt.subplot(2, 2, 2)
plt.pcolormesh(T, Xg, err, shading='auto', cmap='bwr', 
               vmin=-np.max(np.abs(err)), vmax=np.max(np.abs(err)))
plt.colorbar(label='Erreur')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Erreur Solveur vs Solution Exacte')

# Graphique 3: Prédiction PINN (NOUVEAU)
plt.subplot(2, 2, 3)
plt.pcolormesh(T, Xg, vx_pinn, shading='auto', cmap='viridis')
plt.colorbar(label='vx')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Prédiction PINN (tête 1)')

# Graphique 4: Différence Solveur-PINN (NOUVEAU)
plt.subplot(2, 2, 4)
diff = vx_num_v1 - vx_pinn
plt.pcolormesh(T, Xg, diff, shading='auto', cmap='bwr',
               vmin=-np.max(np.abs(diff)), vmax=np.max(np.abs(diff)))
plt.colorbar(label='Différence')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Différence Solveur-PINN')

plt.tight_layout()
plt.show()

# Affichage des erreurs
print("\nAnalyse d'erreur:")
print(f"Erreur max Solveur vs Exact: {np.max(np.abs(err)):.2e}")
print(f"Erreur max Solveur vs PINN: {np.max(np.abs(diff)):.2e}")
print(f"Erreur moyenne Solveur vs PINN: {np.mean(np.abs(diff)):.2e}")

In [None]:
# rho = solv1.y[0:nx, :]           # shape (nx, nt)
# vx  = solv1.y[nx:2*nx, :]
# vy  = solv1.y[2*nx:3*nx, :]
# vz  = solv1.y[3*nx:4*nx, :]
# P   = solv1.y[4*nx:5*nx, :]
# Bx  = solv1.y[5*nx:6*nx, :]
# By  = solv1.y[6*nx:7*nx, :]
# Bz  = solv1.y[7*nx:8*nx, :]


In [None]:
### solveur and PINN head 1, vx
### works


import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import torch

# Paramètres
X = 1.0       # longueur du domaine spatial
nx = 300      # nombre de points spatiaux
L = 1.0       # durée finale
mu = 0.1      # viscosité



# Maillage
x   = np.linspace(0, X, nx)
dx  = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales
def init_cond():
    rho = 2*np.ones_like(x)
    vx = x
    vy = x
    vz = x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    # Forçage direct des variables aux bords
    rho[0] = 2
    vx[0] = vy[0] = vz[0] = 0
    P[0] = Bx[0] = By[0] = Bz[0] = np.exp(-t)


    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])



# Résolution
y0 = init_cond()
solv1 = solve_ivp(mhd_rhs, (0, L), y0, method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num_v1 = solv1.y[5*nx:6*nx, :]  # shape (nx, nt), (0:nx) pour rho
T, Xg = np.meshgrid(solv1.t, x)

# Solution exacte
vx_ex = np.exp(-Xg)*np.exp(-T)
# Calcul de l'erreur
err = vx_num_v1 - vx_ex

# Préparation des données pour le PINN
x_tensor = torch.tensor(Xg.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# Prédiction PINN (en utilisant head_idx=1 pour la tête 2)
with torch.no_grad():
    if next(nets[1].parameters()).dtype != torch.float32:
        for net in nets:
            net.float()
    vx_pinn_flat = solutions(x_tensor, t_tensor, head_idx=0)[7].cpu().numpy().flatten() ## head 1, rho

vx_pinn = vx_pinn_flat.reshape(Xg.shape)

# Nouveau: Interpolation pour aligner les grilles si nécessaire
# (Dans ce cas, les grilles sont identiques donc pas besoin d'interpolation)

# Configuration des graphiques
plt.figure(figsize=(18, 12))

# Graphique 1: Solution numérique
plt.subplot(2, 2, 1)
plt.pcolormesh(T, Xg, vx_num_v1, shading='auto', cmap='viridis')
plt.colorbar(label='vx')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Solution du Solveur (vx numérique)')

# Graphique 2: Erreur solveur vs exacte
plt.subplot(2, 2, 2)
plt.pcolormesh(T, Xg, err, shading='auto', cmap='bwr', 
               vmin=-np.max(np.abs(err)), vmax=np.max(np.abs(err)))
plt.colorbar(label='Erreur')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Erreur Solveur vs Solution Exacte')

# Graphique 3: Prédiction PINN (NOUVEAU)
plt.subplot(2, 2, 3)
plt.pcolormesh(T, Xg, vx_pinn, shading='auto', cmap='viridis')
plt.colorbar(label='vx')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Prédiction PINN (tête 1)')

# Graphique 4: Différence Solveur-PINN (NOUVEAU)
plt.subplot(2, 2, 4)
diff = vx_num_v1 - vx_pinn
plt.pcolormesh(T, Xg, diff, shading='auto', cmap='bwr',
               vmin=-np.max(np.abs(diff)), vmax=np.max(np.abs(diff)))
plt.colorbar(label='Différence')
plt.xlabel('Temps t')
plt.ylabel('Position x')
plt.title('Différence Solveur-PINN')

plt.tight_layout()
plt.show()

# Affichage des erreurs
print("\nAnalyse d'erreur:")
print(f"Erreur max Solveur vs Exact: {np.max(np.abs(err)):.2e}")
print(f"Erreur max Solveur vs PINN: {np.max(np.abs(diff)):.2e}")
print(f"Erreur moyenne Solveur vs PINN: {np.mean(np.abs(diff)):.2e}")

In [None]:
### solveur and PINN head 1, vx
### works
## affiche parfait

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import torch

# Paramètres
X = 1.0       # longueur du domaine spatial
nx = 300      # nombre de points spatiaux
L = 1.0       # durée finale
mu = 0.1      # viscosité

# Maillage
x = np.linspace(0, X, nx)
dx = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales
def init_cond():
    rho = 2*np.ones_like(x)
    vx = x
    vy = x
    vz = x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    # Forçage direct des variables aux bords
    rho[0] = 2
    vx[0] = vy[0] = vz[0] = 0
    P[0] = Bx[0] = By[0] = Bz[0] = np.exp(-t)

    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])

# Résolution
y0 = init_cond()
solv1 = solve_ivp(mhd_rhs, (0, L), y0, method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num_v1 = solv1.y[3*nx:4*nx, :]  # shape (nx, nt), (0:nx) pour rho
T, Xg = np.meshgrid(solv1.t, x)

# Solution exacte (calculée mais non affichée)
vx_ex = np.exp(-Xg)*np.exp(-T)
err = vx_num_v1 - vx_ex

# Préparation des données pour le PINN
x_tensor = torch.tensor(Xg.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# Prédiction PINN (en utilisant head_idx=1 pour la tête 2)
with torch.no_grad():
    if next(nets[1].parameters()).dtype != torch.float32:
        for net in nets:
            net.float()
    vx_pinn_flat = solutions(x_tensor, t_tensor, head_idx=0)[6].cpu().numpy().flatten() ## head 1, rho

vx_pinn = vx_pinn_flat.reshape(Xg.shape)


# Configuration des graphiques (3 graphiques au lieu de 4)
plt.figure(figsize=(18, 6))

# Paramètres de style pour agrandir le texte
fontsize = 14  # Taille de police pour les axes et titres
cbar_fontsize = 14  # Taille de police pour la colorbar
title_fontsize = 16  # Taille de police pour les titres

# Graphique 1: Solution numérique
plt.subplot(1, 3, 1)
img1 = plt.pcolormesh(T, Xg, vx_num_v1, shading='auto', cmap='viridis')
cbar1 = plt.colorbar(img1)
cbar1.set_label('vz', fontsize=cbar_fontsize)
cbar1.ax.tick_params(labelsize=cbar_fontsize)
plt.xlabel('t', fontsize=fontsize)
plt.ylabel('x', fontsize=fontsize)
plt.title('Solver solution', fontsize=title_fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)

# Graphique 2: Prédiction PINN
plt.subplot(1, 3, 2)
img2 = plt.pcolormesh(T, Xg, vx_pinn, shading='auto', cmap='viridis')
cbar2 = plt.colorbar(img2)
cbar2.set_label('vz', fontsize=cbar_fontsize)
cbar2.ax.tick_params(labelsize=cbar_fontsize)
plt.xlabel('t', fontsize=fontsize)
plt.ylabel('x', fontsize=fontsize)
plt.title('PINN prediction(head 1)', fontsize=title_fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)

# Graphique 3: Différence Solveur-PINN
plt.subplot(1, 3, 3)
diff = vx_num_v1 - vx_pinn
img3 = plt.pcolormesh(T, Xg, diff, shading='auto', cmap='bwr',
               vmin=-np.max(np.abs(diff)), vmax=np.max(np.abs(diff)))
cbar3 = plt.colorbar(img3)
cbar3.set_label('Difference', fontsize=cbar_fontsize)
cbar3.ax.tick_params(labelsize=cbar_fontsize)
plt.xlabel('t', fontsize=fontsize)
plt.ylabel('x', fontsize=fontsize)
plt.title('Solver-PINN difference', fontsize=title_fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)

plt.tight_layout()
plt.show()

In [None]:
## TL on CIs (new head head3 initialized on head2)
## without ff, with BC

from neurodiffeq.operators import diff  # Import crucial pour les dérivées


class MultiHeadFCNN(nn.Module):
    def __init__(self, n_input_units=2, n_output_units=1, hidden_units=[256, 256]):
        super().__init__()
        # Couches partagées (inchangées)
        self.shared_layers = nn.Sequential(
            nn.Linear(n_input_units, hidden_units[0]),
            nn.Tanh(),
            nn.Linear(hidden_units[0], hidden_units[1]),
            nn.Tanh()
        )
        
        # Têtes existantes (head1, head2) + nouvelle tête (head3)
        self.head1 = nn.Linear(hidden_units[1], n_output_units)
        self.head2 = nn.Linear(hidden_units[1], n_output_units)
        self.head3 = nn.Linear(hidden_units[1], n_output_units)  # Nouvelle tête pour nouvelles CI

    def forward(self, x, head_idx=2):  # head_idx=2 par défaut (head3)
        shared = self.shared_layers(x)
        if head_idx == 0:
            return self.head1(shared)
        elif head_idx == 1:
            return self.head2(shared)
        else:
            return self.head3(shared)


## Initialisation des réseaux pour le transfer learning
nets = []
total_losses = []
pde_losses = []
ic_losses = []
bc_losses = []  # Pour suivre la loss des conditions aux limites

## Nouvelle condition initiale pour head3
new_initial_conditions = {
    'rho': lambda x: 2*torch.ones_like(x),
    'v': lambda x: x*x,
    'P': lambda x: torch.exp(-x),
    'By': lambda x: torch.exp(-x),
    'Bz': lambda x: torch.exp(-x),
    'vy': lambda x: x*x,
    'vz': lambda x: x*x,
    'Bx': lambda x: torch.exp(-x)
}

## Conditions aux limites
def boundary_conditions(t):
    """Retourne les valeurs cibles pour x=0"""
    return {
        'rho': 2.0 * torch.ones_like(t),
        'v': torch.zeros_like(t),
        'vy': torch.zeros_like(t),
        'vz': torch.zeros_like(t),
        'P': torch.exp(-t),
        'Bx': torch.exp(-t),
        'By': torch.exp(-t),
        'Bz': torch.exp(-t)
    }

## Chargement des poids pré-entraînés
for i in range(8):
    net = MultiHeadFCNN(n_input_units=2, n_output_units=1, hidden_units=[256, 256])
    
    # Charger les poids existants
    pretrained_dict = torch.load(f"saved_weights/net_{i}.pt", map_location='cpu')
    model_dict = net.state_dict()
    
    # Copier les couches partagées + head1 + head2
    model_dict.update(pretrained_dict)

    # Initialiser head3 avec head2
    for name, param in pretrained_dict.items():
        if name.startswith("head2."):
            corresponding_name = name.replace("head2.", "head3.")
            if corresponding_name in model_dict:
                model_dict[corresponding_name] = param.clone()

    net.load_state_dict(model_dict, strict=False)
    nets.append(net)

## Seul head3 est entraînable
for net in nets:
    for name, param in net.named_parameters():
        param.requires_grad = ('head3' in name)

## Optimiseur
optimizer = optim.Adam([p for net in nets for p in net.parameters() if p.requires_grad], lr=1e-4)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.975)

## PDE system sans forçage
def pde_systemTL(rho, vx, P, By, Bz, vy, vz, Bx, x, t):
    # Convertir les entrées en tensors si nécessaire
    if not isinstance(x, torch.Tensor):
        x = torch.tensor(x, dtype=torch.float32, requires_grad=True)
    if not isinstance(t, torch.Tensor):
        t = torch.tensor(t, dtype=torch.float32, requires_grad=True)
    
    # Assurer que toutes les variables sont des tensors avec grad
    variables = [rho, vx, P, By, Bz, vy, vz, Bx]
    for i, var in enumerate(variables):
        if not isinstance(var, torch.Tensor):
            variables[i] = torch.tensor(var, dtype=torch.float32, requires_grad=True)
    rho, vx, P, By, Bz, vy, vz, Bx = variables
    
    # Calcul des équations PDE
    eq1 = diff(rho, t) + vx * diff(rho, x) + rho * diff(vx, x)
    eq2 = rho * diff(vx, t) + rho * vx * diff(vx, x) + diff(P, x) + By * diff(By, x) + Bz * diff(Bz, x) - rho * mu * diff(vx, x, order=2)
    eq3 = rho * diff(vy, t) + rho * vx * diff(vy, x) - Bx * diff(By, x)
    eq4 = rho * diff(vz, t) + rho * vx * diff(vz, x) - Bx * diff(Bz, x)
    eq5 = diff(P, t) + P * diff(vx, x) + vx * diff(P, x)
    eq6 = diff(Bx, t)
    eq7 = diff(By, t) + vx * diff(By, x) + By * diff(vx, x) - Bx * diff(vy, x)
    eq8 = diff(Bz, t) + vx * diff(Bz, x) + Bz * diff(vx, x) - Bx * diff(vz, x)
    
    return [eq1, eq2, eq3, eq4, eq5, eq6, eq7, eq8]


## Entraînement
n_epochs = 12000
print_interval = n_epochs // 10  # Afficher tous les 10% de progression

for epoch in tqdm(range(n_epochs)):
    optimizer.zero_grad()
    total_loss = 0
    
    # 1. PDE Loss
    samples = train_gen.get_examples()
    x_train = samples[0].view(-1, 1)
    t_train = samples[1].view(-1, 1)
    inputs = torch.cat((x_train, t_train), dim=1)
    
    outputs = [net(inputs, head_idx=2) for net in nets]  # Utilisation de head3
    
    pde_residuals = pde_systemTL(*outputs, x_train, t_train)
    loss_pde = sum([criterion(res, torch.zeros_like(res)) for res in pde_residuals])
    
    # 2. Initial condition loss
    ic_inputs = torch.cat((ic_x, ic_t), dim=1)
    ic_outputs = [net(ic_inputs, head_idx=2) for net in nets]
    ic_targets = [new_initial_conditions[key](ic_x) for key in ['rho', 'v', 'P', 'By', 'Bz', 'vy', 'vz', 'Bx']]
    loss_ic = sum([criterion(out, target) for out, target in zip(ic_outputs, ic_targets)])
    
    # 3. Boundary condition loss (x=0)
    bc_samples = bc_gen.get_examples()
    x_bc = torch.zeros_like(bc_samples[0]).view(-1, 1)  # x=0
    t_bc = bc_samples[1].view(-1, 1)
    bc_inputs = torch.cat((x_bc, t_bc), dim=1)
    
    bc_outputs = [net(bc_inputs, head_idx=2) for net in nets]
    bc_targets = boundary_conditions(t_bc)
    loss_bc = sum([
        criterion(bc_outputs[0], bc_targets['rho']),
        criterion(bc_outputs[1], bc_targets['v']),
        criterion(bc_outputs[2], bc_targets['P']),
        criterion(bc_outputs[3], bc_targets['By']),
        criterion(bc_outputs[4], bc_targets['Bz']),
        criterion(bc_outputs[5], bc_targets['vy']),
        criterion(bc_outputs[6], bc_targets['vz']),
        criterion(bc_outputs[7], bc_targets['Bx'])
    ])
    
    # Total loss
    total_loss = loss_pde + loss_ic + loss_bc
    total_loss.backward()
    optimizer.step()
    scheduler.step()
   
    # Afficher seulement tous les 10% de progression
    if (epoch + 1) % print_interval == 0 or epoch == 0:
        print(f"Epoch {epoch+1}/{n_epochs} ({100*(epoch+1)/n_epochs:.0f}%) | "
              f"PDE: {loss_pde.item():.3e} | "
              f"IC: {loss_ic.item():.3e} | "
              f"BC: {loss_bc.item():.3e} | "
              f"Total: {total_loss.item():.3e}")
    
    total_losses.append(total_loss.item())
    pde_losses.append(loss_pde.item())
    ic_losses.append(loss_ic.item())
    bc_losses.append(loss_bc.item())

# Pour évaluation
def solutions_new_ci(x, t):
    return [net(torch.cat((x, t), dim=1), head_idx=2).detach() for net in nets]

In [None]:
## show loss
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.semilogy(total_losses, label='Total Loss')
plt.semilogy(pde_losses, label='PDE Loss', linestyle='--')
plt.semilogy(ic_losses, label='IC Loss', linestyle=':')
plt.semilogy(bc_losses, label='BC Loss', linestyle=':')


plt.xlabel('Epoch')
plt.ylabel('Loss (log scale)')
plt.title('Training Loss Evolution')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
### solveur and PINN head 1, rho
### works
## 

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import torch

# Paramètres
X = 1.0       # longueur du domaine spatial
nx = 300      # nombre de points spatiaux
L = 1.0       # durée finale
mu = 0.1      # viscosité

# Maillage
x = np.linspace(0, X, nx)
dx = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales
def init_cond():
    rho = 2*np.ones_like(x)
    vx = x*x
    vy = x*x
    vz = x*x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    # Forçage direct des variables aux bords
    rho[0] = 2
    vx[0] = vy[0] = vz[0] = 0
    P[0] = Bx[0] = By[0] = Bz[0] = np.exp(-t)

    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])

# Résolution
y0 = init_cond()
solv1 = solve_ivp(mhd_rhs, (0, L), y0, method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num_v1 = solv1.y[7*nx:8*nx, :]  # shape (nx, nt), (0:nx) pour rho
T, Xg = np.meshgrid(solv1.t, x)

# Solution exacte (calculée mais non affichée)
vx_ex = np.exp(-Xg)*np.exp(-T)
err = vx_num_v1 - vx_ex

# Préparation des données pour le PINN
x_tensor = torch.tensor(Xg.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# Prédiction PINN (en utilisant head_idx=1 pour la tête 2)
with torch.no_grad():
    if next(nets[1].parameters()).dtype != torch.float32:
        for net in nets:
            net.float()
    vx_pinn_flat = solutions(x_tensor, t_tensor, head_idx=2)[4].cpu().numpy().flatten() ## head 1, rho

vx_pinn = vx_pinn_flat.reshape(Xg.shape)


# Configuration des graphiques (3 graphiques au lieu de 4)
plt.figure(figsize=(18, 6))

# Paramètres de style pour agrandir le texte
fontsize = 14  # Taille de police pour les axes et titres
cbar_fontsize = 14  # Taille de police pour la colorbar
title_fontsize = 16  # Taille de police pour les titres

# Graphique 1: Solution numérique
plt.subplot(1, 3, 1)
img1 = plt.pcolormesh(T, Xg, vx_num_v1, shading='auto', cmap='viridis')
cbar1 = plt.colorbar(img1)
cbar1.set_label('Bz', fontsize=cbar_fontsize)
cbar1.ax.tick_params(labelsize=cbar_fontsize)
plt.xlabel('t', fontsize=fontsize)
plt.ylabel('x', fontsize=fontsize)
plt.title('Solver solution', fontsize=title_fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)

# Graphique 2: Prédiction PINN
plt.subplot(1, 3, 2)
img2 = plt.pcolormesh(T, Xg, vx_pinn, shading='auto', cmap='viridis')
cbar2 = plt.colorbar(img2)
cbar2.set_label('Bz', fontsize=cbar_fontsize)
cbar2.ax.tick_params(labelsize=cbar_fontsize)
plt.xlabel('t', fontsize=fontsize)
plt.ylabel('x', fontsize=fontsize)
plt.title('PINN prediction (with transfer learning)', fontsize=title_fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)

# Graphique 3: Différence Solveur-PINN
plt.subplot(1, 3, 3)
diff = vx_num_v1 - vx_pinn
img3 = plt.pcolormesh(T, Xg, diff, shading='auto', cmap='bwr',
               vmin=-np.max(np.abs(diff)), vmax=np.max(np.abs(diff)))
cbar3 = plt.colorbar(img3)
cbar3.set_label('Difference', fontsize=cbar_fontsize)
cbar3.ax.tick_params(labelsize=cbar_fontsize)
plt.xlabel('t', fontsize=fontsize)
plt.ylabel('x', fontsize=fontsize)
plt.title('Solver-PINN difference', fontsize=title_fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)

plt.tight_layout()
plt.show()

In [None]:
## L-BFGS
# TL sur les CI (nouvelle tête head3 initialisée sur head2)
## without ff, with BC
## doesn't work

from neurodiffeq.operators import diff  # Import crucial pour les dérivées


class MultiHeadFCNN(nn.Module):
    def __init__(self, n_input_units=2, n_output_units=1, hidden_units=[256, 256]):
        super().__init__()
        # Couches partagées (inchangées)
        self.shared_layers = nn.Sequential(
            nn.Linear(n_input_units, hidden_units[0]),
            nn.Tanh(),
            nn.Linear(hidden_units[0], hidden_units[1]),
            nn.Tanh()
        )
        
        # Têtes existantes (head1, head2) + nouvelle tête (head3)
        self.head1 = nn.Linear(hidden_units[1], n_output_units)
        self.head2 = nn.Linear(hidden_units[1], n_output_units)
        self.head3 = nn.Linear(hidden_units[1], n_output_units)  # Nouvelle tête pour nouvelles CI

    def forward(self, x, head_idx=2):  # head_idx=2 par défaut (head3)
        shared = self.shared_layers(x)
        if head_idx == 0:
            return self.head1(shared)
        elif head_idx == 1:
            return self.head2(shared)
        else:
            return self.head3(shared)


## Initialisation des réseaux pour le transfer learning
nets = []
total_losses = []
pde_losses = []
ic_losses = []
bc_losses = []  # Pour suivre la loss des conditions aux limites

## Nouvelle condition initiale pour head3
new_initial_conditions = {
    'rho': lambda x: 2*torch.ones_like(x),
    'v': lambda x: x*x,
    'P': lambda x: torch.exp(-x),
    'By': lambda x: torch.exp(-x),
    'Bz': lambda x: torch.exp(-x),
    'vy': lambda x: x*x,
    'vz': lambda x: x*x,
    'Bx': lambda x: torch.exp(-x)
}

## Conditions aux limites
def boundary_conditions(t):
    """Retourne les valeurs cibles pour x=0"""
    return {
        'rho': 2.0 * torch.ones_like(t),
        'v': torch.zeros_like(t),
        'vy': torch.zeros_like(t),
        'vz': torch.zeros_like(t),
        'P': torch.exp(-t),
        'Bx': torch.exp(-t),
        'By': torch.exp(-t),
        'Bz': torch.exp(-t)
    }

## Chargement des poids pré-entraînés
for i in range(8):
    net = MultiHeadFCNN(n_input_units=2, n_output_units=1, hidden_units=[256, 256])
    
    # Charger les poids existants
    pretrained_dict = torch.load(f"saved_weights/net_{i}.pt", map_location='cpu')
    model_dict = net.state_dict()
    
    # Copier les couches partagées + head1 + head2
    model_dict.update(pretrained_dict)

    # Initialiser head3 avec head2
    for name, param in pretrained_dict.items():
        if name.startswith("head2."):
            corresponding_name = name.replace("head2.", "head3.")
            if corresponding_name in model_dict:
                model_dict[corresponding_name] = param.clone()

    net.load_state_dict(model_dict, strict=False)
    nets.append(net)

## Seul head3 est entraînable
for net in nets:
    for name, param in net.named_parameters():
        param.requires_grad = ('head3' in name)

## Optimiseur

optimizer = optim.LBFGS(
    [p for net in nets for p in net.parameters() if p.requires_grad],
    lr=0.01,              # Réduire pour plus de stabilité
    max_iter=10, 
    history_size=50,          # Éviter les sur-optimisations
    line_search_fn=None    # Désactiver si instable
)

criterion = nn.MSELoss()
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.975)

## PDE system sans forçage
def pde_systemTL(rho, vx, P, By, Bz, vy, vz, Bx, x, t):
    # Aucun .detach() ici - tous les tenseurs doivent être différentiables
    eq1 = diff(rho, t) + vx * diff(rho, x) + rho * diff(vx, x)
    eq2 = rho * diff(vx, t) + rho * vx * diff(vx, x) + diff(P, x) + By * diff(By, x) + Bz * diff(Bz, x) - rho * mu * diff(vx, x, order=2)
    eq3 = rho * diff(vy, t) + rho * vx * diff(vy, x) - Bx * diff(By, x)
    eq4 = rho * diff(vz, t) + rho * vx * diff(vz, x) - Bx * diff(Bz, x)
    eq5 = diff(P, t) + P * diff(vx, x) + vx * diff(P, x)
    eq6 = diff(Bx, t)
    eq7 = diff(By, t) + vx * diff(By, x) + By * diff(vx, x) - Bx * diff(vy, x)
    eq8 = diff(Bz, t) + vx * diff(Bz, x) + Bz * diff(vx, x) - Bx * diff(vz, x)
    
    return [eq1, eq2, eq3, eq4, eq5, eq6, eq7, eq8]


## Entraînement
n_epochs = 150
print_interval = n_epochs // 10  # Afficher tous les 10% de progression

def closure():
    optimizer.zero_grad()
    
    # Régénération des données À L'INTÉRIEUR de la closure
    samples = train_gen.get_examples()
    x_train = samples[0].view(-1, 1).requires_grad_(True)
    t_train = samples[1].view(-1, 1).requires_grad_(True)
    inputs = torch.cat((x_train, t_train), dim=1)
    
    # Forward pass
    outputs = []
    for net in nets:
        out = net(inputs, head_idx=2)
        if not out.requires_grad:
            out = out.clone().detach().requires_grad_(True)
        outputs.append(out)
    
    # Calcul des résidus PDE
    train_residuals = pde_systemTL(*outputs, x_train, t_train)
    loss_pde = sum([criterion(res, torch.zeros_like(res)) for res in train_residuals])

    # Calcul des autres losses
    ic_inputs = torch.cat((ic_x, ic_t), dim=1)
    ic_outputs = [net(ic_inputs, head_idx=2) for net in nets]
    ic_targets = [new_initial_conditions[key](ic_x) for key in ['rho', 'v', 'P', 'By', 'Bz', 'vy', 'vz', 'Bx']]
    loss_ic = sum([criterion(out, target) for out, target in zip(ic_outputs, ic_targets)])
    
    bc_samples = bc_gen.get_examples()
    x_bc = torch.zeros_like(bc_samples[0]).view(-1, 1).requires_grad_(True)
    t_bc = bc_samples[1].view(-1, 1).requires_grad_(True)
    bc_inputs = torch.cat((x_bc, t_bc), dim=1)
    bc_outputs = [net(bc_inputs, head_idx=2) for net in nets]
    bc_targets = boundary_conditions(t_bc)
    loss_bc = sum([criterion(bc_outputs[i], bc_targets[key]) 
                  for i, key in enumerate(['rho', 'v', 'P', 'By', 'Bz', 'vy', 'vz', 'Bx'])])
    
    total_loss = loss_pde + loss_ic + loss_bc
    total_loss.backward()

    return total_loss


for epoch in tqdm(range(n_epochs)):
    # L-BFGS step
    optimizer.step(closure)
    
    # Affichage périodique
    if (epoch + 1) % print_interval == 0 or epoch == 0:
        # On doit garder les gradients pour le calcul des résidus PDE
        with torch.enable_grad():  # <-- Changement crucial ici
            samples = train_gen.get_examples()
            x_train = samples[0].view(-1, 1).requires_grad_(True)  # <-- Ajout requires_grad
            t_train = samples[1].view(-1, 1).requires_grad_(True)  # <-- Ajout requires_grad
            inputs = torch.cat((x_train, t_train), dim=1)
            
            outputs = [net(inputs, head_idx=2) for net in nets]
            pde_residuals = pde_systemTL(*outputs, x_train, t_train)
            
            # Calcul des losses avec no_grad pour juste l'évaluation
            with torch.no_grad():
                loss_pde = sum([criterion(res, torch.zeros_like(res)) for res in pde_residuals]).item()
                ic_outputs = [net(ic_inputs, head_idx=2) for net in nets]
                loss_ic = sum([criterion(out, target) for out, target in zip(ic_outputs, ic_targets)]).item()
                
                bc_outputs = [net(bc_inputs, head_idx=2) for net in nets]
                loss_bc = sum([criterion(bc_outputs[i], bc_targets[key]) 
                             for i, key in enumerate(['rho', 'v', 'P', 'By', 'Bz', 'vy', 'vz', 'Bx'])]).item()
                
                total_loss = loss_pde + loss_ic + loss_bc
                print(f"Epoch {epoch+1}/{n_epochs} ({100*(epoch+1)/n_epochs:.0f}%) | "
                      f"PDE: {loss_pde:.3e} | "
                      f"IC: {loss_ic:.3e} | "
                      f"BC: {loss_bc:.3e} | "
                      f"Total: {total_loss:.3e}")
                
                print("Résidus:", [r.mean().item() for r in pde_residuals])
        
        # Enregistrement des losses
        total_losses.append(total_loss)
        pde_losses.append(loss_pde)
        ic_losses.append(loss_ic)
        bc_losses.append(loss_bc)

    


# Pour évaluation
def solutions_new_ci(x, t):
    return [net(torch.cat((x, t), dim=1), head_idx=2).detach() for net in nets]

In [None]:
## L-BFGS method !
# TL on CI (new head : head3 initialized on head2)
## without ff, with BC

from neurodiffeq.operators import diff  # Import crucial pour les dérivées


class MultiHeadFCNN(nn.Module):
    def __init__(self, n_input_units=2, n_output_units=1, hidden_units=[256, 256]):
        super().__init__()
        # Couches partagées (inchangées)
        self.shared_layers = nn.Sequential(
            nn.Linear(n_input_units, hidden_units[0]),
            nn.Tanh(),
            nn.Linear(hidden_units[0], hidden_units[1]),
            nn.Tanh()
        )
        
        # Têtes existantes (head1, head2) + nouvelle tête (head3)
        self.head1 = nn.Linear(hidden_units[1], n_output_units)
        self.head2 = nn.Linear(hidden_units[1], n_output_units)
        self.head3 = nn.Linear(hidden_units[1], n_output_units)  # Nouvelle tête pour nouvelles CI

    def forward(self, x, head_idx=2):  # head_idx=2 par défaut (head3)
        shared = self.shared_layers(x)
        if head_idx == 0:
            return self.head1(shared)
        elif head_idx == 1:
            return self.head2(shared)
        else:
            return self.head3(shared)


## Initialisation des réseaux pour le transfer learning
nets = []
total_losses = []
pde_losses = []
ic_losses = []
bc_losses = []  # Pour suivre la loss des conditions aux limites

## Nouvelle condition initiale pour head3
new_initial_conditions = {
    'rho': lambda x: 2*torch.ones_like(x),
    'v': lambda x: x*x,
    'P': lambda x: torch.exp(-x),
    'By': lambda x: torch.exp(-x),
    'Bz': lambda x: torch.exp(-x),
    'vy': lambda x: x*x,
    'vz': lambda x: x*x,
    'Bx': lambda x: torch.exp(-x)
}

## Conditions aux limites
def boundary_conditions(t):
    """Retourne les valeurs cibles pour x=0"""
    return {
        'rho': 2.0 * torch.ones_like(t),
        'v': torch.zeros_like(t),
        'vy': torch.zeros_like(t),
        'vz': torch.zeros_like(t),
        'P': torch.exp(-t),
        'Bx': torch.exp(-t),
        'By': torch.exp(-t),
        'Bz': torch.exp(-t)
    }

## Chargement des poids pré-entraînés
for i in range(8):
    net = MultiHeadFCNN(n_input_units=2, n_output_units=1, hidden_units=[256, 256])
    
    # Charger les poids existants
    pretrained_dict = torch.load(f"saved_weights/net_{i}.pt", map_location='cpu')
    model_dict = net.state_dict()
    
    # Copier les couches partagées + head1 + head2
    model_dict.update(pretrained_dict)

    # Initialiser head3 avec head2
    for name, param in pretrained_dict.items():
        if name.startswith("head2."):
            corresponding_name = name.replace("head2.", "head3.")
            if corresponding_name in model_dict:
                model_dict[corresponding_name] = param.clone()

    net.load_state_dict(model_dict, strict=False)
    nets.append(net)

## Seul head3 est entraînable
for net in nets:
    for name, param in net.named_parameters():
        param.requires_grad = ('head3' in name)

## Optimiseur

optimizer = optim.LBFGS(
    [p for net in nets for p in net.parameters() if p.requires_grad],
    lr=0.01,              # Réduire pour plus de stabilité
    max_iter=10, 
    history_size=50,          # Éviter les sur-optimisations
    line_search_fn=None    # Désactiver si instable
)

criterion = nn.MSELoss()
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.975)

## PDE system sans forçage
def pde_systemTL(rho, vx, P, By, Bz, vy, vz, Bx, x, t):
    # Aucun .detach() ici - tous les tenseurs doivent être différentiables
    eq1 = diff(rho, t) + vx * diff(rho, x) + rho * diff(vx, x)
    eq2 = rho * diff(vx, t) + rho * vx * diff(vx, x) + diff(P, x) + By * diff(By, x) + Bz * diff(Bz, x) - rho * mu * diff(vx, x, order=2)
    eq3 = rho * diff(vy, t) + rho * vx * diff(vy, x) - Bx * diff(By, x)
    eq4 = rho * diff(vz, t) + rho * vx * diff(vz, x) - Bx * diff(Bz, x)
    eq5 = diff(P, t) + P * diff(vx, x) + vx * diff(P, x)
    eq6 = diff(Bx, t)
    eq7 = diff(By, t) + vx * diff(By, x) + By * diff(vx, x) - Bx * diff(vy, x)
    eq8 = diff(Bz, t) + vx * diff(Bz, x) + Bz * diff(vx, x) - Bx * diff(vz, x)
    
    return [eq1, eq2, eq3, eq4, eq5, eq6, eq7, eq8]


## Entraînement
n_epochs = 150
print_interval = n_epochs // 10  # Afficher tous les 10% de progression

def closure():
    optimizer.zero_grad()
    
    # Régénération des données À L'INTÉRIEUR de la closure
    samples = train_gen.get_examples()
    x_train = samples[0].view(-1, 1).requires_grad_(True)
    t_train = samples[1].view(-1, 1).requires_grad_(True)
    inputs = torch.cat((x_train, t_train), dim=1)
    
    # Forward pass
    outputs = []
    for net in nets:
        out = net(inputs, head_idx=2)
        if not out.requires_grad:
            out = out.clone().detach().requires_grad_(True)
        outputs.append(out)
    
    # Calcul des résidus PDE
    train_residuals = pde_systemTL(*outputs, x_train, t_train)
    loss_pde = sum([criterion(res, torch.zeros_like(res)) for res in train_residuals])

    # Calcul des autres losses
    ic_inputs = torch.cat((ic_x, ic_t), dim=1)
    ic_outputs = [net(ic_inputs, head_idx=2) for net in nets]
    ic_targets = [new_initial_conditions[key](ic_x) for key in ['rho', 'v', 'P', 'By', 'Bz', 'vy', 'vz', 'Bx']]
    loss_ic = sum([criterion(out, target) for out, target in zip(ic_outputs, ic_targets)])
    
    bc_samples = bc_gen.get_examples()
    x_bc = torch.zeros_like(bc_samples[0]).view(-1, 1).requires_grad_(True)
    t_bc = bc_samples[1].view(-1, 1).requires_grad_(True)
    bc_inputs = torch.cat((x_bc, t_bc), dim=1)
    bc_outputs = [net(bc_inputs, head_idx=2) for net in nets]
    bc_targets = boundary_conditions(t_bc)
    loss_bc = sum([criterion(bc_outputs[i], bc_targets[key]) 
                  for i, key in enumerate(['rho', 'v', 'P', 'By', 'Bz', 'vy', 'vz', 'Bx'])])
    
    total_loss = loss_pde + loss_ic + loss_bc
    total_loss.backward()

    return total_loss


for epoch in tqdm(range(n_epochs)):
    # L-BFGS step
    optimizer.step(closure)
    
    # Affichage périodique
    if (epoch + 1) % print_interval == 0 or epoch == 0:
        # On doit garder les gradients pour le calcul des résidus PDE
        with torch.enable_grad():  # <-- Changement crucial ici
            samples = train_gen.get_examples()
            x_train = samples[0].view(-1, 1).requires_grad_(True)  # <-- Ajout requires_grad
            t_train = samples[1].view(-1, 1).requires_grad_(True)  # <-- Ajout requires_grad
            inputs = torch.cat((x_train, t_train), dim=1)
            
            outputs = [net(inputs, head_idx=2) for net in nets]
            pde_residuals = pde_systemTL(*outputs, x_train, t_train)
            
            # Calcul des losses avec no_grad pour juste l'évaluation
            with torch.no_grad():
                loss_pde = sum([criterion(res, torch.zeros_like(res)) for res in pde_residuals]).item()
                ic_outputs = [net(ic_inputs, head_idx=2) for net in nets]
                loss_ic = sum([criterion(out, target) for out, target in zip(ic_outputs, ic_targets)]).item()
                
                bc_outputs = [net(bc_inputs, head_idx=2) for net in nets]
                loss_bc = sum([criterion(bc_outputs[i], bc_targets[key]) 
                             for i, key in enumerate(['rho', 'v', 'P', 'By', 'Bz', 'vy', 'vz', 'Bx'])]).item()
                
                total_loss = loss_pde + loss_ic + loss_bc
                print(f"Epoch {epoch+1}/{n_epochs} ({100*(epoch+1)/n_epochs:.0f}%) | "
                      f"PDE: {loss_pde:.3e} | "
                      f"IC: {loss_ic:.3e} | "
                      f"BC: {loss_bc:.3e} | "
                      f"Total: {total_loss:.3e}")
                
                print("Résidus:", [r.mean().item() for r in pde_residuals])
       
        
        # Enregistrement des losses
        total_losses.append(total_loss)
        pde_losses.append(loss_pde)
        ic_losses.append(loss_ic)
        bc_losses.append(loss_bc)

    


# Pour évaluation
def solutions_new_ci(x, t):
    return [net(torch.cat((x, t), dim=1), head_idx=2).detach() for net in nets]

In [None]:
# TL on the CI (new head : head3 initialized on head2)
## without ff, with BC
## works !
## L-BFGS method

from neurodiffeq.operators import diff  # Import crucial pour les dérivées


class MultiHeadFCNN(nn.Module):
    def __init__(self, n_input_units=2, n_output_units=1, hidden_units=[256, 256]):
        super().__init__()
        # Couches partagées (inchangées)
        self.shared_layers = nn.Sequential(
            nn.Linear(n_input_units, hidden_units[0]),
            nn.Tanh(),
            nn.Linear(hidden_units[0], hidden_units[1]),
            nn.Tanh()
        )
        
        # Têtes existantes (head1, head2) + nouvelle tête (head3)
        self.head1 = nn.Linear(hidden_units[1], n_output_units)
        self.head2 = nn.Linear(hidden_units[1], n_output_units)
        self.head3 = nn.Linear(hidden_units[1], n_output_units)  # Nouvelle tête pour nouvelles CI

    def forward(self, x, head_idx=2):  # head_idx=2 par défaut (head3)
        shared = self.shared_layers(x)
        if head_idx == 0:
            return self.head1(shared)
        elif head_idx == 1:
            return self.head2(shared)
        else:
            return self.head3(shared)


## Initialisation des réseaux pour le transfer learning
nets = []
total_losses = []
pde_losses = []
ic_losses = []
bc_losses = []  # Pour suivre la loss des conditions aux limites

## Nouvelle condition initiale pour head3
new_initial_conditions = {
    'rho': lambda x: 2*torch.ones_like(x),
    'v': lambda x: x*x,
    'P': lambda x: torch.exp(-x),
    'By': lambda x: torch.exp(-x),
    'Bz': lambda x: torch.exp(-x),
    'vy': lambda x: x*x,
    'vz': lambda x: x*x,
    'Bx': lambda x: torch.exp(-x)
}

## Conditions aux limites
def boundary_conditions(t):
    """Retourne les valeurs cibles pour x=0"""
    return {
        'rho': 2.0 * torch.ones_like(t),
        'v': torch.zeros_like(t),
        'vy': torch.zeros_like(t),
        'vz': torch.zeros_like(t),
        'P': torch.exp(-t),
        'Bx': torch.exp(-t),
        'By': torch.exp(-t),
        'Bz': torch.exp(-t)
    }

## Chargement des poids pré-entraînés
for i in range(8):
    net = MultiHeadFCNN(n_input_units=2, n_output_units=1, hidden_units=[256, 256])
    
    # Charger les poids existants
    pretrained_dict = torch.load(f"saved_weights/net_{i}.pt", map_location='cpu')
    model_dict = net.state_dict()
    
    # Copier les couches partagées + head1 + head2
    model_dict.update(pretrained_dict)

    # Initialiser head3 avec head2
    for name, param in pretrained_dict.items():
        if name.startswith("head2."):
            corresponding_name = name.replace("head2.", "head3.")
            if corresponding_name in model_dict:
                model_dict[corresponding_name] = param.clone()

    net.load_state_dict(model_dict, strict=False)
    nets.append(net)

## Seul head3 est entraînable
for net in nets:
    for name, param in net.named_parameters():
        param.requires_grad = ('head3' in name)

## Optimiseur

optimizer = optim.LBFGS(
    [p for net in nets for p in net.parameters() if p.requires_grad],
    lr=0.01,              # Réduire pour plus de stabilité
    max_iter=10, 
    history_size=50,          # Éviter les sur-optimisations
    line_search_fn='strong_wolfe'    # None à désactiver si instable
)

criterion = nn.MSELoss()
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.975)

## PDE system sans forçage
def pde_systemTL(rho, vx, P, By, Bz, vy, vz, Bx, x, t):
    # Aucun .detach() ici - tous les tenseurs doivent être différentiables
    eq1 = diff(rho, t) + vx * diff(rho, x) + rho * diff(vx, x)
    eq2 = rho * diff(vx, t) + rho * vx * diff(vx, x) + diff(P, x) + By * diff(By, x) + Bz * diff(Bz, x) - rho * mu * diff(vx, x, order=2)
    eq3 = rho * diff(vy, t) + rho * vx * diff(vy, x) - Bx * diff(By, x)
    eq4 = rho * diff(vz, t) + rho * vx * diff(vz, x) - Bx * diff(Bz, x)
    eq5 = diff(P, t) + P * diff(vx, x) + vx * diff(P, x)
    eq6 = diff(Bx, t)
    eq7 = diff(By, t) + vx * diff(By, x) + By * diff(vx, x) - Bx * diff(vy, x)
    eq8 = diff(Bz, t) + vx * diff(Bz, x) + Bz * diff(vx, x) - Bx * diff(vz, x)
    
    return [eq1, eq2, eq3, eq4, eq5, eq6, eq7, eq8]


## Entraînement
n_epochs = 150
print_interval = n_epochs // 10  # Afficher tous les 10% de progression

def closure():
    optimizer.zero_grad()
    
    # Régénération des données À L'INTÉRIEUR de la closure
    samples = train_gen.get_examples()
    x_train = samples[0].view(-1, 1).requires_grad_(True)
    t_train = samples[1].view(-1, 1).requires_grad_(True)
    inputs = torch.cat((x_train, t_train), dim=1)
    
    # Forward pass
    outputs = []
    for net in nets:
        out = net(inputs, head_idx=2)
        if not out.requires_grad:
            out = out.clone().detach().requires_grad_(True)
        outputs.append(out)
    
    # Calcul des résidus PDE
    train_residuals = pde_systemTL(*outputs, x_train, t_train)
    loss_pde = sum([criterion(res, torch.zeros_like(res)) for res in train_residuals])

    # Calcul des autres losses
    ic_inputs = torch.cat((ic_x, ic_t), dim=1)
    ic_outputs = [net(ic_inputs, head_idx=2) for net in nets]
    ic_targets = [new_initial_conditions[key](ic_x) for key in ['rho', 'v', 'P', 'By', 'Bz', 'vy', 'vz', 'Bx']]
    loss_ic = sum([criterion(out, target) for out, target in zip(ic_outputs, ic_targets)])
    
    bc_samples = bc_gen.get_examples()
    x_bc = torch.zeros_like(bc_samples[0]).view(-1, 1).requires_grad_(True)
    t_bc = bc_samples[1].view(-1, 1).requires_grad_(True)
    bc_inputs = torch.cat((x_bc, t_bc), dim=1)
    bc_outputs = [net(bc_inputs, head_idx=2) for net in nets]
    bc_targets = boundary_conditions(t_bc)
    loss_bc = sum([criterion(bc_outputs[i], bc_targets[key]) 
                  for i, key in enumerate(['rho', 'v', 'P', 'By', 'Bz', 'vy', 'vz', 'Bx'])])
    
    total_loss = loss_pde + loss_ic + loss_bc
    total_loss.backward()

    return total_loss


for epoch in tqdm(range(n_epochs)):
    # L-BFGS step
    optimizer.step(closure)
    
    # Affichage périodique
    if (epoch + 1) % print_interval == 0 or epoch == 0:
        # On doit garder les gradients pour le calcul des résidus PDE
        with torch.enable_grad():  # <-- Changement crucial ici
            samples = train_gen.get_examples()
            x_train = samples[0].view(-1, 1).requires_grad_(True)  # <-- Ajout requires_grad
            t_train = samples[1].view(-1, 1).requires_grad_(True)  # <-- Ajout requires_grad
            inputs = torch.cat((x_train, t_train), dim=1)
            
            outputs = [net(inputs, head_idx=2) for net in nets]
            pde_residuals = pde_systemTL(*outputs, x_train, t_train)
            
            # Calcul des losses avec no_grad pour juste l'évaluation
            with torch.no_grad():
                loss_pde = sum([criterion(res, torch.zeros_like(res)) for res in pde_residuals]).item()
                ic_outputs = [net(ic_inputs, head_idx=2) for net in nets]
                loss_ic = sum([criterion(out, target) for out, target in zip(ic_outputs, ic_targets)]).item()
                
                bc_outputs = [net(bc_inputs, head_idx=2) for net in nets]
                loss_bc = sum([criterion(bc_outputs[i], bc_targets[key]) 
                             for i, key in enumerate(['rho', 'v', 'P', 'By', 'Bz', 'vy', 'vz', 'Bx'])]).item()
                
                total_loss = loss_pde + loss_ic + loss_bc
                print(f"Epoch {epoch+1}/{n_epochs} ({100*(epoch+1)/n_epochs:.0f}%) | "
                      f"PDE: {loss_pde:.3e} | "
                      f"IC: {loss_ic:.3e} | "
                      f"BC: {loss_bc:.3e} | "
                      f"Total: {total_loss:.3e}")
                
                print("Résidus:", [r.mean().item() for r in pde_residuals])
        
        # Enregistrement des losses
        total_losses.append(total_loss)
        pde_losses.append(loss_pde)
        ic_losses.append(loss_ic)
        bc_losses.append(loss_bc)

    


# Pour évaluation
def solutions_new_ci(x, t):
    return [net(torch.cat((x, t), dim=1), head_idx=2).detach() for net in nets]

In [None]:
## loss correct

import matplotlib.pyplot as plt
import numpy as np

plt.figure(figsize=(10, 6))

# Création des abscisses correctes (de 1 à 150)
epochs = np.linspace(1, 150, len(total_losses))  # Adapte l'échelle

plt.semilogy(epochs, total_losses, label='Total Loss')
plt.semilogy(epochs, pde_losses, label='PDE Loss', linestyle='--')
plt.semilogy(epochs, ic_losses, label='IC Loss', linestyle=':')
plt.semilogy(epochs, bc_losses, label='BC Loss', linestyle=':')

plt.xlabel('Epoch')
plt.ylabel('Loss (log scale)')
plt.title('Training Loss Evolution')
plt.legend()
plt.grid(True, which="both", ls="-")
plt.xticks(np.arange(0, 151, 25))  # Graduations tous les 15 epochs
plt.show()

In [None]:
### solveur and PINN head 1
### works
## L-BFGS method
## output must be adapted to your need

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import torch

# Paramètres
X = 1.0       # longueur du domaine spatial
nx = 300      # nombre de points spatiaux
L = 1.0       # durée finale
mu = 0.1      # viscosité

# Maillage
x = np.linspace(0, X, nx)
dx = x[1] - x[0]
t_eval = np.linspace(0, L, nx)

# Conditions initiales
def init_cond():
    rho = 2*np.ones_like(x)
    vx = x*x
    vy = x*x
    vz = x*x
    P = np.exp(-x)
    Bx = np.exp(-x)
    By = np.exp(-x)
    Bz = np.exp(-x)
    return np.concatenate([rho, vx, vy, vz, P, Bx, By, Bz])

# Dérivées spatiales (central diff, Neumann)
def d_dx(u):
    return np.gradient(u, dx)

def d2_dx2(u):
    return np.gradient(np.gradient(u, dx), dx)

# Fonction du système ODE: dy/dt = ...
def mhd_rhs(t, y):
    f = forcing_terms(x, t)  # shape (8, N)

    rho, vx, vy, vz, P, Bx, By, Bz = np.split(y, 8)

    drho_dt = -vx * d_dx(rho) - rho * d_dx(vx) 
    dvx_dt = (-vx * d_dx(vx) - (1/rho) * d_dx(P)
              - (1/rho)*(By * d_dx(By) + Bz * d_dx(Bz))
              + mu * d2_dx2(vx) )
    dvy_dt = (-vx * d_dx(vy) + (1/rho)*Bx*d_dx(By) )
    dvz_dt = (-vx * d_dx(vz) + (1/rho)*Bx*d_dx(Bz) )
    dP_dt = -P * d_dx(vx) - vx * d_dx(P) 
    dBx_dt = np.zeros_like(Bx) 
    dBy_dt = (-vx * d_dx(By) - By * d_dx(vx) + Bx * d_dx(vy) )
    dBz_dt = (-vx * d_dx(Bz) - Bz * d_dx(vx) + Bx * d_dx(vz) )

    # Forçage direct des variables aux bords
    rho[0] = 2
    vx[0] = vy[0] = vz[0] = 0
    P[0] = Bx[0] = By[0] = Bz[0] = np.exp(-t)

    return np.concatenate([drho_dt, dvx_dt, dvy_dt, dvz_dt, dP_dt, dBx_dt, dBy_dt, dBz_dt])

# Résolution
y0 = init_cond()
solv1 = solve_ivp(mhd_rhs, (0, L), y0, method='RK45', t_eval=t_eval)

# Extraction de vx numérique
vx_num_v1 = solv1.y[6*nx:7*nx, :]  # shape (nx, nt), (0:nx) pour rho
T, Xg = np.meshgrid(solv1.t, x)

# Solution exacte (calculée mais non affichée)
vx_ex = np.exp(-Xg)*np.exp(-T)
err = vx_num_v1 - vx_ex

# Préparation des données pour le PINN
x_tensor = torch.tensor(Xg.flatten(), dtype=torch.float32).view(-1, 1)
t_tensor = torch.tensor(T.flatten(), dtype=torch.float32).view(-1, 1)

# Prédiction PINN (en utilisant head_idx=1 pour la tête 2)
with torch.no_grad():
    if next(nets[1].parameters()).dtype != torch.float32:
        for net in nets:
            net.float()
    vx_pinn_flat = solutions(x_tensor, t_tensor, head_idx=2)[3].cpu().numpy().flatten() ## head 1, rho

vx_pinn = vx_pinn_flat.reshape(Xg.shape)


# Configuration des graphiques (3 graphiques au lieu de 4)
plt.figure(figsize=(18, 6))

# Paramètres de style pour agrandir le texte
fontsize = 14  # Taille de police pour les axes et titres
cbar_fontsize = 14  # Taille de police pour la colorbar
title_fontsize = 16  # Taille de police pour les titres

# Graphique 1: Solution numérique
plt.subplot(1, 3, 1)
img1 = plt.pcolormesh(T, Xg, vx_num_v1, shading='auto', cmap='viridis')
cbar1 = plt.colorbar(img1)
cbar1.set_label('By', fontsize=cbar_fontsize)
cbar1.ax.tick_params(labelsize=cbar_fontsize)
plt.xlabel('t', fontsize=fontsize)
plt.ylabel('x', fontsize=fontsize)
plt.title('Solver solution', fontsize=title_fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)

# Graphique 2: Prédiction PINN
plt.subplot(1, 3, 2)
img2 = plt.pcolormesh(T, Xg, vx_pinn, shading='auto', cmap='viridis')
cbar2 = plt.colorbar(img2)
cbar2.set_label('By', fontsize=cbar_fontsize)
cbar2.ax.tick_params(labelsize=cbar_fontsize)
plt.xlabel('t', fontsize=fontsize)
plt.ylabel('x', fontsize=fontsize)
plt.title('PINN prediction(transfer learning using L-BFGS)', fontsize=title_fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)

# Graphique 3: Différence Solveur-PINN
plt.subplot(1, 3, 3)
diff = vx_num_v1 - vx_pinn
img3 = plt.pcolormesh(T, Xg, diff, shading='auto', cmap='bwr',
               vmin=-np.max(np.abs(diff)), vmax=np.max(np.abs(diff)))
cbar3 = plt.colorbar(img3)
cbar3.set_label('Difference', fontsize=cbar_fontsize)
cbar3.ax.tick_params(labelsize=cbar_fontsize)
plt.xlabel('t', fontsize=fontsize)
plt.ylabel('x', fontsize=fontsize)
plt.title('Solver-PINN difference', fontsize=title_fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)

plt.tight_layout()
plt.show()