In [1]:
from policy import *
import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

import matplotlib.pyplot as plt

if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: GeForce RTX 3070


In [2]:
def get_dataloaders(train_dataset, test_dataset, batch_size):
    
    train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=batch_size)
    test_dataloader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=len(test_dataset))
    
    return train_dataloader, test_dataloader

def get_rewards_vector(full_rewards, actions):
    actions_one_hot = F.one_hot(actions.long(), num_classes=full_rewards.size()[1]).float()
    r = torch.matmul(full_rewards.unsqueeze(1), actions_one_hot.unsqueeze(2)).squeeze()
    return r

def snips_loss(pi_w, pi_0, r, lamda=None):
    return torch.mean((1-r) * pi_w / pi_0) / torch.mean(pi_w / pi_0)

def banditnet_loss(pi_w, pi_0, r, lamda):
    return torch.mean(((1-r) - lamda) * (pi_w / pi_0))

In [3]:
from tqdm.auto import tqdm

def train_loop(model, optimizer, train_dataloader, test_dataloader, n_epochs, loss_func, lamda=0.9, model_path='temp'):
    
    train_losses = []
    test_values = []
    test_accuracies = []
    best_value = 0
    
    for t in tqdm(range(n_epochs)):
        # ========================================
        #               Training
        # ========================================
        model.train()
        tol_loss = 0
        
        for i, data in enumerate(train_dataloader):
            
            data = [i.to(device) for i in data]
            X, actions, pi_0, y, full_rewards = data
            
            pi_w = model.get_action_propensities(X, actions)
            r = get_rewards_vector(full_rewards, actions)
            loss = loss_func(pi_w, pi_0, r, lamda=lamda)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            tol_loss += loss
        
        train_losses.append(tol_loss/len(train_dataloader))
        
        # ========================================
        #               Testing
        # ========================================
        model.eval()
        X_test, y_test, full_rewards_test = test_dataloader.dataset.tensors
        X_test = torch.FloatTensor(X_test).to(device)
        full_rewards_test = torch.FloatTensor(full_rewards_test).to(device)
        
        # value
        value = model.get_value_estimate(X_test, full_rewards_test).item()
        
        # deterministic accuracy
        y_pred = torch.argmax(model.get_action_distribution(X_test), dim=1)
        accuracy = (y_pred.cpu().detach().numpy() == y_test.cpu().detach().numpy()).sum() / len(y_test)
        
        test_values.append(value)
        test_accuracies.append(accuracy)        
        
        # check if test value is increasing
        if value > best_value:
            torch.save(model, f'{model_path}.pt')
            best_value = value

    return train_losses, test_values, test_accuracies, best_value

In [4]:
train_dataset = torch.load('../data/train_dataset.pt')
test_dataset = torch.load('../data/test_dataset.pt')

In [6]:
batch_sizes = [64, 128, 256, 512, 1024, 2048, 4096]
lr = 0.1
n_epochs = 50
model_dir = '../models/0/'

X_test, y_test, full_rewards_test = test_dataset.tensors
X_test = torch.FloatTensor(X_test).to(device)
full_rewards_test = torch.FloatTensor(full_rewards_test).to(device)
BN_values = []
BN_losses = []
SNIPS_values = []
SNIPS_losses = []

for batch_size in batch_sizes:
    
    print('Batch Size: ', batch_size)
    
    # get dataloaders
    train_dataloader, test_dataloader = get_dataloaders(train_dataset, test_dataset, batch_size=64)
    
    #===========================
    #   BanditNet
    #===========================
    loss_func = banditnet_loss
    model_name = model_dir + f'{batch_size}-BanditNet'
    
    model = LogisticPolicy(num_actions=26, num_features=16)
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    BN_results = train_loop(model, optimizer, train_dataloader, test_dataloader, n_epochs=n_epochs, 
                            loss_func=loss_func, model_path=model_name)
    
    ## load best model
    model = torch.load(f'{model_name}.pt')
    model = model.to(device)
    ## value
    value = model.get_value_estimate(X_test, full_rewards_test).item()
    
    ## snips loss
    pi_w = model.get_action_propensities(train_dataset.tensors[0].to(device), train_dataset.tensors[1].to(device))
    r = get_rewards_vector(train_dataset.tensors[-1].to(device), train_dataset.tensors[1].to(device))
    loss = snips_loss(pi_w, train_dataset.tensors[2].to(device), r).item()
    
    BN_values.append(value)
    BN_losses.append(loss)
    
    print('BanditNet Value: ', value)
    print('BanditNet Loss: ', loss)
    
    #===========================
    #   SNIPS
    #===========================
    loss_func = snips_loss
    model_name = model_dir + f'{batch_size}-SNIPS'
    
    model = LogisticPolicy(num_actions=26, num_features=16)
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    SNIPS_results = train_loop(model, optimizer, train_dataloader, test_dataloader, n_epochs=n_epochs, 
                               loss_func=loss_func, model_path=model_name)
    
    ## load best model
    model = torch.load(f'{model_name}.pt')
    model = model.to(device)
    ## value
    value = model.get_value_estimate(X_test, full_rewards_test).item()
    
    ## snips loss
    pi_w = model.get_action_propensities(train_dataset.tensors[0].to(device), train_dataset.tensors[1].to(device))
    r = get_rewards_vector(train_dataset.tensors[-1].to(device), train_dataset.tensors[1].to(device))
    loss = snips_loss(pi_w, train_dataset.tensors[2].to(device), r).item()
    
    SNIPS_values.append(value)
    SNIPS_losses.append(loss)
    
    print('SNIPS Value: ', value)
    print('SNIPS Loss: ', loss)
    print('='*50)
    print('\n\n')

    
# Save results to a DataFrame
results_df = pd.DataFrame({
    'Batch_Size': batch_sizes,
    'BanditNet_Value': BN_values,
    'BanditNet_Loss': BN_losses,
    'SNIPS_Value': SNIPS_values,
    'SNIPS_Loss': SNIPS_losses
}) 

Batch Size:  64


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.07261142134666443
BanditNet Loss:  0.9180900454521179


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.08634651452302933
SNIPS Loss:  0.8843541741371155



Batch Size:  128


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.0931156575679779
BanditNet Loss:  0.8974903225898743


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.08844666182994843
SNIPS Loss:  0.8934233784675598



Batch Size:  256


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.16151724755764008
BanditNet Loss:  0.8295487761497498


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.11497370153665543
SNIPS Loss:  0.8427345156669617



Batch Size:  512


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.12919048964977264
BanditNet Loss:  0.788617730140686


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.08611742407083511
SNIPS Loss:  0.8586195111274719



Batch Size:  1024


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.20250867307186127
BanditNet Loss:  0.7358907461166382


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.1266627311706543
SNIPS Loss:  0.8609011769294739



Batch Size:  2048


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.2068261206150055
BanditNet Loss:  0.7115663886070251


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.09128447622060776
SNIPS Loss:  0.8904736042022705



Batch Size:  4096


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.10231105983257294
BanditNet Loss:  0.8612636923789978


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.13666611909866333
SNIPS Loss:  0.8572790026664734





In [7]:
results_df

Unnamed: 0,Batch_Size,BanditNet_Value,BanditNet_Loss,SNIPS_Value,SNIPS_Loss
0,64,0.072611,0.91809,0.086347,0.884354
1,128,0.093116,0.89749,0.088447,0.893423
2,256,0.161517,0.829549,0.114974,0.842735
3,512,0.12919,0.788618,0.086117,0.85862
4,1024,0.202509,0.735891,0.126663,0.860901
5,2048,0.206826,0.711566,0.091284,0.890474
6,4096,0.102311,0.861264,0.136666,0.857279


In [8]:
results_df.to_csv('../models/0/results.csv', index=False)

In [9]:
batch_sizes = [64, 128, 256, 512, 1024, 2048, 4096]
lr = 0.1
n_epochs = 50
model_dir = '../models/1/'

X_test, y_test, full_rewards_test = test_dataset.tensors
X_test = torch.FloatTensor(X_test).to(device)
full_rewards_test = torch.FloatTensor(full_rewards_test).to(device)
BN_values = []
BN_losses = []
SNIPS_values = []
SNIPS_losses = []

for batch_size in batch_sizes:
    
    print('Batch Size: ', batch_size)
    
    # get dataloaders
    train_dataloader, test_dataloader = get_dataloaders(train_dataset, test_dataset, batch_size=64)
    
    #===========================
    #   BanditNet
    #===========================
    loss_func = banditnet_loss
    model_name = model_dir + f'{batch_size}-BanditNet'
    
    model = LogisticPolicy(num_actions=26, num_features=16)
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    BN_results = train_loop(model, optimizer, train_dataloader, test_dataloader, n_epochs=n_epochs, 
                            loss_func=loss_func, model_path=model_name)
    
    ## load best model
    model = torch.load(f'{model_name}.pt')
    model = model.to(device)
    ## value
    value = model.get_value_estimate(X_test, full_rewards_test).item()
    
    ## snips loss
    pi_w = model.get_action_propensities(train_dataset.tensors[0].to(device), train_dataset.tensors[1].to(device))
    r = get_rewards_vector(train_dataset.tensors[-1].to(device), train_dataset.tensors[1].to(device))
    loss = snips_loss(pi_w, train_dataset.tensors[2].to(device), r).item()
    
    BN_values.append(value)
    BN_losses.append(loss)
    
    print('BanditNet Value: ', value)
    print('BanditNet Loss: ', loss)
    
    #===========================
    #   SNIPS
    #===========================
    loss_func = snips_loss
    model_name = model_dir + f'{batch_size}-SNIPS'
    
    model = LogisticPolicy(num_actions=26, num_features=16)
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    SNIPS_results = train_loop(model, optimizer, train_dataloader, test_dataloader, n_epochs=n_epochs, 
                               loss_func=loss_func, model_path=model_name)
    
    ## load best model
    model = torch.load(f'{model_name}.pt')
    model = model.to(device)
    ## value
    value = model.get_value_estimate(X_test, full_rewards_test).item()
    
    ## snips loss
    pi_w = model.get_action_propensities(train_dataset.tensors[0].to(device), train_dataset.tensors[1].to(device))
    r = get_rewards_vector(train_dataset.tensors[-1].to(device), train_dataset.tensors[1].to(device))
    loss = snips_loss(pi_w, train_dataset.tensors[2].to(device), r).item()
    
    SNIPS_values.append(value)
    SNIPS_losses.append(loss)
    
    print('SNIPS Value: ', value)
    print('SNIPS Loss: ', loss)
    print('='*50)
    print('\n\n')

    
# Save results to a DataFrame
results_df = pd.DataFrame({
    'Batch_Size': batch_sizes,
    'BanditNet_Value': BN_values,
    'BanditNet_Loss': BN_losses,
    'SNIPS_Value': SNIPS_values,
    'SNIPS_Loss': SNIPS_losses
}) 

Batch Size:  64


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.20689363777637482
BanditNet Loss:  0.7416794300079346


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.10502137243747711
SNIPS Loss:  0.8610108494758606



Batch Size:  128


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.19380448758602142
BanditNet Loss:  0.780174195766449


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.12448729574680328
SNIPS Loss:  0.8636963367462158



Batch Size:  256


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.12684370577335358
BanditNet Loss:  0.8688867092132568


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.11913691461086273
SNIPS Loss:  0.8620232939720154



Batch Size:  512


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.18853257596492767
BanditNet Loss:  0.780601441860199


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.15687847137451172
SNIPS Loss:  0.7868620753288269



Batch Size:  1024


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.1756957620382309
BanditNet Loss:  0.7815707921981812


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.10975830256938934
SNIPS Loss:  0.8945661783218384



Batch Size:  2048


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.22271673381328583
BanditNet Loss:  0.7290008068084717


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.10528328269720078
SNIPS Loss:  0.9141712784767151



Batch Size:  4096


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.21021346747875214
BanditNet Loss:  0.7174968123435974


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.1324721723794937
SNIPS Loss:  0.858963131904602





In [10]:
results_df

Unnamed: 0,Batch_Size,BanditNet_Value,BanditNet_Loss,SNIPS_Value,SNIPS_Loss
0,64,0.206894,0.741679,0.105021,0.861011
1,128,0.193804,0.780174,0.124487,0.863696
2,256,0.126844,0.868887,0.119137,0.862023
3,512,0.188533,0.780601,0.156878,0.786862
4,1024,0.175696,0.781571,0.109758,0.894566
5,2048,0.222717,0.729001,0.105283,0.914171
6,4096,0.210213,0.717497,0.132472,0.858963


In [11]:
results_df.to_csv('../models/1/results.csv', index=False)

In [12]:
batch_sizes = [64, 128, 256, 512, 1024, 2048, 4096]
lr = 0.1
n_epochs = 50
model_dir = '../models/2/'

X_test, y_test, full_rewards_test = test_dataset.tensors
X_test = torch.FloatTensor(X_test).to(device)
full_rewards_test = torch.FloatTensor(full_rewards_test).to(device)
BN_values = []
BN_losses = []
SNIPS_values = []
SNIPS_losses = []

for batch_size in batch_sizes:
    
    print('Batch Size: ', batch_size)
    
    # get dataloaders
    train_dataloader, test_dataloader = get_dataloaders(train_dataset, test_dataset, batch_size=64)
    
    #===========================
    #   BanditNet
    #===========================
    loss_func = banditnet_loss
    model_name = model_dir + f'{batch_size}-BanditNet'
    
    model = LogisticPolicy(num_actions=26, num_features=16)
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    BN_results = train_loop(model, optimizer, train_dataloader, test_dataloader, n_epochs=n_epochs, 
                            loss_func=loss_func, model_path=model_name)
    
    ## load best model
    model = torch.load(f'{model_name}.pt')
    model = model.to(device)
    ## value
    value = model.get_value_estimate(X_test, full_rewards_test).item()
    
    ## snips loss
    pi_w = model.get_action_propensities(train_dataset.tensors[0].to(device), train_dataset.tensors[1].to(device))
    r = get_rewards_vector(train_dataset.tensors[-1].to(device), train_dataset.tensors[1].to(device))
    loss = snips_loss(pi_w, train_dataset.tensors[2].to(device), r).item()
    
    BN_values.append(value)
    BN_losses.append(loss)
    
    print('BanditNet Value: ', value)
    print('BanditNet Loss: ', loss)
    
    #===========================
    #   SNIPS
    #===========================
    loss_func = snips_loss
    model_name = model_dir + f'{batch_size}-SNIPS'
    
    model = LogisticPolicy(num_actions=26, num_features=16)
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    SNIPS_results = train_loop(model, optimizer, train_dataloader, test_dataloader, n_epochs=n_epochs, 
                               loss_func=loss_func, model_path=model_name)
    
    ## load best model
    model = torch.load(f'{model_name}.pt')
    model = model.to(device)
    ## value
    value = model.get_value_estimate(X_test, full_rewards_test).item()
    
    ## snips loss
    pi_w = model.get_action_propensities(train_dataset.tensors[0].to(device), train_dataset.tensors[1].to(device))
    r = get_rewards_vector(train_dataset.tensors[-1].to(device), train_dataset.tensors[1].to(device))
    loss = snips_loss(pi_w, train_dataset.tensors[2].to(device), r).item()
    
    SNIPS_values.append(value)
    SNIPS_losses.append(loss)
    
    print('SNIPS Value: ', value)
    print('SNIPS Loss: ', loss)
    print('='*50)
    print('\n\n')

    
# Save results to a DataFrame
results_df = pd.DataFrame({
    'Batch_Size': batch_sizes,
    'BanditNet_Value': BN_values,
    'BanditNet_Loss': BN_losses,
    'SNIPS_Value': SNIPS_values,
    'SNIPS_Loss': SNIPS_losses
}) 

Batch Size:  64


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.2172599881887436
BanditNet Loss:  0.7124348282814026


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.10756157338619232
SNIPS Loss:  0.8722429275512695



Batch Size:  128


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.10906437039375305
BanditNet Loss:  0.881159782409668


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.11805941164493561
SNIPS Loss:  0.8896043300628662



Batch Size:  256


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.10586566478013992
BanditNet Loss:  0.8666161894798279


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.14123770594596863
SNIPS Loss:  0.8656289577484131



Batch Size:  512


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.19014452397823334
BanditNet Loss:  0.7490904331207275


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.07949547469615936
SNIPS Loss:  0.9132755994796753



Batch Size:  1024


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.21061521768569946
BanditNet Loss:  0.7350210547447205


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.12964075803756714
SNIPS Loss:  0.8366609215736389



Batch Size:  2048


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.1894395798444748
BanditNet Loss:  0.7639780640602112


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.10555371642112732
SNIPS Loss:  0.8914798498153687



Batch Size:  4096


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.19890069961547852
BanditNet Loss:  0.7532165050506592


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.10635340213775635
SNIPS Loss:  0.8741147518157959





In [13]:
results_df

Unnamed: 0,Batch_Size,BanditNet_Value,BanditNet_Loss,SNIPS_Value,SNIPS_Loss
0,64,0.21726,0.712435,0.107562,0.872243
1,128,0.109064,0.88116,0.118059,0.889604
2,256,0.105866,0.866616,0.141238,0.865629
3,512,0.190145,0.74909,0.079495,0.913276
4,1024,0.210615,0.735021,0.129641,0.836661
5,2048,0.18944,0.763978,0.105554,0.89148
6,4096,0.198901,0.753217,0.106353,0.874115


In [14]:
results_df.to_csv('../models/2/results.csv', index=False)