In [36]:
import numpy as np

In [80]:
k = 1000
d = 8
active_arms01 = np.random.binomial(n=1, p=0.5, size=k)
active_arms = np.where(active_arms01 == 1)[0]
print(active_arms)

[  2   3   7   8   9  13  14  15  21  23  26  27  28  29  32  37  38  39
  41  42  43  45  47  51  52  55  56  57  59  60  62  64  65  69  71  72
  74  78  79  80  81  84  85  86  89  90  91  92  93  94  97  98 101 104
 109 110 119 120 122 123 125 126 128 129 133 134 135 141 143 145 146 147
 152 154 156 158 160 161 163 166 167 168 169 171 173 175 176 177 181 183
 185 186 187 194 198 199 203 204 205 206 207 209 210 211 213 218 222 223
 224 226 227 228 231 233 235 237 238 239 240 241 243 244 246 247 249 252
 255 256 257 261 262 263 265 267 269 273 274 275 276 277 280 282 283 287
 288 289 290 291 292 293 295 296 300 301 302 304 305 306 307 308 310 311
 312 313 314 316 317 318 319 320 323 326 327 329 330 331 332 334 336 337
 339 341 342 343 347 348 350 351 352 361 363 365 366 367 369 370 371 372
 374 375 376 377 378 379 381 386 389 390 392 393 394 395 396 397 399 400
 402 403 406 407 408 410 414 415 416 419 420 421 422 424 425 426 427 428
 429 430 431 432 435 439 443 444 445 446 447 448 45

In [81]:
def make_smaller_matrix(A, active_arms):
    B = np.zeros((len(active_arms),A.shape[1]))

    for i in range(len(active_arms)):
        B[i,:] = A[active_arms[i],:]

    return B

In [82]:
A = np.random.normal(size=(k,d))

In [83]:
B = make_smaller_matrix(A, active_arms)

In [84]:
def compute_induced_norm(Ainv, v):
    results = np.zeros(v.shape[0])
    for i in range(v.shape[0]):
        results[i] = np.dot(v[i,:].T, np.dot(Ainv, v[i,:]))
    return results

def compute_design_matrix(A, pi):
    D = np.zeros((A.shape[1],A.shape[1]))

    for i in range(A.shape[0]):
        D += pi[i]*np.dot(A[i:i+1,:].T,A[i:i+1,:])
    return D

def squeeze_distribution(pi, n):
    # apply noise injection to avoid ties
    pi = pi + np.random.normal(0,scale=1e-4,size=len(pi))

    sorted_vals = sorted(pi, reverse=True)
    nth_largest = sorted_vals[min(n, len(sorted_vals))-1]
    print('nth_largest')
    print(nth_largest)
    pi[pi<nth_largest] = 0
    pi = pi/np.sum(pi)
    return pi

def onehot(idx, k):
    v = np.zeros(k)
    v[idx] = 1
    return v

def eval_pi(pi, A):
    D = compute_design_matrix(A, pi)
    Dinv = np.linalg.inv(D)
    v = compute_induced_norm(Dinv, A)
    return np.max(v)


def find_optimal_design(A, iter=1000, thresh=0):
    k = A.shape[0]
    pi = np.ones(k)/k

    for it in range(iter):
        D = compute_design_matrix(A, pi)
        Dinv = np.linalg.inv(D)
        v = compute_induced_norm(Dinv, A)

        best_index = np.argmax(v)
        current = v[best_index]
        print(current)
        if current < (thresh + 1)*A.shape[1]:
            break
        gamma = (current/d-1)/(current-1)

        pi = (1-gamma)*pi + gamma*onehot(best_index, k)
    print(pi)
    print(eval_pi(pi, A))
    pi = squeeze_distribution(pi, 2*A.shape[1])
    print(eval_pi(pi, A))
    return pi

In [85]:
find_optimal_design(B, iter=1000)

24.137094534583923
21.540224064154216
22.609031047007317
22.984993952202185
22.22953491554142
23.432516551930366
22.34228086825489
21.37861554744155
18.4483816616033
14.402298912114965
14.02095651886959
13.597240613984573
14.225437912602994
13.123529457182434
13.244353448002602
12.711678312194413
12.315871469345447
12.250809805693889
12.377458379783375
12.118727228419182
11.386111906544299
10.489642760332503
10.77731404662263
10.69327414508735
10.84893974103757
10.642813750929765
10.623013862048177
10.437767467231556
10.33753491347326
9.78305519303509
9.760665585241329
9.982008036159046
9.888529057418529
9.764487876113714
10.006852180044687
9.722017022986387
9.54176140946959
9.439331372512301
9.537233643976549
9.59464959240671
9.599997059110283
9.42126409398857
9.45652079819011
9.504816136217073
9.222318440971161
9.131402574993924
9.120186552768674
9.129355491075547
9.125422926956446
9.08715434708135
9.000197159425246
9.047692816914818
9.023416605596175
9.088808581622013
9.001605032853

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.     