In [1]:
from mlscorecheck.symbolic import get_all_objects
from mlscorecheck.individual import load_tptn_solutions
from mlscorecheck.individual import Interval, sqrt
from mlscorecheck.symbolic import collect_denominators_and_bases
from mlscorecheck.scores import score_functions_without_complements
import sympy as sp
import json
import numpy as np

In [2]:
from mlscorecheck.individual import mcc_tp, acc_tp, acc_tn, sens_tp

In [3]:
%%timeit
mcc_tp(mcc=Interval(0.5, 0.6), p=10, n=20, tn=15)

323 µs ± 44.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]:
sens_tptn = []
sens = Interval(0.8, 0.8001)
for tn in range(270000):
    interval = sens_tp(sens=sens, p=30000)
    for tp_val in range(int(np.ceil(interval.lower_bound)), int(np.floor(interval.upper_bound))):
        sens_tptn.append((tp_val, tn))
        

In [5]:
len(sens_tptn)

810000

In [6]:
tptn = []
acc = Interval(0.5, 0.5001)
for tp in range(30000):
    interval = acc_tn(acc=acc, p=30000, n=270000, tp=tp)
    for tn_val in range(int(np.ceil(interval.lower_bound)), int(np.floor(interval.upper_bound))):
        tptn.append((tp, tn_val))
        

In [7]:
len(tptn)

900000

In [8]:
len(set(tptn).intersection(set(sens_tptn)))

90

In [9]:
solutions = load_tptn_solutions()

In [10]:
params = {'tn': 50, 'p': 200, 'n': 500, 'sqrt': sqrt, 'beta_plus': 1.0, 'beta_minus': 1.0}

In [11]:
def tptn_from_score(score, score_interval, equations, p, n):
    pairs = set()
    if len(equations['tp']) != 0:
        for tn in range(n+1):
            for equation in equations['tp']:
                tp_interval = eval(equation, {score: score_interval,
                                                'p': p,
                                                'n': n,
                                                'tn': tn,
                                                'sqrt': sqrt,
                                                'beta_plus': 1.0,
                                                'beta_minus': 1.0})
                tp_interval = tp_interval.shrink_to_integers()
                if not tp_interval.is_empty():
                    for tp in range(max(0, int(tp_interval.lower_bound)), 
                                    min(int(tp_interval.upper_bound), p) + 1):
                        pairs.add((tp, tn))
    else:
        for tp in range(p+1):
            for equation in equations['tn']:
                tn_interval = eval(equation, {score: score_interval,
                                                'p': p,
                                                'n': n,
                                                'tp': tp,
                                                'sqrt': sqrt,
                                                'beta_plus': 1.0,
                                                'beta_minus': 1.0})
                tn_interval = tn_interval.shrink_to_integers()
                if not tn_interval.is_empty():
                    for tn in range(max(0, int(tn_interval.lower_bound)), 
                                    min(int(tn_interval.upper_bound), n) + 1):
                        pairs.add((tp, tn))
    
    return pairs

In [12]:
tptn_from_score('acc', Interval(0.951, 0.959), solutions['acc'], p=40, n=100)

{(34, 100), (35, 99), (36, 98), (37, 97), (38, 96), (39, 95), (40, 94)}

In [13]:
for score, score_dict in solutions.items():
    if len(score_dict['tp']) != 0:
        for eq in score_dict['tp']:
            params[score] = Interval(0.785, 0.787)
            print(score, eval(eq, params))

mcc (15.408381865267977, 16.21687925881312)
mcc (232.70897269021557, 233.73670748078877)
acc (499.5, 500.9)
err (99.10000000000002, 100.5)
sens (157.0, 157.4)
fnr (42.599999999999994, 42.99999999999999)
ppv (1643.0232558139535, 1662.6760563380287)
fdr (121.79161372299869, 123.24840764331209)
for_ (15.116279069767442, 17.605633802816907)
npv (186.30573248407643, 186.46759847522236)
fbp (419.95884773662556, 421.7230008244023)
f1p (419.95884773662556, 421.7230008244023)
fbm (621.0292249047014, 624.5222929936306)
f1m (622.6114649681529, 622.9351969504447)
upm (-55.71634697506306, -52.66650137380833)
upm (598.2131632924059, 601.2630088936606)
gm (-1574.0, -1570.0)
gm (1570.0, 1574.0)
fm (-182.30939688310286, -181.64930377278694)
fm (305.0513037727869, 306.0257968831028)
mk (-1054.2940131314779, -1049.571510244451)
mk (214.50162460277363, 217.09656090217857)
lrp (141.3, 141.66)
lrn (184.26, 184.3)
bm (337.00000000000006, 337.4)
pt 180.0
pt (12.0406736533472, 14.65260253965679)
dor (174.76808

In [14]:
scores = get_all_objects('sympy')

In [15]:
symbols = scores['acc'].symbols

In [16]:
solutions = {}

In [17]:
def doc_type(arg):
    if arg in {'tp', 'tn', 'p', 'n'}:
        return 'int'
    if arg in {'beta_minus', 'beta_plus'}:
        return 'float'
    if arg in list(scores.keys()):
        return 'float|Interval|IntervalUnion'

In [18]:
def doc_str(arg):
    if arg in {'tp', 'tn', 'p', 'n'}:
        return f'the {arg} count'
    if arg in {'beta_minus', 'beta_plus'}:
        return f'the {arg} parameter of the score'
    if arg in list(scores.keys()):
        return f'the value or interval for the score {arg}'

In [19]:
def docstring(sol_for, score, args):
    tabs = '    '
    docs = tabs + '"""\n'
    docs += tabs + f'Solves {sol_for} from the score {score}\n\n'
    docs += tabs + 'Args:\n'
    
    arglist = tabs + tabs
    arglist += ('\n' + tabs + tabs).join(f'{arg} ({doc_type(arg)}): {doc_str(arg)}' for arg in args)
    arglist += '\n' + tabs + tabs + 'kwargs (dict): additional keyword arguments\n'
    
    docs += arglist + '\n'
    docs += tabs + 'Returns:\n'
    docs += tabs + tabs + f'float|Interval|IntervalUnion: the value or interval for {sol_for}\n'
    docs += tabs + '"""\n'
    return docs

In [20]:
print(docstring('tp', 'mcc', {'mcc', 'p', 'n', 'tn'}))

    """
    Solves tp from the score mcc

    Args:
        p (int): the p count
        mcc (float|Interval|IntervalUnion): the value or interval for the score mcc
        n (int): the n count
        tn (int): the tn count
        kwargs (dict): additional keyword arguments

    Returns:
        float|Interval|IntervalUnion: the value or interval for tp
    """



In [21]:
def conditions(bases, denoms):
    code = ''
    if len(bases) > 0:
        condition = ' or '.join(f"is_less_than_zero({str(base)})" for base in bases)
        code += f'    if {condition}:\n'
        code += '        return None\n'
    if len(denoms) > 0:
        condition = ' or '.join(f"is_zero({str(denom)})" for denom in denoms)
        code += f'    if {condition}:\n'
        code += '        return None\n'
    return code

In [22]:
all_code = ''
all_names = []

for score in scores:
    if not score in score_functions_without_complements:
        continue
    tps = [expr for expr in sp.solve(scores[score].equation_polynomial, symbols.tp)]
    tns = [expr for expr in sp.solve(scores[score].equation_polynomial, symbols.tn)]
    
    names = []
    
    for idx, tp in enumerate(tps):
        args = symbols.algebra.args(tp)
        args = [str(arg) for arg in args]
        denoms, bases = collect_denominators_and_bases(tp, algebra=symbols.algebra)
        denoms = set(denoms)
        bases = set(bases)
        
        name = f'{score}_tp' if len(tps) == 1 else f'{score}_tp_{idx}'
        names.append(name)
        all_names.append(name)
        
        code = f"def {name}(*, {', '.join(args)}, **kwargs):\n"
        code += docstring('tp', score, args)
        code += conditions(bases, denoms)
        code += f'    return {str(tp)}\n\n'
        all_code += code
    
    if len(tps) > 1:
        code = f'def {score}_tp(*, {", ".join(args)}, **kwargs):\n'
        all_names.append(f'{score}_tp')
        code += docstring('tp', score, args)
        arglist = ", ".join(f'{arg}={arg}' for arg in args)
        calls = ",\n                          ".join(f'{name}({arglist})' for name in names)
        code += f'    return unify_results([{calls}])\n'
        all_code += code
    
    names = []
    
    for idx, tn in enumerate(tns):
        args = symbols.algebra.args(tn)
        args = [str(arg) for arg in args]
        denoms, bases = collect_denominators_and_bases(tn, algebra=symbols.algebra)
        denoms = set(denoms)
        bases = set(bases)
        
        name = f'{score}_tn' if len(tns) == 1 else f'{score}_tn_{idx}'
        names.append(name)
        all_names.append(name)
        
        code = f"def {name}(*, {', '.join(args)}, **kwargs):\n"
        code += docstring('tn', score, args)
        code += conditions(bases, denoms)
        code += f'    return {str(tn)}\n\n'
        all_code += code
        
        print(code)
    
    if len(tns) > 1:
        code = f'def {score}_tn(*, {", ".join(args)}, **kwargs):\n'
        all_names.append(f'{score}_tn')
        code += docstring('tn', score, args)
        arglist = ", ".join(f'{arg}={arg}' for arg in args)
        calls = ",\n                          ".join(f'{name}({arglist})' for name in names)
        code += f'    return unify_results([{calls}])\n'
        all_code += code
        

def mcc_tn_0(*, p, n, tp, mcc, **kwargs):
    """
    Solves tn from the score mcc

    Args:
        p (int): the p count
        n (int): the n count
        tp (int): the tp count
        mcc (float|Interval|IntervalUnion): the value or interval for the score mcc
        kwargs (dict): additional keyword arguments

    Returns:
        float|Interval|IntervalUnion: the value or interval for tn
    """
    if is_less_than_zero(n) or is_less_than_zero(mcc**2*n*p + 4*p*tp - 4*tp**2) or is_less_than_zero(p):
        return None
    if is_zero(mcc**2*n + p) or is_zero(sqrt(p)):
        return None
    return (-mcc*sqrt(n)*(n + p)*sqrt(mcc**2*n*p + 4*p*tp - 4*tp**2) + n*sqrt(p)*(mcc**2*n - mcc**2*p + 2*mcc**2*tp + 2*p - 2*tp))/(2*sqrt(p)*(mcc**2*n + p))


def mcc_tn_1(*, p, n, tp, mcc, **kwargs):
    """
    Solves tn from the score mcc

    Args:
        p (int): the p count
        n (int): the n count
        tp (int): the tp count
        mcc (float|Interval|IntervalUnion): the value 

In [23]:
preface = ''
preface += '"""\n'
preface += 'This module contains the tp and tn solutions.\nThis is a generated file, do not edit.\n'
preface += '"""\n\n'
preface += 'from ._helper import is_less_than_zero, is_zero, unify_results\n'
preface += 'from ._interval import sqrt\n\n'

names = '__all__ = [\n'
names += ',\n'.join(f'"{name}"' for name in all_names)
names += ']\n\n'
preface += names

In [24]:
all_code = preface + all_code

In [25]:
print(all_code)

"""
This module contains the tp and tn solutions.
This is a generated file, do not edit.
"""

from ._helper import is_less_than_zero, is_zero, unify_results
from ._interval import sqrt

__all__ = [
"mcc_tp_0",
"mcc_tp_1",
"mcc_tp",
"mcc_tn_0",
"mcc_tn_1",
"mcc_tn",
"acc_tp",
"acc_tn",
"sens_tp",
"spec_tn",
"ppv_tp",
"ppv_tn",
"npv_tp",
"npv_tn",
"fbp_tp",
"fbp_tn",
"f1p_tp",
"f1p_tn",
"fbm_tp",
"fbm_tn",
"f1m_tp",
"f1m_tn",
"upm_tp_0",
"upm_tp_1",
"upm_tp",
"upm_tn_0",
"upm_tn_1",
"upm_tn",
"gm_tp",
"gm_tn",
"fm_tp_0",
"fm_tp_1",
"fm_tp",
"fm_tn",
"mk_tp_0",
"mk_tp_1",
"mk_tp",
"mk_tn_0",
"mk_tn_1",
"mk_tn",
"lrp_tp",
"lrp_tn",
"lrn_tp",
"lrn_tn",
"bm_tp",
"bm_tn",
"pt_tp_0",
"pt_tp_1",
"pt_tp",
"pt_tn_0",
"pt_tn_1",
"pt_tn",
"dor_tp",
"dor_tn",
"ji_tp",
"ji_tn",
"bacc_tp",
"bacc_tn",
"kappa_tp",
"kappa_tn",
"p4_tp_0",
"p4_tp_1",
"p4_tp",
"p4_tn_0",
"p4_tn_1",
"p4_tn"]

def mcc_tp_0(*, p, tn, n, mcc, **kwargs):
    """
    Solves tp from the score mcc

    Args:
        p (int): the p 

In [26]:
with open('_tptn_solutions.py', 'wt') as file:
    file.write(all_code)

In [27]:
print(preface + all_code)

"""
This module contains the tp and tn solutions.
This is a generated file, do not edit.
"""

from ._helper import is_less_than_zero, is_zero, unify_results
from ._interval import sqrt

__all__ = [
"mcc_tp_0",
"mcc_tp_1",
"mcc_tp",
"mcc_tn_0",
"mcc_tn_1",
"mcc_tn",
"acc_tp",
"acc_tn",
"sens_tp",
"spec_tn",
"ppv_tp",
"ppv_tn",
"npv_tp",
"npv_tn",
"fbp_tp",
"fbp_tn",
"f1p_tp",
"f1p_tn",
"fbm_tp",
"fbm_tn",
"f1m_tp",
"f1m_tn",
"upm_tp_0",
"upm_tp_1",
"upm_tp",
"upm_tn_0",
"upm_tn_1",
"upm_tn",
"gm_tp",
"gm_tn",
"fm_tp_0",
"fm_tp_1",
"fm_tp",
"fm_tn",
"mk_tp_0",
"mk_tp_1",
"mk_tp",
"mk_tn_0",
"mk_tn_1",
"mk_tn",
"lrp_tp",
"lrp_tn",
"lrn_tp",
"lrn_tn",
"bm_tp",
"bm_tn",
"pt_tp_0",
"pt_tp_1",
"pt_tp",
"pt_tn_0",
"pt_tn_1",
"pt_tn",
"dor_tp",
"dor_tn",
"ji_tp",
"ji_tn",
"bacc_tp",
"bacc_tn",
"kappa_tp",
"kappa_tn",
"p4_tp_0",
"p4_tp_1",
"p4_tp",
"p4_tn_0",
"p4_tn_1",
"p4_tn"]

"""
This module contains the tp and tn solutions.
This is a generated file, do not edit.
"""

from ._helper import is

In [28]:
from mlscorecheck.individual import Expression, Solution0

In [29]:
sol = Solution0(**solutions['mcc']['tp'][0])

KeyError: 'mcc'

In [None]:
sol.evaluate({'mcc': 0.5, 'p': 10, 'n': 20, 'tn': 15})

-1.0762521851076514

In [None]:
solutions

{'mcc': {'tp': [{'solution': {'expression': '(-mcc*sqrt(p)*(n + p)*sqrt(mcc**2*n*p + 4*n*tn - 4*tn**2) + sqrt(n)*p*(-mcc**2*n + mcc**2*p + 2*mcc**2*tn + 2*n - 2*tn))/(2*sqrt(n)*(mcc**2*p + n))',
     'symbols': ['n', 'mcc', 'tn', 'p']},
    'non_zero': [{'expression': 'mcc**2*p + n', 'symbols': ['n', 'mcc', 'p']},
     {'expression': 'sqrt(n)', 'symbols': ['n']}],
    'non_negative': [{'expression': 'n', 'symbols': ['n']},
     {'expression': 'mcc**2*n*p + 4*n*tn - 4*tn**2',
      'symbols': ['n', 'mcc', 'p', 'tn']},
     {'expression': 'p', 'symbols': ['p']}]},
   {'solution': {'expression': '(mcc*sqrt(p)*(n + p)*sqrt(mcc**2*n*p + 4*n*tn - 4*tn**2) + sqrt(n)*p*(-mcc**2*n + mcc**2*p + 2*mcc**2*tn + 2*n - 2*tn))/(2*sqrt(n)*(mcc**2*p + n))',
     'symbols': ['n', 'mcc', 'tn', 'p']},
    'non_zero': [{'expression': 'mcc**2*p + n', 'symbols': ['n', 'mcc', 'p']},
     {'expression': 'sqrt(n)', 'symbols': ['n']}],
    'non_negative': [{'expression': 'n', 'symbols': ['n']},
     {'expression'

In [None]:
with open('tptn_solutions.json', 'wt') as file:
    json.dump(solutions, file)

TypeError: Object of type Mul is not JSON serializable