In [3]:
import numpy as np
import time

eps = 1e-25

def normalize_transition_matrix(transition_matrix):
    """Normalizes a given transition matrix or batch thereof such that each column sums up to one.
    
    The given transition matrix is assumed to be stored in a numpy array. 
    """
    
    input_dims = len(transition_matrix.shape)
    assert input_dims == 2 or input_dims == 3, 'Transition matrix must be a least 2D, and appropriately batched!'
    if input_dims == 2:
        return transition_matrix / np.maximum(transition_matrix.sum(0, keepdims=True), 1)
    else:
        return transition_matrix / np.maximum(transition_matrix.sum(1, keepdims=True), 1)
    
limit = 100
num_samples = 1000
num_states = 641
S = np.random.randint(limit, size=(num_samples, num_states, num_states))
print('sample shape:', S.shape)
S = normalize_transition_matrix(S)
T_data = np.random.randint(limit, size=(num_states, num_states))
S_from_T = normalize_transition_matrix(T_data)

start = time.time()
tmp_log = T_data * np.log(np.maximum(S, eps))
tmp_log = tmp_log.sum(axis=(-1,-2))
# Now we have some vector of negative numbers with very large absolute values
# If we call numpy exp, even float128bit precision will be insufficient to represent the result
# The result would be rounded to zero, and we would lose all information
# This problem is more severe the larger the transition matrices are
# Solution: Shift the results by the maximum entry.
# This scales the output of exp such that 1 is its maximum output. The other values are not lost!
# Note that this shift is also applied to the normalization term, so it cancels out automatically!
b = tmp_log.max()
tmp_log_shift = tmp_log - b
tmp_exp = np.exp(tmp_log_shift)
print('max output of exp:', tmp_exp.max())
S_weighted =  S * tmp_exp[:, None, None] 
S_expectation = S_weighted.sum(axis=0)
eta = tmp_exp.sum()
print('eta:', eta)
# normalize the result, here we divide an array of tiny values by a very small normalization factor eta
# now our values are large enough again to be represented appropriately by float64
S_res = (1/eta * S_expectation).astype(np.float64)
# check if the result is the same as the costly computation without log shift
print('Result dtype:', S_res.dtype)
print(S_res.shape)
print('Block execution took', time.time()-start, 'seconds.')

sample shape: (1000, 641, 641)
max output of exp: 1.0
eta: 1.0
Result dtype: float64
(641, 641)
Block execution took 24.482236862182617 seconds.
