In [175]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.optim import SGD

- sampling distribution: z~N(O, I)
- posterior model: g_nu(z) = theta
- expected reward model: f_theta(action) = E(y|action, theta)
- noise for observations: sigmao
- prior std: sigmap

In [176]:
class KBandits:
    def __init__(self, k, sigma_obs=0.5, sigma_model=2):
        self.rewards = np.random.randn(k) * sigma_model
        self.n_bandits = k
        self.sigma_obs = sigma_obs
        self.sigma_model = sigma_model
    
    def reset(self):
        self.rewards = np.random.randn(k) * self.sigma_model
    
    def set_model(self, theta):
        self.rewards = theta
    
    def step(self, action):
        return (np.random.randn(self.n_bandits) * self.sigma_obs + self.rewards)[action]

In [177]:
def gather_dataset(env, npoints, nbandits):
    """
    output of shape (Npoints, Kbandits)
    """
    return np.array([[env.step(i) for i in range(nbandits)] for _ in range(npoints)])

In [178]:
d = gather_dataset(KBandits(4), 1000, 4)
d.mean(0), d.std(0)

(array([ 3.85540057, -3.46561015,  4.00836241, -3.41760425]),
 array([0.49310461, 0.50771701, 0.5113953 , 0.51589582]))

In [181]:
def index_selection_mapping(theta, x):
    """
    theta of size (B, K)
    x of size (B,)
    """
    return torch.gather(theta, dim=1, index=x.unsqueeze(1))

In [182]:
class LinearHypermodelBandits:
    def __init__(self, hpm_size, sigma_prior, k_arms, device='cpu'):
        self.hpm_size = hpm_size
        self.sigma_p = sigma_prior
        self.n_arms = k_arms
        self.posterior_model_g = LinearModuleBandits(k_bandits=k_arms, model_dim=hpm_size)
        self.device = device
        self.posterior_model_g = self.posterior_model_g.to(device)
        self.prior = self.sample_prior_dbz() # of shape (k, )
        
    
    def sample_prior_dbz(self):
        D = np.random.randn(self.n_arms) * self.sigma_p # dim (k, k)
        B = generate_hypersphere(dim=self.hpm_size, n_samples=self.n_arms, norm=1) # dim (k, m)
        z = np.random.randn(self.n_arms, self.hpm_size) # dim (k, m)
        return torch.from_numpy(D * (B * z).sum(-1)).to(self.device)
    
    def sample_posterior(self, n_samples):
        return self.posterior_model_g.sample(n_samples) + self.prior # NEED TO DISCUSS THIS
    
    def update_device(device):
        self.device = device
        self.prior_posterior_model_g = self.prior_posterior_model_g.to(device)

In [186]:
hm = LinearHypermodelBandits(hpm_size=1, sigma_prior=0.5, k_arms=4)

In [187]:
data_test = np.array([hm.sample_prior_dbz().numpy() for _ in range(1000)])

In [188]:
data_test.mean(0), data_test.std(0)

(array([ 0.01785005,  0.0320669 ,  0.01765375, -0.02198162]),
 array([0.52112837, 0.50227694, 0.46088904, 0.49820453]))

In [189]:
class LinearModuleBandits(nn.Module):
    def __init__(self, k_bandits, model_dim):
        super(LinearModuleBandits, self).__init__()
        self.k = k_bandits
        self.m = model_dim
        self.C = nn.Parameter(torch.ones(size=(k_bandits, model_dim)))
        self.mu = nn.Parameter(torch.ones(k_bandits))
        self.init_parameters()
    
    def init_parameters(self):
        mu_sampled = torch.randn(self.k) * 0.05
        c_sampled = torch.randn(self.k, self.m) * 0.05
        self.C.data = c_sampled
        self.mu.data = mu_sampled
    
    def forward(self, z):
        """
        z of size (batch, Kbandits, modelsize)
        theta of size (batch, Kbandits)
        """
        return (self.C * z).sum(-1) + self.mu
    
    def sample(self, n_samples):
        return self.forward(torch.randn(n_samples, self.k, self.m))

In [190]:
data_test = hm.posterior_model_g.sample(10000)
data_test.mean(0), data_test.std(0)

(tensor([ 0.0327, -0.0407,  0.0043,  0.0232], grad_fn=<MeanBackward1>),
 tensor([0.0157, 0.0130, 0.0677, 0.0580], grad_fn=<StdBackward1>))

In [191]:
hm.prior

tensor([-0.4995,  0.5636,  0.4819,  0.1375], dtype=torch.float64)

In [192]:
hm.posterior_model_g.mu.data

tensor([ 0.0327, -0.0409,  0.0021,  0.0230])

In [193]:
data_test = hm.sample_posterior(10000)
data_test.mean(0), data_test.std(0)

(tensor([-0.4667,  0.5226,  0.4842,  0.1604], dtype=torch.float64,
        grad_fn=<MeanBackward1>),
 tensor([0.0157, 0.0130, 0.0690, 0.0585], dtype=torch.float64,
        grad_fn=<StdBackward1>))

In [194]:
def generate_hypersphere(dim, n_samples, norm=1):
    if norm==1: # TODO ask question about that
        samples = np.random.rand(n_samples, dim)
        samples = samples / np.expand_dims(np.abs(samples).sum(1), 1)
        return samples
    elif norm==2:
        samples = np.random.randn(n_samples, dim)
        samples = samples / np.expand_dims(np.sqrt((samples ** 2).sum(1)), 1)
        return samples
    else:
        raise ValueError

In [195]:
class RandomAgent:
    def __init__(self, k_bandits):
        self.n_arms = k_bandits

    def act(self):
        return np.random.randint(self.n_arms)

    def reset(self):
        pass

In [196]:
def augmented_dataset(dataset, perturbations_dimension, mode='hypersphere'):
    n_points = len(dataset)
    if mode == 'hypersphere':
        perturbations = generate_hypersphere(dim=perturbations_dimension, n_samples=n_points, norm=1)
    else:
        perturbations = generate_hypersphere(dim=perturbations_dimension, n_samples=n_points, norm=2)
    return [tuple([*data_point, perturbations[ix, :]]) for ix, data_point in enumerate(dataset)]

In [197]:
def run_episode(envnmt, actor, horizon, n_steps, n_samples_z, lr, sigmao, sigmap, batch_size, hypermodel, update_every=1, training=False, device='cpu'):
    obs = []
    new_data = []
    dataset = []
    for ix in range(horizon):
        arm_selected = actor.act()
        reward = envnmt.step(arm_selected)
        data_point = [arm_selected, reward]
        obs.append(data_point)
        new_data.append(data_point)
        if (ix + 1) % update_every == 0:
            dataset += augmented_dataset(new_data, perturbations_dimension=hypermodel.hpm_size)
            new_data = []
            if training:
                data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
                train_hypermodel(data_loader,
                                 nsteps=n_steps,
                                 nsamples_z=n_samples_z,
                                 learning_rate=lr,
                                 hypermodel=hypmodel,
                                 sigmao=sigmao,
                                 sigmap=sigmap,
                                 device='cpu')
    return obs, dataset

In [212]:
def train_hypermodel(data_loader, hypermodel, nsteps, nsamples_z, learning_rate, sigmao, sigmap, device='cpu'):
    optimizer = SGD(hypermodel.posterior_model_g.parameters(), lr=learning_rate, weight_decay=1/(2 * sigmap ** 2))
    steps_done = 0
    total_loss = 0
    while True:
        for batch in data_loader:
            x, y, a = batch
            x = x.to(device)
            y = y.to(device)
            a = a.to(device) # shape B, m
#             import ipdb;
#             ipdb.set_trace()
            z_sample = torch.randn(nsamples_z, k, modelling_size, device=device) # shape M, K, m
            z_sliced = torch.index_select(z_sample, 1, x) # shape M, B, m
            
            sigAz = sigmao * (a.unsqueeze(0) * z_sliced).sum(-1) #shape M, B
            
            posteriors = hypermodel.posterior_model_g(z_sample)
            outputs = torch.index_select(posteriors, 1, x)
            
            loss = (((y + sigAz - outputs) ** 2).mean(1)).mean(0) #/ (2 * sigmao ** 2)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            steps_done += 1
            if (steps_done % 25) == 0:
                print(f"step {steps_done}, loss:{loss.item():2f}")
            if steps_done >= nsteps:
                return total_loss / nsteps

In [235]:
H = 1000
sigmao = 0.1 # environmnet parameters
sigmap = 1.

sigmap_algo = 0.5 # hypermodel prior width

sigmap_training = 10. # weight decay penalty
sigmao_training = 0.5
updates_freq = 1
batch_s = 16
lr = 5 * 1e-2
n_samples_z = 16
n_steps = 1000

k = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'
modelling_size = 2

In [236]:
env = KBandits(k=k,
               sigma_obs=sigmao,
               sigma_model=sigmap)

In [237]:
env.rewards

array([-0.1515557 ,  1.99217456,  0.80892846, -0.31244675,  0.70451408,
       -1.70928757,  1.73128574,  0.69615721, -0.38798008,  1.10477927])

In [238]:
d = gather_dataset(env, 1000, k)
d.mean(0), d.std(0)

(array([-0.15695174,  1.98961239,  0.80447679, -0.30946026,  0.70149627,
        -1.71306783,  1.73528717,  0.69431881, -0.38710757,  1.10509069]),
 array([0.10305578, 0.09509185, 0.09925482, 0.09750421, 0.09746322,
        0.09762265, 0.09930263, 0.10072955, 0.0996663 , 0.10275035]))

In [239]:
hypmodel = LinearHypermodelBandits(hpm_size=modelling_size,
                                   sigma_prior=sigmap_algo,
                                   k_arms=k,
                                   device=device)

In [240]:
data_test = hypmodel.posterior_model_g.sample(10000)
data_test.mean(0), data_test.std(0)

(tensor([ 0.0161,  0.0008,  0.0006,  0.0803,  0.1147, -0.0117,  0.0608, -0.0809,
         -0.0086, -0.0476], grad_fn=<MeanBackward1>),
 tensor([0.0804, 0.0731, 0.0782, 0.0206, 0.0655, 0.0491, 0.0521, 0.0174, 0.0582,
         0.0413], grad_fn=<StdBackward1>))

In [241]:
hypmodel.prior

tensor([-0.3166,  0.1244,  0.3018,  0.0734,  0.0136, -0.3415,  0.4828,  0.3423,
        -0.0254, -0.1283], dtype=torch.float64)

In [242]:
hypmodel.posterior_model_g.mu.data

tensor([ 0.0161,  0.0015,  0.0006,  0.0803,  0.1138, -0.0115,  0.0610, -0.0808,
        -0.0092, -0.0470])

In [249]:
data_test = hypmodel.sample_posterior(10000)
data_test.mean(0), data_test.std(0)

(tensor([-0.4643,  2.0717,  1.0991, -0.2398,  0.6913, -2.0047,  2.1761,  0.9925,
         -0.3997,  0.9433], dtype=torch.float64, grad_fn=<MeanBackward1>),
 tensor([0.3428, 0.3535, 0.3466, 0.3470, 0.3465, 0.3419, 0.3555, 0.3449, 0.3467,
         0.3464], dtype=torch.float64, grad_fn=<StdBackward1>))

In [244]:
env.rewards

array([-0.1515557 ,  1.99217456,  0.80892846, -0.31244675,  0.70451408,
       -1.70928757,  1.73128574,  0.69615721, -0.38798008,  1.10477927])

In [245]:
obs, dataset = run_episode(envnmt=env,
                           actor=RandomAgent(k_bandits=k),
                           horizon=H,
                           n_steps=n_steps,
                           n_samples_z=n_samples_z,
                           lr=lr,
                           sigmao = sigmao_algo,
                           sigmap = sigmap_algo,
                           batch_size=batch_s,
                           update_every=1,
                           hypermodel=hypmodel,
                           training=False,
                           device='cpu')

In [247]:
dl = DataLoader(dataset, batch_size=batch_s, shuffle=True)

In [248]:
train_hypermodel(data_loader=dl,
                 nsteps=n_steps, 
                 nsamples_z=n_samples_z,
                 learning_rate=lr,
                 sigmao=sigmao_training,
                 sigmap=sigmap_training,
                 hypermodel=hypmodel, device='cpu')

step 25, loss:1.160254
step 50, loss:0.557353
step 75, loss:0.309358
step 100, loss:0.189499
step 125, loss:0.212733
step 150, loss:0.141650
step 175, loss:0.113221
step 200, loss:0.081919
step 225, loss:0.062137
step 250, loss:0.037872
step 275, loss:0.047653
step 300, loss:0.039998
step 325, loss:0.035347
step 350, loss:0.048748
step 375, loss:0.028004
step 400, loss:0.041591
step 425, loss:0.023608
step 450, loss:0.036558
step 475, loss:0.029538
step 500, loss:0.037406
step 525, loss:0.029133
step 550, loss:0.037377
step 575, loss:0.032833
step 600, loss:0.037851
step 625, loss:0.030647
step 650, loss:0.023506
step 675, loss:0.032300
step 700, loss:0.029240
step 725, loss:0.034913
step 750, loss:0.029136
step 775, loss:0.035578
step 800, loss:0.027723
step 825, loss:0.040517
step 850, loss:0.032273
step 875, loss:0.024940
step 900, loss:0.028806
step 925, loss:0.049389
step 950, loss:0.032057
step 975, loss:0.045163
step 1000, loss:0.035307


0.11077446661083155

TODOs:
- online loop with TS
- similar offline tests with the MNL bandit env
- script for experiments