# A Bayesian Network for Stem Cell Differentiation Tutorial

A Bayesian Network is a graphical way of representing probabilistic dependencies
between variables in a system, and is a way to hypothesize causal relations between observable
variables and the behavior of the system. In this project you will use a Bayesian Network to model how
protein expression influences the differentiation of induced Pluripotent Stem Cells (iPSCs). iPSCs are
similar to embryonic stem (ES) cells in that they are pluripotent (can differentiate into hundreds of
different cell types), but in contrast to ES cells, pluripotent stem cells are artificially derived from a nonpluripotent
adult cell, and therefore have possible immunological and ethical advantages compared to
ES cells. In an important advance in regenerative medicine, iPSCs were first produced in 2006 from
mouse cells and in 2007 from human cells in a series of experiments in Yamanaka’s lab at Kyoto
University. 

In [1]:
# Dependencies
import math
from itertools import *
from operator import itemgetter

In [2]:
# Returns a generator for the power set
def powerset(iterable):
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

In [3]:
# a fast approximation for factorials of very large n's
def stirling_approx(n):
    return (n+.5)*math.log(n,2)-n+.5*math.log(2*math.pi,2)

In [4]:
# Returns a binary state table
def getBinaryStateTable(size, alphabet = [0,1]):
    if size == 1:
        return [[character] for character in alphabet]
        
    statetable_prev = getBinaryStateTable(size - 1)
    statetable_curr = [unfinished_state + [character] for character in alphabet for unfinished_state in statetable_prev]
    return statetable_curr

# Test
getBinaryStateTable(3)

[[0, 0, 0],
 [1, 0, 0],
 [0, 1, 0],
 [1, 1, 0],
 [0, 0, 1],
 [1, 0, 1],
 [0, 1, 1],
 [1, 1, 1]]

In [5]:
# Calculates the log likelihood
def log2P(statetable_dict):
    total = 0
    
    for key in statetable_dict.keys():
        outcomes = statetable_dict[key][0]
        num = 0; denom = 0; partial_sum = 0
        
        if (max(outcomes) > 50):
            num = sum([stirling_approx(o) for o in outcomes])
            denom = stirling_approx(sum(outcomes))
            partial_sum = num - denom
        else:
            num = reduce(lambda x,y:x*y, [math.factorial(o) for o in outcomes])
            denom = float(math.factorial(sum(outcomes)))
            partial_sum = math.log(num/denom, 2)
        total += partial_sum
    
    return total-math.log(len(statetable_dict.keys()), 2)*55

In [6]:
# Creates a bayesian network
# Inputs:
# - features_list: a list of samples, with columns 0-5 representing features, and column 6 = cell state 
# - desired_features_index: which indices to choose from the features_list as features
# - total # of cell states there are
def bayesian_network(features_list, desired_features_index, sample_outcomes):    
    statetable = [tuple(state) for state in getBinaryStateTable(len(desired_features_index))]
    statetable_dict = {}
    for state in statetable:
        statetable_dict[state] = [[0]*len(sample_outcomes), [0]*len(sample_outcomes)]
    
    # Counts the total number of samples of being in a given cell state for each feature state
    for sample in features_list:
        state_sample = tuple([sample[i] for i in desired_features_index])
        N = sample[6] # cell state
        statetable_dict[state_sample][0][N] += 1;
    
    # Calculates the probabilities
    for key in statetable_dict.keys():
        tot = float(sum(statetable_dict[key][0]))
        for i in range(len(sample_outcomes)):
            statetable_dict[key][1][i] = statetable_dict[key][0][i]/tot
    
    # calculates the loglikelihood
    loglikelihood = log2P(statetable_dict)
    
    return desired_features_index, loglikelihood, statetable_dict

In [7]:
# creates a list of all possible bayesian models from a matrix of features
# In this assignment, there are 6 features, so 2^6-1 possible models to create
# sorts the list of models by the highest likelihood
def exhaustive_model_search(features_list, parameters):
    allmodels = list(powerset(parameters))
    bayesian_models = []
    for model in allmodels:
        if list(model) != []:
            bayesian_models.append(bayesian_network(features_list, list(model), [0,1]))
    return sorted(bayesian_models, key=lambda x: x[1], reverse=True)

In [9]:
# Obtains all possible bayesian models from ipsc1.dat
# 0 - OCT4
# 1 - SOX2
# 2 - REX1
# 3-5 - Not given, but end up being the least important variables
# 6 - Cell state. (N=0
iPSC1 = open("ipsc1.dat", "r")
iPSC1_list = iPSC1.readlines()
iPSC1_list = [map(int, line.rstrip().split(" ")) for line in iPSC1_list]
bayesian_models = exhaustive_model_search(iPSC1_list, [0,1,2,3,4,5])

In [10]:
# Prints all the models
print("All Possible Bayesian Models in order of Loglikelihood")
for model in bayesian_models:
    print "Model:", model[0], "\nLoglikelihood:", model[1], "\n"

All Possible Bayesian Models in order of Loglikelihood
Model: [0, 1, 2] 
Loglikelihood: -1269.16007054 

Model: [0, 1, 2, 3, 4, 5] 
Loglikelihood: -1281.07097343 

Model: [0, 1, 2, 3] 
Loglikelihood: -1292.32280551 

Model: [0, 1, 2, 4] 
Loglikelihood: -1295.13344659 

Model: [0, 1, 2, 5] 
Loglikelihood: -1297.60423827 

Model: [0, 1, 2, 3, 4] 
Loglikelihood: -1299.08266678 

Model: [0, 1, 2, 3, 5] 
Loglikelihood: -1303.15706101 

Model: [0, 1, 2, 4, 5] 
Loglikelihood: -1306.22691026 

Model: [0, 1] 
Loglikelihood: -1317.74494217 

Model: [0, 2] 
Loglikelihood: -1332.14024948 

Model: [0, 1, 3] 
Loglikelihood: -1351.91880982 

Model: [0, 1, 4] 
Loglikelihood: -1354.54992623 

Model: [0, 1, 5] 
Loglikelihood: -1356.62437134 

Model: [0, 2, 3] 
Loglikelihood: -1368.47109338 

Model: [0, 2, 4] 
Loglikelihood: -1369.58040741 

Model: [0, 2, 5] 
Loglikelihood: -1370.77250079 

Model: [0] 
Loglikelihood: -1373.52167752 

Model: [0, 1, 3, 4] 
Loglikelihood: -1375.51142674 

Model: [0, 1, 3, 5

In [30]:
# The best model
best_model = bayesian_models[0][2]
for row in best_model.iteritems():
    print row

((1, 1, 0), [[192, 16], [0.9230769230769231, 0.07692307692307693]])
((0, 1, 1), [[14, 149], [0.08588957055214724, 0.9141104294478528]])
((1, 0, 0), [[109, 90], [0.5477386934673367, 0.45226130653266333]])
((0, 0, 1), [[20, 155], [0.11428571428571428, 0.8857142857142857]])
((1, 0, 1), [[98, 96], [0.5051546391752577, 0.4948453608247423]])
((0, 0, 0), [[89, 109], [0.4494949494949495, 0.5505050505050505]])
((0, 1, 0), [[88, 99], [0.47058823529411764, 0.5294117647058824]])
((1, 1, 1), [[184, 28], [0.8679245283018868, 0.1320754716981132]])
