In [1]:
import json

import sympy as sp

from mlscorecheck.scores import score_functions_without_complements
from mlscorecheck.symbolic import collect_denominators_and_bases, get_all_objects

In [2]:
scores = get_all_objects('sympy')

In [3]:
symbols = scores['acc'].symbols

In [4]:
solutions = {}

In [5]:
def doc_type(arg):
    if arg in {'tp', 'tn', 'p', 'n'}:
        return 'int'
    if arg in {'beta_negative', 'beta_positive'}:
        return 'float'
    if arg in list(scores.keys()):
        return 'float|Interval|IntervalUnion'

In [6]:
def doc_str(arg):
    if arg in {'tp', 'tn', 'p', 'n'}:
        return f'the {arg} count'
    if arg in {'beta_negative', 'beta_positive'}:
        return f'the {arg} parameter of the score'
    if arg in list(scores.keys()):
        return f'the value or interval for the score {arg}'

In [7]:
def docstring(sol_for, score, args):
    tabs = '    '
    docs = tabs + '"""\n'
    docs += tabs + f'Solves {sol_for} from the score {score}\n\n'
    docs += tabs + 'Args:\n'

    arglist = tabs + tabs
    arglist += ('\n' + tabs + tabs).join(f'{arg} ({doc_type(arg)}): {doc_str(arg)}' for arg in args)
    arglist += '\n' + tabs + tabs + 'kwargs (dict): additional keyword arguments\n'

    docs += arglist + '\n'
    docs += tabs + 'Returns:\n'
    docs += tabs + tabs + f'float|Interval|IntervalUnion: the value or interval for {sol_for}\n'
    docs += tabs + '"""\n'
    return docs

In [8]:
print(docstring('tp', 'mcc', {'mcc', 'p', 'n', 'tn'}))

    """
    Solves tp from the score mcc

    Args:
        mcc (float|Interval|IntervalUnion): the value or interval for the score mcc
        n (int): the n count
        p (int): the p count
        tn (int): the tn count
        kwargs (dict): additional keyword arguments

    Returns:
        float|Interval|IntervalUnion: the value or interval for tp
    """



In [9]:
def conditions(bases, denoms):
    code = ''
    condition = ''
    #if len(bases) > 0:
    #    condition = ' or '.join(f"is_less_than_zero({str(base[0])})" for base in bases)
    if len(denoms) > 0:
        condition = ' or '.join(f"is_zero({str(denom[0])})" for denom in denoms)

    if len(denoms) > 0:
        code += f'    if {condition}:\n'
        code += '        return None\n'

    return code

In [10]:
all_code = ''
all_names = []
results = {}

for score in scores:
    if score not in score_functions_without_complements:
        continue
    tps = [expr for expr in sp.solve(scores[score].equation_polynomial, symbols.tp)]
    tns = [expr for expr in sp.solve(scores[score].equation_polynomial, symbols.tn)]

    results[score] = {'tp': tps, 'tn': tns}

    names = []

    for idx, tp in enumerate(tps):
        args = symbols.algebra.args(tp)
        args = [str(arg) for arg in args]
        denoms, bases = collect_denominators_and_bases(tp, algebra=symbols.algebra)
        denoms = set(denoms)
        bases = set(bases)

        name = f'{score}_tp' if len(tps) == 1 else f'{score}_tp_{idx}'
        names.append(name)
        all_names.append(name)

        code = f"def {name}(*, {', '.join(args)}, **kwargs):\n"
        code += docstring('tp', score, args)
        code += '    _ = kwargs\n'
        code += conditions(bases, denoms)
        code += f'    return {str(tp)}\n\n'
        all_code += code

    if len(tps) > 1:
        code = f'def {score}_tp(*, {", ".join(args)}, **kwargs):\n'
        all_names.append(f'{score}_tp')
        code += docstring('tp', score, args)
        code += '    _ = kwargs\n'
        arglist = ", ".join(f'{arg}={arg}' for arg in args)
        calls = ",\n                          ".join(f'{name}({arglist})' for name in names)
        code += f'    return unify_results([{calls}])\n'
        all_code += code

    names = []

    for idx, tn in enumerate(tns):
        args = symbols.algebra.args(tn)
        args = [str(arg) for arg in args]
        denoms, bases = collect_denominators_and_bases(tn, algebra=symbols.algebra)
        denoms = set(denoms)
        bases = set(bases)

        name = f'{score}_tn' if len(tns) == 1 else f'{score}_tn_{idx}'
        names.append(name)
        all_names.append(name)

        code = f"def {name}(*, {', '.join(args)}, **kwargs):\n"
        code += docstring('tn', score, args)
        code += '    _ = kwargs\n'
        code += conditions(bases, denoms)
        code += f'    return {str(tn)}\n\n'
        all_code += code

        print(code)

    if len(tns) > 1:
        code = f'def {score}_tn(*, {", ".join(args)}, **kwargs):\n'
        all_names.append(f'{score}_tn')
        code += docstring('tn', score, args)
        code += '    _ = kwargs\n'
        arglist = ", ".join(f'{arg}={arg}' for arg in args)
        calls = ",\n                          ".join(f'{name}({arglist})' for name in names)
        code += f'    return unify_results([{calls}])\n'
        all_code += code


def acc_tn(*, acc, p, tp, n, **kwargs):
    """
    Solves tn from the score acc

    Args:
        acc (float|Interval|IntervalUnion): the value or interval for the score acc
        p (int): the p count
        tp (int): the tp count
        n (int): the n count
        kwargs (dict): additional keyword arguments

    Returns:
        float|Interval|IntervalUnion: the value or interval for tn
    """
    _ = kwargs
    return acc*n + acc*p - tp




def spec_tn(*, spec, n, **kwargs):
    """
    Solves tn from the score spec

    Args:
        spec (float|Interval|IntervalUnion): the value or interval for the score spec
        n (int): the n count
        kwargs (dict): additional keyword arguments

    Returns:
        float|Interval|IntervalUnion: the value or interval for tn
    """
    _ = kwargs
    return n*spec


def ppv_tn(*, tp, ppv, n, **kwargs):
    """
    Solves tn from the score ppv

    Args:
        tp (int): the tp count
        ppv (float|Interval|IntervalUnion): the value or interval for the score ppv
        n (int): the n count
        kwargs (dict): additional keyword arguments

    Returns:
        float|Interval|IntervalUnion: the value or interval for tn
    """
    _ = kwargs
    if is_zero(ppv):
        return None
    return n + tp - tp/ppv


def npv_tn(*, p, tp, npv, **kwargs):
    """
    Solves tn from the score npv

    Args:
        p (int): the p count
        tp (int): the tp count
        npv (

In [11]:
preface = ''
preface += '"""\n'
preface += 'This module contains the tp and tn solutions.\nThis is a generated file, do not edit.\n'
preface += '"""\n\n'
preface += '# pylint: disable=line-too-long\n'
preface += '# pylint: disable=too-many-lines\n\n'
preface += 'from ._utils import is_zero, unify_results\n'
preface += 'from ._interval import sqrt\n\n'

names = '__all__ = [\n'
names += ',\n'.join(f'"{name}"' for name in all_names)
names += ']\n\n'
preface += names

In [12]:
all_code = preface + all_code

In [13]:
print(all_code)

"""
This module contains the tp and tn solutions.
This is a generated file, do not edit.
"""

# pylint: disable=line-too-long
# pylint: disable=too-many-lines

from ._utils import is_zero, unify_results
from ._interval import sqrt

__all__ = [
"acc_tp",
"acc_tn",
"sens_tp",
"spec_tn",
"ppv_tp",
"ppv_tn",
"npv_tp",
"npv_tn",
"fbp_tp",
"fbp_tn",
"f1p_tp",
"f1p_tn",
"fbn_tp",
"fbn_tn",
"f1n_tp",
"f1n_tn",
"gm_tp",
"gm_tn",
"fm_tp_0",
"fm_tp_1",
"fm_tp",
"fm_tn",
"upm_tp_0",
"upm_tp_1",
"upm_tp",
"upm_tn_0",
"upm_tn_1",
"upm_tn",
"mk_tp_0",
"mk_tp_1",
"mk_tp",
"mk_tn_0",
"mk_tn_1",
"mk_tn",
"lrp_tp",
"lrp_tn",
"lrn_tp",
"lrn_tn",
"bm_tp",
"bm_tn",
"pt_tp_0",
"pt_tp_1",
"pt_tp",
"pt_tn_0",
"pt_tn_1",
"pt_tn",
"dor_tp",
"dor_tn",
"ji_tp",
"ji_tn",
"bacc_tp",
"bacc_tn",
"kappa_tp",
"kappa_tn",
"p4_tp_0",
"p4_tp_1",
"p4_tp",
"p4_tn_0",
"p4_tn_1",
"p4_tn",
"mcc_tp_0",
"mcc_tp_1",
"mcc_tp",
"mcc_tn_0",
"mcc_tn_1",
"mcc_tn"]

def acc_tp(*, acc, p, tn, n, **kwargs):
    """
    Solves tp from the 

In [14]:
with open('_tptn_solutions.py', 'w') as file:
    file.write(all_code)

In [15]:
print(preface + all_code)

"""
This module contains the tp and tn solutions.
This is a generated file, do not edit.
"""

# pylint: disable=line-too-long
# pylint: disable=too-many-lines

from ._utils import is_zero, unify_results
from ._interval import sqrt

__all__ = [
"acc_tp",
"acc_tn",
"sens_tp",
"spec_tn",
"ppv_tp",
"ppv_tn",
"npv_tp",
"npv_tn",
"fbp_tp",
"fbp_tn",
"f1p_tp",
"f1p_tn",
"fbn_tp",
"fbn_tn",
"f1n_tp",
"f1n_tn",
"gm_tp",
"gm_tn",
"fm_tp_0",
"fm_tp_1",
"fm_tp",
"fm_tn",
"upm_tp_0",
"upm_tp_1",
"upm_tp",
"upm_tn_0",
"upm_tn_1",
"upm_tn",
"mk_tp_0",
"mk_tp_1",
"mk_tp",
"mk_tn_0",
"mk_tn_1",
"mk_tn",
"lrp_tp",
"lrp_tn",
"lrn_tp",
"lrn_tn",
"bm_tp",
"bm_tn",
"pt_tp_0",
"pt_tp_1",
"pt_tp",
"pt_tn_0",
"pt_tn_1",
"pt_tn",
"dor_tp",
"dor_tn",
"ji_tp",
"ji_tn",
"bacc_tp",
"bacc_tn",
"kappa_tp",
"kappa_tn",
"p4_tp_0",
"p4_tp_1",
"p4_tp",
"p4_tn_0",
"p4_tn_1",
"p4_tn",
"mcc_tp_0",
"mcc_tp_1",
"mcc_tp",
"mcc_tn_0",
"mcc_tn_1",
"mcc_tn"]

"""
This module contains the tp and tn solutions.
This is a generated f

In [22]:
results_str = {}
for key, item in results.items():
    results_str[key] = {'tp': [str(tmp) for tmp in item['tp']],
                        'tn': [str(tmp) for tmp in item['tn']]}

In [23]:
with open('tptn_solutions.json', 'w') as file:
    json.dump(results_str, file)