In [120]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import attacks
import models
import pandas as pd
import utils
from tqdm.notebook import trange
torch.manual_seed(0)
import torch.nn as nn
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from sklearn.metrics import confusion_matrix

dataloaders, dataset_sizes = utils.task_loader('online')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 
device

device(type='cuda', index=0)

In [118]:
def precision_recall_matrix(truth,prediction):
    prediction = torch.round(prediction)
    confusion_vector = prediction / truth
    print(prediction)

    true_positives = torch.sum(confusion_vector == 1).item() / len(truth)
    false_positives = torch.sum(confusion_vector == float('inf')).item() / len(truth)
    true_negatives = torch.sum(torch.isnan(confusion_vector)).item() / len(truth)
    false_negatives = torch.sum(confusion_vector == 0).item() / len(truth)

    return [true_positives,true_negatives,false_positives,false_negatives]

In [177]:
offline_controler = models.load_controler()
offline_controler.load_state_dict(torch.load("offline_controler.pt" )['model_state'])
offline_controler.eval()

controler = models.load_controler()

target = models.load_target()

def random_exploration(p_feedback, p_otw):
    return np.random.choice( [True, False], p=[p_feedback, p_otw])


bar = trange( len(dataloaders['train']) )
batch_loss = { 'online': 0, 'offline':0 }
dft = [] 
optimizer = optim.Adam(controler.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

controler.train()  

    
for batch, data in enumerate(dataloaders['train']):
    
    X,label,y = data
    bar.set_description(f'Batch {batch}'.ljust(20))
    X,y = X.to(device), y.to(device)
    oracle = sum(y).item() / len(y)
    optimizer.zero_grad()

    with torch.set_grad_enabled(True):
        
        #### ONLINE:
        yp = controler( X )
        
        pr_matrix = confusion_matrix(y.detach().cpu().numpy(), torch.round(yp).detach().cpu().numpy(), 
                                     normalize ='all' )

        indices = torch.IntTensor([ idx for idx,pred in enumerate(yp)  if pred>0.5 or random_exploration(0.1, 0.9) ]).to(device) 
        y_feedback = torch.index_select(yp, 0, indices).to(device)
        status_feedback = torch.index_select(y, 0, indices).to(device)
        loss = torch.nn.BCELoss()( y_feedback, status_feedback )
        loss.backward()
        optimizer.step() 
#         scheduler.step()
        
        #### OFFLINE
        yp2 = offline_controler( X )
        pr_matrix2 = confusion_matrix(y.detach().cpu().numpy(), torch.round(yp2).detach().cpu().numpy(), 
                                      normalize ='all' )
        indices2 = torch.IntTensor([ idx for idx,pred in enumerate(yp2) if pred>0.5 or random_exploration(0.1, 0.9) ]).to(device)
        y_feedback2 = torch.index_select(yp2, 0, indices2).to(device)
        status_feedback2 = torch.index_select(y, 0, indices2).to(device)
        loss2 = torch.nn.BCELoss()( y_feedback2, status_feedback2 )
            
#     status = {'online_bce': loss.item(), 'offline_bce': loss2.item() } 
#     bar.set_postfix( status ) 
    bar.update(batch)
    dft.append( [loss.item()] + [loss2.item()] + pr_matrix.ravel().tolist() + pr_matrix2.ravel().tolist() +[oracle] ) 
    
# status = {  "model_state": controler.state_dict(), "optimizer_state": optimizer.state_dict() }     
# torch.save(status, './online_controler.pt')

  0%|          | 0/235 [00:00<?, ?it/s]

In [178]:
import plotly.graph_objects as go

sequence = np.arange(0,len(dft),1)
dft = np.array(dft)

fig = go.Figure( )
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,0], name='Online Random strategy' ) ]) #, showlegend =True
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,1], name='Offline strategy' ) ])
fig.update_yaxes(range=[0, 1])
fig.show(legend=True)


In [193]:
sequence = np.arange(0,len(dft),1)
dft = np.array(dft)
fig = go.Figure( )
# fig.add_traces([ go.Scatter( x=sequence, y=dft[:,2], name='True Negative',line={ 'color': "Red", "width":1} ) ]) 
# fig.add_traces([ go.Scatter( x=sequence, y=dft[:,3], name='False Negative',line={ 'color': "Orange", "width":1} ) ])
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,4], name='False Positive',line={ 'color': "Green", "width":1} ) ])
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,5], name='True Positive',line={ 'color': "Cyan", "width":1} ) ])
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,10], name='Oracle',line={ 'color': "Black", "width":1,"dash":"dot" } ) ])




fig.update_yaxes(range=[0, 1])
fig.show(legend=True)


In [194]:
fig = go.Figure( )
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,6], name='True Negative',line={ 'color': "Red", "width":1} ) ]) 
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,7], name='False Negative',line={ 'color': "Orange", "width":1} ) ])
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,8], name='False Positive',line={ 'color': "Green", "width":1} ) ])
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,9], name='True Positive',line={ 'color': "Cyan", "width":1} ) ])
fig.add_traces([ go.Scatter( x=sequence, y=dft[:,10], name='Oracle',line={ 'color': "Black", "width":1,"dash":"dot" } ) ])
fig.update_yaxes(range=[0, 1])
fig.show(legend=True)
