In [None]:
# Setup

!pip install --upgrade torch
!pip install --upgrade pyro-ppl
!pip install --upgrade scipy
!pip install --upgrade matplotlib
!pip install --upgrade scikit-optimize

In [None]:
import pyro
import torch
import torch.tensor as tensor
import pyro.distributions as dist
# from torch.distributions import Binomial, Gamma, Uniform
from pyro.distributions import Binomial, Bernoulli, Categorical, Dirichlet, DirichletMultinomial as DM, Beta, BetaBinomial, Uniform, Gamma, Multinomial

import numpy as np

import scipy
from skopt import gp_minimize 
from scipy.stats import binom as ScipyBinom
from matplotlib import pyplot

from collections import namedtuple
import time
seed = 0

In [None]:
# Measuring overhead
import time

# .1s
l = []
start = time.time()
for i in range(int(1e6)):
    l.append(i)

if len(l) > 5:
    print("Done")
print(f"Scalar version took: {time.time() - start}")

# 30x slower, 3.2s
l = []
start = time.time()
for i in range(int(1e6)):
    l.append(tensor(i))

if len(l) > 5:
    print("Done")
print(f"Tensor version took: {time.time() - start}")

# do it in one pass
# this wraps the array in tensor, aka tensor([]),
# but accessing a single element gives back a tensor
l = []
start = time.time()
for i in range(int(1e6)):
    l.append(i)

# .13s 
l = torch.tensor(l)
if len(l) > 5:
    print("Done")
print(f"Tensor convert array version took: {time.time() - start}")

In [None]:
#### Likelihood functions
# These assume univariate currently
 
# TODO:
# 1) Explore constraining alphas using prevalence estimate, namely E(P(D)) = alpha0 / (alpha0 + alpha1 + alpha2 + alphaBoth) (as long as all case counts are mutually exclusive)
# 2) Can DM approximate NB + Multinomial? If so do we need mixture at all? But if we don't have that how do we model % disease-afffecting genes in each hypothesis(maybe proportion of alphas?)
# rr: relative risk
# P(V|D) = P(D|V)*P(V) / P(D)
# rr * P(D|!V) = P(D|V)
# P(V|D) = rr * P(D|!V) * P(V) / P(D)
# P(D) = (P(D|V)P(V) + P(D|!V)P(!V))
# P(D) = P(D|V) + P(D|!V)(1-P(V))
# P(V|D) = ( rrP(D|!V)) ) * P(V) ) / ( (P(D|V)P(V) + P(D|!V)(1-P(V))) )
# let a = ( rrP(D|!V)) ) * P(V) )
# P(V|D) = a / P(D|!V) / ( P(D|V)P(V) + P(D|!V) - P(D|!V)P(V) ) / P(D|!V)
# = ( rr*P(V) ) / ( rr*P(V) + 1 - P(V) )
def pVgivenD(rr, pV):
    return (rr * pV) / (rr * pV + (1 - pV))

# pD: prevalence, tensor of mConditions x 1
# pVgivenD: tensor of mConditions x 1
# pV: allele frequency
def pVgivenNotD(pD, pV, pVgivenD):
    p = (pV - (pD*pVgivenD).sum()) / (1 - pD.sum())
    if(p < 0):
        raise ValueError(f"pVgivenNotD: invalid params: pD: {pD}, pV: {pV}, pVgivenD: {pVgivenD} yield: p = {p}")
    return p

# def pVgivenNotD(pD, pV, pVgivenD):
#     p = (pV - (pD*pVgivenD)) / (1 - pD)
#     assert(p >= 0)
#     return p

def pDgivenV(pD, pVgivenD, pV):
    return pVgivenD * pD / pV

# works like shit
def llUnivariateSingleGeneJensen(xCtrl, xCase, pD, pi0, pi1, pDgivenV):
    n = xCtrl + xCase
    #convex function, so by jensen's sum of logs is fine (always <= the log of sum)
    return pi0 * Binomial(total_count=n, probs=pD).log_prob(xCase) + pi1*Binomial(total_count=n, probs=pDgivenV).log_prob(xCase)

def llUnivariateSingleGene(xCtrl, xCase, pD, pi0, pi1, pDgivenV):
    n = xCtrl + xCase
    #convex function, so by jensen's sum of logs is fine (always <= the log of sum)
    return torch.log(pi0 * torch.exp(Binomial(total_count=n, probs=pD).log_prob(xCase)) + pi1*torch.exp(Binomial(total_count=n, probs=pDgivenV).log_prob(xCase)))

# alphas shape: [2] #corresponding to cases and controls
def llUnivariateSingleGeneBetaBinomial(xCtrl, xCase, pD, alphas, pi0, pi1):
    n = xCtrl + xCase
    #convex function, so by jensen's sum of logs is fine (always <= the log of sum)
    # what is the 
    h0 = pi0 * torch.exp( Binomial(total_count=n, probs=pD).log_prob(xCase) )
    h1 = pi1 * torch.exp( BetaBinomial(total_count=n, concentration1=alphas[1], concentration0=alphas[0]).log_prob(xCase) )
    return torch.logalpha3( h0 + h1 )

# TODO: support pooled and non-pooled controls
# TODO: think about whether we need overlapping cases (both disease1 + disease2) or whether that can be inferred
# altCounts.shape = [1 control + nConditions cases, 1]
# alphas shape: [nConditions + 2] #1 ctrl + nCondition cases; for now the last condition in nCondition cases is for individuals who has all of the previous nConditions
# in a more multivariate setting we will need more information, aka mapping to which combinations of conditions these people have
# xCases: we have nConditions cases
# pDs shape: [nConditions]
# TODO: make this more effificent by taking alphas tensor of shape (1 + nConditions)
def llPooledBivariateSingleGene(altCounts, pDs, alpha0, alpha1, alpha2, alphaBoth, pi0, pi1, pi2, piBoth):
    # currently assume altCounts are all independent (in simulation), or 0 for everything but first condition
    n = altCounts.sum()
    alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
    print("n is ", n)
    #convex function, so by jensen's sum of logs is fine (always <= the log of sum)
    # what is the 
    case1nullLikelihood = torch.exp( Binomial(total_count=n, probs=pDs[0]).log_prob(altCounts[1]) )
    case2nullLikelihood = torch.exp( Binomial(total_count=n, probs=pDs[1]).log_prob(altCounts[2]) )
    h0 = pi0 * case1nullLikelihood * case2nullLikelihood
    h1 = pi1 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha1, concentration0=alphasSum - alpha1).log_prob(altCounts[1]) ) * case2nullLikelihood
    h2 = pi2 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha2, concentration0=alphasSum - alpha2).log_prob(altCounts[2]) ) * case1nullLikelihood
    h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCounts))
    print(f"h0: {h0}, h1: {h1}, h2: {h2}, h3: {h3}")
    return torch.log( h0 + h1 + h2 + h3 )

# shape of altCountsByGene: [nGenes, nConditions, 2]
# last dimension is 
# 2nd dimension altCountsCasesByGene must match controls, or the control nConditions must be 1 (pooled controls)
def likelihoodUnivariate(altCountsByGene, pDs):
    nGenes = len(altCountsByGene)
    
    # passed to optimization function, we optimize pDgivenV and pi1 by maximizing likelihood
    def likelihood(params):
        pDgivenV = params[0]
        pi1 = params[1]
        pi0 = 1 - pi1
        
        if(pDgivenV >= 1 or pDgivenV < 0 or pi1 < 0 or pi1 > 1):
            print("returning inf")
            return float("-inf")
    
        logLikelihood = 0
        penaltyCount = float(nGenes)
        
        # 
        for geneIdx in range(nGenes):
            ctrlAltCount = altCountsByGene[geneIdx, 0, 0]
            caseAltCount = altCountsByGene[geneIdx, 0, 1]
            pd = pDs[0]
            
            if ctrlAltCount == 0 and caseAltCount == 0:
                print("skipping", geneIdx)
                continue

            # this is insanely slow
            ll = llUnivariateSingleGene(ctrlAltCount, caseAltCount, pd, pi0, pi1, pDgivenV)

            if torch.isnan(ll) or torch.isinf(ll):
                print(f"nan or 0 likelihood: like: {like}, p1: {pi1}, pDgivenV: {pDgivenV}, gene: {geneIdx}, ctrlCount: {ctrlAltCount}, caseCount: {caseAltCount}")
                penaltyCount -= 1
                continue
                
            logLikelihood += ll
        
    
        if penaltyCount == 0:
            penaltyCount = 1
    
        return -logLikelihood * (nGenes / penaltyCount)
    
    return likelihood

def likelihoodUnivariateFast(altCountsByGene, pDs):
    nGenes = len(altCountsByGene)
    geneSums = altCountsByGene[:, 0, :].sum(1)
        
    caseAltCounts = altCountsByGene[:, 0, 1]
    pD = pDs[0]
    def likelihood(params):
        pi1, pDgivenV = params

        pi0 = 1.0 - pi1

        if(pDgivenV > 1 or pDgivenV < 0 or pi1 < 0 or pi1 > 1):
            return float("inf")
        
        binomH0 = Binomial(total_count=geneSums, probs=pD)
        binomH1 = Binomial(total_count=geneSums, probs=pDgivenV)
        
        component0 = pi0 * torch.exp(binomH0.log_prob(caseAltCounts))
        component1 = pi1 * torch.exp(binomH1.log_prob(caseAltCounts))
        
        return - torch.log(component0 + component1).sum()
    
    return likelihood

def likelihoodUnivariateBetaBinomialFast(altCountsByGene, pDs):
    nGenes = len(altCountsByGene)
    geneSums = altCountsByGene[:, 0, :].sum(1)
        
    caseAltCounts = altCountsByGene[:, 0, 1]
    pD = pDs[0]
    def likelihood(params):
        pi1, alpha1, alpha0 = params

        if alpha1 < 0 or alpha0 < 0 or pi1 < 0 or pi1 > 1:
            return float("inf")
        
        pi0 = 1.0 - pi1

        binomH0 = Binomial(total_count=geneSums, probs=pD)
        binomH1 = BetaBinomial(total_count=geneSums, concentration1=alpha1, concentration0=alpha0)
        
        component0 = pi0 * torch.exp(binomH0.log_prob(caseAltCounts))
        component1 = pi1 * torch.exp(binomH1.log_prob(caseAltCounts))

        return - torch.log(component0 + component1).sum()
    
    return likelihood

def getUnivariateAlpha0(alpha1, pD):
    return ((1-pD) / pD)*alpha1

# doesn't really work constraint looks wrong
def likelihoodUnivariateBetaBinomialConstrainedFast(altCountsByGene, pDs):
    nGenes = len(altCountsByGene)
    geneSums = altCountsByGene[:, 0, :].sum(1)
        
    caseAltCounts = altCountsByGene[:, 0, 1]
    pD = pDs[0]
    pNotDRatio = (1 - pD)/pD
    def likelihood(params):
        pi1, alpha1 = params
        
        if alpha1 < 0 or pi1 < 0 or pi1 > 1:
            return float("inf")
        
        pi0 = 1.0 - pi1
        
        alpha0 = pNotDRatio*alpha1
        
        assert(alpha0 > 0)
        
        print("alpha0",alpha0)
        
        binomH0 = Binomial(total_count=geneSums, probs=pD)
        binomH1 = BetaBinomial(total_count=geneSums, concentration1=alpha1, concentration0=alpha0)
        
        component0 = pi0 * torch.exp(binomH0.log_prob(caseAltCounts))
        component1 = pi1 * torch.exp(binomH1.log_prob(caseAltCounts))

        return - torch.log(component0 + component1).sum()
    
    return likelihood

# Bivariate likelihood function modeled on:
#def llPooledBivariateSingleGene(altCounts, pDs, alpha0, alpha1, alpha2, alphaBoth, pi0, pi1, pi2, piBoth):
# # currently assume altCounts are all independent (in simulation), or 0 for everything but first condition
# n = altCounts.sum()
# alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
# case1nullLikelihood = torch.exp( Binomial(total_count=n, probs=pDs[0]).log_prob(altCounts[1]) )
# case2nullLikelihood = torch.exp( Binomial(total_count=n, probs=pDs[1]).log_prob(altCounts[2]) )
# h0 = pi0 * case1nullLikelihood * case2nullLikelihood
# h1 = pi1 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha1, concentration0=alphasSum - alpha1).log_prob(altCounts[1]) ) * case2nullLikelihood
# h2 = pi2 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha2, concentration0=alphasSum - alpha2).log_prob(altCounts[2]) ) * case1nullLikelihood
# h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCounts))
# print(f"h0: {h0}, h1: {h1}, h2: {h2}, h3: {h3}")
# return torch.log( h0 + h1 + h2 + h3 )
def likelihoodBivariateFast(altCountsByGene, pDs):
    print("shape", altCountsByGene.shape)
    nGenes = altCountsByGene.shape[0]
    
    if(altCountsByGene.shape[1] == 4):
        altCountsFlat = altCountsByGene
    else:
        altCountsFlat = []
        for geneIdx in range(nGenes):
            # ctrl count is first index of first condition, all other conditions get 0 count at 0th index
            altCountsFlat.append([altCountsByGene[geneIdx, 0, 0], *altCountsByGene[geneIdx, :, 1].flatten()])

    altCountsFlat = tensor(altCountsFlat)
    # nGenes x 4 
    xCtrl = altCountsFlat[:, 0]
    xCase1 = altCountsFlat[:, 1]
    xCase2 = altCountsFlat[:, 2]
    xCase12 = altCountsFlat[:, 3]
    # nGenes x 1
    n = xCtrl + xCase1 + xCase2 + xCase12
    print("altCountsFlat", altCountsFlat)
    print("n", n)
    print("xCase1, xCase2, xCase12", xCase1)
    print("xCase1, xCase2, xCase12", xCase2)
    print("xCase1, xCase2, xCase12", xCase12)
    
    pd1 = pDs[0]
    pd2 = pDs[1]
    pdBoth = pDs[2]
    
    # TODO: maybe we just want to explicitly use sample proportions
    pdCtrl = 1 - (pd1 + pd2 + pdBoth)

    case1Null = torch.exp(Binomial(total_count=n, probs=pd1).log_prob(xCase1))
    case2Null = torch.exp(Binomial(total_count=n, probs=pd2).log_prob(xCase2))
    caseBothNull = torch.exp(Binomial(total_count=n, probs=pdBoth).log_prob(xCase12))
    allNull = case1Null * case2Null * caseBothNull
    print("altCountsFlat", altCountsFlat)
    allNull2 = torch.exp(Multinomial(probs=tensor([1-pDs.sum(), pDs[0], pDs[1], pDs[2]])).log_prob(altCountsFlat))
    print("allNull2", allNull2)
    print("pd1, pd2, pdBoth, pdCtrl", pd1, pd2, pdBoth, pdCtrl)
    def likelihood1(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull

        h1 = pi1 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha1, concentration0=alphasSum - alpha1).log_prob(xCase1) ) * case2Null * caseBothNull
        h2 = pi2 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha2, concentration0=alphasSum - alpha2).log_prob(xCase2) ) * case1Null * caseBothNull
        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat))

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    def likelihood1a(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull

        h1 = pi1 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha1, concentration0=alphasSum - alpha1).log_prob(xCase1) ) * case2Null
        h2 = pi2 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha2, concentration0=alphasSum - alpha2).log_prob(xCase2) ) * case1Null
        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat))

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    def likelihood1b(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull

        h1 = pi1 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha1, concentration0=alphasSum - alpha1).log_prob(xCase1 + xCase12) ) * case2Null
        h2 = pi2 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha2, concentration0=alphasSum - alpha2).log_prob(xCase2 + xCase12) ) * case1Null
        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat))

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    def likelihoodConstrained(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull
        
        # idea 1
        # alpha1 and alpha0 determined
        # A gene has counts from gene1 samples 2 , from gene2 samples 1 geneBoth count
        # if i have some people that only have 1, that is evidence for gene1 liability, but says nothing for liability for 
        # the more shared risk there is, the more the count will be in the "both category", 
        # the fewer people will be only one or the other
        # so eventually all 
        # 
        h1 = pi1 * torch.exp( DirichletMultinomial(total_count=n, concentration1=alpha1, concentration0=alphasSum - alpha1).log_prob(xCase1 + xCase12) ) * case2Null
        
        h1 = pi1 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha1, concentration0=alphasSum - alpha1).log_prob(xCase1 + xCase12) ) * case2Null
        h2 = pi2 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha2, concentration0=alphasSum - alpha2).log_prob(xCase2 + xCase12) ) * case1Null
        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat))

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    def likelihood2(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull

        h1 = pi1 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha1, concentration0=alphasSum - alpha1).log_prob(xCase1) ) * case2Null * torch.exp( BetaBinomial(total_count=n, concentration1=alphaBoth, concentration0=alphasSum - alphaBoth).log_prob(xCase12) )
        h2 = pi2 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha2, concentration0=alphasSum - alpha2).log_prob(xCase2) ) * case1Null * torch.exp( BetaBinomial(total_count=n, concentration1=alphaBoth, concentration0=alphasSum - alphaBoth).log_prob(xCase12) )
        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat))

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    def likelihood2a(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull

        h1 = pi1 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha1 + alphaBoth, concentration0=alphasSum - alpha1).log_prob(xCase1) ) * case2Null
        h2 = pi2 * torch.exp( BetaBinomial(total_count=n, concentration1=alpha2 + alphaBoth, concentration0=alphasSum - alpha2).log_prob(xCase2) ) * case1Null 
        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat))

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    
    def likelihood2b(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull2

        h1 = pi1 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha0, alpha1])).log_prob(altCountsFlat) )
        h2 = pi2 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha0, alpha2, alpha2])).log_prob(altCountsFlat) )
        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alphaBoth, alphaBoth, alphaBoth])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    def likelihood2c(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull2

        h1 = pi1 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1+alphaBoth, alpha0, alpha1+alphaBoth])).log_prob(altCountsFlat) )
        h2 = pi2 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha0, alpha2+alphaBoth, alpha2+alphaBoth])).log_prob(altCountsFlat) )
        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1+alphaBoth, alpha2+alphaBoth, alpha1+alphaBoth+alpha2])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    def likelihood2c(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull2

        h1 = pi1 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha0, alpha1])).log_prob(altCountsFlat) )
        h2 = pi2 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha0, alpha2, alpha2])).log_prob(altCountsFlat) )
        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    def likelihood2d(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull2
        
        # From E(P2_null) = P(D2) = a2 / (a0 + a1 + a2 + aB)
        alpha2Null = pDs[1] * (alpha0 + alpha1 + alphaBoth) / (1 - pDs[1])
        # h1 is that the genes in this component only increse risk for disease 1
        h1 = pi1 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2Null, alpha1])).log_prob(altCountsFlat) )
        alpha1Null = pDs[0] * (alpha0 + alpha2 + alphaBoth) / (1 - pDs[0])
        # h2 is that the genes in this component only increse risk for disease 2
        h2 = pi2 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1Null, alpha2, alpha2])).log_prob(altCountsFlat) )
        # h3 is that the alleles in these genes increase risk for both diseases
        # we model by individual and shared components ; hower in this component do samples affected only by disease 1 necessarily have the probability of seeing
        # an allele that is the same as in the other cases? I think no, I think this is higher
        # as the shared component becomes large and larger, if it is added to the other alpha1 (the individual copmonet only), alpha1 will equal alphaBoth, i.e
        # they will be perfectly correlated
        # else, alpha0 will be some amount larger than alphaBoth
        h3 = piBoth * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    def likelihood2e(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull2
        
        # From E(P2_null) = P(D2) = a2 / (a0 + a1 + a2 + aB)
        # P(D2)(others) / (1-PD2) = a2
        alpha2Null = pDs[1] * (alpha0 + alpha1 + alphaBoth) / (1 - pDs[1])
        # h1 is that the genes in this component only increse risk for disease 1
        h1 = pi1 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2Null, alpha1])).log_prob(altCountsFlat) )
        alpha1Null = pDs[0] * (alpha0 + alpha2 + alphaBoth) / (1 - pDs[0])
        # h2 is that the genes in this component only increse risk for disease 2
        h2 = pi2 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1Null, alpha2, alpha2])).log_prob(altCountsFlat) )
        # h3 is that the alleles in these genes increase risk for both diseases
        # we model by individual and shared components ; hower in this component do samples affected only by disease 1 necessarily have the probability of seeing
        # an allele that is the same as in the other cases? I think no, I think this is higher
        # as the shared component becomes large and larger, if it is added to the other alpha1 (the individual copmonet only), alpha1 will equal alphaBoth, i.e
        # they will be perfectly correlated
        # else, alpha0 will be some amount larger than alphaBoth
        h3 = piBoth * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1 + alphaBoth, alpha2 + alphaBoth, alphaBoth + alpha1 + alpha2])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum
        
    # like 2d, but alphaBoth is the last component in each case
    # this allows ffor inference of a weighted alphaBoth
    # following the example of h3, where alpha1 is alpha1, not alpha1 + alphaBoth
    def likelihood2f(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull2
        
        # From E(P2_null) = P(D2) = a2 / (a0 + a1 + a2 + aB)
        alpha2Null = pDs[1] * (alpha0 + alpha1 + alphaBoth) / (1 - pDs[1])
        # h1 is that the genes in this component only increse risk for disease 1
        h1 = pi1 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2Null, alphaBoth])).log_prob(altCountsFlat) )
        alpha1Null = pDs[0] * (alpha0 + alpha2 + alphaBoth) / (1 - pDs[0])
        # h2 is that the genes in this component only increse risk for disease 2
        h2 = pi2 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1Null, alpha2, alphaBoth])).log_prob(altCountsFlat) )
        # h3 is that the alleles in these genes increase risk for both diseases
        # we model by individual and shared components ; hower in this component do samples affected only by disease 1 necessarily have the probability of seeing
        # an allele that is the same as in the other cases? I think no, I think this is higher
        # as the shared component becomes large and larger, if it is added to the other alpha1 (the individual copmonet only), alpha1 will equal alphaBoth, i.e
        # they will be perfectly correlated
        # else, alpha0 will be some amount larger than alphaBoth
        h3 = piBoth * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
        
    # Like 2D, except alphaBoth in h1 and h2 is scaled by P(DBoth/PD1) or P(DBoth)/P(D1), since that is what is required
    # to go to P(D|V) given a shared P(V|D), which is all we mean when we try to equate alpha1 and alphaboth in component 1
    def likelihood2g(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull2
        
        # From E(P2_null) = P(D2) = a2 / (a0 + a1 + a2 + aB)
        alpha2Null = pDs[1] * (alpha0 + alpha1 + alphaBoth) / (1 - pDs[1])
        # h1 is that the genes in this component only increse risk for disease 1
        h1 = pi1 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2Null, alpha1 * pDs[2] / pDs[0]])).log_prob(altCountsFlat) )
        alpha1Null = pDs[0] * (alpha0 + alpha2 + alphaBoth) / (1 - pDs[0])
        # h2 is that the genes in this component only increse risk for disease 2
        h2 = pi2 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1Null, alpha2, alpha2 * pDs[2] / pDs[1]])).log_prob(altCountsFlat) )
        # h3 is that the alleles in these genes increase risk for both diseases
        # we model by individual and shared components ; hower in this component do samples affected only by disease 1 necessarily have the probability of seeing
        # an allele that is the same as in the other cases? I think no, I think this is higher
        # as the shared component becomes large and larger, if it is added to the other alpha1 (the individual copmonet only), alpha1 will equal alphaBoth, i.e
        # they will be perfectly correlated
        # else, alpha0 will be some amount larger than alphaBoth
        h3 = piBoth * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alpha1 + alpha2 + alphaBoth])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
     # Like 2g, except follows the principle of the shared component being additive to individual components,
    # in the shared component
    def likelihood2h(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull2
        
        # From E(P2_null) = P(D2) = a2 / (a0 + a1 + a2 + aB)
        alpha2Null = pDs[1] * (alpha0 + alpha1 + alphaBoth) / (1 - pDs[1])
        alphaBothFrom1 = alpha1 * pDs[2] / pDs[0]
        # h1 is that the genes in this component only increse risk for disease 1
        h1 = pi1 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2Null, alphaBothFrom1])).log_prob(altCountsFlat) )
        alpha1Null = pDs[0] * (alpha0 + alpha2 + alphaBoth) / (1 - pDs[0])
        alphaBothFrom2 = alpha2 * pDs[2] / pDs[1]
        # h2 is that the genes in this component only increse risk for disease 2
        h2 = pi2 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1Null, alpha2, alphaBothFrom2])).log_prob(altCountsFlat) )
        # h3 is that the alleles in these genes increase risk for both diseases
        # we model by individual and shared components ; hower in this component do samples affected only by disease 1 necessarily have the probability of seeing
        # an allele that is the same as in the other cases? I think no, I think this is higher
        # as the shared component becomes large and larger, if it is added to the other alpha1 (the individual copmonet only), alpha1 will equal alphaBoth, i.e
        # they will be perfectly correlated
        # else, alpha0 will be some amount larger than alphaBoth
        h3 = piBoth * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1 + alphaBoth, alpha2 + alphaBoth, alpha1 + alpha2 + alphaBoth])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
        
    # just let everything vary
    def likelihood2i(params):
        # TODO: better to do constrained or unconstrained alpha1?
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull2
        # h1 is that the genes in this component only increse risk for disease 1
        h1 = pi1 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat) ) 
        # h2 is that the genes in this component only increse risk for disease 2
        h2 = pi2 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat) )
        # h3 is that the alleles in these genes increase risk for both diseases
        # we model by individual and shared components ; hower in this component do samples affected only by disease 1 necessarily have the probability of seeing
        # an allele that is the same as in the other cases? I think no, I think this is higher
        # as the shared component becomes large and larger, if it is added to the other alpha1 (the individual copmonet only), alpha1 will equal alphaBoth, i.e
        # they will be perfectly correlated
        # else, alpha0 will be some amount larger than alphaBoth
        h3 = piBoth * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
     # Like 2h except unfucked
    def likelihood2j(params):
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull2
        
        alphaBothFrom1 = alpha1 * pDs[2] / pDs[0]
        # From E(P2_null) = P(D2) = a2 / (a0 + a1 + a2 + aB)
        alpha2Null = pDs[1] * (alpha0 + alpha1 + alphaBothFrom1) / (1 - pDs[1])
        h1 = pi1 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2Null, alphaBothFrom1])).log_prob(altCountsFlat) )
        
        alphaBothFrom2 = alpha2 * pDs[2] / pDs[1]
        alpha1Null = pDs[0] * (alpha0 + alpha2 + alphaBothFrom2) / (1 - pDs[0])
        h2 = pi2 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1Null, alpha2, alphaBothFrom2])).log_prob(altCountsFlat) )

        h3 = piBoth * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1 + alphaBoth, alpha2 + alphaBoth, alpha1 + alpha2 + alphaBoth])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
     # Like 2j except allow last multinomial to freely vary as alpha1, alpha2, alphaBoth
    def likelihood2k(params):
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
        
        alphasSum = alpha0 + alpha1 + alpha2 + alphaBoth
        
        h0 = pi0 * allNull2
        
        alphaBothFrom1 = alpha1 * pDs[2] / pDs[0]
        # From E(P2_null) = P(D2) = a2 / (a0 + a1 + a2 + aB)
        alpha2Null = pDs[1] * (alpha0 + alpha1 + alphaBothFrom1) / (1 - pDs[1])
        h1 = pi1 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2Null, alphaBothFrom1])).log_prob(altCountsFlat) )
        
        alphaBothFrom2 = alpha2 * pDs[2] / pDs[1]
        alpha1Null = pDs[0] * (alpha0 + alpha2 + alphaBothFrom2) / (1 - pDs[0])
        h2 = pi2 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1Null, alpha2, alphaBothFrom2])).log_prob(altCountsFlat) )

        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2, alphaBoth])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    def likelihood2l(params):
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
                
        h0 = pi0 * allNull2
        
        alphaBothFrom1 = alpha1 * pDs[2] / pDs[0]
        # From E(P2_null) = P(D2) = a2 / (a0 + a1 + a2 + aB)
        alpha2Null = pDs[1] * (alpha0 + alpha1 + alphaBothFrom1) / (1 - pDs[1])
        h1 = pi1 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1, alpha2Null, alphaBothFrom1])).log_prob(altCountsFlat) )
        
        alphaBothFrom2 = alpha2 * pDs[2] / pDs[1]
        alpha1Null = pDs[0] * (alpha0 + alpha2 + alphaBothFrom2) / (1 - pDs[0])
        h2 = pi2 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1Null, alpha2, alphaBothFrom2])).log_prob(altCountsFlat) )
        
        alphaBoth1 = alphaBoth * pDs[0] / pDs[2]
        alphaBoth2 = alphaBoth * pDs[1] / pDs[2]
        alpha1_prime = alpha1 + alphaBoth1 #E[X_i] = n * alphaBoth / alphaSum
        alpha2_prime = alpha2 + alphaBoth2
        alphaBoth_prime = alphaBothFrom1 + alphaBothFrom2 + alphaBoth
        # Need cleaner walkthrough the different pieces; what is the prevalance piece vs scaling up or down that happens to match relative risk estimates
        # that it's a single effect size layered on the different proportions
        # we may want some multiplicative effect
        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, alpha1_prime, alpha2_prime, alphaBoth_prime])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    # is_sparse with float32
    def likelihood2lSparse(params):
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
                
        h0 = pi0 * allNull2
        
        alphaBothFrom1 = alpha1 * pDs[2] / pDs[0]
        # From E(P2_null) = P(D2) = a2 / (a0 + a1 + a2 + aB)
        alpha2Null = pDs[1] * (alpha0 + alpha1 + alphaBothFrom1) / (1 - pDs[1])
        h1 = pi1 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1, alpha2Null, alphaBothFrom1], dtype=torch.float32)).log_prob(altCountsFlat) )
        
        alphaBothFrom2 = alpha2 * pDs[2] / pDs[1]
        alpha1Null = pDs[0] * (alpha0 + alpha2 + alphaBothFrom2) / (1 - pDs[0])
        h2 = pi2 * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1Null, alpha2, alphaBothFrom2], dtype=torch.float32)).log_prob(altCountsFlat) )
        
        alphaBoth1 = alphaBoth * pDs[0] / pDs[2]
        alphaBoth2 = alphaBoth * pDs[1] / pDs[2]
        alpha1_prime = alpha1 + alphaBoth1 #E[X_i] = n * alphaBoth / alphaSum
        alpha2_prime = alpha2 + alphaBoth2
        alphaBoth_prime = alphaBothFrom1 + alphaBothFrom2 + alphaBoth
        # Need cleaner walkthrough the different pieces; what is the prevalance piece vs scaling up or down that happens to match relative risk estimates
        # that it's a single effect size layered on the different proportions
        # we may want some multiplicative effect
        h3 = piBoth * torch.exp( DirichletMultinomial(is_sparse=True, total_count=n, concentration=tensor([alpha0, alpha1_prime, alpha2_prime, alphaBoth_prime], dtype=torch.float32)).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()
    
    # In this model, by RWalter's suggestion, all of the alphas are latent, and so need to be scaled by the prevalence of the group
    def likelihood2m(params):
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = params

        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi2 < 0 or piBoth < 0:
            return float("inf")
        
        pi0 = 1.0 - (pi1 + pi2 + piBoth)
        
        if pi0 < 0:
            return float("inf")
     
        h0 = pi0 * allNull2

        actrl = pdCtrl * alpha0
        a11 = pd1 * alpha1
        a12 = pd2 * alpha0
        a1Both = pdBoth * alpha1
        
        h1 = pi1 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([actrl, a11, a12, a1Both])).log_prob(altCountsFlat) )
        
        a21 = pd1 * alpha0
        a22 = pd2 * alpha2
        a2Both = pdBoth * alpha2
        h2 = pi2 * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([actrl, a21, a22, a2Both])).log_prob(altCountsFlat) )
        
        aBoth1 = pd1 * (alpha1 + alphaBoth)
        aBoth2 = pd2 * (alpha2 + alphaBoth)
        aBothBoth = pdBoth * (alpha1 + alpha2 + alphaBoth)
        
        h3 = piBoth * torch.exp( DirichletMultinomial(total_count=n, concentration=tensor([alpha0, aBoth1, aBoth2, aBothBoth])).log_prob(altCountsFlat) )

        return -torch.log( h0 + h1 + h2 + h3 ).sum()

    return likelihood1, likelihood1a, likelihood1b, likelihood2, likelihood2a, likelihood2b, likelihood2c, likelihood2d, likelihood2e, likelihood2f, likelihood2g, likelihood2h, likelihood2i, likelihood2j, likelihood2k, likelihood2l, likelihood2m, likelihood2lSparse

def cb(f, context):
    print("got callback", f, context)

# TODO: update for multivariate
def fitFnUniveriate(altCountsByGene, pDs, nEpochs = 100, minLLThresholdCount = 100, debug = False):
    costFn = likelihoodUnivariateFast(altCountsByGene, pDs)
    
    lls = []
    params = []

    minLLDiff = 1
    thresholdHitCount = 0
    
    nGenes = len(altCountsByGene)

    randomDist = Uniform(1/nGenes, .5)
    randomDist2 = Uniform(0, 1)
    
        # pDgivenV can't be smaller than this assuming allele freq > 1e-6 and rr < 100
    # P(V|D) * P(D) / P(V)
#     pDgivenVbounds = ( pVgivenD(2, 1e-6) * .001 / 1e-6, pVgivenD(100, 1e-2) * .1 / 1e-2 )
#     pi1Bounds = ( 1/nGenes,  1 )
#     bounds = [pDgivenVbounds, pi1Bounds]
    for i in range(nEpochs):
        best = float("inf")
        bestParams = []
        for y in range(100):
            # pi1, p(D|V)
            fnArgs = [randomDist.sample(), randomDist2.sample()]
            ll = costFn(fnArgs)
            if ll < best:
                best = ll
                bestParams = fnArgs
                
        if debug:
            print(f"best ll: {best}, params: {bestParams}")

        fit = scipy.optimize.minimize(costFn, x0 = bestParams, method='Nelder-Mead', options={"maxiter": 10000, "adaptive": True})#gp_minimize(costFn, [(1e-7, .9),(1/nGenes, .99)])#scipy.optimize.minimize(costFn, x0 = fnArgs, method="Nelder-Mead", options={"maxiter": 10000})
        
        if debug:
            print(f"epoch {i}")
            print(fit)

        if not fit["success"] is True:
            if debug:
                print("Failed to converge")
                print(fit)
            continue
        
        pi1, pDgivenV= fit["x"]
        if pDgivenV < 0 or pDgivenV > 1 or pi1 < 1/nGenes or pi1 > 1:
            if debug:
                print("Failed to converge")
                print(fit)
            continue
        
        ll = fit["fun"]
        if len(lls) == 0:
            lls.append(ll)
            params.append(fit["x"])
            continue

        minPrevious = min(lls)
        
        if debug:
            print("minPrevious", minPrevious)
            
        # TODO: take mode of some pc-based cluster of parameters, or some auto-encoded cluster
        if ll < minPrevious and (minPrevious - ll) >= minLLDiff:
            if debug:
                print(f"better by at >= {minLLDiff}; new ll: {fit}")

            lls.append(ll)
            params.append(fit["x"])
            
            thresholdHitCount = 0
            continue

        thresholdHitCount += 1
        
        if thresholdHitCount == minLLThresholdCount:
            break
            
    return {"lls": lls, "params": params}


# TODO: maybe beta distribution should be constrained such that variance is that of the data?
# or maybe there's an analog to 0 mean liability variance?
def fitFnUniveriateBetaBinomial(altCountsByGene, pDs, nEpochs = 100, minLLThresholdCount = 100, debug = False):
    costFn = likelihoodUnivariateBetaBinomialFast(altCountsByGene, pDs)
    
    llsAll = []
    lls = []
    params = []

    minLLDiff = 1
    thresholdHitCount = 0
    
    nGenes = len(altCountsByGene)
    remainingEpochs = nEpochs
    
    randomDist = Uniform(1/nGenes, .5)
    randomDist2 = Uniform(100, 25000)
    # pDgivenV can't be smaller than this assuming allele freq > 1e-6 and rr < 100
    # P(V|D) * P(D) / P(V)
    while remainingEpochs > 0:
        best = float("inf")
        bestParams = []
        for i in range(50):
            # pi1, alpha1, alpha0
            fnArgs = [randomDist.sample(), randomDist2.sample(), randomDist2.sample()]
            ll = costFn(fnArgs)
            if ll < best:
                best = ll
                bestParams = fnArgs
        
        if debug:
            print(f"best ll: {best}, bestParams: {bestParams}")

        fit = scipy.optimize.minimize(costFn, x0 = bestParams, method='Nelder-Mead', options={"maxiter": 10000, "adaptive": True})#gp_minimize(costFn, [(1e-7, .9),(1/nGenes, .99)])#scipy.optimize.minimize(costFn, x0 = fnArgs, method="Nelder-Mead", options={"maxiter": 10000})
        #fit = scipy.optimize.basinhopping(costFn, x0 = bestParams)
        if debug:
            print(f"epoch {remainingEpochs}")
            print(fit)

        if not fit["success"] is True:
            if debug:
                print("Failed to converge")
                print(fit)
            continue
        
        
        pi1, alpha1, alpha0 = fit["x"]
        # TODO: is pi1 > .5 restriction sound?
        if pi1 < 1/nGenes or pi1 > .5 or alpha1 <= 0 or alpha0 <= 0:
            if debug:
                print("Failed to converge")
                print(fit)
            continue
            
        remainingEpochs -= 1
        
        ll = fit["fun"]
        llsAll.append(ll)
        if len(lls) == 0:
            lls.append(ll)
            params.append(fit["x"])
            continue

        minPrevious = min(lls)
        
        if debug:
            print("minPrevious", minPrevious)
            
        # TODO: take mode of some pc-based cluster of parameters, or some auto-encoded cluster
        if ll < minPrevious and (minPrevious - ll) >= minLLDiff:
            if debug:
                print(f"better by at >= {minLLDiff}; new ll: {fit}")

            lls.append(ll)
            params.append(fit["x"])
            
            thresholdHitCount = 0
            continue

        thresholdHitCount += 1
        
        if thresholdHitCount == minLLThresholdCount:
            break
            
    return {"lls": lls, "params": params, "llTrajectory": llsAll}

# TODO: maybe beta distribution should be constrained such that variance is that of the data?
# or maybe there's an analog to 0 mean liability variance
def fitFnBivariate(altCountsByGene, pDs, nEpochs = 100, minLLThresholdCount = 100, K = 4, debug = False, costFnIdx = 0):
    costFunctions = likelihoodBivariateFast(altCountsByGene, pDs)
        
    costFn = costFunctions[costFnIdx]
    print("past", costFn)
    llsAll = []
    lls = []
    params = []

    minLLDiff = 1
    thresholdHitCount = 0
    
    nGenes = len(altCountsByGene)
    
    # pDgivenV can't be smaller than this assuming allele freq > 1e-6 and rr < 100
    # P(V|D) * P(D) / P(V)
    pi0Dist = Uniform(.5, 1)
    alphasDist = Uniform(100, 25000)    
    for i in range(nEpochs):
        # TODO: should we constrain alpha0 to the pD, i.e
        # E[P(D)] = alpha1 / sum(alphasRes)
        # P(D) * (alphasRes) = alpha1
        best = float("inf")
        bestParams = []
        for y in range(100):
            pi0 = pi0Dist.sample()
            pis = Uniform(1/nGenes, 1-pi0).sample([K-1])
            pis = pis/(pis.sum() + pi0)
#             print("pi0", pi0, "pis", pis, "sum", pis.sum())
            fnArgs = [*pis, *alphasDist.sample([K,])]

            ll = costFn(fnArgs)
            if ll < best:
                best = ll
                bestParams = fnArgs
        
        print(f"best ll: {best}, bestParams: {bestParams}")

#         fnArgs = [probs[0], probs[1], probs[2], *alphas]
        fit = scipy.optimize.minimize(costFn, x0 = bestParams, method='Nelder-Mead', options={"maxiter": 10000, "adaptive": True})

        if debug:
            print(f"epoch {i}")
            print(fit)

        if not fit["success"] is True:
            if debug:
                print("Failed to converge")
                print(fit)
            continue
        
        
        pi1, pi2, piBoth, alpha0, alpha1, alpha2, alphaBoth = fit["x"]
        if alpha0 < 0 or alpha1 < 0 or alpha2 < 0 or alphaBoth < 0 or pi1 < 0 or pi1 > 1 or pi2 < 0 or pi2 > 1 or piBoth < 0 or piBoth > 1:
            if debug:
                print("Failed to converge")
                print(fit)
            continue
        
        ll = fit["fun"]
        llsAll.append(ll)
        if len(lls) == 0:
            lls.append(ll)
            params.append(fit["x"])
            continue

        minPrevious = min(lls)
        
        if debug:
            print("minPrevious", minPrevious)
            
        # TODO: take mode of some pc-based cluster of parameters, or some auto-encoded cluster
        if ll < minPrevious and (minPrevious - ll) >= minLLDiff:
            if debug:
                print(f"better by at >= {minLLDiff}; new ll: {fit}")

            lls.append(ll)
            params.append(fit["x"])
            
            thresholdHitCount = 0
            continue

        thresholdHitCount += 1
        
        if thresholdHitCount == minLLThresholdCount:
            break
            
    return {"lls": lls, "params": params, "llTrajectory": llsAll}



def initBetaParams(mu, variance):
    alpha = ((1 - mu) / variance - 1 / variance) * mu**2
    beta = alpha * (1/mu -1)
    
    return alpha, beta

In [None]:
###### all named tuples used

Samples = namedtuple("Samples", ["ctrls", "cases"])

In [313]:
# nSamples shape: [nConditions, 2] , last dim is ctrls, cases
# Generating process

# I am gene 1
# i have 3 possible contributions to diseases 1, 2
# 1) I contribute to disease 1 only
# 2) I contribute to disease 2 only
# 3) I contribute to both

# For each gene I have some counts. Some people are only disease 1, some are only disease 2, some are both
# My probability of contributing to disease 1 is a sum of 1 & both, because P(D1|V) = P(D1Only|V) + P(D1And2|V)
# My probability of contributing to diesase 2 is sum of 2 & both

# I have some counts. Those people that are only known only to have disease 1 we estimate are contributing only to disease 1
# Some of

# what is correlation between disease 1 and 2.


# this is insanely slow for some reason, and almost all time is in the expanded binomial sampling
# def genData(nSamples, pDs, diseaseFractions, rrShape, rrMeans, afMean, afShape, nGenes = 20000):
#     # TODO: assert shapes match
#     print("TESTING WITH: nSamples", nSamples, "rrMean", rrMeans, "rrShape", rrShape, "afMean", afMean, "afShape", afShape, "diseaseFractions", diseaseFractions, "pDs", pDs)
    
#     nConditions = len(nSamples)
#     probs = []
#     afDist = Gamma(concentration=afShape,rate=afShape/afMean)
#     rrDist = Gamma(concentration=rrShape,rate=rrShape/rrMeans)
#     rrNullDist = Gamma(concentration=rrShape,rate=rrShape.expand(nConditions))
    
#     # shape == [nGenes, nConditions]
#     afs = afDist.sample([nGenes,])   
#     rrs = rrDist.sample([nGenes,])
#     rrNulls = rrNullDist.sample([nGenes,])
#     for geneIdx in range(nGenes):
#         geneProbs = []
#         for conditionIdx in range(nConditions):
#             # TODO: sample from uniform
#             if geneIdx < nGenes * diseaseFractions[conditionIdx]:
#                 rr = rrs[geneIdx, conditionIdx]
#             else:
#                 rr = rrNulls[geneIdx, conditionIdx]
            
#             probVgivenD = pVgivenD(rr, afs[geneIdx])
#             probVgivenNotD = pVgivenNotD(pDs[conditionIdx], afs[geneIdx], probVgivenD)
            
#             geneProbs.append([probVgivenNotD, probVgivenD])
#         probs.append(geneProbs)
#     probs = tensor(probs)

#     # This should not be slow but is
#     # https://github.com/pytorch/pytorch/issues/11389
#     start = time.time()
#     altCounts = Binomial(total_count=nSamples.expand([nGenes, *nSamples.shape]), probs=probs).sample()
#     print("final sampling took", time.time() - start)
    
#     return altCounts, probs

def genDataSequential(nSamples, pDs, diseaseFractions, rrShape, rrMeans, afMean, afShape, nGenes = 20000):
    # TODO: assert shapes match
    print("TESTING WITH: nSamples", nSamples, "rrMean", rrMeans, "rrShape", rrShape, "afMean", afMean, "afShape", afShape, "diseaseFractions", diseaseFractions, "pDs", pDs)
    
    nConditions = len(nSamples)
    assert(nConditions == 3)
    altCounts = []
    probs = []
    afDist = Gamma(concentration=afShape,rate=afShape/afMean)
    rrDist = Gamma(concentration=rrShape,rate=rrShape/rrMeans)
    rrNullDist = Gamma(concentration=rrShape,rate=rrShape.expand(nConditions))
    
    # shape == [nGenes, nConditions]
    afs = afDist.sample([nGenes,])   
    rrs = rrDist.sample([nGenes,])
    rrNulls = rrNullDist.sample([nGenes,])
    
    endIndices = nGenes * diseaseFractions
    startIndices = []
    for i in range(len(diseaseFractions)):
        if i == 0:
            startIndices.append(0)
            continue
        endIndices[i] += endIndices[i-1]
        startIndices.append(endIndices[i-1])
    
    print("startIndices", startIndices, "endIndices", endIndices)
    affectsCond1Only, affectsCond2Only, affectsBoth = False, False, False
    for geneIdx in range(nGenes):
        geneAltCounts = [0, 0, 0]
        geneProbs = [0, 0, 0]
        rrs = [1, 1, 1]
        # Each gene gets only 1 state: affects condition 1 only, condition 2 only, or both
        # currently, in the both case, the increased in counts (rr) is. the same for both conditions
        for conditionIdx in range(nConditions):
            if geneIdx >= startIndices[conditionIdx] and geneIdx < endIndices[conditionIdx]:
                if conditionIdx == 0:
                    affectsCond1Only = True
                elif conditionIdx == 1:
                    affectsCond2Only = True
                else:
                    affectsBoth = True
        
        # gene affects one of 3 states
        # based on which state it affects, sampleCase1, samplesCase2, samplesBoth get different rrs for this gene
        # controls always get the same value, and that is based on 1 - sum(rrs)
        if affectsCond1Only:
            rrs[0] = rrs[geneIdx, conditionIdx]
            rrs[1] = rrNulls[geneIdx, conditionIdx]
            rrs[2] = rrs[geneIdx, conditionIdx] #both always gets a rr of non-1
        elif affectsCond2Only:
            rrs[0] = rrNulls[geneIdx, conditionIdx]
            rrs[1] = rrs[geneIdx, conditionIdx]
            rrs[2] = rrs[geneIdx, conditionIdx]
        elif affectsBoth:
            # q: should these really get the same value?
            # where is the concept of covariance here?
            rrs[0] = rrs[geneIdx, conditionIdx]
            rrs[1] = rrs[geneIdx, conditionIdx]
            rrs[2] = rrs[geneIdx, conditionIdx]
            
            probVgivenDs = pVgivenD(rr, afs[geneIdx])
            altCountsCases = Binomial(total_count=nSamples[conditionIdx][1], probs=probVgivenD).sample()
            
            # we can use one simulation to study pooled an separate samples
            # in the pooled model, we could sum control samples during inference
            probVgivenNotD = pVgivenNotD(pDs[conditionIdx], afs[geneIdx], probVgivenD)
            altCountsCtrls = Binomial(total_count=nSamples[conditionIdx][0], probs=probVgivenNotD).sample()
            
            geneAltCounts.append([altCountsCtrls, altCountsCases])
            geneProbs.append([probVgivenNotD, probVgivenD])
        altCounts.append(geneAltCounts)
        probs.append(geneProbs)
    altCounts = tensor(altCounts)
    probs = tensor(probs)
    
    return altCounts, probs

def genDataSequentialPooledCtrls(nCases, nCtrls, pDs, diseaseFractions, rrShape, rrMeans, afMean, afShape, nGenes = 20000):
    # TODO: assert shapes match
    print("TESTING POOLED WITH: nCases", nCases, "nCtrls", nCtrls, "rrMean", rrMeans, "rrShape", rrShape, "afMean", afMean, "afShape", afShape, "diseaseFractions", diseaseFractions, "pDs", pDs)
    
    nConditions = len(nCases)
    altCounts = []
    probs = []
    afDist = Gamma(concentration=afShape,rate=afShape/afMean)
    rrDist = Gamma(concentration=rrShape,rate=rrShape/rrMeans)
    rrNullDist = Gamma(concentration=rrShape,rate=rrShape.expand(nConditions))
    
    # shape == [nGenes, nConditions]
    afs = afDist.sample([nGenes,])   
    rrs = rrDist.sample([nGenes,])
    rrNulls = rrNullDist.sample([nGenes,])
    
    endIndices = nGenes * diseaseFractions
    startIndices = []
    for i in range(len(diseaseFractions)):
        if i == 0:
            startIndices.append(0)
            continue
        endIndices[i] += endIndices[i-1]
        startIndices.append(endIndices[i-1])
    
    print("startIndices", startIndices, "endIndices", endIndices)
    for geneIdx in range(nGenes):
        geneAltCounts = []
        geneProbs = []

        probVgivenDs = []
        for conditionIdx in range(nConditions):
            # TODO: sample from uniform
            if geneIdx >= startIndices[conditionIdx] and geneIdx < endIndices[conditionIdx]:
                rr = rrs[geneIdx, conditionIdx]
            else:
                rr = rrNulls[geneIdx, conditionIdx]
            probVgivenDs.append(pVgivenD(rr, afs[geneIdx]))

        probVgivenDs = tensor(probVgivenDs)
#         print("probVgivenDs", probVgivenDs)
        altCountsCases = Binomial(total_count=nCases, probs=probVgivenDs).sample()

#         print("altCountsCases", altCountsCases, "altCountCases.shape", altCountsCases.shape)
#         print("0 index", altCountsCases[0])
        # we can use one simulation to study pooled an separate samples
        # in the pooled model, we could sum control samples during inference
        probVgivenNotD = pVgivenNotD(pDs, afs[geneIdx], probVgivenDs)
#         print("probVgivenNotD", probVgivenNotD)
        
        altCountsCtrls = Binomial(total_count=nCtrls, probs=probVgivenNotD).sample()
#         print("altCountsCtrls", altCountsCtrls)
        
        for conditionIdx in range(nConditions):
            if conditionIdx == 0:
                geneAltCounts.append([altCountsCtrls, altCountsCases[conditionIdx]])
            else:
                geneAltCounts.append([0, altCountsCases[conditionIdx]])
            geneProbs.append([probVgivenNotD, probVgivenDs[conditionIdx]])

        altCounts.append(geneAltCounts)
        probs.append(geneProbs)
    altCounts = tensor(altCounts)
    probs = tensor(probs)
    
    return altCounts, probs

## Another view on generative process
# there is some contribution to disease 1 with probability P(V|D1)
# there is some contribution to disease 2 with probability P(V|D2)
# there is some contribution to both diseases with probability P(V|DBoth)
# there is some contribution to no diseases with probability P(V|None)
# The total P(V) = P(V|D1)*P(D1) + P(V|D2)*P(D2) + P(V|DBoth)*P(DBoth) + P(V|Ctrl)*P(Ctrl) = (say) 1e-4


# P(V|D1) = (P(D1|V) * P(V)) / P(D1)

# When do we see counts for diseases 1? When P(V|D1) + P(V|DBoth)
# A person who is ascertained as having both diseases must be the contribution of 1 - P(V|Ctrl) (P(D1|V) + P(D2|V) + P(V|DBoth))
# A person who has 1 : P(V|D1) + P(V|DBoth) and for person 2 : P(V|D2) + P(V|DBoth)

# If we didn't people ascertained for both, we could use 1 - P(V|Ctrl1) + P(V|Ctrl2) ?

# in our components, component 1 and 2 give no contribution whatsoever, so they are measuring singular effect
# in our 

# def intersection(list1, list2):

# TODO: this generates bivariate data, expand to multinomial

# If i'm a sample with disease1 only, I get the effect1 component for the gene, null2 component for the gene, and the effectBoth component for the gene
# If i'm a sample with disease2 only, I get the null1 component for the gene, effect2 component for the gene, and the effectBoth component for the gene
# If I'm a sample with both, I always get the full contribution

# Finally, genes are either in categories none, 1only, 2only, or both
# For each gene, I think we should sample the state in each category only once

# If I'm a case1only, and if the gene is a risk gene, my count should reflect P(V|D2)
# If I'm a ctrl and gene is associated, the count should reflect P(!D|V)
# If I'm a case2only, I should have P(V|D2)
# If I'm a both case, I think my P(V|D) should reflect a relative risk that is the sum of rr1 + rrLatentBoth if 

# In Dave's picture, a gene that affects disease 1 only adds the same liability to all indviduals 
def genData2(nCases, nCtrls, pDs, diseaseFractions, rrShape, rrMeans, afMean, afShape, nGenes = 20000):
    # TODO: assert shapes match
    print("TESTING WITH: nSamples", nSamples, "rrMean", rrMeans, "rrShape", rrShape, "afMean", afMean, "afShape", afShape, "diseaseFractions", diseaseFractions, "pDs", pDs)
    
    nConditions = len(nSamples)
    assert(nConditions == 3)
    altCounts = []
    probs = []
    afDist = Gamma(concentration=afShape,rate=afShape/afMean)
    rrDist = Gamma(concentration=rrShape,rate=rrShape/rrMeans)
    print("rrDist mean", rrDist.sample([10_000,]).mean(0))
    rrNullDist = Gamma(concentration=rrShape,rate=rrShape.expand(nConditions))
    
    # shape == [nGenes, nConditions]
    afs = afDist.sample([nGenes,])   
    rrs = rrDist.sample([nGenes,])
#     print("rrs", rrs);
    rrNulls = rrNullDist.sample([nGenes,])
    
    endIndices = nGenes * diseaseFractions
    startIndices = []
    for i in range(len(diseaseFractions)):
        if i == 0:
            startIndices.append(0)
            continue
        endIndices[i] += endIndices[i-1]
        startIndices.append(endIndices[i-1])
    
    print("startIndices", startIndices, "endIndices", endIndices)
#     print("rrs", rrs[0:2000,0].mean())
    for geneIdx in range(nGenes):
        geneAltCounts = []
        geneProbs = []
        rrSamples = tensor([1., 1., 1.])
        affects = 0
        # Each gene gets only 1 state: affects condition 1 only, condition 2 only, or both
        # currently, in the both case, the increased in counts (rr) is. the same for both conditions
        for conditionIdx in range(nConditions):
            if geneIdx >= startIndices[conditionIdx] and geneIdx < endIndices[conditionIdx]:
                if conditionIdx == 0:
                    affects = 1
                elif conditionIdx == 1:
                    affects = 2
                elif conditionIdx == 2:
                    affects = 3
                else:
                    assert(conditionIdx <= 2)
                    
        rrSamples = rrNulls[geneIdx, :]
        
        # gene affects one of 3 states
        # based on which state it affects, sampleCase1, samplesCase2, samplesBoth get different rrs for this gene
        # controls always get the same value, and that is based on 1 - sum(rrs)
        if affects == 0:
            rrSamples = rrNulls[geneIdx, :] # Maybe just have == 1, follows 
        elif affects == 1:
#             print(f"affects1: {geneIdx}")
            rrSamples[0] = rrs[geneIdx, 0] #rr for  samples affected with disease 1
            rrSamples[1] = rrNulls[geneIdx, 1] #samples affected with diseases 2
            rrSamples[2] = rrs[geneIdx, 0] #both always gets a rr of non-1
        elif affects == 2:
#             print(f"affects2: {conditionIdx}")
            rrSamples[0] = rrNulls[geneIdx, 0]
            rrSamples[1] = rrs[geneIdx, 1]
            rrSamples[2] = rrs[geneIdx, 1]
        elif affects == 3:
#             print(f"affectsBoth: {conditionIdx}")
            # q: should these really get the same value?
            # where is the concept of covariance here?
            rrSamples[0] = rrs[geneIdx, 2]
            rrSamples[1] = rrs[geneIdx, 2]
            rrSamples[2] = rrs[geneIdx, 2]
        
        probVgivenDs = pVgivenD(rrSamples, afs[geneIdx])
#         print("rrs", rrSamples, rrNulls[geneIdx, :], rrs[geneIdx, :])
#         print(probVgivenDs)
#         print("probVgivenDs", probVgivenDs, "for rrs", rrSamples)
        altCountsCases = Binomial(total_count=nCases, probs=probVgivenDs).sample()
#         print("altCountsCases", altCountsCases)
        # we can use one simulation to study pooled an separate samples
        # in the pooled model, we could sum control samples during inference
        probVgivenNotD = pVgivenNotD(pDs, afs[geneIdx], probVgivenDs)
#         print("nCases", nCases, "probVgivenDs/probVgivenNotD", probVgivenDs / probVgivenNotD)
        altCountsCtrls = Binomial(total_count=nCtrls, probs=probVgivenNotD).sample()
#         print("altCountsCtrls", altCountsCtrls)
        
        for conditionIdx in range(nConditions):
            if conditionIdx == 0:
                geneAltCounts.append([altCountsCtrls, altCountsCases[conditionIdx]])
            else:
                geneAltCounts.append([0, altCountsCases[conditionIdx]])
            geneProbs.append([probVgivenNotD, probVgivenDs[conditionIdx]])
#         print(f"geneAltCoutns {geneAltCounts} \n")
        altCounts.append(geneAltCounts)
        probs.append(geneProbs)
#     print("altCounts", altCounts)
    altCounts = tensor(altCounts)
    probs = tensor(probs)
    
    return altCounts, probs

# like 2, but in "both" component, samples that are case1 only get rr1, samples that are case2 only get rr2, samples that are both get rrBoth
# If i'm a sample that is affected by both conditions, that means
# that I get the invididual contribution from each, and I also get the shared component
# XShared ~ X1 + X3 + X2
# If I'm an individual only affected by d1, I get Y1 = X1 + X3
# If I'm an individual only affected by d2 I get Y2 = X2 + X3
# If I'm an individual affected by both I think I get Y3 = X1 + X2 + X3
# What if: in genes that affect only 1 disease give X1, genes that affect d2 only give X2, genes that affect both get d2,
def genData3(nCases, nCtrls, pDs, diseaseFractions, rrShape, rrMeans, afMean, afShape, nGenes = 20000):
    # TODO: assert shapes match
    print("TESTING WITH: nSamples", nSamples, "rrMean", rrMeans, "rrShape", rrShape, "afMean", afMean, "afShape", afShape, "diseaseFractions", diseaseFractions, "pDs", pDs)
    
    nConditions = len(nSamples)
    assert(nConditions == 3)
    altCounts = []
    probs = []
    afDist = Gamma(concentration=afShape,rate=afShape/afMean)
    rrDist = Gamma(concentration=rrShape,rate=rrShape/rrMeans)
    print("rrDist mean", rrDist.sample([10_000,]).mean())
    rrNullDist = Gamma(concentration=rrShape,rate=rrShape.expand(nConditions))
    
    # shape == [nGenes, nConditions]
    afs = afDist.sample([nGenes,])   
    rrs = rrDist.sample([nGenes,])
#     print("rrs", rrs);
    rrNulls = rrNullDist.sample([nGenes,])
    
    endIndices = nGenes * diseaseFractions
    startIndices = []
    for i in range(len(diseaseFractions)):
        if i == 0:
            startIndices.append(0)
            continue
        endIndices[i] += endIndices[i-1]
        startIndices.append(endIndices[i-1])
    
    print("startIndices", startIndices, "endIndices", endIndices)
#     print("rrs", rrs[0:2000,0].mean())
    for geneIdx in range(nGenes):
        geneAltCounts = []
        geneProbs = []
        rrSamples = tensor([1., 1., 1.])
        affects = 0
        # Each gene gets only 1 state: affects condition 1 only, condition 2 only, or both
        # currently, in the both case, the increased in counts (rr) is. the same for both conditions
        for conditionIdx in range(nConditions):
            if geneIdx >= startIndices[conditionIdx] and geneIdx < endIndices[conditionIdx]:
                if conditionIdx == 0:
                    affects = 1
                elif conditionIdx == 1:
                    affects = 2
                elif conditionIdx == 2:
                    affects = 3
                else:
                    assert(conditionIdx <= 2)
                    
        rrSamples = rrNulls[geneIdx, :]
        
        # gene affects one of 3 states
        # based on which state it affects, sampleCase1, samplesCase2, samplesBoth get different rrs for this gene
        # controls always get the same value, and that is based on 1 - sum(rrs)
        if affects == 0:
            rrSamples = rrNulls[geneIdx, :]
        elif affects == 1:
#             print(f"affects1: {geneIdx}")
            rrSamples[0] = rrs[geneIdx, 0]
            rrSamples[1] = rrNulls[geneIdx, 1]
            rrSamples[2] = rrs[geneIdx, 0] #both always gets a rr of non-1
        elif affects == 2:
#             print(f"affects2: {conditionIdx}")
            rrSamples[0] = rrNulls[geneIdx, 0]
            rrSamples[1] = rrs[geneIdx, 1]
            rrSamples[2] = rrs[geneIdx, 1]
        elif affects == 3:
#           # Here is the difference, samples with both get their own risk
            rrSamples = rrs[geneIdx, :]
        
        probVgivenDs = pVgivenD(rrSamples, afs[geneIdx])
#         print("rrs", rrSamples, rrNulls[geneIdx, :], rrs[geneIdx, :])

#         print("probVgivenDs", probVgivenDs, "for rrs", rrSamples)
        altCountsCases = Binomial(total_count=nCases, probs=probVgivenDs).sample()
#         print("altCountsCases", altCountsCases)
        # we can use one simulation to study pooled an separate samples
        # in the pooled model, we could sum control samples during inference
        probVgivenNotD = pVgivenNotD(pDs, afs[geneIdx], probVgivenDs)
#         print("nCases", nCases, "probVgivenDs/probVgivenNotD", probVgivenDs / probVgivenNotD)
        altCountsCtrls = Binomial(total_count=nCtrls, probs=probVgivenNotD).sample()
#         print("altCountsCtrls", altCountsCtrls)
        
        for conditionIdx in range(nConditions):
            if conditionIdx == 0:
                geneAltCounts.append([altCountsCtrls, altCountsCases[conditionIdx]])
            else:
                geneAltCounts.append([0, altCountsCases[conditionIdx]])
            geneProbs.append([probVgivenNotD, probVgivenDs[conditionIdx]])
#         print(f"geneAltCoutns {geneAltCounts} \n")
        altCounts.append(geneAltCounts)
        probs.append(geneProbs)
#     print("altCounts", altCounts)
    altCounts = tensor(altCounts)
    probs = tensor(probs)
    
    return altCounts, probs

# like 2, but in "both" component, samples that are case1 only get rr1, samples that are case2 only get rr2, samples that are both get rr1+rrBoth, rr2+rrBoth, or rr1 + rr2 + rrBoth
# I think this one assumes that rrBoth is small (it's not the relative risk of genes that contribute to both, but the shared contribution...genes that contribute to both diseases
# confer with rr1only, rr2only, rr1only + rr2only + rrBoth)
# If i'm a sample that is affected by both conditions, that means
# that I get the invididual contribution from each, and I also get the shared component
# XShared ~ X1 + X3 + X2
# If I'm an individual only affected by d1, I get Y1 = X1 + X3
# If I'm an individual only affected by d2 I get Y2 = X2 + X3
# If I'm an individual affected by both I think I get Y3 = X1 + X2 + X3
# What if: in genes that affect only 1 disease give X1, genes that affect d2 only give X2, genes that affect both get d2,
def genData4(nCases, nCtrls, pDs, diseaseFractions, rrShape, rrMeans, afMean, afShape, nGenes = 20000):
    # TODO: assert shapes match
    print("TESTING WITH: nCases", nCases, "nCtrls", nCtrls, "rrMeans", rrMeans, "rrShape", rrShape, "afMean", afMean, "afShape", afShape, "diseaseFractions", diseaseFractions, "pDs", pDs)
    
    nConditions = len(nCases)
    assert(nConditions == 3)
    altCounts = []
    probs = []
    afDist = Gamma(concentration=afShape,rate=afShape/afMean)
    rrDist = Gamma(concentration=rrShape,rate=rrShape/rrMeans)
    print("rrDist mean", rrDist.sample([10_000,]).mean(0))
    rrNullDist = Gamma(concentration=rrShape,rate=rrShape.expand(nConditions))
    
    # shape == [nGenes, nConditions]
    afs = afDist.sample([nGenes,])   
    rrs = rrDist.sample([nGenes,])
    rrNulls = rrNullDist.sample([nGenes,]) #tensor([1.]).expand([20_000, nConditions])#

    endIndices = nGenes * diseaseFractions
    startIndices = []
    for i in range(len(diseaseFractions)):
        if i == 0:
            startIndices.append(0)
            continue
        endIndices[i] += endIndices[i-1]
        startIndices.append(endIndices[i-1])
    
    print("startIndices", startIndices, "endIndices", endIndices)
    
    affectedGenes = [[]]
    unaffectedGenes = []
    for geneIdx in range(nGenes):
        geneAltCounts = []
        geneProbs = []
        rrSamples = tensor([1., 1., 1.])
        affects = 0
        # Each gene gets only 1 state: affects condition 1 only, condition 2 only, or both
        # currently, in the both case, the increased in counts (rr) is. the same for both conditions
        for conditionIdx in range(nConditions):
            if geneIdx >= startIndices[conditionIdx] and geneIdx < endIndices[conditionIdx]:
                if conditionIdx == 0:
                    affects = 1
                elif conditionIdx == 1:
                    affects = 2
                elif conditionIdx == 2:
                    affects = 3
                else:
                    assert(conditionIdx <= 2)
                
                
                if len(affectedGenes) <= conditionIdx:
                    affectedGenes.append([])
                affectedGenes[conditionIdx].append(geneIdx)
                break
        
        # gene affects one of 3 states
        # based on which state it affects, sampleCase1, samplesCase2, samplesBoth get different rrs for this gene
        # controls always get the same value, and that is based on 1 - sum(rrs)
        if affects == 0:
            rrSamples = rrNulls[geneIdx, :]
            unaffectedGenes.append(geneIdx)
        elif affects == 1:
#             print(f"affects1: {geneIdx}")
            rrSamples[0] = rrs[geneIdx, 0]
            rrSamples[1] = rrNulls[geneIdx, 1]
            rrSamples[2] = rrs[geneIdx, 0] #both always gets a rr of non-1
        elif affects == 2:
#             print(f"affects2: {conditionIdx}")
            rrSamples[0] = rrNulls[geneIdx, 0]
            rrSamples[1] = rrs[geneIdx, 1]
            rrSamples[2] = rrs[geneIdx, 1]
        elif affects == 3:
            rrSamples[0] = rrs[geneIdx, 0] + rrs[geneIdx, 2]
            rrSamples[1] = rrs[geneIdx, 1] + rrs[geneIdx, 2]
            rrSamples[2] = rrs[geneIdx, 0] + rrs[geneIdx, 1] + rrs[geneIdx, 2]
#         print("affects", affects, "rrSamples", rrSamples)
        probVgivenDs = pVgivenD(rrSamples, afs[geneIdx])
        probVgivenNotD = pVgivenNotD(pDs, afs[geneIdx], probVgivenDs)

        altCountsCases = Binomial(total_count=nCases, probs=probVgivenDs).sample()
        
        altCountsCtrls = Binomial(total_count=nCtrls, probs=probVgivenNotD).sample()
        
        for conditionIdx in range(nConditions):
            if conditionIdx == 0:
                geneAltCounts.append([altCountsCtrls, altCountsCases[conditionIdx]])
            else:
                geneAltCounts.append([0, altCountsCases[conditionIdx]])
            geneProbs.append([probVgivenNotD, probVgivenDs[conditionIdx]])
        altCounts.append(geneAltCounts)
        probs.append(geneProbs)

    altCounts = tensor(altCounts)
    probs = tensor(probs)
    
    # cannot convert affectedGenes to tensor; apparently tensors need to have same dimensions at each level of the tensor...stupid
    return altCounts, probs, affectedGenes, unaffectedGenes

# genData4, but with 1 rr for null loci
def genData4b(nCases, nCtrls, pDs, diseaseFractions, rrShape, rrMeans, afMean, afShape, nGenes = 20000):
    # TODO: assert shapes match
    print("TESTING WITH: nCases", nCases, "nCtrls", nCtrls, "rrMeans", rrMeans, "rrShape", rrShape, "afMean", afMean, "afShape", afShape, "diseaseFractions", diseaseFractions, "pDs", pDs)
    
    nConditions = len(nCases)
    assert(nConditions == 3)
    altCounts = []
    probs = []
    afDist = Gamma(concentration=afShape,rate=afShape/afMean)
    rrDist = Gamma(concentration=rrShape,rate=rrShape/rrMeans)
    print("rrDist mean", rrDist.sample([10_000,]).mean(0))
#     rrNullDist = Gamma(concentration=rrShape,rate=rrShape.expand(nConditions))
    
    # shape == [nGenes, nConditions]
    afs = afDist.sample([nGenes,])   
    rrs = rrDist.sample([nGenes,])

    endIndices = nGenes * diseaseFractions
    startIndices = []
    for i in range(len(diseaseFractions)):
        if i == 0:
            startIndices.append(0)
            continue
        endIndices[i] += endIndices[i-1]
        startIndices.append(endIndices[i-1])
    
    print("startIndices", startIndices, "endIndices", endIndices)
    
    affectedGenes = [[]]
    unaffectedGenes = []

    for geneIdx in range(nGenes):
        geneAltCounts = []
        geneProbs = []
        affects = 0
        rrSamples = tensor([1., 1., 1.])
        # Each gene gets only 1 state: affects condition 1 only, condition 2 only, or both
        # currently, in the both case, the increased in counts (rr) is. the same for both conditions
        for conditionIdx in range(nConditions):
            if geneIdx >= startIndices[conditionIdx] and geneIdx < endIndices[conditionIdx]:
                if conditionIdx == 0:
                    affects = 1
                elif conditionIdx == 1:
                    affects = 2
                elif conditionIdx == 2:
                    affects = 3
                else:
                    assert(conditionIdx <= 2)
                
                
                if len(affectedGenes) <= conditionIdx:
                    affectedGenes.append([])
                affectedGenes[conditionIdx].append(geneIdx)
                break
        
        assert(affects <= 3)
        # gene affects one of 3 states
        # based on which state it affects, sampleCase1, samplesCase2, samplesBoth get different rrs for this gene
        # controls always get the same value, and that is based on 1 - sum(rrs)
        if affects == 0:
            unaffectedGenes.append(geneIdx)
        elif affects == 1:
#             print(f"affects1: {geneIdx}")
            rrSamples[0] = rrs[geneIdx, 0]
            rrSamples[2] = rrs[geneIdx, 0] #both always gets a rr of non-1
        elif affects == 2:
#             print(f"affects2: {geneIdx}")
            rrSamples[1] = rrs[geneIdx, 1]
            rrSamples[2] = rrs[geneIdx, 1]
        elif affects == 3:
#             print(f"affects2: {geneIdx}")
            rrSamples[0] = rrs[geneIdx, 0] + rrs[geneIdx, 2]
            rrSamples[1] = rrs[geneIdx, 1] + rrs[geneIdx, 2]
            rrSamples[2] = rrs[geneIdx, 0] + rrs[geneIdx, 1] + rrs[geneIdx, 2]
#         print("affects", affects, "rrSamples", rrSamples)
        probVgivenDs = pVgivenD(rrSamples, afs[geneIdx])
        probVgivenNotD = pVgivenNotD(pDs, afs[geneIdx], probVgivenDs)

        altCountsCases = Binomial(total_count=nCases, probs=probVgivenDs).sample()
        
        altCountsCtrls = Binomial(total_count=nCtrls, probs=probVgivenNotD).sample()
        
        for conditionIdx in range(nConditions):
            if conditionIdx == 0:
                geneAltCounts.append([altCountsCtrls, altCountsCases[conditionIdx]])
            else:
                geneAltCounts.append([0, altCountsCases[conditionIdx]])
            geneProbs.append([probVgivenNotD, probVgivenDs[conditionIdx]])
        altCounts.append(geneAltCounts)
        probs.append(geneProbs)

    altCounts = tensor(altCounts)
    probs = tensor(probs)
    
    # cannot convert affectedGenes to tensor; apparently tensors need to have same dimensions at each level of the tensor...stupid
    return altCounts, probs, affectedGenes, unaffectedGenes

# genData4b, but summing probabilities instead of relative risks
def genData4c(nCases, nCtrls, pDs, diseaseFractions, rrShape, rrMeans, afMean, afShape, nGenes = 20000):
    # TODO: assert shapes match
    print("TESTING WITH: nCases", nCases, "nCtrls", nCtrls, "rrMeans", rrMeans, "rrShape", rrShape, "afMean", afMean, "afShape", afShape, "diseaseFractions", diseaseFractions, "pDs", pDs)
    
    nConditions = len(nCases)
    assert(nConditions == 3)
    altCounts = []
    probs = []
    afDist = Gamma(concentration=afShape,rate=afShape/afMean)
    rrDist = Gamma(concentration=rrShape,rate=rrShape/rrMeans)
    print("rrDist mean", rrDist.sample([10_000,]).mean(0))
    
    rrNullDist = Gamma(concentration=rrShape,rate=rrShape.expand(nConditions))
    rrNulls = rrNullDist.sample([nGenes,]) #tensor([1.]).expand([20_000, nConditions])#
    
    # shape == [nGenes, nConditions]
    afs = afDist.sample([nGenes,])   
    rrs = rrDist.sample([nGenes,])

    endIndices = nGenes * diseaseFractions
    startIndices = []
    for i in range(len(diseaseFractions)):
        if i == 0:
            startIndices.append(0)
            continue
        endIndices[i] += endIndices[i-1]
        startIndices.append(endIndices[i-1])
    
    print("startIndices", startIndices, "endIndices", endIndices)
    
    affectedGenes = [[]]
    unaffectedGenes = []

    for geneIdx in range(nGenes):
        geneAltCounts = []
        geneProbs = []
        affects = 0
        rrSamples = tensor([1., 1., 1.])
        # Each gene gets only 1 state: affects condition 1 only, condition 2 only, or both
        # currently, in the both case, the increased in counts (rr) is. the same for both conditions
        for conditionIdx in range(nConditions):
            if geneIdx >= startIndices[conditionIdx] and geneIdx < endIndices[conditionIdx]:
                if conditionIdx == 0:
                    affects = 1
                elif conditionIdx == 1:
                    affects = 2
                elif conditionIdx == 2:
                    affects = 3
                else:
                    assert(conditionIdx <= 2)
                
                
                if len(affectedGenes) <= conditionIdx:
                    affectedGenes.append([])
                affectedGenes[conditionIdx].append(geneIdx)
                break
        
        assert(affects <= 3)
        PVDcases = pVgivenD(rrNulls[geneIdx], afs[geneIdx])
        # gene affects one of 3 states
        # based on which state it affects, sampleCase1, samplesCase2, samplesBoth get different rrs for this gene
        # controls always get the same value, and that is based on 1 - sum(rrs)
        if affects == 0:
            unaffectedGenes.append(geneIdx)
        elif affects == 1:
            PVDcases[0] = pVgivenD(rrs[geneIdx, 0], afs[geneIdx])
            PVDcases[2] = PVDcases[0]
        elif affects == 2:
            PVDcases[1] = pVgivenD(rrs[geneIdx, 1], afs[geneIdx])
            PVDcases[2] = PVDcases[0]
        elif affects == 3:
            pvds = pVgivenD(rrs[geneIdx], afs[geneIdx])
            PVDcases[0] = pvds[0] + pvds[2]
            PVDcases[1] = pvds[1] + pvds[2]
            PVDcases[2] = pvds[0] + pvds[1] + pvds[2]

        PVNotD = pVgivenNotD(pDs, afs[geneIdx], PVDcases)

        altCountsCases = Binomial(total_count=nCases, probs=PVDcases).sample()
        
        altCountsCtrls = Binomial(total_count=nCtrls, probs=PVNotD).sample()
        
        for conditionIdx in range(nConditions):
            if conditionIdx == 0:
                geneAltCounts.append([altCountsCtrls, altCountsCases[conditionIdx]])
            else:
                geneAltCounts.append([0, altCountsCases[conditionIdx]])
            geneProbs.append([PVNotD, PVDcases[conditionIdx]])
        altCounts.append(geneAltCounts)
        probs.append(geneProbs)

    altCounts = tensor(altCounts)
    probs = tensor(probs)
    
    # cannot convert affectedGenes to tensor; apparently tensors need to have same dimensions at each level of the tensor...stupid
    return altCounts, probs, affectedGenes, unaffectedGenes


# Like the 4b case, but multinomial
# TODO: shoudl we do int() or some rounding function to go from float counts to int counts
def genData5(nCases, nCtrls, pDs, diseaseFractions, rrShape, rrMeans, afMean, afShape, nGenes = 20000):
    # TODO: assert shapes match
    print("TESTING WITH: nCases", nCases, "nCtrls", nCtrls, "rrMeans", rrMeans, "rrShape", rrShape, "afMean", afMean, "afShape", afShape, "diseaseFractions", diseaseFractions, "pDs", pDs)
    
    nConditions = len(nCases)
    assert(nConditions == 3)
    altCounts = []
    probs = []
    afDist = Gamma(concentration=afShape,rate=afShape/afMean)
    rrDist = Gamma(concentration=rrShape,rate=rrShape/rrMeans)
    print("rrDist mean", rrDist.sample([10_000,]).mean(0))
#     rrNullDist = Gamma(concentration=rrShape,rate=rrShape.expand(nConditions))
    
    # shape == [nGenes, nConditions]
    afs = afDist.sample([nGenes,])   
    rrs = rrDist.sample([nGenes,])

    endIndices = nGenes * diseaseFractions
    startIndices = []
    for i in range(len(diseaseFractions)):
        if i == 0:
            startIndices.append(0)
            continue
        endIndices[i] += endIndices[i-1]
        startIndices.append(endIndices[i-1])
    
    print("startIndices", startIndices, "endIndices", endIndices)
    
    affectedGenes = [[]]
    unaffectedGenes = []
    rrAll = []
    
    totalSamples = int(nCtrls + nCases.sum())
    print("totalSamples", totalSamples)
    for geneIdx in range(nGenes):
        geneAltCounts = []
        geneProbs = []
        affects = 0
        rrSamples = tensor([1., 1., 1.])
        # Each gene gets only 1 state: affects condition 1 only, condition 2 only, or both
        # currently, in the both case, the increased in counts (rr) is. the same for both conditions
        for conditionIdx in range(nConditions):
            if geneIdx >= startIndices[conditionIdx] and geneIdx < endIndices[conditionIdx]:
                if conditionIdx == 0:
                    affects = 1
                elif conditionIdx == 1:
                    affects = 2
                elif conditionIdx == 2:
                    affects = 3
                else:
                    assert(conditionIdx <= 2)
                
                
                if len(affectedGenes) <= conditionIdx:
                    affectedGenes.append([])
                affectedGenes[conditionIdx].append(geneIdx)
                break
        
        assert(affects <= 3)
        # gene affects one of 3 states
        # based on which state it affects, sampleCase1, samplesCase2, samplesBoth get different rrs for this gene
        # controls always get the same value, and that is based on 1 - sum(rrs)
        if affects == 0:
            unaffectedGenes.append(geneIdx)
        elif affects == 1:
#             print(f"affects1: {geneIdx}")
            rrSamples[0] = rrs[geneIdx, 0]
            rrSamples[2] = rrs[geneIdx, 0] #both always gets a rr of non-1
        elif affects == 2:
#             print(f"affects2: {geneIdx}")
            rrSamples[1] = rrs[geneIdx, 1]
            rrSamples[2] = rrs[geneIdx, 1]
        elif affects == 3:
#             print(f"affects2: {geneIdx}")
            rrSamples[0] = rrs[geneIdx, 0] + rrs[geneIdx, 2]
            rrSamples[1] = rrs[geneIdx, 1] + rrs[geneIdx, 2]
            rrSamples[2] = rrs[geneIdx, 0] + rrs[geneIdx, 1] + rrs[geneIdx, 2]
#         print("affects", affects, "rrSamples", rrSamples)
        probVgivenDs = pVgivenD(rrSamples, afs[geneIdx])
        probVgivenNotD = pVgivenNotD(pDs, afs[geneIdx], probVgivenDs)
        
        totalProbability = (probVgivenDs * pDs).sum() + probVgivenNotD * (1 - pDs.sum())
#         print("af", afs[geneIdx], "probVgivenDs", probVgivenDs, "pDs", pDs, "probVgivenNotD", probVgivenNotD, "totalProbability", totalProbability)
        
        assert abs(totalProbability-afs[geneIdx]) / afs[geneIdx]  < 0.001
        marginalAlleleCount = int(totalProbability * totalSamples)
#         print("marginal allele count", marginalAlleleCount)
        
        p=tensor([probVgivenNotD, *probVgivenDs])
#         print("probs", probs)
        # without .numpy() can't later convert tensor(altCounts) : "only tensors can be converted to Python scalars"
        altCountsGene = Multinomial(probs=p, total_count=marginalAlleleCount).sample().numpy()
        
#         print("altCountsGene", altCountsGene)
        altCounts.append(altCountsGene)
        probs.append([probVgivenNotD, *probVgivenDs])
        rrAll.append(rrSamples)
    altCounts = tensor(altCounts)
    probs = tensor(probs)
    
    # cannot convert affectedGenes to tensor; apparently tensors need to have same dimensions at each level of the tensor...stupid
    return { "altCounts": altCounts, "afs": probs, "affectedGenes": affectedGenes, "unaffectedGenes": unaffectedGenes, "rrs": rrAll }

# Like the 5 case, but don't sum rr's, sum P(V|D)'s instead
def genData6(nCases, nCtrls, pDs, diseaseFractions, rrShape, rrMeans, afMean, afShape, nGenes = 20000):
    # TODO: assert shapes match
    print("TESTING WITH: nCases", nCases, "nCtrls", nCtrls, "rrMeans", rrMeans, "rrShape", rrShape, "afMean", afMean, "afShape", afShape, "diseaseFractions", diseaseFractions, "pDs", pDs)
    
    nConditions = len(nCases)
    assert(nConditions == 3)
    altCounts = []
    probs = []
    afDist = Gamma(concentration=afShape,rate=afShape/afMean)
    rrDist = Gamma(concentration=rrShape,rate=rrShape/rrMeans)
    print("rrDist mean", rrDist.sample([10_000,]).mean(0))
#     rrNullDist = Gamma(concentration=rrShape,rate=rrShape.expand(nConditions))
    
    # shape == [nGenes, nConditions]
    afs = afDist.sample([nGenes,])   
    rrs = rrDist.sample([nGenes,])

    endIndices = nGenes * diseaseFractions
    startIndices = []
    for i in range(len(diseaseFractions)):
        if i == 0:
            startIndices.append(0)
            continue
        endIndices[i] += endIndices[i-1]
        startIndices.append(endIndices[i-1])
    
    print("startIndices", startIndices, "endIndices", endIndices)
    
    affectedGenes = [[]]
    unaffectedGenes = []
    
    totalSamples = int(nCtrls + nCases.sum())

    print("totalSamples", totalSamples)
    for geneIdx in range(nGenes):
        geneAltCounts = []
        geneProbs = []
        affects = 0

        # Each gene gets only 1 state: affects condition 1 only, condition 2 only, or both
        # currently, in the both case, the increased in counts (rr) is. the same for both conditions
        for conditionIdx in range(nConditions):
            if geneIdx >= startIndices[conditionIdx] and geneIdx < endIndices[conditionIdx]:
                if conditionIdx == 0:
                    affects = 1
                elif conditionIdx == 1:
                    affects = 2
                elif conditionIdx == 2:
                    affects = 3
                else:
                    assert(conditionIdx <= 2)
                
                
                if len(affectedGenes) <= conditionIdx:
                    affectedGenes.append([])
                affectedGenes[conditionIdx].append(geneIdx)
                break
        
        assert(affects <= 3)
        
        
        PVDcases = pVgivenD(tensor([1., 1., 1.]), afs[geneIdx])
        # gene affects one of 3 states
        # based on which state it affects, sampleCase1, samplesCase2, samplesBoth get different rrs for this gene
        # controls always get the same value, and that is based on 1 - sum(rrs)
        if affects == 0:
            unaffectedGenes.append(geneIdx)
        elif affects == 1:
            PVDcases[0] = pVgivenD(rrs[geneIdx, 0], afs[geneIdx])
            PVDcases[2] = PVDcases[0]
        elif affects == 2:
            PVDcases[1] = pVgivenD(rrs[geneIdx, 1], afs[geneIdx])
            PVDcases[2] = PVDcases[0]
        elif affects == 3:
            pvds = pVgivenD(rrs[geneIdx], afs[geneIdx])
            PVDcases[0] = pvds[0] + pvds[2]
            PVDcases[1] = pvds[1] + pvds[2]
            PVDcases[2] = pvds[0] + pvds[1] + pvds[2]

        PVNotD = pVgivenNotD(pDs, afs[geneIdx], PVDcases)
        
        totalProbability = (PVDcases * pDs).sum() + PVNotD * (1 - pDs.sum())
#         print("affects", affects, "af", afs[geneIdx], "PVDcases", PVDcases, "PVNotD", PVNotD, "pDs", pDs, "totalProbability", totalProbability)
        
        assert abs(totalProbability-afs[geneIdx]) / afs[geneIdx]  < 0.001
        marginalAlleleCount = int(totalProbability * totalSamples)
#         print("marginal allele count", marginalAlleleCount)
        
        p=tensor([PVNotD, *PVDcases])
#         print("probs", probs)
        # without .numpy() can't later convert tensor(altCounts) : "only tensors can be converted to Python scalars"
        altCountsGene = Multinomial(probs=p, total_count=marginalAlleleCount).sample().numpy()
        
#         print("altCountsGene", altCountsGene)
        altCounts.append(altCountsGene)
        probs.append([PVNotD, *PVDcases])
    altCounts = tensor(altCounts)
    probs = tensor(probs)
    
    # cannot convert affectedGenes to tensor; apparently tensors need to have same dimensions at each level of the tensor...stupid
    return { "altCounts": altCounts, "afs": probs, "affectedGenes": affectedGenes, "unaffectedGenes": unaffectedGenes, "rrs": rrs }


# Like the 6 case, but we scale P(V|Ds) by prevalence, since the actual sample sizes say for the binomial in which P(V|D1) would be used is the fraction P(D1) of the total
# and in the multionmial setting, we use only a single sample size
# for instance, lets say we have 500k controls, 1000 cases
# the P(V|D) (cases) may be .0001 and P(V|!D) may  .0001, but the probability in a multinomial should really be 99.9999% in favor of controls
def genData7(nCases, nCtrls, pDs, diseaseFractions, rrShape, rrMeans, afMean, afShape, nGenes = 20000):
    # TODO: assert shapes match
    print("TESTING WITH: nCases", nCases, "nCtrls", nCtrls, "rrMeans", rrMeans, "rrShape", rrShape, "afMean", afMean, "afShape", afShape, "diseaseFractions", diseaseFractions, "pDs", pDs)
    
    nConditions = len(nCases)
    assert(nConditions == 3)
    altCounts = []
    probs = []
    afDist = Gamma(concentration=afShape,rate=afShape/afMean)
    rrDist = Gamma(concentration=rrShape,rate=rrShape/rrMeans)
    print("rrDist mean", rrDist.sample([10_000,]).mean(0))
#     rrNullDist = Gamma(concentration=rrShape,rate=rrShape.expand(nConditions))
    
    # shape == [nGenes, nConditions]
    afs = afDist.sample([nGenes,])   
    rrs = rrDist.sample([nGenes,])

    endIndices = nGenes * diseaseFractions
    startIndices = []
    for i in range(len(diseaseFractions)):
        if i == 0:
            startIndices.append(0)
            continue
        endIndices[i] += endIndices[i-1]
        startIndices.append(endIndices[i-1])
    
    print("startIndices", startIndices, "endIndices", endIndices)
    
    affectedGenes = [[]]
    unaffectedGenes = []
    
    totalSamples = int(nCtrls + nCases.sum())

    print("totalSamples", totalSamples)
    for geneIdx in range(nGenes):
        geneAltCounts = []
        geneProbs = []
        affects = 0

        # Each gene gets only 1 state: affects condition 1 only, condition 2 only, or both
        # currently, in the both case, the increased in counts (rr) is. the same for both conditions
        for conditionIdx in range(nConditions):
            if geneIdx >= startIndices[conditionIdx] and geneIdx < endIndices[conditionIdx]:
                if conditionIdx == 0:
                    affects = 1
                elif conditionIdx == 1:
                    affects = 2
                elif conditionIdx == 2:
                    affects = 3
                else:
                    assert(conditionIdx <= 2)
                
                
                if len(affectedGenes) <= conditionIdx:
                    affectedGenes.append([])
                affectedGenes[conditionIdx].append(geneIdx)
                break
        
        assert(affects <= 3)
        
        
        PVDcases = pVgivenD(tensor([1., 1., 1.]), afs[geneIdx])
        # gene affects one of 3 states
        # based on which state it affects, sampleCase1, samplesCase2, samplesBoth get different rrs for this gene
        # controls always get the same value, and that is based on 1 - sum(rrs)
        if affects == 0:
            unaffectedGenes.append(geneIdx)
        elif affects == 1:
            PVDcases[0] = pVgivenD(rrs[geneIdx, 0], afs[geneIdx])
            PVDcases[2] = PVDcases[0]
        elif affects == 2:
            PVDcases[1] = pVgivenD(rrs[geneIdx, 1], afs[geneIdx])
            PVDcases[2] = PVDcases[0]
        elif affects == 3:
            pvds = pVgivenD(rrs[geneIdx], afs[geneIdx])
            PVDcases[0] = pvds[0] + pvds[2]
            PVDcases[1] = pvds[1] + pvds[2]
            PVDcases[2] = pvds[0] + pvds[1] + pvds[2]

        PVNotD = pVgivenNotD(pDs, afs[geneIdx], PVDcases) #* (1 - pDs.sum())
        PVDcases = PVDcases #* pDs
            
        # P(D|V)/P(V)
        PVDprevalenceWeighted = PVDcases * pDs
        PVNotDprevalenceWeighted = PVNotD * (1 - pDs.sum())
        totalProbability = PVDprevalenceWeighted.sum() + PVNotDprevalenceWeighted
#         print("affects", affects, "af", afs[geneIdx], "PVDcases", PVDcases, "pDs", pDs, "PVNotD", PVNotD, "totalProbability", totalProbability)
        
        assert abs(totalProbability-afs[geneIdx]) / afs[geneIdx]  < 0.00001
        marginalAlleleCount = int(totalProbability * totalSamples)
#         print("marginal allele count", marginalAlleleCount)
        
        p=tensor([PVNotDprevalenceWeighted, *PVDprevalenceWeighted])
#         print("probs", probs)
        # without .numpy() can't later convert tensor(altCounts) : "only tensors can be converted to Python scalars"
        altCountsGene = Multinomial(probs=p, total_count=marginalAlleleCount).sample().numpy()
        
#         print("altCountsGene", altCountsGene)
        altCounts.append(altCountsGene)

        probs.append([PVNotDprevalenceWeighted, *PVDprevalenceWeighted])
    altCounts = tensor(altCounts)
    probs = tensor(probs)
    
    # cannot convert affectedGenes to tensor; apparently tensors need to have same dimensions at each level of the tensor...stupid
    return { "altCounts": altCounts, "afs": probs, "affectedGenes": affectedGenes, "unaffectedGenes": unaffectedGenes, "rrs": rrs }


def flattenAltCounts(altCounts, afs):
    altCountsFlatPooled = []
    afsFlatPooled = []
    for geneIdx in range(nGenes):
        altCountsFlatPooled.append([altCounts[geneIdx, 0, 0], *altCounts[geneIdx, :, 1].flatten()])
        afsFlatPooled.append([afs[geneIdx, 0, 0], *afs[geneIdx, :, 1].flatten()])

    altCountsFlatPooled = tensor(altCountsFlatPooled)
    afsFlatPooled = tensor(afsFlatPooled)
    print("altCountsFlatPooled", altCountsFlatPooled)
    print("afsFlatPooled", afsFlatPooled)

    flattenedData = []

    for geneAfData in afs:
        flattenedData.append([geneAfData[0][0],*geneAfData[:, 1]])
    flattenedData = tensor(flattenedData)
    
    return altCountsFlatPooled, afsFlatPooled, flattenedData

In [296]:
pVgivenNotD?

[0;31mSignature:[0m [0mpVgivenNotD[0m[0;34m([0m[0mpD[0m[0;34m,[0m [0mpV[0m[0;34m,[0m [0mpVgivenD[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      ~/projects/tada/<ipython-input-194-24ad4bae7d5b>
[0;31mType:[0m      function


In [217]:
pVex = 1e-4
pDex = tensor([0.0098, 0.0098, 0.0020])
pNotDex = 1 - pDex.sum()
print("pNotDex", pNotDex)
print("(pVex - (pDex*pVgivenDex).sum())", (pVex - (pDex*pVgivenDex).sum()))
pVgivenDex = tensor([0.0005, 0.0001, 0.0005])

(pVex - (pDex*pVgivenDex).sum()) / (1 - pDex.sum())

pNotDex tensor(0.9784)
(pVex - (pDex*pVgivenDex).sum()) tensor(9.3120e-05)


tensor(9.5176e-05)

In [222]:
start = time.time()
# genDataSequentialPooledCtrls(nCases=nCasesLarge, nCtrls=nCtrlsLarge, pDs=pDsGlobalLarge, diseaseFractions=diseaseFractions, rrShape=rrShape, rrMeans=rrMeansCovary, afMean=afMean, afShape=afShape, nGenes=nGenes)
r = genData6(**genParams(rrMeans=tensor([5., 5., 2.]), pis=tensor([.05, .05, .05]))[0])
print("took", time.time() - start)

pDs are: tensor([0.0098, 0.0098, 0.0020])
TESTING WITH: nCases tensor([5000., 5000., 1000.]) nCtrls tensor(500000.) rrMeans tensor([5., 5., 2.]) rrShape tensor(10.) afMean tensor(1.0000e-04) afShape tensor(10.) diseaseFractions tensor([0.0500, 0.0500, 0.0500]) pDs tensor([0.0098, 0.0098, 0.0020])
rrDist mean tensor([5.0214, 5.0024, 2.0027])
startIndices [0, tensor(1000.), tensor(2000.)] endIndices tensor([1000., 2000., 3000.])
totalSamples 511000
took 6.093706369400024


In [309]:
start = time.time()
# genDataSequentialPooledCtrls(nCases=nCasesLarge, nCtrls=nCtrlsLarge, pDs=pDsGlobalLarge, diseaseFractions=diseaseFractions, rrShape=rrShape, rrMeans=rrMeansCovary, afMean=afMean, afShape=afShape, nGenes=nGenes)
r7 = genData7(**genParams(rrMeans=tensor([5., 5., 2.]), pis=tensor([.05, .05, .05]))[0])
print("took", time.time() - start)

pDs are: tensor([0.0098, 0.0098, 0.0020])
TESTING WITH: nCases tensor([5000., 5000., 1000.]) nCtrls tensor(500000.) rrMeans tensor([5., 5., 2.]) rrShape tensor(10.) afMean tensor(1.0000e-04) afShape tensor(10.) diseaseFractions tensor([0.0500, 0.0500, 0.0500]) pDs tensor([0.0098, 0.0098, 0.0020])
rrDist mean tensor([4.9876, 5.0122, 1.9882])
startIndices [0, tensor(1000.), tensor(2000.)] endIndices tensor([1000., 2000., 3000.])
totalSamples 511000
took 5.881845951080322


In [310]:
r7["altCounts"][3000:4000].mean(0)

tensor([49.7500,  0.4790,  0.5090,  0.1140])

In [307]:
start = time.time()
# genDataSequentialPooledCtrls(nCases=nCasesLarge, nCtrls=nCtrlsLarge, pDs=pDsGlobalLarge, diseaseFractions=diseaseFractions, rrShape=rrShape, rrMeans=rrMeansCovary, afMean=afMean, afShape=afShape, nGenes=nGenes)
rOld = genData4c(**genParams(rrMeans=tensor([5., 5., 2.]), pis=tensor([.05, .05, .05]))[0])
print("took", time.time() - start)

pDs are: tensor([0.0098, 0.0098, 0.0020])
TESTING WITH: nCases tensor([5000., 5000., 1000.]) nCtrls tensor(500000.) rrMeans tensor([5., 5., 2.]) rrShape tensor(10.) afMean tensor(1.0000e-04) afShape tensor(10.) diseaseFractions tensor([0.0500, 0.0500, 0.0500]) pDs tensor([0.0098, 0.0098, 0.0020])
rrDist mean tensor([4.9909, 4.9864, 2.0018])
startIndices [0, tensor(1000.), tensor(2000.)] endIndices tensor([1000., 2000., 3000.])
took 82.58038687705994


In [311]:
rOld[0][3000:4000].mean(0)


tensor([[49.9320,  0.5090],
        [ 0.0000,  0.4870],
        [ 0.0000,  0.0890]])

In [None]:
r

In [None]:
np.corrcoef(r["afs"][2000:3000, 1], r["afs"][2000:3000, 3])

In [None]:
print(flattenedData4)
pyplot.clf()
pyplot.figure(num=None, figsize=(20, 6), dpi=80, facecolor='w', edgecolor='k')
pyplot.plot(cachedData4b[-1][0]["afsPooled"][:, 2, 1])

In [None]:
print("corr for 1 & both in both", np.corrcoef(afsFlatPooled4[4000:5000, 2], afsFlatPooled4[4000:5000, 3]))
print("corr for 1 & both in 1only", np.corrcoef(afsFlatPooled4[0:2000, 1], afsFlatPooled4[0:2000, 3]))

In [None]:
print(afsFlatPooled2[0:5000, 3].mean())
afsFlatPooled2[5000:, 3].mean()

In [None]:
print("empirical rr for both", afsFlatPooled4[0:5000, 3].mean()/ afsFlatPooled4[5000:, 3].mean())
print("empirical rr for 1", ((afsFlatPooled4[0:2000, 1].mean()))/ afsFlatPooled4[2000:4000, 1].mean())
print("empirical rr for 2", ((afsFlatPooled4[2000:4000, 2].mean() + afsFlatPooled4[4000:5000, 2].mean())/2)/ afsFlatPooled4[:2000, 2].mean())
# print("nullLikelihoodGlobal 1", (nullLikelihoodsGlobal[0:2000, 0] + nullLikelihoodsGlobal[4000:5000, 0]).mean(), nullLikelihoodsGlobal[2000:, 0].mean())
# print("nullLikelihoodGlobal 2", nullLikelihoodsGlobal[2000:4000, 1].mean(), nullLikelihoodsGlobal[4000:, 1].mean())
# print("nullLikelihoodGlobal Both", nullLikelihoodsGlobal[4000:5000, 2].mean(), nullLikelihoodsGlobal[0:4000, 2].mean())

In [None]:
llPooledBivariateSingleGene(tensor([10.,1.,1.,20000.]), tensor([.01,.01,.05]), tensor(13.), tensor(10.), tensor(10.), tensor(100000.), tensor(.77), tensor(.1), tensor(.1), tensor(.01))

In [None]:
#this gives -2.401 log(likelihoodUnivariateSingleGene(xCtrl = 10, xCase1 = 1, prevalence1 = .01, pi0 = .9, pi1 = .1, pDiseaseGivenVariant = .001))
#tensor(-2.5290): llUnivariateSingleGeneJensen(xCtrl = tensor(10.), xCase = tensor(1.), pD = tensor(.01), pi0 = tensor(.9), pi1 = tensor(.1), pDgivenV = tensor(.001))
r = llUnivariateSingleGeneNoJensen(xCtrl = tensor(10.), xCase = tensor(1.), pD = tensor(.01), pi0 = tensor(.9), pi1 = tensor(.1), pDgivenV = tensor(.001))
assert(abs(-r + tensor(-2.4010)) < .1)

In [None]:
altCounts = tensor([10., 2., 3., 1.])
n = altCounts.sum()

testAlpha = tensor([16., 20., 30., 15.])
print(f"test data: testAlpha: {testAlpha}, n: {n}, altCounts: {altCounts}")
DirichletMultinomial(total_count=n, concentration=testAlpha).log_prob(altCounts)

In [None]:
pDgivenV(.01, afsByGene[0:2000, 0, 1], 1e-4).mean()

In [None]:
### Test functions
pDgivenV(.01, afsByGeneRR2[0:2000, 0, 1], afsByGeneRR2Shape5[0:2000, 0, 0]).mean()

In [None]:
# variance is wrong
def betaVariance(alpha, beta):
    return (alpha * beta) / ( ((alpha + beta)**2) + (alpha + beta + 1) )

def betaMean(alpha, beta):
    return alpha / (alpha + beta)

print("variance", betaVariance(6.47e1,5.39e3))
print("mean", betaMean(6.47e1,5.39e3))
print("true varianc", )

In [None]:
m1 = afsByGene[0:2000, 0, 1].mean()
m2 = afsByGeneRR2[0:2000, 0, 1].mean()/afsByGeneRR2[2000:, 0, 1].mean()
m1 - m2
print(m1, m2)

In [None]:
start = time.time()
res = fitFnUniveriate(altCountsByGene, pDs, nEpochs=20, minLLThresholdCount=20, debug=True)
print((time.time() - start) / 20, "per iteration")
res

In [None]:
fitFnUniveriate(altCountsByGeneRR2, pDs)

In [None]:
pDgivenV(pDs[0], afsByGeneRR2Shape5[0:2000, 0, 1], afsByGeneRR2Shape5[0:2000, 0, 0]).std()

In [None]:
pDgivenV?

In [None]:
fitFnUniveriate(altCounts, pDs, debug=True)

In [None]:
fitFnUniveriateBetaBinomial(altCountsByGeneRR2, pDs, debug=True)

In [None]:
fitFnUniveriateBetaBinomial(altCountsByGeneRR2Shape5, pDs, debug=True)

In [None]:
Beta(2.23307950e+02, 2.52700651e+04).sample([10_000,]).mean()

In [None]:
Beta(2.20865706e+02, 1.73544747e+04).sample([10_000,]).std()

In [None]:
fitFnUniveriateBetaBinomial(altCountsByGeneRR3, pDs, nEpochs=50)

In [None]:
Beta(2.50432693e+02, 1.87756988e+04).sample([10_000]).mean()

In [None]:
fitFnUniveriate(altCountsByGeneRR3, pDs, nEpochs=50)

In [None]:
Beta(3.84376856e+02, 2.37879954e+04).sample([10_000]).mean()

In [None]:
resultsRR2Shape5 = []
for i in range(1):
    res = fitFnUniveriateBetaBinomial(altCountsByGeneRR2Shape5, pDs, nEpochs=50, minLLThresholdCount=50, debug=False)
    resultsRR2Shape5.append(res)

In [None]:
resultsRR2Shape5

In [None]:
Beta(1.96912591e+02, 1.61461738e+04).sample([10000,]).mean()

In [None]:
fitFnUniveriateBetaBinomial(altCountsByGene, pDs, nEpochs=50, minLLThresholdCount=50, debug=False)

In [None]:
Beta(3.30289057e+03, 2.94460355e+04).sample([10000,]).mean()

In [None]:
# doesn't really work resConstrained = fitFnUniveriateBetaBinomialConstrained(altCountsByGeneRR2Shape5, pDs, nEpochs=10, minLLThresholdCount=10, debug=True)
#resConstrained

In [None]:
fitFnUniveriateBetaBinomial(altCountsByGene, pDs, nEpochs=10, minLLThresholdCount=10, debug=False)

In [None]:
cachedData

In [None]:
params = {"lls": [], "inferredAlphas": [], "inferredPis": [], "inferredPDVs": [], "trueMeanPDVs": [], "truePis": []}
cachedData = [[altCountsByGenePooledCtrls, afsByGenePooledCtrls, {
            "nCases": nCasesLarge,
            "nCtrls": nCtrlsLarge,
            "pDs": pDsGlobalLarge,
            "diseaseFractions": diseaseFractions,
            "rrShape": rrShape,
            "rrMeans": rrMeans,
            "afShape": afShape,
            "afMean": afMean
        }]]

In [None]:
params2 = {"lls": [], "inferredAlphas": [], "inferredPis": [], "inferredPDVs": [], "trueMeanPDVs": [], "truePis": [], "costFnIdx": []}
cachedData2 = [[altCountsByGenePooledCtrls2, afsByGenePooledCtrls2, {
            "nCases": nCasesLarge,
            "nCtrls": nCtrlsLarge,
            "pDs": pDsGlobalLarge,
            "diseaseFractions": diseaseFractions,
            "rrShape": rrShape,
            "rrMeans": rrMeans,
            "afShape": afShape,
            "afMean": afMean,
        }]]

In [None]:
for i in range(10):
    if i >= len(cachedData):
        start = time.time()
        altCountsByGenePooledCtrls, afsByGenePooledCtrls = genDataSequentialPooledCtrls(nCases=nCasesLarge, nCtrls=nCtrlsLarge, pDs=pDsGlobalLarge, diseaseFractions=diseaseFractions, rrShape=rrShape, rrMeans=rrMeans, afMean=afMean, afShape=afShape)
        print("took", time.time() - start)
        cachedData.append([altCountsByGenePooledCtrls, afsByGenePooledCtrls, {
            "nCases": nCasesLarge,
            "nCtrls": nCtrlsLarge,
            "pDs": pDsGlobalLarge,
            "diseaseFractions": diseaseFractions,
            "rrShape": rrShape,
            "rrMeans": rrMeans,
            "afShape": afShape,
            "afMean": afMean
        }])
        
    res = fitFnBivariate(cachedData[i][0], pDsGlobalLarge, nEpochs=20, minLLThresholdCount=20, debug=True, costFnIdx=0)
    bestRes = res["params"][-1]

    inferredPis = tensor(bestRes[0:3])
    print("inferredPis", inferredPis)
    inferredAlphas = tensor(bestRes[3:])
    print("inferredAlphas", inferredAlphas)
    
    inferredPDs = Dirichlet(concentration=inferredAlphas).sample([10_000,]).mean(0)

    params["lls"].append(res["lls"][-1])
    params["inferredAlphas"].append(inferredAlphas)
    params["inferredPis"].append(inferredPis)
    params["inferredPDVs"].append(inferredPDs)

    truth1 = pDgivenV(pDsGlobalLarge[0], afsByGenePooledCtrls[0:2000, 0, 1], afsByGenePooledCtrls[0:2000, 0, 0]).mean()
    truth2 = pDgivenV(pDsGlobalLarge[1], afsByGenePooledCtrls[2000:4000, 1, 1], afsByGenePooledCtrls[2000:4000, 0, 0]).mean()
    truth3 = pDgivenV(pDsGlobalLarge[2], afsByGenePooledCtrls[4000:5000, 2, 1], afsByGenePooledCtrls[4000:5000, 0, 0]).mean()
    truth0 = 1 - (truth1 + truth2 + truth3)
    print("truth0", truth0, "truth1", truth1, "truth2", truth2, "truthBoth", truth3)

    params["trueMeanPDVs"].append(tensor([truth0, truth1, truth2, truth3]))
    params["truePis"].append(tensor(diseaseFractions))

    print(f"params on run {i}", params)

In [None]:
for i in range(10):
    if i >= len(cachedData2):
        start = time.time()
        altCountsByGenePooledCtrls2, afsByGenePooledCtrls2 = genDataSequentialPooledCtrls(nCases=nCasesLarge, nCtrls=nCtrlsLarge, pDs=pDsGlobalLarge, diseaseFractions=diseaseFractions, rrShape=rrShape, rrMeans=rrMeans, afMean=afMean, afShape=afShape)
        print("took", time.time() - start)
        cachedData2.append([altCountsByGenePooledCtrls2, afsByGenePooledCtrls2, {
            "nCases": nCasesLarge,
            "nCtrls": nCtrlsLarge,
            "pDs": pDsGlobalLarge,
            "diseaseFractions": diseaseFractions,
            "rrShape": rrShape,
            "rrMeans": rrMeans,
            "afShape": afShape,
            "afMean": afMean,
        }])
    runCostFnIdx = 6
    # todo append all entries to indciate failure
    params2["costFnIdx"].append(runCostFnIdx)

    res = fitFnBivariate(cachedData2[i][0], pDsGlobalLarge, nEpochs=20, minLLThresholdCount=20, debug=True, costFnIdx=runCostFnIdx)
    
    
    bestRes = res["params"][-1]

    inferredPis = tensor(bestRes[0:3])
    print("inferredPis", inferredPis)
    inferredAlphas = tensor(bestRes[3:])
    print("inferredAlphas", inferredAlphas)
    
    inferredPDs = Dirichlet(concentration=inferredAlphas).sample([10_000,]).mean(0)

    params2["lls"].append(res["lls"][-1])
    params2["inferredAlphas"].append(inferredAlphas)
    params2["inferredPis"].append(inferredPis)
    params2["inferredPDVs"].append(inferredPDs)

    truth1 = pDgivenV(pDsGlobalLarge[0], cachedData2[i][1][0:2000, 0, 1], cachedData2[i][1][0:2000, 0, 0]).mean()
    truth2 = pDgivenV(pDsGlobalLarge[1], cachedData2[i][1][2000:4000, 1, 1], cachedData2[i][1][2000:4000, 0, 0]).mean()
    truth3 = pDgivenV(pDsGlobalLarge[2], cachedData2[i][1][4000:5000, 2, 1], cachedData2[i][1][4000:5000, 0, 0]).mean()
    truth0 = 1 - (truth1 + truth2 + truth3)
    print("truth0", truth0, "truth1", truth1, "truth2", truth2, "truthBoth", truth3)

    params2["trueMeanPDVs"].append(tensor([truth0, truth1, truth2, truth3]))
    params2["truePis"].append(tensor(diseaseFractions))
    
    print(f"params on run {i}", params2)

In [None]:
params3 = {"lls": [], "inferredAlphas": [], "inferredPis": [], "inferredPDVs": [], "trueMeanPDVs": [], "truePis": [], "costFnIdx": []}
cachedData3 = [[altCountsByGenePooledCtrls3, afsByGenePooledCtrls3, {
            "nCases": nCasesLarge,
            "nCtrls": nCtrlsLarge,
            "pDs": pDsGlobalLarge,
            "diseaseFractions": diseaseFractions,
            "rrShape": rrShape,
            "rrMeans": rrMeans,
            "afShape": afShape,
            "afMean": afMean,
        }]]

In [None]:
params4 = {"lls": [], "inferredAlphas": [], "inferredPis": [], "inferredPDVs": [], "truePDVs": [], "truePis": [], "costFnIdx": []}
# altCountsByGenePooledCtrls4, afsByGenePooledCtrls4, affectedGenes4 = genData4(nCases=nCasesLarge, nCtrls=nCtrlsLarge, pDs=pDsGlobalLarge, diseaseFractions=diseaseFractions, rrShape=rrShape, rrMeans=rrMeansCovary, afMean=afMean, afShape=afShape, nGenes=nGenes)


In [None]:
def genParams(pis = tensor([.1, .1, .05]), rrShape = tensor(10.), rrMeans = tensor([3., 3., 1.5]), afShape = tensor(10.), afMean = tensor(1e-4)):
    nGenes = 20_000

    nCtrls = tensor(5e5)
    nCasesAffectedByOne = tensor([5e3, 5e3])
    nBothCases = nCasesAffectedByOne.sum() * .1
    nCases = tensor([*nCasesAffectedByOne, nBothCases])

    pDs = nCases / ( nCases.sum() + nCtrls )
    print("pDs are:", pDs)
    
    return [{
        "nGenes": nGenes,
        "nCases": nCases,
        "nCtrls": nCtrls,
        "pDs": pDs,
        "diseaseFractions": pis,
        "rrShape": rrShape,
        "rrMeans": rrMeans,
        "afShape": afShape,
        "afMean": afMean,
    }]

In [None]:
cachedData4 = []

In [None]:
for i in range(0, 5):
    if i >= len(cachedData4):
        
        params = genParams()[0]
        start = time.time()
        xsPooledRun, afsPooledRun, affectedGenesRun, unaffectedGenesRun = genData4(**params)

        print("took", time.time() - start)
        cachedData4.append({
            "xsPooled": xsPooledRun,
            "afsPooled": afsPooledRun,
            "affectedGenes": affectedGenesRun,
            "unaffectedGenes": unaffectedGenesRun,
            "params": params,
        })
    cd = cachedData4[i]
    xsPooledRun = cd["xsPooled"]
    afsPooledRun = cd["afsPooled"]
    affectedGenesRun = cd["affectedGenes"]
    unaffectedGenesRun = cd["unaffectedGenes"]

    pDsRun = cd["params"]["pDs"]
    pisRun = cd["params"]["diseaseFractions"]

    runCostFnIdx = 15
    res = fitFnBivariate(xsPooledRun, pDsRun, nEpochs=3, minLLThresholdCount=20, debug=True, costFnIdx=runCostFnIdx)
    
    bestRes = res["params"][-1]

    inferredPis = tensor(bestRes[0:3]) # 3-vector
    inferredAlphas = tensor(bestRes[3:]) # 4-vector, idx0 is P(!D|V)

    # index 0 is the P(!D|V), aka probability of being a control given variant present
    inferredPDs = Dirichlet(concentration=inferredAlphas).sample([10_000,]).mean(0)
    
    inferredPiProp1 = inferredPis[0] / (inferredPis[0] + inferredPis[2])
    inferredPiProp2 = inferredPis[1] / (inferredPis[1] + inferredPis[2])
    
    PDctrlV = inferredPDs[0]
    PD1V = inferredPDs[1]
    PD2V = inferredPDs[2]
    PDBothV = inferredPDs[1]
    
    inferredC1PDgivenV = inferredPiProp1 * PD1V + (1 - inferredPiProp1) * (PD1V + PDBothV * pDsRun[0]/pDsRun[2])
    inferredC2PDgivenV = inferredPiProp2 * PD2V + (1 - inferredPiProp2) * (PD2V + PDBothV * pDsRun[1]/pDsRun[2])
    inferredCBothPDgivenV = (inferredPis[0] / inferredPis.sum()) * PD1V * pDsRun[2]/pDsRun[0] + (inferredPis[1] / inferredPis.sum()) * PD2V * pDsRun[2]/pDsRun[1] + (inferredPis[2] / inferredPis.sum()) * PDBothV
    
    inferredPDGivenVs = tensor([inferredC1PDgivenV, inferredC2PDgivenV, inferredCBothPDgivenV])
    
    piProp1 = pisRun[0] / (pisRun[0] + pisRun[2])
    piProp2 = pisRun[1] / (pisRun[1] + pisRun[2])
    
    PV_hat = afsPooledRun[unaffectedGenesRun].mean()
    
    trueC1PVgivenD = piProp1 * afsPooledRun[affectedGenesRun[0], 0, 1].mean() + (1-piProp1) * afsPooledRun[affectedGenesRun[2], 0, 1].mean()
    trueC2PVgivenD = piProp2 * afsPooledRun[affectedGenesRun[1], 1, 1].mean() + (1-piProp2) * afsPooledRun[affectedGenesRun[2], 1, 1].mean()
    trueCBothPVgivenD = (pisRun[0] / pisRun.sum()) * afsPooledRun[affectedGenesRun[0], 2, 1].mean() + (pisRun[1] / pisRun.sum()) * afsPooledRun[affectedGenesRun[1], 2, 1].mean() + (pisRun[2] / pisRun.sum()) * afsPooledRun[affectedGenesRun[2], 2, 1].mean()
    
    trueC1PDgivenV = pDgivenV(pD=pDsRun[0], pVgivenD=trueC1PVgivenD, pV=PV_hat)
    trueC2PDgivenV = pDgivenV(pD=pDsRun[1], pVgivenD=trueC2PVgivenD, pV=PV_hat)
    trueCBothPDgivenV = pDgivenV(pD=pDsRun[2], pVgivenD=trueCBothPVgivenD, pV=PV_hat)
    
    
    truePDGivenVs = tensor([trueC1PDgivenV, trueC2PDgivenV, trueCBothPDgivenV])
    print("\ninferredPis", inferredPis)
    print("\ninferredPDVs raw", inferredPDs)
    print("\ninferredPDVs scaled", inferredPDGivenVs)
    print("\ntruePDGivenVs", truePDGivenVs)
    
    params4["lls"].append(res["lls"][-1])
    params4["inferredAlphas"].append(inferredAlphas)
    params4["inferredPis"].append(inferredPis)
    params4["inferredPDVs"].append(inferredPDGivenVs)
    params4["truePDVs"].append(truePDGivenVs)
    params4["truePis"].append(pisRun)
    
    params4["costFnIdx"].append(runCostFnIdx)

    print(f"\n\nparams on run {i}\n", params4)

In [None]:
piProp1 = pisRun[0] / (pisRun[0] + pisRun[2])
piProp2 = pisRun[1] / (pisRun[1] + pisRun[2])

PV_hat = afsPooledRun[unaffectedGenesRun].mean()

afs1 = afsPooledRun[affectedGenesRun[0], 0, 1].mean()
piProp1 * afsPooledRun[affectedGenesRun[0], 0, 1].mean() + (1-piProp1) * ( afsPooledRun[affectedGenesRun[2], 0, 1].mean() )

In [None]:
inferredPiProp1 = inferredPis[0] / inferredPis.sum()
inferredPiProp2 = inferredPis[1] / inferredPis.sum()
inferredPiProp3 = inferredPis[2] / inferredPis.sum()
inferredPis

In [None]:
inferredPiProp1 = inferredPis[0] / (inferredPis[0] + inferredPis[2])
PDctrlV = inferredPDs[0]
PD1V = inferredPDs[1]
PD2V = inferredPDs[2]
PDBothV = inferredPDs[1]

inferredC1PDgivenV = inferredPiProp1 * PD1V + (1 - inferredPiProp1) * (PD1V + PDBothV * pDsRun[0]/pDsRun[2])
inferredC2PDgivenV = inferredPiProp2 * PD2V + (1 - inferredPiProp2) * (PD2V + PDBothV * pDsRun[1]/pDsRun[2])
inferredCBothPDgivenV = inferredPiProp1 * PD1V * pDsRun[2]/pDsRun[0] + inferredPiProp2 * PD2V * pDsRun[2]/pDsRun[1] + inferredPiProp3 * PDBothV

inferredC1PDgivenV


In [None]:
piProp2 * afsPooledRun[affectedGenesRun[1], 1, 1].mean() + (1-piProp2) * afsPooledRun[affectedGenesRun[2], 1, 1].mean()

In [None]:
afsPooledRun[affectedGenesRun[2], 0, 1].mean()

In [None]:
afsPooledRun[1200]

In [None]:
piProp1 * afsPooledRun[affectedGenesRun[0], 0, 1].mean() + (1-piProp1) * afsPooledRun[affectedGenesRun[2], 0, 1].mean()

In [None]:
inferredPDs = Dirichlet(concentration=inferredAlphas).sample([10_000,]).mean(0)

inferredPiProp1 = inferredPis[0] / inferredPis.sum()
inferredPiProp2 = inferredPis[1] / inferredPis.sum()

inferredC1PDgivenV = inferredPiProp1 * inferredPDs[1] + (1 - inferredPiProp1) * (inferredPDs[1] + inferredPDs[3] * pDsRun[0]/pDsRun[2])
inferredC2PDgivenV = inferredPiProp2 * inferredPDs[2] + (1 - inferredPiProp2) * (inferredPDs[2] + inferredPDs[3] * pDsRun[1]/pDsRun[2])
inferredCBothPDgivenV = inferredPiProp1 * inferredPDs[1] * pDsRun[2]/pDsRun[0] + inferredPiProp2 * inferredPDs[2] * pDsRun[2]/pDsRun[1] + (1 - inferredPiProp1 - inferredPiProp2) * inferredPDs[3]
inferredCBothPDgivenV


In [None]:
params4b = {"lls": [], "inferredAlphas": [], "inferredPis": [], "inferredPDVs": [], "truePDVs": [], "truePis": [], "costFnIdx": []}
# altCountsByGenePooledCtrls4, afsByGenePooledCtrls4, affectedGenes4 = genData4(nCases=nCasesLarge, nCtrls=nCtrlsLarge, pDs=pDsGlobalLarge, diseaseFractions=diseaseFractions, rrShape=rrShape, rrMeans=rrMeansCovary, afMean=afMean, afShape=afShape, nGenes=nGenes)


In [None]:
cachedData4b = []

In [None]:
for i in range(2):
    if i >= len(cachedData4b):
        # In DSB:
        # 	No ID	ID	
        #         ASD+ADHD	684	217	
        #         ASD	3091	871	
        #         ADHD	3206	271	
        #         Control	5002	-	

        #         gnomAD	44779	(Non-Finnish Europeans in non-psychiatric exome subset)	

        #         Case total:	8340		
        #         Control total:	49781		
        # so we can use pDBoth = .1 * total_cases
        params = genParams(rrMeans=tensor([5., 5., 2.]), pis=tensor([.05, .05, .05]))[0]
        start = time.time()
        xsPooledRun, afsPooledRun, affectedGenesRun, unaffectedGenesRun = genData4(**params)

        print("took", time.time() - start)
        cachedData4b.append({
            "xsPooled": xsPooledRun,
            "afsPooled": afsPooledRun,
            "affectedGenes": affectedGenesRun,
            "unaffectedGenes": unaffectedGenesRun,
            "params": params,
        })
    xsPooledRun = cachedData4b[i]["xsPooled"]
    afsPooledRun = cachedData4b[i]["afsPooled"]
    affectedGenesRun = cachedData4b[i]["affectedGenes"]
    unaffectedGenesRun = cachedData4b[i]["unaffectedGenes"]
    pDsRun = cachedData4b[i]["params"]["pDs"]
    pisRun = cachedData4b[i]["params"]["diseaseFractions"]

    print("i is", i)
    print("params are:", cachedData4b[i]["params"])
    runCostFnIdx = 15
    res = fitFnBivariate(xsPooledRun, pDsRun, nEpochs=10, minLLThresholdCount=20, debug=True, costFnIdx=runCostFnIdx)
    
    bestRes = res["params"][-1]

    inferredPis = tensor(bestRes[0:3]) # 3-vector
    inferredAlphas = tensor(bestRes[3:]) # 4-vector, idx0 is P(!D|V)

    # index 0 is the P(!D|V), aka probability of being a control given variant present
    inferredPDs = Dirichlet(concentration=inferredAlphas).sample([10_000,]).mean(0)
    
    inferredPiProp1 = inferredPis[0] / inferredPis.sum()
    inferredPiProp2 = inferredPis[1] / inferredPis.sum()
    
    inferredC1PDgivenV = inferredPiProp1 * inferredPDs[1] + (1 - inferredPiProp1) * (inferredPDs[1] + inferredPDs[3] * pDsRun[0]/pDsRun[2])
    inferredC2PDgivenV = inferredPiProp2 * inferredPDs[2] + (1 - inferredPiProp2) * (inferredPDs[2] + inferredPDs[3] * pDsRun[1]/pDsRun[2])
    inferredCBothPDgivenV = (inferredPis[0] / inferredPis.sum()) * inferredPDs[1] * pDsRun[2]/pDsRun[0] + (inferredPis[1] / inferredPis.sum()) * inferredPDs[2] * pDsRun[2]/pDsRun[1] + (inferredPis[2] / inferredPis.sum()) * inferredPDs[3]

    piProp1 = pisRun[0] / (pisRun[0] + pisRun[2])
    piProp2 = pisRun[1] / (pisRun[1] + pisRun[2])
    
    PV_hat = afsPooledRun[unaffectedGenesRun].mean()
    
    trueC1PVgivenD = piProp1 * afsPooledRun[affectedGenesRun[0], 0, 1].mean() + (1-piProp1) * afsPooledRun[affectedGenesRun[2], 0, 1].mean()
    trueC2PVgivenD = piProp2 * afsPooledRun[affectedGenesRun[1], 1, 1].mean() + (1-piProp2) * afsPooledRun[affectedGenesRun[2], 1, 1].mean()
    trueCBothPVgivenD = (pisRun[0] / pisRun.sum()) * afsPooledRun[affectedGenesRun[0], 2, 1].mean() + (pisRun[1] / pisRun.sum()) * afsPooledRun[affectedGenesRun[1], 2, 1].mean() + (pisRun[2] / pisRun.sum()) * afsPooledRun[affectedGenesRun[2], 2, 1].mean()
    
    trueC1PDgivenV = pDgivenV(pD=pDsRun[0], pVgivenD=trueC1PVgivenD, pV=PV_hat)
    trueC2PDgivenV = pDgivenV(pD=pDsRun[1], pVgivenD=trueC2PVgivenD, pV=PV_hat)
    trueCBothPDgivenV = pDgivenV(pD=pDsRun[2], pVgivenD=trueCBothPVgivenD, pV=PV_hat)
    
    inferredPDGivenVs = tensor([inferredC1PDgivenV, inferredC2PDgivenV, inferredCBothPDgivenV])
    truePDGivenVs = tensor([trueC1PDgivenV, trueC2PDgivenV, trueCBothPDgivenV])
    print("\ninferredPis", inferredPis)
    print("\ninferredPDVs", inferredPDGivenVs)
    print("\ntruePDGivenVs", truePDGivenVs)
    
    params4b["lls"].append(res["lls"][-1])
    params4b["inferredAlphas"].append(inferredAlphas)
    params4b["inferredPis"].append(inferredPis)
    params4b["inferredPDVs"].append(inferredPDGivenVs)
    params4b["truePDVs"].append(truePDGivenVs)
    params4b["truePis"].append(pisRun)
    
    params4b["costFnIdx"].append(runCostFnIdx)

    print(f"\n\nparams on run {i}\n", params4b)

In [201]:
params4c = {"lls": [], "inferredAlphas": [], "inferredPis": [], "inferredPDVs": [], "truePDVs": [], "truePis": [], "costFnIdx": []}
# altCountsByGenePooledCtrls4, afsByGenePooledCtrls4, affectedGenes4 = genData4(nCases=nCasesLarge, nCtrls=nCtrlsLarge, pDs=pDsGlobalLarge, diseaseFractions=diseaseFractions, rrShape=rrShape, rrMeans=rrMeansCovary, afMean=afMean, afShape=afShape, nGenes=nGenes)
cachedData4c = []

In [206]:
pafor i in range(2):
    if i >= len(cachedData4c):
        # In DSB:
        # 	No ID	ID	
        #         ASD+ADHD	684	217	
        #         ASD	3091	871	
        #         ADHD	3206	271	
        #         Control	5002	-	

        #         gnomAD	44779	(Non-Finnish Europeans in non-psychiatric exome subset)	

        #         Case total:	8340		
        #         Control total:	49781		
        # so we can use pDBoth = .1 * total_cases
        params = genParams(rrMeans=tensor([5., 5., 2.]), pis=tensor([.05, .05, .05]))[0]
        start = time.time()
        xsPooledRun, afsPooledRun, affectedGenesRun, unaffectedGenesRun = genData4c(**params)

        print("took", time.time() - start)
        cachedData4c.append({
            "xsPooled": xsPooledRun,
            "afsPooled": afsPooledRun,
            "affectedGenes": affectedGenesRun,
            "unaffectedGenes": unaffectedGenesRun,
            "params": params,
        })
    xsPooledRun = cachedData4c[i]["xsPooled"]
    afsPooledRun = cachedData4c[i]["afsPooled"]
    affectedGenesRun = cachedData4c[i]["affectedGenes"]
    unaffectedGenesRun = cachedData4c[i]["unaffectedGenes"]
    pDsRun = cachedData4c[i]["params"]["pDs"]
    pisRun = cachedData4c[i]["params"]["diseaseFractions"]

    print("i is", i)
    print("params are:", cachedData4c[i]["params"])
    runCostFnIdx = 15
    res = fitFnBivariate(xsPooledRun, pDsRun, nEpochs=10, minLLThresholdCount=20, debug=True, costFnIdx=runCostFnIdx)pa
    
    bestRes = res["params"][-1]

    inferredPis = tensor(bestRes[0:3]) # 3-vector
    inferredAlphas = tensor(bestRes[3:]) # 4-vector, idx0 is P(!D|V)

    # index 0 is the P(!D|V), aka probability of being a control given variant present
    inferredPDs = Dirichlet(concentration=inferredAlphas).sample([10_000,]).mean(0)
    
    inferredPiProp1 = inferredPis[0] / inferredPis.sum()
    inferredPiProp2 = inferredPis[1] / inferredPis.sum()
    
    inferredC1PDgivenV = inferredPiProp1 * inferredPDs[1] + (1 - inferredPiProp1) * (inferredPDs[1] + inferredPDs[3] * pDsRun[0]/pDsRun[2])
    inferredC2PDgivenV = inferredPiProp2 * inferredPDs[2] + (1 - inferredPiProp2) * (inferredPDs[2] + inferredPDs[3] * pDsRun[1]/pDsRun[2])
    inferredCBothPDgivenV = (inferredPis[0] / inferredPis.sum()) * inferredPDs[1] * pDsRun[2]/pDsRun[0] + (inferredPis[1] / inferredPis.sum()) * inferredPDs[2] * pDsRun[2]/pDsRun[1] + (inferredPis[2] / inferredPis.sum()) * inferredPDs[3]

    piProp1 = pisRun[0] / (pisRun[0] + pisRun[2])
    piProp2 = pisRun[1] / (pisRun[1] + pisRun[2])
    
    PV_hat = afsPooledRun[unaffectedGenesRun].mean()
    
    trueC1PVgivenD = piProp1 * afsPooledRun[affectedGenesRun[0], 0, 1].mean() + (1-piProp1) * afsPooledRun[affectedGenesRun[2], 0, 1].mean()
    trueC2PVgivenD = piProp2 * afsPooledRun[affectedGenesRun[1], 1, 1].mean() + (1-piProp2) * afsPooledRun[affectedGenesRun[2], 1, 1].mean()
    trueCBothPVgivenD = (pisRun[0] / pisRun.sum()) * afsPooledRun[affectedGenesRun[0], 2, 1].mean() + (pisRun[1] / pisRun.sum()) * afsPooledRun[affectedGenesRun[1], 2, 1].mean() + (pisRun[2] / pisRun.sum()) * afsPooledRun[affectedGenesRun[2], 2, 1].mean()
    
    trueC1PDgivenV = pDgivenV(pD=pDsRun[0], pVgivenD=trueC1PVgivenD, pV=PV_hat)
    trueC2PDgivenV = pDgivenV(pD=pDsRun[1], pVgivenD=trueC2PVgivenD, pV=PV_hat)
    trueCBothPDgivenV = pDgivenV(pD=pDsRun[2], pVgivenD=trueCBothPVgivenD, pV=PV_hat)
    
    inferredPDGivenVs = tensor([inferredC1PDgivenV, inferredC2PDgivenV, inferredCBothPDgivenV])
    truePDGivenVs = tensor([trueC1PDgivenV, trueC2PDgivenV, trueCBothPDgivenV])
    print("\ninferredPis", inferredPis)
    print("\ninferredPDVs", inferredPDGivenVs)
    print("\ntruePDGivenVs", truePDGivenVs)
    
    params4c["lls"].append(res["lls"][-1])
    params4c["inferredAlphas"].append(inferredAlphas)
    params4c["inferredPis"].append(inferredPis)
    params4c["inferredPDVs"].append(inferredPDGivenVs)
    params4c["truePDVs"].append(truePDGivenVs)
    params4c["truePis"].append(pisRun)
    
    params4c["costFnIdx"].append(runCostFnIdx)

    print(f"\n\nparams on run {i}\n", params4c)

pDs are: tensor([0.0098, 0.0098, 0.0020])
TESTING WITH: nCases tensor([5000., 5000., 1000.]) nCtrls tensor(500000.) rrMeans tensor([5., 5., 2.]) rrShape tensor(10.) afMean tensor(1.0000e-04) afShape tensor(10.) diseaseFractions tensor([0.0500, 0.0500, 0.0500]) pDs tensor([0.0098, 0.0098, 0.0020])
rrDist mean tensor([4.9935, 4.9712, 1.9986])
startIndices [0, tensor(1000.), tensor(2000.)] endIndices tensor([1000., 2000., 3000.])
took 83.56188488006592
i is 0
params are: {'nGenes': 20000, 'nCases': tensor([5000., 5000., 1000.]), 'nCtrls': tensor(500000.), 'pDs': tensor([0.0098, 0.0098, 0.0020]), 'diseaseFractions': tensor([0.0500, 0.0500, 0.0500]), 'rrShape': tensor(10.), 'rrMeans': tensor([5., 5., 2.]), 'afShape': tensor(10.), 'afMean': tensor(1.0000e-04)}
shape torch.Size([20000, 3, 2])
altCountsFlat tensor([[51.,  1.,  2.,  0.],
        [17.,  2.,  0.,  0.],
        [28.,  1.,  0.,  0.],
        ...,
        [34.,  2.,  0.,  0.],
        [80.,  0.,  2.,  0.],
        [48.,  0.,  1.,  2

In [None]:
params5 = {"lls": [], "inferredAlphas": [], "inferredPis": [], "inferredPDVs": [], "truePDVs": [], "costFnIdx": []}
# altCountsByGenePooledCtrls4, afsByGenePooledCtrls4, affectedGenes4 = genData4(nCases=nCasesLarge, nCtrls=nCtrlsLarge, pDs=pDsGlobalLarge, diseaseFractions=diseaseFractions, rrShape=rrShape, rrMeans=rrMeansCovary, afMean=afMean, afShape=afShape, nGenes=nGenes)

cachedData5 = []

In [None]:
for i in range(2):
    if i >= -1:#len(cachedData6):
        # In DSB:
        # 	No ID	ID	
        #         ASD+ADHD	684	217	
        #         ASD	3091	871	
        #         ADHD	3206	271	
        #         Control	5002	-	

        #         gnomAD	44779	(Non-Finnish Europeans in non-psychiatric exome subset)	

        #         Case total:	8340		
        #         Control total:	49781		
        # so we can use pDBoth = .1 * total_cases
        params = genParams(rrMeans=tensor([5, 5, 2]), pis=tensor([.05, .05, .05]))
        start = time.time()
        r = genData5(**params[0])

        print("took", time.time() - start)
        cachedData5.append({**r, "params": params[0]})
    print("params are:", cachedData5[i]["params"], "pis are", pisRun)
    xsRun = cachedData5[i]["altCounts"]
    afsRun = cachedData5[i]["afs"]
    affectedGenesRun = cachedData5[i]["affectedGenes"]
    unaffectedGenesRun = cachedData5[i]["unaffectedGenes"]
    pDsRun = cachedData5[i]["params"]["pDs"]
    pisRun = cachedData5[i]["params"]["diseaseFractions"]

    print("i is", i)
    print("pis are", pisRun)
    runCostFnIdx = 15
    res = fitFnBivariate(xsPooledRun, pDsRun, nEpochs=10, minLLThresholdCount=20, debug=True, costFnIdx=runCostFnIdx)
    
    bestRes = res["params"][-1]

    inferredPis = tensor(bestRes[0:3]) # 3-vector
    inferredAlphas = tensor(bestRes[3:]) # 4-vector, idx0 is P(!D|V)

    # index 0 is the P(!D|V), aka probability of being a control given variant present
    inferredPDs = Dirichlet(concentration=inferredAlphas).sample([10_000,]).mean(0)
    
    inferredPiProp1 = inferredPis[0] / inferredPis.sum()
    inferredPiProp2 = inferredPis[1] / inferredPis.sum()
    
    inferredC1PDgivenV = inferredPiProp1 * inferredPDs[1] + (1 - inferredPiProp1) * (inferredPDs[1] + inferredPDs[3] * pDsRun[0]/pDsRun[2])
    inferredC2PDgivenV = inferredPiProp2 * inferredPDs[2] + (1 - inferredPiProp2) * (inferredPDs[2] + inferredPDs[3] * pDsRun[1]/pDsRun[2])
    inferredCBothPDgivenV = (inferredPis[0] / inferredPis.sum()) * inferredPDs[1] * pDsRun[2]/pDsRun[0] + (inferredPis[1] / inferredPis.sum()) * inferredPDs[2] * pDsRun[2]/pDsRun[1] + (inferredPis[2] / inferredPis.sum()) * inferredPDs[3]

    
    piProp1 = pisRun[0] / (pisRun[0] + pisRun[2])
    piProp2 = pisRun[1] / (pisRun[1] + pisRun[2])
    
    PV_hat = afsPooledRun[unaffectedGenesRun].mean()
    
    trueC1PVgivenD = piProp1 * afsPooledRun[affectedGenesRun[0], 0, 1].mean() + (1-piProp1) * afsPooledRun[affectedGenesRun[2], 0, 1].mean()
    trueC2PVgivenD = piProp2 * afsPooledRun[affectedGenesRun[1], 1, 1].mean() + (1-piProp2) * afsPooledRun[affectedGenesRun[2], 1, 1].mean()
    trueCBothPVgivenD = (pisRun[0] / pisRun.sum()) * afsPooledRun[affectedGenesRun[0], 2, 1].mean() + (pisRun[1] / pisRun.sum()) * afsPooledRun[affectedGenesRun[1], 2, 1].mean() + (pisRun[2] / pisRun.sum()) * afsPooledRun[affectedGenesRun[2], 2, 1].mean()
    
    trueC1PDgivenV = pDgivenV(pD=pDsRun[0], pVgivenD=trueC1PVgivenD, pV=PV_hat)
    trueC2PDgivenV = pDgivenV(pD=pDsRun[1], pVgivenD=trueC2PVgivenD, pV=PV_hat)
    trueCBothPDgivenV = pDgivenV(pD=pDsRun[2], pVgivenD=trueCBothPVgivenD, pV=PV_hat)
    
    inferredPDGivenVs = tensor([inferredC1PDgivenV, inferredC2PDgivenV, inferredCBothPDgivenV])
    truePDGivenVs = tensor([trueC1PDgivenV, trueC2PDgivenV, trueCBothPDgivenV])
    print("\ninferredPis", inferredPis)
    print("\ninferredPDVs", inferredPDGivenVs)
    print("\ntruePDGivenVs", truePDGivenVs)
    
    params5["lls"].append(res["lls"][-1])
    params5["inferredAlphas"].append(inferredAlphas)
    params5["inferredPis"].append(inferredPis)
    params5["inferredPDVs"].append(inferredPDGivenVs)
    params5["truePDVs"].append(truePDGivenVs)
    params5["params"].append(cachedData6[i]["params"])
    
    params5["costFnIdx"].append(runCostFnIdx)

    print(f"\n\nparams on run {i}\n", params6)

In [322]:
params7 = {"lls": [], "inferredAlphas": [], "inferredPis": [], "inferredPDVs": [], "truePDVs": [], "params": [], "costFnIdx": []}
# altCountsByGenePooledCtrls4, afsByGenePooledCtrls4, affectedGenes4 = genData4(nCases=nCasesLarge, nCtrls=nCtrlsLarge, pDs=pDsGlobalLarge, diseaseFractions=diseaseFractions, rrShape=rrShape, rrMeans=rrMeansCovary, afMean=afMean, afShape=afShape, nGenes=nGenes)

cachedData7 = []

In [324]:
for i in range(0, 50):
    if i >= len(cachedData7):
        # In DSB:
        # 	No ID	ID	
        #         ASD+ADHD	684	217	
        #         ASD	3091	871	
        #         ADHD	3206	271	
        #         Control	5002	-	

        #         gnomAD	44779	(Non-Finnish Europeans in non-psychiatric exome subset)	

        #         Case total:	8340		
        #         Control total:	49781		
        # so we can use pDBoth = .1 * total_cases
        params = genParams(rrMeans=tensor([5., 5., 2]), pis=tensor([.1, .1, .05]))
        start = time.time()
        r = genData7(**params[0])

        print("took", time.time() - start)
        cachedData7.append({**r, "params": params[0]})
    cd = cachedData7[i]
    print(f"I is {i}")
    print("params are:", cd["params"], "pis are", pisRun)
    xsRun = cd["altCounts"]
    afsRun = cd["afs"]
    affectedGenesRun = cd["affectedGenes"]
    unaffectedGenesRun = cd["unaffectedGenes"]
    pDsRun = cd["params"]["pDs"]
    pisRun = cd["params"]["diseaseFractions"]

    print("i is", i)
    print("pis are", pisRun)
    runCostFnIdx = 16
    res = fitFnBivariate(xsRun, pDsRun, nEpochs=10, minLLThresholdCount=20, debug=True, costFnIdx=runCostFnIdx)
    
    bestRes = res["params"][-1]

    inferredPis = tensor(bestRes[0:3]) # 3-vector
    inferredAlphas = tensor(bestRes[3:]) # 4-vector, idx0 is P(!D|V)

    # index 0 is the P(!D|V), aka probability of being a control given variant present
    inferredPDs = Dirichlet(concentration=inferredAlphas).sample([10_000,]).mean(0)
    
    inferredPiProp1 = inferredPis[0] / inferredPis.sum()
    inferredPiProp2 = inferredPis[1] / inferredPis.sum()
    
    inferredC1PDgivenV = inferredPiProp1 * inferredPDs[1] + (1 - inferredPiProp1) * (inferredPDs[1] + inferredPDs[3] * pDsRun[0]/pDsRun[2])
    inferredC2PDgivenV = inferredPiProp2 * inferredPDs[2] + (1 - inferredPiProp2) * (inferredPDs[2] + inferredPDs[3] * pDsRun[1]/pDsRun[2])
    inferredCBothPDgivenV = (inferredPis[0] / inferredPis.sum()) * inferredPDs[1] * pDsRun[2]/pDsRun[0] + (inferredPis[1] / inferredPis.sum()) * inferredPDs[2] * pDsRun[2]/pDsRun[1] + (inferredPis[2] / inferredPis.sum()) * inferredPDs[3]
    
    piProp1 = pisRun[0] / (pisRun[0] + pisRun[2])
    piProp2 = pisRun[1] / (pisRun[1] + pisRun[2])
    
    PV_hat = afsRun[unaffectedGenesRun].mean()
    
    trueC1PVgivenD = piProp1 * afsPooledRun[affectedGenesRun[0], 1].mean() + (1-piProp1) * afsPooledRun[affectedGenesRun[2], 1].mean()
    trueC2PVgivenD = piProp2 * afsPooledRun[affectedGenesRun[1], 2].mean() + (1-piProp2) * afsPooledRun[affectedGenesRun[2], 2].mean()
    trueCBothPVgivenD = (pisRun[0] / pisRun.sum()) * afsPooledRun[affectedGenesRun[0], 3].mean() + (pisRun[1] / pisRun.sum()) * afsPooledRun[affectedGenesRun[1], 3].mean() + (pisRun[2] / pisRun.sum()) * afsPooledRun[affectedGenesRun[2], 3].mean()
    
    trueC1PDgivenV = pDgivenV(pD=pDsRun[0], pVgivenD=trueC1PVgivenD, pV=PV_hat)
    trueC2PDgivenV = pDgivenV(pD=pDsRun[1], pVgivenD=trueC2PVgivenD, pV=PV_hat)
    trueCBothPDgivenV = pDgivenV(pD=pDsRun[2], pVgivenD=trueCBothPVgivenD, pV=PV_hat)
    
    inferredPDGivenVs = tensor([inferredC1PDgivenV, inferredC2PDgivenV, inferredCBothPDgivenV])
    truePDGivenVs = tensor([trueC1PDgivenV, trueC2PDgivenV, trueCBothPDgivenV])
    print("\ninferredPis", inferredPis)
    print("\ninferredPDVs", inferredPDGivenVs)
    print("\ntruePDGivenVs", truePDGivenVs)
    
    params7["lls"].append(res["lls"][-1])
    params7["inferredAlphas"].append(inferredAlphas)
    params7["inferredPis"].append(inferredPis)
    params7["inferredPDVs"].append(inferredPDGivenVs)
    params7["truePDVs"].append(truePDGivenVs)
    params7["params"].append(cd["params"])
    
    params7["costFnIdx"].append(runCostFnIdx)

    print(f"\n\nparams on run {i}\n", params7)

I is 0
params are: {'nGenes': 20000, 'nCases': tensor([5000., 5000., 1000.]), 'nCtrls': tensor(500000.), 'pDs': tensor([0.0098, 0.0098, 0.0020]), 'diseaseFractions': tensor([0.1000, 0.1000, 0.0500]), 'rrShape': tensor(10.), 'rrMeans': tensor([5., 5., 2.]), 'afShape': tensor(10.), 'afMean': tensor(1.0000e-04)} pis are tensor([0.1000, 0.1000, 0.0500])
i is 0
pis are tensor([0.1000, 0.1000, 0.0500])
shape torch.Size([20000, 4])
altCountsFlat tensor([[38.,  3.,  1.,  1.],
        [36.,  3.,  0.,  1.],
        [60.,  1.,  1.,  1.],
        ...,
        [67.,  0.,  0.,  0.],
        [60.,  0.,  0.,  0.],
        [35.,  1.,  1.,  0.]])
n tensor([43., 40., 63.,  ..., 67., 60., 37.])
xCase1, xCase2, xCase12 tensor([3., 3., 1.,  ..., 0., 0., 1.])
xCase1, xCase2, xCase12 tensor([1., 0., 1.,  ..., 0., 0., 1.])
xCase1, xCase2, xCase12 tensor([1., 1., 1.,  ..., 0., 0., 0.])
altCountsFlat tensor([[38.,  3.,  1.,  1.],
        [36.,  3.,  0.,  1.],
        [60.,  1.,  1.,  1.],
        ...,
        [6



best ll: 59982.078125, bestParams: [tensor(0.1253), tensor(0.1317), tensor(0.0780), tensor(2670.5784), tensor(10862.6279), tensor(20437.1914), tensor(20894.7539)]
epoch 0
 final_simplex: (array([[8.43590248e-02, 9.36026163e-02, 4.16182043e-02, 3.78246175e+03,
        2.25777261e+04, 1.93765392e+04, 1.40210948e+04],
       [8.43590249e-02, 9.36026163e-02, 4.16182042e-02, 3.78246175e+03,
        2.25777261e+04, 1.93765392e+04, 1.40210948e+04],
       [8.43590248e-02, 9.36026166e-02, 4.16182042e-02, 3.78246175e+03,
        2.25777261e+04, 1.93765392e+04, 1.40210948e+04],
       [8.43590246e-02, 9.36026167e-02, 4.16182041e-02, 3.78246175e+03,
        2.25777261e+04, 1.93765392e+04, 1.40210948e+04],
       [8.43590247e-02, 9.36026166e-02, 4.16182041e-02, 3.78246175e+03,
        2.25777261e+04, 1.93765392e+04, 1.40210948e+04],
       [8.43590248e-02, 9.36026167e-02, 4.16182041e-02, 3.78246175e+03,
        2.25777261e+04, 1.93765391e+04, 1.40210948e+04],
       [8.43590248e-02, 9.36026166e-02

In [370]:
lastIdx = len(cachedData7) - 1
cd = cachedData7[lastIdx]
alphas = params7["inferredAlphas"][lastIdx + 1]
pds = tensor([1 - cd["params"]["pDs"].sum(), *cd["params"]["pDs"]])
print('pds', pds)
individualProbabilities = Dirichlet(alphas*pds).sample([10_000]).mean(0)
print("alphas", alphas)
print("P(D|V)'s'", individualProbabilities)
pis = params7["inferredPis"][lastIdx + 1]
print("pis", pis)
params = params7["params"][0]
alphas
print(params.__repr__())
print(cd["params"])

inferredAlphas = tensor([x.numpy() for x in params7["inferredAlphas"][1:]])
avg = inferredAlphas.mean(0)
std = inferredAlphas.std(0)
minimum = inferredAlphas.min(0)
maximum = inferredAlphas.max(0)

print("avg", avg, "std", std, "min", minimum, "max", maximum)

inferredPDVs = tensor([Dirichlet(x*pds).sample([10_000]).mean(0).numpy() for x in params7["inferredAlphas"][1:]])
avg = inferredPDVs.mean(0)
std = inferredPDVs.std(0)
minimum = inferredPDVs.min(0)
maximum = inferredPDVs.max(0)

print("avg P(D|V)", avg, "std P(D|V)", std, "min P(D|V)", minimum, "max P(D|V)", maximum)


params7["params"]

params7["params"][0]

pds tensor([0.9785, 0.0098, 0.0098, 0.0020])
alphas tensor([1663.6231, 9815.7876, 8915.7372, 5856.1519], dtype=torch.float64)
P(D|V)'s' tensor([0.8932, 0.0527, 0.0479, 0.0063], dtype=torch.float64)
pis tensor([0.0885, 0.0844, 0.0428], dtype=torch.float64)
{'nGenes': 20000, 'nCases': tensor([5000., 5000., 1000.]), 'nCtrls': tensor(500000.), 'pDs': tensor([0.0098, 0.0098, 0.0020]), 'diseaseFractions': tensor([0.1000, 0.1000, 0.0500]), 'rrShape': tensor(10.), 'rrMeans': tensor([5., 5., 2.]), 'afShape': tensor(10.), 'afMean': tensor(1.0000e-04)}
{'nGenes': 20000, 'nCases': tensor([5000., 5000., 1000.]), 'nCtrls': tensor(500000.), 'pDs': tensor([0.0098, 0.0098, 0.0020]), 'diseaseFractions': tensor([0.1000, 0.1000, 0.0500]), 'rrShape': tensor(10.), 'rrMeans': tensor([5., 5., 2.]), 'afShape': tensor(10.), 'afMean': tensor(1.0000e-04)}
avg tensor([1092.5818, 6318.8419, 5645.6539, 3940.6564], dtype=torch.float64) std tensor([1306.6170, 7633.5401, 6885.9809, 4778.8787], dtype=torch.float64) min 

{'nGenes': 20000,
 'nCases': tensor([5000., 5000., 1000.]),
 'nCtrls': tensor(500000.),
 'pDs': tensor([0.0098, 0.0098, 0.0020]),
 'diseaseFractions': tensor([0.1000, 0.1000, 0.0500]),
 'rrShape': tensor(10.),
 'rrMeans': tensor([5., 5., 2.]),
 'afShape': tensor(10.),
 'afMean': tensor(1.0000e-04)}

In [279]:
cd = cachedData7[1]
np.corrcoef(cd["afs"][0:1000, 0], cd["afs"][0:1000, 2])

array([[1.      , 0.999245],
       [0.999245, 1.      ]])

In [None]:
# cd = cachedData4b[0]["xsPooled"]
# print(cd[:, 0, 0][0])

# cdCtrl = cd[:, 0, 0]
# cdCases = cd[:, :, 1]

# cdFlat = []
# for geneIdx in range(20_000):
#     cdFlat.append([cdCtrl[geneIdx], *cdCases[geneIdx].flatten()])
# cdFlat = tensor(cdFlat)
# cdFlat.type()

In [None]:
# cdFlat.max()

In [None]:
x = Multinomial(probs=tensor([.5, .5]))

In [None]:
x.log_prob(tensor([5., 0.]))

In [None]:
cachedData4b[-1]["xsPooled"].max()

In [None]:
nCtrlsRun = tensor(5e5)
nCases12 = tensor([5e3, 5e3])
nCasesBoth = nCases12.sum() * .1
nCasesRun = tensor([*nCases12, nCasesBoth])

pDsRun = nCasesRun / ( nCasesRun.sum() + nCtrlsRun )
pDsRun

In [None]:
unaffectedGenesRun

In [None]:
params4

In [None]:
params4

In [None]:
np.corrcoef(cachedData4[0][0][4000:5000, 0, 0], cachedData4[0][0][4000:5000, 0, 1])

In [None]:
cdAlts = cachedData4[0][0]
cdAfs = cachedData4[0][1]
cdGenes = cachedData4[0][2]
cdPDs = cachedData4[0][3]["pDs"]
cdPDs
# pDgivenV(cdPDs[0], cdAfs[cdGenes[0], 0, 1], cdAfs[genesRun[0], 0, 0]).mean()

In [None]:
cdAlts

In [None]:
## Calculating P(D|V) for sample 1:
# First, we know which genes are affected, the first pi1*nGenes, and then some genes that affect both diseases
cdAfs[0:2000, 0, 1].mean()/cdAfs[0:2000, 0, 0].mean()
s1PVgivenD = (2/3.0) * cdAfs[0:2000, 0, 1].mean() + (1/3.0) * cdAfs[4000:5000, 0, 1].mean()
s1pD = cdPDs[0]
s1PVhat = cdAfs[5000:, 0, 0].mean() # estimate allele frequency
print("s1PVgivenD", s1PVgivenD, "s1pD", s1pD, "s1PVhat", s1PVhat)
# Note, we don't use 0:2000 for the estimate, because control allele frequency is P(V|!D)
# and is depressed by the the presence of controls, e.g our P(V|!D) is proportional to P(V) - (P(V|D)*P(D)).sum()
print("True P(D1|V)", pDgivenV(pD=s1pD,pVgivenD=s1PVgivenD,pV=s1PVhat))

# Now the estimated one
pisInferred = params4["inferredPis"][0]
pi1inferred = pisInferred[0]
piBothinferred = pisInferred[2]
pi1ratio = (pi1inferred / (pi1inferred + piBothinferred))
print("pi1inferred", pi1inferred, "piBothinferred", piBothinferred, "pi1ratio", pi1ratio)
pDsInferred = params4["inferredPDVs"][0]
pD1onlyInferred = pDsInferred[1]
pDsharedInferred = pDsInferred[3]
print("pD1onlyInferred inferred", pD1onlyInferred, "pDsharedInferred", pDsharedInferred)
print("Inferred P(D1|V)", pD1onlyInferred * pi1ratio + (pD1onlyInferred + pDsharedInferred) * (1 - pi1ratio))


In [None]:
len(cdGenes[)

In [None]:
## The P(D|V) for condition

In [None]:
cdAfs[cdGenes[0], 0, 1].mean() / cdAfs[cdGenes[0], 0, 0].mean()

In [None]:
# params4, cachedData2 = runModel(altCountsByGenePooledCtrls2, afsByGenePooledCtrls2, 4)

In [None]:
r = altCount

In [None]:
res = fitFnBivariate(cachedData[0][0], pDsLarge, nEpochs=20, minLLThresholdCount=20, debug=True, costFnIdx=0)

In [None]:
res1 = fitFnBivariate(cachedData[0][0], pDsLarge, nEpochs=20, minLLThresholdCount=20, debug=True, costFnIdx=1)

In [None]:
res2 = fitFnBivariate(cachedData[0][0], pDsLarge, nEpochs=20, minLLThresholdCount=20, debug=True, costFnIdx=2)

In [None]:
print("res0", res0)
print("\nres0", "pis", res0["params"][-1][0:3], "mean P(D|V)'s", Dirichlet(tensor(res0["params"][-1][3:])).sample([10_000]).mean(0))

print("\n\n\nres1", res1)
print("\nres1", "pis", res1["params"][-1][0:3], "mean P(D|V)'s", Dirichlet(tensor(res1["params"][-1][3:])).sample([10_000]).mean(0))

print("\n\n\nres2", res2)
print("\nres2", "pis", res2["params"][-1][0:3], "mean P(D|V)'s", Dirichlet(tensor(res2["params"][-1][3:])).sample([10_000]).mean(0))

In [None]:
pyplot.plot(res0["llTrajectory"])



In [None]:
pyplot.plot(res1["llTrajectory"])

In [None]:
pyplot.plot(res2["llTrajectory"])

In [None]:
params

In [None]:
(afsByGenePooledCtrls[2000:4000, 1, 1]/afsByGenePooledCtrls[2000:4000, 0, 0]).mean()

In [None]:
truth = pDgivenV(pD., afsByGenePooledCtrls[0:2000, :, 1], afsByGenePooledCtrls[0:2000, 0, 0])

In [None]:
test = Dirichlet(tensor(1/4.0).expand(4)).sample()
test = test[0:3]
r = [0,1,2,3]
r[0:4]

In [None]:
fitFnBivariate(altCountsByGenePooledCtrls, pDs, nEpochs=100, minLLThresholdCount=100, debug=True)

In [None]:
d = Dirichlet(concentration=tensor([1.40625703e+04,
         5.56195520e+03, 1.57978682e+02, 2.33518936e+04]))
d.sample([10_000,]).mean(0)

In [None]:
Beta(7.74788652e+02, 2.58170768e+04 + 9.72956833e+02 + 5.18278100e+03).sample([10000]).mean()

In [None]:
Beta(3.05871723e+04, 3.25256694e+02 + 3.75135881e+03 +4.52942294e+04).sample([10000,]).mean()

In [None]:
start = time.time()
res = fitFnUniveriateBetaBinomial(altCountsByGene, pDs, nEpochs=100, minLLThresholdCount=100, debug=False)
print("fitFnUniveriateBetaBinomial took for 100 epochs: ", time.time() - start)

In [None]:
pyplot.plot(res["llTrajectory"])
res

In [None]:
binomH0 = Binomial(total_count=tensor([1.,1]), probs=pDs[0])

In [None]:
binomH0.log_prob(tensor(1.))

In [None]:
costFn2 = likelihoodUnivariateFast(altCountsByGene, pDs)
# print(costFn2([1e-9, .999999]))
print(costFn2([1e-9, 1e-9]))
print(costFn2([0.08845797,0.11094360])) #gives ~12067 using jensen's method, and ~9887 using exponentiation of the log

# best result from R
#  0.08845797           0.11094360 , ll -10127.23, and with jensen's version, "example -12037.4347455843"
# pDgivenV, pi1


In [None]:
costFn = likelihoodUnivariate(altCountsByGene, pDs)
print("costFn1:", costFn([.001, .01]),"costFn2:",costFn2([.001, .01]))

In [None]:
costFn([0.0001,0.11094360])

In [None]:
print(costFn2([0.0001,0.11094360]))

In [None]:
d = Binomial(total_count=tensor([14., 0., 9.]), probs=tensor(.0099))
d.log_prob(tensor([0.,0.,0.]))

In [None]:
costFn2([1e-9, .999999])

In [None]:
binomH0 = Binomial(total_count=geneSums, probs=.001)
binomH1 = Binomial(total_count=geneSums, probs=.01)
caseAltCounts = altCountsByGene[:, 0, 1]
print(caseAltCounts)
component0 = binomH0.log_prob(caseAltCounts)
print("component0", component0, .5*component0)
component1 = binomH1.log_prob(caseAltCounts)

In [None]:
pDgivenV(pDs[0], afsByGene2[0:2000, 0, 1].mean(), afMean)

In [None]:
condition1 = altCountsByGene2[:, 0, :]
condition1
pDs[0]

afsByGene2[0:2000,:,1].mean()

In [None]:
pyplot.figure(num=None, figsize=(20, 6), dpi=80, facecolor='w', edgecolor='k')
pyplot.plot(afsByGene2[:, 0, 1:2].flatten())
pyplot.plot(afsByGene2[:, 0, 0:1].flatten())

In [None]:
pyplot.figure(num=None, figsize=(20, 6), dpi=80, facecolor='w', edgecolor='k')
pyplot.plot(afsByGenePooledCtrls[:, 0, 0:1].flatten())
pyplot.plot(afsByGenePooledCtrls[:, 0, 1:2].flatten())
pyplot.plot(afsByGenePooledCtrls[:, 1, 1:2].flatten())
pyplot.plot(afsByGenePooledCtrls[:, 2, 1:2].flatten())
# pyplot.plot(afsByGeneRR2[:, 0, 1:2].flatten())
# pyplot.plot(afsByGeneRR2[:, 0, 0:1].flatten())

In [None]:
ctrlCounts = altCountsByGene[:, 0, 0]
altCountsCases = altCountsByGene[:, :, 1]

altCountsFlat = []
for geneIdx in range(nGenes):
    altCountsFlat.append([ctrlCounts[geneIdx], *altCountsByGene[geneIdx, :, 1].flatten()])
altCountsFlat = tensor(altCountsFlat)

In [None]:
altCountsFlat[0]

In [None]:
DirichletMultinomial?

In [None]:
import os
from collections import defaultdict
import torch
import numpy as np
import scipy.stats
from torch.distributions import constraints
from matplotlib import pyplot
%matplotlib inline

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True)

K = 4  # Fixed number of components.

@config_enumerate
def model(data):
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    alpha0 = pyro.sample('alpha0', Uniform(1, 2.5e4))
    alpha1 = pyro.sample('alpha0', Uniform(1, 2.5e4))
    alpha2 = pyro.sample('alpha0', Uniform(1, 2.5e4))
    alpha3 = pyro.sample('alpha0', Uniform(1, 2.5e4))

    with pyro.plate('components', K):
        concentrations = pyro.sample('concentrations', dist.Dirichlet(0.5 * torch.ones(K)))

    with pyro.plate('data', len(data)):
        # Local variables.
        component = pyro.sample('assignment', dist.Categorical(weights))
        print(f"concentrations: {concentrations[component]}")
        pyro.sample('obs', DirichletMultinomial(concentration=concentrations[component], total_count=data.sum(1)), obs=data)

optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)

In [None]:
def init_loc_fn(site):
    if site["name"] == "weights":
        # Initialize weights to uniform.
        return torch.ones(K) / K
    if site["name"] == "concentrations":
        return torch.ones(K) / K
    raise ValueError(site["name"])

def initialize(seed):
    global global_guide, svi
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    global_guide = AutoDelta(poutine.block(model, expose=['weights', 'concentrations']),
                             init_loc_fn=init_loc_fn)
    svi = SVI(model, global_guide, optim, loss=elbo)
    return svi.loss(model, global_guide, altCountsFlat)

# Choose the best among 100 random initializations.
loss, seed = min((initialize(seed), seed) for seed in range(2))
initialize(seed)
print('seed = {}, initial_loss = {}'.format(seed, loss))

In [None]:
# Register hooks to monitor gradient norms.
gradient_norms = defaultdict(list)
for name, value in pyro.get_param_store().named_parameters():
    value.register_hook(lambda g, name=name: gradient_norms[name].append(g.norm().item()))

losses = []
for i in range(200 if not smoke_test else 2):
    loss = svi.step(altCountsFlat)
    losses.append(loss)
    print('.' if i % 100 else '\n', end='')

In [None]:
pyplot.figure(figsize=(10,3), dpi=100).set_facecolor('white')
pyplot.plot(losses)
pyplot.xlabel('iters')
pyplot.ylabel('loss')
pyplot.yscale('log')
pyplot.title('Convergence of SVI');

In [None]:
for name, grad_norms in gradient_norms.items():
    pyplot.plot(grad_norms, label=name)
pyplot.xlabel('iters')
pyplot.ylabel('gradient norm')
pyplot.yscale('log')
pyplot.legend(loc='best')
pyplot.title('Gradient norms during SVI');

In [None]:
map_estimates = global_guide(altCountsFlat)
weights = map_estimates['weights']
locs = map_estimates['concentrations']
print('weights = {}'.format(weights.data.numpy()))
print('concentrations = {}'.format(locs.data.numpy()))

In [None]:
Dirichlet(tensor([0.8973397  , 0.0494441,  0.04917945, 0.00403667])).sample([10_000,]).mean()

In [None]:
import os
from collections import defaultdict
import torch
import numpy as np
import scipy.stats
from torch.distributions import constraints
from matplotlib import pyplot
%matplotlib inline

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.3.1')
pyro.enable_validation(True)

In [None]:
K = 2  # Fixed number of components.

@config_enumerate
def model(data):
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample('scale', dist.LogNormal(0., 2.))
    with pyro.plate('components', K):
        locs = pyro.sample('locs', dist.Normal(0., 10.))

    with pyro.plate('data', len(data)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)
        
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)


def init_loc_fn(site):
    if site["name"] == "weights":
        # Initialize weights to uniform.
        return torch.ones(K) / K
    if site["name"] == "scale":
        return (data.var() / 2).sqrt()
    if site["name"] == "locs":
        return data[torch.multinomial(torch.ones(len(data)) / len(data), K)]
    raise ValueError(site["name"])

def initialize(seed):
    global global_guide, svi
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scale']),
                             init_loc_fn=init_loc_fn)
    svi = SVI(model, global_guide, optim, loss=elbo)
    return svi.loss(model, global_guide, data)

# Choose the best among 100 random initializations.
loss, seed = min((initialize(seed), seed) for seed in range(100))
initialize(seed)
print('seed = {}, initial_loss = {}'.format(seed, loss))

# Register hooks to monitor gradient norms.
gradient_norms = defaultdict(list)
for name, value in pyro.get_param_store().named_parameters():
    value.register_hook(lambda g, name=name: gradient_norms[name].append(g.norm().item()))

losses = []
for i in range(200 if not smoke_test else 2):
    loss = svi.step(data)
    losses.append(loss)
    print('.' if i % 100 else '\n', end='')
    

pyplot.figure(figsize=(10,3), dpi=100).set_facecolor('white')
pyplot.plot(losses)
pyplot.xlabel('iters')
pyplot.ylabel('loss')
pyplot.yscale('log')
pyplot.title('Convergence of SVI');


pyplot.figure(figsize=(10,4), dpi=100).set_facecolor('white')
for name, grad_norms in gradient_norms.items():
    pyplot.plot(grad_norms, label=name)
pyplot.xlabel('iters')
pyplot.ylabel('gradient norm')
pyplot.yscale('log')
pyplot.legend(loc='best')
pyplot.title('Gradient norms during SVI');

In [None]:
map_estimates = global_guide(data)
weights = map_estimates['weights']
locs = map_estimates['locs']
scale = map_estimates['scale']
print('weights = {}'.format(weights.data.numpy()))
print('locs = {}'.format(locs.data.numpy()))
print('scale = {}'.format(scale.data.numpy()))