In [1]:
import numpy as np
import time

np.random.seed(321)

def normalize_transition_matrix(transition_matrix):
    """Normalizes a given transition matrix or batch thereof such that each column sums up to one.
    
    The given transition matrix is assumed to be stored in a numpy array. 
    """
    
    input_dims = len(transition_matrix.shape)
    assert input_dims == 2 or input_dims == 3, 'Transition matrix must be a least 2D, and appropriately batched!'
    if input_dims == 2:
        return transition_matrix / np.maximum(transition_matrix.sum(0, keepdims=True), 1)
    else:
        return transition_matrix / np.maximum(transition_matrix.sum(1, keepdims=True), 1)

def estim_map(transition_mat, num_samples=500, limit=100, dtype=np.float32, eps=1e-25):
    num_states = transition_mat.shape[-1]
    samples = np.random.randint(limit, size=(num_samples, num_states, num_states), dtype=np.int16).astype(d_type)
    samples = normalize_transition_matrix(samples)
    tmp = T_data * np.log(np.maximum(samples, eps))
    tmp = tmp.sum(axis=(-1,-2))
    tmp -= tmp.max()
    tmp = np.exp(tmp)
    samples =  samples * tmp[:, None, None] 
    S_res = samples.sum(axis=0)
    eta = tmp.sum()
    return (1/eta * S_res).astype(np.float64)

def estim_mle(transition_mat):
    return normalize_transition_matrix(transition_mat)

In [2]:
import resource
def using(point=""):
    usage=resource.getrusage(resource.RUSAGE_SELF)
    return '''%s: usertime=%s systime=%s mem=%s mb
           '''%(point,usage[0],usage[1],
                usage[2]/1024.0 )

print(using("Start"))
d_type = np.float32

print(np.finfo(d_type))

if d_type == np.float16:
    eps = 7e-5
else:
    eps = 1e-25
    
print('eps:', eps)
save = False
mode = 'new'

    
np.random.seed(321)
limit = 100
num_samples = 10000 #for this number of states, expect around 5GB peak RAM usage per 1k samples
num_states = 641
T_data = np.random.randint(limit, size=(num_states, num_states), dtype=np.int16).astype(d_type)

start = time.time()
if mode == 'prev':
    S = np.random.randint(limit, size=(num_samples, num_states, num_states), dtype=np.int16).astype(d_type)
    print('sample shape:', S.shape)
    S = normalize_transition_matrix(S)
    print(using("Normalized Samples"))
    
    S_from_T = normalize_transition_matrix(T_data)
    tmp = T_data * np.log(np.maximum(S, eps))
    print(using("Weighted logs"))
    tmp = tmp.sum(axis=(-1,-2))
    print(using("Summed logs"))
    # Now we have some vector of negative numbers with very large absolute values
    # If we call numpy exp, even float128bit precision will be insufficient to represent the result
    # The result would be rounded to zero, and we would lose all information
    # This problem is more severe the larger the transition matrices are
    # Solution: Shift the results by the maximum entry.
    # This scales the output of exp such that 1 is its maximum output. The other values are not lost!
    # Note that this shift is also applied to the normalization term, so it cancels out automatically!
    b = tmp.max()
    tmp = tmp - b
    print(using("Shifted logs"))
    tmp = np.exp(tmp)
    print(using("Exp eval"))
    print('max output of exp:', tmp.max())
    S_res =  S * tmp[:, None, None] 
    print(using("Weighted samples"))
    S_res = S_res.sum(axis=0)
    print(using("Samples summed"))
    eta = tmp.sum()
    print('eta:', eta)
    # normalize the result, here we divide an array of tiny values by a very small normalization factor eta
    # now our values are large enough again to be represented appropriately by float64
    S_res = (1/eta * S_res).astype(np.float64)
else:
    S_res = estim_map(T_data, num_samples=num_samples)
print(using("Result calculated"))
if save:
    print('results saved')
    np.save('prev_result.npy', S_res)
try:
    S_prev = np.load('prev_result.npy')
    print('Same result as last time?', np.allclose(S_res, S_prev))
except:
    pass

print('Result dtype:', S_res.dtype)
print(S_res.shape)
print('Block execution took', time.time()-start, 'seconds.')
print(using("After"))

Start: usertime=0.290056 systime=0.040853 mem=73.39453125 mb
           
Machine parameters for float32
---------------------------------------------------------------
precision =   6   resolution = 1.0000000e-06
machep =    -23   eps =        1.1920929e-07
negep =     -24   epsneg =     5.9604645e-08
minexp =   -126   tiny =       1.1754944e-38
maxexp =    128   max =        3.4028235e+38
nexp =        8   min =        -max
---------------------------------------------------------------

eps: 1e-25
Result calculated: usertime=28.358989 systime=4.197034 mem=47096.4296875 mb
           
Same result as last time? True
Result dtype: float64
(641, 641)
Block execution took 32.222660541534424 seconds.
After: usertime=28.363185 systime=4.197063 mem=47096.4296875 mb
           
