In [None]:
import re
import itertools

import numpy as np
import pandas as pd

from tqdm import tqdm

import mph
from tools import plots

import matplotlib.pyplot as plt
import plotly.express as px
import peewee

In [None]:
client = mph.start()

In [None]:
model = client.load('models/Kinetic model.mph')

In [None]:
client.clear()

In [None]:
def compare(from_node: mph.Node, to_node: mph.Node):
    from_properties, to_properties = from_node.properties(), to_node.properties()
    all_properties = sorted(set(from_properties).union(set(to_properties)))

    print(f"{'from ' + from_node.name(): >53} | {'to '+to_node.name(): <53}")
    for key in all_properties:
        from_prop = str(from_properties.get(key, None))
        to_prop = str(to_properties.get(key, None))

        string = f'{"*" if (from_prop!=to_prop) else " "} {key: <30} '+\
                 f'{from_prop[:20]: <20} | {to_prop[:20]: <20}'
        print(string)


def copy_settings(from_node: mph.Node, to_node: mph.Node):
    for i in range(2):
        from_properties, to_properties = from_node.properties(), to_node.properties()

        auto_settings = {}
        for key, value in from_properties.items():
            if (str(value) == 'auto') or (to_properties.get(key, None) is None):
                auto_settings.update({key: value})
                continue

            to_node.property(name=key, value=value)

        for key, value in auto_settings.items():
            if to_properties.get(key, None) is None: continue
            to_node.property(name=key, value=value)


def copy_solver(from_solver: mph.Node, to_solver: mph.Node, verbose=False):
    copy_settings(from_solver, to_solver)

    from_dict = {node.name(): node for node in from_solver.children()}
    to_dict = {node.name(): node for node in to_solver.children()}

    for node in from_dict:
        from_node = from_dict[node]
        to_node = to_dict[node]
        copy_settings(from_node, to_node)
        if verbose:
            compare(from_node, to_node)
            print('*' * 120)


def input_check(string):
    if string is None:
        while string != 'q':
            string = input(f'Set {string=}, to quit - q:')
            string = string.strip()
    return string

In [None]:
class AbstractStudy:
    study_name = None

    def __init__(self, model: mph.Model):
        self.comsol_model = model

    @property
    def time_end(self):
        return self.comsol_model.parameters()['time_end']

        
    @property
    def constants(self):
        rule = lambda key: key[0] == 'K'
        return {
            key: value
            for key, value in self.comsol_model.parameters().items()
            if rule(key)
        } #yapf: disable

    @property
    def initial(self):
        rule = lambda key: ('0' in key) or (key in ['light'])
        return {
            key: value
            for key, value in self.comsol_model.parameters().items()
            if rule(key)
        } #yapf: disable

    @property
    def species(self) -> dict:
        reaction_node_children = [
            i.name() for i in self.nodes['reaction'].children()
        ]

        species = re.findall(
            string='\n'.join(reaction_node_children),
            pattern='Species: (.*)',
        )
        return {specie: f'reaction.c_{specie}' for specie in species}

    @property
    def nodes(self):
        study_node = self.comsol_model / 'studies' / f'{self.study_name}'
        assert study_node.exists(), f'Study node does not exist'

        solution_node = self.comsol_model / 'solutions' / f'{self.study_name}_Solution'
        assert study_node.exists(), f'Solution node does not exist'

        data_node = self.comsol_model / 'datasets' / f'{self.study_name}_data'
        assert study_node.exists(), f'Data node does not exist'

        reaction_node = model / 'physics' / 'Reaction Engineering'
        assert study_node.exists(), f'Reaction node does not exist'

        nodes_dict = {
            'study': study_node,
            'solution': solution_node,
            'data': data_node,
            'reaction': reaction_node
        }
        return nodes_dict

    def set_parametrs(self, **parameters):
        for key, value in parameters.items():
            self.comsol_model.parameter(name=key, value=value)

    @staticmethod
    def _set_node_properties(node: mph.Node, **properties):
        for key, value in properties.items():
            node.property(key, value)

    def evaluate(
        self,
        functions: dict,
        outer_number=1,
    ) -> pd.DataFrame:

        model = self.comsol_model
        functions.update({'time': 't'})
        row_data = model.evaluate(
            list(functions.values()),
            dataset=self.nodes['data'].name(),
            outer=outer_number,
        )
        return pd.DataFrame(row_data, columns=list(functions))

    def solve(self):
        self.comsol_model.solve(study=self.nodes['study'].name())


In [None]:
class Generator(AbstractStudy):
    study_name = 'Generator'

    def evaluate(self, functions={}) -> pd.DataFrame:
        return super().evaluate(outer_number=1, functions=functions)


In [None]:
def make_comdinations(diap: dict):
    keys = list(diap.keys())
    values = [diap[key] for key in keys]
    combinations = list(itertools.product(*values))
    result = [dict(zip(keys, comb)) for comb in combinations]
    return result


def sweep(combinations, generator: Generator):

    result = []
    combinations = tqdm(iterable=combinations)
    for combination in combinations:
        generator.set_parametrs(**combination)
        generator.solve()
        df = generator.evaluate(generator.species())
        df.loc[:, combination.keys()] = list(combination.values())
        result.append(df)
    return pd.concat(result)



In [None]:
class Sensitivity(AbstractStudy):
    study_name = 'Sensitivity'

    @property
    def sensitivities(self):
        return {key: f'fsens({key})' for key in self.constants}

    @property
    def nodes(self):
        nodes_dict = super().nodes

        sensivity_node = nodes_dict['study'] / 'Sensitivity'
        assert sensivity_node.exists(), f'Estimation node does not exist'

        nodes_dict.update({'sensitivity': sensivity_node})
        return nodes_dict

    @property
    def constants(self):
        all_parameters = self.nodes['sensitivity'].properties()
        filtered_properties = {
            key: value
            for key, value
            in zip(all_parameters['pname'], all_parameters['initval'])
        } #yapf: disable

        rule = lambda key: key[0] == 'K'
        return {
            key: value
            for key, value in filtered_properties.items()
            if rule(key)
        } #yapf: disable

    @property
    def initial(self):
        all_parameters = self.nodes['sensitivity'].properties()
        filtered_properties = {
            key: value
            for key, value
            in zip(all_parameters['pname'], all_parameters['initval'])
        } #yapf: disable

        rule = lambda key: ('0' in key) or (key in ['light'])
        return {
            key: value
            for key, value
            in filtered_properties.items()
            if rule(key)
        } #yapf: disable

    def set_parametrs(self, **parameters):
        all_parameters = self.constants
        all_parameters.update(self.initial)
        old_len = len(all_parameters)

        all_parameters.update(parameters)
        assert len(all_parameters) == old_len, 'Parametrs not exist'

        self.nodes['sensitivity'].property(
            name='pname',
            value=list(all_parameters),
        )
        self.nodes['sensitivity'].property(
            name='initval',
            value=[str(i) for i in all_parameters.values()],
        )

    @property
    def target(self):
        result = self.nodes['sensitivity'].properties()['optobj'][0]
        return result.replace('comp.reaction.c_', '')

    def set_target(self, target: str):
        assert f'reaction.c_{target}' in self.species.values(), 'Target is not specie'
        self.nodes['sensitivity'].property(
            name='optobj',
            value=[f'comp.reaction.c_{target}'],
        )


In [None]:
# TODO: out of found parametrs
class Estimator(AbstractStudy):
    study_name = 'Estimator'

    @property
    def nodes(self):
        nodes_dict = super().nodes
        
        estimation_node = self.comsol_model / 'physics' / 'Reaction Engineering' / 'Estimation'
        assert estimation_node.exists(), f'Estimation node does not exist'
        
        nodes_dict.update({'estimation': estimation_node})
        return nodes_dict

    @property
    def experiments(self):
        return self.nodes['estimation_node'].children()

    @property
    def tables(self):
        tables = self.comsol_model / 'tables'
        experiment_tables = [
            node for node in tables.children() if 'Experiment' in node.name()
        ]
        return experiment_tables

    def create_one_experiment(
        self,
        data,
        data_columns,
        experiment_i,
        path=r'D:\WORKS\COMSOL_polymers\Batch\generator_out_short.csv',
    ):
        experiment_name = f'exp{experiment_i}'

        # create experiment
        self.nodes['estimation_node'].java.create(
            experiment_name,
            "Experiment",
            -1,
        )
        experiment = self.experiments[-1]

        # create table
        table_tag = f"tbl_compreactionest1{experiment_name}"
        table = (self.comsol_mode / 'tables').java.create(table_tag, "Table")
        table.label(f"Experiment {experiment_i} Table")
        table.setTableData(data)
        table.active(False)

        # set up parametrs
        variables_dict = {'Time': 't'}
        variables_dict.update(self.species())
        variables = [variables_dict[key] for key in data_columns]

        Estimator._set_node_properties(
            node=experiment,
            fileName=path,
            dataColumn=data_columns,
            use=[1] * len(data_columns),
            modelVariable=variables,
        )

    def create_experiments(self, datas: list[pd.DataFrame]):
        i = 0
        for data in datas:
            self.create_one_experiment(
                data=data,
                data_columns=data.columns,
                experiment_i=i,
                path=r'./generator_out_short.csv',
            )
            i += 1

    def clear_experiments(self):
        experiments, tables = self.experiments, self.tables
        for table in tables:
            table.remove()
        for experiment in experiments:
            experiment.remove()

    def solve(self):
        self.nodes['estimation'].toggle('on')
        try:
            self.solve()
        finally:
            self.nodes['estimation'].toggle('off')

# Generator

In [None]:
gen = Generator(model)

In [None]:
gen.constants

In [None]:
gen.solve()

In [None]:
species = {
    key: value
    for key,
    value in gen.species.items()
    if key in ['Q', 'DH', 'QHH']
}
df = gen.evaluate(species)

In [None]:
df

In [None]:
df = df/1000
df['time'] = df['time']*1000

In [None]:
def simple_temporal_plot(df:pd.DataFrame):
    fig = px.line(df, x='time',y=[col for col in df.columns if col not in ['time','light']])
    fig.update_layout(
        height=500,
        margin={
            'r': 0, 'l': 0, 't': 0, 'b': 0
        },
        legend=dict(x=-0.1, y=1, xanchor="center"),
    )
    return fig

In [None]:
simple_temporal_plot(df)


# Sweep

In [None]:
gen.constants

In [None]:
a=make_comdinations({
    'light':(np.linspace(0.1, 5.1, 11)/1000).round(7)
})

In [None]:
df= sweep(a,gen)

In [None]:
df

# Sensitivity

In [None]:
sens = Sensitivity(model)

In [None]:
# sens.set_target('QHH')

In [None]:
sens.set_parametrs(Kdisp=1e9)

In [None]:
sens.solve()

In [None]:
sens.sensitivities

In [None]:
sens.evaluate({'Kc': 'fsens(Ke)'})

# copy

In [None]:
mph.tree(from1)

In [None]:
mph.tree(to1)

In [None]:
from1= gen.nodes['solution']/'Time-Dependent Solver 1'

In [None]:
to1 = sens.nodes['solution']/'Time-Dependent Solver 1'

In [None]:
compare(from_node=from1,to_node=to1)

In [None]:
copy_solver(from_solver=from1,to_solver=to1)

In [None]:
model.save()

# Plots

In [None]:
import plotly.express as px

In [None]:
rule = {'Kc': (0, 0.30000)}
diap = ['Kc', 'light']
datas, dfs = get_solves(rule)
result = collect_dfs(datas, dfs, diap)


In [None]:
df_gen

In [None]:
result =df_gen
fig = px.line(
    result,
    x="time",
    y="DH",
    animation_frame='light',
    # color='Kc',
    range_y=[0, 1],
)
fig.update_layout(
    height=500,
    margin={
        'r': 0, 'l': 0, 't': 0, 'b': 0
    },
    legend=dict(x=-0.1, y=1, xanchor="left"),
)

# Estimator

In [None]:
# est =Estimator(model)