In [91]:
import numpy as np
import sys
sys.path.insert(0,'..')
import xbrl.envs as bandits
from xbrl.algs.nnmodel import MLLinearNetwork
import xbrl.envs.hlsutils as hlsutils
import torch
import torch.nn as nn
import math

In [92]:
ncontexts = 100
narms = 10
dim = 10
contextgeneration= "gaussian"
feature_expansion = "none"
seed_problem=99

In [93]:
features, theta = bandits.make_synthetic_features(
    n_contexts=ncontexts, n_actions=narms, dim=dim,
    context_generation=contextgeneration, feature_expansion=feature_expansion,
    seed=seed_problem
)
print(features.shape)

(100, 10, 10)


In [94]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.in_channels
            m.weight.data.normal_(0, math.sqrt(4. / n))
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            n = m.in_features + m.out_features
            m.weight.data.normal_(0, math.sqrt(100. / n))
            if m.bias is not None:
                m.bias.data.zero_()

In [95]:
hid_dim = [60, 10]
layers = [(el, nn.Tanh()) for el in hid_dim]
net = MLLinearNetwork(dim, layers)
initialize_weights(net)

In [96]:
U = features.reshape(-1, dim)
print(features.shape, U.shape)
xt = torch.tensor(U, dtype=torch.float)
H = net.embedding(xt).cpu().detach().numpy()
print(H.shape)
newfeatures = H.reshape(ncontexts, narms, net.embedding_dim)
R = net(xt).detach().numpy().ravel()
newreward = R.reshape(ncontexts, narms)

theta = net.fc2.weight.detach().numpy().ravel()
y = newfeatures @ theta
assert np.allclose(y, newreward)

(100, 10, 10) (1000, 10)
(1000, 10)


In [97]:
print(f"New rep -> HLS rank: {hlsutils.hls_rank(newfeatures, newreward)} / {newfeatures.shape[2]}")
print(f"New rep -> is HLS: {hlsutils.is_hls(newfeatures, newreward)}")
print(f"New rep -> HLS min eig: {hlsutils.hls_lambda(newfeatures, newreward)}")
print(f"New rep -> is CMB: {hlsutils.is_cmb(newfeatures, newreward)}")
print(f"features norm: {np.linalg.norm(newfeatures, axis=-1).max()}")
print(f"reward range: {R.min()}, {R.max()}")

New rep -> HLS rank: 10 / 10
New rep -> is HLS: True
New rep -> HLS min eig: 0.23849486052095925
New rep -> is CMB: True
features norm: 3.162221908569336
reward range: -22.551109313964844, 22.437606811523438


In [98]:
NEW_LAYERS = [
    layers, #as the ground truth
    [(300, nn.Tanh())],
    [(10, nn.Tanh()), (10, nn.Tanh())]
]

for L in NEW_LAYERS:
    net2 = MLLinearNetwork(dim, L)
    initialize_weights(net2)
    print(net2)

    H = net2.embedding(xt).cpu().detach().numpy()
    H = H.reshape(ncontexts, narms, net2.embedding_dim)
    R = net2(xt).detach().numpy().ravel()
    R = R.reshape(ncontexts, narms)
    errors = np.abs(newreward - R)
    print(f"max error: {errors.max()}")
    print(f"min error: {errors.min()}")
    print(f"avg error: {errors.mean()}")
    print(f"New rep -> is HLS wrt true reward: {hlsutils.is_hls(H, newreward)}")
    print(f"New rep -> HLS min eig wrt true reward: {hlsutils.hls_lambda(H, newreward)}")
    print("\n\n")

MLLinearNetwork(
  (layers): ModuleList(
    (0): Linear(in_features=10, out_features=60, bias=True)
    (1): Tanh()
    (2): Linear(in_features=60, out_features=10, bias=True)
    (3): Tanh()
  )
  (fc2): Linear(in_features=10, out_features=1, bias=False)
)
max error: 40.58506393432617
min error: 0.009131431579589844
avg error: 10.753617286682129
New rep -> is HLS wrt true reward: True
New rep -> HLS min eig wrt true reward: 0.29117509729687524



MLLinearNetwork(
  (layers): ModuleList(
    (0): Linear(in_features=10, out_features=300, bias=True)
    (1): Tanh()
  )
  (fc2): Linear(in_features=300, out_features=1, bias=False)
)
max error: 29.50351905822754
min error: 0.012479782104492188
avg error: 9.124445915222168
New rep -> is HLS wrt true reward: False
New rep -> HLS min eig wrt true reward: 0.015731919302286457



MLLinearNetwork(
  (layers): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): Tanh()
    (2): Linear(in_features=10, out_features=10, b