In [6]:
from typing import List, Dict, Tuple, Callable
from functools import reduce

import torch
import copy
import numpy as np

In [7]:
eq1 = {
    'd^3x/dt^3':{
        'coeff': 1,
        'term': [0, 0, 0],
        'pow': 1,
        'var': [0]
    },
    '-x*alpha':{
        'coeff': -20,
        'term': [None],
        'pow': 1,
        'var': [0]
    },
    '+beta*x*y':{
        'coeff': 20,
        'term': [[None], [None]],
        'pow': [1, 1],
        'var': [0, 1]
    }
}

eq2 = {
    'd^3y/dt^3':{
        'coeff': 1,
        'term': [0, 0, 0],
        'pow': 1,
        'var': [1]
    },
    '-y*alpha':{
        'coeff': -20,
        'term': [None],
        'pow': 1,
        'var': [1]
    },
    '+beta*x*y':{
        'coeff': 20,
        'term': [[None], [None]],
        'pow': [1, 1],
        'var': [0, 1]
    }
}

In [8]:
grid = np.linspace(0, 10, 100)
digits = np.floor(np.log10((grid[1] - grid[0])/2.)-1)
print(digits)

grid_rounded = {np.round(grid_val, -int(digits)): idx for idx, grid_val in np.ndenumerate(grid)}
print(grid_rounded)
grid_rounded[0.000]

-3.0
{0.0: (0,), 0.101: (1,), 0.202: (2,), 0.303: (3,), 0.404: (4,), 0.505: (5,), 0.606: (6,), 0.707: (7,), 0.808: (8,), 0.909: (9,), 1.01: (10,), 1.111: (11,), 1.212: (12,), 1.313: (13,), 1.414: (14,), 1.515: (15,), 1.616: (16,), 1.717: (17,), 1.818: (18,), 1.919: (19,), 2.02: (20,), 2.121: (21,), 2.222: (22,), 2.323: (23,), 2.424: (24,), 2.525: (25,), 2.626: (26,), 2.727: (27,), 2.828: (28,), 2.929: (29,), 3.03: (30,), 3.131: (31,), 3.232: (32,), 3.333: (33,), 3.434: (34,), 3.535: (35,), 3.636: (36,), 3.737: (37,), 3.838: (38,), 3.939: (39,), 4.04: (40,), 4.141: (41,), 4.242: (42,), 4.343: (43,), 4.444: (44,), 4.545: (45,), 4.646: (46,), 4.747: (47,), 4.848: (48,), 4.949: (49,), 5.051: (50,), 5.152: (51,), 5.253: (52,), 5.354: (53,), 5.455: (54,), 5.556: (55,), 5.657: (56,), 5.758: (57,), 5.859: (58,), 5.96: (59,), 6.061: (60,), 6.162: (61,), 6.263: (62,), 6.364: (63,), 6.465: (64,), 6.566: (65,), 6.667: (66,), 6.768: (67,), 6.869: (68,), 6.97: (69,), 7.071: (70,), 7.172: (71,), 7.27

(0,)

In [9]:
grid

array([ 0.        ,  0.1010101 ,  0.2020202 ,  0.3030303 ,  0.4040404 ,
        0.50505051,  0.60606061,  0.70707071,  0.80808081,  0.90909091,
        1.01010101,  1.11111111,  1.21212121,  1.31313131,  1.41414141,
        1.51515152,  1.61616162,  1.71717172,  1.81818182,  1.91919192,
        2.02020202,  2.12121212,  2.22222222,  2.32323232,  2.42424242,
        2.52525253,  2.62626263,  2.72727273,  2.82828283,  2.92929293,
        3.03030303,  3.13131313,  3.23232323,  3.33333333,  3.43434343,
        3.53535354,  3.63636364,  3.73737374,  3.83838384,  3.93939394,
        4.04040404,  4.14141414,  4.24242424,  4.34343434,  4.44444444,
        4.54545455,  4.64646465,  4.74747475,  4.84848485,  4.94949495,
        5.05050505,  5.15151515,  5.25252525,  5.35353535,  5.45454545,
        5.55555556,  5.65656566,  5.75757576,  5.85858586,  5.95959596,
        6.06060606,  6.16161616,  6.26262626,  6.36363636,  6.46464646,
        6.56565657,  6.66666667,  6.76767677,  6.86868687,  6.96

In [10]:
np.round(grid, -int(digits))

array([ 0.   ,  0.101,  0.202,  0.303,  0.404,  0.505,  0.606,  0.707,
        0.808,  0.909,  1.01 ,  1.111,  1.212,  1.313,  1.414,  1.515,
        1.616,  1.717,  1.818,  1.919,  2.02 ,  2.121,  2.222,  2.323,
        2.424,  2.525,  2.626,  2.727,  2.828,  2.929,  3.03 ,  3.131,
        3.232,  3.333,  3.434,  3.535,  3.636,  3.737,  3.838,  3.939,
        4.04 ,  4.141,  4.242,  4.343,  4.444,  4.545,  4.646,  4.747,
        4.848,  4.949,  5.051,  5.152,  5.253,  5.354,  5.455,  5.556,
        5.657,  5.758,  5.859,  5.96 ,  6.061,  6.162,  6.263,  6.364,
        6.465,  6.566,  6.667,  6.768,  6.869,  6.97 ,  7.071,  7.172,
        7.273,  7.374,  7.475,  7.576,  7.677,  7.778,  7.879,  7.98 ,
        8.081,  8.182,  8.283,  8.384,  8.485,  8.586,  8.687,  8.788,
        8.889,  8.99 ,  9.091,  9.192,  9.293,  9.394,  9.495,  9.596,
        9.697,  9.798,  9.899, 10.   ])

In [11]:
def get_terms_der_order(equation: Dict, variable_idx: int) -> np.ndarray:
    '''
    Get the highest orders of the ``variable_idx``-th variable derivative in the equation terms.
    '''
    term_max_ord = np.zeros(len(equation))
    for term_idx, term_dict in enumerate(equation.values()):
        # print(term_dict['term'])
        if isinstance(term_dict['var'], list) and len(term_dict['var']) > 1:
            max_ord = 0
            for arg_idx, deriv_ord in enumerate(term_dict['term']):
                if isinstance(term_dict['pow'][arg_idx], int) and term_dict['var'][arg_idx] == variable_idx:
                    max_ord = max(max_ord, len([var for var in deriv_ord if var is not None]))
            term_max_ord[term_idx] = max_ord
        elif isinstance(term_dict['var'], int):
            if isinstance(term_dict['pow'], int) and term_dict['var'] == variable_idx:
                term_max_ord[term_idx] = max(0, len([var for var in term_dict['term'] if var is not None]))
        elif isinstance(term_dict['var'], list) and len(term_dict['var']) == 1:
            if isinstance(term_dict['pow'], int) and term_dict['var'][0] == variable_idx:
                term_max_ord[term_idx] = max(0, len([var for var in term_dict['term'] if var is not None]))

    return term_max_ord

def get_higher_order_coeff(equation: Dict, orders: np.ndarray, var: int) -> Tuple[List]:
    def transform_term(term: Dict, deriv_key: list, var: int) -> Dict:
        term_filtered = copy.deepcopy(term)
        if (isinstance(term['var'], int) and term['var'] == var) or (isinstance(term['var'], list) 
                                                                     and len(term['var']) == 1 and term['var'][0] == var):
            term_filtered['term'] = [None,]
            term_filtered['pow'] = 0
        else:
            term_idx = [der_var for idx, der_var in enumerate(term_filtered['term']) 
                        if der_var == deriv_key and term_filtered['pow'][idx] == var][0]
            term_filtered['term'][term_idx] = [None,]
            term_filtered['pow'][term_idx] = 0
        return term_filtered            

    denom_terms = []
    numer_terms = []
    for term_idx, term in enumerate(equation.values()):
        if orders[term_idx] == np.max(orders):
            denom_terms.append(transform_term(term, deriv_key=[0,]*int(np.max(orders)), var=var))
        else:
            numer_terms.append(term)
    return [denom_terms, numer_terms]

def get_eq_order(equation, variables: List[str]):
    eq_var = 0; eq_orders = np.zeros(len(equation))
    for var_idx in range(len(variables)):
        orders = get_terms_der_order(equation=equation, variable_idx=var_idx)
        if np.max(orders) > np.max(eq_orders):
            eq_var = var_idx; eq_orders = orders
    # print(f'Calling get_higher_order_coeff for var {eq_var} with orders {eq_orders}')
    return eq_var, eq_orders

def parse_right_part(equation, eq_var, eq_orders: np.ndarray):
    denom_terms, numer_terms = get_higher_order_coeff(equation=equation, orders=eq_orders,
                                                      var=eq_var)
    return denom_terms, numer_terms

def replace_operator(term: Dict, variables: List):
    '''

    Variables have to be in form of [(0, [None]), (0, [0,]), (0, [0, 0]), (0, [0, 0, 0]), (1, [None,]), ... ]
    where the list elements are factors, taken as derivatives: (variable, differentiations), and the index in list
    matches the index of dynamics operator output.

    '''
    term_ = copy.deepcopy(term)
    if isinstance(term_['var'], list) and len(term_['var']) > 1:
        for arg_idx, deriv_ord in enumerate(term_['term']):
            term_['var'][arg_idx]  = variables.index((term_['var'][arg_idx], deriv_ord))
            term_['term'][arg_idx] = [None,]
    elif isinstance(term['var'], int) or (isinstance(term_['var'], list) and len(term_['var']) == 1):
        if isinstance(term['var'], int):
            term_var = term_['var']
        else:
            term_var = term_['var'][0]
        if isinstance(term['pow'], int):
            term_['var']  = variables.index((term_var, term_['term']))
            term_['term'] = [None,]
    return term_

In [None]:
class ImplicitEquation(object):
    def __init__(self, system: List, grid: np.ndarray, variables: List[str]):
        self.grid_dict = grid

        self._dynamics_operators = []
        var_order = []; vars_with_eqs = {}

        for var, order in [get_eq_order(equation, variables) for equation in system]:
            var_order.extend([(var, [None,])] + [(var, [0,]*(idx+1)) for idx in range(int(np.max(order))-1)])
            if len(vars_with_eqs) == 0:
                vars_with_eqs[int(np.max(order)) - 1] = (var, order)
            else:
                vars_with_eqs[list(vars_with_eqs.keys())[-1] + int(np.max(order))] = (var, order)

        for var_idx, var in enumerate(var_order):
            if var_idx in vars_with_eqs.keys():
                operator = get_higher_order_coeff(equation = system[vars_with_eqs[var_idx][0]],
                                                  orders = vars_with_eqs[var_idx][1], 
                                                  var = vars_with_eqs[var_idx][0])
                operator[0] = [replace_operator(denom_term, var_order) for denom_term in operator[0]]
                operator[1] = [replace_operator(numer_term, var_order) for numer_term in operator[1]]
            else:
                operator = [None, self.create_first_ord_eq(var_idx + 1)]
            self._dynamics_operators.append(operator)

    def __call__(self, t, y):
        values = np.empty(len(self._dynamics_operators))
        for idx, operator in enumerate(self._dynamics_operators):
            if operator[0] is None:
                denom = 1
            else:
                denom = [self.term_callable(term, t, y) for term in operator[0]]
                if np.isclose(denom, 0):
                    raise ZeroDivisionError('Denominator in the dynamics operator is close to zero.')
            numerator = [self.term_callable(term, t, y) for term in operator[1]]
            values[idx] = np.sum(numerator)/np.sum(denom)
        return values

    @property
    def grid_dict(self):
        return self._grid_rounded

    @grid_dict.setter
    def grid_dict(self, grid_points):
        self._grid_step = grid_points[1] - grid_points[0]
        digits = np.floor(np.log10(self._grid_step/2.)-1)
        self._grid_rounded = {np.round(grid_val, -int(digits)): idx 
                              for idx, grid_val in np.ndenumerate(grid_points)}

    def create_first_ord_eq(self, var: int) -> List[Tuple]:
        '''
        Example of order: np.array([3., 0., 0.]) for third ord eq. 
        '''
        return [{'coeff' : 1., 
                 'term'  : [None,],
                 'pow'   : 1,
                 'var'   : var},] # TODO: validate

    def merge_coeff(self, coeff: np.ndarray, t: float):
        try:
            return self.grid_dict[t]
        except KeyError:
            for grid_loc, grid_idx in self.grid_dict.items():
                if grid_loc < t and grid_loc + self._grid_step > t:
                    print('Search in ', grid_loc, grid_loc + self._grid_step)
                    left_loc, right_loc = grid_loc, grid_loc + self._grid_step
                    left_idx, right_idx = grid_idx[0], grid_idx[0] + 1
                    break
            val = coeff[left_idx] + (t - left_loc) / (right_loc - left_loc) * (coeff[right_idx] - coeff[left_idx])
            return val

    def term_callable(self, term: Dict, t, y):
        if isinstance(term['coeff'], Callable):
            k = term['coeff'](t)
        elif isinstance(term['coeff'], torch.nn.Sequential):
            k = term['coeff'](torch.from_numpy(t).reshape((1, 1).float()))
        elif isinstance(term['coeff'], np.ndarray):
            k = self.merge_coeff(term['coeff'], t)
        else:
            k = term['coeff']
        if isinstance(term['var'], int):
            if isinstance(term['pow'], int):
                values = [y[term['var']]**term['pow'],]
            elif isinstance(term['pow'], torch.nn.Sequential):
                values = [term['power'](y[term['var']]).detach().numpy(),]
            else:                
                values = [term['pow'](y[term['var']]),]
        elif isinstance(term['var'], list):
            values = []
            for var_idx, var in enumerate(term['var']):
                if isinstance(term['pow'], int):
                    values.append(y[var]**term['pow'][var_idx])
                elif isinstance(term['pow'], torch.nn.Sequential):
                    values.append(term['pow'][var_idx](y[var]).detach().numpy())
                    # TODO: validate on LV eqs
                else:                
                    values.append(term['pow'][var_idx](y[var]))
        return reduce(lambda x, z: x*z, values, k)

In [64]:
eq1 = {
    'd^3x/dt^3':{
        'coeff': np.sin(grid),
        'term': [0, 0],
        'pow': 1,
        'var': [0]
    },
    '-x*alpha':{
        'coeff': -20,
        'term': [None],
        'pow': 1,
        'var': [0]
    },
    '+beta*x*y':{
        'coeff': 20,
        'term': [[None], [None]],
        'pow': [1, 1],
        'var': [0, 1]
    }
}

eq2 = {
    'd^3y/dt^3':{
        'coeff': np.cos(grid),
        'term': [0, 0],
        'pow': 1,
        'var': [1]
    },
    '-y*alpha':{
        'coeff': -20,
        'term': [None],
        'pow': 1,
        'var': [1]
    },
    '+beta*x*y':{
        'coeff': 20,
        'term': [[None], [None]],
        'pow': [1, 1],
        'var': [0, 1]
    }
}

In [65]:
eq = ImplicitEquation(system = [eq1, eq2], grid = grid, variables=['x', 'y'])
print(eq1['d^3x/dt^3'])

{'coeff': array([ 0.        ,  0.10083842,  0.20064886,  0.2984138 ,  0.39313661,
        0.48385164,  0.56963411,  0.64960951,  0.72296256,  0.78894546,
        0.84688556,  0.8961922 ,  0.93636273,  0.96698762,  0.98775469,
        0.99845223,  0.99897117,  0.98930624,  0.96955595,  0.93992165,
        0.90070545,  0.85230712,  0.79522006,  0.73002623,  0.65739025,
        0.57805259,  0.49282204,  0.40256749,  0.30820902,  0.21070855,
        0.11106004,  0.01027934, -0.09060615, -0.19056796, -0.28858706,
       -0.38366419, -0.47483011, -0.56115544, -0.64176014, -0.7158225 ,
       -0.7825875 , -0.84137452, -0.89158426, -0.93270486, -0.96431712,
       -0.98609877, -0.99782778, -0.99938456, -0.99075324, -0.97202182,
       -0.94338126, -0.90512352, -0.85763861, -0.80141062, -0.73701276,
       -0.66510151, -0.58640998, -0.50174037, -0.41195583, -0.31797166,
       -0.22074597, -0.12126992, -0.0205576 ,  0.0803643 ,  0.18046693,
        0.27872982,  0.37415123,  0.46575841,  0.55261

In [66]:
eq._dynamics_operators

[[None, [{'coeff': 1.0, 'term': [None], 'pow': 1, 'var': 1}]],
 [[{'coeff': array([ 0.        ,  0.10083842,  0.20064886,  0.2984138 ,  0.39313661,
            0.48385164,  0.56963411,  0.64960951,  0.72296256,  0.78894546,
            0.84688556,  0.8961922 ,  0.93636273,  0.96698762,  0.98775469,
            0.99845223,  0.99897117,  0.98930624,  0.96955595,  0.93992165,
            0.90070545,  0.85230712,  0.79522006,  0.73002623,  0.65739025,
            0.57805259,  0.49282204,  0.40256749,  0.30820902,  0.21070855,
            0.11106004,  0.01027934, -0.09060615, -0.19056796, -0.28858706,
           -0.38366419, -0.47483011, -0.56115544, -0.64176014, -0.7158225 ,
           -0.7825875 , -0.84137452, -0.89158426, -0.93270486, -0.96431712,
           -0.98609877, -0.99782778, -0.99938456, -0.99075324, -0.97202182,
           -0.94338126, -0.90512352, -0.85763861, -0.80141062, -0.73701276,
           -0.66510151, -0.58640998, -0.50174037, -0.41195583, -0.31797166,
           -0.22

In [67]:
len(eq._dynamics_operators)

4

In [68]:
eq(0.1, np.array([1, 1, 1, 1]))

Search in  0.0 0.10101010101010101


TypeError: 'list' object is not callable