In [37]:
import torch
import os
import sys
import matplotlib.pyplot as plt
import argparse
import numpy as np
from tqdm.auto import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

sys.path.append("../../")
from src.model.video_diffusion_pytorch_conv3d import Unet3D_with_Conv3D
from src.model.fno import FNO3D
from src.train.nuclear_thermal_coupling import load_nt_dataset_emb, cond_emb, normalize, renormalize
from src.utils.utils import L2_norm, get_parameter_net, plot_compare_2d, relative_error

In [42]:
device = "cuda"
iter = "iter1"
model_type = "FNO"

# neutron

In [43]:
train_which = "neutron"
dim = 8
emb = cond_emb(train_which, device=device)

In [None]:
cond, data = load_nt_dataset_emb(field=train_which, dataset=iter, device=device)
if model_type == "Unet":
    model_neu = Unet3D_with_Conv3D(
        dim=dim,
        cond_dim=0,
        channels=len(emb),
        out_dim=1,
        cond_emb=emb,
        time_cond=False,
        dim_mults=(1, 2, 4),
        use_sparse_linear_attn=False,
        attn_dim_head=16,
    ).to(device)

    model_neu.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/surrogateUnetneutron/" + iter + "_5000/model.pt")["model"]
    )
elif model_type == "ViT":
    model_neu = ViT(
        image_size=data.shape[-2:],
        image_patch_size=(8, 2),
        frames=data.shape[2],
        frame_patch_size=2,
        dim=128,
        depth=2,
        heads=8,
        mlp_dim=256,
        cond_emb=emb,
        channels=len(emb),
        out_channels=data.shape[1],
        dropout=0.1,
        emb_dropout=0.1,
    ).to(device)
    model_neu.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/surrogateViTneutron/" + iter + "_5000/model.pt")["model"]
    )
elif model_type == "FNO":
    model_neu = FNO3D(
        in_channels=len(emb),
        out_channels=data.shape[1],
        nr_fno_layers=3,
        fno_layer_size=8,
        fno_modes=[6, 16, 8],
        cond_emb=emb,
    ).to(device)
    model_neu.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/surrogateFNOneutron/" + iter + "_5000/model.pt")["model"]
    )

In [None]:
b = -32
for i in range(len(cond)):
    cond[i] = cond[i][b:]
data = data[b:]
with torch.no_grad():
    pred = model_neu(cond)
    pred = renormalize(pred, "neutron")
    data = renormalize(data, "neutron")
    rmse = relative_error(data, pred)
    mse = F.mse_loss(pred, data)
    # print(get_relative_error(model_neu, cond, data, batchsize=32))
rmse, mse

In [None]:
plot_compare_2d(data[-1, -1, -1], pred[-1, -1, -1])

# fuel

In [47]:
train_which = "solid"
dim = 8
emb = cond_emb(train_which, device=device)

In [None]:
cond, data = load_nt_dataset_emb(field=train_which, dataset=iter, device=device)
if model_type == "Unet":
    model_fuel = Unet3D_with_Conv3D(
        dim=dim,
        cond_dim=0,
        channels=len(emb),
        out_dim=1,
        cond_emb=emb,
        time_cond=False,
        dim_mults=(1, 2, 4),
        use_sparse_linear_attn=False,
        attn_dim_head=16,
    ).to(device)
    model_fuel.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/surrogateUnetsolid/" + iter + "_5000/model.pt")["model"]
    )
elif model_type == "ViT":
    model_fuel = ViT(
        image_size=data.shape[-2:],
        image_patch_size=(8, 2),
        frames=data.shape[2],
        frame_patch_size=2,
        dim=128,
        depth=2,
        heads=8,
        mlp_dim=256,
        cond_emb=emb,
        channels=len(emb),
        out_channels=data.shape[1],
        dropout=0.1,
        emb_dropout=0.1,
    ).to(device)
    model_fuel.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/surrogateViTsolid/" + iter + "_5000/model.pt")["model"]
    )
elif model_type == "FNO":
    model_fuel = FNO3D(
        in_channels=len(emb),
        out_channels=data.shape[1],
        nr_fno_layers=3,
        fno_layer_size=8,
        fno_modes=[6, 16, 4],
        cond_emb=emb,
    ).to(device)
    model_fuel.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/surrogateFNOsolid/" + iter + "_5000/model.pt")["model"]
    )

In [None]:
b = -32
for i in range(len(cond)):
    cond[i] = cond[i][b:]
data = data[b:]
with torch.no_grad():
    pred = model_fuel(cond)
    pred = renormalize(pred, "solid")
    data = renormalize(data, "solid")
    rmse = relative_error(data, pred)
    mse = F.mse_loss(pred, data)
    # print(get_relative_error(model_neu, cond, data, batchsize=32))
rmse, mse

In [None]:
plot_compare_2d(data[-1, -1, -1], pred[-1, -1, -1])

# fluid

In [51]:
train_which = "fluid"
dim = 16
emb = cond_emb(train_which, device=device)

In [None]:
cond, data = load_nt_dataset_emb(field=train_which, dataset=iter, device=device)
if model_type == "Unet":
    model_fluid = Unet3D_with_Conv3D(
        dim=dim,
        cond_dim=0,
        channels=len(emb),
        out_dim=4,
        cond_emb=emb,
        time_cond=False,
        dim_mults=(1, 2, 4),
        use_sparse_linear_attn=False,
        attn_dim_head=16,
    ).to(device)

    model_fluid.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/surrogateUnetfluid/" + iter + "_5000/model.pt")["model"]
    )
elif model_type == "ViT":
    model_fluid = ViT(
        image_size=data.shape[-2:],
        image_patch_size=(8, 2),
        frames=data.shape[2],
        frame_patch_size=2,
        dim=256,
        depth=2,
        heads=8,
        mlp_dim=256,
        cond_emb=emb,
        channels=len(emb),
        out_channels=data.shape[1],
        dropout=0.1,
        emb_dropout=0.1,
    ).to(device)
    model_fluid.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/surrogateViTfluid/" + iter + "_5000/model.pt")["model"]
    )
elif model_type == "FNO":
    model_fluid = FNO3D(
        in_channels=len(emb),
        out_channels=data.shape[1],
        nr_fno_layers=3,
        fno_layer_size=16,
        fno_modes=[6, 16, 6],
        cond_emb=emb,
    ).to(device)
    model_fluid.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/surrogateFNOfluid/" + iter + "_5000/model.pt")["model"]
    )

In [53]:
b = -32
for i in range(len(cond)):
    cond[i] = cond[i][b:]
data = data[b:]
with torch.no_grad():
    pred = model_fluid(cond)
    pred = renormalize(pred, field="fluid")
    data = renormalize(data, field="fluid")

In [None]:
loss_fluid = 0
for i in range(4):
    cu_loss = relative_error(data[:, i], pred[:, i])
    print(cu_loss)
    loss_fluid += cu_loss
relative_error(data, pred), loss_fluid / 4

In [None]:
channel = 0
plot_compare_2d(data[-2, channel, -1], pred[-2, channel, -1])

# combine

In [60]:
b = None
fuel = torch.tensor(np.load("../../data/NTcouple/val/fuel.npy")).float().to(device)[:b]


fluid = torch.tensor(np.load("../../data/NTcouple/val/fluid.npy")).float().to(device)[:b]


neu = torch.tensor(np.load("../../data/NTcouple/val/neu.npy")).float().to(device)[:b]


bc = torch.tensor(np.load("../../data/NTcouple/val/bc.npy")).float().to(device)[:b]

b = bc.shape[0]

In [None]:
fuel.shape, fluid.shape, neu.shape

In [62]:
fluid_p = renormalize(torch.ones_like(fluid) * 0.5, "fluid")
fuel_p = renormalize(torch.ones_like(fuel) * 0.5, "solid")
neu_p = renormalize(torch.ones_like(neu) * 0.5, "neutron")

In [63]:
def k(t):
    return 17.5 * (1 - 0.223) / (1 + 0.161) + 1.54e-2 * (1 + 0.0061) / (1 + 0.161) * t + 9.38e-6 * t * t


def update_f_neu(model, neu, fuel, fluid, arg):
    fuel_n = normalize(fuel, "solid")
    fluid_n = normalize(fluid[:, :1], "fluid")
    arg_n = normalize(arg, "neutron")
    T_n = torch.concat((fuel_n, fluid_n), dim=-1)
    # print(fuel_n.shape, fluid_n.shape, arg_n.shape, T_n.shape)
    neu_p = model([arg_n, T_n])
    return renormalize(neu_p, "neutron")


def update_f_fuel(model, neu, fuel, fluid, arg):
    neu_n = normalize(neu[..., :8], "solid")
    fluid_n = normalize(fluid, "fluid")[:, 0:1, :, :, 0:1]
    fuel_p = model([neu_n, fluid_n])
    return renormalize(fuel_p, "solid")


def update_f_fluid(model, neu, fuel, fluid, arg):
    flux = (fuel[..., -2:-1] - fuel[..., -1:None]) * k(fuel[..., -1:None])
    flux_n = normalize(flux, "flux")
    fluid_pp = model([flux_n])
    return renormalize(fluid_pp, "fluid")

In [None]:
coeff = 0.5
max_iter = 100

with torch.no_grad():

    for i in range(max_iter):
        neu_p_old = neu_p
        fuel_p_old = fuel_p
        fluid_p_old = fluid_p
        neu_p = update_f_neu(
            model_neu, neu_p.clone(), fuel_p.clone(), fluid_p.clone(), bc.clone()
        ) * coeff + neu_p_old * (1 - coeff)
        fuel_p = update_f_fuel(
            model_fuel, neu_p.clone(), fuel_p.clone(), fluid_p.clone(), bc.clone()
        ) * coeff + fuel_p_old * (1 - coeff)
        fluid_p = update_f_fluid(
            model_fluid, neu_p.clone(), fuel_p.clone(), fluid_p.clone(), bc.clone()
        ) * coeff + fluid_p_old * (1 - coeff)

        loss1 = np.sum(L2_norm(neu_p - neu_p_old) / L2_norm(neu_p)) / neu_p.shape[0]

        loss2 = np.sum(L2_norm(fuel_p - fuel_p_old) / L2_norm(fuel_p)) / fuel_p.shape[0]

        loss3 = np.sum(L2_norm(fluid_p - fluid_p_old) / L2_norm(fluid_p)) / fluid_p.shape[0]

        loss = loss1 + loss2 + loss3

        print("loss: ", loss, loss1, loss2, loss3)

        if loss < 1e-3:

            print("converse in iteration: ", i)

            break

        if i == max_iter - 1:

            print("up to max iteration")

In [None]:
loss_fluid = 0
for i in range(4):
    cu_loss = relative_error(fluid[:, i], fluid_p[:, i])
    print(cu_loss)
    loss_fluid += cu_loss
loss_fluid = loss_fluid / 4
relative_error(neu, neu_p), relative_error(fuel, fuel_p), relative_error(fluid, fluid_p), loss_fluid

In [None]:
pred, data = neu_p, neu
channel = 0
plot_compare_2d(
    data[-1, channel, -1], pred[-1, channel, -1], savep="../../results/nuclear_thermal_coupling/neutron_surrogate.png"
)

In [None]:
pred, data = fuel_p, fuel
channel = 0
plot_compare_2d(
    data[-1, channel, -1], pred[-1, channel, -1], savep="../../results/nuclear_thermal_coupling/solid_surrogate.png"
)

In [None]:
pred, data = fluid_p, fluid
phy_lis = ["T", "P", "vx", "vy"]
for i in range(4):

    channel = i

    plot_compare_2d(
        data[-1, channel, -1],
        pred[-1, channel, -1],
        savep="../../results/nuclear_thermal_coupling/fluid" + phy_lis[i] + "_surrogate.png",
    )