In [38]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import attacks
import models
import pandas as pd
import utils
from tqdm.notebook import trange
torch.manual_seed(0)
import torch.nn as nn
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

dataloaders, dataset_sizes = utils.task_loader('online')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 
device

device(type='cuda', index=0)

In [68]:
offline_controler = models.load_controler()
offline_controler.load_state_dict(torch.load("offline_controler.pt" )['model_state'])
offline_controler.eval()

controler = models.load_controler()

target = models.load_target()

def random_exploration(p_feedback, p_otw):
    return np.random.choice( [True, False], p=[p_feedback, p_otw])

# def controler_random(X):
#     return [  np.random.choice( [1, 0], p=[0.5, 0.5])  for x in range(X.shape[0]) ] 

# def apple_tasting(controler, X, status):
#     for x in zip(X,status):

bar = trange( len(dataloaders['train']) )
batch_loss = { 'online': 0, 'offline':0 }
dft = [] #pd.DataFrame(columns=['online_bce', 'offline_bce'])

optimizer = optim.Adam(controler.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

controler.train()  

    
for batch, data in enumerate(dataloaders['train']):
    X,label,y = data
    bar.set_description(f'Batch {batch}'.ljust(20))
    
    X,y = X.to(device), y.to(device)
        
    # zero the parameter gradients
    optimizer.zero_grad()
    # forward, track history if only in train
    with torch.set_grad_enabled(True):
        
        yp = controler( X )
        yp2 = offline_controler( X )
            
        indices = torch.IntTensor([ idx for idx,pred in enumerate(yp)  if pred>0.5 or random_exploration(0.1, 0.9)  ]).to(device) 

        y_feedback = torch.index_select(yp, 0, indices).to(device)
        y_feedback2 = torch.index_select(yp2, 0, indices).to(device)
        status_feedback = torch.index_select(y, 0, indices).to(device)
        
        if len(indices)>0:
            loss2 = torch.nn.BCELoss()( y_feedback2, status_feedback )
            loss = torch.nn.BCELoss()( y_feedback, status_feedback )
            loss.backward()
            optimizer.step() 
#             scheduler.step()
    
    
    status = {'online_bce': loss.item(),'offline_bce': loss2.item() }
    bar.set_postfix( status ) # 
    bar.update(batch)
#     dft = dft.append(status, ignore_index=True)
    dft.append([loss.item(),loss2.item()])
    
status = {  "model_state": controler.state_dict(), "optimizer_state": optimizer.state_dict() }     
torch.save(status, './online_controler.pt')

  0%|          | 0/235 [00:00<?, ?it/s]

In [73]:
import plotly.graph_objects as go

sequence = np.arange(0,len(dft),1)

dft = np.array(dft)
fig = go.Figure( )
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,0], name='Online Random strategy' ) ]) #, showlegend =True
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,1], name='Offline strategy' ) ])
fig.update_yaxes(range=[0, 1])

fig.show(legend=True)


In [55]:
dft

[[0.6854766011238098, 4.470356529395758e-08],
 [0.6954746842384338, 4.6566128730773926e-09],
 [0.7025421261787415, 2.0023438906946467e-08],
 [0.690346896648407, 1.2572858310022639e-08],
 [0.6908690929412842, 1.3504179996459698e-08],
 [0.6807330846786499, 4.190951585769653e-09],
 [0.7062134742736816, 1.9092123437758346e-08],
 [0.7084940671920776, 1.0244549208948683e-08],
 [0.694294810295105, 1.816079375771551e-08],
 [0.6659138798713684, 1.0244550097127103e-08],
 [0.6878027319908142, 9.313226634333205e-09],
 [0.7136024236679077, 1.0710211384434842e-08],
 [0.7059715986251831, 5.2619817836330185e-08],
 [0.7015688419342041, 1.9557781172352406e-08],
 [0.710766077041626, 2.7939679458910405e-09],
 [0.7080518007278442, 3.725290742551124e-09],
 [0.689747154712677, 8.847565347025466e-09],
 [0.6890531778335571, 6.5192584663975595e-09],
 [0.6993728876113892, 4.6566133171666024e-09],
 [0.6937435865402222, 2.3283075023528e-08],
 [0.7073265910148621, 1.3969843948302696e-08],
 [0.697676956653595, 5.075