In [2]:
import numpy as np
import matplotlib.pyplot as plt
# Code for getting the transition probability of infections inside a household SIS model

beta  = np.random.rand() # transmission rate
gamma = np.random.rand() # recovery rate
I     = np.random.rand() # percentage of infected people
n     = 4                # people in the household

def factorial(n):
    res = 1
    while n > 1:
        res *= n
        n -= 1
    return res

def nchoosek(n,k):
    return factorial(n) / (factorial(n-k) * factorial(k))

def P(i,I):
    ''' 
    Probability of getting infected globally
    i: i infected individuals in a household
    I: Infected population percentage
    '''
    return 1 - (1-beta)**i * (1-beta*I)

def q(j,i):
    '''
    Probability of j agents getting infected
    j: new infected agents in a household
    i: already infected agents in a household
    '''
    return nchoosek(n-i,j) * P(i,I)**j * (1 - P(i,I))**((n-i)-j)

def r(k,i):
    '''
    Probability of k agents recovering
    k: agents recovering in the household
    i: already infected agents in a household
    '''
    return nchoosek(i,k) * gamma**k * (1-gamma)**(i-k)

def TransitionRate(i,j,k):
    '''
    Transition probabilities 
    i: already infected
    j: newly infected
    k: recovered
    '''
    return q(j,i) * r(k,i)

def MarkovMatrix(n):
    M = np.zeros((n,n))
    for row in range(n):
        for col in range(n):
            M[row,col] = TransitionRate(row,col,min(row,col))
    return M

def check_row_sums(matrix):
    """
    Check if each row of a matrix sums up to 1.
    """
    row_sums = np.sum(matrix, axis=1)
    return row_sums

M = MarkovMatrix(4)
print(f'beta = {beta}')
print(f'gamma = {gamma}')
print(f'I = {I}')
print(M)
print(check_row_sums(M))

matrix = M

# Sum of row values
row_sums = np.sum(matrix, axis=1)

print('Sum of rows', row_sums)

beta = 0.0933906062682831
gamma = 0.29316153896813146
I = 0.0893794595878512
[[9.67026925e-01 3.25596581e-02 4.11104634e-04 2.30697236e-06]
 [5.13641069e-01 7.17678289e-02 8.05919752e-03 3.01670287e-04]
 [3.31925400e-01 1.24931468e-01 2.93888923e-03 2.22252425e-04]
 [2.60964122e-01 1.14703714e-01 8.40277831e-03 1.36790462e-04]]
[1.         0.59376977 0.46001801 0.3842074 ]
Sum of rows [1.         0.59376977 0.46001801 0.3842074 ]
