In [1]:
import numpy as np
import matplotlib.pyplot as plt

import sys, os
sys.path.append(os.path.abspath('../../..'))

# load the modules
from fridom.ShallowWater.ModelSettings import ModelSettings
from fridom.ShallowWater.Grid import Grid
from fridom.ShallowWater.InitialConditions import Random, SingleWave, Jet
from fridom.ShallowWater.Model import Model
from fridom.ShallowWater.Plot import Plot
from fridom.ShallowWater.ModelPlotter import ModelPlotter
from fridom.ShallowWater.State import State
from fridom.Framework.ModelBase import ModelBase

In [5]:
class PerturbatedJet(Jet):
    def __init__(self, mset, grid, seed, **kwargs):
        super().__init__(mset, grid, **kwargs, waveamp=0)
        rand = Random(mset, grid, seed=seed) * 0.1
        self.u = self.u + rand.u
        self.v = self.v + rand.v
        self.h = self.h + rand.h
        return

class MyPlotter(ModelPlotter):
    def create_figure():
        fig = plt.figure(figsize=(15,5), dpi=200, tight_layout=True)
        return fig

    def update_figure(fig, z1, z2, zm, time, **kwargs):
        ax1 = fig.add_subplot(131)
        Plot(z1.ekin())(z1, fig=fig, ax=ax1, cmax=0.8, vmax=2)
        plt.title("Member 1")
        ax2 = fig.add_subplot(132)
        Plot(z2.ekin())(z2, fig=fig, ax=ax2, cmax=0.8, vmax=2)
        plt.title("Member 2")
        ax3 = fig.add_subplot(133)
        Plot(zm.ekin())(zm, fig=fig, ax=ax3, cmax=0.8, vmax=2)
        plt.title("Ensemble Mean")
        fig.suptitle("Time: {:.2f} s".format(time))
        return

class EnsembleState:
    def __init__(self, mset, grid, ensemble_size):
        self.mset = mset
        self.grid = grid
        self.ensemble_size = ensemble_size

        states = []
        for i in range(ensemble_size):
            states.append(PerturbatedJet(mset, grid, seed=np.random.randint(0,100000)))
        self.states = states
        return

class EnsembleModel(ModelBase):
    def __init__(self, mset, grid, ensemble_size):
        super().__init__(mset, grid, State)
        self.ensemble_size = ensemble_size

        mset_sub = mset.copy()
        mset_sub.enable_vid_anim = False

        self.states = EnsembleState(mset_sub, grid, ensemble_size)
        self.models = []
        for state in self.states.states:
            m = Model(mset_sub, grid)
            m.z = state
            self.models.append(m)
        return

    def step(self):
        for model in self.models:
            model.step()

        # vid animation
        if self.mset.enable_vid_anim:
            if (self.it % self.mset.vid_anim_interval) == 0:
                self.update_vid_animation()

        self.it += 1
        return

    def update_vid_animation(self):
        self.vid_animation.update(
            z1=self.models[0].z.cpu(), 
            z2=self.models[1].z.cpu(),
            zm=self.z.cpu(), 
            time=self.time)

    @property
    def z(self):
        mean = State(self.mset, self.grid)

        for model in self.models:
            mean += model.z

        mean /= self.ensemble_size
        return mean

    @z.setter
    def z(self, z):
        return

In [8]:
fac = 8
mset = ModelSettings(N=[2**fac,2**fac], dt=2**(-fac), Ro=0.3)
mset.enable_vid_anim = True
mset.vid_anim_filename = "ensemble.mp4"
mset.vid_anim_interval = 10
mset.vid_plotter = MyPlotter

grid = Grid(mset)
ensemble = EnsembleModel(mset, grid, ensemble_size=200)
ensemble.run(runlen=20)
ensemble.show_video()

100%|██████████| 5120/5120 [1:16:24<00:00,  1.12it/s]
