In [7]:
import numpy as np
from scipy.optimize import linprog

def get_best_stat_dist(P, c):
    """
    Compute best stationary distribution of a transition matrix given
    a cost matrix c.

    Parameters:
    P (numpy.ndarray): transition matrix of shape (n, n)
    c (numpy.ndarray): cost matrix of shape (n, n)

    Returns:
    tuple: (stat_dist, exp_cost)
        stat_dist (numpy.ndarray): vector of shape (n,) corresponding to best stationary distribution of P with respect to c.
        exp_cost (float): the expected cost of stat_dist with respect to c.
    """
    # Set up constraints.
    n = P.shape[0]
    c = c.ravel()  # Reshape to a 1D array to match Matlab code
    Aeq = np.vstack([P.T - np.eye(n), np.ones((1, n))])
    beq = np.zeros(n + 1)
    beq[-1] = 1
    lb = np.zeros(n)
    
    # Solve linear program.
    options = {'disp': False, 'tol': 1e-6}
    res = linprog(c, A_eq=Aeq, b_eq=beq, bounds=[(0, None) for _ in range(n)], options=options)
    
    # In case the solver fails due to numerical underflow, try with rescaling.
    alpha = 1
    while not res.success and alpha >= 1e-10:
        alpha /= 10
        res = linprog(c, A_eq=alpha * Aeq, b_eq=alpha * beq, bounds=[(0, None) for _ in range(n)], options=options)
    
    if not res.success:
        raise ValueError("Failed to compute stationary distribution.")
    
    stat_dist = res.x
    exp_cost = res.fun
    
    return stat_dist, exp_cost

P = np.array([[0.1555,0,0.1148,0,0,0.0471,0.0009,0.0850,0.0133,0,0,0.1291,0,0.0832,0.1322,0,0,0.0605,0,0.1783],
[0.1555,0,0.1148,0,0,0.0471,0.0843,0.0017,0.0133,0,0,0.1291,0,0.0832,0.1322,0,0,0.0605,0,0.1783],
[0.0158,0,0.1292,0,0,0.1100,0.1365,0,0.0251,0,0.0121,0.0891,0,0.1513,0.0978,0,0,0.1216,0,0.1113],
[0.0158,0,0.1292,0,0,0.1100,0.1365,0,0.0251,0,0.0955,0.0058,0,0.1513,0.0978,0,0,0.1216,0,0.1113],
[0.0870,0,0.0606,0,0,0.1211,0,0.0607,0.0872,0,0,0.1889,0,0.1306,0.1818,0.0480,0,0.0221,0,0.0120],
[0.0870,0,0.0606,0,0,0.1211,0.0354,0.0253,0.0872,0,0,0.1889,0,0.1306,0.2298,0,0,0.0221,0,0.0120],
[0.2338,0,0.0814,0,0,0.0365,0,0.0348,0.0301,0,0,0.0072,0,0.1192,0.0714,0.0361,0,0.1658,0,0.1836],
[0.2338,0,0.0814,0,0,0.0365,0.0348,0,0.0301,0,0.0072,0,0,0.1192,0.1075,0,0.0051,0.1607,0,0.1836],
[0.0564,0,0.1208,0,0,0.0532,0.0758,0,0.1104,0,0.0313,0.2513,0,0.2348,0.0219,0,0,0.0023,0,0.0416],
[0.0564,0,0.1208,0,0,0.0532,0.0758,0,0.1104,0,0.1147,0.1680,0,0.2348,0.0219,0,0,0.0023,0,0.0416],
[0.0748,0,0.0710,0,0,0.1351,0.0775,0.0253,0.1163,0,0,0.1295,0,0.1171,0.0770,0,0,0.1077,0,0.0687],
[0.0748,0,0.0710,0,0,0.1351,0.1028,0,0.1163,0,0.0580,0.0715,0,0.1171,0.0770,0,0,0.1077,0,0.0687],
[0.0204,0,0.0117,0,0,0.0337,0.1454,0.0950,0.1937,0,0,0.1487,0,0.1252,0.0454,0,0,0.0297,0,0.1509],
[0.0204,0,0.0117,0,0,0.0337,0.2288,0.0116,0.1937,0,0,0.1487,0,0.1252,0.0454,0,0,0.0297,0,0.1509],
[0.0184,0,0.1518,0,0,0.1654,0.0695,0.0074,0.0874,0,0,0.0596,0,0.1648,0.0895,0,0,0.1012,0,0.0849],
[0.0184,0,0.1518,0,0,0.1654,0.0769,0,0.0874,0,0.0596,0,0,0.1648,0.0895,0,0.0163,0.0848,0,0.0849],
[0.2121,0,0.0662,0,0,0.0236,0.0872,0.0640,0.0467,0,0,0.1745,0,0.1131,0.0043,0,0,0.1928,0,0.0153],
[0.2121,0,0.0662,0,0,0.0236,0.1512,0,0.0467,0,0.0193,0.1552,0,0.1131,0.0043,0,0,0.1928,0,0.0153],
[0.1067,0,0.0376,0,0.0298,0.2446,0.0644,0,0.0169,0,0.1207,0,0,0.2663,0.0342,0,0.0064,0,0,0.0724],
[0.1067,0,0.0376,0,0.1132,0.1613,0.0644,0,0.0169,0,0.1207,0,0,0.2663,0.0342,0,0.0064,0,0,0.0724]])

c = np.array([[1.9369,3.7898],
[1.8408,2.3808],
[1.0780,0.0109],
[0.5031,0.1017],
[0.5026,0.4278],
[1.0777,0.2013],
[2.2210,0.3809],
[0.4153,0.2995],
[0.9999,0.0069],
[3.6520,1.6171]])
# get_best_stat_dist(P, c)
get_best_stat_dist(np.load('P.npy'),np.load('c.npy'))

  res = linprog(c, A_eq=Aeq, b_eq=beq, bounds=[(0, None) for _ in range(n)], options=options)


(array([ 0.09597561,  0.        ,  0.07839435,  0.        ,  0.01096115,
         0.08750061,  0.09110821,  0.017162  ,  0.0804365 , -0.        ,
         0.02330818,  0.10847691, -0.        ,  0.14690716,  0.08058266,
         0.00381696,  0.00077181,  0.07773943,  0.        ,  0.09685847]),
 0.7264478450557728)