# Everything

## Imports

In [0]:
import os

if not os.path.isdir('Preference_Extraction'):
    print("Setting up colab environment")
    !pip uninstall -y -q pyarrow
    !pip install -q https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev5-cp36-cp36m-manylinux1_x86_64.whl
    !pip install -q ray[debug]
    !pip install 'ray[tune]' 
    !pip install bayesian-optimization

    !git clone https://github.com/arunraja-hub/Preference_Extraction.git
    # # A hack to force the runtime to restart, needed to include the above dependencies.
    # # Only after first time
    os._exit(0)

In [0]:
## If you are running on Google Colab, please install TensorFlow 2.0 by uncommenting below..

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

In [0]:
from __future__ import print_function
import argparse
import os
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import label_binarize

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms.functional as TF
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
import torch.autograd as autograd
from torchsummary import summary

from sklearn.utils import shuffle
import tensorflow as tf
import concurrent.futures
import itertools
import os
import random
import sys
import time
import re
import io
import itertools
import sys

import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.suggest.bayesopt import BayesOptSearch

sys.path.append('Preference_Extraction')
from imports_data import all_load_data

In [0]:
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
device

device(type='cuda')

## Parameters

In [0]:
params = {
    'num_train': 50,
    'num_tune': 25,
    'num_val': 400,
    'batch_size': 10,
    'val_batch_size': 10,
    'num_epochs': 100,
    'use_qnet_weights': True, # Flag for running models that use the weights of Qnet vs models that use random weights
    'use_mnist': False,  # Flag for running models on MNIST. If False uses RL Preference Extraction data
    'num_run': 5  # Number of runs (with different data sample) over which to average performance
}

## Subnets Methods

In [0]:
"""
    Original code from What's hidden in a randomly weighted neural network? paper
    Implemented at https://github.com/allenai/hidden-networks
    Remove weigths-initialisation since it is not relevant for us
"""

class GetSubnet(autograd.Function):
    @staticmethod
    def forward(ctx, scores, k):
        # Get the supermask by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1 - k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1

        return out

    @staticmethod
    def backward(ctx, g):
        # send the gradient g straight-through on the backward pass.
        return g, None

class SupermaskConv(nn.Conv2d):
    def __init__(self, *args, k, scores_init='kaiming_uniform', **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k
        self.scores_init = scores_init

        # initialize the scores
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        if self.scores_init == 'kaiming_normal':
          nn.init.kaiming_normal_(self.scores)
        elif self.scores_init == 'kaiming_uniform':
          nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))
        elif self.scores_init == 'xavier_normal':
          nn.init.xavier_normal_(self.scores)
        elif self.scores_init == 'best_activation':
          nn.init.ones_(self.scores)
        else:
          nn.init.uniform_(self.scores)

        # initialize the weights
        nn.init.uniform_(self.weight)
        
        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), self.k)
        w = self.weight * subnet
        x = F.conv2d(
            x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x

class SupermaskLinear(nn.Linear):
    def __init__(self, *args, k, scores_init='kaiming_uniform', **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k
        self.scores_init = scores_init

        # initialize the scores and weights
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        if self.scores_init == 'kaiming_normal':
          nn.init.kaiming_normal_(self.scores)
        elif self.scores_init == 'kaiming_uniform':
          nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))
        elif self.scores_init == 'xavier_normal':
          nn.init.xavier_normal_(self.scores)
        elif self.scores_init == 'best_activation':
          nn.init.ones_(self.scores)
        else:
          nn.init.uniform_(self.scores)

        nn.init.uniform_(self.weight)

        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), self.k)
        w = self.weight * subnet
        return F.linear(x, w, self.bias)
        return x

# NOTE: not used here but we use NON-AFFINE Normalization!
# So there is no learned parameters for your nomralization layer.
class NonAffineBatchNorm(nn.BatchNorm2d):
    def __init__(self, dim):
        super(NonAffineBatchNorm, self).__init__(dim, affine=False)

## Define architecture

In [0]:
class PrefQNet(nn.Module):
    """
      If q_head_index is None, this uses a linear model on the normalized q outputs.
      Otherwise, it gets the Q head with the specified index.
    """ 
    def __init__(self, fine_tune, k, q_head_index, q_means_stds, use_last_linear, init_from_act_index=None):
        super(PrefQNet, self).__init__()
        
        if not params['use_mnist']:
            channels_in = 5
            flattened_shape = 960
        else:
            channels_in = 1
            flattened_shape = 4608

        if fine_tune:
            conv_layer = nn.Conv2d
            dense_layer = nn.Linear
            additional_args = {}
            init_from_act_index = None
        else:
            conv_layer = SupermaskConv
            dense_layer = SupermaskLinear
            additional_args = {'k': k}
            if init_from_act_index is not None:
                additional_args['scores_init'] = 'best_activation'
        
        self.conv1 = conv_layer(in_channels=channels_in, out_channels=16, kernel_size=3, stride=1, bias=True, **additional_args)
        self.conv2 = conv_layer(in_channels=16, out_channels=32, kernel_size=3, stride=2, bias=True, **additional_args)
        self.fc1 = dense_layer(in_features=flattened_shape, out_features=64, bias=True, **additional_args)
        self.fc2 = dense_layer(in_features=64, out_features=3, bias=True, **additional_args)
        
        if init_from_act_index is not None:
            init_scores = np.zeros((3, 64))
            init_scores[:, init_from_act_index] = 1.0
            self.fc2.scores.data = torch.from_numpy(init_scores).float()

        self.fc3 = dense_layer(in_features=3, out_features=1, bias=True, **additional_args)
        self.linear = nn.Linear(1, 1, bias=True)

        self.qix = q_head_index
        self.qu_mu_s = q_means_stds
        self.use_last_linear = use_last_linear

    def fwd_conv1(self, x):
        x = self.conv1(x)
        return F.relu(x)

    def fwd_conv2(self, x):
        x = self.fwd_conv1(x)
        x = self.conv2(x)
        return F.relu(x)

    def fwd_flat(self, x):
        x = self.fwd_conv2(x)
        return torch.flatten(torch.transpose(x, 1, 3), 1) # Pre-flattening transpose is necessary for TF-Torch conversion

    def fwd_fc1(self, x):
        x = self.fwd_flat(x)
        x = self.fc1(x)
        return F.relu(x)
    
    def fwd_fc2(self, x):
        x = self.fwd_fc1(x)
        return self.fc2(x)

    def forward(self, x):
        x = self.fwd_fc2(x)

        x -= torch.tensor(self.qu_mu_s[0], device=device)
        x /= torch.tensor(self.qu_mu_s[1], device=device)

        if self.qix == None:
          x = self.fc3(x)
        else:
          x = x[: ,self.qix:self.qix+1]

        if self.use_last_linear:
          x = self.linear(x)

        x = torch.sigmoid(x)
        return x.flatten()

## Load Data

In [0]:
# Run this cell for training with original RL Preference Extraction data
if not params['use_mnist']:
    all_raw_data = all_load_data("Preference_Extraction/data/simple_env_1/")

    activations = []
    observations = []
    preferences = []

    for data in all_raw_data:
        for i in range(data.observation.shape[0]):
            observations.append(np.copy(data.observation[i]))
            activations.append(np.copy(data.policy_info["activations"][i]))
            preferences.append((data.policy_info['satisfaction'].as_list()[i] > -6).astype(int))

    activations = np.array(activations)

    xs = np.rollaxis(np.array(observations), 3, 1) # Torch wants channel-first
    ys = np.array(preferences)

In [0]:
# Run this cell for training with MNIST
if params['use_mnist']:
    tr_data_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), batch_size=params['batch_size'], shuffle=True)

    val_data_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist', train=False, download=True, 
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])), batch_size=params['val_batch_size'], shuffle=True)

## Loading Weights

In [10]:
new_save_path = "Preference_Extraction/saved_model2"
restored_model = tf.keras.models.load_model(new_save_path)
restored_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
EncodingNetwork/conv2d (Conv (None, 12, 14, 16)        736       
_________________________________________________________________
EncodingNetwork/conv2d_1 (Co (None, 5, 6, 32)          4640      
_________________________________________________________________
flatten (Flatten)            (None, 960)               0         
_________________________________________________________________
EncodingNetwork/dense (Dense (None, 64)                61504     
_________________________________________________________________
dense (Dense)                (None, 3)                 195       
Total params: 67,075
Trainable params: 67,075
Non-trainable params: 0
_________________________________________________________________


In [0]:
original_weights=restored_model.get_weights()

In [0]:
def load_weights(model):
    if not params['use_mnist']:
        model.conv1.weight.data = torch.from_numpy(np.transpose(original_weights[0]))
        model.fc1.weight.data = torch.from_numpy(np.transpose(original_weights[4]))
    else:
        model.conv1.weight.data = torch.from_numpy(np.transpose(original_weights[0][:,:,:1,:]))
        mnist_flt_weights = np.random.rand(64, 4608)
        mnist_flt_weights[:, :original_weights[4].shape[0]] = np.transpose(original_weights[4])
        mnist_flt_weights = mnist_flt_weights.astype(np.float32)
        model.fc1.weight.data = torch.from_numpy(mnist_flt_weights)

    model.conv1.bias.data = torch.from_numpy(original_weights[1])
    model.conv2.weight.data = torch.from_numpy(np.transpose(original_weights[2]))
    model.conv2.bias.data = torch.from_numpy(original_weights[3])
    model.fc1.bias.data = torch.from_numpy(original_weights[5])
    model.fc2.weight.data = torch.from_numpy(np.transpose(original_weights[6]))
    model.fc2.bias.data = torch.from_numpy(original_weights[7])
    model.fc3.weight.data = torch.from_numpy(np.ones(shape=[1,3], dtype=np.float32))
    model.fc3.bias.data = torch.from_numpy(np.zeros(shape=[1], dtype=np.float32))
    model.to(device)

In [0]:
test_model = PrefQNet(k=1, fine_tune=False, q_head_index=None, q_means_stds=[[0, 0, 0], [1, 1, 1]], use_last_linear=True).to(device)

if params['use_qnet_weights']:
    load_weights(test_model)

## Test the weights loaded properly

In [0]:
# Comparing that the models have identical observations for identical images
tf_conv1_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[0].output)
tf_conv2_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[1].output)
tf_flt_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[2].output)
tf_fc1_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[3].output)

def npsigmoid(x):
  return 1/(1 + np.exp(-x)) 

def check_same(torch_layer, tf_layer):
    torch_out = np.transpose(torch_layer(single_observation_torch).detach().cpu().numpy())
    torch_out = torch_out.reshape(torch_out.shape[:-1])
    tf_out = tf_layer(single_observation)[0].numpy()
    np.testing.assert_allclose(torch_out, tf_out, rtol=.1, atol=5)  

# due to shape of original TF model this test can be done only when use_mnist = False
if not params['use_mnist'] and params['use_qnet_weights']:
    for i in range(len(all_raw_data[0].observation)):

        single_observation = np.array([all_raw_data[0].observation[i]])
        single_observation_torch = torch.Tensor(np.array([np.transpose(all_raw_data[0].observation[i])]))

        single_observation_torch = single_observation_torch.to(device)

        check_same(test_model.fwd_conv1, tf_conv1_fn)
        check_same(test_model.fwd_conv2, tf_conv2_fn)
        check_same(test_model.fwd_flat, tf_flt_fn)
        check_same(test_model.fwd_flat, tf_flt_fn)

        fc1_torch_out = np.transpose(test_model.fwd_fc1(single_observation_torch).detach().cpu().numpy())
        fc1_torch_out = fc1_torch_out.reshape(fc1_torch_out.shape[:-1])
        fc1_tf_out = tf_fc1_fn(single_observation)[0].numpy()
        
        np.testing.assert_allclose(fc1_torch_out, fc1_tf_out, rtol=.1, atol=5)
        old_activations = all_raw_data[0].policy_info["activations"][i]
        np.testing.assert_allclose(fc1_torch_out, old_activations, rtol=.1, atol=5)
        np.testing.assert_allclose(old_activations, fc1_tf_out, rtol=.1, atol=5)

        check_same(test_model.fwd_fc2, restored_model)

        torch_out = np.transpose(test_model.forward(single_observation_torch).detach().cpu().numpy())
        torch_out = torch_out.reshape(torch_out.shape[:-1])
        tf_out = npsigmoid(np.sum(restored_model(single_observation)[0].numpy()))
        np.testing.assert_allclose(torch_out, tf_out, rtol=.1, atol=5)  

## Modelling

### Get data to normalize qHeads

In [15]:
def get_q_heads_mu_and_sigma(model, all_obs, num_obs):
    
    model.eval()

    all_obs = shuffle(all_obs)
    obs_to_pass = all_obs[:num_obs]

    obs_tensor = torch.Tensor(obs_to_pass)
    obs_tensor = obs_tensor.to(device)
    qheads_values = model.fwd_fc2(obs_tensor).detach().cpu().numpy()

    mu = qheads_values.mean(axis=0)
    s = qheads_values.std(axis=0)

    print("mu", mu, "s", s)
    
    return np.array([mu, s])

if params['use_mnist']:
    img_batch, label = iter(tr_data_loader).next()
    xs = img_batch

q_mu_s = get_q_heads_mu_and_sigma(test_model, xs, 10000)

mu [ 93.20709   68.562904 138.82045 ] s [47.691833 51.373444 77.846725]


### Methods to inspect performance

In [0]:
def get_number_of_new_scores_in_top_k(new_scores, old_scores, k):
    new_top_k_scores = set(new_scores[:int(len(new_scores) * k)])
    old_top_k_scores = set(old_scores[:int(len(old_scores) * k)])

    return len(old_top_k_scores) - len(new_top_k_scores.intersection(old_top_k_scores))

def model_scores_to_dict(model):
    return {
        'conv1': model.conv1.scores.detach().cpu().numpy().flatten().argsort(),
        'conv2': model.conv2.scores.detach().cpu().numpy().flatten().argsort(),
        'fc1': model.fc1.scores.detach().cpu().numpy().flatten().argsort(),
        'fc2': model.fc2.scores.detach().cpu().numpy().flatten().argsort(),
        'fc3': model.fc3.scores.detach().cpu().numpy().flatten().argsort()
    }

def get_no_of_changed_scores(model, previous_scores, k):

    new_scores_idxs = model_scores_to_dict(model)

    score_changes = {}

    for score in new_scores_idxs:
        changed_scores_num = get_number_of_new_scores_in_top_k(new_scores_idxs[score], previous_scores[score], k)
        score_changes[score] = changed_scores_num

    return score_changes, new_scores_idxs

def plot_metric(results_dict, metric):
    plt.title(metric)
    plt.xlabel('Epochs')
    plt.plot(list(range(1, params['num_epochs'] + 1)), results_dict[f'train{metric}'], label=f'Train {metric}')
    plt.plot(list(range(1, params['num_epochs'] + 1)), results_dict[f'test{metric}'], label=f'Test {metric}')
    plt.legend()
    plt.show()

def plot_metric_multiple_runs(results_items, metric, train=True):
    plt.title(metric)
    plt.xlabel('Epochs')
    for res_key, res_dict in results_items.items():
        if train:
            plt.plot(list(range(1, params['num_epochs'] + 1)), res_dict[f'train{metric}'], label=f'Train {metric} - {res_key}')
        else:
            plt.plot(list(range(1, params['num_epochs'] + 1)), res_dict[f'test{metric}'], label=f'Test {metric} - {res_key}')
    plt.legend()
    plt.show()

def plot_score_changes(score_changes_dict):
    plt.title('Layer-wise score changes')
    plt.xlabel('Optimisation steps (num_train / batch_size * epochs)')
    for layer in score_changes_dict:
        plt.plot(list(range(1, len(score_changes_dict[layer]) + 1)), score_changes_dict[layer], label=f'{layer}')
    plt.legend()
    plt.show()

### Method to get data for one sample run

In [0]:
def get_data_sample(xs=None, ys=None):

    if not params['use_mnist']:
        xs, ys = shuffle(xs, ys)
        
        train_split = params['num_train']
        tune_split = params['num_train']+params['num_tune']
        test_split = params['num_train']+params['num_tune']+params['num_val']

        tr_data_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(torch.Tensor(xs[:train_split]), torch.Tensor(ys[:train_split])),
            batch_size=params['batch_size'])

        tune_data_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(torch.Tensor(xs[train_split:tune_split]), torch.Tensor(ys[train_split:tune_split])),
            batch_size=params['batch_size'])

        val_data_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(torch.Tensor(xs[tune_split:test_split]), torch.Tensor(ys[tune_split:test_split])),
            batch_size=params['val_batch_size'])
    else:
        tr_data_loader = torch.utils.data.DataLoader(
            datasets.MNIST('mnist', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])), batch_size=params['batch_size'], shuffle=True)

        tune_data_loader = torch.utils.data.DataLoader(
            datasets.MNIST('mnist', train=False, download=True, 
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])), batch_size=params['batch_size'], shuffle=True)

        val_data_loader = torch.utils.data.DataLoader(
            datasets.MNIST('mnist', train=False, download=True, 
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])), batch_size=params['val_batch_size'], shuffle=True)
        
    return tr_data_loader, tune_data_loader, val_data_loader

### Single run train/test methods

In [0]:
"""
    Train/Test function for Randomly Weighted Hidden Neural Networks Techniques
    Adapted from https://github.com/NesterukSergey/hidden-networks/blob/master/demos/mnist.ipynb
"""

def compute_metrics(predictions, true_labels):
    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    accuracy = np.sum(np.equal((predictions > 0.5).astype(int), true_labels)) / len(true_labels)
    fpr, tpr, thresholds = metrics.roc_curve(true_labels, predictions, pos_label=1)
    auc = metrics.auc(fpr, tpr)
    return accuracy, auc

def train(model, k, device, train_loader, optimizer, criterion):
    
    train_loss = 0
    true_labels = []
    predictions = [] # labels

    model.train()
    train_score_changes = {}
    if k is not None:
        scores = model_scores_to_dict(model)
        train_score_changes = {k: [] for k in scores}

    for data, target in itertools.islice(train_loader, params['num_train']):
        
        data, target = data.to(device), target.to(device)
        if params['use_mnist']:
            target = (target > 0).float()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if k is not None:
            score_changes, scores = get_no_of_changed_scores(model, scores, k)
            for layer_changes in score_changes:
                train_score_changes[layer_changes].append(score_changes[layer_changes])

        train_loss += loss
        predictions.extend(output.detach().cpu().numpy())
        true_labels.extend(target.detach().cpu().numpy())
    
    train_loss /= len(train_loader.dataset)
    accuracy, auc = compute_metrics(predictions, true_labels)

    return train_loss.item(), accuracy, auc, train_score_changes


def test(model, device, criterion, test_loader, num_test):
    true_labels = []
    predictions = [] # labels

    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, target in itertools.islice(test_loader, num_test):
            data, target = data.to(device), target.to(device)
            if params['use_mnist']:
                target = (target > 0).float()
            output = model(data)
            test_loss += criterion(output, target)

            predictions.extend(output.detach().cpu().numpy())
            true_labels.extend(target.detach().cpu().numpy())
    
    test_loss /= len(test_loader.dataset)
    accuracy, auc = compute_metrics(predictions, true_labels)

    return test_loss.item(), accuracy, auc

def run_model(model, k, learning_rate, weight_decay, num_epochs):

  tr_data_loader, tune_data_loader, val_data_loader = get_data_sample(xs, ys)

  optimizer = optim.Adam(
      [p for p in model.parameters() if p.requires_grad],
      lr=learning_rate,
      weight_decay=weight_decay
  )

  criterion = nn.BCELoss().to(device)
  scheduler = CosineAnnealingLR(optimizer, T_max=len(tr_data_loader))

  train_losses = []
  test_losses = []
  tune_losses = []
  train_accs = []
  train_aucs = []
  test_accs = []
  test_aucs = []
  score_changes = []
  

  for epoch in range(num_epochs):
      train_loss, train_accuracy, train_auc, train_score_changes = train(model, k, device, tr_data_loader, optimizer, criterion)
      tune_loss, _, _ = test(model, device, criterion, tune_data_loader, params['num_tune'])
      test_loss, test_accuracy, test_auc = test(model, device, criterion, val_data_loader, params['num_val'])
      scheduler.step()

      score_changes.append(train_score_changes)
      train_losses.append(train_loss)
      tune_losses.append(tune_loss)
      test_losses.append(test_loss)
      train_accs.append(train_accuracy)
      train_aucs.append(train_auc)
      test_accs.append(test_accuracy)
      test_aucs.append(test_auc)

  merged_score_changes = {k: [] for k in score_changes[0].keys()}
  for d in score_changes:
    for k in d:
        merged_score_changes[k].extend(d[k])

  return {'trainLoss': train_losses, 'testLoss': test_losses, 'tuneLoss': tune_losses,
          'trainAccuracy': train_accs, 'testAccuracy': test_accs,
          'trainAUC': train_aucs, 'testAUC': test_aucs, 'scoreChanges': merged_score_changes}

In [0]:
def multi_runs(fine_tune, K, q_head_index, q_means_stds, use_last_linear, init_from_act_index, 
               learning_rate, weight_decay, plots=False):

    averaged_results = {}    
    for run_ix in range(params['num_run']):
        
        model = PrefQNet(fine_tune=fine_tune, k=K, q_head_index=q_head_index, q_means_stds=q_means_stds,
                         use_last_linear=use_last_linear, init_from_act_index=init_from_act_index)
        
        if params['use_qnet_weights']:
            load_weights(model)

        results = run_model(model, K, learning_rate=learning_rate, weight_decay=weight_decay, num_epochs=params['num_epochs'])
        
        print(f'Train pass no. {run_ix+1}')
        if (run_ix == 0) and plots:
            print('Debug charts for first training run')
            plot_metric(results, 'Loss')
            plot_metric(results, 'Accuracy')
            plot_metric(results, 'AUC')

        for val in results:
            if len(results[val]) > 0 and val != 'scoreChanges':
                if val not in averaged_results:
                    averaged_results[val] = [results[val][-1]]
                else:
                    averaged_results[val].append(results[val][-1])         
    
    return averaged_results, {x: sum(averaged_results[x]) / params['num_run'] for x in averaged_results}

## Initialise Subnets Search with activation that obtained optimal AUC in previous experiment

We do this both as a sanity check as well as a potential improvement

In [67]:
from sklearn import metrics 

acts = []
prefs = []

for data in all_raw_data:
    for i in range(data.observation.shape[0]):
        acts.append(np.copy(data.policy_info["activations"][i]))
        prefs.append((data.policy_info['satisfaction'].as_list()[i] > -6).astype(int))

acts = np.array(acts)
prefs = np.array(prefs)

def display_auc_info(xs, ys):
    
    def calc_auc(xs, ys, i):
        fpr, tpr, thresholds = metrics.roc_curve(ys, xs[:,i], pos_label=1)
        return metrics.auc(fpr, tpr)


    multi_runs_aucs = []
    for run_ix in range(50):
        xs, ys = shuffle(xs, ys)
        flat_xs = np.reshape(xs, (xs.shape[0], -1))
        aucs = []    
        
        for i in range(flat_xs.shape[1]):
            auc = calc_auc(flat_xs[:params['num_train']], ys[:params['num_train']], i)
            aucs.append(auc)  

        aucs = np.array(aucs)
        multi_runs_aucs.append(aucs)

    aucs = np.array(multi_runs_aucs)
    aucs = aucs.mean(axis=0)

    print("AUC from only picking a single activation")
    print(np.argmin(aucs), "train", 1-np.min(aucs), "val", 1-calc_auc(flat_xs[params['num_train']:], ys[params['num_train']:], np.argmin(aucs)))
    print(np.argmax(aucs), "train", np.max(aucs), "val", calc_auc(flat_xs[params['num_train']:], ys[params['num_train']:], np.argmax(aucs)))
  
display_auc_info(acts, prefs)

AUC from only picking a single activation
34 train 0.8283087091008614 val 0.8202670285662659
13 train 0.6322109960460084 val 0.6179261050588937


In [78]:
best_act_index = 34
K = 66774 / 67152  # Num of weigths with all dense activations except one set to 0 / Number of total weights

params['num_epochs'] = 1
results = multi_runs(fine_tune=False, K=1, q_head_index=None, q_means_stds=q_mu_s, 
                     use_last_linear=False, init_from_act_index=best_act_index,
                     learning_rate=0.000, weight_decay=0.000, plots=False)

1 - results['testAUC']

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


0.8293560760678543

## Getting Results for Different Combinations

In [0]:
params['num_epochs'] = 100

In [95]:
multi_runs(fine_tune=True, K=None, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.01, weight_decay=0.001, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.8133364456893868,
   0.8154763434504382,
   0.8522975641869651,
   0.8247224815852268,
   0.8357568783345539],
  'testAccuracy': [0.7625, 0.7425, 0.815, 0.775, 0.7575],
  'testLoss': [0.1537771224975586,
   0.09648054838180542,
   0.08775465935468674,
   0.13497887551784515,
   0.07934705168008804],
  'trainAUC': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainAccuracy': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainLoss': [0.00011910752073163167,
   0.00018867047037929296,
   0.00037541292840614915,
   0.0006876284605823457,
   0.0016543453093618155],
  'tuneLoss': [0.2263316661119461,
   0.17281566560268402,
   0.20987460017204285,
   0.1995871663093567,
   0.07609060406684875]},
 {'testAUC': 0.8283179426493141,
  'testAccuracy': 0.7705,
  'testLoss': 0.11046765148639678,
  'trainAUC': 1.0,
  'trainAccuracy': 1.0,
  'trainLoss': 0.000605032937892247,
  'tuneLoss': 0.1769399404525757})

In [97]:
multi_runs(fine_tune=True, K=None, q_head_index=None, q_means_stds=[[0,0,0],[1,1,1]], 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.01, weight_decay=0.001, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.7477879665379665,
   0.7751653575857793,
   0.6778058007566203,
   0.8065151986720615,
   0.7833914053426247],
  'testAccuracy': [0.6925, 0.7025, 0.6775, 0.765, 0.745],
  'testLoss': [0.2816723585128784,
   0.19054630398750305,
   0.186612069606781,
   0.1592748761177063,
   0.1535634547472],
  'trainAUC': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainAccuracy': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainLoss': [5.849573062732816e-05,
   0.00012599444016814232,
   0.00014537507377099246,
   4.419025935931131e-05,
   7.232251664390787e-05],
  'tuneLoss': [0.1408819556236267,
   0.24297289550304413,
   0.27762624621391296,
   0.1575925052165985,
   0.1641690582036972]},
 {'testAUC': 0.7581331457790105,
  'testAccuracy': 0.7165,
  'testLoss': 0.19433381259441376,
  'trainAUC': 1.0,
  'trainAccuracy': 1.0,
  'trainLoss': 8.927560411393642e-05,
  'tuneLoss': 0.1966485321521759})

In [98]:
multi_runs(fine_tune=True, K=None, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=False, init_from_act_index=None,
           learning_rate=0.01, weight_decay=0.001, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.7715556012438642,
   0.774658553076403,
   0.5266279819471309,
   0.7981329290672625,
   0.5574850925829062],
  'testAccuracy': [0.72, 0.705, 0.5825, 0.71, 0.555],
  'testLoss': [0.11522474884986877,
   0.11896015703678131,
   0.25925329327583313,
   0.13024890422821045,
   0.22739072144031525],
  'trainAUC': [1.0, 1.0, 0.988970588235294, 1.0, 1.0],
  'trainAccuracy': [1.0, 1.0, 0.94, 1.0, 0.98],
  'trainLoss': [0.00022527291730511934,
   0.00024238736659754068,
   0.012866455130279064,
   0.0002553285157773644,
   0.005089514888823032],
  'tuneLoss': [0.03174200281500816,
   0.22029440104961395,
   0.3788627088069916,
   0.13941320776939392,
   0.23116329312324524]},
 {'testAUC': 0.6856920315835133,
  'testAccuracy': 0.6545,
  'testLoss': 0.17021556496620177,
  'trainAUC': 0.9977941176470587,
  'trainAccuracy': 0.984,
  'trainLoss': 0.003735791763756424,
  'tuneLoss': 0.20029512271285058})

In [99]:
multi_runs(fine_tune=True, K=None, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.005, weight_decay=0.005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.7548387096774193,
   0.7608515915667631,
   0.8438559705935648,
   0.8534888586672247,
   0.8283984375000001],
  'testAccuracy': [0.68, 0.7175, 0.785, 0.7625, 0.7525],
  'testLoss': [0.16195568442344666,
   0.09815102070569992,
   0.12214992195367813,
   0.0896163359284401,
   0.1074909120798111],
  'trainAUC': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainAccuracy': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainLoss': [0.0005004504346288741,
   0.0010932598961517215,
   0.0005481308326125145,
   0.0006376910605467856,
   0.0006498582661151886],
  'tuneLoss': [0.04782821610569954,
   0.14940868318080902,
   0.20063456892967224,
   0.03472176566720009,
   0.11419814825057983]},
 {'testAUC': 0.8082867136009944,
  'testAccuracy': 0.7395,
  'testLoss': 0.11587277501821518,
  'trainAUC': 1.0,
  'trainAccuracy': 1.0,
  'trainLoss': 0.0006858780980110168,
  'tuneLoss': 0.10935827642679215})

In [101]:
multi_runs(fine_tune=True, K=None, q_head_index=0, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.01, weight_decay=0.005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.7471994884910486,
   0.8140650656814451,
   0.7021952406567791,
   0.7318123441867492,
   0.8180720104854667],
  'testAccuracy': [0.69, 0.7375, 0.6575, 0.6925, 0.785],
  'testLoss': [0.17120319604873657,
   0.1225447803735733,
   0.17691823840141296,
   0.14129996299743652,
   0.10622461140155792],
  'trainAUC': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainAccuracy': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainLoss': [0.00029294859268702567,
   0.0005789602873846889,
   0.0009452882222831249,
   0.0010689892806112766,
   0.0008288358803838491],
  'tuneLoss': [0.13852056860923767,
   0.14755722880363464,
   0.2768683135509491,
   0.1588255614042282,
   0.17423930764198303]},
 {'testAUC': 0.7626688299002977,
  'testAccuracy': 0.7125,
  'testLoss': 0.14363815784454345,
  'trainAUC': 1.0,
  'trainAccuracy': 1.0,
  'trainLoss': 0.0007430044526699931,
  'tuneLoss': 0.17920219600200654})

In [103]:
multi_runs(fine_tune=True, K=None, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.02, weight_decay=0.02, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.7739121430975123,
   0.572289156626506,
   0.6748532591906086,
   0.7508389046850585,
   0.8160444716442268],
  'testAccuracy': [0.725, 0.595, 0.6275, 0.6875, 0.745],
  'testLoss': [0.09092437475919724,
   0.07079415023326874,
   0.06848209351301193,
   0.09792287647724152,
   0.08965844660997391],
  'trainAUC': [1.0, 0.7205882352941176, 0.7754677754677755, 1.0, 1.0],
  'trainAccuracy': [1.0, 0.74, 0.86, 0.98, 1.0],
  'trainLoss': [0.002261457731947303,
   0.05391618609428406,
   0.04400831460952759,
   0.00655113160610199,
   0.0017912012990564108],
  'tuneLoss': [0.11504357308149338,
   0.08367732912302017,
   0.07473348081111908,
   0.07115575671195984,
   0.11498477309942245]},
 {'testAUC': 0.7175875870487823,
  'testAccuracy': 0.6759999999999999,
  'testLoss': 0.08355638831853866,
  'trainAUC': 0.8992112021523786,
  'trainAccuracy': 0.916,
  'trainLoss': 0.02170565826818347,
  'tuneLoss': 0.09191898256540298})

In [104]:
multi_runs(fine_tune=False, K=0.9, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.01, weight_decay=0.001, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.8143466380653234,
   0.7676147793537216,
   0.8176718827815783,
   0.5068503694581281,
   0.8545417400852932],
  'testAccuracy': [0.77, 0.74, 0.77, 0.53, 0.815],
  'testLoss': [0.05556024983525276,
   0.056379470974206924,
   0.054878585040569305,
   0.07284703850746155,
   0.05128539353609085],
  'trainAUC': [0.8633333333333333,
   0.8084415584415585,
   0.8276972624798712,
   0.6499999999999999,
   0.9136029411764707],
  'trainAccuracy': [0.88, 0.8, 0.74, 0.64, 0.86],
  'trainLoss': [0.04745708405971527,
   0.047659922391176224,
   0.05367177352309227,
   0.06543111801147461,
   0.03468288481235504],
  'tuneLoss': [0.07962016016244888,
   0.0642518475651741,
   0.06761405616998672,
   0.07475253194570541,
   0.028424810618162155]},
 {'testAUC': 0.7522050819488089,
  'testAccuracy': 0.7250000000000001,
  'testLoss': 0.05819014757871628,
  'trainAUC': 0.8126150190862468,
  'trainAccuracy': 0.784,
  'trainLoss': 0.049780556559562684,
  'tuneLoss': 0.06293268129229546})

In [105]:
multi_runs(fine_tune=False, K=0.6, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.01, weight_decay=0.001, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.8406857142857143,
   0.6529928735111178,
   0.7702365451388888,
   0.5740233304626947,
   0.5459261519682637],
  'testAccuracy': [0.7875, 0.6775, 0.6725, 0.5675, 0.485],
  'testLoss': [0.05376938357949257,
   0.06315452605485916,
   0.05775335058569908,
   0.07495023310184479,
   0.07004529982805252],
  'trainAUC': [0.9466666666666667,
   0.7000000000000001,
   0.5349999999999999,
   0.7321428571428571,
   0.6666666666666667],
  'trainAccuracy': [0.92, 0.72, 0.6, 0.7, 0.64],
  'trainLoss': [0.029632307589054108,
   0.05787434056401253,
   0.06735801696777344,
   0.06216932460665703,
   0.06394052505493164],
  'tuneLoss': [0.12848210334777832,
   0.07463286817073822,
   0.07181157916784286,
   0.0973539799451828,
   0.08388696610927582]},
 {'testAUC': 0.6767729230733359,
  'testAccuracy': 0.6379999999999999,
  'testLoss': 0.06393455862998962,
  'trainAUC': 0.7160952380952381,
  'trainAccuracy': 0.7160000000000001,
  'trainLoss': 0.05619490295648575,
  'tuneLoss': 0.09123

In [106]:
multi_runs(fine_tune=False, K=0.6, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.001, weight_decay=0.001, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.7000343997248022,
   0.5269282257234064,
   0.8544660724031289,
   0.7313003452243959,
   0.6140092165898617],
  'testAccuracy': [0.6775, 0.5525, 0.805, 0.6175, 0.6525],
  'testLoss': [0.06187906488776207,
   0.07768125832080841,
   0.05388563871383667,
   0.06781795620918274,
   0.06757856160402298],
  'trainAUC': [0.8146167557932263,
   0.8472222222222222,
   0.8940972222222222,
   0.9006410256410257,
   0.784],
  'trainAccuracy': [0.8, 0.82, 0.7, 0.8, 0.68],
  'trainLoss': [0.05098897963762283,
   0.0832400694489479,
   0.04823710024356842,
   0.04425951465964317,
   0.05717984586954117],
  'tuneLoss': [0.07258088886737823,
   0.09840782731771469,
   0.05668112635612488,
   0.07356924563646317,
   0.08751397579908371]},
 {'testAUC': 0.685347651933119,
  'testAccuracy': 0.661,
  'testLoss': 0.06576849594712257,
  'trainAUC': 0.8481154451757392,
  'trainAccuracy': 0.76,
  'trainLoss': 0.0567811019718647,
  'tuneLoss': 0.07775061279535293})

In [107]:
multi_runs(fine_tune=False, K=0.9, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=False, init_from_act_index=None,
           learning_rate=0.005, weight_decay=0.001, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.3784893267651888,
   0.41625041625041626,
   0.2968005266622778,
   0.3443843940738351,
   0.5686321444719276],
  'testAccuracy': [0.565, 0.5225, 0.515, 0.52, 0.6425],
  'testLoss': [0.11786006391048431,
   0.09654093533754349,
   0.10057266056537628,
   0.09719052165746689,
   0.087620310485363],
  'trainAUC': [0.26559714795008915,
   0.46314102564102566,
   0.4935897435897436,
   0.41,
   0.5695238095238095],
  'trainAccuracy': [0.66, 0.5, 0.48, 0.6, 0.7],
  'trainLoss': [0.11307734251022339,
   0.10102222859859467,
   0.09550159424543381,
   0.09159834682941437,
   0.06936497241258621],
  'tuneLoss': [0.15354789793491364,
   0.11735742539167404,
   0.09057482331991196,
   0.1102447584271431,
   0.13559284806251526]},
 {'testAUC': 0.40091136164472907,
  'testAccuracy': 0.553,
  'testLoss': 0.0999568983912468,
  'trainAUC': 0.4403703453409335,
  'trainAccuracy': 0.5880000000000001,
  'trainLoss': 0.09411289691925048,
  'tuneLoss': 0.1214635506272316})

In [108]:
multi_runs(fine_tune=False, K=0.95, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=34,
           learning_rate=0.01, weight_decay=0.0001, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.675113504381797,
   0.8661377275963786,
   0.8394269161512565,
   0.7257984874866811,
   0.7394791666666666],
  'testAccuracy': [0.6575, 0.7875, 0.7825, 0.62, 0.665],
  'testLoss': [0.06970095634460449,
   0.05055497586727142,
   0.05125127732753754,
   0.06653737276792526,
   0.06019798666238785],
  'trainAUC': [0.925,
   0.9472,
   0.9357638888888888,
   0.7985739750445633,
   0.9014778325123153],
  'trainAccuracy': [0.86, 0.9, 0.88, 0.74, 0.76],
  'trainLoss': [0.03679502382874489,
   0.02972467988729477,
   0.030581464990973473,
   0.05469159781932831,
   0.04718099907040596],
  'tuneLoss': [0.08983348309993744,
   0.07408870756626129,
   0.07701148092746735,
   0.05246501415967941,
   0.05968058109283447]},
 {'testAUC': 0.769191160456556,
  'testAccuracy': 0.7025,
  'testLoss': 0.05964851379394531,
  'trainAUC': 0.9016031392891535,
  'trainAccuracy': 0.828,
  'trainLoss': 0.03979475311934948,
  'tuneLoss': 0.070615853369236})

In [109]:
multi_runs(fine_tune=False, K=0.95, q_head_index=0, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=34,
           learning_rate=0.01, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.7694624652455977,
   0.8248337595907927,
   0.8124300681517648,
   0.7965104001689367,
   0.8487487865386689],
  'testAccuracy': [0.61, 0.7675, 0.7375, 0.7475, 0.775],
  'testLoss': [0.11215081810951233,
   0.07271672785282135,
   0.056875914335250854,
   0.05969255790114403,
   0.04754012078046799],
  'trainAUC': [0.962059620596206,
   1.0,
   0.9295238095238095,
   0.9618055555555556,
   0.9833333333333334],
  'trainAccuracy': [0.86, 0.98, 0.92, 0.82, 0.9],
  'trainLoss': [0.028372040018439293,
   0.019243409857153893,
   0.032974474132061005,
   0.03540046885609627,
   0.03246229141950607],
  'tuneLoss': [0.12373261153697968,
   0.060958605259656906,
   0.07308142632246017,
   0.08058549463748932,
   0.060368526726961136]},
 {'testAUC': 0.8103970959391521,
  'testAccuracy': 0.7275,
  'testLoss': 0.06979522779583931,
  'trainAUC': 0.967344463801781,
  'trainAccuracy': 0.8959999999999999,
  'trainLoss': 0.029690536856651305,
  'tuneLoss': 0.07974533289670944})

In [110]:
multi_runs(fine_tune=False, K=0.95, q_head_index=0, q_means_stds=[[0,0,0], [1,1,1]], 
           use_last_linear=True, init_from_act_index=34,
           learning_rate=0.01, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.7526146103640796,
   0.5945896877269427,
   0.7550835148874364,
   0.7966951747439552,
   0.6848419318298836],
  'testAccuracy': [0.7175, 0.54, 0.7, 0.7175, 0.6375],
  'testLoss': [0.06425639241933823,
   0.07532334327697754,
   0.11996649205684662,
   0.10264424979686737,
   0.08122178167104721],
  'trainAUC': [0.8701298701298702,
   0.765625,
   0.9904000000000001,
   0.9461805555555556,
   0.7783251231527093],
  'trainAccuracy': [0.8, 0.76, 0.88, 0.88, 0.68],
  'trainLoss': [0.044709257781505585,
   0.05730096250772476,
   0.018094198778271675,
   0.02974756248295307,
   0.06562159955501556],
  'tuneLoss': [0.08781808614730835,
   0.06935253739356995,
   0.2592705488204956,
   0.07122427970170975,
   0.11129285395145416]},
 {'testAUC': 0.7167649839104594,
  'testAccuracy': 0.6625,
  'testLoss': 0.08868245184421539,
  'trainAUC': 0.8701321097676271,
  'trainAccuracy': 0.8,
  'trainLoss': 0.043094716221094134,
  'tuneLoss': 0.11979166120290756})

In [111]:
multi_runs(fine_tune=False, K=0.5, q_head_index=0, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=34,
           learning_rate=0.01, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.6257382333978078, 0.5, 0.6320043103448275, 0.5, 0.5],
  'testAccuracy': [0.605, 0.5675, 0.58, 0.635, 0.585],
  'testLoss': [0.06631014496088028,
   0.0688510462641716,
   0.07176155596971512,
   0.06588596105575562,
   0.06846720725297928],
  'trainAUC': [0.785,
   0.3557692307692308,
   0.8095238095238095,
   0.4000000000000001,
   0.36111111111111105],
  'trainAccuracy': [0.66, 0.52, 0.72, 0.6, 0.64],
  'trainLoss': [0.062056612223386765,
   0.0692453682422638,
   0.0479121096432209,
   0.06731388717889786,
   0.06535253673791885],
  'tuneLoss': [0.08478225767612457,
   0.08133701235055923,
   0.08145411312580109,
   0.0823788270354271,
   0.0765928253531456]},
 {'testAUC': 0.551548508748527,
  'testAccuracy': 0.5945,
  'testLoss': 0.06825518310070038,
  'trainAUC': 0.5422808302808303,
  'trainAccuracy': 0.628,
  'trainLoss': 0.06237610280513763,
  'tuneLoss': 0.08130900710821151})

In [113]:
multi_runs(fine_tune=False, K=0.5, q_head_index=0, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.01, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.6774869923118738,
   0.8382871667185393,
   0.4694270833333334,
   0.8442786974713479,
   0.8157291666666667],
  'testAccuracy': [0.6775, 0.7375, 0.5725, 0.785, 0.755],
  'testLoss': [0.07251473516225815,
   0.05908678099513054,
   0.07354000210762024,
   0.06079702079296112,
   0.12730401754379272],
  'trainAUC': [1.0, 1.0, 0.8057142857142857, 1.0, 1.0],
  'trainAccuracy': [0.94, 1.0, 0.72, 1.0, 1.0],
  'trainLoss': [0.01984412595629692,
   0.006265793461352587,
   0.05153064429759979,
   0.008897450752556324,
   6.345896690618247e-05],
  'tuneLoss': [0.05665118247270584,
   0.1205177754163742,
   0.07822347432374954,
   0.0429522879421711,
   0.19523631036281586]},
 {'testAUC': 0.7290418213003522,
  'testAccuracy': 0.7055,
  'testLoss': 0.07864851132035255,
  'trainAUC': 0.9611428571428572,
  'trainAccuracy': 0.932,
  'trainLoss': 0.01732029468694236,
  'tuneLoss': 0.09871620610356331})

In [114]:
multi_runs(fine_tune=False, K=0.95, q_head_index=0, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.01, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.820771001150748,
   0.7318933333333333,
   0.8628095722748507,
   0.8583643026066166,
   0.7261219792865362],
  'testAccuracy': [0.75, 0.68, 0.795, 0.7925, 0.6625],
  'testLoss': [0.06421355158090591,
   0.11194369941949844,
   0.06392784416675568,
   0.06583256274461746,
   0.0744662806391716],
  'trainAUC': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainAccuracy': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainLoss': [0.002578325103968382,
   0.0023219643626362085,
   0.0023922764230519533,
   0.0024932799860835075,
   0.004318648017942905],
  'tuneLoss': [0.06408543139696121,
   0.04451474919915199,
   0.10824717581272125,
   0.12606698274612427,
   0.14177194237709045]},
 {'testAUC': 0.7999920377304169,
  'testAccuracy': 0.736,
  'testLoss': 0.07607678771018982,
  'trainAUC': 1.0,
  'trainAccuracy': 1.0,
  'trainLoss': 0.0028208987787365913,
  'tuneLoss': 0.09693725630640984})

In [115]:
multi_runs(fine_tune=False, K=0.95, q_head_index=0, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=34,
           learning_rate=0.1, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.6188574858425324,
   0.6511699507389164,
   0.6255630630630631,
   0.5930911533691609,
   0.8116473749483258],
  'testAccuracy': [0.5825, 0.615, 0.6575, 0.6125, 0.75],
  'testLoss': [0.1295589655637741,
   0.08068183809518814,
   0.06292353570461273,
   0.0890970528125763,
   0.05521814525127411],
  'trainAUC': [1.0, 1.0, 0.8423645320197044, 0.9750445632798573, 1.0],
  'trainAccuracy': [1.0, 0.96, 0.74, 0.92, 0.96],
  'trainLoss': [0.002316945930942893,
   0.021039480343461037,
   0.05347485467791557,
   0.02702312357723713,
   0.013313046656548977],
  'tuneLoss': [0.11564882099628448,
   0.09832269698381424,
   0.08518502861261368,
   0.1050223782658577,
   0.058514554053545]},
 {'testAUC': 0.6600658055923997,
  'testAccuracy': 0.6435000000000001,
  'testLoss': 0.08349590748548508,
  'trainAUC': 0.9634818190599124,
  'trainAccuracy': 0.916,
  'trainLoss': 0.023433490237221122,
  'tuneLoss': 0.09253869578242302})

In [117]:
multi_runs(fine_tune=False, K=0.95, q_head_index=0, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=4,
           learning_rate=0.1, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.4875583616298812,
   0.5603684932953226,
   0.5,
   0.7264018033248802,
   0.5897973259762309],
  'testAccuracy': [0.55, 0.575, 0.3825, 0.7025, 0.6225],
  'testLoss': [0.08827588707208633,
   0.07162626087665558,
   6.174999713897705,
   0.09538057446479797,
   0.08188704401254654],
  'trainAUC': [0.9500805152979066, 0.9622331691297208, 0.5, 1.0, 0.96],
  'trainAccuracy': [0.84, 0.9, 0.38, 0.98, 0.94],
  'trainLoss': [0.03287625312805176,
   0.033768948167562485,
   6.199999809265137,
   0.008226421661674976,
   0.023387528955936432],
  'tuneLoss': [0.15499870479106903,
   0.06846723705530167,
   8.399999618530273,
   0.11436572670936584,
   0.12820696830749512]},
 {'testAUC': 0.572825196845263,
  'testAccuracy': 0.5665,
  'testLoss': 1.3024338960647583,
  'trainAUC': 0.8744627368855256,
  'trainAccuracy': 0.808,
  'trainLoss': 1.2596517922356725,
  'tuneLoss': 1.773207651078701})

In [118]:
multi_runs(fine_tune=False, K=0.99, q_head_index=0, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=34,
           learning_rate=0.1, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.8329141311697092,
   0.6919871038767005,
   0.5624496478598716,
   0.5552075993632402,
   0.6513831696758525],
  'testAccuracy': [0.7425, 0.6, 0.545, 0.5, 0.6475],
  'testLoss': [0.08803754299879074,
   0.10880590230226517,
   0.08032785356044769,
   0.11284294724464417,
   0.10569707304239273],
  'trainAUC': [1.0, 1.0, 0.7869269949066213, 0.9885057471264367, 1.0],
  'trainAccuracy': [1.0, 1.0, 0.74, 0.96, 0.98],
  'trainLoss': [0.0011589666828513145,
   0.007331452798098326,
   0.0491270050406456,
   0.01544104516506195,
   0.007958120666444302],
  'tuneLoss': [0.07976274192333221,
   0.1272801011800766,
   0.10634607821702957,
   0.0972849503159523,
   0.09030703455209732]},
 {'testAUC': 0.6587883303890748,
  'testAccuracy': 0.607,
  'testLoss': 0.0991422638297081,
  'trainAUC': 0.9550865484066116,
  'trainAccuracy': 0.9359999999999999,
  'trainLoss': 0.0162033180706203,
  'tuneLoss': 0.1001961812376976})

In [119]:
multi_runs(fine_tune=True, K=None, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.1, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.4947465437788018,
   0.5,
   0.5,
   0.513583339857512,
   0.4837848932676519],
  'testAccuracy': [0.61, 0.58, 0.56, 0.6025, 0.5075],
  'testLoss': [0.0674053356051445,
   4.199999809265137,
   4.400000095367432,
   0.06802979856729507,
   0.6365293264389038],
  'trainAUC': [0.3863636363636364, 0.5, 0.5, 0.3872785829307569, 1.0],
  'trainAccuracy': [0.56, 0.58, 0.7, 0.54, 1.0],
  'trainLoss': [0.0685952827334404,
   4.199999809265137,
   3.0,
   0.06899969279766083,
   6.407623004633933e-05],
  'tuneLoss': [0.08115255832672119,
   5.199999809265137,
   2.3999998569488525,
   0.07682645320892334,
   1.05329430103302]},
 {'testAUC': 0.4984229553807931,
  'testAccuracy': 0.572,
  'testLoss': 1.8743928730487824,
  'trainAUC': 0.5547284438588787,
  'trainAccuracy': 0.6759999999999999,
  'trainLoss': 1.4675317722052568,
  'tuneLoss': 1.7622545957565308})

In [121]:
multi_runs(fine_tune=True, K=None, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.1, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.5069774613862404,
   0.47904696238412614,
   0.5121325183244688,
   0.6140483183826837,
   0.4435182980074621],
  'testAccuracy': [0.585, 0.5725, 0.5175, 0.47, 0.48],
  'testLoss': [0.07276777923107147,
   0.06838060170412064,
   0.3830075263977051,
   0.1075957641005516,
   0.07727602124214172],
  'trainAUC': [0.5865874363327674,
   0.5714285714285714,
   1.0,
   0.8060897435897436,
   0.563301282051282],
  'trainAccuracy': [0.56, 0.58, 0.98, 0.7, 0.64],
  'trainLoss': [0.0628318265080452,
   0.06780815124511719,
   0.0020031295716762543,
   0.05002676323056221,
   0.06529764086008072],
  'tuneLoss': [0.08729052543640137,
   0.07584648579359055,
   0.25307905673980713,
   0.1024998351931572,
   0.07730937749147415]},
 {'testAUC': 0.5111447116969963,
  'testAccuracy': 0.5249999999999999,
  'testLoss': 0.1418055385351181,
  'trainAUC': 0.7054814066804729,
  'trainAccuracy': 0.6920000000000001,
  'trainLoss': 0.049593502283096315,
  'tuneLoss': 0.11920505613088608})

In [126]:
multi_runs(fine_tune=True, K=None, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.01, weight_decay=0.0001, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.8031373767181795,
   0.8275679117147707,
   0.765932881773399,
   0.8180254557963537,
   0.8527954706298656],
  'testAccuracy': [0.74, 0.78, 0.72, 0.7375, 0.785],
  'testLoss': [0.12222526222467422,
   0.14976026117801666,
   0.1272013783454895,
   0.11517591029405594,
   0.10829062014818192],
  'trainAUC': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainAccuracy': [1.0, 1.0, 1.0, 1.0, 1.0],
  'trainLoss': [3.612661748775281e-05,
   0.0001716889819363132,
   4.5609442167915404e-05,
   3.601723801693879e-05,
   2.9387227186816745e-05],
  'tuneLoss': [0.12290343642234802,
   0.18294183909893036,
   0.19423751533031464,
   0.05970530956983566,
   0.11969821900129318]},
 {'testAUC': 0.8134918193265136,
  'testAccuracy': 0.7525000000000001,
  'testLoss': 0.12453068643808365,
  'trainAUC': 1.0,
  'trainAccuracy': 1.0,
  'trainLoss': 6.376590135914739e-05,
  'tuneLoss': 0.13589726388454437})

In [130]:
# Temporarily added dropout in the code (will not be able to reproduce)
multi_runs(fine_tune=True, K=None, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=None,
           learning_rate=0.005, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.7992793955526876,
   0.7850740711189127,
   0.8039541759053955,
   0.8336572890025576,
   0.8473712934327817],
  'testAccuracy': [0.7775, 0.7175, 0.76, 0.7775, 0.765],
  'testLoss': [0.13031229376792908,
   0.10789167881011963,
   0.1365920454263687,
   0.20121319591999054,
   0.16124093532562256],
  'trainAUC': [1.0, 0.9983896940418681, 1.0, 1.0, 1.0],
  'trainAccuracy': [1.0, 0.96, 1.0, 1.0, 1.0],
  'trainLoss': [0.0009359754621982574,
   0.007983668707311153,
   0.0012027123011648655,
   0.0010768078500404954,
   0.0030503899324685335],
  'tuneLoss': [0.09233274310827255,
   0.054200902581214905,
   0.19584806263446808,
   0.30205732583999634,
   0.21459732949733734]},
 {'testAUC': 0.813867245002467,
  'testAccuracy': 0.7595,
  'testLoss': 0.1474500298500061,
  'trainAUC': 0.9996779388083736,
  'trainAccuracy': 0.992,
  'trainLoss': 0.002849910850636661,
  'tuneLoss': 0.17180727273225785})

In [131]:
# Temporarily added dropout in the code (will not be able to reproduce)
multi_runs(fine_tune=True, K=None, q_head_index=None, q_means_stds=q_mu_s, 
           use_last_linear=True, init_from_act_index=34,
           learning_rate=0.005, weight_decay=0.0005, plots=False)

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5


({'testAUC': [0.702785500978272,
   0.8599755710907248,
   0.7906640876853643,
   0.7534694812966312,
   0.7853019595661516],
  'testAccuracy': [0.645, 0.7675, 0.7325, 0.6975, 0.72],
  'testLoss': [0.13485755026340485,
   0.11228759586811066,
   0.1193864494562149,
   0.14743559062480927,
   0.10943439602851868],
  'trainAUC': [1.0, 1.0, 1.0, 1.0, 0.9949066213921902],
  'trainAccuracy': [0.98, 1.0, 0.98, 1.0, 0.98],
  'trainLoss': [0.004431566689163446,
   0.001344757853075862,
   0.0034431852400302887,
   0.0007841906044632196,
   0.009738054126501083],
  'tuneLoss': [0.0951336994767189,
   0.023204084485769272,
   0.07031336426734924,
   0.22040671110153198,
   0.08768835663795471]},
 {'testAUC': 0.7784393201234288,
  'testAccuracy': 0.7125,
  'testLoss': 0.12468031644821168,
  'trainAUC': 0.9989813242784381,
  'trainAccuracy': 0.9879999999999999,
  'trainLoss': 0.00394835090264678,
  'tuneLoss': 0.09934924319386482})

### Results

All runs use the following parameters

```
{'batch_size': 10,
 'num_epochs': 100,
 'num_run': 5,
 'num_train': 50,
 'num_tune': 25,
 'num_val': 400,
 'use_mnist': False,
 'use_qnet_weights': True,
 'val_batch_size': 10}
 ```

| fine_tune      | q_head_index | norm_q_means | use_last_linear | init_from_act_index | K     | Optimiser | learning_rate | momentum | weight_decay | testAUC |
|----------------|--------------|--------------|-----------------|---------------------|-------|-----------|---------------|----------|--------------|---------|
| False          | None (fc3)   | True         | True            | Activation 34       | ~0.99 | Adam      | 0.000         | None     | 0.000        | 0.83    |
| True           | None (fc3)   | True         | True            | None                | 1     | Adam      | 0.01          | None     | 0.001        | 0.83    |
| True           | None (fc3)   | False        | True            | None                | 1     | Adam      | 0.01          | None     | 0.001        | 0.76    |
| True           | None (fc3)   | True         | False           | None                | 1     | Adam      | 0.01          | None     | 0.001        | 0.68    |
| True           | None (fc3)   | True         | False           | None                | 1     | Adam      | 0.005         | None     | 0.005        | 0.81    |
| True           | 0            | True         | True            | None                | 1     | Adam      | 0.01          | None     | 0.005        | 0.76    |
| True           | None (fc3)   | True         | True            | None                | 1     | Adam      | 0.1           | None     | 0.001        | 0.51    |
| True           | None (fc3)   | True         | True            | None                | 1     | Adam      | 0.02          | None     | 0.02         | 0.71    |
| False          | None (fc3)   | True         | True            | None                | 0.9   | Adam      | 0.01          | None     | 0.001        | 0.75    |
| False          | None (fc3)   | True         | True            | None                | 0.6   | Adam      | 0.01          | None     | 0.001        | 0.67    |
| False          | None (fc3)   | True         | True            | None                | 0.6   | Adam      | 0.001         | None     | 0.001        | 0.68    |
| False          | None (fc3)   | True         | False           | None                | 0.9   | Adam      | 0.005         | None     | 0.001        | 0.4     |
| False          | None (fc3)   | True         | True            | Activation 34       | 0.95  | Adam      | 0.01          | None     | 0.0001       | 0.77    |
| False          | 0            | True         | True            | Activation 34       | 0.95  | Adam      | 0.01          | None     | 0.0005       | 0.81    |
| False          | 0            | False        | True            | Activation 34       | 0.95  | Adam      | 0.01          | None     | 0.0005       | 0.71    |
| False          | 0            | True         | True            | Activation 34       | 0.5   | Adam      | 0.01          | None     | 0.0005       | 0.55    |
| False          | 0            | True         | True            | None                | 0.5   | SDG       | 0.1           | 0.9      | 0.0005       | 0.73    |
| False          | 0            | True         | True            | None                | 0.95  | SDG       | 0.1           | 0.9      | 0.0005       | 0.79    |
| False          | 0            | True         | True            | Activation 34       | 0.95  | SDG       | 0.1           | 0.9      | 0.0005       | 0.66    |
| False          | 0            | True         | True            | Activation 4        | 0.95  | SDG       | 0.1           | 0.9      | 0.0005       | 0.57    |
| False          | 0            | True         | True            | Activation 34       | 0.99  | SDG       | 0.1           | 0.9      | 0.0005       | 0.66    |
| True           | None (fc3)   | True         | True            | None                | 1     | SDG       | 0.1           | 0.9      | 0.0005       | 0.5     |
| True           | None (fc3)   | True         | True            | None                | 1     | Adam      | 0.1           | None     | 0.0005       | 0.51    |
| True w/DropOut | None (fc3)   | True         | True            | None                | 1     | Adam      | 0.005         | None     | 0.0005       | 0.81    |
| True w/DropOut | None (fc3)   | True         | True            | Activation 34       | 1     | Adam      | 0.005         | None     | 0.005        | 0.78    |

## Hyperparameters Tuning (not used)

In [0]:
def multi_runs_tune(config, reporter):

    averaged_results = {}
    tuning_losses = []
    for run_ix in range(params['num_run']):
        if params['fine_tune']:
            K = None
        else:
            K = config['k']

        model = PrefQNet(fine_tune=params['fine_tune'], k=K, 
                         q_head_index=None, q_means_stds=q_mu_s, use_last_linear=True, init_from_act_index=34).to(device)

        if params['use_qnet_weights']:
            load_weights(model)
        
        results = run_model(model, K, config['lr'], config['decay'], num_epochs=params['num_epochs'])

        if reporter is not None:  # Hyperp-tuning pass
            tuning_losses.append(results['tuneLoss'][-1])
            reporter(timesteps_total=run_ix, mean_loss=sum(tuning_losses)/len(tuning_losses))

def launch_tune():

    space = {
        "k": (0.05, 0.95), 
        "lr": (0.001, 0.1), 
        'decay': (0.0001, 0.05)
    }

    if params['fine_tune']:
        space['k'] = (1, 1)

    config = {"num_samples": params['num_tune_iters'], "stop": {"timesteps_total": params['num_run']}}

    algo = BayesOptSearch(space, metric="mean_loss", mode="min", utility_kwargs={
        "kind": "ucb", "kappa": 2.5, "xi": 0.0})

    scheduler = AsyncHyperBandScheduler(metric="mean_loss", mode="min")

    return tune.run(multi_runs_tune, resources_per_trial={'gpu': 1, 'cpu': 2}, verbose=1, 
                    name="tune_exp", search_alg=algo, scheduler=scheduler, **config)