In [1]:
import numpy as np
from scipy.linalg import lu
from scipy.sparse.linalg import svds
from numpy.linalg import svd

In [2]:
def check_spanrd(vectors, d):
    """
    Inputs:
        - vectors (array): matrix (N, d)
        - d (int): dimension of the space to be spanned
    Return:
        - True or False
    """
    # https://math.stackexchange.com/questions/56201/how-to-tell-if-a-set-of-vectors-spans-a-space
    # https://stackoverflow.com/questions/15638650/is-there-a-standard-solution-for-gauss-elimination-in-python
    pl, u = lu(vectors, permute_l=True)
    rank = np.linalg.matrix_rank(u)
    return d == int(rank)

def span(vectors):
    
    d = vectors.shape[1]
    for i in range(d):
        if check_spanrd(vectors, d - i):
            return d - i

In [3]:
files = ["/home/andrea/git/lrcb/problem_data/jester/33/1/jester_d33_span33_L4.95_S4.09_hls0.00186.npz",
        "/home/andrea/git/lrcb/problem_data/jester/33/1/jester_d31_span23_L4.13_S1.85_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/33/1/jester_d29_span26_L4.08_S1.96_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/33/1/jester_d27_span24_L3.97_S2.11_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/33/1/jester_d25_span20_L4.18_S2.00_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/33/1/jester_d23_span16_L5.13_S2.10_hls0.00000.npz",
        "/home/andrea/git/lrcb/problem_data/jester/33/1/jester_d21_span17_L4.54_S1.71_hls0.00000.npz"]

#dims = [257, 11, 21, 41, 61, 81, 101, 121, 141, 161, 181, 201, 221, 241]
dims = [33, 31, 29, 27, 25, 23, 21]

features = {}
thetas = {}

for file, d in zip(files, dims):
    
        f = np.load(file)
        features[d] = f['features']
        thetas[d] = f['theta']
        print("Loaded d={}".format(d))
        del(f)
print()

Loaded d=33
Loaded d=31
Loaded d=29
Loaded d=27
Loaded d=25
Loaded d=23
Loaded d=21



In [4]:
# remove useless features

tol = 1e-8  # threshold to consider an eigenvalue equal to zero

new_features = {}
new_thetas = {}

for d in dims:
    
    print("Starting d={}".format(d))
    fmat = features[d].reshape(-1, d)
    
    U, s, Vt = svd(fmat, full_matrices=False)
    sp = np.sum(s > tol)
    print("[d={0}] span: {1}".format(d,sp))
    s = s[:sp]
    U = U[:, :sp]
    Vt = Vt[:sp, :]

    s = np.diag(s)
    U = np.dot(U, s)
    M = U.dot(Vt)
    rmse = np.sqrt(np.mean(np.abs(M - fmat) ** 2))
    print("[d={0}] Reconstruction rmse: {1}".format(d, rmse))
        
    # create new features/parameters
    new_features[sp] = U.reshape(features[d].shape[0], features[d].shape[1], sp)
    new_thetas[sp] = Vt.dot(thetas[d])
    
    # normalize parameters
    norm = np.linalg.norm(new_thetas[sp])
    new_thetas[sp] /= norm
    new_features[sp] *= norm
    
    # check errors
    old_mu = features[d].dot(thetas[d])
    new_mu = new_features[sp].dot(new_thetas[sp])
    err = np.abs(old_mu - new_mu)
    print("[d={0}] mu error: max {1} - mean {2}".format(d, np.max(err), np.mean(err)))
    
    del(old_mu)
    del(new_mu)
    del(err)
    
    print()

Starting d=33
[d=33] span: 33
[d=33] Reconstruction rmse: 7.396315737651094e-08
[d=33] mu error: max 9.5367431640625e-07 - mean 9.124558886242085e-08

Starting d=31
[d=31] span: 23
[d=31] Reconstruction rmse: 4.430497213547824e-08
[d=31] mu error: max 3.5762786865234375e-07 - mean 4.0275679680235044e-08

Starting d=29
[d=29] span: 26
[d=29] Reconstruction rmse: 5.371440536805494e-08
[d=29] mu error: max 2.980232238769531e-07 - mean 3.898551526049232e-08

Starting d=27
[d=27] span: 24
[d=27] Reconstruction rmse: 5.1650676624603875e-08
[d=27] mu error: max 3.5762786865234375e-07 - mean 6.987396261592949e-08

Starting d=25
[d=25] span: 20
[d=25] Reconstruction rmse: 5.4938041671448445e-08
[d=25] mu error: max 2.980232238769531e-07 - mean 3.767866729731395e-08

Starting d=23
[d=23] span: 16
[d=23] Reconstruction rmse: 5.5417388011846924e-08
[d=23] mu error: max 2.980232238769531e-07 - mean 5.34213064895539e-08

Starting d=21
[d=21] span: 17
[d=21] Reconstruction rmse: 5.740727360148412e-08

In [5]:
# filter gaps

thresh = 0.1

# ground truth
d_gt = max(new_features.keys())
mu_gt = new_features[d_gt].dot(new_thetas[d_gt])
gap_gt = np.max(mu_gt, axis=1)[:, np.newaxis] - mu_gt
gap_gt[gap_gt == 0] = 100
print("gap min:", gap_gt.min())
gap_gt = np.min(gap_gt, axis=1)

# indexes of contexts with minimum gap above threshold
good_contexts = gap_gt > thresh
print("# contexts with gap_min > {0}: {1}".format(thresh, np.sum(good_contexts)))

# filter
for d in new_features.keys():
    new_features[d] = new_features[d][good_contexts, :, :]

n_contexts = np.sum(good_contexts)
mu_gt = mu_gt[good_contexts, :]

gap min: 4.2915344e-06
# contexts with gap_min > 0.1: 5437


In [6]:
# check misspecification

eps = 0.05 # threshold for low misspecification

low_eps_contexts = np.ones(n_contexts, dtype=np.bool) 

for d in new_features.keys():
    mu = new_features[d].dot(new_thetas[d])
    err = np.abs(mu - mu_gt)
    print("[d={0}] error wrt ground truth: max {1} - mean {2}".format(d, err.max(), np.mean(err)))
    idx = np.max(err, axis=1) < eps  # contexts with low misspecification
    print("# contexts with eps < {0}: {1}".format(eps, np.sum(idx)))
    low_eps_contexts *= idx  # make sure all representations have low misspecification
    
    del(mu)
    del(err)

print("# contexts with eps < {0} in all representations: {1}".format(eps, np.sum(low_eps_contexts)))

# filter
for d in new_features.keys():
    new_features[d] = new_features[d][low_eps_contexts, :, :]

n_contexts = np.sum(low_eps_contexts)

[d=16] error wrt ground truth: max 0.7536253929138184 - mean 0.01893376186490059
# contexts with eps < 0.05: 2600
[d=33] error wrt ground truth: max 0.0 - mean 0.0
# contexts with eps < 0.05: 5437
[d=20] error wrt ground truth: max 0.6976933479309082 - mean 0.0202250387519598
# contexts with eps < 0.05: 2333
[d=17] error wrt ground truth: max 0.7919116020202637 - mean 0.01966984197497368
# contexts with eps < 0.05: 2366
[d=23] error wrt ground truth: max 0.8105988502502441 - mean 0.018861087039113045
# contexts with eps < 0.05: 2320
[d=24] error wrt ground truth: max 0.707683801651001 - mean 0.019329054281115532
# contexts with eps < 0.05: 2391
[d=26] error wrt ground truth: max 0.7306199073791504 - mean 0.017828047275543213
# contexts with eps < 0.05: 2782
# contexts with eps < 0.05 in all representations: 1291


In [7]:
# check span optimal arms

span_opt = {}

for d in new_features.keys():
    
    mu = new_features[d].dot(new_thetas[d])
    astar = np.argmax(mu, axis=1)
    fstar = np.array([new_features[d][x, astar[x]] for x in range(n_contexts)])

    span = d
    for i in range(d):
        if check_spanrd(fstar, d - i):
            span = d - i
            break
            
    span_opt[d] = span
    
    outer = np.matmul(fstar.T, fstar) / n_contexts
    lambda_hls = np.linalg.eigvals(outer).min()
    
    print("[d={0}] span optimal arms: {1} - lambda HLS: {2}".format(d, span, lambda_hls))
    
    del(mu)
    del(astar)
    del(fstar)
    del(outer)

[d=16] span optimal arms: 16 - lambda HLS: 0.000542577588930726
[d=33] span optimal arms: 33 - lambda HLS: 0.01859199069440365
[d=20] span optimal arms: 20 - lambda HLS: 2.5937774594808616e-08
[d=17] span optimal arms: 17 - lambda HLS: 0.0048090796917676926
[d=23] span optimal arms: 23 - lambda HLS: 0.0016820760210976005
[d=24] span optimal arms: 24 - lambda HLS: 0.0003457514103502035
[d=26] span optimal arms: 26 - lambda HLS: 0.0010042345384135842


In [8]:
# save

for d in new_features.keys():
    np.savez_compressed('jester_post_d{0}_span{1}.npz'.format(d,span_opt[d]), 
                        features=new_features[d], theta=new_thetas[d])