In [3]:
%%writefile StatsUtil.py

import sys;
from scipy.stats import chi2_contingency;
from scipy.stats import fisher_exact;
from math import sqrt, exp, log 

DEGENERATE_VALUE_THRESHOLD = 0.0;
DEGENERATE_VALUE_ADJUSTMENT = 1.0;

class getStats:

    def __init__(self, nAB, nA, nB, N):
        """Setup 2x2 table based on occurence totals
        - nAB: Number of cases where both event A and B occur
        - nA: Number of cases where A occurs
        - nB: Number of cases where B occurs
        - N: Number of total cases
        """
        # Keep references to original values to avoid potentially losing numerical precision information from calculations
        self.nAB = float(nAB);
        self.nA = float(nA);
        self.nB = float(nB);
        self.N = float(N);
        
        self.ct = [ [None,None], [None,None] ];
        self.ct[0][0] = float(nAB);
        self.ct[0][1] = float(nA-nAB);
        self.ct[1][0] = float(nB-nAB);
        self.ct[1][1] = float(N-nA-nB+nAB);
        
        self.degValueAdj = 1
        
    def normalize(self,truncateNegativeValues=False):
        """Check for irregular table values like negative or zero values and adjust them to avoid calculation failures
        """
        ct = self.ct;   # Convenience short-hand
        if truncateNegativeValues:
            # If any negative values, means there was some irregularity in the initial data.  
            # Will prevent proper calculation of P-values by Fisher or Chi2 methods
            for i in (0,1):
                for j in (0,1):
                    if ct[i][j] < 0.0:
                        ct[i][j] = 0.0;

        # If any zero values, change or add a small delta value to ALL fields to avoid divide by zero when calculating things like oddsRatio
        hasDegenerateValues = False;
        for i in (0,1):
            for j in (0,1):
                if abs(ct[i][j]) <= DEGENERATE_VALUE_THRESHOLD:
                    hasDegenerateValues = True;
        if hasDegenerateValues:
            for i in (0,1):
                for j in (0,1):
                    if abs(ct[i][j]) <= DEGENERATE_VALUE_THRESHOLD:
                        ct[i][j] = DEGENERATE_VALUE_ADJUSTMENT;
                    else:
                        ct[i][j] += DEGENERATE_VALUE_ADJUSTMENT;
        
    def calc(self, metric):
        nA = self.nA
        nB = self.nB
        nAB = self.nAB
        N = self.N
        ct = self.ct
        
        def numA():
            return nA
        def numB():
            return nB
        def total():
            return N
        def support():
            return nAB
        def prevalence():
            return nB/N
        def PPV():
            return nAB/nA 
        def RR():
            return (ct[0][0]/(ct[0][0] + ct[0][1])) / ((ct[1][0])/(ct[1][0] + ct[1][1]))

        def Fisher():
            oddsratio, pvalue = fisher_exact(ct)   
            return oddsratio, pvalue

        def Fisher_NegLog():
            try:
                (oddsRatio, fisherP) = fisher_exact(ct)
                logP = -sys.float_info.max
                if fisherP > 0.0:
                    logP = log(fisherP,10)

                if oddsRatio > 1.0:
                    return -logP
                else:
                    return logP
            except ValueError as exc:
                # Likely from negative table values.  Return default / uncertain value
                return 0.0
   
        switcher = {
                'numA': numA,
                'numB': numB,
                'total': total,
                'support': support,
                'prevalence': prevalence,
                'ppv': PPV,
                'rr' : RR,
                'fisher': Fisher,
                'fisher_neglog': Fisher_NegLog
                }
    
        func = switcher.get(metric, lambda: 'invalid')    
        
        return func()
    


Writing StatsUtil.py
