In [1]:
import numpy as np
from scipy.linalg import lu
from scipy.sparse.linalg import svds
from numpy.linalg import svd

In [2]:
def check_spanrd(vectors, d):
    """
    Inputs:
        - vectors (array): matrix (N, d)
        - d (int): dimension of the space to be spanned
    Return:
        - True or False
    """
    # https://math.stackexchange.com/questions/56201/how-to-tell-if-a-set-of-vectors-spans-a-space
    # https://stackoverflow.com/questions/15638650/is-there-a-standard-solution-for-gauss-elimination-in-python
    pl, u = lu(vectors, permute_l=True)
    rank = np.linalg.matrix_rank(u)
    return d == int(rank)

def span(vectors):
    
    d = vectors.shape[1]
    for i in range(d):
        if check_spanrd(vectors, d - i):
            return d - i

In [3]:
files = ["/home/andrea/git/lrcb/problem_data/jester/jester_d257_span186_L5.61_S2.08_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d11_span10_L4.59_S1.15_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d21_span20_L3.15_S1.71_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d41_span40_L3.50_S1.93_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d61_span54_L3.17_S2.20_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d81_span70_L3.21_S2.22_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d101_span84_L3.56_S2.33_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d121_span101_L3.29_S2.39_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d141_span111_L3.25_S2.41_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d161_span115_L3.21_S2.54_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d181_span125_L3.29_S2.65_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d201_span131_L3.19_S2.92_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d221_span129_L3.50_S3.11_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/jester_d241_span137_L3.31_S2.85_hls0.00000.npz"]
files = ["/home/andrea/git/lrcb/problem_data/jester/257/jester_d257_span186_L5.61_S2.08_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/257/jester_d201_span131_L3.19_S2.92_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/257/jester_d221_span129_L3.50_S3.11_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/257/jester_d241_span137_L3.31_S2.85_hls0.00000.npz"]

#dims = [257, 11, 21, 41, 61, 81, 101, 121, 141, 161, 181, 201, 221, 241]
dims = [257, 201, 221, 241]

features = {}
thetas = {}

for file, d in zip(files, dims):
    
        f = np.load(file)
        features[d] = f['features']
        thetas[d] = f['theta']
        print("Loaded d={}".format(d))
        del(f)
print()

Loaded d=257
Loaded d=201
Loaded d=221
Loaded d=241



In [4]:
# remove useless features

tol = 1e-8  # threshold to consider an eigenvalue equal to zero

new_features = {}
new_thetas = {}

for d in dims:
    
    print("Starting d={}".format(d))
    fmat = features[d].reshape(-1, d)
    
    U, s, Vt = svd(fmat, full_matrices=False)
    sp = np.sum(s > tol)
    print("[d={0}] span: {1}".format(d,sp))
    s = s[:sp]
    U = U[:, :sp]
    Vt = Vt[:sp, :]

    s = np.diag(s)
    U = np.dot(U, s)
    M = U.dot(Vt)
    rmse = np.sqrt(np.mean(np.abs(M - fmat) ** 2))
    print("[d={0}] Reconstruction rmse: {1}".format(d, rmse))
        
    # create new features/parameters
    new_features[sp] = U.reshape(features[d].shape[0], features[d].shape[1], sp)
    new_thetas[sp] = Vt.dot(thetas[d])
    
    # check errors
    old_mu = features[d].dot(thetas[d])
    new_mu = new_features[sp].dot(new_thetas[sp])
    err = np.abs(old_mu - new_mu)
    print("[d={0}] mu error: max {1} - mean {2}".format(d, np.max(err), np.mean(err)))
    
    del(old_mu)
    del(new_mu)
    del(err)
    
    print()

Starting d=257
[d=257] span: 188
[d=257] Reconstruction rmse: 5.700174909861744e-08
[d=257] mu error: max 5.960464477539062e-07 - mean 5.222723231668169e-08

Starting d=201
[d=201] span: 132
[d=201] Reconstruction rmse: 3.551129523771124e-08
[d=201] mu error: max 5.364418029785156e-07 - mean 5.1338169271275547e-08

Starting d=221
[d=221] span: 130
[d=221] Reconstruction rmse: 3.340858967249005e-08
[d=221] mu error: max 4.76837158203125e-07 - mean 4.98829884065799e-08

Starting d=241
[d=241] span: 140
[d=241] Reconstruction rmse: 3.279894045249421e-08
[d=241] mu error: max 4.76837158203125e-07 - mean 4.9307779192986345e-08



In [5]:
# filter gaps

thresh = 0.1

# ground truth
d_gt = max(new_features.keys())
mu_gt = new_features[d_gt].dot(new_thetas[d_gt])
gap_gt = np.max(mu_gt, axis=1)[:, np.newaxis] - mu_gt
gap_gt[gap_gt == 0] = 100
print("gap min:", gap_gt.min())
gap_gt = np.min(gap_gt, axis=1)

# indexes of contexts with minimum gap above threshold
good_contexts = gap_gt > thresh
print("# contexts with gap_min > {0}: {1}".format(thresh, np.sum(good_contexts)))

# filter
for d in new_features.keys():
    new_features[d] = new_features[d][good_contexts, :, :]

n_contexts = np.sum(good_contexts)
mu_gt = mu_gt[good_contexts, :]

gap min: 2.9802322e-07
# contexts with gap_min > 0.1: 3297


In [14]:
# check misspecification

eps = 0.05 # threshold for low misspecification

low_eps_contexts = np.ones(n_contexts, dtype=np.bool) 

for d in new_features.keys():
    mu = new_features[d].dot(new_thetas[d])
    err = np.abs(mu - mu_gt)
    print("[d={0}] error wrt ground truth: max {1} - mean {2}".format(d, err.max(), np.mean(err)))
    idx = np.max(err, axis=1) < eps  # contexts with low misspecification
    print("# contexts with eps < {0}: {1}".format(eps, np.sum(idx)))
    low_eps_contexts *= idx  # make sure all representations have low misspecification
    
    del(mu)
    del(err)

print("# contexts with eps < {0} in all representations: {1}".format(eps, np.sum(low_eps_contexts)))

# filter
for d in new_features.keys():
    new_features[d] = new_features[d][low_eps_contexts, :, :]

n_contexts = np.sum(low_eps_contexts)

[d=132] error wrt ground truth: max 0.2978231608867645 - mean 0.018024424090981483
# contexts with eps < 0.05: 1142
[d=130] error wrt ground truth: max 0.34656280279159546 - mean 0.01759651117026806
# contexts with eps < 0.05: 1239
[d=140] error wrt ground truth: max 0.3155077397823334 - mean 0.0179055854678154
# contexts with eps < 0.05: 1368
[d=188] error wrt ground truth: max 0.0 - mean 0.0
# contexts with eps < 0.05: 3297
# contexts with eps < 0.05 in all representations: 679


In [18]:
# check span optimal arms

span_opt = {}

for d in new_features.keys():
    
    mu = new_features[d].dot(new_thetas[d])
    astar = np.argmax(mu, axis=1)
    fstar = np.array([new_features[d][x, astar[x]] for x in range(n_contexts)])

    span = d
    for i in range(d):
        if check_spanrd(fstar, d - i):
            span = d - i
            break
            
    span_opt[d] = span
    
    outer = np.matmul(fstar.T, fstar) / n_contexts
    lambda_hls = np.linalg.eigvals(outer).min()
    
    print("[d={0}] span optimal arms: {1} - lambda HLS: {2}".format(d, span, lambda_hls))
    
    del(mu)
    del(astar)
    del(fstar)
    del(outer)

[d=132] span optimal arms: 126 - lambda HLS: -5.865825880579933e-11
[d=130] span optimal arms: 120 - lambda HLS: -6.367945060148372e-10
[d=140] span optimal arms: 133 - lambda HLS: -9.735783013109511e-11
[d=188] span optimal arms: 177 - lambda HLS: -7.381848604604002e-11


In [17]:
# save

for d in new_features.keys():
    np.savez_compressed('jester_post_d{0}_span{1}.npz'.format(d,span_opt[d]), 
                        features=new_features[d], theta=new_thetas[d])