In [1]:
import torch
import numpy as np

In [2]:
from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm

from eval import get_run_metrics, read_run_dir, get_model_from_run
from plot_utils import basic_plot, collect_results, relevant_model_names

%matplotlib inline
%load_ext autoreload
%autoreload 2

sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

run_dir = "../models"

In [3]:
devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
device = devices[1] # use GPU #1 on the machine

### In-Context Learning (ICL) for Transformers
A Transformer $T$ takes in a sequence and generate a prediction. In our case, each sequence represents a learning task with varying underlying oracle functions.

Suppose we have a Tranformer $T$ trained with in-context-learning objective on linear functions, i.e., each sequence is generated as follows:
1. $w \sim \mathcal{N}(0,I)$
2. For $i= 1, \dots, n$:
       $x_i \sim \mathcal{N}(0,I), y_i = w^{\top}x_i$
The Transformers $T$ is trained to take in the sequence $\{x_1,y_1,x_2,y_2,\dots,x_n\}$ and predicts $y_n$.

### Transferring the Knowledge within the Transformers

We consider the **offline** contextual bandit problem where we are given an offline dataset $D=\{(s_i,a_i,r_i)\}_{i=1}^n$. The actions $a_i$'s were collected by an unknown behavirol policy. Each task is characterized by a state distribution $P_S$, and a reward function 
$r = f(s,a)$. We first focus on the linear contextual bandit problem with
$$
r = f(s,a) =\theta^{*\top} \phi(s,a)
$$
where $\phi(s,a):\mathcal{S} \times \mathcal{A} \rightarrow \mathbb{R}^d$ is the feature function. 

#### Comparison with the regular MDP setting
In the offline RL setting considered by the Decision Transformer, we have a set of trajectories $D=\{\tau_j = (s_0,a_0,r_0,s_1,a_1,r_1,\dots)\}_{j=1}^n$

### Loading Trained Transformer

The trained Transformer has input dimension $20$ and generate scalar ouputs.

In [4]:
task = "linear_regression"
#task = "sparse_linear_regression"
#task = "decision_tree"
#task = "relu_2nn_regression"

run_id = "pretrained"  # if you train more models, replace with the run_id from the table above
run_path = os.path.join(run_dir, task, run_id)

model, conf = get_model_from_run(run_path)
model = model.to(device)

In [5]:
xs = torch.randn(10,40,20).to(device) # randomly generate 10 ICL tasks, each one with 40 instances and X has dimension 20
ys = torch.randn(10,40).to(device) # randomly generate regression labels 

In [6]:
_,pred = model(xs,ys)

In [7]:
pred.shape

torch.Size([1, 40])

### ICL for Contextual Bandits

We consider the case where $\phi$ is a linear function, i.e., 
$$
\begin{align}
&\phi(s,a) = As + Ba,\\
& r = \theta^{*\top}As + \theta^{*\top}Ba
\end{align}
$$
Another reward function choice can be 
$$
r = s^{\top}Aa
$$
Assume that $\theta^* \sim \mathcal{N}(0,I)$ and $A,B \sim \mathcal{N}(0,1)$ elementwise. Different tuples of $(\theta^*, A, B)$ defines different contextual bandits problems.

The expected return of an action $a$ is 
$$
R(a) = \mathbb{E}[r|a] = \mathbb{E}[\theta^{*\top}\phi(s,a)] = \theta^{*\top}Ba.
$$

In [5]:
import torch.nn as nn

### Using varying bandit problems for training and testing

In [85]:
# Generate the Offline dataset with varying contextual bandits problems
S_DIM = 5 # dimension of states
A_NUM = 5 # total number of actions (bandits)

def reward_function(s,a,theta,Ms,Ma):
    '''
    Linear reward function
    '''
    r = theta@Ms@s + theta@Ma@a
    return r

def generate_ICL_seqs(num_ICL, num_instances, s_dim=20, a_num = 100, optimal_scale=0.8):
    # Ss, As, Rs = [], [], []
    ICL_seqs = []
    for _ in tqdm(range(num_ICL)):
        # generate a new contextual bandits problem
        theta = torch.randn(20)*0.1
        Ms = torch.randn(20,s_dim)*0.1
        Ma = torch.randn(20,a_num)*0.1
        # find optimal action and define behaviral policy 
        optimal_action = torch.argmax(theta@Ma).item()
        if optimal_scale:
            probs = torch.tensor([(1-optimal_scale)/a_num]*a_num)
            probs[optimal_action] += optimal_scale
        else: # if no optimal_scale, use a uniform policy as behaviral policy 
            probs = torch.tensor([1/a_num]*a_num)
        b_policy = torch.distributions.categorical.Categorical(probs)

        D = []
        for i in range(num_instances):
            new_s = torch.randn(s_dim) # sample a new state
            new_a = b_policy.sample()  # perform an action with behaviral policy 
            new_a = nn.functional.one_hot(new_a,num_classes=a_num).float()
            new_r = reward_function(new_s,new_a,theta, Ms,Ma) # receive a reward 
            D.append((new_s,new_a,new_r))
        S, A, R = torch.stack([t[0] for t in D],dim=0), torch.stack([t[1] for t in D],dim=0), torch.stack([t[2] for t in D],dim=0)

        ICL_seqs.append((S,A,R,nn.functional.one_hot(torch.tensor(optimal_action),num_classes=a_num).float()))
    return ICL_seqs

In [52]:
# Generate the Offline dataset
# Number of actions: 100
# Distribution of states:Gaussian(0,I_20)
# Behaviral policy: uniformly random policy
s_dim = 20
a_num = 100

theta = torch.randn(20)
Ms = torch.randn(20,s_dim)
Ma = torch.randn(20,a_num)
def reward_function(s,a):
    '''
    Linear reward function
    '''
    r = theta@Ms@s + theta@Ma@a
    return r

def generate_offline_dataset(num,s_dim,a_num):
    uniform_policy = torch.distributions.categorical.Categorical(torch.tensor([1/a_num]*a_num))
    dataset = []
    for _ in range(num):
        new_s = torch.randn(s_dim)
        new_a = uniform_policy.sample()
        new_a = nn.functional.one_hot(new_a,num_classes=a_num).float()
        new_r = reward_function(new_s,new_a)
        dataset.append((new_s,new_a,new_r))
    return dataset


In [53]:
# Oracle Best Bandit
print('Best bandit index:', torch.argmax(theta@Ma))

Best bandit index: tensor(88)


In [54]:
D = generate_offline_dataset(num=1000,s_dim=s_dim,a_num=a_num)
S, A, R = torch.stack([t[0] for t in D],dim=0), torch.stack([t[1] for t in D],dim=0), torch.stack([t[2] for t in D],dim=0)
S.shape, A.shape, R.shape

(torch.Size([1000, 20]), torch.Size([1000, 100]), torch.Size([1000]))

In [55]:
# For the given n data instances, each time we randomly sample m out of them, we repeat this for k times to construct k ICL sequences
def generate_ICL_sequences(S,A,R,num_ICL,num_instances):
    total_num = S.shape[0]
    ICL_sequences = []
    for _ in range(num_ICL):
        indices = np.random.choice(range(total_num),num_instances,replace=False)
        ICL_sequences.append((S[indices],A[indices],R[indices]))
    return ICL_sequences

In [56]:
ICLs = generate_ICL_sequences(S,A,R,2000,50)
ICL_X = torch.cat([torch.stack([t[0] for t in ICLs],dim=0),torch.stack([t[2] for t in ICLs],dim=0).unsqueeze(2)],dim=-1)
ICL_Y = torch.stack([t[1] for t in ICLs],dim=0)
ICL_X.shape, ICL_Y.shape

(torch.Size([2000, 50, 21]), torch.Size([2000, 50, 100]))

### Dataset Generation

Decision Transformer (DT)-based methods address RL as sequence modeling problem, at time step $t$, the input is the realized trajectory so far $\tau_t = (s_1,a_1,r_1,s_2,a_2,r_2,\dots,s_t)$ and a signal called returns-to-go (RTG) $R_t$ which represents the rewards to receive, the ouput $a_t$ is the predicted act that can lead to $R_t$.

To frame offline contextual bandit as ICL problem, we would like construct a sequence of $(X,y)$'s and an individual $X'$ whose $y$ is the target for prediction. Recall the offline dataset for contextual bandits $D=\{S_i,A_i,R_i\}$. To this end, we let $X = [S,R]$ and $Y = A$.

In [78]:
from torch.utils.data import Dataset, DataLoader

In [87]:
num_ICL = 10000 # generate 50k contextual bandits problems
num_instances = 100 # each offline dataset has 100 samples

# generate_ICL_seqs will return offline datasets for num_ICL different tasks
# [ICL_1, ICL_2, ICL_3, ..., ICL_{num_ICL}]
# where ICL_i = [S, A, R]
# S has dimension (num_instances x s_dim)
# A has dimension (num_instances x a_num)
# R has dimension (num_instances x 1) 

ICL_seqs = generate_ICL_seqs(num_ICL, num_instances, s_dim = S_DIM, a_num = A_NUM, optimal_scale=0.7)

# X = [S,R], Y = A
ICL_X = torch.cat([torch.stack([t[0] for t in ICL_seqs],dim=0),torch.stack([t[2] for t in ICL_seqs],dim=0).unsqueeze(2)],dim=-1)
ICL_Y = torch.stack([t[1] for t in ICL_seqs],dim=0)
ICL_Y_opt = torch.stack([t[3] for t in ICL_seqs],dim=0)

ICL_X.shape, ICL_Y.shape

  0%|          | 0/10000 [00:00<?, ?it/s]

(torch.Size([10000, 100, 6]), torch.Size([10000, 100, 5]))

In [88]:
class ConBanDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
    
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        ICL_X = self.X[idx]
        ICL_Y = self.Y[idx]
        
        return ICL_X, ICL_Y

In [120]:
batch_num_ICL = 32 # each batch has 32 contextual bandits problems
tr_dataset = ConBanDataset(ICL_X, ICL_Y)
tr_loader = DataLoader(tr_dataset, batch_num_ICL, shuffle=True)

## Train a TF from Scratch

In [9]:
from transformers import GPT2Model, GPT2Config
from models import TransformerModel

In [62]:
class ResNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_hidden=5):
        super(ResNet, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(input_dim, hidden_dim)] \
                                     +[nn.Linear(hidden_dim, hidden_dim) for i in range(num_hidden)]\
                                     +[nn.Linear(hidden_dim, output_dim)])
        self.relu = nn.ReLU()
    
    def forward(self,x):
        x = self.linears[0](x)
        for linear in self.linears[1:-1]:
            x = x + linear(x)
            x = self.relu(x)
        x = self.linears[-1](x)
        return x

class ICL_learner(nn.Module):
    def __init__(self,dim_states, n_action, n_positions=501, n_embd=256, n_layer=12, n_head=8):
        super(ICL_learner, self).__init__()
        # Backbone is a GPT2 model
        self._backbone = TransformerModel(n_dims = dim_states+1, n_positions=n_positions, n_embd=n_embd, n_layer=n_layer, n_head=n_head)
        self._action_encoder = ResNet(n_action,hidden_dim=128,output_dim = dim_states+1, num_hidden=3)
        self._readout = nn.Linear(n_embd,n_action)
    def forward(self,raw_x, raw_y):
        
        c_x = raw_x 
        c_y = self._action_encoder(raw_y)
        bsize, points, dim = c_x.shape
        
        zs = torch.stack([c_x,c_y],dim=2)
        zs = zs.view(bsize,2*points,dim)
        # print(zs.shape)

        # embed = model._read_in(zs)
        embed = self._backbone._read_in(zs)
        output = self._backbone._backbone(inputs_embeds = embed).last_hidden_state
        # output = self._backbone(zs).last_hidden_state
        pred = self._readout(output)
        # pred = self.softmax(pred)
        
        return output, pred[:,::2]


In [119]:
adapter = ICL_learner(S_DIM, A_NUM).to(device)

In [121]:
## Training 
loss = nn.CrossEntropyLoss() # cross-entropy loss for bandit identification
optimizer = torch.optim.Adam([param for param in adapter.parameters() if param.requires_grad == True],lr=1e-4)

EPOCHS = 10
for epoch in tqdm(range(EPOCHS)):
    print(f'EPOCH {epoch}...')
    for batch_idx, (batch_x,batch_y) in enumerate(tr_loader):
        optimizer.zero_grad()
        # shuffle the order of ICL instances 
        # shuffle_idx = torch.randperm(50)
        # batch_x = batch_x[:,shuffle_idx,:]
        # batch_y = batch_y[:,shuffle_idx,:]
        _,pred = adapter(batch_x.to(device),batch_y.to(device))
        # pred_bandit = torch.argmax(pred,dim=-1)
        l = loss(pred[:,:,:].softmax(dim=-1), batch_y[:,:,:].to(device))
        l.backward()
       
        optimizer.step()

        # report every 50 steps
        if batch_idx% 50 == 0: 
            # print(torch.norm(adapter.linear_x.weight.grad))
            # print('selected bandits:',pred_bandit)
            print("loss:",l.item())

  0%|          | 0/10 [00:00<?, ?it/s]

EPOCH 0...
loss: 92.16751861572266
loss: 92.103271484375
loss: 92.10306549072266
loss: 92.10340881347656
loss: 92.10343933105469
loss: 92.10356903076172
loss: 92.10282135009766
EPOCH 1...
loss: 92.103515625
loss: 92.1039047241211
loss: 92.0979995727539
loss: 92.02251434326172
loss: 92.1231689453125
loss: 92.1375732421875
loss: 92.0184326171875
EPOCH 2...
loss: 92.04609680175781
loss: 92.00853729248047
loss: 92.0137710571289
loss: 91.99898529052734
loss: 92.08428955078125
loss: 91.97474670410156
loss: 91.82493591308594
EPOCH 3...
loss: 91.9023666381836
loss: 91.59791564941406
loss: 91.93990325927734
loss: 91.5851058959961
loss: 91.07327270507812
loss: 91.07305145263672
loss: 90.59034729003906
EPOCH 4...
loss: 90.48949432373047
loss: 90.47530364990234
loss: 90.20228576660156
loss: 90.1248779296875
loss: 89.96210479736328
loss: 90.10002899169922
loss: 90.1965560913086
EPOCH 5...
loss: 89.99669647216797
loss: 90.02967834472656
loss: 90.0966567993164
loss: 90.10198211669922
loss: 89.8756103

In [122]:
torch.save(adapter, '../models/ICL_tfs_10000.pth')

#### Try setting all the target actions are optimal actions

In [115]:
batch_num_ICL = 32 # each batch has 32 contextual bandits problems
tr_dataset = ConBanDataset(ICL_X, torch.cat([ICL_Y_opt.unsqueeze(1)]*num_instances,dim=1))
tr_loader = DataLoader(tr_dataset, batch_num_ICL, shuffle=True)

In [None]:
## Training 
loss = nn.CrossEntropyLoss() # cross-entropy loss for bandit identification
optimizer = torch.optim.Adam([param for param in adapter.parameters() if param.requires_grad == True],lr=1e-4)

EPOCHS = 2
for epoch in tqdm(range(EPOCHS)):
    print(f'EPOCH {epoch}...')
    for batch_idx, (batch_x,batch_y) in enumerate(tr_loader):
        optimizer.zero_grad()
        # shuffle the order of ICL instances 
        # shuffle_idx = torch.randperm(50)
        # batch_x = batch_x[:,shuffle_idx,:]
        # batch_y = batch_y[:,shuffle_idx,:]
        _,pred = adapter(batch_x.to(device),batch_y.to(device))
        # pred_bandit = torch.argmax(pred,dim=-1)
        l = loss(pred[:,:,:].softmax(dim=-1), batch_y[:,:,:].to(device))
        l.backward()
       
        optimizer.step()

        # report every 200 steps
        if batch_idx%50 == 0: 
            # print(torch.norm(adapter.linear_x.weight.grad))
            # print('selected bandits:',pred_bandit)
            print("loss:",l.item())

## Adapt From A Trained ICL Model

In [7]:
class ICL_adapter(nn.Module):
    def __init__(self, base_model, x_dim, y_dim,pred_dim,hidden_dim=256):
        # Note that pred_dim should be equal to y_dim
        super(ICL_adapter, self).__init__()
        self.linear_x = nn.Linear(x_dim,hidden_dim)
        self.linear_y = nn.Linear(y_dim,hidden_dim)
        self.converter_x = ResNet(hidden_dim, hidden_dim, conf.model.n_dims).to(device) #Input dim to trained TF: conf.model.n_dims = 20
        self.converter_y = ResNet(hidden_dim, hidden_dim, conf.model.n_dims).to(device)
        self.base_model = base_model
        # self.readout = ResNet(256,256,pred_dim,num_hidden=3)
        self.readout = nn.Linear(256,pred_dim)
        # self.softmax = nn.Softmax(dim=-1)
       

    def forward(self,raw_x, raw_y):
        ## c_x and c_y needs to have the same shape after converter mapping!
        c_x, c_y = self.linear_x(raw_x), self.linear_y(raw_y)
        # non-linear transformation
        c_x, c_y = self.converter_x(c_x),self.converter_y(c_y)

        bsize, points, dim = c_x.shape
        zs = torch.stack([c_x,c_y],dim=2)
        zs = zs.view(bsize,2*points,dim)

        embed = model._read_in(zs)
        output = model._backbone(inputs_embeds = embed).last_hidden_state
        pred = self.readout(output)
        # pred = self.softmax(pred)
        
        return output, pred[:,::2]

In [132]:
# Initialize with Trained Transformers
model, conf = get_model_from_run(run_path)
model = model.to(device)

In [133]:
# Freeze the trained ICL model
for param in model.parameters():
    param.requires_grad = False
adapter = ICL_adapter(model, x_dim = s_dim + 1, y_dim = a_num, pred_dim=a_num,hidden_dim=256).to(device)

In [None]:
## Training 
loss = nn.CrossEntropyLoss() # cross-entropy loss for bandit identification
optimizer = torch.optim.Adam([param for param in adapter.parameters() if param.requires_grad == True],lr=1e-4)

EPOCHS = 10
for epoch in tqdm(range(EPOCHS)):
    print(f'EPOCH {epoch}...')
    for batch_idx, (batch_x,batch_y) in enumerate(tr_loader):
        optimizer.zero_grad()
        # shuffle the order of ICL instances 
        # shuffle_idx = torch.randperm(50)
        # batch_x = batch_x[:,shuffle_idx,:]
        # batch_y = batch_y[:,shuffle_idx,:]
        _,pred = adapter(batch_x.to(device),batch_y.to(device))
        # pred_bandit = torch.argmax(pred,dim=-1)
        l = loss(pred[:,:,:].softmax(dim=-1), batch_y[:,:,:].to(device))
        l.backward()
       
        optimizer.step()

        # report every 200 steps
        if batch_idx%50 == 0: 
            # print(torch.norm(adapter.linear_x.weight.grad))
            # print('selected bandits:',pred_bandit)
            print("loss:",l.item())

In [117]:
torch.save(adapter, '../models/ICL_adapter_10000.pth')

### Testing

In [75]:
def test_model(model,input_r=None, horizon=100):
    # generate a new contextual bandit instance
    theta = torch.randn(20)*0.1
    Ms = torch.randn(20,s_dim)*0.1
    Ma = torch.randn(20,a_num)*0.1
    
    # calculate the expected return of all the bandits 
    bandit_values = theta@Ma
    optimal_value = max(bandit_values).item()
    worst_value = min(bandit_values).item()
    # print('best bandit:', torch.argmax(bandit_values).item())
    # print('best reward:', optimal_value)
    if input_r is None:
        input_r = optimal_value
    
    
    selected_bandits = []
    s=[]
    new_a = torch.distributions.categorical.Categorical(torch.tensor([1/a_num]*a_num)).sample()
    new_a = nn.functional.one_hot(new_a,num_classes=a_num).float().to(device)
    a=[new_a] #dummy first action
    r=[] 
    
    for _ in range(horizon):
        # encounter a new state
        new_s = torch.randn(s_dim)
        s.append(new_s)
        if len(r) > 0:
            r[-1] = new_r
        r.append(torch.tensor(input_r))
        # construct inputs
        r = r[-50:]
        s = s[-50:]
        a = a[-50:]
        input_x = torch.cat([torch.stack(s,dim=0),torch.stack(r,dim=0).unsqueeze(-1)],dim=-1).unsqueeze(0)
        input_y = torch.stack(a,dim=0).unsqueeze(0)
        # Predict bandits
        _, pred = model(input_x.to(device),input_y.to(device))
        selected_bandit = torch.argmax(pred[-1,-1,:])
        selected_bandits.append(selected_bandit)
        a.append(nn.functional.one_hot(selected_bandit,num_classes=a_num).float().to(device))
        # new_r = reward_function(new_s,a[-1].detach().cpu(),theta,Ms,Ma)
        new_r = theta@Ma@(a[-1].detach().cpu())
        # print(selected_bandit)
        
    return bandit_values, optimal_value, worst_value, selected_bandits

def rank(avg_r, r_list):
    for i in range(len(r_list)):
        if avg_r < r_list[i]:
            break
    return i

In [72]:
bandit_values,optimal_value, worst_value, bandits = test_model(adapter,horizon=150)
print('avg return:', np.average([bandit_values[b.detach().cpu().item()].item() for b in bandits[50:]]))
print('best return:', optimal_value)
print('worst return:', worst_value)

avg return: 0.012679548934102058
best return: 0.05268328636884689
worst return: -0.05264483764767647


In [73]:
print(torch.sort(bandit_values)[0])
print([b.item() for b in bandits])

tensor([-0.0526, -0.0382, -0.0016,  0.0127,  0.0527])
[4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]


In [None]:
ranks = []
regrets = []
opt_values = []
worst_values = []
selected_bandits = []

for _ in tqdm(range(10)):
    bandit_values,optimal_value, worst_value, bandits = test_model(adapter,horizon=150)
    avg_return = np.average([bandit_values[b.detach().cpu().item()].item() for b in bandits[100:]])
    regrets.append(optimal_value-avg_return)
    opt_values.append(optimal_value)
    worst_values.append(worst_value)
    selected_bandits.append([b.item() for b in bandits])
    ranks.append(rank(avg_return, torch.sort(bandit_values)[0])/5)
    # print([b.item() for b in bandits])

  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
# compute the performance metric: regret/(optimal_value - worst_value)
# 0 is best, 1 is worst
np.array(regrets)/(np.array(opt_values) - np.array(worst_values))

## Legacy Code (Ignore)

In [242]:
## Training 
loss = nn.CrossEntropyLoss() # cross-entropy loss for bandit identification
optimizer = torch.optim.Adam(adapter.parameters(),lr=1e-4)

EPOCHS = 5000
total_num = ICL_X.shape[0]
batch_size = 8
for epoch in tqdm(range(EPOCHS)):
    optimizer.zero_grad()
    indices = np.random.choice(range(total_num),batch_size,replace=False)
    batch_x = ICL_X[indices]
    batch_y = ICL_Y[indices]

    _,pred = adapter(batch_x.to(device),batch_y.to(device))
    pred_bandit = torch.argmax(pred,dim=-1)
    l = loss(pred[:,:,:].softmax(dim=-1), batch_y[:,:,:].to(device))

    l.backward()
    optimizer.step()

    if epoch%200 == 0:
        # print('selected bandits:',pred_bandit)
        print(l)
    

  0%|          | 0/5000 [00:00<?, ?it/s]

tensor(1.8020, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8295, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8209, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8258, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8092, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8200, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8119, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8075, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8362, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8214, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8175, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8261, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8110, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8223, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8019, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8150, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8008, device='cuda:1', grad_fn=<DivBackward1>)
tensor(1.8072, device='cuda:1', grad_fn=<DivBack

In [243]:
torch.save(adapter, '../models/less_naive_adapter.pth')

In [169]:
def test_model(model,input_r,horizon=100):
    
    selected_bandits = []
    s=[]
    new_a = torch.distributions.categorical.Categorical(torch.tensor([1/a_num]*a_num)).sample()
    new_a = nn.functional.one_hot(new_a,num_classes=a_num).float().to(device)
    a=[new_a] #dummy first action
    r=[] 
    
    for _ in range(horizon):
        # encounter a new state
        new_s = torch.randn(s_dim)
        s.append(new_s)
        r.append(torch.tensor(input_r))
        # construct inputs
        input_x = torch.cat([torch.stack(s,dim=0),torch.stack(r,dim=0).unsqueeze(-1)],dim=-1).unsqueeze(0)
        input_y = torch.stack(a,dim=0).unsqueeze(0)
        # Predict bandits
        _, pred = model(input_x.to(device),input_y.to(device))
        selected_bandit = torch.argmax(pred[-1,-1,:])
        selected_bandits.append(selected_bandit)
        a.append(nn.functional.one_hot(selected_bandit,num_classes=a_num).float().to(device))
        # print(selected_bandit)
        
    return selected_bandits

bandit_values = theta@Ma
bandit_values

tensor([ -5.2011,   1.6482,  -6.3739,  -0.1427,  -6.1827,  -9.7138,  -0.3569,
          6.4348,  -0.3145,  -5.9095,   3.1633,   2.1572,  -1.6072,   0.2380,
         -0.5632,  -1.2888,   7.2966,   2.7758,  -2.2375,  -2.5067,   1.3271,
         -0.0910,  -8.4142,   8.3373, -11.8403,  -2.1347,  -2.4112,   7.2236,
         -2.5696,  -7.2637,  -0.9239,   4.3019,  -0.2706,  -1.3537,  -0.7677,
         -3.6587,   2.4505,  -2.4339,   1.2420,   4.9552,   5.1487, -10.6287,
         -0.6562,   5.0139,   3.0126,   1.6920,  -3.0392,  -9.6241,  -1.6161,
          6.0712,   1.4880,  -2.9165,   1.6696,   0.6303,   1.3941,   4.2367,
          5.7465,   2.1158,   6.4666,   4.7635,   2.8982,  -2.9378,  -7.0493,
         -4.3799,   0.3168,   7.7574,   6.7376,  -4.2825,   5.0983,  -0.3487,
         -3.3717,   5.4457,  -4.9380,   5.0232,   4.0637,  -0.6166,   5.5513,
          4.3126,  -0.7402,  -3.0745,   3.8573,  -2.7212,   9.5248,  -6.1576,
         -1.7650,  -1.7771,   2.0558,  -8.9426,  -8.4776,  -3.47

In [239]:
torch.sort(bandit_values)[0]

tensor([-11.8403, -10.6287,  -9.7138,  -9.6241,  -8.9426,  -8.9091,  -8.4776,
         -8.4142,  -8.2671,  -7.2637,  -7.0493,  -6.3739,  -6.1827,  -6.1576,
         -5.9095,  -5.2011,  -4.9380,  -4.6948,  -4.3799,  -4.2825,  -3.6587,
         -3.4777,  -3.3717,  -3.0745,  -3.0392,  -2.9378,  -2.9165,  -2.7212,
         -2.5696,  -2.5067,  -2.4339,  -2.4112,  -2.2375,  -2.1347,  -1.9844,
         -1.7771,  -1.7650,  -1.6161,  -1.6072,  -1.3537,  -1.2888,  -0.9239,
         -0.7677,  -0.7402,  -0.6562,  -0.6166,  -0.5632,  -0.5455,  -0.3569,
         -0.3487,  -0.3145,  -0.2706,  -0.1427,  -0.0910,   0.2380,   0.3168,
          0.4516,   0.6303,   0.9487,   1.2420,   1.3271,   1.3941,   1.4756,
          1.4880,   1.6482,   1.6696,   1.6920,   1.7206,   2.0558,   2.1158,
          2.1572,   2.2572,   2.4505,   2.7758,   2.8982,   3.0126,   3.1633,
          3.8573,   4.0637,   4.2367,   4.3019,   4.3126,   4.7635,   4.9552,
          5.0139,   5.0232,   5.0983,   5.1487,   5.4457,   5.55

In [246]:
for input_r in [0.,10.,20.,100.,250.,500.]:
    runs = []
    for _ in range(5):
        bandits = test_model(adapter,input_r=input_r,horizon=50)
        # print([b.detach().cpu().item() for b in bandits])
        avg_rewards = np.average([bandit_values[b.detach().cpu().item()] for b in bandits])
        runs.append(avg_rewards)
    print(f'input_r:{input_r}', 'rewards', np.average(runs),np.var(runs))

input_r:0.0 rewards -1.1543026 0.17061827
input_r:10.0 rewards 0.24711378 0.28390676
input_r:20.0 rewards 2.7497773 0.24683237
input_r:100.0 rewards 5.7336683 0.007939255
input_r:250.0 rewards 0.34360862 0.04985754
input_r:500.0 rewards -2.5695734 0.0


In [248]:
R

tensor([ 5.4794e+00,  6.9455e+00, -1.6326e+01, -3.7359e+00,  1.2556e+01,
         2.4965e+01,  8.7280e+00, -1.4675e+01, -1.1836e+01, -9.5921e+00,
         1.8663e+01,  2.5561e+01, -2.2740e+01,  9.0587e-01,  9.7514e-01,
         2.9755e+01, -7.8016e+00,  7.1111e+00, -2.7794e+00,  1.0005e+01,
         1.3468e+01,  2.1822e+00,  4.7645e+00,  7.4162e+00, -1.7447e+01,
        -1.2855e+01,  1.2345e+01, -6.2606e+00,  5.2475e+00, -1.2295e+01,
         1.9056e+01, -1.6313e+01, -5.5074e+01,  4.4581e+01,  1.6872e+01,
         1.5547e+01,  4.0048e+00,  2.5713e+01, -1.5063e+01,  9.1060e+00,
        -3.3305e+00,  2.9580e-01,  1.3879e+01, -1.4045e+01, -7.4136e+00,
         2.2452e+01,  1.3520e+01, -2.2897e+01, -2.7340e+01, -5.0949e+00,
         7.9965e+00,  1.2841e+01, -3.5550e+01, -2.8889e+00,  7.4625e+00,
        -1.1878e+01, -9.6043e+00,  3.5396e+01,  2.1978e+00,  2.1904e+01,
         1.7308e+01,  7.2392e+00,  8.1534e+00, -1.8541e+01,  4.0366e+00,
        -1.4968e+01, -2.2632e+01,  2.2521e+00,  1.2

In [247]:
bandits

[tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28, device='cuda:1'),
 tensor(28