**TODO**:

Visual studio

Faire un package

1. les classes
2. test simulation

  
```
état: initialiser(params état initial)
état: couper la forêt
simulation: initialiser(état initial, paramètres simulation)
simulation: simuler()
simulation: observe (donne integrale de c) (OU simulation: visualizer)
```

f_c(x, params):
    sim = Simulation(params)

In [None]:
# !jupyter nbconvert --to script "wild_fire_sim-OOP.ipynb"

In [None]:
%load_ext snakeviz
%load_ext line_profiler
%matplotlib ipympl

from dataclasses import dataclass, field, InitVar
from abc import ABC, abstractmethod
from typing import List, Tuple, Union, Optional, Callable, Any
from copy import deepcopy
from datetime import datetime

import IPython
from IPython.display import clear_output
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation

In [None]:
mpl.rcParams["pcolor.shading"] = "nearest"

# Simulation

In [None]:
def disk(X, Y, x, y, r):
    return (X - x)**2 + (Y - y)**2 <= r**2


def rect(X, Y, xmin, xmax, ymin, ymax):
    """print(xmin <= X)
    print(X <= xmax)
    print(ymin <= Y)
    print(Y <= ymax)"""
    return np.logical_and.reduce((xmin <= X, X <= xmax, ymin <= Y, Y <= ymax))

def vprint(*msg, verbose: bool, timestamp: bool = False, **kwargs):
    if verbose:
        if timestamp:
            ts = "[" + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "]"
            msg = (ts,) + msg
        print(*msg, **kwargs)

In [None]:
def test_vprint():
    # https://www.geeksforgeeks.org/string-alignment-in-python-f-string/
    fmt = "{:<10}{:<10}  {}"
    print(fmt.format("verbose", "timestamp", "message"))
    for verbose in (False, True):
        for timestamp in (False, True):
            print(fmt.format(verbose, timestamp, ""), end="")
            vprint("test", verbose=verbose, timestamp=timestamp)
            if not verbose:
                print()
# test_vprint()

Champs scalaires :
 - T : température
 - c : combustible
 - b : brulé (booléen)
 
Champ vectoriel:
 - u, v : vent

In [None]:
@dataclass
class Forest:
    N: int
    T_fire: float
    mu: float
    X: np.ndarray = field(init=False, repr=False)
    Y: np.ndarray = field(init=False, repr=False)
    u: np.ndarray = field(init=False, repr=False)
    v: np.ndarray = field(init=False, repr=False)
    dx: float = field(init=False)
    dt: float = field(init=False)
    
    def __post_init__(self):
        x, self.dx = np.linspace(0, 1, self.N, retstep=True)
        self.X, self.Y = np.meshgrid(x, x, indexing="ij")
        self.u = np.cos(np.pi * self.Y)
        self.v = 0.6 * np.sin(np.pi / 2 * (self.X + 0.2))
        
        dt_c = self.dx / np.sqrt(self.u**2 + self.v**2).max()
        dt_d = self.dx**2 / (2 * self.mu)
        self.dt = min(dt_c, dt_d) / 4


@dataclass
class SimulationState:
    X: InitVar[np.ndarray]
    Y: InitVar[np.ndarray]
    x0: InitVar[float]
    y0: InitVar[float]
    r0: InitVar[float]
    T_init_fire: InitVar[float]
    c_init: InitVar[float]
    n_circles: InitVar[int]
    # either None for non-reproducibility, int for seed, or Generator
    rng: InitVar[Union[None, int, np.random.Generator]] = None
    T: np.ndarray = field(init=False, repr=False)
    c: np.ndarray = field(init=False, repr=False)
    
    def __post_init__(self, X, Y, x0, y0, r0, T_init_fire, c_init, n_circles, rng):
        self.T = np.zeros(X.shape)
        self.T[disk(X, Y, x0, y0, r0)] = T_init_fire
        self.c = np.full(X.shape, c_init, dtype=float)
        
        if not isinstance(rng, np.random.Generator):
            rng = np.random.default_rng(rng)
            
        for i in range(n_circles):
            rr = rng.uniform(0.1, 0.2)
            xr = rng.uniform(0.1, 0.9)
            yr = rng.uniform(0.1, 0.9)
            valr = rng.uniform(-5, 5)
            self.c[disk(X, Y, xr, yr, rr)] += valr
    
    def cut_trees(self, X: np.ndarray, Y: np.ndarray, xmin: float, xmax: float, ymin: float, ymax: float):
        self.c[rect(X, Y, xmin, xmax, ymin, ymax)] = -1
    
    def get_masked_T(self, T_fire: float) -> np.ma.MaskedArray:
        return np.ma.masked_where(self.T < T_fire, self.T)
    
    def get_fuel_amount(self, dx: float):
        return self.c.sum() * dx * dx
        

StateHistory = List[Tuple[float, SimulationState]]  # list((t, state))

@dataclass
class Simulation:
    """create simulation, simulate, get results"""
    forest: Forest
    tf: float
    max_iter: int
    history: StateHistory = field(init=False)

    def simulate(self, initial_state: SimulationState, track_state: bool = False, verbose: bool = False) -> Tuple[int, SimulationState]:
        state = deepcopy(initial_state)
        self.history = [] if track_state else None
        self._track_state(0, state)
        for k, t in enumerate(np.arange(self.forest.dt, self.tf + self.forest.dt, self.forest.dt)):
            # pas de simulation
            self._simulation_step(state)
            
            # critères d'arrêt
            stop_msg = None
            if self.max_iter > 0 and k > self.max_iter:
                stop_msg = "Maximum number of iterations has been exceeded."
            elif state.T.max() < self.forest.T_fire:
                stop_msg = f"T < T_fire everywhere."
            
            # ajout à l'historique si le tracking est activé
            self._track_state(0, state)

            # arrêt
            if stop_msg is not None:
                if verbose:
                    print(f"Stopping simulation at {k=} ({t=:.2f}) because", stop_msg)
                break
                
        return state
                
    def _simulation_step(self, state: SimulationState):
        # raccourcis
        back = slice(None, -2)
        mid =  slice(1, -1)
        front = slice(2, None)
        T = state.T
        c = state.c
        T_mid = T[mid,mid]
        u_mid = self.forest.u[mid,mid]
        v_mid = self.forest.v[mid,mid]
        mu = self.forest.mu
        dx = self.forest.dx
        dt = self.forest.dt
        on_fire = T > self.forest.T_fire

        # calcul des dérivées
        Tx_back = (T_mid - T[back,mid]) / dx
        Tx_front = (T[front,mid] - T_mid) / dx
        Ty_back = (T_mid - T[mid,back]) / dx
        Ty_front = (T[mid,front] - T_mid) / dx
        Txx = (Tx_front - Tx_back) / dx
        Tyy = (Ty_front - Ty_back) / dx

        # laplacien, advection, reaction
        diffusion = mu * (Txx + Tyy)
        
        Tx_upwind = np.where(u_mid > 0, Tx_back, Tx_front)
        Ty_upwind = np.where(v_mid > 0, Ty_back, Ty_front)
        advection = -(Tx_upwind * u_mid + Ty_upwind * v_mid)
        
        reaction_T = np.zeros(T.shape)
        reaction_T[np.logical_and(on_fire, c >= 0)] = 10
        reaction_T[np.logical_and(on_fire, c < 0)] = -5
        #reaction_T = np.zeros(T.shape)
        #reaction_T[np.logical_and(on_fire, c >= 0)] = 10
        #reaction_T[np.logical_and(on_fire, c < 0)] = -5
        
        # mise à jour
        T_mid += dt * ((diffusion + advection) + (reaction_T * T)[mid,mid])
        c[on_fire] += dt * -100

        # condition de neumann au bord
        T[:,0] = T[:,1]
        T[:,-1] = T[:,-2]
        T[0,:] = T[1,:]
        T[-1,:] = T[-2,:]
    
    def _track_state(self, t: float, state: SimulationState):
        if self.history is not None:
            self.history.append((t, deepcopy(state)))    

    def visualize_anim(self, nb_frames: int = None, sup_title: str = "", **kwargs):
        """more complicated but smoother"""
        if self.history is None:
            raise Exception("Tracking was disabled during the previous call to `simulate`. Retry after activating tracking.")

        if nb_frames is None:
            nb_frames = len(self.history)
        
        _, init_state = self.history[0]
        X, Y = simulation.forest.X, self.forest.Y

        fig, ax = plt.subplots()

        fuel = ax.pcolormesh(X, Y, init_state.c, cmap=plt.cm.YlGn, vmin=0, vmax=10)
        fuel_cb = fig.colorbar(fuel, ax=ax)
        fuel_cb.set_label("fuel", loc="top")

        fire = ax.pcolormesh(X, Y, init_state.get_masked_T(self.forest.T_fire), cmap=plt.cm.hot, vmin=0, vmax=.15)
        fire_cb = fig.colorbar(fire, ax=ax)
        fire_cb.set_label("fire", loc="top")


        n_arrows = 6
        wind_ind = np.zeros(X.shape, dtype=bool)
        wind_ind[::len(X) // n_arrows, ::len(X) // n_arrows] = True
        wind = ax.quiver(X[wind_ind], Y[wind_ind], self.forest.u[wind_ind], self.forest.v[wind_ind])

        title = (sup_title + "\n" if sup_title else "") + "t = {t:.2f} s, {pct}%"
        ax.set_xlabel("x")
        ax.set_ylabel("y")

        def init():
            fuel.set_array([])
            fire.set_array([])
            return fuel, fire

        def animate(k):
            t, state = self.history[k]
            pct = int(100 * (k + 1) / len(self.history))
            fuel.set_array(state.c.ravel())
            fire.set_array(state.get_masked_T(self.forest.T_fire).ravel())
            ax.set_title(title.format(t=t, pct=pct))
            return fuel, fire
        
        def frames(nb_frames):
            nb_records = len(self.history)
            show_every = max(1, nb_records // (nb_frames + 1))
            yield from range(0, nb_records - 1, show_every)
            yield nb_records - 1

        anim = FuncAnimation(
                fig,
                animate,
                init_func=init,
                frames=frames(nb_frames),
                **kwargs
        )
        
        plt.show()
        
        return anim

In [None]:
forest_params = {
    "N": 100,
    "T_fire": 5e-2,
    "mu": 5e-3
}

simulation_params = {
    "tf": 2.5,
    "max_iter": 1e6
}

initial_state_params = {    
    "x0": .1,
    "y0": .1,
    "r0": .05,
    "T_init_fire": .5,
    "c_init": 5,
    "n_circles": 5
}

In [None]:
forest = Forest(**forest_params)
simulation = Simulation(forest=forest, **simulation_params)
initial_state = SimulationState(X=forest.X, Y=forest.Y, rng=1, **initial_state_params)

x = np.array([.2, .4, .2, .4])
initial_state.cut_trees(forest.X, forest.Y, *x)

final_state = simulation.simulate(
    initial_state=initial_state,
    track_state=True,
    verbose=True
)

In [None]:
anim = simulation.visualize_anim(
    nb_frames=100,
    interval=17,
    repeat=False,
    repeat_delay=1000,
    blit=True
)

# Optimization

In [None]:
OptVar = np.ndarray
CostHistory = List[Tuple[OptVar, float]]

def count_func_calls(function, kwargs):
    ncalls = [0]
    
    def function_wrapper(x):
        ncalls[0] += 1
        return function(x, **kwargs)

    return ncalls, function_wrapper
    

@dataclass
class SimplexOptimizer(ABC):
    xtol: float = 1e-3
    ftol: float = 1e-3
    maxiter: int = 400
    maxfun: int = 400
    verbose: bool = False
    history: CostHistory = field(init=False)
    
    def minimize(self, function, initial_simplex, funckwargs: dict = {}, track_cost: bool = False):
        """find x that minimizes function(x, **funcargs)"""
        # best simplex is simplex[0]
        ncalls, wrapped_function = count_func_calls(function, funckwargs)
        
        simplex = initial_simplex.copy()
        
        i = 0
        self.history = [] if track_cost else None
        while (ncalls[0] < self.maxfun and i < self.maxiter):
            i += 1
            
            fsimplex = np.array([wrapped_function(x) for x in simplex])
            self._sort_simplex(simplex, fsimplex)
            self._track_cost(simplex[0], fsimplex[0])
                                
            if (np.abs(simplex[1:] - simplex[0]).max() < self.xtol and
                np.abs(fsimplex[1:] - fsimplex[0]).max() < self.ftol):
                # converged
                break
            
            self._minimization_step(wrapped_function, simplex, fsimplex)
        else:
            # diverged
            pass
        
        self._sort_simplex(simplex, fsimplex)
        return simplex[0], fsimplex[0]
    
    @abstractmethod
    def _minimization_step(self, wrapped_function, simplex, fsimplex):
        """modify simplex and fsimplex to advance one step"""
    
    def _sort_simplex(self, simplex, fsimplex):
        # best: ind = 0
        # worst: ind = -1
        ind = np.argsort(fsimplex)
        simplex.take(ind, out=simplex, axis=0)
        fsimplex.take(ind, out=fsimplex)
    
    def _track_cost(self, x: OptVar, cost: float):
        if self.history is not None:
            self.history.append((x, cost))
    
    def visualize_anim(self):
        pass
        
@dataclass
class NelderMead(SimplexOptimizer):
    """
    Pluses:
        - no gradient -> not sennsitive to noise
        - simple (no maths)
    Minuses :
        - cases of convergence toward a non-stationary point (grad != 0)
        - costly (many function calls)
    """
    alpha: float = 1
    beta: float = 2
    gamma: float = 1 / 2
    
    def _minimization_step(self, wrapped_function, simplex, fsimplex):
        # sort already happened
        print(fsimplex)
        best, fbest = simplex[0].copy(), fsimplex[0].copy()
        worst, fworst = simplex[-1].copy(), fsimplex[-1].copy()
        
        barycenter = simplex[:-1].mean(axis=0)  # without worst point
        reflected = (1 + self.alpha) * barycenter - self.alpha * worst
        freflected = wrapped_function(reflected)
        if freflected < fbest:
            expansion = (1 + self.beta) * barycenter - self.beta * worst
            fexpansion = wrapped_function(expansion)
            if fexpansion < freflected:
                simplex[-1] = expansion
                fsimplex[-1] = fexpansion
            else:
                simplex[-1] = reflected
                fsimplex[-1] = freflected
        
        if freflected >= fbest:
            if freflected < fworst:
                simplex[-1] = reflected
                fsimplex[-1] = freflected
            else:
                contracted = (1 - self.gamma) * barycenter + self.gamma * worst
                fcontracted = wrapped_function(contracted)
                if fcontracted < fworst:
                    simplex[-1] = contracted
                    fsimplex[-1] = fcontracted
                else:
                    for i in range(simplex.shape[0]):
                        simplex[i] = (1 - self.gamma) * best + self.gamma * simplex[i]
                        fsimplex[i] = wrapped_function(simplex[i])
        

@dataclass
class Torczon(SimplexOptimizer):
    """
    Pluses over Nelder-Mead:
        - paralellisable
        - proof of convergence toward a local minimum
    """
    alpha: float = 1 / 2
    beta: float = 1 / 2
    gamma: float = 2
    
    def _minimization_step(self, wrapped_function, simplex, fsimplex):
        # sort already happened
        best, fbest = simplex[0].copy(), fsimplex[0].copy()
        reflexion = (1 + self.alpha) * best - alpha * simplex
        freflexion = np.fromiter()
        

In [None]:
def cost_function(x: OptVar, simulation: Simulation, initial_state: SimulationState) -> float:
    """initial state before cutting down the trees"""
    xmin, xmax, ymin, ymax = x
    forest = simulation.forest
    
    initial_state = deepcopy(initial_state)
    initial_state.cut_trees(forest.X, forest.Y, *x)
    final_state = simulation.simulate(initial_state=initial_state)
    
    forest_cost = initial_state.get_fuel_amount(dx=forest.dx) - final_state.get_fuel_amount(dx=forest.dx)
    area_cost = abs(xmax - xmin) * abs(ymax - ymin)
    position_cost = max(0, .2 - ymin)
        
    return forest_cost + 10 * area_cost + 100 * position_cost

In [None]:
forest = Forest(**forest_params)
simulation = Simulation(forest=forest, **simulation_params)
initial_state = SimulationState(X=forest.X, Y=forest.Y, **initial_state_params)

initial_simplex = np.random.random((5,4))

funckwargs = {"simulation": simulation, "initial_state": initial_state}

neldermead = NelderMead()
neldermead.minimize(
    function=cost_function,
    initial_simplex=initial_simplex,
    funckwargs=funckwargs,
    track_cost=True
)

In [None]:
neldermead.history

In [2]:
import yaml

In [14]:
with open("params.yml", "r") as pfile:
    params = yaml.safe_load(pfile.read())

In [15]:
params

{'forest': {'N': 100, 'T_fire': 0.05, 'mu': 0.005},
 'simulation': {'tf': 2.5, 'max_iter': 10000},
 'initial_state': {'x0': 0.1,
  'y0': 0.1,
  'z0': 0.05,
  'T_init_fire': 0.5,
  'c_init': 5,
  'n_circles': 5,
  'rng': 1}}