In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
from struct import pack

# Set random seed for reproducibility
torch.manual_seed(42)

# Define the PINN model
class PINN(nn.Module):
    def __init__(self, num_hidden=128, num_layers=4):
        super(PINN, self).__init__()
        layers = []
        # Input: (r, z, t)
        layers.append(nn.Linear(3, num_hidden))
        layers.append(nn.Tanh())
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(num_hidden, num_hidden))
            layers.append(nn.Tanh())
        # Output: T (temperature)
        layers.append(nn.Linear(num_hidden, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# Material properties and parameters (from the document)
def thermal_conductivity(T):
    Tm = 1944.0  # Melting temperature (K)
    k = torch.where(T < Tm,
                    22.5 + 8e-6 * (T - 650.0)**2,
                    22.5 + 8e-6 * (Tm - 650.0)**2)
    return k

def heat_capacity(T):
    Tm = 1944.0  # Melting temperature (K)
    Tb = 3560.0  # Boiling temperature (K)
    c = torch.where(T < Tm,
                    4500.0 * (600.0 + 1.07e-5 * torch.abs(T - 500.0)**2.5),
                    torch.where(T < Tb,
                                950.0 * 4110.0,
                                700.0 * 4110.0))
    return c

def smoothed_delta(T):
    Tm = 1944.0  # Melting temperature (K)
    delta_Tm = 40.0  # Width of melting interval (K)
    return (1.0 / (delta_Tm * np.sqrt(2 * np.pi))) * torch.exp(-((T - Tm)**2) / (2 * delta_Tm**2))

# Laser pulse intensity
def laser_pulse(r, t):
    t0 = 40e-6  # Pulse front duration (s)
    r0 = 25e-6  # Beam radius (m)
    P = 4000.0  # Power for 40 µJ at f=100 kHz (W)
    f = 100e3  # Pulse frequency (Hz)
    Q0 = P / (f * np.pi * r0**2)  # Energy density per pulse
    t_shifted = t - torch.floor(t * f) / f  # Time within one pulse period
    q_p = torch.where(t_shifted >= 0,
                      Q0 * (t_shifted / t0**2) * torch.exp(-t_shifted / t0) * torch.exp(-r**2 / r0**2),
                      torch.tensor(0.0, device=t.device))
    return q_p

# Generate collocation points
def generate_collocation_points(N_r=75, N_z=400, N_t=1000):
    r_max = 75e-6  # m
    z_max = 40e-6  # m
    t_max = 100e-6  # s
    r = torch.linspace(0, r_max, N_r)
    z = torch.linspace(0, z_max, N_z)
    t = torch.linspace(0, t_max, N_t)
    R, Z, T = torch.meshgrid(r, z, t, indexing='ij')
    points = torch.stack([R.flatten(), Z.flatten(), T.flatten()], dim=1)
    points.requires_grad_(True)
    return points, r, z, t

# Compute PDE residual
def compute_pde_residual(model, points):
    r = points[:, 0:1]
    z = points[:, 1:2]
    t = points[:, 2:3]
    T = model(points)
    
    # Compute derivatives
    T_t = torch.autograd.grad(T, t, grad_outputs=torch.ones_like(T), create_graph=True)[0]
    T_r = torch.autograd.grad(T, r, grad_outputs=torch.ones_like(T), create_graph=True)[0]
    T_z = torch.autograd.grad(T, z, grad_outputs=torch.ones_like(T), create_graph=True)[0]
    T_rr = torch.autograd.grad(T_r, r, grad_outputs=torch.ones_like(T_r), create_graph=True)[0]
    T_zz = torch.autograd.grad(T_z, z, grad_outputs=torch.ones_like(T_z), create_graph=True)[0]
    
    k = thermal_conductivity(T)
    c = heat_capacity(T)
    Lm = 1.43e9  # Latent heat (J/m^3)
    delta = smoothed_delta(T)
    
    # PDE: (c + Lm * delta(T - Tm)) * T_t = (1/r) * d/dr(r * k * T_r) + d/dz(k * T_z)
    term1 = (c + Lm * delta) * T_t
    term2 = (1.0 / r) * (r * k * T_r) + k * T_rr + T_r * torch.autograd.grad(k, r, grad_outputs=torch.ones_like(k), create_graph=True)[0]
    term3 = k * T_zz + T_z * torch.autograd.grad(k, z, grad_outputs=torch.ones_like(k), create_graph=True)[0]
    residual = term1 - (term2 + term3)
    return residual

# Boundary and initial conditions
def compute_bc_ic_loss(model, points, r, z, t):
    r_max = 75e-6
    z_max = 40e-6
    T0 = 290.0  # Initial temperature (K)
    R_coeff = 0.62  # Reflection coefficient
    
    # Initial condition: T(t=0) = T0
    ic_points = points[points[:, 2] == 0]
    if len(ic_points) > 0:
        T_ic = model(ic_points)
        ic_loss = torch.mean((T_ic - T0)**2)
    else:
        ic_loss = torch.tensor(0.0, device=points.device)
    
    # Boundary: dT/dr(r=0) = 0
    bc_r0 = points[points[:, 0] == 0]
    if len(bc_r0) > 0:
        T_r0 = model(bc_r0)
        T_r = torch.autograd.grad(T_r0, bc_r0, grad_outputs=torch.ones_like(T_r0), create_graph=True)[0][:, 0]
        bc_r0_loss = torch.mean(T_r**2)
    else:
        bc_r0_loss = torch.tensor(0.0, device=points.device)
    
    # Boundary: dT/dr(r=r_max) = 0
    bc_rmax = points[torch.isclose(points[:, 0], torch.tensor(r_max, device=points.device))]
    if len(bc_rmax) > 0:
        T_rmax = model(bc_rmax)
        T_r = torch.autograd.grad(T_rmax, bc_rmax, grad_outputs=torch.ones_like(T_rmax), create_graph=True)[0][:, 0]
        bc_rmax_loss = torch.mean(T_r**2)
    else:
        bc_rmax_loss = torch.tensor(0.0, device=points.device)
    
    # Boundary: -k * dT/dz(z=0) = (1-R) * q
    bc_z0 = points[points[:, 1] == 0]
    if len(bc_z0) > 0:
        T_z0 = model(bc_z0)
        T_z = torch.autograd.grad(T_z0, bc_z0, grad_outputs=torch.ones_like(T_z0), create_graph=True)[0][:, 1]
        k = thermal_conductivity(T_z0)
        q = laser_pulse(bc_z0[:, 0:1], bc_z0[:, 2:3])
        bc_z0_loss = torch.mean((-k * T_z - (1 - R_coeff) * q)**2)
    else:
        bc_z0_loss = torch.tensor(0.0, device=points.device)
    
    # Boundary: dT/dz(z=z_max) = 0
    bc_zmax = points[torch.isclose(points[:, 1], torch.tensor(z_max, device=points.device))]
    if len(bc_zmax) > 0:
        T_zmax = model(bc_zmax)
        T_z = torch.autograd.grad(T_zmax, bc_zmax, grad_outputs=torch.ones_like(T_zmax), create_graph=True)[0][:, 1]
        bc_zmax_loss = torch.mean(T_z**2)
    else:
        bc_zmax_loss = torch.tensor(0.0, device=points.device)
    
    return ic_loss + bc_r0_loss + bc_rmax_loss + bc_z0_loss + bc_zmax_loss

# Training function
def train_pinn(model, points, epochs=10000):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.5)
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        pde_residual = compute_pde_residual(model, points)
        pde_loss = torch.mean(pde_residual**2)
        bc_ic_loss = compute_bc_ic_loss(model, points, r, z, t)
        loss = pde_loss + 100.0 * bc_ic_loss  # Weight boundary conditions
        loss.backward()
        optimizer.step()
        scheduler.step()
        if epoch % 1000 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.6f}, PDE Loss: {pde_loss.item():.6f}, BC/IC Loss: {bc_ic_loss.item():.6f}")

# Save output in specified binary format
def save_temperature_field(model, r, z, t, filename="output_Trd.bin"):
    model.eval()
    Nr, Nz, Nt = len(r), len(z), len(t)
    with open(filename, 'wb') as f:
        for t_idx in range(Nt):
            for k in range(Nz):
                for p in range(Nr):
                    point = torch.tensor([[r[p], z[k], t[t_idx]]], dtype=torch.float32)
                    T = model(point).item()
                    f.write(pack('d', T))
    
    # Save time points and surface temperatures
    with open("output_ts.txt", 'w') as f:
        for t_idx in range(Nt):
            point_center = torch.tensor([[0.0, 0.0, t[t_idx]]], dtype=torch.float32)
            point_bottom = torch.tensor([[0.0, z[-1], t[t_idx]]], dtype=torch.float32)
            T_center = model(point_center).item()
            T_bottom = model(point_bottom).item()
            f.write(f"{t[t_idx]:.6e} {T_center:.6f} {T_bottom:.6f}\n")

# Main execution
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = PINN().to(device)
    points, r, z, t = generate_collocation_points()
    points = points.to(device)
    r, z, t = r.to(device), z.to(device), t.to(device)
    
    # Train the model
    train_pinn(model, points)
    
    # Save the results
    save_temperature_field(model, r.cpu().numpy(), z.cpu().numpy(), t.cpu().numpy())