In [1]:
from policy import *
import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

import matplotlib.pyplot as plt

if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: GeForce RTX 3070


In [2]:
def get_dataloaders(train_dataset, test_dataset, batch_size):
    
    train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=batch_size)
    test_dataloader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=len(test_dataset))
    
    return train_dataloader, test_dataloader

def get_rewards_vector(full_rewards, actions):
    actions_one_hot = F.one_hot(actions.long(), num_classes=full_rewards.size()[1]).float()
    r = torch.matmul(full_rewards.unsqueeze(1), actions_one_hot.unsqueeze(2)).squeeze()
    return r

def snips_loss(pi_w, pi_0, r, lamda):
    return torch.mean((1-r) * pi_w / pi_0) / torch.mean(pi_w / pi_0)

def banditnet_loss(pi_w, pi_0, r, lamda):
    return torch.mean(((1-r) - lamda) * (pi_w / pi_0))

In [24]:
from tqdm.auto import tqdm

def train_loop(model, optimizer, train_dataloader, test_dataloader, n_epochs, loss_func, lamda=0.9, model_name='temp'):
    
    train_losses = []
    test_values = []
    test_accuracies = []
    best_value = 0
    
    for t in tqdm(range(n_epochs)):
        # ========================================
        #               Training
        # ========================================
        model.train()
        tol_loss = 0
        
        for i, data in enumerate(train_dataloader):
            
            data = [i.to(device) for i in data]
            X, actions, pi_0, y, full_rewards = data
            
            pi_w = model.get_action_propensities(X, actions)
            r = get_rewards_vector(full_rewards, actions)
            loss = loss_func(pi_w, pi_0, r, lamda=lamda)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            tol_loss += loss
        
        train_losses.append(tol_loss/len(train_dataloader))
        
        # ========================================
        #               Testing
        # ========================================
        model.eval()
        X_test, y_test, full_rewards_test = test_dataloader.dataset.tensors
        X_test = torch.FloatTensor(X_test).to(device)
        full_rewards_test = torch.FloatTensor(full_rewards_test).to(device)
        
        # value
        value = model.get_value_estimate(X_test, full_rewards_test).item()
        
        # deterministic accuracy
        y_pred = torch.argmax(model.get_action_distribution(X_test), dim=1)
        accuracy = (y_pred.cpu().detach().numpy() == y_test.cpu().detach().numpy()).sum() / len(y_test)
        
        test_values.append(value)
        test_accuracies.append(accuracy)        
        
        # check if test value is increasing
        if value > best_value:
            torch.save(model, f'../models/{model_name}.pt')
            best_value = value

    return train_losses, test_values, test_accuracies, best_value

In [25]:
train_dataset = torch.load('../data/train_dataset.pt')
test_dataset = torch.load('../data/test_dataset.pt')

In [28]:
batch_sizes = [64, 128, 256, 512, 1024, 2048, 4096]
lr = 0.1
n_epochs = 50
model_dir = '../models/'

X_test, y_test, full_rewards_test = test_dataset.tensors
X_test = torch.FloatTensor(X_test).to(device)
full_rewards_test = torch.FloatTensor(full_rewards_test).to(device)
BN_values = []
BN_accuracies = []
SNIPS_values = []
SNIPS_accuracies = []

for batch_size in batch_sizes:
    
    print('Batch Size: ', batch_size)
    
    # get dataloaders
    train_dataloader, test_dataloader = get_dataloaders(train_dataset, test_dataset, batch_size=64)
    
    #===========================
    #   BanditNet
    #===========================
    loss_func = banditnet_loss
    model_name = f'{batch_size}-BanditNet'
    
    model = LogisticPolicy(num_actions=26, num_features=16)
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    BN_results = train_loop(model, optimizer, train_dataloader, test_dataloader, n_epochs=n_epochs, 
                            loss_func=loss_func, model_name=model_name)
    
    ## load best model
    model = torch.load(model_dir + f'{model_name}.pt')
    model = model.to(device)
    ## value
    value = model.get_value_estimate(X_test, full_rewards_test).item()
    
    ## deterministic accuracy
    y_pred = torch.argmax(model.get_action_distribution(X_test), dim=1)
    accuracy = ((y_pred.cpu().detach().numpy() == y_test.cpu().detach().numpy()).sum() / len(y_test)).item()
    
    BN_values.append(value)
    BN_accuracies.append(accuracy)
    
    print('BanditNet Value: ', value)
    print('BanditNet Accuracy: ', accuracy)
    
    #===========================
    #   SNIPS
    #===========================
    loss_func = snips_loss
    model_name = f'{batch_size}-SNIPS'
    
    model = LogisticPolicy(num_actions=26, num_features=16)
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    SNIPS_results = train_loop(model, optimizer, train_dataloader, test_dataloader, n_epochs=n_epochs, 
                               loss_func=loss_func, model_name=model_name)
    
    ## load best model
    model = torch.load(model_dir + f'{model_name}.pt')
    
    ## value
    value = model.get_value_estimate(X_test, full_rewards_test).item()
    
    ## deterministic accuracy
    y_pred = torch.argmax(model.get_action_distribution(X_test), dim=1)
    accuracy = ((y_pred.cpu().detach().numpy() == y_test.cpu().detach().numpy()).sum() / len(y_test)).item()
    
    SNIPS_values.append(value)
    SNIPS_accuracies.append(accuracy)
    
    print('SNIPS Value: ', value)
    print('SNIPS Accuracy: ', accuracy)
    print('='*50)
    print('\n\n')

    
# Save results to a DataFrame
results_df = pd.DataFrame({
    'Batch_Size': batch_sizes,
    'BanditNet_Value': BN_values,
    'BanditNet_Accuracy': BN_accuracies,
    'SNIPS_Value': SNIPS_values,
    'SNIPS_Accuracy': SNIPS_accuracies
}) 

Batch Size:  64


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.17206531763076782
BanditNet Accuracy:  0.1735


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.1353093981742859
SNIPS Accuracy:  0.13575



Batch Size:  128


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.10615689307451248
BanditNet Accuracy:  0.10675


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.13924174010753632
SNIPS Accuracy:  0.13875



Batch Size:  256


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.13264243304729462
BanditNet Accuracy:  0.133


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.07123489677906036
SNIPS Accuracy:  0.0715



Batch Size:  512


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.19758711755275726
BanditNet Accuracy:  0.1995


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


SNIPS Value:  0.10518762469291687
SNIPS Accuracy:  0.106



Batch Size:  1024


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


BanditNet Value:  0.07706692814826965
BanditNet Accuracy:  0.07725


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))




KeyboardInterrupt: 

In [12]:
results_df

Unnamed: 0,Batch_Size,BanditNet_Value,BanditNet_Accuracy,SNIPS_Value,SNIPS_Accuracy
0,64,0.040754,0.04075,0.04256,0.0425
1,128,0.039759,0.03975,0.036672,0.0365


In [17]:
results_df.to_csv('../results.csv', index=False)