# Init

In [5]:
# coding: utf-8

# Implement damped Newton method to minimize the majorant function
import numpy as np
import matplotlib.pyplot as plt
import time
from math import log
from math import sqrt
import sys

K = 1000 # number of arms
d = 20 # dimension
T = 10000 # time horizon
M = 50 # number of agents
D = T * log(M * T) / (d * M) # global parameter
lam = 1 # global parameter
ins = []
tracex = []
tracey = []

# Generate linear bandits sample
# ins[0]: theta*
# ins[1:K]: K arms
np.random.seed(173)
Theta = np.random.normal(size=d)
#Theta = Theta / np.linalg.norm(Theta)
rho = 1# / sqrt(d)
X=np.random.random(size=(K, d))*rho - np.ones((K, d))*0.5*rho
ins.append(Theta.reshape((d, 1)))
for i in range(1, K + 1) :
    ins.append(X[i - 1].reshape((d, 1)))    

# Run

In [None]:
# Sample from arm k
def sample(k) :
    global ins
    return np.asscalar(np.dot(ins[0].T, ins[k])) + np.random.random_sample() - 0.5

# solve the inner argmax problem
# max <x, theta> s.t. ||theta - theta_hat||_2 <= radius
#def solveArgmax(x, theta_hat, radius) :
#    return np.asscalar(np.dot(x.T, theta_hat)) + radius * np.linalg.norm(x)


# solve the inner argmax problem
# max <x, theta> s.t. ||theta||_V <= radius
def solveArgmax(x, Vinv, radius) :
    return radius * sqrt(np.asscalar(x.T @ Vinv @ x))

# compute matrix square root
#def sqrtMat(A) :
#    u, s, v = np.linalg.svd(A)
#    s = np.diag(np.sqrt(s))
#    return u @ s @ v

# Simulate M agents for horizon T
def simulate() :
    global K, d, T, M, D, lam, ins
    W = [np.zeros((d, d))] # W[0] is W_syn
    U = [np.zeros((d, 1))] # U[0] is U_syn
    Vlast = lam * np.identity(d) # V of last communication
    (sgn, VlastLogd) =  np.linalg.slogdet(Vlast)
    Tlast = 0 # time of last communication
    CumRegret = 0 # cumulative regret
    CommuCost = 0 # communication cost (in packets)
    count = 0
    access = np.zeros((K + 1,))
    
    # Initilization
    for i in range(M) :
        W.append(np.zeros((d, d)))
        U.append(np.zeros((d, 1)))
    regret = [-200]
    for i in range(K) :
        regret.append(np.asscalar(np.dot(ins[0].T, ins[i + 1])))
        if (regret[i + 1] > regret[0]) :
            regret[0] = regret[i + 1]
    for i in range(K) :
        regret[i + 1] = regret[0] - regret[i + 1]
        #print(regret[i + 1])
        
    for t in range(1, T + 1) :
        CommuSignal = False # decides if communication is needed
        for i in range(1, M + 1) : # each agent pulls an arm
            V = lam * np.identity(d) + W[0] + W[i]
            Vinv = np.linalg.inv(V)
            theta_hat = Vinv @ (U[0] + U[i])

            (sgn, logd) = np.linalg.slogdet(V)
            assert(sgn == 1)
            radius = sqrt(logd - d * log(lam) + 2 * log(T)) + sqrt(lam) # delta = 1 / T
            
            bestarm = 1
            bestval = -100000
            for j in range(1, K + 1) : # choose the best arm
                val = np.asscalar(np.dot(ins[i].T, theta_hat)) + solveArgmax(ins[i], Vinv, radius)
                #assert(val >= 0)
                if val > bestval :
                    bestval = val
                    bestarm = j

            access[bestarm] += 1
            y = sample(bestarm)
            CumRegret += regret[bestarm]
            W[i] += ins[bestarm] @ ins[bestarm].T
            U[i] += y * ins[bestarm]
            V += ins[bestarm] @ ins[bestarm].T
                        
            (sgn, logd) = np.linalg.slogdet(V)
            assert(sgn == 1)

            if (t - Tlast) * (logd - VlastLogd) > D :
                CommuSignal = True
                CommuCost += 1
        tracex.append(t)
        tracey.append(CumRegret)
        if CommuSignal and M > 1: # initiate a communication stage
            count += 1
            for i in range(1, M + 1) :
                CommuCost += sys.getsizeof(W[i])
                CommuCost += sys.getsizeof(U[i])
                W[0] += W[i]
                U[0] += U[i]
                CommuCost += sys.getsizeof(W[0])
                CommuCost += sys.getsizeof(U[0])
                W[i] = np.zeros((d, d))
                U[i] = np.zeros((d, 1))

            Tlast = t
            Vlast = lam * np.identity(d) + W[0]

    print(count)
    print(CommuCost)
    #for i in range(1, K + 1) :
    #    print(access[i], end = ' ')

    plt.plot(tracex, tracey, label = 'DisLinUCB')
    plt.xlabel('Time')
    plt.ylabel('Cumulative Regret')
    plt.legend()
    plt.savefig('DisLinUCB.png')

if __name__ == '__main__' :
    simulate()