In [1]:
from poolData import protein, interactionList, pooledDataset, proteinProteinMatrix
from poolSolver import parallelProteinSolver, subsetSelectionProteinSolvers, bestSubsetSelectionPoolSolver
from poolSolver import correlationSolver, nnlsSolver, leastSquaresSolver
from poolSimulator import poolSimulator
import importlib
importlib.reload(poolSimulator)

import matplotlib.pyplot as plt
 

In [None]:
##############################
# Figure 1b: plot mixing matrix, 
# signal vectors and 
# reconstructed pulldown vectors
################################


#Loads the design matrix
designFile= "standardizedData/20250404_PooledIP_15x30/20250404_PooledIP_15x30.mixingMatrix.tsv"
(baits, mixing) = pooledDataset.pooledDataset.readMixingMatrix(designFile)

#Normalizes the pools to have max=1
mixing=(mixing.T/np.max(mixing, axis=1)).T


#Builds a pool solver object
sim = poolSimulator.poolSimulator(mixing, baits)

#Builds the ground-truth protein-protein interaction matrix
np.random.seed(1)
sim.buildBernoulliPPIMatrix(20,0.03)

#Sets a ground truth PPI matrix for plotting purposes
(iProt, iBait) = (17, 8)
sim.groundTruthPPIs.matrix[iProt] =  np.zeros(sim.groundTruthPPIs.matrix.shape[1])
sim.groundTruthPPIs.matrix[iProt, iBait] = 1

# Simulates the experiment with a specified noise level
sim.simulateExperiment(noiseSD=0.2, noiseType="normal_nonneg")

#Solves the experiment
nnls_bss = bestSubsetSelectionPoolSolver.NNLSBestSubsetSelectionProteinSolver(sim.syntheticData, stopping_criteria=5, model='F_global', maxSize=4)
ppiMatrix = parallelProteinSolver.parallelPoolSolver(sim.syntheticData,nnls_bss).solveExperiment()


ppiMatrix.removeColumns([0]).plotCombinedMixSignalBetaPlot(imageScale=0.25, trainingData=sim.syntheticData, ppiAnnotations=[], sigThreshold=None, mixingIndices=range(len(mixing[0])))

plt.show()

In [None]:
########################################
#Figure 1b: plot correlation between
# mixing and signal vectors for Protein 19 
# with Ab 9 and Ab 22
#########################################



#correlation with Ab 9
iProt = 18
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(4, 1.5))


yMix=sim.mixingMatrix[:,iBait]
ySig=sim.syntheticData.signalMatrix[iProt]
x = range(len(yMix))

corr_coef = np.corrcoef(yMix, ySig)[0, 1]
axs[0].plot(x, ySig, 'o', color=ppiMatrix.cSig, linestyle='-')
axs[0].plot(x, yMix, 'o', color=ppiMatrix.cMix, linestyle='-')
axs[0].set_title(f'Protein {iProt+1} vs. AB {iBait+1}')
axs[0].set_xlabel(f'Pools')
axs[0].set_ylabel(f'Signal\n(protein {iProt+1})')
axs[0].set_xticks([])
axs[0].text(0.05, 0.95, f'r = {corr_coef:.2f}', transform=axs[0].transAxes,
            verticalalignment='top', horizontalalignment='left', fontsize=10, bbox=dict(facecolor='white', alpha=0.0, edgecolor='none'))




#correlation with Ab 22
iBaitOff=21
yMix=sim.mixingMatrix[:,iBaitOff]
ySig=sim.syntheticData.signalMatrix[iProt]
x = range(len(yMix))

corr_coef = np.corrcoef(yMix, ySig)[0, 1]
axs[1].plot(x, ySig, 'o', color=ppiMatrix.cSig, linestyle='-')
axs[1].plot(x, yMix, 'o', color=ppiMatrix.cMix, linestyle='-')
axs[1].set_title(f'Protein {iProt+1} vs. AB {iBaitOff+1}')
axs[1].set_yticks([])
axs[1].set_xticks([])
axs[1].text(0.05, 0.95, f'r = {corr_coef:.2f}', transform=axs[1].transAxes,
            verticalalignment='top', horizontalalignment='left', fontsize=10, bbox=dict(facecolor='white', alpha=0.0, edgecolor='none'))


plt.show

In [None]:
#####################################
# Figure 1c: Solve Protein 16 using Pool Solver
###########################################

iPlot=15
Fcut=5
iBest=0
y=sim.syntheticData.signalMatrix[iPlot]
fullOutput=nnls_bss.solveProtein(y, fullOutput=True)
#Plots:
betaMatrix=np.zeros((len(fullOutput["modelList"]), nnls_bss.design_matrix.shape[1]))
for (iModel, model) in enumerate(fullOutput["modelList"]):
    print(f'Model {iModel}, Indices: {model["indices"]}, Betas: {model["betas"]}, RSS: {model["RSS"]}')
    if iModel >0:
        (RSS1, RSS2) = ( fullOutput["modelList"][iModel-1]["RSS"], fullOutput["modelList"][iModel]["RSS"] )
        (df1,  df2)  = ( fullOutput["modelList"][iModel-1]["p"],   fullOutput["modelList"][iModel]["p"] )
        F = ((RSS1-RSS2)/(df2-df1))/(RSS2/(len(y) - df2))
        print(f"    F-statistic for model {iModel}: {F}, RSS1: {RSS1}, RSS2: {RSS2}, df1: {df1}, df2: {df2}, n: {len(y)}")
        if F>Fcut:
                iBest = iModel
    for (iCoef, iCol) in enumerate(model["indices"]):
        betaMatrix[iModel, iCol] = model["betas"][iCoef]
# Plots models
modelNames=[protein.protein(f'Model {i}',[]) for i in range(len(fullOutput["modelList"]))]
betaPPMatrix=proteinProteinMatrix.proteinProteinMatrix(betaMatrix, modelNames, ppiMatrix.columnProteins, [])
betaPPMatrix.plotCombinedMixSignalBetaPlot(imageScale=0.25)
print(f"Best model: {iBest}")

print("> Exclusion:")
iBest=3
betaExclusionMatrix = np.array(fullOutput["modelList"][iBest]["beta_exclusion"])
exclusionNames = [ ppiMatrix.columnProteins[i] for i in fullOutput['indices']]
betaExclusionPPMatrix=proteinProteinMatrix.proteinProteinMatrix(betaExclusionMatrix, exclusionNames, ppiMatrix.columnProteins, [])
betaExclusionPPMatrix.plotCombinedMixSignalBetaPlot(imageScale=0.25)

In [None]:
#Figure 1c: print RSS and F values to obtain signifcance of each PPI
#NOTE: Set iBest according to results

iBest=3
for iExclude in fullOutput["modelList"][iBest]['indices']:
    rssBase = fullOutput["modelList"][iBest]["RSS"]
    rrsExl = fullOutput['modelList'][iBest]['RSS_exclusion'][iExclude]
    p = fullOutput["modelList"][iBest]["p"]
    n = len(y)
    print(f"Exclusion {iExclude}: RSS = {rrsExl}, F-statistic = {(rrsExl - rssBase)/(rssBase/(n - p))}")
    

In [None]:
#############################
# Figure 1d: Solve simulated 
# ground truth data using
# correlation, linear regression
# NNLS and the Pool Solver
#############################


print("> Loads and processes the design matrix")
designFile="standardizedData/20250404_PooledIP_15x30/20250404_PooledIP_15x30.mixingMatrix.tsv"
(baits, mixing) = pooledDataset.pooledDataset.readMixingMatrix(designFile)
mixing=(mixing.T/np.max(mixing, axis=1)).T #Normalizes the pools to have max=1
nProts=mixing.shape[1]

print("> Builds the simulations")
noiseSD=0.2
#Builds diagonal simulation with a specified noise level
sim_diag = poolSimulator.poolSimulator(mixing, baits)
sim_diag.buildDiagonalPPIMatrix()
sim_diag.simulateExperiment(noiseSD=noiseSD, noiseType="normal_nonneg")

print("> Builds the solvers")
sigThreshold = 10
solver_cor = correlationSolver.correlationSolver(sim_diag.syntheticData)
solver_ls  = leastSquaresSolver.leastSquaresSolver(sim_diag.syntheticData)
solver_nnls = nnlsSolver.nnlsSolver(sim_diag.syntheticData)
solver_bss = bestSubsetSelectionPoolSolver.NNLSBestSubsetSelectionProteinSolver(sim_diag.syntheticData, stopping_criteria=sigThreshold, model='F_global', maxSize=5)

print("> Solves the diagonal simulations")
ppiMatrix_diag_cor  = parallelProteinSolver.parallelPoolSolver(sim_diag.syntheticData, solver_cor ).solveExperiment()
ppiMatrix_diag_ls   = parallelProteinSolver.parallelPoolSolver(sim_diag.syntheticData, solver_ls  ).solveExperiment()
ppiMatrix_diag_nnls = parallelProteinSolver.parallelPoolSolver(sim_diag.syntheticData, solver_nnls).solveExperiment()
ppiMatrix_diag_bss  = parallelProteinSolver.parallelPoolSolver(sim_diag.syntheticData, solver_bss ).solveExperiment()

#Plots the ppi matrices as row of rasters
scale=2.5
fig, axs = plt.subplots(1, 5, figsize=(5*scale, 1*scale))

#Plots Diagonal PPI simulation
sim_diag.groundTruthPPIs.plotBetaHeatmap(axs[0], highlightBaits=False, sigThreshold=None)
axs[0].set_title("Ground Truth")
ppiMatrix_diag_cor.removeColumns([0]).plotBetaHeatmap(axs[1], highlightBaits=False, sigThreshold=None)
axs[1].set_title("Correlation")
ppiMatrix_diag_ls.removeColumns([0]).plotBetaHeatmap(axs[2], highlightBaits=False, sigThreshold=None)
axs[2].set_title("Linear Regression")                    
ppiMatrix_diag_nnls.removeColumns([0]).plotBetaHeatmap(axs[3], highlightBaits=False, sigThreshold=None)
axs[3].set_title("NNLS")
ppiMatrix_diag_bss.removeColumns([0]).plotBetaHeatmap(axs[4], highlightBaits=False, sigThreshold=None)
axs[4].set_title("Pool Solver")
axs[0].get_yaxis().set_visible(True)
axs[0].get_xaxis().set_visible(True)
axs[0].set_yticks([])
axs[0].set_xticks([])
axs[0].set_ylabel("Bait proteins")
axs[0].set_xlabel("antibodies")

plt.tight_layout()

In [None]:
########Supplementary Figure 2a and 2c##############

#Builds the multi-PPI simulation
maxPPIs=5
sim_ppis = poolSimulator.poolSimulator(mixing, baits)
sim_ppis.buildBandedPPIMatrix(int(nProts/maxPPIs))
sim_ppis.simulateExperiment(noiseSD=noiseSD, noiseType="normal_nonneg")

#Builds signal-strength varying simulation
sim_ston = poolSimulator.poolSimulator(mixing, baits)
sim_ston.buildDiagonalVaryingStrengthMatrix()
sim_ston.simulateExperiment(noiseSD=0.2, noiseType="normal_nonneg")

# Builds the solvers
sigThreshold = 10
solver_cor = correlationSolver.correlationSolver(sim_ppis.syntheticData)
pool_solver = bestSubsetSelectionPoolSolver.NNLSBestSubsetSelectionProteinSolver(sim_ppis.syntheticData, stopping_criteria=sigThreshold, model='F_global', maxSize=5)

#Solve multi-PPI experiment
ppiMatrix_ppis_cor  = parallelProteinSolver.parallelPoolSolver(sim_ppis.syntheticData, solver_cor ).solveExperiment()
ppiMatrix_ppis_bss  = parallelProteinSolver.parallelPoolSolver(sim_ppis.syntheticData, pool_solver ).solveExperiment()

#Solve varying-signal experiment
ppiMatrix_ston_cor  = parallelProteinSolver.parallelPoolSolver(sim_ston.syntheticData, solver_cor ).solveExperiment()
ppiMatrix_ston_bss  = parallelProteinSolver.parallelPoolSolver(sim_ston.syntheticData, solver_bss ).solveExperiment()  


#Plot data as two rows

scale=2.5
fig, axs = plt.subplots(2, 3, figsize=(3*scale, 2*scale))


sim_ppis.groundTruthPPIs.plotBetaHeatmap(axs[0,0], highlightBaits=False, sigThreshold=None)
axs[0,0].set_title("Ground Truth")
ppiMatrix_ppis_cor.removeColumns([0]).plotBetaHeatmap(axs[0,1], highlightBaits=False, sigThreshold=None)
axs[0,1].set_title("Correlation Method (r)")
ppiMatrix_ppis_bss.removeColumns([0]).plotBetaHeatmap(axs[0,2], highlightBaits=False, sigThreshold=None)
axs[0,2].set_title("Pool Solver")
axs[0,0].get_yaxis().set_visible(True)
axs[0,0].set_yticks([])
axs[0,0].set_ylabel("Bait proteins")
axs[0,0].get_yaxis().set_visible(True)
axs[0,0].set_yticks([ (nProts/maxPPIs)*(i+0.5) for i in range(maxPPIs)])
axs[0,0].set_yticklabels([ (i+1) for i in range(maxPPIs)])
axs[0,0].set_ylabel("Number of PPIs")

#Plots Signal-strength varying simulation
sim_ston.groundTruthPPIs.plotBetaHeatmap(axs[1,0], highlightBaits=False, sigThreshold=None)
ppiMatrix_ston_cor.removeColumns([0]).plotBetaHeatmap(axs[1,1], highlightBaits=False, sigThreshold=None)
ppiMatrix_ston_bss.removeColumns([0]).plotBetaHeatmap(axs[1,2], highlightBaits=False, sigThreshold=None)
axs[1,0].get_yaxis().set_visible(True)
nRows=sim_ston.groundTruthPPIs.nRows
nTicks=6
axs[1,0].set_yticks(np.arange(0, nRows, (nRows-1)/(nTicks-1)))
#axs[2,0].set_yticklabels([ 1.0*i/(nTicks-1)/noiseSD for i in range(nTicks)])
axs[1,0].set_yticklabels([ round(1.0*i/(nTicks-1)/noiseSD) for i in range(nTicks)])
axs[1,0].set_ylabel("Signal/Noise")
axs[1,0].get_xaxis().set_visible(True)
axs[1,0].set_xticks([0,nProts-1])
axs[1,0].set_xticklabels([1, nProts])
axs[1,0].set_xlabel("Baits")

plt.tight_layout()

In [7]:
############################# Supplementary Figure 2b ########################
designFile="standardizedData/20250404_PooledIP_15x30/20250404_PooledIP_15x30.mixingMatrix.tsv"
(baits, mixing) = pooledDataset.pooledDataset.readMixingMatrix(designFile)
mixing=(mixing.T/np.max(mixing, axis=1)).T #Normalizes the pools to have max=1
nProts=mixing.shape[1]

nMaxPPI=3
nProts=100
noiseSD=0.2
sigThreshold=5
#number of PPIs given by i
#Builds the simulations
sims = [ poolSimulator.poolSimulator(mixing, baits) for i in range(nMaxPPI+1)]
nSims = len(sims)

#Builds the ground-truth protein-protein interaction matrix
for i in range(nSims):
    sims[i].buildRandomPPIMatrix(nProts,i)
    sims[i].simulateExperiment(noiseSD=0.2, noiseType="normal_nonneg", seed=42)
    sims[i].syntheticData.normalizeByRankedValue(3,type="mean")
    
#Builds the solvers
#Correlation solver
solver_corr = correlationSolver.correlationSolver(sims[0].syntheticData)
pool_solver = bestSubsetSelectionPoolSolver.NNLSBestSubsetSelectionProteinSolver(sims[0].syntheticData, stopping_criteria=sigThreshold, model='F_global', maxSize=4) 

#Solves the synthetic data.
solutions_corr = [parallelProteinSolver.parallelPoolSolver(sim.syntheticData,solver_corr).solveExperiment() for sim in sims] 
solutions_ps = [parallelProteinSolver.parallelPoolSolver(sim.syntheticData,pool_solver).solveExperiment() for sim in sims] 

In [None]:
#Calculates and plots reconstruction rates, sensitivity and specifcity for correlation solver
thresholdList=[0.4,0.5,0.6,0.7,0.8,0.9,0.95]
reconstructionRates = [ [ 1.0*sum([ np.array_equal((solutions_corr[iSim].matrix[iP][1:]>thresholdList[iThreshold]), (sims[iSim].groundTruthPPIs.matrix[iP]>0)) for iP in range(nProts) ])/nProts
 for iThreshold in range(len(thresholdList)) ] for iSim in range(nSims)]
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 4))
im = ax.imshow(reconstructionRates, aspect='auto', cmap='viridis', vmin=0, vmax=1)

# Set ticks and labels
ax.set_xticks(np.arange(len(thresholdList)))
ax.set_xticklabels([]) #thresholdList)
ax.set_yticks(np.arange(nSims))
ax.set_yticklabels([str(i) for i in range(nSims)])

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Reconstruction Rate')

# Annotate each cell with its value
for i in range(nSims):
    for j in range(len(thresholdList)):
        ax.text(j, i, f"{reconstructionRates[i][j]:.2f}", ha="center", va="center", color="w" if reconstructionRates[i][j]<0.5 else "black")

#ax.set_title('Reconstruction Rate, Correlation')
plt.tight_layout()
plt.show()

PPIlist = [[(sol.matrix[:,1:]>sigThreshold).astype(int) for sol in solutions_corr] for sigThreshold in thresholdList]

#############
# Sensitivity #
#############
sensitivity = [[ np.sum(np.logical_and(PPIlist[iThreshold][iSim], sims[iSim].groundTruthPPIs.matrix))/np.sum(sims[iSim].groundTruthPPIs.matrix>0) for iThreshold in range(len(thresholdList)) ] for iSim in range(nSims) ]
ig, ax = plt.subplots(figsize=(8, 4))
original_list = sensitivity[0]
numpy_array = np.array(original_list, dtype=np.float64)
numpy_array[numpy_array==0.0] = np.nan
sensitivity[0] = numpy_array.tolist()
cmap = plt.get_cmap('viridis').copy()
cmap.set_bad(color='grey')
im = ax.imshow(sensitivity, aspect='auto', cmap=cmap, vmin=0, vmax=1)

# Set axis labels


# Set ticks and labels
ax.set_xticklabels([]) 
ax.set_yticks(np.arange(nSims))
ax.set_yticklabels([str(i) for i in range(nSims)])

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Sensitivity')

# Annotate each cell with its value
for i in range(nSims-1):
    for j in range(len(thresholdList)):
        ax.text(j, i+1, f"{sensitivity[i+1][j]:.2f}", ha="center", va="center", color="w" if sensitivity[i+1][j]<0.5 else "black")

#ax.set_title('Sensitivity, Correlation')
plt.tight_layout()



###############
# Specificity #
###############
specificity = [[ np.sum((PPIlist[iThreshold][iSim]==0) & (sims[iSim].groundTruthPPIs.matrix == 0)) / np.sum(sims[iSim].groundTruthPPIs.matrix==0) for iThreshold in range(len(thresholdList)) ] for iSim in range(nSims) ]
ig, ax = plt.subplots(figsize=(8, 4))
im = ax.imshow(specificity, aspect='auto', cmap='viridis', vmin=0.9, vmax=1)

# Set axis labels


# Set ticks and labels
ax.set_xticks(np.arange(len(thresholdList)))
ax.set_xticklabels(thresholdList)
ax.set_yticks(np.arange(nSims))
ax.set_yticklabels([str(i) for i in range(nSims)])

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Specificity')

# Annotate each cell with its value
for i in range(nSims):
    for j in range(len(thresholdList)):
        ax.text(j, i, f"{specificity[i][j]:.2f}", ha="center", va="center", color="w" if specificity[i][j]<0.5 else "black")

plt.tight_layout()


In [None]:
#Calculates and plots reconstruction rates, sensitivity and specifcity for pool solver
#no restriction on beta value
betaThreshold = 0
thresholdList=[2,5,7,10,15,20,30,50]
reconstructionRates = [ [ 1.0*sum([ np.array_equal(np.logical_and(solutions_ps[iSim].sigMatrix[iP][1:]>thresholdList[iThreshold], solutions_ps[iSim].matrix[iP][1:]>betaThreshold), (sims[iSim].groundTruthPPIs.matrix[iP]>0)) for iP in range(nProts) ])/nProts
 for iThreshold in range(len(thresholdList)) ] for iSim in range(nSims)]


import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 4))
im = ax.imshow(reconstructionRates, aspect='auto', cmap='viridis', vmin=0, vmax=1)


# Set ticks and labels
ax.set_xticks(np.arange(len(thresholdList)))
ax.set_xticklabels([]) 
ax.set_yticks(np.arange(nSims))
ax.set_yticklabels([]) 

# Add colorbar
cbar = plt.colorbar(im, ax=ax)

# Annotate each cell with its value
for i in range(nSims):
    for j in range(len(thresholdList)):
        ax.text(j, i, f"{reconstructionRates[i][j]:.2f}", ha="center", va="center", color="w" if reconstructionRates[i][j]<0.5 else "black")

plt.tight_layout()
plt.show()




#For the next two, create explicit lists for what is a PPI and what is not for correct counting

PPIlist = [[np.logical_and((sol.matrix[:,1:]>betaThreshold),(sol.sigMatrix[:,1:]>sigThreshold)).astype(int) for sol in solutions_ps] for sigThreshold in thresholdList]



#############
# Sensitivity #
#############
sensitivity = [[ np.sum(np.logical_and(PPIlist[iThreshold][iSim], sims[iSim].groundTruthPPIs.matrix))/np.sum(sims[iSim].groundTruthPPIs.matrix>0) for iThreshold in range(len(thresholdList)) ] for iSim in range(nSims) ]
ig, ax = plt.subplots(figsize=(8, 4))
original_list = sensitivity[0]
numpy_array = np.array(original_list, dtype=np.float64)
numpy_array[numpy_array==0.0] = np.nan
sensitivity[0] = numpy_array.tolist()
cmap = plt.get_cmap('viridis').copy()
cmap.set_bad(color='grey')
im = ax.imshow(sensitivity, aspect='auto', cmap=cmap, vmin=0, vmax=1)


# Set ticks and labels
ax.set_xticks(np.arange(len(thresholdList)))
ax.set_xticklabels([]) 
ax.set_yticks(np.arange(nSims))
ax.set_yticklabels([]) 

# Add colorbar
cbar = plt.colorbar(im, ax=ax)

# Annotate each cell with its value
for i in range(nSims-1):
    for j in range(len(thresholdList)):
        ax.text(j, i+1, f"{sensitivity[i+1][j]:.2f}", ha="center", va="center", color="w" if sensitivity[i+1][j]<0.5 else "black")
plt.tight_layout()



###############
# Specificity #
###############



specificity = [[ np.sum((PPIlist[iThreshold][iSim]==0) & (sims[iSim].groundTruthPPIs.matrix == 0)) / np.sum(sims[iSim].groundTruthPPIs.matrix==0) for iThreshold in range(len(thresholdList)) ] for iSim in range(nSims) ]
ig, ax = plt.subplots(figsize=(8, 4))
im = ax.imshow(specificity, aspect='auto', cmap='viridis', vmin=0.9, vmax=1)

# Set axis labels

# Set ticks and labels
ax.set_xticks(np.arange(len(thresholdList)))
ax.set_xticklabels(thresholdList)
ax.set_yticklabels([]) 

# Add colorbar
cbar = plt.colorbar(im, ax=ax)


# Annotate each cell with its value
for i in range(nSims):
    for j in range(len(thresholdList)):
        ax.text(j, i, f"{specificity[i][j]:.2f}", ha="center", va="center", color="w" if specificity[i][j]<0.5 else "black")
plt.tight_layout()


In [None]:
#Calculates and plots reconstruction rates, sensitivity and specifcity for pool solver
#only beta>0.5 counted as PPIs
betaThreshold = 0.5
thresholdList=[2,5,7,10,15,20,30,50]
reconstructionRates = [ [ 1.0*sum([ np.array_equal(np.logical_and(solutions_ps[iSim].sigMatrix[iP][1:]>thresholdList[iThreshold], solutions_ps[iSim].matrix[iP][1:]>betaThreshold), (sims[iSim].groundTruthPPIs.matrix[iP]>0)) for iP in range(nProts) ])/nProts
 for iThreshold in range(len(thresholdList)) ] for iSim in range(nSims)]


import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 4))
im = ax.imshow(reconstructionRates, aspect='auto', cmap='viridis', vmin=0, vmax=1)


# Set ticks and labels
ax.set_xticks(np.arange(len(thresholdList)))
ax.set_xticklabels([]) 
ax.set_yticks(np.arange(nSims))
ax.set_yticklabels([]) 

# Add colorbar
cbar = plt.colorbar(im, ax=ax)

# Annotate each cell with its value
for i in range(nSims):
    for j in range(len(thresholdList)):
        ax.text(j, i, f"{reconstructionRates[i][j]:.2f}", ha="center", va="center", color="w" if reconstructionRates[i][j]<0.5 else "black")

plt.tight_layout()
plt.show()




#For the next two, create explicit lists for what is a PPI and what is not for correct counting

PPIlist = [[np.logical_and((sol.matrix[:,1:]>betaThreshold),(sol.sigMatrix[:,1:]>sigThreshold)).astype(int) for sol in solutions_ps] for sigThreshold in thresholdList]



#############
# Sensitivity #
#############
sensitivity = [[ np.sum(np.logical_and(PPIlist[iThreshold][iSim], sims[iSim].groundTruthPPIs.matrix))/np.sum(sims[iSim].groundTruthPPIs.matrix>0) for iThreshold in range(len(thresholdList)) ] for iSim in range(nSims) ]
ig, ax = plt.subplots(figsize=(8, 4))
original_list = sensitivity[0]
numpy_array = np.array(original_list, dtype=np.float64)
numpy_array[numpy_array==0.0] = np.nan
sensitivity[0] = numpy_array.tolist()
cmap = plt.get_cmap('viridis').copy()
cmap.set_bad(color='grey')
im = ax.imshow(sensitivity, aspect='auto', cmap=cmap, vmin=0, vmax=1)


# Set ticks and labels
ax.set_xticks(np.arange(len(thresholdList)))
ax.set_xticklabels([]) 
ax.set_yticks(np.arange(nSims))
ax.set_yticklabels([]) 

# Add colorbar
cbar = plt.colorbar(im, ax=ax)

# Annotate each cell with its value
for i in range(nSims-1):
    for j in range(len(thresholdList)):
        ax.text(j, i+1, f"{sensitivity[i+1][j]:.2f}", ha="center", va="center", color="w" if sensitivity[i+1][j]<0.5 else "black")
plt.tight_layout()



###############
# Specificity #
###############



specificity = [[ np.sum((PPIlist[iThreshold][iSim]==0) & (sims[iSim].groundTruthPPIs.matrix == 0)) / np.sum(sims[iSim].groundTruthPPIs.matrix==0) for iThreshold in range(len(thresholdList)) ] for iSim in range(nSims) ]
ig, ax = plt.subplots(figsize=(8, 4))
im = ax.imshow(specificity, aspect='auto', cmap='viridis', vmin=0.9, vmax=1)

# Set axis labels

# Set ticks and labels
ax.set_xticks(np.arange(len(thresholdList)))
ax.set_xticklabels(thresholdList)
ax.set_yticklabels([]) 

# Add colorbar
cbar = plt.colorbar(im, ax=ax)


# Annotate each cell with its value
for i in range(nSims):
    for j in range(len(thresholdList)):
        ax.text(j, i, f"{specificity[i][j]:.2f}", ha="center", va="center", color="w" if specificity[i][j]<0.5 else "black")
plt.tight_layout()
