# Physics-Informed Neural Network (PINN) for Granular Segregation - Phase 1

This notebook implements a PINN to solve the non-dimensionalized segregation PDE.

## Non-dimensionalized PDE

The governing equation is:

$$\frac{\partial c}{\partial \tilde{t}} + \tilde{u} \frac{\partial c}{\partial \tilde{x}} + \tilde{w} \frac{\partial c}{\partial \tilde{z}} + \Lambda(1 - \tilde{x}) \frac{\partial}{\partial \tilde{z}} \left[ g(\tilde{z}) c(1-c) \right] = \frac{\partial}{\partial \tilde{z}} \left[ \frac{1}{Pe} \frac{\partial c}{\partial \tilde{z}} \right]$$

Where:
- $c$ is the concentration (volume fraction of small particles)
- $\tilde{x} \in [0, 1]$, $\tilde{z} \in [-1, 0]$, $\tilde{t} \in [0, t_{end}]$ are dimensionless coordinates
- $\Lambda$ is the segregation parameter
- $Pe$ is the Péclet number
- $\tilde{u}$ and $\tilde{w}$ are dimensionless velocity profiles
- $g(\tilde{z})$ is a shear-rate-like profile

## Boundary Conditions

- **Inlet** ($\tilde{x} = 0$): Dirichlet $c = 0.5$ (well-mixed feed)
- **Top/Bottom** ($\tilde{z} = 0, -1$): Neumann $(1/Pe) \partial c/\partial \tilde{z} = \Lambda(1-\tilde{x}) g(\tilde{z}) c(1-c)$
- **Outlet** ($\tilde{x} = 1$): Natural boundary condition (zero diffusive flux)
- **Initial condition**: $c = 0.5$ everywhere at $\tilde{t} = 0$


In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import grad
import time

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


In [None]:
# Parameters from the paper (section 3)
Lambda = 0.3949      # segregation parameter Λ
Pe = 27.5387        # Péclet number
D = 1.0 / Pe      # vertical diffusivity
k = 2.3           # exponential velocity profile parameter
tEnd = 20.0        # final dimensionless time

# Domain bounds
x_min, x_max = 0.0, 1.0    # x̃ ∈ [0, 1]
z_min, z_max = -1.0, 0.0   # z̃ ∈ [-1, 0]
t_min, t_max = 0.0, tEnd   # t̃ ∈ [0, tEnd]

print(f"Parameters: Λ={Lambda}, Pe={Pe}, D={D:.6f}, k={k}")
print(f"Domain: x̃ ∈ [{x_min}, {x_max}], z̃ ∈ [{z_min}, {z_max}], t̃ ∈ [{t_min}, {t_max}]")


In [None]:
# Helper functions for velocity profiles and g(z̃)
def u_tilde(x, z, k):
    """Dimensionless velocity in x direction"""
    # Convert k to tensor if it's not already a tensor
    if not isinstance(k, torch.Tensor):
        k = torch.tensor(float(k), dtype=x.dtype, device=x.device)
    factor = 0.5 * k * (1 - torch.exp(-k))
    return factor * (1 - x) * torch.exp(k * z)

def w_tilde(z, k):
    """Dimensionless velocity in z direction"""
    # Convert k to tensor if it's not already a tensor
    if not isinstance(k, torch.Tensor):
        k = torch.tensor(float(k), dtype=z.dtype, device=z.device)
    factor = 0.5 * (1 - torch.exp(-k))
    return factor * (torch.exp(k * z) - 1)

def g_profile(z, k):
    """Shear-rate-like profile g(z̃) and its derivative"""
    # Convert k to tensor if it's not already a tensor
    if not isinstance(k, torch.Tensor):
        k = torch.tensor(float(k), dtype=z.dtype, device=z.device)
    g_val = (k**2 / (2 * (1 - torch.exp(-k)))) * torch.exp(k * z)
    g_prime = k * g_val
    return g_val, g_prime


In [None]:
# Neural Network Architecture
class PINN(nn.Module):
    def __init__(self, layers):
        super(PINN, self).__init__()
        self.layers = nn.ModuleList()
        
        for i in range(len(layers) - 1):
            self.layers.append(nn.Linear(layers[i], layers[i+1]))
        
        # Initialize weights using Xavier initialization
        for layer in self.layers:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
    
    def forward(self, x, z, t):
        # Concatenate inputs: [x, z, t]
        inputs = torch.cat([x, z, t], dim=1)
        
        # Forward pass through hidden layers
        for i, layer in enumerate(self.layers[:-1]):
            inputs = torch.tanh(layer(inputs))
        
        # Output layer (no activation for regression)
        # output = self.layers[-1](inputs)
        
        # Output layer with sigmoid activation to bound concentration between 0 and 1
        output = torch.sigmoid(self.layers[-1](inputs))
        return output

# Network architecture: 3 inputs (x, z, t) -> hidden layers -> 1 output (c)
layers = [3, 64, 64, 64, 64, 1]
model = PINN(layers).to(device)
print(f"Model architecture: {layers}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")


In [None]:
def compute_pde_residual(x, z, t, model, Lambda, Pe, k):
    """
    Compute the PDE residual:
    ∂c/∂t̃ = (1/Pe) ∂²c/∂z̃² - ũ ∂c/∂x̃ - w̃ ∂c/∂z̃ 
            - Λ(1-x̃)[ g'(z̃)c(1-c) + g(z̃)(1-2c)∂c/∂z̃ ]
    """
    # Clone and detach tensors to avoid modifying originals, then enable gradient computation
    x = x.clone().detach().requires_grad_(True)
    z = z.clone().detach().requires_grad_(True)
    t = t.clone().detach().requires_grad_(True)
    
    # Forward pass
    c = model(x, z, t)
    
    # Compute gradients
    c_t = grad(c, t, grad_outputs=torch.ones_like(c), create_graph=True)[0]
    c_x = grad(c, x, grad_outputs=torch.ones_like(c), create_graph=True)[0]
    c_z = grad(c, z, grad_outputs=torch.ones_like(c), create_graph=True)[0]
    
    # Second derivative in z
    c_zz = grad(c_z, z, grad_outputs=torch.ones_like(c_z), create_graph=True)[0]
    
    # Velocity profiles
    u_tilde_val = u_tilde(x, z, k)
    w_tilde_val = w_tilde(z, k)
    
    # g profile and its derivative
    g_val, g_prime = g_profile(z, k)
    
    # Diffusion term
    diff_term = (1.0 / Pe) * c_zz
    
    # Advection term
    adv_term = u_tilde_val * c_x + w_tilde_val * c_z
    
    # Segregation term
    seg_term = Lambda * (1 - x) * (g_prime * c * (1 - c) + g_val * (1 - 2*c) * c_z)
    
    # PDE residual: ∂c/∂t - (diffusion - advection - segregation)
    residual = c_t - (diff_term - adv_term - seg_term)
    
    return residual, c


In [None]:
# Generate training points
def generate_training_data(n_pde, n_bc_inlet, n_bc_top, n_bc_bottom, n_bc_outlet, n_ic):
    """
    Generate collocation points for:
    - PDE residual points (interior)
    - Boundary condition points
    - Initial condition points
    """
    # PDE collocation points (interior domain)
    x_pde = torch.rand(n_pde, 1, device=device) * (x_max - x_min) + x_min
    z_pde = torch.rand(n_pde, 1, device=device) * (z_max - z_min) + z_min
    t_pde = torch.rand(n_pde, 1, device=device) * (t_max - t_min) + t_min
    
    # Boundary: x̃ = 0 (inlet, Dirichlet: c = 0.5)
    x_bc_inlet = torch.zeros(n_bc_inlet, 1, device=device)
    z_bc_inlet = torch.rand(n_bc_inlet, 1, device=device) * (z_max - z_min) + z_min
    t_bc_inlet = torch.rand(n_bc_inlet, 1, device=device) * (t_max - t_min) + t_min
    
    # Boundary: z̃ = 0 (top, Neumann)
    x_bc_top = torch.rand(n_bc_top, 1, device=device) * (x_max - x_min) + x_min
    z_bc_top = torch.zeros(n_bc_top, 1, device=device)
    t_bc_top = torch.rand(n_bc_top, 1, device=device) * (t_max - t_min) + t_min
    
    # Boundary: z̃ = -1 (bottom, Neumann)
    x_bc_bottom = torch.rand(n_bc_bottom, 1, device=device) * (x_max - x_min) + x_min
    z_bc_bottom = torch.ones(n_bc_bottom, 1, device=device) * z_min
    t_bc_bottom = torch.rand(n_bc_bottom, 1, device=device) * (t_max - t_min) + t_min
    
    # Boundary: x̃ = 1 (outlet, zero diffusive flux - natural BC, handled by PDE)
    x_bc_outlet = torch.ones(n_bc_outlet, 1, device=device)
    z_bc_outlet = torch.rand(n_bc_outlet, 1, device=device) * (z_max - z_min) + z_min
    t_bc_outlet = torch.rand(n_bc_outlet, 1, device=device) * (t_max - t_min) + t_min
    
    # Initial condition: t̃ = 0, c = 0.5 everywhere
    x_ic = torch.rand(n_ic, 1, device=device) * (x_max - x_min) + x_min
    z_ic = torch.rand(n_ic, 1, device=device) * (z_max - z_min) + z_min
    t_ic = torch.zeros(n_ic, 1, device=device)
    
    return {
        'pde': (x_pde, z_pde, t_pde),
        'bc_inlet': (x_bc_inlet, z_bc_inlet, t_bc_inlet),
        'bc_top': (x_bc_top, z_bc_top, t_bc_top),
        'bc_bottom': (x_bc_bottom, z_bc_bottom, t_bc_bottom),
        'bc_outlet': (x_bc_outlet, z_bc_outlet, t_bc_outlet),
        'ic': (x_ic, z_ic, t_ic)
    }

# Number of training points
n_pde = 10000
n_bc_inlet = 1000
n_bc_top = 1000
n_bc_bottom = 1000
n_bc_outlet = 1000
n_ic = 2000

train_data = generate_training_data(n_pde, n_bc_inlet, n_bc_top, n_bc_bottom, n_bc_outlet, n_ic)
print("Training data generated:")
for key, (x, z, t) in train_data.items():
    print(f"  {key}: {x.shape[0]} points")


In [None]:
def compute_loss(model, train_data, Lambda, Pe, k, weights):
    """
    Compute total loss:
    L = w_pde * L_pde + w_bc_inlet * L_bc_inlet + w_bc_top * L_bc_top 
        + w_bc_bottom * L_bc_bottom + w_ic * L_ic
    """
    total_loss = 0.0
    
    # PDE residual loss
    x_pde, z_pde, t_pde = train_data['pde']
    residual, _ = compute_pde_residual(x_pde, z_pde, t_pde, model, Lambda, Pe, k)
    loss_pde = torch.mean(residual**2)
    total_loss += weights['pde'] * loss_pde
    
    # Boundary condition: x̃ = 0, c = 0.5 (Dirichlet)
    x_in, z_in, t_in = train_data['bc_inlet']
    c_in = model(x_in, z_in, t_in)
    loss_bc_inlet = torch.mean((c_in - 0.5)**2)
    total_loss += weights['bc_inlet'] * loss_bc_inlet
    
    # Boundary condition: z̃ = 0 (top, Neumann: (1/Pe) ∂c/∂z = Λ(1-x) g(z) c(1-c))
    # Clone tensors to avoid modifying originals
    x_top, z_top, t_top = train_data['bc_top']
    x_top = x_top.clone().detach().requires_grad_(True)
    z_top = z_top.clone().detach().requires_grad_(True)
    t_top = t_top.clone().detach()
    c_top = model(x_top, z_top, t_top)
    c_z_top = grad(c_top, z_top, grad_outputs=torch.ones_like(c_top), create_graph=True)[0]
    g_val_top, _ = g_profile(z_top, k)
    bc_top_target = Lambda * (1 - x_top) * g_val_top * c_top * (1 - c_top)
    bc_top_pred = (1.0 / Pe) * c_z_top
    loss_bc_top = torch.mean((bc_top_pred - bc_top_target)**2)
    total_loss += weights['bc_top'] * loss_bc_top
    
    # Boundary condition: z̃ = -1 (bottom, Neumann: (1/Pe) ∂c/∂z = Λ(1-x) g(z) c(1-c))
    # Clone tensors to avoid modifying originals
    x_bot, z_bot, t_bot = train_data['bc_bottom']
    x_bot = x_bot.clone().detach().requires_grad_(True)
    z_bot = z_bot.clone().detach().requires_grad_(True)
    t_bot = t_bot.clone().detach()
    c_bot = model(x_bot, z_bot, t_bot)
    c_z_bot = grad(c_bot, z_bot, grad_outputs=torch.ones_like(c_bot), create_graph=True)[0]
    g_val_bot, _ = g_profile(z_bot, k)
    bc_bot_target = Lambda * (1 - x_bot) * g_val_bot * c_bot * (1 - c_bot)
    bc_bot_pred = (1.0 / Pe) * c_z_bot
    loss_bc_bottom = torch.mean((bc_bot_pred - bc_bot_target)**2)
    total_loss += weights['bc_bottom'] * loss_bc_bottom
    
    # Initial condition: t̃ = 0, c = 0.5
    x_ic, z_ic, t_ic = train_data['ic']
    c_ic = model(x_ic, z_ic, t_ic)
    loss_ic = torch.mean((c_ic - 0.5)**2)
    total_loss += weights['ic'] * loss_ic
    
    # Outlet boundary (x̃ = 1): zero diffusive flux is naturally satisfied
    # by the PDE, so we don't need to enforce it explicitly
    
    return total_loss, {
        'pde': loss_pde.item(),
        'bc_inlet': loss_bc_inlet.item(),
        'bc_top': loss_bc_top.item(),
        'bc_bottom': loss_bc_bottom.item(),
        'ic': loss_ic.item()
    }


In [None]:
# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1000)

# Loss weights (can be adjusted)
loss_weights = {
    'pde': 1.0,
    'bc_inlet': 10.0,
    'bc_top': 10.0,
    'bc_bottom': 10.0,
    'ic': 10.0
}

# Training parameters
n_epochs = 20000
print_interval = 500
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs) # option 2: cosine annealing

# Training loop
loss_history = []
start_time = time.time()

print("Starting training...")
print(f"{'Epoch':<10} {'Total Loss':<15} {'PDE':<15} {'BC Inlet':<15} {'BC Top':<15} {'BC Bottom':<15} {'IC':<15}")
print("-" * 100)

try:
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        
        # Compute loss
        total_loss, loss_dict = compute_loss(model, train_data, Lambda, Pe, k, loss_weights)
        
        # Backward pass
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
#         scheduler.step(total_loss)
        scheduler.step()
        
        # Store loss history
        loss_history.append({
            'epoch': epoch,
            'total': total_loss.item(),
            **loss_dict
        })
        
        # Print progress
        if (epoch + 1) % print_interval == 0 or epoch == 0:
            print(f"{epoch+1:<10} {total_loss.item():<15.6e} {loss_dict['pde']:<15.6e} "
                  f"{loss_dict['bc_inlet']:<15.6e} {loss_dict['bc_top']:<15.6e} "
                  f"{loss_dict['bc_bottom']:<15.6e} {loss_dict['ic']:<15.6e}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user.")
except Exception as e:
    print(f"\nError during training: {e}")
    import traceback
    traceback.print_exc()

elapsed_time = time.time() - start_time
print(f"\nTraining completed in {elapsed_time:.2f} seconds")


In [None]:
# Plot loss history
if len(loss_history) == 0:
    print("Warning: loss_history is empty. Please run the training cell first.")
else:
    loss_history_array = np.array([(h['epoch'], h['total'], h['pde'], h['bc_inlet'], 
                                    h['bc_top'], h['bc_bottom'], h['ic']) for h in loss_history])
    
    fig, axes = plt.subplots(2, 1, figsize=(10, 8))
    
    # Total loss
    axes[0].semilogy(loss_history_array[:, 0], loss_history_array[:, 1], 'b-', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Total Loss', fontsize=12)
    axes[0].set_title('Total Loss History', fontsize=14)
    axes[0].grid(True, alpha=0.3)
    if 'n_epochs' in globals():
        axes[0].set_xlim([0, n_epochs])
    else:
        axes[0].set_xlim([0, loss_history_array[-1, 0] + 1])
    
    # Individual losses
    axes[1].semilogy(loss_history_array[:, 0], loss_history_array[:, 2], 'r-', label='PDE', linewidth=2)
    axes[1].semilogy(loss_history_array[:, 0], loss_history_array[:, 3], 'g-', label='BC Inlet', linewidth=2)
    axes[1].semilogy(loss_history_array[:, 0], loss_history_array[:, 4], 'm-', label='BC Top', linewidth=2)
    axes[1].semilogy(loss_history_array[:, 0], loss_history_array[:, 5], 'c-', label='BC Bottom', linewidth=2)
    axes[1].semilogy(loss_history_array[:, 0], loss_history_array[:, 6], 'y-', label='IC', linewidth=2)
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Loss', fontsize=12)
    axes[1].set_title('Individual Loss Components', fontsize=14)
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3)
    if 'n_epochs' in globals():
        axes[1].set_xlim([0, n_epochs])
    else:
        axes[1].set_xlim([0, loss_history_array[-1, 0] + 1])
    
    plt.tight_layout()
    plt.show()


In [None]:
# Visualize solution at final time
model.eval()
with torch.no_grad():
    # Create grid for visualization
    n_x, n_z = 100, 100
    x_vis = torch.linspace(x_min, x_max, n_x, device=device).reshape(-1, 1)
    z_vis = torch.linspace(z_min, z_max, n_z, device=device).reshape(-1, 1)
    X_vis, Z_vis = torch.meshgrid(x_vis.squeeze(), z_vis.squeeze(), indexing='ij')
    
    # Evaluate at final time
    t_final = torch.ones_like(X_vis.reshape(-1, 1)) * t_max
    c_final = model(X_vis.reshape(-1, 1), Z_vis.reshape(-1, 1), t_final)
    C_final = c_final.reshape(n_x, n_z).cpu().numpy()
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, 8))
    levels = np.linspace(0.0, 1.0, 51)
    im = ax.contourf(X_vis.cpu().numpy(), Z_vis.cpu().numpy(), 1 - C_final, levels=levels, cmap='hot')
    ax.set_xlabel(r'$\tilde{x}$', fontsize=14)
    ax.set_ylabel(r'$\tilde{z}$', fontsize=14)
    ax.set_title(f'Concentration $c_s$ at $\\tilde{{t}}={t_max}$ (Λ={Lambda}, Pe={Pe})', fontsize=14)
    cbar = plt.colorbar(im, ax=ax, label='Concentration $c_s$')
    ax.set_aspect('equal')
    plt.tight_layout()
    plt.show()


In [None]:
# Visualize solution evolution over time
model.eval()
with torch.no_grad():
    n_x, n_z = 100, 100
    x_vis = torch.linspace(x_min, x_max, n_x, device=device).reshape(-1, 1)
    z_vis = torch.linspace(z_min, z_max, n_z, device=device).reshape(-1, 1)
    X_vis, Z_vis = torch.meshgrid(x_vis.squeeze(), z_vis.squeeze(), indexing='ij')
    
    # Plot at different times
    times = [0.0, 2.5, 5.0, 7.5, 10.0]
    levels = np.linspace(0.0, 1.0, 51)
    fig, axes = plt.subplots(1, len(times), figsize=(20, 4))
    
    for idx, t_val in enumerate(times):
        t_vis = torch.ones_like(X_vis.reshape(-1, 1)) * t_val
        c_vis = model(X_vis.reshape(-1, 1), Z_vis.reshape(-1, 1), t_vis)
        C_vis = c_vis.reshape(n_x, n_z).cpu().numpy()
        
        im = axes[idx].contourf(X_vis.cpu().numpy(), Z_vis.cpu().numpy(), 1-C_vis, levels=levels, cmap='hot')
        im.set_clim(0.0, 1.0)  # Explicitly set color limits
        axes[idx].set_xlabel(r'$\tilde{x}$', fontsize=12)
        axes[idx].set_ylabel(r'$\tilde{z}$', fontsize=12)
        axes[idx].set_title(f'$\\tilde{{t}}={t_val}$', fontsize=12)
        axes[idx].set_aspect('equal')
        cbar = plt.colorbar(im, ax=axes[idx])
        # cbar.set_clim(0.0, 1.0)  # Explicitly set colorbar limits
    
    plt.suptitle(f'Concentration Evolution (Λ={Lambda}, Pe={Pe})', fontsize=14)
    plt.tight_layout()
    plt.show()


In [None]:
# Check PDE residual at final time
model.eval()
# Sample points in the domain
n_check = 1000
x_check = torch.rand(n_check, 1, device=device) * (x_max - x_min) + x_min
z_check = torch.rand(n_check, 1, device=device) * (z_max - z_min) + z_min
t_check = torch.ones(n_check, 1, device=device) * t_max

residual, c_check = compute_pde_residual(x_check, z_check, t_check, model, Lambda, Pe, k)

# Detach for statistics (no need for gradients)
residual_detached = residual.detach()

print(f"PDE Residual Statistics at t={t_max}:")
print(f"  Mean: {torch.mean(torch.abs(residual_detached)).item():.6e}")
print(f"  Std:  {torch.std(residual_detached).item():.6e}")
print(f"  Max:  {torch.max(torch.abs(residual_detached)).item():.6e}")
print(f"  Min:  {torch.min(torch.abs(residual_detached)).item():.6e}")

# Plot residual distribution
fig, ax = plt.subplots(figsize=(8, 6))
ax.hist(residual_detached.cpu().numpy().flatten(), bins=50, edgecolor='black', alpha=0.7)
ax.set_xlabel('PDE Residual', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Distribution of PDE Residuals at Final Time', fontsize=14)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()


In [None]:
# Save the model
file_name = './models/pinn_forward.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'Lambda': Lambda,
    'Pe': Pe,
    'k': k,
    'tEnd': tEnd,
    'loss_history': loss_history,
    'layers': layers,  # Save network architecture
    'x_min': x_min, 'x_max': x_max,
    'z_min': z_min, 'z_max': z_max,
    't_min': t_min, 't_max': t_max,
    'pde points': n_pde,
    'bc_inlet points': n_bc_inlet,
    'bc_top points': n_bc_top,
    'bc_bottom points': n_bc_bottom,
    'bc_outlet points': n_bc_outlet,
    'ic points': n_ic
}, file_name)
print(f"Model saved to '{file_name}'")


In [None]:
# Optional: Load a saved model
# checkpoint = torch.load('pinn_segregation_model.pth')
# model.load_state_dict(checkpoint['model_state_dict'])
# Lambda = checkpoint['Lambda']
# Pe = checkpoint['Pe']
# k = checkpoint['k']
# tEnd = checkpoint['tEnd']
# loss_history = checkpoint['loss_history']
# print("Model loaded successfully!")
