In [None]:
# default_exp tree_likelihood

In [None]:
from multiinstance.ward_clustering import WardClustering

In [None]:
from multiinstance.utils import *
from multiinstance.data.syntheticData import buildDataset,getBag

# import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# export
from autograd import grad,hessian
from autograd.scipy import  stats as agss
import autograd.numpy as np

import scipy.stats as ss

from tqdm.notebook import tqdm

In [None]:
from multiinstance.data.realData import buildDataset as buildReal

In [None]:
from glob import glob

In [None]:
# export

class LikelihoodMethod:
    def __init__(self,ds,clusterAssignments, alphaHatMat,lr=0.01,rowLambda=1.0):
        self.ds = ds
        self.clusterAssignments = clusterAssignments.astype(int)
        self.alphaHatMat = alphaHatMat
        self.leafMeans = np.mean(self.alphaHatMat[0],axis=1)
        self.initClusterVariances()
        self.lr = lr
        self.meanHistory = []
        self.varianceHistory = []
        self.MAEs = []
        self.NLLs = []
        self.rowLambda = rowLambda
        
    def initClusterVariances(self):
            self.clusterVariances = []
            # Coordinate tuple in clusterAssignment -> index withing cluster variances
            self.loc2Idx = {}
            for rowNum in range(self.clusterAssignments.shape[0]):
                levelClusters = np.unique(self.clusterAssignments[rowNum])
                for cluster in levelClusters:
                    self.loc2Idx[(rowNum,cluster)] = len(self.clusterVariances)
                    alphaHats= self.alphaHatMat[rowNum, cluster]
                    _,v = ss.norm.fit(alphaHats)
                    self.clusterVariances.append(v)
                    
            self.clusterVariances = np.array(self.clusterVariances)
                    
    def logLikelihood(self,alphaHats, mu, sigma):
        LL = np.sum(agss.norm.logpdf(alphaHats, mu, sigma))
        LL = LL * (1 / len(alphaHats))
        return LL

    def treeNegativeLogLikelihood(self):
        def getLevelClusters(rowNum):
            clusterLabels = np.unique(self.clusterAssignments[rowNum])
            clusters = {c : np.where(self.clusterAssignments[rowNum] == c)[0] for c in clusterLabels}
            return clusters
        
        def getClusterMean(leafMeans, clusterMembers):
            leafSizes= np.array([self.ds.numU[i] for i in clusterMembers])
            alphaTilde = np.dot(leafMeans[clusterMembers], leafSizes) / np.sum(leafSizes)
            return alphaTilde
            
        def NLL(leafMeans, clusterVars):
            ll = 0
            for rowNum in range(self.clusterAssignments.shape[0]):
                clusters = getLevelClusters(rowNum)
                for clusterIdx, clusterMembers in clusters.items():
                    clusterMean = getClusterMean(leafMeans,clusterMembers)
                    varIdx = self.loc2Idx[(rowNum, clusterIdx)]
                    clusterVar = clusterVars[varIdx]
                    alphaHats = self.alphaHatMat[rowNum, clusterIdx]
                    ll = ll + self.rowLambda**rowNum * self.logLikelihood(alphaHats, clusterMean, clusterVar)
            return -1 * ll
        return NLL

    def run(self, n_iters):
        gradNLL_mu = grad(self.treeNegativeLogLikelihood(), 0)
        gradNLL_sigma = grad(self.treeNegativeLogLikelihood(), 1)
        hessianNLL_mu = hessian(self.treeNegativeLogLikelihood(), 0)
        hessianNLL_sigma = hessian(self.treeNegativeLogLikelihood(), 1)
        self.log()
        means = self.leafMeans
        var = self.clusterVariances
        for iteration in tqdm(range(n_iters)):
            if not n_iters % 500:
                self.lr *= .95
            deltaMu = np.linalg.inv(hessianNLL_mu(means,
                                                  var)) @ gradNLL_mu(means,
                                                                     var)
            deltaSigma = np.linalg.inv(hessianNLL_sigma(mean,
                                                        var)) @ gradNLL_sigma(means,
                                                                             var)
            means = means - self.lr * deltaMu
            var = var - self.lr * deltaSigma
            assert (self.clusterVariances > 0).all()
            self.log()

    def log(self):
        self.MAEs.append(np.mean(np.abs(dsi.trueAlphas.flatten() - self.leafMeans)))
        nllfunc = self.treeNegativeLogLikelihood()
        self.NLLs.append(nllfunc(self.leafMeans, self.clusterVariances))
        self.meanHistory.append(self.leafMeans)
        self.varianceHistory.append(self.clusterVariances)

In [None]:
def plotDistrTree(trueAlphas, alphaHatMat, meanHistory, scaleHistory,loc2Index,clusterAssignments, numU):
    rows,cols = list(zip(*list(method.loc2Idx.keys())))
    Nrows = np.max(rows) + 1
    Ncols = np.max(cols) + 1
    fig,ax = plt.subplots(nrows=Nrows,ncols=Ncols, figsize=(5 * Nrows, 5*Ncols))
    for row in range(clusterAssignments.shape[0]):
        clusters = np.unique(clusterAssignments[row])
        for c in clusters:
            scale = scaleHistory[-1][loc2Index[(row,c)]]
            scale0 = scaleHistory[0][loc2Index[(row,c)]]
            children = np.where(clusterAssignments[row] == c)[0]
            childMeans = meanHistory[-1][children]
            childMeans0 = meanHistory[0][children]
            childN = numU[children]
            mu = np.dot(childMeans, childN) / childN.sum()
            mu0 = np.dot(childMeans0, childN) / childN.sum()
            alpha = np.dot(trueAlphas[children], childN)/ childN.sum()
            ax[row,c].plot(np.arange(0,1,.01),
                           ss.norm.pdf(np.arange(0,1,.01),loc=mu,scale=scale),color="green")
            ax[row,c].plot(np.arange(0,1,.01),
                           ss.norm.pdf(np.arange(0,1,.01),loc=mu0,scale=scale0),color="red",alpha=.5)
            ax[row,c].hist(alphaHatMat[row,c],density=True,color="blue")
            ax[row,c].vlines(alpha, 0,1,color="red")
    plt.show()
    return fig

In [None]:
from sklearn.metrics import roc_auc_score
def posteriorCorrection(tau, alpha, S0S1):
    post =  alpha * S0S1 * (tau / (1 - tau))
    post[np.isinf(post)] = 1
    return post

def correctedAUC(ds,bagAlphaHats,):
    _, tauArrays = list(zip(*[getTransformScores(ds,i) for i in range(ds.N)]))
    S0_S1 = ds.numU/ds.numP
    posteriors = [posteriorCorrection(tau,alphaHat, s0s1) for tau,alphaHat,s0s1 in zip(tauArrays,
                                                                                       bagAlphaHats,
                                                                                       S0_S1)]
    posteriorVals = np.concatenate(posteriors)
    hiddenLabels = np.concatenate([ds.hiddenLabels[i][:ds.numU[i]] for i in range(ds.N)])
    return roc_auc_score(hiddenLabels, posteriorVals)

In [None]:
dsi = buildDataset(1,alphaDistr=lambda: np.random.choice([.2]),
                  nP=100,nU=200,posMean=5,negMean=1,cov=1)
ds2 = buildDataset(1,alphaDistr=lambda: np.random.choice([.8]),
                  nP=100,nU=200,posMean=5,negMean=1,cov=1)
dsi.merge(ds2)
dsi = addTransformScores(dsi)
dsi = addGlobalEsts(dsi)
dsi.alphaHats,dsi.curves = getBagAlphaHats(dsi,numbootstraps=25)

In [None]:
ward = WardClustering(dsi,numbootstraps=dsi.alphaHats.shape[-1],randomPairing=True)
ward.cluster()

In [None]:
ward.alphaHatMat

In [None]:
method = LikelihoodMethod(dsi,ward.clusterAssignment,
                          ward.alphaHatMat + np.random.normal(scale=0.00001,size=ward.alphaHatMat.shape),
                          lr=0.1,rowLambda=1.0)


In [None]:
method.run(100)

In [None]:
fig = plotDistrTree(dsi.trueAlphas.flatten(),method.alphaHatMat, method.meanHistory,
              method.varianceHistory, method.loc2Idx, method.clusterAssignments, dsi.numU)

In [None]:
plt.plot(method.MAEs)

In [None]:
plt.plot(method.NLLs)

In [None]:
absErrs = {"local":[],
           "likelihood":[],
           "global": []}
aucVals = {
    "local":[],
    "likelihood":[],
    "global":[]
}
N = 0
for f in tqdm(glob("/ssdata/ClassPriorEstimationPrivate/data/rawDatasets/*.mat")):
    dsi = buildReal(f,16,
                    alphaDistr=lambda: np.random.uniform(.05,.95),
                    nPDistr=lambda: 1 + np.random.poisson(125),
                    nUDistr=lambda: 1 + np.random.poisson(175))
    dsi = addTransformScores(dsi)
    dsi = addGlobalEsts(dsi,reps=10)
    dsi.alphaHats,dsi.curves = getBagAlphaHats(dsi,
                                               numbootstraps=100)
    globalMAE = np.mean(np.abs(dsi.trueAlphas.flatten() - dsi.globalAlphaHats.mean()))
    absErrs["global"].append(globalMAE * dsi.N)
    aucVals["local"].append(correctedAUC(dsi,dsi.alphaHats.mean(1)))
    aucVals["global"].append(correctedAUC(dsi,np.ones(dsi.N)*dsi.globalAlphaHats.mean()))
    wrd = WardClustering(dsi,numbootstraps=dsi.alphaHats.shape[1],randomPairing=True)
    wrd.cluster()
    mth = LikelihoodMethod(dsi,wrd.clusterAssignment,
                              wrd.alphaHatMat + np.random.normal(scale=0.00001,size=wrd.alphaHatMat.shape),
                              lr=0.01,rowLambda=1.0)
    mth.run(250)
    absErrs["local"].append(mth.MAEs[0] * dsi.N)
    absErrs["likelihood"].append(mth.MAEs[-1] * dsi.N)
    aucVals["likelihood"].append(correctedAUC(dsi, mth.leafMeans))
    maefig,ax = plt.subplots()
    
    ax.plot(mth.MAEs)
    
    ax.hlines(globalMAE,0,len(mth.MAEs),label="global")
    plt.show()
#     treeFig = plotDistrTree(dsi.trueAlphas.flatten(),
#                         mth.alphaHatMat,
#                         mth.meanHistory,
#                         mth.varianceHistory,
#                         mth.loc2Idx,
#                         mth.clusterAssignments,
#                         dsi.numU)
#     plt.show()
    fig,ax = plt.subplots()
    plt.plot(mth.NLLs)
    plt.show()
    N += dsi.N
    print("MAE")
    print("local: {:.3f}".format(np.sum(absErrs["local"])/N))
    print("likelihood: {:.3f}".format(np.sum(absErrs["likelihood"])/N))
    print("global: {:.3f}".format(np.sum(absErrs["global"])/N))
    print("AUC")
    print("local: {:.3f}".format(np.mean(aucVals["local"])))
    print("likelihood: {:.3f}".format(np.mean(aucVals["likelihood"])))
    print("global: {:.3f}".format(np.mean(aucVals["global"])))

In [None]:
aucVals