In [50]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch.optim as optim
from torch.autograd import Variable
from typing import Dict, Any, Tuple, Union, NamedTuple


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



torch.manual_seed(995) 

<torch._C.Generator at 0x71920c3425d0>

In [51]:
# Physical constants
F = 96485 # Faraday constant [C/mol]
R = 8.3145 # Gas constant [J/(mol·K)]
T = 603 # Temperature [K] - from Seyeux experiments
N_V = 6.022e23 # ovogadro's number [mol⁻¹]

# Diffusion coefficients [m²/s]
D_vo = 5.30e-23 # Oxygen vacancy diffusion coefficient
D_MCr = 5.0e-24 # Cation vacancy diffusion coefficient
D_ICr = 1.0e-20 # Cation interstitial diffusion coefficient

# Species charges (inferred from defect chemistry)
z_ov = +2 # Oxygen vacancy charge
z_MCr = -3 # Cation vacancy charge (V_Cr''')
z_ICr = +3 # Cation interstitial charge (Cr_i•••)

# Initial potentials [V]
phi_f_i = 0.1 # Initial film potential
phi_mf_0 = 0.001 # Metal/film interface potential  
phi_fs_0 = 0.3 # Initial film/solution interface potential

# Applied potential and distribution
delta_V = 0.01 # Applied potential change [V]
alpha = 0.5 # Potential distribution parameter
x_d = 5.0e-8 # Decay length [m]

# Alloy composition
chi_Cr = 0.32 # Chromium mole fraction
chi_Fe = 0.10 # Iron mole fraction
chi_Ni = 0.58 # Nickel mole fraction

# Gibbs free energies [J/mol] - for all reactions
delta_G1 = -15000 # Oxygen vacancy formation at metal/film
delta_G8 = -100000 # Oxygen vacancy reaction at film/solution
delta_G2 = -90000 # Cation vacancy formation
delta_G4 = -30000 # Cation vacancy dissolution
delta_G6 = -85000 # Cation interstitial formation
delta_G9 = 100000 # Related to cation transport
delta_G11 = 30000 # Additional reaction energies
delta_G13 = 10000
delta_G10 = -1000
delta_G12 = -6000
delta_G14 = -3000
delta_G3 = 85000
delta_G5 = 10000
delta_G7 = 10000

# Dissolution parameters (if including dissolution)
n = 3 # Reaction order
k_0 = 8.40e16 # Pre-exponential factor [m⁻²s⁻¹]
m = 1 # pH dependence order
E_a = 54000 # Activation energy [J/mol]

# Molar volume
Omega = 1.4e-5 # [m³/mol]

# Solution properties
pH = 7.2 # From Seyeux Alloy 690 experiments
c_H = 10**(-pH) * 1000 # Proton concentration [mol/m³]


#Applied Potential [V]
E_min = 0.0
E_max = 1.8

In [52]:
#Define networks

class Swish(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self,x):
        return torch.sigmoid(x)*x
    

class FFN(nn.Module):
    """
    Fully Connected Feed Forward Neural Network.
    Args:
        input_dim: Number of input features
        output_dim: Number of output features  
        hidden_layers: Number of hidden layers
        layer_size: Size of each hidden layer
        activation: Activation function name ('swish', 'swoosh', 'swash', 'squash_swish', 'relu', 'tanh')
        initialize_weights: Whether to apply Xovier initialization
    """
    def __init__(
        self,
        input_dim: int = 2,
        output_dim: int = 1,
        hidden_layers: int = 5,
        layer_size: int = 20,
        activation: str = "swish",
        initialize_weights: bool = False
    ):
        super(FFN, self).__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = hidden_layers
        self.layer_size = layer_size
        self.activation = Swish()
        
        # Input layer
        self.input_layer = nn.Linear(input_dim, self.layer_size)
        
        # Hidden layers
        self.hidden_layers = nn.ModuleList([
            nn.Linear(self.layer_size, self.layer_size)
            for _ in range(self.num_layers)  
        ])
        
        # Output layer
        self.output_layer = nn.Linear(self.layer_size, output_dim)
        
        # Initialize weights
        if initialize_weights:
            self.initialize_weights()
    
    def initialize_weights(self):
        """Apply Xovier initialization to all linear layers"""
        # Initialize input layer
        nn.init.xavier_normal_(self.input_layer.weight)
        nn.init.zeros_(self.input_layer.bias)
        
        # Initialize hidden layers
        for layer in self.hidden_layers:
            nn.init.xavier_normal_(layer.weight)
            nn.init.zeros_(layer.bias)
        
        # Initialize output layer
        nn.init.xavier_normal_(self.output_layer.weight)
        nn.init.zeros_(self.output_layer.bias)
    
    def forward(self, x):
        x = self.activation(self.input_layer(x))
        
        for layer in self.hidden_layers: 
            x = self.activation(layer(x))
        
        return self.output_layer(x)


class ResidualBlock(nn.Module):
    """Single residual block: x + F(x)"""
    def __init__(self, layer_size, activation):
        super(ResidualBlock, self).__init__()
        self.layer_size = layer_size
        self.activation = activation
        
        # Two layers in each residual block
        self.linear1 = nn.Linear(layer_size, layer_size)
        self.linear2 = nn.Linear(layer_size, layer_size)
        
    def initialize_weights(self):
        nn.init.xavier_normal_(self.linear1.weight)
        nn.init.zeros_(self.linear1.bias)
        nn.init.xavier_normal_(self.linear2.weight)
        nn.init.zeros_(self.linear2.bias)
    
    def forward(self, x):
        identity = x  # Save input for residual connection
        
        # F(x) computation
        out = self.activation(self.linear1(x))
        out = self.linear2(out)  # No activation on final layer of block
        
        # Residual connection: x + F(x)
        out = out + identity
        
        # Activation after residual connection
        out = self.activation(out)
        
        return out
    
class ResidualFFN(nn.Module):
    def __init__(self, input_dim=3, output_dim=1, num_layers=8, layer_size=50, initialize_weights=False):
        super(ResidualFFN, self).__init__()
        self.layer_size = layer_size
        self.num_layers = num_layers
        self.activation = Swish()
        
        # Input projection to get to residual dimension
        self.input_layer = nn.Linear(input_dim, self.layer_size)
        
        # Residual blocks
        self.residual_layers = nn.ModuleList([
            ResidualBlock(self.layer_size, self.activation)
            for _ in range(self.num_layers)  
        ])
        
        # Output layer
        self.output_layer = nn.Linear(self.layer_size, output_dim)
        
        if initialize_weights:
            self.initialize_weights()
    
    def initialize_weights(self):
        """Apply Xavier initialization to all linear layers"""
        nn.init.xavier_normal_(self.input_layer.weight)
        nn.init.zeros_(self.input_layer.bias)
        
        for block in self.residual_layers:
            block.initialize_weights()
        
        nn.init.xavier_normal_(self.output_layer.weight)
        nn.init.zeros_(self.output_layer.bias)
    
    def forward(self, x):
        # Input projection
        x = self.activation(self.input_layer(x))
        
        # Residual blocks
        for residual_layer in self.residual_layers:
            x = residual_layer(x)
        
        # Output
        return self.output_layer(x)
    

ov_net = ResidualFFN(input_dim=2, output_dim=1, num_layers=3, layer_size=20)
cv_net = ResidualFFN(input_dim=2, output_dim=1, num_layers=3, layer_size=20)
cir_net = ResidualFFN(input_dim=2, output_dim=1, num_layers=3, layer_size=20)
L_net = ResidualFFN(input_dim=1, output_dim=1, num_layers=3, layer_size=20)
cv_net.to(device)
ov_net.to(device)
cir_net.to(device)
L_net.to(device)

total_model_parameters = list(cv_net.parameters()) + list(ov_net.parameters()) + list(cir_net.parameters()) + list(L_net.parameters()) 

In [53]:
#gradient and sampling utils

class GradientResults(NamedTuple):
    """
    Container for gradient computation results.

    Organizes all computed derivatives in a structured way for easy access.
    """
    # Network predictions
    c_cir: torch.Tensor # Cation Intersitial concentration
    c_cv: torch.Tensor  # Cation vacancy concentration
    c_ov: torch.Tensor  # Oxygen vacancy concentration

    # Time derivatives
    c_cir_t: torch.Tensor
    c_cv_t: torch.Tensor  # ∂c_cv/∂t
    c_ov_t: torch.Tensor  # ∂c_ov/∂t

    # First spatial derivatives
    c_cir_x: torch.Tensor
    c_cv_x: torch.Tensor  # ∂c_cv/∂x
    c_ov_x: torch.Tensor  # ∂c_ov/∂x

    # Second spatial derivatives
    c_cir_xx: torch.Tensor  
    c_cv_xx: torch.Tensor  # ∂²c_cv/∂x²
    c_ov_xx: torch.Tensor  # ∂²c_av/∂x²


def _grad(x,dx):
    """Take the derrivative of x w.r.t dx"""

    return torch.autograd.grad(x,dx,torch.ones_like(dx),create_graph=True,retain_graph=True)[0]

def compute_gradients(x, t):
    inputs = torch.cat([x, t], dim=1)

    # Get network predictions
    c_cir_raw = cir_net(inputs)
    c_cv_raw = cv_net(inputs)
    c_ov_raw = ov_net(inputs)


    # Networks predict concentrations directly
    c_cv = c_cv_raw
    c_ov = c_ov_raw
    c_cir = c_cir_raw

    # Direct derivatives
    c_cv_t = _grad(c_cv, t)
    c_ov_t = _grad(c_ov, t)
    c_cir_t = _grad(c_cir,t)

    c_cv_x = _grad(c_cv, x)
    c_ov_x = _grad(c_ov, x)
    c_cir_x = _grad(c_cir,x)
   
    c_cv_xx = _grad(c_cv_x, x)
    c_ov_xx = _grad(c_ov_x, x)
    c_cir_xx = _grad(c_cir_x,x)

    return GradientResults(
        c_cir=c_cir, c_cv=c_cv, c_ov=c_ov,
        c_cv_t=c_cv_t, c_ov_t=c_ov_t, c_cir_t = c_cir_t,
        c_cir_x = c_cir_x, c_cv_x=c_cv_x, c_ov_x=c_ov_x, 
        c_cir_xx = c_cir_xx, c_cv_xx=c_cv_xx, c_ov_xx=c_ov_xx
    )

def sample_interior_points():
    batch_size = 2048
    t = torch.rand(batch_size, 1, device=device, requires_grad=True)
    L_pred = L_net(t)  # L_net takes only time
    x = torch.rand(batch_size, 1, device=device, requires_grad=True) * L_pred
    return x, t

def sample_boundary_points():
    batch_size = 2 * 1024
    t = torch.rand(batch_size, 1, device=device, requires_grad=True)
    L_pred = L_net(t)
    
    half_batch = batch_size // 2
    x_mf = torch.zeros(half_batch, 1, device=device, requires_grad=True)
    x_fs = L_pred[half_batch:]
    
    x_boundary = torch.cat([x_mf, x_fs], dim=0)
    t_boundary = torch.cat([t[:half_batch], t[half_batch:]], dim=0)
    
    return x_boundary, t_boundary

def sample_film_physics_points():
    batch_size = 2048
    t = torch.rand(batch_size, 1, device=device, requires_grad=True)
    return t


## Mathematics of the Model Being Implemented

### Interior Equations
\frac{\partial C_{V_O}}{\partial t} = D_{V_O} \frac{\partial^2 C_{V_O}}{\partial x^2} + \frac{D_{V_O} F z_{V_O}}{RT} \frac{\partial C_{V_O}}{\partial x} \frac{\partial \phi_f}{\partial x} + \frac{D_{V_O} F z_{V_O}}{RT} C_{V_O} \frac{\partial^2 \phi_f}{\partial x^2}
                              
\frac{\partial C_{V_{Cr}}}{\partial t} = D_{V_{Cr}} \frac{\partial^2 C_{V_{Cr}}}{\partial x^2} + \frac{D_{V_{Cr}} F z_{V_{Cr}}}{RT} \frac{\partial C_{V_{Cr}}}{\partial x} \frac{\partial \phi_f}{\partial x} + \frac{D_{V_{Cr}} F z_{V_{Cr}}}{RT} C_{V_{Cr}} \frac{\partial^2 \phi_f}{\partial x^2}

\frac{\partial C_{Cr_i}}{\partial t} = D_{Cr_i} \frac{\partial^2 C_{Cr_i}}{\partial x^2} + \frac{D_{Cr_i} F z_{Cr_i}}{RT} \frac{\partial C_{Cr_i}}{\partial x} \frac{\partial \phi_f}{\partial x} + \frac{D_{Cr_i} F z_{Cr_i}}{RT} C_{Cr_i} \frac{\partial^2 \phi_f}{\partial x^2}

\frac{\partial C_i}{\partial t} = D_i \frac{\partial^2 C_i}{\partial x^2} + \frac{D_i F z_i}{RT} \left( \frac{\partial C_i}{\partial x} \frac{\partial \phi}{\partial x} + C_i \frac{\partial^2 \phi}{\partial x^2} \right)


In [54]:
#Import the loss functions
def loss(residual):
    return torch.mean(residual**2)


#Analytically Prescribed Potentials per Seyeux
def compute_potentials(x, t):
    """Compute analytical potentials from Seyeux model"""
    
    # Dynamic interfacial potentials
    phi_mf = phi_mf_0  # constant
    phi_fs = phi_fs_0 + alpha * delta_V * torch.exp(-x / x_d)
    phi_f = phi_f_i + delta_V * (1 - alpha * torch.exp(-x / x_d))
    
    return phi_mf, phi_fs, phi_f

def compute_potential_derivatives(x, t):
    """Compute analytical derivatives of potentials"""
    
    # φ_f derivatives
    phi_f_x = (delta_V * alpha / x_d) * torch.exp(-x / x_d)  # ∂φ_f/∂x
    
    # φ_fs derivatives  
    phi_fs_x = -(alpha * delta_V / x_d) * torch.exp(-x / x_d)  # ∂φ_fs/∂x
    
    # φ_mf derivatives
    phi_mf_x = torch.zeros_like(x)  # constant, so derivative = 0

    phi_f_xx = -(delta_V * alpha/ x_d**2) * torch.exp(-x/x_d)

    phi_fs_xx = (alpha * delta_V / x_d**2) * torch.exp(-x / x_d)  
    
    return phi_f_x, phi_fs_x, phi_mf_x, phi_f_xx, phi_fs_xx


#Compute the Interior PDE's

def interior_loss(x: torch.Tensor, t: torch.Tensor):
    grads = compute_gradients(x, t)
    phi_f_x, phi_fs_x, phi_mf_x, phi_f_xx, phi_fs_xx = compute_potential_derivatives(x,t)

    # Calculate the PDE residuals for each species

    ov_residual = grads.c_ov_t - (D_vo*grads.c_ov_xx + 
                                D_vo*F*z_ov*(1/(R*T))*grads.c_ov_x*phi_f_x + 
                                D_vo*grads.c_ov*z_ov*(F/(R*T))*phi_f_xx)

    cv_residual = grads.c_cv_t - (D_MCr*grads.c_cv_xx + 
                                D_MCr*F*z_MCr*(1/(R*T))*grads.c_cv_x*phi_f_x + 
                                D_MCr*grads.c_cv*z_MCr*(F/(R*T))*phi_f_xx)

    cir_residual = grads.c_cir_t - (D_ICr*grads.c_cir_xx + 
                                    D_ICr*F*z_ICr*(1/(R*T))*grads.c_cir_x*phi_f_x + 
                                    D_ICr*grads.c_cir*z_ICr*(F/(R*T))*phi_f_xx)

    
    interior = loss(ov_residual) + loss(cv_residual) + loss(cir_residual)

    return interior, ov_residual, cv_residual, cir_residual


def boundary_loss(x:torch.Tensor, t:torch.Tensor):

    grads = compute_gradients(x,t)
    phi_mf, phi_fs, phi_f = compute_potentials(x,t)
    phi_f_x, phi_fs_x, phi_mf_x, phi_f_xx, phi_fs_xx = compute_potential_derivatives(x,t)

    batch_size = x.shape[0]
    half_batch = batch_size // 2
    
    # Split into metal/film and film/solution interface points
    mf_indices = torch.arange(half_batch)
    fs_indices = torch.arange(half_batch, batch_size)
    
    # Use interior points (middle of film) as reference
    # Sample some interior points for reference
    L_pred = L_net(t)
    x_interior = 0.5 * L_pred  # Middle of film
    
    # Get interior concentrations as reference
    interior_inputs = torch.cat([x_interior, t], dim=1)
    c_ov_interior = ov_net(interior_inputs)
    c_cv_interior = cv_net(interior_inputs) 
    c_cir_interior = cir_net(interior_inputs)
    
    # Dynamic potentials
    phi_fs = phi_fs_0 + alpha * delta_V * torch.exp(-x[fs_indices] / x_d)

    # Metal/film interface ratios - expand to match batch size
    ratio_mf_ov_scalar = torch.exp(torch.tensor((delta_G1 + z_ov*F*phi_mf) / (R*T), device=device))
    ratio_mf_cv_scalar = torch.exp(torch.tensor((delta_G2 + z_MCr*F*phi_mf) / (R*T), device=device))
    ratio_mf_cir_scalar = torch.exp(torch.tensor((delta_G3 + z_ICr*F*phi_mf) / (R*T), device=device))

    # Expand to match half_batch size
    ratio_mf_ov = ratio_mf_ov_scalar.expand(half_batch, 1)
    ratio_mf_cv = ratio_mf_cv_scalar.expand(half_batch, 1)
    ratio_mf_cir = ratio_mf_cir_scalar.expand(half_batch, 1)

    # Film/solution interface ratios  
    ratio_fs_ov = torch.exp((delta_G8 + z_ov*F*phi_fs) / (R*T))     # Reaction 8
    ratio_fs_cv = torch.exp((delta_G9 + z_MCr*F*phi_fs) / (R*T))    # Reaction 9
    ratio_fs_cir = torch.exp((delta_G10 + z_ICr*F*phi_fs) / (R*T))  # Reaction 10

    # Calculate residuals
    mf_ov_residual = (grads.c_ov[mf_indices] / c_ov_interior[:half_batch]) - ratio_mf_ov
    mf_cv_residual = (grads.c_cv[mf_indices] / c_cv_interior[:half_batch]) - ratio_mf_cv
    mf_cir_residual = (grads.c_cir[mf_indices] / c_cir_interior[:half_batch]) - ratio_mf_cir
    
    fs_ov_residual = (grads.c_ov[fs_indices] / c_ov_interior[half_batch:]) - ratio_fs_ov
    fs_cv_residual = (grads.c_cv[fs_indices] / c_cv_interior[half_batch:]) - ratio_fs_cv
    fs_cir_residual = (grads.c_cir[fs_indices] / c_cir_interior[half_batch:]) - ratio_fs_cir
    
    # Calculate losses
    mf_ov_loss = loss(mf_ov_residual)
    mf_cv_loss = loss(mf_cv_residual)
    mf_cir_loss = loss(mf_cir_residual)
    
    fs_ov_loss = loss(fs_ov_residual)
    fs_cv_loss = loss(fs_cv_residual)
    fs_cir_loss = loss(fs_cir_residual)
    
    mf_boundary_loss = mf_ov_loss + mf_cv_loss + mf_cir_loss
    fs_boundary_loss = fs_ov_loss + fs_cv_loss + fs_cir_loss
    boundary_loss = mf_boundary_loss + fs_boundary_loss
    
    residuals = {
        "mf_ov": mf_ov_residual,
        "mf_cv": mf_cv_residual,
        "mf_cir": mf_cir_residual,
        "fs_ov":fs_ov_residual,
        "fs_cv": fs_cv_residual,
        "fs_cir":fs_cir_residual
    }

    return boundary_loss, residuals
    
def compute_film_physics_loss(t):
    """
    Compute film growth physics loss based on Seyeux flux formulation.
    
    Film growth: dx/dt = (Ω/N_V) * [J_VO + J_MCr + J_ICr]
    """
    batch_size = t.shape[0]
    
    # Get film thickness and its time derivative
    L_inputs = t
    L_pred = L_net(L_inputs)
    L_t = _grad(L_pred, t)
    
    # Sample points at film/solution interface (x = L) to compute fluxes
    x_interface = L_pred  # x = L(t)
    interface_inputs = torch.cat([x_interface, t], dim=1)
    
    # Get concentrations and gradients at interface
    c_ov = ov_net(interface_inputs)
    c_cv = cv_net(interface_inputs) 
    c_cir = cir_net(interface_inputs)
    
    c_ov_x = _grad(c_ov, x_interface)
    c_cv_x = _grad(c_cv, x_interface)
    c_cir_x = _grad(c_cir, x_interface)
    
    # Analytical potential derivatives at interface
    phi_f_x = (delta_V * alpha / x_d) * torch.exp(-x_interface / x_d)
    
    # Calculate species fluxes using Nernst-Planck
    # J_i = -D_i * (∂C_i/∂x + (z_i*F*C_i/RT) * ∂φ/∂x)
    J_ov = -D_vo * (c_ov_x + (z_ov*F*c_ov/(R*T)) * phi_f_x)
    J_cv = -D_MCr * (c_cv_x + (z_MCr*F*c_cv/(R*T)) * phi_f_x)
    J_cir = -D_ICr * (c_cir_x + (z_ICr*F*c_cir/(R*T)) * phi_f_x)
    
    # Total flux contributing to film growth
    J_total = J_ov + J_cv + J_cir
    
    # Expected film growth rate from flux
    expected_growth_rate = (Omega / N_V) * J_total
    
    # Film physics residual
    film_residual = L_t - expected_growth_rate
    
    # Loss
    film_loss = loss(film_residual)
    
    return film_loss, film_residual



In [55]:
#Compute the Total Loss and Intialize any training parameters

def total_loss():
    x_interior, t_interior = sample_interior_points()
    x_boundary, t_boundary = sample_boundary_points()
    t_film = sample_film_physics_points()

    interior, ov_residual, cv_residual, cir_residual = interior_loss(x_interior,t_interior)
    boundary,residuals = boundary_loss(x_boundary,t_boundary)
    film_loss, film_residual = compute_film_physics_loss(t_film)

    total_loss = interior + (1e-15)*boundary + film_loss 

    return total_loss, interior, (1e-15)*boundary,film_loss


loss_history = {
    'total': [], 'interior': [], 'boundary': [], 'film': []
}

lr = 1e-3
max_steps = 30000
print_freq = 100


In [56]:

optimizer=torch.optim.Adam([{'params': total_model_parameters}],lr=lr)
for step in tqdm(range(max_steps),desc="Training Status"):
    cv_net.train()
    ov_net.train()
    cir_net.train()
    L_net.train()
    optimizer.zero_grad()


    loss_val, interior, boundary, film_loss = total_loss()
    loss_total = interior + boundary + film_loss
    loss_val.backward()

    optimizer.step()
        
    loss_history["total"].append(loss_total.item())
    loss_history['interior'].append(interior.item())
    loss_history['boundary'].append(boundary.item())
    loss_history['film'].append(film_loss.item())

    if step % print_freq == 0:
        tqdm.write(f"Total:{loss_total}, interior:{interior}, boundary:{boundary},film:{film_loss} at step:{step}")

Training Status:   0%|          | 0/30000 [00:00<?, ?it/s]

Total:1.3435707092285156, interior:0.0035064537078142166, boundary:1.3395941257476807,film:0.00047014132724143565 at step:0
Total:1.3395966291427612, interior:2.4897085495467763e-06, boundary:1.3395941257476807,film:5.061082930524208e-08 at step:100
Total:1.339595913887024, interior:1.7382802752763382e-06, boundary:1.3395941257476807,film:2.890074668471243e-08 at step:200
Total:1.3395953178405762, interior:1.1691121244439273e-06, boundary:1.3395941257476807,film:1.6739043928737374e-08 at step:300
Total:1.339595079421997, interior:9.28295037283533e-07, boundary:1.3395941257476807,film:9.38841360209608e-09 at step:400
Total:1.339594841003418, interior:7.25257223166409e-07, boundary:1.3395941257476807,film:5.863178387244261e-09 at step:500
Total:1.3395949602127075, interior:6.023690275469562e-07, boundary:1.3395943641662598,film:4.591556912458827e-09 at step:600
Total:1.339594841003418, interior:4.5435916717906366e-07, boundary:1.3395943641662598,film:3.667703918353027e-09 at step:700
Tot

KeyboardInterrupt: 