In [18]:
import gymnasium as gym
import torch
import numpy as np
import dedalus.public as d3
import logging
import h5py
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env

In [54]:
class DedalusRBC_Env(gym.Env):
    metadata = {'render_modes' : 'human'}
    obs_metadata = {"Ni" : 30, "Nk" : 8 }
    sim_metadata = {"Lx" : np.pi, "Lz" : 1.0, "Ra" : 1e4, "Ni" : 100, "Nk" : 64, "DiscardTime" : 80}
    act_metadata = {"actionDuration" : 1.5, "actionsPerEp" : 200}
    
    def __init__(self, render_mode=None):
        self.observation_space = gym.spaces.Box(-0.5, 1.5, shape=(self.obs_metadata['Nk']*self.obs_metadata['Ni'],))
        self.action_space = gym.spaces.Box(-1, 1, shape=(10,))
        self.render_mode = render_mode

    def reset(self, seed=None, options={}):
        super().reset(seed=seed)
        self._D3_init()
        obs = self._extractObs()
        info = {}

        return obs, info

    def step(self, action=None):
        if action is not None:
            self._setBC(action)

        for _ in range(int(self.act_metadata['actionDuration'] / self.timestep)):
            self.solver.step(self.timestep)

        obs = self._extractObs()
        info = {}
        reward = self._computeReward()
        term = not self.solver.proceed
        trun = term

        return obs, reward, term, trun, info

    def _computeReward(self):
        return -(self.fp.max('Nu') - self.Nu0)
        

    def _extractObs(self):
        x=np.linspace(0, self.sim_metadata['Lx'], self.obs_metadata['Ni']+1)
        z=np.linspace(0.1, self.sim_metadata['Lz']-0.1, self.obs_metadata['Nk'])
        
        #remove the last x due to periodicity
        x=x[:-1]
        
        X, Z = np.meshgrid(x, z)
        obs = np.zeros_like(X, dtype=np.float32)

        for i in range(self.obs_metadata['Ni']):
            for k in range(self.obs_metadata['Nk']):
                obs[k][i] = np.squeeze(self.problem.variables[1](x=X[k][i], z=Z[k][i]).evaluate()['g'])

        return obs.flatten()

    def _setBC(self, action):
        #lower BC
        Tp = action - np.mean(action)

        for j in range(len(Tp)):
            Tp[j] /= max(1., Tp[j])/0.75

        Tp = np.repeat(Tp, 3)
        
        #copy last action due to periodicity
        Tp = np.append(Tp, Tp[0])
        xp=np.linspace(0, self.sim_metadata['Lx'], len(Tp))

        T = np.interp(self.x, xp, Tp)
       
        self.g['g'] = T
        self.g['g'] += 1.0
    
    def _D3_init(self):
        self.problem, self.solver, self.CFL, self.fp, self.g, self.x = self._D3_RBC_setup(np.random.randint(100000))

        while True:
            self.timestep = self.CFL.compute_timestep()
            self.solver.step(self.timestep)
            if self.solver.sim_time >= self.sim_metadata['DiscardTime']:
                break

        #get initial Nu for normalisation
        self.Nu0 = self.fp.max('Nu')
        
    
    def _D3_RBC_setup(self, seed):
        logger = logging.getLogger(__name__)
        
        # Parameters
        Lx, Lz = self.sim_metadata['Lx'], self.sim_metadata['Lz']
        Nx, Nz = self.sim_metadata['Ni'], self.sim_metadata['Nk']
        Rayleigh = self.sim_metadata['Ra']
        Prandtl = 1
        dealias = 3/2
        stop_sim_time = self.act_metadata['actionsPerEp']*self.act_metadata['actionDuration']+self.sim_metadata['DiscardTime']
        timestepper = d3.RK222
        max_timestep = 0.125
        dtype = np.float64

        # Bases
        coords = d3.CartesianCoordinates('x', 'z')
        dist = d3.Distributor(coords, dtype=dtype)
        xbasis = d3.RealFourier(coords['x'], size=Nx, bounds=(0, Lx), dealias=dealias)
        zbasis = d3.ChebyshevT(coords['z'], size=Nz, bounds=(0, Lz), dealias=dealias)


        # Fields
        p = dist.Field(name='p', bases=(xbasis,zbasis))
        b = dist.Field(name='b', bases=(xbasis,zbasis))
        u = dist.VectorField(coords, name='u', bases=(xbasis,zbasis))
        g = dist.Field(bases=xbasis)
        tau_p = dist.Field(name='tau_p')
        tau_b1 = dist.Field(name='tau_b1', bases=xbasis)
        tau_b2 = dist.Field(name='tau_b2', bases=xbasis)
        tau_u1 = dist.VectorField(coords, name='tau_u1', bases=xbasis)
        tau_u2 = dist.VectorField(coords, name='tau_u2', bases=xbasis)
        
        # Substitutions
        kappa = (Rayleigh * Prandtl)**(-1/2)
        nu = (Rayleigh / Prandtl)**(-1/2)
        x, z = dist.local_grids(xbasis, zbasis)
        ex, ez = coords.unit_vector_fields(dist)
        lift_basis = zbasis.derivative_basis(1)
        lift = lambda A: d3.Lift(A, lift_basis, -1)
        grad_u = d3.grad(u) + ez*lift(tau_u1) # First-order reduction
        grad_b = d3.grad(b) + ez*lift(tau_b1) # First-order reduction

        #bottom boundary
        g['g'] = 1.0
        
        # Problem
        # First-order form: "div(f)" becomes "trace(grad_f)":
        # First-order form: "lap(f)" becomes "div(grad_f)"
        problem = d3.IVP([p, b, u, tau_p, tau_b1, tau_b2, tau_u1, tau_u2], namespace=locals())
        problem.add_equation("trace(grad_u) + tau_p = 0")
        problem.add_equation("dt(b) - kappa*div(grad_b) + lift(tau_b2) = - u@grad(b)")
        problem.add_equation("dt(u) - nu*div(grad_u) + grad(p) - b*ez + lift(tau_u2) = - u@grad(u)")
        problem.add_equation("b(z=0) = g")
        problem.add_equation("u(z=0) = 0")
        problem.add_equation("b(z=Lz) = 0")
        problem.add_equation("u(z=Lz) = 0")
        problem.add_equation("integ(p) = 0") # Pressure gauge

        # Solver
        solver = problem.build_solver(timestepper)
        solver.stop_sim_time = stop_sim_time

        fp = d3.GlobalFlowProperty(solver)
        fp.add_property(d3.Average(d3.Average((ez@u)*b, 'x') - kappa*d3.Differentiate(d3.Average(b, 'x'), coords[1]), 'z')/kappa, name='Nu')
            
        # Initial conditions
        b.fill_random('g', seed=seed, distribution='normal', scale=1e-3) # Random noise
        b['g'] *= z * (Lz - z) # Damp noise at walls
        b['g'] += Lz - z # Add linear background


        # CFL
        CFL = d3.CFL(solver, initial_dt=max_timestep, cadence=10, safety=0.5, threshold=0.05,
             max_change=1.5, min_change=0.5, max_dt=max_timestep)
        CFL.add_velocity(u)


        return problem, solver, CFL, fp, g, x

    def render(self):
        print( self.fp.max('Nu') )
        

In [3]:
#env = make_vec_env(DedalusRBC_Env, n_envs=30, seed=0, vec_env_cls=SubprocVecEnv)
gym.envs.register(id='rbc', entry_point=DedalusRBC_Env)
env = gym.make('rbc', render_mode="human")

In [4]:
model = PPO.load("ppo_rbc", env=env)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [5]:
from stable_baselines3.common.evaluation import evaluate_policy
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=3)

In [6]:
mean_reward

17.121048

In [13]:
vec_env = model.get_env()
obs = vec_env.reset()
score=0
for i in range(200):
    action, _ = model.predict(obs, deterministic=False)
    obs, rew, _, _ = vec_env.step(action)
    T = env.unwrapped.problem.variables[1]['g']
    fig, ax = plt.subplots()
    c = ax.imshow(T, vmin=0., vmax=1.2)
    fig.colorbar(c)
    plt.savefig('figs/'+str(i)+'.png')
    plt.close()
    score+=rew
    print( env.unwrapped.fp.max('Nu'), rew )
print(score)

2.5957151608898887 [0.0842309]
2.522567699138032 [0.15737836]
2.511762013774016 [0.16818404]
2.5338194204362607 [0.14612664]
2.550422354529942 [0.12952371]
2.551620080662518 [0.12832598]
2.556098745323132 [0.12384731]
2.571587504558441 [0.10835855]
2.577589077346637 [0.10235699]
2.565228882626344 [0.11471718]
2.5458237827416776 [0.13412228]
2.543375585793873 [0.13657047]
2.5673044970275303 [0.11264157]
2.57078329118649 [0.10916277]
2.5579515327779596 [0.12199453]
2.5550045566678956 [0.12494151]
2.5563313472320033 [0.12361471]
2.555991520775043 [0.12395454]
2.5619046461844532 [0.11804142]
2.5673697990310718 [0.11257626]
2.5577685133221104 [0.12217755]
2.547950677689279 [0.13199538]
2.5476870998700623 [0.13225897]
2.5651069123477406 [0.11483915]
2.589715555298429 [0.0902305]
2.5870044341034744 [0.09294163]
2.5760472366565446 [0.10389882]
2.5718336062799088 [0.10811245]
2.5524852763831913 [0.12746078]
2.5448267164243834 [0.13511935]
2.5502391695546045 [0.12970689]
2.5551523803121117 [0.12

In [55]:
r = DedalusRBC_Env()

In [56]:
r.reset()

1.0000124284716099
1.0000000040348063
1.0000000035418353
1.0000000032782406
1.0000000031570242
1.0000000031221987
1.0000000031510288
1.000000003230099
1.0000000033503602
1.000000003505861
1.000000003693006
1.0000000039100039
1.0000000041563906
1.0000000044327173
1.0000000047403095
1.0000000050811297
1.00000000545766
1.000000005872871
1.0000000063301944
1.0000000068335178
1.0000000073872097
1.0000000079961364
1.0000000086657108
1.0000000094019401
1.0000000102114779
1.0000000111016876
1.0000000120807226
1.0000000131576057
1.0000000143423176
1.000000015645906
1.000000017080591
1.000000018659906
1.00000002039882
1.0000000223139085
1.0000000244235183
1.0000000267479676
1.0000000293097444
1.0000000321337676
1.000000035247623
1.000000038681876
1.0000000424703768
1.0000000466506427
1.000000051264227
1.0000000563571996
1.0000000619805938
1.0000000681909882
1.0000000750510907
1.0000000826304198
1.0000000910060505
1.0000001002634396
1.0000001104973408
1.000000121812834
1.0000001343264535
1.000000

(array([0.70653075, 0.66239196, 0.6300521 , 0.6145556 , 0.6189356 ,
        0.64229536, 0.6804036 , 0.72820747, 0.7822977 , 0.8410596 ,
        0.8940579 , 0.90884906, 0.87165534, 0.8133284 , 0.7565813 ,
        0.7052102 , 0.6612424 , 0.62906146, 0.6137003 , 0.61816484,
        0.64153427, 0.67956793, 0.72720945, 0.7810366 , 0.83953166,
        0.8929362 , 0.9093022 , 0.87331915, 0.8150795 , 0.75810075,
        0.4851984 , 0.43154433, 0.39324778, 0.37427282, 0.37970066,
        0.40781978, 0.45294508, 0.5135204 , 0.59271425, 0.69323665,
        0.79531133, 0.8254772 , 0.7508395 , 0.64382213, 0.553223  ,
        0.48322868, 0.42987406, 0.39177364, 0.37296703, 0.37851995,
        0.4066689 , 0.45168605, 0.51194185, 0.5905137 , 0.690279  ,
        0.7929533 , 0.82642245, 0.7541767 , 0.6470112 , 0.55569935,
        0.43471125, 0.39379844, 0.35209164, 0.32438934, 0.33286968,
        0.36987203, 0.41145477, 0.45420873, 0.5157708 , 0.614108  ,
        0.7288322 , 0.76419735, 0.67742485, 0.56