In [1]:
import numpy as np
from numpy.linalg import svd
from numpy import sqrt
from scipy.optimize import linear_sum_assignment

from arsenal.maths import random_dist
from mdp import DiscountedMDP, MRP, random_MDP, random_MDP_forward_reward
from utils import *

In [79]:
def random_planted_MDP(S, A, count, epsilon, gamma=0.95, b = None):
    
    if b is None: b = S

    phi = np.zeros((S,A))
    for s in range(S):
        phi[s,np.random.randint(A)] = 1
    
    alive = np.arange(S)[:count]
    dead = np.arange(S)[count:]
    
    s0 = np.zeros(S)
    
    s0[alive] = random_dist(count)
#     s0[dead] = epsilon * random_dist(S - count)
#     s0 /= np.sum(s0)

    
    P = np.zeros((S,A,S))
    for s in range(S):
        for a in range(A):
            if phi[s,a] == 1:
                connected = np.random.choice(alive, size=min(count, b), replace=False)
                P[s,a,connected] = (1-epsilon) * random_dist(min(count, b))
                
                connected = np.random.choice(dead, size=b, replace=False)
                P[s,a,connected] = epsilon * random_dist(b)
                
                P[s,a,:] /= np.sum(P[s,a,:])
            else:
                connected = np.random.choice(S, size=b, replace=False)
                P[s,a,connected] = random_dist(b)

    R = np.zeros((S,A,S))

    mdp = DiscountedMDP(
        s0 = s0,
        R = R,
        P = P,
        gamma = gamma,
    )

    return mdp, phi

In [80]:
def generate_isomorphic_mdp(mdp):
    S, A, _ = mdp.P.shape
    perm = np.random.permutation(S)
    Pi = np.zeros((S,S))
    I_a = np.eye(A)
    
    for s in range(S):
        Pi[s, perm[s]] = 1
    
    P = np.transpose(mdp.P, (2, 1, 0))
    P = P.reshape((S, A * S))
    P_hat = Pi.T @ P @ np.kron(I_a, Pi)
    P_hat = P_hat.reshape((S, A, S))
    P_hat = np.transpose(P_hat, (2, 1, 0))
    
    mdp_hat = DiscountedMDP(
        s0 = Pi.T @ mdp.s0,
        R = np.zeros((S,A,S)),
        P = P_hat,
        gamma = mdp.gamma,
    )
    
    return mdp_hat, Pi
    

In [81]:
def factor_approximation(M, u, t):
    D = np.diag(u)
    I = np.argwhere(u >= t).flatten()

    M = subindex(M, I)
    D = subindex(D, I)

    _, _, Vh = svd(sqrt(D) @ M @ invsqrt(D))
    V = Vh.T

    colsum = np.sum(V, axis = 0, keepdims = True)
    colsum = np.where(colsum > 0, 1, -1)
    return V * colsum, I

In [82]:
def policy_learning(mdp, phi, t, samples):
    m = len(samples)
    S, A = phi.shape
    Phi = to_policy_matrix(phi)
    M = (mdp | phi).P_with_reset()
    u = (mdp | phi).d()
    
    N = np.zeros((S,S))
    for i in range(m):
        s, sp = samples[i]
        N[s, sp] += 1
    
    u_hat = np.sum(N, axis = 1) / m
    
    denom = np.sum(N, axis = 1, keepdims = True)
    denom[denom == 0] = np.inf
    M_hat = N / denom
#     M_hat = N / np.sum(N, axis = 1, keepdims = True)
#     M_hat = np.nan_to_num(M_hat)
    
    V, I = factor_approximation(M, u, t)
    try:
        V_hat, I_hat = factor_approximation(M_hat, u_hat, t)
    except np.linalg.LinAlgError:
        return None, None
    
    _, perm = linear_sum_assignment(-1 * V @ V_hat.T)
    Pi = np.zeros((S,S))

    I_C = np.setdiff1d(np.arange(S), I)
    I_hat_C = np.setdiff1d(np.arange(S), I_hat)

    dummy_perm = np.random.permutation(I_C.size)
    
    for s in range(S):
        if s in I:
            index = np.where(I == s)[0][0]
            index = perm[index]
            Pi[s, I_hat[index]] = 1
        else:
            index = np.where(I_C == s)[0][0]
            index = dummy_perm[index]
            Pi[s, I_hat_C[index]] = 1
            
            
    Phi_hat = np.kron(np.eye(A), Pi.T) @ Phi @ Pi
    return from_policy_matrix(Phi_hat), Pi
    

In [83]:
#####################################################################

In [100]:
def run_trial(S, A, count, epsilon, b, sample_total = 10000, gamma = 0.95, planted = True):
    if planted:
        mdp, phi = random_planted_MDP(S, A, count, epsilon, gamma, b)
    else:
        mdp = random_MDP(S, A, gamma, b)
        vi = mdp.solve_by_policy_iteration()
        phi = vi['policy']
    
    mdp_hat, Pi_star = generate_isomorphic_mdp(mdp)
    
    Phi = to_policy_matrix(phi)
    Phi_star = np.kron(np.eye(A), Pi_star.T) @ Phi @ Pi_star
    phi_star = from_policy_matrix(Phi_star)
    
    samples = draw_samples(mdp_hat, phi_star, sample_total)
    
    if planted:
        u = np.sort(state_occupancy(mdp, phi))[::-1]
        t = (u[count-1] + u[count]) / 2
        gap = np.abs(u[count-1] - u[count])
    else:
        t = 0.0
        gap = -1

    phi_hat, Pi = policy_learning(mdp, phi, t, samples)
    
    #Random permutation if divide-by-zero error from insufficient samples
    if phi_hat is None:
        perm = np.random.permutation(S)
        Pi = np.zeros((S,S))
        for s in range(S):
            Pi[s, perm[s]] = 1
        Phi_hat = np.kron(np.eye(A), Pi.T) @ Phi @ Pi
        phi_hat = from_policy_matrix(Phi_hat)

    d = full_occupancy(mdp, phi)
    d_hat = full_occupancy(mdp_hat, phi_hat)

    return tv(np.kron(np.eye(A), Pi_star.T) @ d.T.flatten(), d_hat.T.flatten()), gap

In [101]:
def run_trials(trials, S, A, count, epsilon, b, sample_total = 10000, gamma = 0.95, planted = True):
    tvs = []
    gaps = []
    for _ in range(trials):
        diff, gap = run_trial(S, A, count, epsilon, b, sample_total, gamma, planted)
        tvs.append(diff)
        gaps.append(gap)
    tvs = np.array(tvs)
    gaps = np.array(gaps)
    return np.mean(tvs), np.std(tvs)

In [102]:
#####################################################################

In [103]:
trials = 20
A = 5
b = 5
S = 100
gamma = 0.95

In [104]:
#Planted runs

count = 5
epsilon = 0.0001
planted = True

for sample_total in [1e3, 1e4, 1e5]:
    print(run_trials(trials, S, A, count, epsilon, b, sample_total, gamma, planted))

(0.14592243414676903, 0.29726027171164277)
(0.07422775799168238, 0.2219057750954345)
(0.00031294499037486133, 8.376660636339191e-05)


In [105]:
#Unplanted runs

count = 0
epsilon = 0
planted = False

for sample_total in [1e3, 1e4, 1e5]:
    print(run_trials(trials, S, A, count, epsilon, b, sample_total, gamma, planted))

  


(0.8800934186455953, 0.03355774375931906)
(0.781238323122904, 0.11284426048561891)
(0.22991348754150995, 0.17774927314329664)
