In [1]:
import numpy as np

In [2]:
k = 1000
d = 8
active_arms01 = np.random.binomial(n=1, p=0.5, size=k)
active_arms = np.where(active_arms01 == 1)[0]
print(active_arms)

[  0   1   3   6   7   8   9  10  12  18  19  20  23  25  28  30  33  34
  51  52  53  54  57  58  60  61  64  65  67  68  71  73  76  77  78  84
  86  87  88  90  91  93  96  97  99 100 103 106 108 109 110 112 113 114
 115 118 121 122 124 126 127 128 129 131 132 133 134 135 139 142 143 147
 150 151 152 153 156 157 159 161 163 167 171 174 175 176 177 179 181 183
 184 188 191 193 194 195 196 202 203 206 207 209 210 211 216 217 218 220
 221 222 224 225 228 230 231 234 235 237 238 239 240 241 242 244 245 246
 249 250 252 254 255 256 257 258 259 263 264 265 266 269 274 275 277 278
 279 281 287 288 289 290 295 299 300 301 305 306 307 308 311 316 320 321
 322 326 327 332 334 335 336 337 338 340 345 348 349 352 353 357 358 359
 360 361 362 364 365 369 370 376 381 383 385 387 390 392 393 394 395 396
 398 399 400 407 410 411 416 417 418 419 420 421 424 426 427 429 430 432
 433 434 437 441 442 445 449 451 453 454 461 464 466 467 471 472 473 474
 476 477 480 483 484 486 487 489 490 492 493 494 49

In [3]:
def make_smaller_matrix(A, active_arms):
    B = np.zeros((len(active_arms),A.shape[1]))

    for i in range(len(active_arms)):
        B[i,:] = A[active_arms[i],:]

    return B

In [4]:
A = np.random.normal(size=(k,d))

In [5]:
B = make_smaller_matrix(A, active_arms)

In [6]:
def compute_induced_norm(Ainv, v):
    results = np.zeros(v.shape[0])
    for i in range(v.shape[0]):
        results[i] = np.dot(v[i,:].T, np.dot(Ainv, v[i,:]))
    return results

def compute_design_matrix(A, pi):
    D = np.zeros((A.shape[1],A.shape[1]))

    for i in range(A.shape[0]):
        D += pi[i]*np.dot(A[i:i+1,:].T,A[i:i+1,:])
    return D

def squeeze_distribution(pi, n):
    # apply noise injection to avoid ties
    pi = pi + np.random.normal(0,scale=1e-4,size=len(pi))

    sorted_vals = sorted(pi, reverse=True)
    nth_largest = sorted_vals[min(n, len(sorted_vals))-1]
    pi[pi<nth_largest] = 0
    pi = pi/np.sum(pi)
    return pi

def onehot(idx, k):
    v = np.zeros(k)
    v[idx] = 1
    return v

def eval_pi(pi, A):
    D = compute_design_matrix(A, pi)
    Dinv = np.linalg.inv(D)
    v = compute_induced_norm(Dinv, A)
    return np.max(v)


def find_optimal_design(A, iter=1000, thresh=0):
    k = A.shape[0]
    pi = np.ones(k)/k

    for it in range(iter):
        D = compute_design_matrix(A, pi)
        Dinv = np.linalg.inv(D)
        v = compute_induced_norm(Dinv, A)

        best_index = np.argmax(v)
        current = v[best_index]
        if current < (thresh + 1)*A.shape[1]:
            break
        gamma = (current/d-1)/(current-1)

        pi = (1-gamma)*pi + gamma*onehot(best_index, k)
    print(pi)
    print(eval_pi(pi, A))
    pi = squeeze_distribution(pi, 2*A.shape[1])
    print(eval_pi(pi, A))
    return pi

In [7]:
find_optimal_design(B, iter=1000)

23.095698411774325
24.6248102467051
23.446585578191677
23.877447196604518
25.6127092659024
25.04366406133616
23.698012030929185
24.621347870320168
18.37104610337315
15.854985432008032
13.87948600617429
14.307547990809951
14.871643470764711
14.419694982787064
14.292128414037151
13.899092562453182
11.820027923305533
11.57101218617452
11.260639045699959
10.908645464210887
10.884858807181073
11.226155304783132
11.156931979862332
11.006379958380979
10.143867342390443
9.8519316184309
10.049840827776116
10.306839726766947
10.24490664077732
10.31665995856793
10.407876549205728
9.747736062519287
9.815226946019507
9.763093746273592
9.85028576703826
9.803628052497773
9.879080649025607
9.845358207372223
9.591593621484309
9.520328211778075
9.47863558868691
9.304803942232944
9.463228428490504
9.61623253897627
9.28910304718444
9.084861840151552
9.184097436689344
9.099698248319628
9.155065982846736
9.284389539781571
9.434119433923962
9.382687056982052
9.120342340948147
9.029324042309698
8.961974851187

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.04758