In [1]:
import os
import random
from functools import partial
from decimal import Decimal
import numpy as np
import scipy.io as sio
import pysindy as ps
from tqdm import trange

from pymoo.algorithms.moo.nsga2 import NSGA2
from pymoo.core.problem import ElementwiseProblem
from pymoo.core.sampling import Sampling
from pymoo.core.crossover import Crossover
from pymoo.core.mutation import Mutation
from pymoo.core.duplicate import ElementwiseDuplicateElimination
from pymoo.optimize import minimize
from pymoo.visualization.scatter import Scatter

def sci_format(n):
    sf = '%.2E' % Decimal(n)
    sf = sf.split('E')
    return float(sf[0]), int(sf[1])

In [2]:
n_poly = 2
n_derivatives = 3
n_modules = 8

In [3]:
data_path = "../PDE-Discovery-EC/Datasets/"
data = sio.loadmat(os.path.join(data_path, "burgers.mat"))
u_clean = (data['usol']).real; u = u_clean.copy()
x = (data['x'][0]).real
t = (data['t'][:,0]).real
dt = t[1]-t[0]; dx = x[2]-x[1]
xt = np.array([x.reshape(-1, 1), t.reshape(1, -1)], dtype=object)
X, T = np.meshgrid(x, t)
XT = np.asarray([X, T]).T

In [4]:
function_library = ps.PolynomialLibrary(degree=n_poly, include_bias=False)

weak_lib = ps.WeakPDELibrary(
    function_library=function_library,
    derivative_order=n_derivatives,
    spatiotemporal_grid=XT,
    include_bias=True,
    diff_kwargs={"is_uniform":True},
    K=10000
)

X_pre = np.array(weak_lib.fit_transform(np.expand_dims(u, -1)))
y_pre = weak_lib.convert_u_dot_integral(np.expand_dims(u, -1))

In [5]:
base_poly = np.array([[p, 0] for p in range(1, n_poly+1)])
base_derivative = np.array([[0, d] for d in range(1, n_derivatives+1)])
modules = [(0, 0)] if weak_lib.include_bias else []
modules += [(p, 0) for p in range(1, n_poly+1)] + \
            [(0, d) for d in range(1, n_derivatives+1)] + \
            [tuple(p+d) for d in base_derivative for p in base_poly]
assert len(modules) == len(weak_lib.get_feature_names())
base_features = dict(zip(modules, X_pre.T))
u_t = y_pre.copy()

In [6]:
class MyProblem(ElementwiseProblem):
    def __init__(self, n_poly, n_derivatives, n_modules, 
                 base_features, u_t, epsilon=0):
        super().__init__(n_var=1, n_obj=2, n_ieq_constr=0)
        self.n_poly = n_poly
        self.n_derivatives = n_derivatives
        self.n_modules = n_modules
        self.base_features = base_features
        self.u_t = u_t
        self.epsilon = epsilon
        self.sample_size = np.prod(self.u_t.shape)

    def _evaluate(self, X, out, *args, **kwargs):
        genome = X[0]
        coeff, mse = self.compute_genome_coefficient(genome)
        mse = mse/self.sample_size
        complexity_penalty = self.epsilon*len(genome)
        out["F"] = [mse, complexity_penalty]
        
    def numericalize_genome(self, genome):
        return np.stack([self.base_features[tuple(module)] 
                         for module in genome], axis=-1)

    def compute_genome_coefficient(self, genome):
        features = self.numericalize_genome(genome)
        features = features.reshape(-1, features.shape[-1])
        coeff, error, _, _ = np.linalg.lstsq(features, self.u_t, rcond=None)
        return coeff, error[0]
    
    def generate_module(self, n_poly, n_derivatives):
        return (random.randint(0, n_poly), random.randint(0, n_derivatives))
    
class PopulationSampling(Sampling):
    def _do(self, problem, n_samples, **kwargs):
        X = np.full((n_samples, 1), None, dtype=object)
        X_set = set()
        i = 0
        while i < n_samples:
            n_modules = random.randint(1, problem.n_modules)
            genome = frozenset(problem.generate_module(problem.n_poly, problem.n_derivatives) for _ in range(n_modules))
            if len(genome) > 0 and genome not in X_set:
                X_set.add(genome)
                X[i, 0] = genome
                i += 1
        return X
    
class DuplicateElimination(ElementwiseDuplicateElimination):
    def is_equal(self, g1, g2):
        return g1.X[0] == g2.X[0]

In [7]:
class GenomeCrossover(Crossover):
    def __init__(self):
        # define the crossover: number of parents and number of offsprings
        super().__init__(2, 2)

    def _do(self, problem, X, **kwargs):
        # The input of has the following shape (n_parents, n_matings, n_var)
        _, n_matings, n_var = X.shape

        # The output owith the shape (n_offsprings, n_matings, n_var)
        # Because there the number of parents and offsprings are equal it keeps the shape of X
        Y = np.full_like(X, None, dtype=object)
        
        # for each mating provided
        for k in range(n_matings):
            # get the first and the second parent          
            Y[0, k, 0], Y[1, k, 0] = self.crossover_permutation(X[0, k, 0], X[1, k, 0])
            
        return Y
    
    def crossover_permutation(self, genome1, genome2):
        collection = list(genome1) + list(genome2)
        random.shuffle(collection)
        return frozenset(collection[:len(genome1)]), frozenset(collection[len(genome1):])
    
class GenomeMutation(Mutation):
    def __init__(self, add_rate=0.4, del_rate=0.5, order_rate=0.4):
        super().__init__()
        self.add_rate = add_rate
        self.del_rate = del_rate
        self.order_rate = order_rate

    def _do(self, problem, X, **kwargs):
        for i in range(len(X)):
            if random.random() < self.add_rate:
                X[i, 0] = self.add_mutate(problem, X[i, 0])
            if random.random() < self.del_rate:
                X[i, 0] = self.del_mutate(problem, X[i, 0])
            if random.random() < self.order_rate:
                X[i, 0] = self.module_mutate(problem, X[i, 0])
        return X
    
    def add_mutate(self, problem, genome, max_iter=3):
        for _ in range(max_iter):
            new_module = problem.generate_module(problem.n_poly, problem.n_derivatives)
            if new_module not in genome:
                return genome.union(frozenset({new_module}))
        return genome
    
    def del_mutate(self, problem, genome, max_iter=3):
        genome = list(genome)
        lg = len(genome)
        if lg > 0:
            if lg == 1:
                for _ in range(max_iter):
                    new_module = problem.generate_module(problem.n_poly, problem.n_derivatives)
                    if new_module != genome[0]:
                        return frozenset({new_module})
            else:
                genome.pop(random.randint(0, lg-1))
        return frozenset(genome)
    
    def module_mutate(self, problem, genome):
        if len(genome) == 0:
            return genome
        genome = set(genome)
        genome.remove(random.choice(list(genome)))
        for _ in range(3):
            new_module = problem.generate_module(problem.n_poly, problem.n_derivatives)
            if new_module not in genome:
                genome.add(new_module)
                return frozenset(genome)
        return frozenset(genome)

In [8]:
pop_size = 300
problem = MyProblem(n_poly, n_derivatives, n_modules, 
                    base_features, u_t, 0)
pop = PopulationSampling().do(problem, pop_size)
pop = [[pop[i].X[0]] for i in range(len(pop))]
epi = 10**(sci_format(np.median(problem.evaluate(pop)[:, 0]))[1])
del pop

In [9]:
problem = MyProblem(n_poly, n_derivatives, n_modules, 
                    base_features, u_t, epi)

algorithm = NSGA2(pop_size=pop_size,
                  sampling=PopulationSampling(),
                  crossover=GenomeCrossover(),
                  mutation=GenomeMutation(),
                  eliminate_duplicates=DuplicateElimination())

res = minimize(problem,
               algorithm,
               ('n_gen', 20),
               verbose=True)

best = res.X[np.argsort(res.F[:, 0]+res.F[:, 1])][0]
best

n_gen  |  n_eval  | n_nds  |      eps      |   indicator  
     1 |      300 |      6 |             - |             -
     2 |      600 |      5 |  0.2000000000 |         nadir
     3 |      900 |      7 |  0.1666666667 |         nadir
     4 |     1200 |      8 |  0.1428571429 |         nadir
     5 |     1500 |      8 |  0.000000E+00 |             f
     6 |     1800 |      8 |  0.000000E+00 |             f
     7 |     2100 |      8 |  0.000000E+00 |             f
     8 |     2400 |      9 |  0.1250000000 |         nadir
     9 |     2700 |      9 |  2.539194E-13 |             f
    10 |     3000 |     10 |  0.1111111111 |         nadir
    11 |     3300 |     10 |  0.000000E+00 |             f
    12 |     3600 |     10 |  0.000000E+00 |             f
    13 |     3900 |     10 |  0.000000E+00 |             f
    14 |     4200 |     10 |  4.162444E-15 |             f
    15 |     4500 |     10 |  1.981791E-13 |             f
    16 |     4800 |     10 |  1.981791E-13 |            

array([frozenset({(1, 1), (0, 2)})], dtype=object)