In [1]:
import os, sys, logging, argparse, pickle, glob, random, pylab, time
from tqdm import tqdm
from phi.torch.flow import *

In [6]:
random.seed(42)
np.random.seed(42)

# math.seed(42) # phiflow seed
math.set_global_precision(32) # single precision

USE_CPU = 0
TORCH.set_default_device("GPU")
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
if USE_CPU > 0:
    device = 'cpu'
device = torch.device(device)
print("Using device: "+str(device))

Using device: cuda:0


In [7]:
RE_FAC_SOL = 10/(128*128) # factor to compensate for the original scaling from the original solver-in-the-loop paper

class KarmanFlow():
    def __init__(self, domain):
        self.domain = domain

        self.vel_BcMask = self.domain.staggered_grid( HardGeometryMask( Box(y=(None, 5), x=None) ) )

        self.inflow = self.domain.scalar_grid(Box(y=(5,10), x=(25,75)) ) # scale with domain if necessary!
        self.obstacles = [Obstacle(Sphere(center=tensor([50, 50], channel(vector="y,x")), radius=10))]

    def step(self, marker_in, velocity_in, Re, res, buoyancy_factor=0, dt=1.0):
        velocity = velocity_in
        marker   = marker_in
        Re_phiflow = Re / RE_FAC_SOL # rescale for phiflow

        # viscosity
        velocity = phi.flow.diffuse.explicit(u=velocity, diffusivity=1.0/Re_phiflow*dt*res*res, dt=dt)

        # inflow boundary conditions
        velocity = velocity*(1.0 - self.vel_BcMask) + self.vel_BcMask * (1,0)

        # advection
        marker = advect.semi_lagrangian(marker+ 1. * self.inflow, velocity, dt=dt)
        velocity = advected_velocity = advect.semi_lagrangian(velocity, velocity, dt=dt)

        # mass conservation (pressure solve)
        pressure = None
        velocity, pressure = fluid.make_incompressible(velocity, self.obstacles)
        self.solve_info = { 'pressure': pressure, 'advected_velocity': advected_velocity }

        return [marker, velocity]



In [8]:
layers = [32,32,32] # small
#layers = [32,48,48,48,32] # uncomment for a somewhat larger and deeper network
#network = conv_net(in_channels=3,out_channels=2,layers=layers) # a simpler variant
network = res_net(in_channels=3,out_channels=2,layers=layers)
print(network)

# reinit
import torch.nn as nn
for m in network.modules():
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight, gain=0.1)

print("Total number of trainable parameters: "+ str( sum(p.numel() for p in network.parameters()) ))

ResNet(
  (Res_in): ResNetBlock(
    (sample_input): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
    (bn_sample): Identity()
    (bn1): Identity()
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): Identity()
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (Res1): ResNetBlock(
    (sample_input): Identity()
    (bn_sample): Identity()
    (bn1): Identity()
    (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): Identity()
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (Res2): ResNetBlock(
    (sample_input): Identity()
    (bn_sample): Identity()
    (bn1): Identity()
    (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): Identity()
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (Res_out): Conv2d(32, 2, kernel_size=(1, 1), stride=(1, 1))
)
Total number of trainable

In [9]:
def to_phiflow(batch):
    vx = batch[:,1,:-1,:-1]
    vy = batch[:,2,:,:] # fine

    #print("v_dims "+str([vx.shape,vy.shape])) # example for debugging
    # v_dims should be vx [torch.Size([B, 64, 31]),  vy torch.Size([B, 65, 32])]

    vel = domain.staggered_grid( math.stack( [
        math.tensor(vy, math.batch('batch'), math.spatial('y, x')),
        math.tensor(vx, math.batch('batch'), math.spatial('y, x')),
    ], math.dual(vector="y,x")
    ) )
    return vel

def to_pytorch(marker_vel, Re):
    # align the sides the staggered velocity grid making its size the same as the centered grid
    grid = math.stack(
        [
            math.pad( marker_vel[1].vector['x'].values, {'x':(0,1)} , math.extrapolation.ZERO), # x component
            marker_vel[1].vector['y'].y[:-1].values,                                            # y component
            math.ones(marker_vel[0].shape)*Re                                                   # constant Re
        ],
        math.channel('channels')
    ).native(order='batch,channels,y,x')
    return grid

def to_solver(inputs):
    marker_in = inputs[:,0,:-1,:]
    marker_in = domain.scalar_grid( math.tensor(marker_in, math.batch('batch'), math.spatial('y, x')) )
    v_in = to_phiflow(inputs)
    Re = math.tensor(inputs[0,3, 0,0].detach()) # Scalar , get first index 0,0

    Re_norm = (Re - math.tensor(DATA_RE_MEAN)) / math.tensor(DATA_RE_STD)
    Re_norm = float(Re_norm.native().detach()) # we just need a single number

    return marker_in, v_in, Re, Re_norm


In [10]:
LEARNING_RATE = 1e-3
optimizer = adam(network, LEARNING_RATE)

# one of the most crucial parameters: how many simulation steps to look into the future in each training step
MSTEPS = 4

BATCH_SIZE = 3

In [11]:
import pbdl
import pbdl.torch.loader

dataloader = pbdl.torch.loader.Dataloader("solver-in-the-loop-wake-flow", MSTEPS, shuffle=True, sel_sims=[0,1,2,3,4,5],
                                          batch_size=BATCH_SIZE, normalize_const="std", normalize_data="std", intermediate_time_steps=True)

# workaround for using un-normalized and normalized values in one script:
#    save the normalization constants of the Reynolds number conditioning , then turn off (norm=None);
#    the Re values will be normalized manually later on
DATA_RE_MEAN = dataloader.dataset.norm_strat_const.const_mean[0][0]
DATA_RE_STD  = dataloader.dataset.norm_strat_const.const_std[0][0]
print([DATA_RE_MEAN,DATA_RE_STD])
dataloader.dataset.norm_strat_const = None
dataloader.dataset.norm_strat_data = None

[Kdownload completed	 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[38;5;240m[96m 100%[0m6m[96m[96m[96m[96m[96m
[96m[1mSuccess:[22m Loaded solver-in-the-loop-wake-flow with 6 simulations (6 selected) and 496 samples each.[0m
[1mInfo:[22m No precomputed normalization data found (or not complete). Calculating data...[0m
[np.float64(1025.390625), np.float64(1057.4417884622908)]


In [12]:
# this is the actual resolution in terms of cells for phiflow (not too crucial)
SOURCE_RES = [64,32]

# this is the physical size in terms of abstract units for the bounding box of the domain (it's important for conversions between computational and physical units)
LENGTH = 100.

# for readability
from phi.physics._boundaries import Domain, OPEN, STICKY as CLOSED

BNDS = {
    'y':(phi.physics._boundaries.OPEN,  phi.physics._boundaries.OPEN) ,
    'x':(phi.physics._boundaries.STICKY,phi.physics._boundaries.STICKY)
}

domain = phi.physics._boundaries.Domain(y=SOURCE_RES[0], x=SOURCE_RES[1], boundaries=BNDS, bounds=Box(y=2*LENGTH, x=LENGTH))
simulator = KarmanFlow(domain=domain)

Please create grids directly, replacing the domain with a dict, e.g.
    domain = dict(x=64, y=128, bounds=Box(x=1, y=1))
    grid = CenteredGrid(0, **domain)
  from phi.physics._boundaries import Domain, OPEN, STICKY as CLOSED
  domain = phi.physics._boundaries.Domain(y=SOURCE_RES[0], x=SOURCE_RES[1], boundaries=BNDS, bounds=Box(y=2*LENGTH, x=LENGTH))
  domain = phi.physics._boundaries.Domain(y=SOURCE_RES[0], x=SOURCE_RES[1], boundaries=BNDS, bounds=Box(y=2*LENGTH, x=LENGTH))
  self.vel_BcMask = self.domain.staggered_grid( HardGeometryMask( Box(y=(None, 5), x=None) ) )


In [13]:
@jit_compile
def simulation_step(marker,velocity,Re,resolution):
    m,v = simulator.step(
        marker_in=marker,
        velocity_in=velocity,
        Re=Re, res=resolution
    )
    return m,v

In [14]:
def training_step(inputs_targets):
    [inputs, targets] = inputs_targets
    marker_in, v_in, Re, Re_norm = to_solver(inputs)
    prediction = [ [marker_in,v_in] ]
    loss = 0

    for i in range(MSTEPS):
        m2,v2 = simulation_step(
            marker=prediction[-1][0],
            velocity=prediction[-1][1],
            Re=Re, resolution=SOURCE_RES[1]
        )

        net_in = to_pytorch([m2,v2],Re_norm)
        net_out = network(net_in)

        cy = net_out[:,1,:,:] # pad y
        cy = torch.nn.functional.pad(input=cy, pad=(0,0, 0,1), mode='constant', value=0)
        cx = net_out[:,0,:,:-1]

        v_corr = domain.staggered_grid( math.stack( [
            math.tensor(cy, math.batch('batch'), math.spatial('y, x')),
            math.tensor(cx, math.batch('batch'), math.spatial('y, x')),
        ], math.dual(vector="y,x")
        ) )

        prediction += [ [domain.scalar_grid(m2) , v2 + v_corr] ]
        vdiff = prediction[-1][1] - to_phiflow(targets[:,i,...])
        loss += field.l2_loss(vdiff)

    return loss, prediction


In [15]:
EPOCHS = 5

pbar = tqdm(initial=0, total=EPOCHS*len(dataloader), ncols=96)

for epoch in range(EPOCHS):
    for b, (input_cpu, targets_cpu) in enumerate(dataloader):
        input = torch.tensor(input_cpu, dtype=torch.float32).to(device);
        targets = torch.tensor(targets_cpu, dtype=torch.float32).to(device)

        loss, prediction = update_weights(network, optimizer, training_step, [input, targets])

        pbar.set_description("loss "+str(np.sum(loss.numpy("batch"))), refresh=False); pbar.update(1)

    torch.save(network.state_dict(), "net-"+str(epoch)+".pickle")

pbar.close()

  input = torch.tensor(input_cpu, dtype=torch.float32).to(device);
  targets = torch.tensor(targets_cpu, dtype=torch.float32).to(device)
  return torch.sparse_csr_tensor(row_pointers, column_indices, values, shape, device=values.device)
  input = torch.tensor(input_cpu, dtype=torch.float32).to(device);
  targets = torch.tensor(targets_cpu, dtype=torch.float32).to(device)
loss 8.132623:   2%|▊                                      | 111/4960 [02:06<1:29:06,  1.10s/it]

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\torch\autograd\function.py(575): apply
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\backend\torch\_torch_backend.py(248): select_jit
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\math\_functional.py(963): __call__
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\math\_optimize.py(666): solve_linear
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phi\physics\fluid.py(156): make_incompressible
C:\Users\xayah\AppData\Local\Temp\ipykernel_22524\2052032387.py(29): step
C:\Users\xayah\AppData\Local\Temp\ipykernel_22524\2022291116.py(3): simulation_step
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\math\_functional.py(256): jit_f_native
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\backend\torch\_torch_backend.py(1124): forward
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\torch\nn\modules\module.py(1729): _slow_forward
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\torch\nn\modules\module.py(1750): _call_impl
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\torch\nn\modules\module.py(1739): _wrapped_call_impl
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\torch\jit\_trace.py(1276): trace_module
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\torch\jit\_trace.py(696): _trace_impl
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\torch\jit\_trace.py(1000): trace
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\backend\torch\_torch_backend.py(1130): __call__
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\math\_functional.py(310): __call__
C:\Users\xayah\AppData\Local\Temp\ipykernel_22524\4240932139.py(8): training_step
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\backend\torch\nets.py(56): update_weights
C:\Users\xayah\AppData\Local\Temp\ipykernel_22524\3073841444.py(10): <module>
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\interactiveshell.py(3579): run_code
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\interactiveshell.py(3519): run_ast_nodes
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\interactiveshell.py(3336): run_cell_async
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\async_helpers.py(128): _pseudo_sync_runner
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\interactiveshell.py(3132): _run_cell
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\interactiveshell.py(3077): run_cell
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\zmqshell.py(549): run_cell
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\ipkernel.py(449): do_execute
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\kernelbase.py(778): execute_request
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\ipkernel.py(362): execute_request
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\kernelbase.py(437): dispatch_shell
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\kernelbase.py(534): process_one
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\kernelbase.py(545): dispatch_queue
C:\Users\xayah\AppData\Local\Programs\Python\Python313\Lib\asyncio\events.py(89): _run
C:\Users\xayah\AppData\Local\Programs\Python\Python313\Lib\asyncio\base_events.py(2040): _run_once
C:\Users\xayah\AppData\Local\Programs\Python\Python313\Lib\asyncio\base_events.py(683): run_forever
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\tornado\platform\asyncio.py(205): start
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\kernelapp.py(739): start
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\traitlets\config\application.py(1075): launch_instance
C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel_launcher.py(18): <module>
<frozen runpy>(88): _run_code
<frozen runpy>(198): _run_module_as_main
RuntimeError: KeyboardInterrupt: <EMPTY MESSAGE>

At:
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\backend\torch\_torch_backend.py(1179): forward
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\torch\autograd\function.py(575): apply
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\torch\nn\modules\module.py(1750): _call_impl
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\torch\nn\modules\module.py(1739): _wrapped_call_impl
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\backend\_backend.py(427): call
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\backend\torch\_torch_backend.py(1135): __call__
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\math\_functional.py(314): __call__
  C:\Users\xayah\AppData\Local\Temp\ipykernel_22524\4240932139.py(8): training_step
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\phiml\backend\torch\nets.py(56): update_weights
  C:\Users\xayah\AppData\Local\Temp\ipykernel_22524\3073841444.py(10): <module>
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\interactiveshell.py(3579): run_code
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\interactiveshell.py(3519): run_ast_nodes
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\interactiveshell.py(3336): run_cell_async
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\async_helpers.py(128): _pseudo_sync_runner
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\interactiveshell.py(3132): _run_cell
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\IPython\core\interactiveshell.py(3077): run_cell
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\zmqshell.py(549): run_cell
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\ipkernel.py(449): do_execute
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\kernelbase.py(778): execute_request
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\ipkernel.py(362): execute_request
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\kernelbase.py(437): dispatch_shell
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\kernelbase.py(534): process_one
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\kernelbase.py(545): dispatch_queue
  C:\Users\xayah\AppData\Local\Programs\Python\Python313\Lib\asyncio\events.py(89): _run
  C:\Users\xayah\AppData\Local\Programs\Python\Python313\Lib\asyncio\base_events.py(2040): _run_once
  C:\Users\xayah\AppData\Local\Programs\Python\Python313\Lib\asyncio\base_events.py(683): run_forever
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\tornado\platform\asyncio.py(205): start
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel\kernelapp.py(739): start
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\traitlets\config\application.py(1075): launch_instance
  C:\Users\xayah\Desktop\DeepDiffFluid\venv\Lib\site-packages\ipykernel_launcher.py(18): <module>
  <frozen runpy>(88): _run_code
  <frozen runpy>(198): _run_module_as_main



In [None]:
del dataloader # close hdf5 file handle

# get new & unseen test data
dataloader_test = pbdl.torch.loader.Dataloader("solver-in-the-loop-wake-flow", 200, batch_size=1, shuffle=True, sel_sims=[6,7,8,9],
                                               normalize=False, intermediate_time_steps=True)

In [None]:
# optionally load
if False:
    fn = "net-"+str(EPOCHS-1)+".pickle" # load last
    network.load_state_dict(torch.load(fn, map_location=device, weights_only=True))
    print("Loaded "+fn)

In [None]:
def run_sim(inputs, targets, steps, network=None):
    marker_in, v_in, Re, Re_norm = to_solver(inputs)

    simtype = "With corr."
    if (network==None): simtype = "Sim. only"
    print("Running test with Re="+str(Re)+", "+simtype)

    prediction = [ [marker_in,v_in] ]
    correction = [ [marker_in,v_in] ]
    refs = [ v_in ]
    errors = []

    for i in tqdm(range(steps), desc=simtype, ncols = 64):
        marker_sim,v_sim = simulation_step(
            marker=prediction[-1][0],
            velocity=prediction[-1][1],
            Re=Re, resolution=SOURCE_RES[1]  # take Re from constant field
        )

        if network: # run hybrid solver with trained Neural op
            net_in = to_pytorch([marker_sim,v_sim],Re_norm)
            net_out = network(net_in)

            cy = net_out[:,1,:,:] # pad y
            cy = torch.nn.functional.pad(input=cy, pad=(0,0, 0,1), mode='constant', value=0)
            cx = net_out[:,0,:,:-1]

            v_corr = domain.staggered_grid( math.stack( [
                math.tensor(cy, math.batch('batch'), math.spatial('y, x')),
                math.tensor(cx, math.batch('batch'), math.spatial('y, x')),
            ], math.dual(vector="y,x")
            ) )

            prediction += [ [domain.scalar_grid(marker_sim) , v_sim + v_corr] ]
            correction += [ [domain.scalar_grid(marker_sim) , v_corr] ]

        else: # only low-fidelity solver
            prediction += [ [domain.scalar_grid(marker_sim) , v_sim ] ]

        refs += [to_phiflow(targets[:,i,...])]
        vdiff = prediction[i][1] - refs[-1]

        error_phi = field.l1_loss(vdiff)
        errors += [float( error_phi.native("batch")[0] / field.l1_loss(refs[-1]).native("batch")[0] )]

    return errors, prediction, refs, correction


In [None]:
# get a sample
(input_cpu, targets_cpu) = next(iter(dataloader_test))
input = torch.tensor(input_cpu, dtype=torch.float32).to(device)
targets = torch.tensor(targets_cpu, dtype=torch.float32).to(device)
print("Re ",math.tensor(input[0,3, 0,0].detach()))

In [None]:
ROLLOUT_STEPS = 100

err_lowfid_only, prediction_lowfid_only, refs, _   = run_sim(input, targets, ROLLOUT_STEPS)
err_corrected,   prediction_corrected  , _ , corrs = run_sim(input, targets, ROLLOUT_STEPS, network)
print("\n Rel. L2 errors: low-fidelity:", float(np.mean(err_lowfid_only))," corrected:", float(np.mean(err_corrected)) )

In [None]:
fig = pylab.figure().gca()
pltx = np.linspace(0,ROLLOUT_STEPS-1,ROLLOUT_STEPS)
fig.plot(pltx, err_lowfid_only, lw=2, color='mediumblue', label='Source')
fig.plot(pltx, err_corrected,   lw=2, color='green', label='Hybrid')
pylab.xlabel('Time step'); pylab.ylabel('Relative L2 error'); fig.legend()


In [None]:
# which step from which batch to show , by default shows last step from first case
STEP = ROLLOUT_STEPS
BATCH = 0
NUM_SHOW = 4
PRINT_STATS = False # optional, print statistics

fig, axes = pylab.subplots(1, 4, figsize=(16, 5))
i = 0

v = refs[STEP].staggered_tensor().numpy('batch,y,x,vector')[BATCH,:,:,0]
if PRINT_STATS: print(["reference ", BATCH, i, np.mean(v),  np.min(v) ,  np.max(v)])
axes[i].set_title(f"Ref")
im=axes[i].imshow( v , origin='lower', cmap='magma') ;
pylab.colorbar(im) ; i=i+1; vy_ref=v

v = prediction_lowfid_only[STEP][1].staggered_tensor().numpy('batch,y,x,vector')[BATCH,:,:,0]
if PRINT_STATS: print(["low-fid. ", BATCH, i, np.mean(v),  np.min(v) ,  np.max(v)])
axes[i].set_title(f"Low-fid.")
im=axes[i].imshow( v , origin='lower', cmap='magma') ;
pylab.colorbar(im) ; i=i+1; vy_lowfid=v

v = prediction_corrected[STEP][1].staggered_tensor().numpy('batch,y,x,vector')[BATCH,:,:,0]
if PRINT_STATS: print(["corrected", BATCH, i, np.mean(v),  np.min(v) ,  np.max(v)])
axes[i].set_title(f"Corr.")
im=axes[i].imshow( v , origin='lower', cmap='magma') ;
pylab.colorbar(im) ; i=i+1; vy_corr=v

# show error side by side
err_lf   = vy_ref - vy_lowfid
err_corr = vy_ref - vy_corr
v = np.concatenate([err_lf,err_corr], axis=1)
axes[i].set_title(f" Errors: Low-fid & Learned")
im=axes[i].imshow( v , origin='lower', cmap='cividis') ;
pylab.colorbar(im) ; i=i+1

pylab.tight_layout()