In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import attacks
import models
import pandas as pd
import utils
from tqdm.notebook import trange
torch.manual_seed(0)
import torch.nn as nn

dataloaders, dataset_sizes = utils.task_loader('online')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 
device

device(type='cuda', index=0)

In [7]:
offline_controler = models.load_controler()
offline_controler.load_state_dict(torch.load("offline_controler.pt" )['model_state'])
offline_controler.eval()

controler = models.load_controler()

target = models.load_target()

def random_exploration(p_feedback, p_otw):
    return np.random.choice( [True, False], p=[p_feedback, p_otw])

# def controler_random(X):
#     return [  np.random.choice( [1, 0], p=[0.5, 0.5])  for x in range(X.shape[0]) ] 

# def apple_tasting(controler, X, status):
#     for x in zip(X,status):

n_epochs = 1
bar = trange(n_epochs)
epoch_loss = {'train': 0, 'val': 0, 'train_offline':0,'val_offline':0}
dft = pd.DataFrame(columns=['train_bce', 'val_bce', 'train_offline_bce','val_offline_bce'])

optimizer = optim.Adam(controler.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

for epoch in bar:
    
    # Each epoch has a training and validation phase
    for phase in ['train', 'val']: 
        
        if phase == 'train':
            controler.train()  # Set model to training mode
            
        else:
            controler.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_loss2 = 0.0
        
        bar.set_description(f'Epoch {epoch} {phase}'.ljust(20))

        # Iterate over data.
        for X,b,y in dataloaders[phase]:
            X,y = X.to(device), y.to(device)
            
            yp2 = offline_controler( X )
            loss2 = torch.nn.BCELoss()( yp2, y )
            running_loss2 += loss2.item() * X.size(0)
            epoch_loss['{}_offline'.format(phase)] = running_loss2 / dataset_sizes[phase]
            dft.loc[epoch, '{}_offline_bce'.format(phase)] = epoch_loss['{}_offline'.format(phase)]
        
            
     
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward, track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                
                yp = controler( X )
#                 print(yp)
                feedback = [ idx for idx,pred in enumerate(yp)  if pred>0.5 or random_exploration(0.1, 0.9)  ] #
#                 print(feedback)
                if len(feedback)>0:
                    y_feedback = torch.stack([ yp[idx] for idx in feedback ]).type(torch.FloatTensor).to(device)
                    status_feedback = torch.stack([ y[idx] for idx in feedback ]).type(torch.FloatTensor).to(device)
                    status_feedback = torch.reshape(status_feedback, (-1,1) )
                    loss = torch.nn.BCELoss()( y_feedback, status_feedback )

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step() 
                    
                    running_loss += loss.item() * X.size(0)
                    epoch_loss[phase] = running_loss / dataset_sizes[phase]
                   
        if phase == 'train':
            scheduler.step()
            
        bar.set_postfix( train_loss=f'{epoch_loss["train"]:0.5f}', val_loss=f'{epoch_loss["val"]:0.5f}') 
        dft.loc[epoch, f'{phase}_bce'] = epoch_loss[phase]
        
        
        status = {
            "epoch": epoch,
            "model_state": controler.state_dict(),
            "optimizer_state": optimizer.state_dict() }
        
        torch.save(status, './online_controler.pt')

  0%|          | 0/50 [00:00<?, ?it/s]

In [15]:
import plotly.graph_objects as go

fig = go.Figure( )
fig.add_traces([ go.Scatter( x=dft.index, y=dft.train_bce, name='Online Random strategy' ) ]) #, showlegend =True
fig.add_traces([ go.Scatter( x=dft.index, y=dft.train_offline_bce, name='Offline strategy' ) ])
fig.update_yaxes(range=[0, 1])

fig.show(legend=True)


In [19]:
import plotly.graph_objects as go

fig = go.Figure( )
fig.add_traces([ go.Scatter( x=dft.index, y=dft.val_bce, name='Online Random strategy' ) ]) #, showlegend =True
fig.add_traces([ go.Scatter( x=dft.index, y=dft.val_offline_bce, name='Offline strategy' ) ])
fig.update_yaxes(range=[0, 0.01])

fig.show(legend=True)
