In [None]:
from smartsim import Experiment
from smartredis import Client
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import torch.optim as optim
from sklearn.metrics import mean_squared_error
import numpy as np
from matplotlib import pyplot as plt

import io, os, jinja2

env = jinja2.Environment()
def get_field_name(fn_name, field_name, processor=0, timestep=None):
    """
    Get the name of the field from the database. This function uses
    a metadata dataset posted by the function object itself to determine
    how things are named through Jinja2 templates

    Args:
        fn_name (str): The name of the submitting entity, the database name in this case
        field_name (str): The name of the OpenFOAM field
        processor (int): The MPI rank
        timestep (int): The target timestep index
    """
    client.poll_dataset(fn_name+"_metadata", 10, 1000)
    meta = client.get_dataset(fn_name+"_metadata")
    ds_naming = env.from_string(str(meta.get_meta_strings("dataset")[0]))
    ds_name = ds_naming.render(time_index=timestep, mpi_rank=processor)
    f_naming = env.from_string(str(meta.get_meta_strings("field")[0]))
    f_name = f_naming.render(name=field_name, patch="internal")
    return f"{{{ds_name}}}.{f_name}"


# Set up the execution of the foamSmartSimMapField application 
# as a SmartSim Experiment. 
exp = Experiment("foam-smartsim-svd", launcher="local")

db = exp.create_database(port=8000,       # database port
                         interface="lo")  # network interface to use
exp.start(db)

# Connect the python client to the smartredis database
client = Client(address=db.get_address()[0], cluster=False)

run_settings_prepare = exp.create_run_settings(exe=f"{os.getcwd()}/cavity/Allrun")
openfoam_prepare_model = exp.create_model(name="prepare", run_settings=run_settings_prepare)

# MPI parallel run settings for foamSmartSimSvD
num_mpi_ranks = 4
run_settings_parallel = exp.create_run_settings(exe="foamSmartSimSvdDBAPI", 
                                                exe_args="-case cavity -fieldName p -parallel", 
                                                run_command="mpirun", 
                                                run_args={"np": f"{num_mpi_ranks}"})

openfoam_svd_model = exp.create_model(name="foamSmartSimSvdDBAPI", run_settings=run_settings_parallel)

try:
     torch.set_default_dtype(torch.float64)

     exp.start(openfoam_prepare_model, block=True)
     # Run foamSmartSimSvd and do not block.
     exp.start(openfoam_svd_model, block=True)

     # Perform SVD for each MPI rank in SmartRedis - this should be distributed.
     nTimes = int(client.get_dataset("foamSmartSimSvdDBAPI_metadata").get_meta_strings("NTimes")[0])
     for proc in range(4):
         for i in range(nTimes):
             field_name = get_field_name("foamSmartSimSvdDBAPI", "p", processor=proc, timestep=i)
             print(f"{field_name}: {client.get_tensor(field_name).size}")
except Exception as e:
    print("Caught an exception: ", str(e))
    
finally:
    exp.stop(db)