# Find RL Agent Subnetworks using Supermasks Techniques

---




## SetUp & Imports

In [1]:
!git clone https://github.com/arunraja-hub/Preference_Extraction.git

Cloning into 'Preference_Extraction'...
remote: Enumerating objects: 201, done.[K
remote: Counting objects: 100% (201/201), done.[K
remote: Compressing objects: 100% (142/142), done.[K
remote: Total 1201 (delta 135), reused 104 (delta 59), pack-reused 1000[K
Receiving objects: 100% (1201/1201), 38.18 MiB | 10.96 MiB/s, done.
Resolving deltas: 100% (449/449), done.


In [2]:
# Check what GPU is active
!nvidia-smi

Wed Sep 30 07:52:16 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.05    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P8    30W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
try:
  %tensorflow_version 2.x
except Exception:
  pass

In [1]:
from __future__ import print_function
import argparse
import os
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import label_binarize
import joblib

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms.functional as TF
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
import torch.autograd as autograd
from torchsummary import summary

from sklearn.utils import shuffle
import tensorflow as tf
import concurrent.futures
import itertools
import os
import random
import sys
import time
import re
import io
import itertools
import sys

sys.path.append('Preference_Extraction')
from extractors.data_getter import get_data_from_folder

In [2]:
PARAMS = {}
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
device

device(type='cpu')

## Load Data

In [3]:
# Data source options, choices: 'mnist', 'grid' and 'doom'
PARAMS['env'] = 'grid' 

In [42]:
# Load doom data from GCP storage, run only if env == 'doom'
# Requires GCP authentication
# Change last line to load different experience data
from google.colab import auth
auth.authenticate_user()
!gsutil cp gs://pref-extr-data/agentv29/experienceData.pkl experienceData.pkl

Copying gs://pref-extr-data/agentv29/experienceData.pkl...
- [1 files][  1.3 GiB/  1.3 GiB]   46.8 MiB/s                                   
Operation completed over 1 objects/1.3 GiB.                                      


In [7]:
# Data Loader
xs = []
ys = []

if PARAMS['env'] == 'grid':
    all_raw_data = get_data_from_folder("data/simple_env_1/")
    for data in all_raw_data:
        for i in range(data.observation.shape[0]):
            x = np.copy(data.observation[i])
            y = (data.policy_info['satisfaction'].as_list()[i] > -6).astype(int)

            xs.append(x)
            ys.append(y)

    ys = np.array(ys).astype(int)
    xs = np.rollaxis(np.array(xs), 3, 1) # Torch wants channel-first

elif PARAMS['env'] == 'doom':
    all_raw_data = joblib.load('/content/experienceData.pkl')
    for data in all_raw_data:
        for i in range(data.observation.shape[0]):
            x = np.copy(data.observation[i])
            label_object = data.policy_info['satisfaction'][i]
            if len(label_object) == 0: # When label is empty, i.e. human is dead, skip frame
                continue
            else:
                y = label_object['object_angle'] < 90 or label_object['object_angle'] > 270

            xs.append(x)
            ys.append(y)
    
    ys = np.array(ys).astype(int)
    xs = np.rollaxis(np.array(xs), 3, 1) # Torch wants channel-first

elif PARAMS['env'] == 'mnist':
    mnist_choice = 3 # Change this value to do binary classification against other mnist digit
    (mnist_xs, mnist_ys), _ = tf.keras.datasets.mnist.load_data()
    for ix, label in enumerate(mnist_ys):
        x = np.copy(mnist_xs[ix])
        y = (label == mnist_choice).astype(int)

        xs.append(np.array([x]))
        ys.append(y)

    ys = np.array(ys).astype(int)
    xs = np.array(xs) # Mnist has only one channel
    del mnist_xs, mnist_ys

else:
    print('Invalid enviroment choice!!')

print("xs", xs.shape, "ys", ys.shape)
print("ys 1", np.sum(ys))

xs (23750, 5, 14, 16) ys (23750,)
ys 1 9569


In [154]:
# Rebalancing data to minority class
points = xs
labels = ys

# indexes of 1s and 0s
indexes1 = [i for i in range(len(points)) if labels[i] == 1]
indexes0 = [i for i in range(len(points)) if labels[i] == 0]

# separate 0s and 1s
x0, x1, y0, y1 = points[indexes0], points[indexes1], labels[indexes0], labels[indexes1]

minority_points, minority_labels = x1, y1  # points and labels for the minority class
majority_points, majority_labels = x0, y0  # points and labels for the majority class

# get a random permutation of indexes of the majority that includes a number of indexes equal to the minority
sample_ind = np.random.permutation(len(majority_labels))[:len(minority_labels)]

# subsample the majority
majority_points, majority_labels = majority_points[sample_ind], majority_labels[sample_ind]

# concat the minority and the sub-sampled majority
xs = np.concatenate((majority_points, minority_points))
ys = np.concatenate((majority_labels, minority_labels))

del points, labels, x0, x1, y0, y1
del minority_points, majority_points, minority_labels, majority_labels,

print("xs", xs.shape, "ys", ys.shape)
print("ys 1", np.sum(ys))
xs, ys = shuffle(xs, ys)

xs (3656, 6, 60, 100) ys (3656,)
ys 1 1828


## Nets and Subnets Architectures

In [191]:
"""
    Original code from What's hidden in a randomly weighted neural network? paper
    Implemented at https://github.com/allenai/hidden-networks
"""

class GetSubnet(autograd.Function):
    @staticmethod
    def forward(ctx, scores, k):
        # Get the supermask by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1 - k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1

        return out

    @staticmethod
    def backward(ctx, g):
        # send the gradient g straight-through on the backward pass.
        return g, None

class SupermaskConv(nn.Conv2d):
    def __init__(self, *args, k, scores_init='kaiming_uniform', **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k
        self.scores_init = scores_init

        # initialize the scores
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        if self.scores_init == 'kaiming_normal':
          nn.init.kaiming_normal_(self.scores)
        elif self.scores_init == 'kaiming_uniform':
          nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))
        elif self.scores_init == 'xavier_normal':
          nn.init.xavier_normal_(self.scores)
        elif self.scores_init == 'best_activation':
          nn.init.ones_(self.scores)
        else:
          nn.init.uniform_(self.scores)
        
        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), self.k)
        w = self.weight * subnet
        x = F.conv2d(
            x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x

class SupermaskLinear(nn.Linear):
    def __init__(self, *args, k, scores_init='kaiming_uniform', **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k
        self.scores_init = scores_init

        # initialize the scores and weights
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        if self.scores_init == 'kaiming_normal':
          nn.init.kaiming_normal_(self.scores)
        elif self.scores_init == 'kaiming_uniform':
          nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))
        elif self.scores_init == 'xavier_normal':
          nn.init.xavier_normal_(self.scores)
        elif self.scores_init == 'best_activation':
          nn.init.ones_(self.scores)
        else:
          nn.init.uniform_(self.scores)

        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), self.k)
        w = self.weight * subnet
        return F.linear(x, w, self.bias)
        return x

# NOTE: not used here but we use NON-AFFINE Normalization!
# So there is no learned parameters for your nomralization layer.
class NonAffineBatchNorm(nn.BatchNorm2d):
    def __init__(self, dim):
        super(NonAffineBatchNorm, self).__init__(dim, affine=False)

In [156]:
class DQNMaskNet(nn.Module):
    """
    Class for Supermask Networks for DQN Agents
    When picking an agent with a different architecture
    You need to change the specs of this class as well (e.g. add a layer)
    If q_head_index is None, this uses a linear model on the normalized q outputs
    otherwise, it gets the Q head with the specified index.
    """ 
    def __init__(self, fine_tune, k, q_head_index, use_last_linear,
                 means_stds=[[0, 0, 0], [1, 1, 1]], init_from_act_index=None):
        super(DQNMaskNet, self).__init__()
        
        if not PARAMS['env'] == 'mnist':
            channels_in = 5
            flattened_shape = 960
        else:
            channels_in = 1
            flattened_shape = 4608

        if fine_tune:
            conv_layer = nn.Conv2d
            dense_layer = nn.Linear
            additional_args = {}
            init_from_act_index = None
        else:
            conv_layer = SupermaskConv
            dense_layer = SupermaskLinear
            additional_args = {'k': k}
            if init_from_act_index is not None:
                additional_args['scores_init'] = 'best_activation'
        
        self.conv1 = conv_layer(in_channels=channels_in, out_channels=16, 
                                kernel_size=3, stride=1, bias=True, **additional_args)
        self.conv2 = conv_layer(in_channels=16, out_channels=32, 
                                kernel_size=3, stride=2, bias=True, **additional_args)
        self.fc1 = dense_layer(in_features=flattened_shape, out_features=64, 
                               bias=True, **additional_args)
        self.fc2 = dense_layer(in_features=64, out_features=3, bias=True, **additional_args)
        
        if init_from_act_index is not None:
            init_scores = np.zeros((3, 64))
            init_scores[:, init_from_act_index] = 1.0
            self.fc2.scores.data = torch.from_numpy(init_scores).float()

        self.fc3 = dense_layer(in_features=3, out_features=1, bias=True, **additional_args)
        self.linear = nn.Linear(1, 1, bias=True)

        self.qix = q_head_index
        self.mu_s = means_stds
        self.use_last_linear = use_last_linear

    def fwd_conv1(self, x):
        x = self.conv1(x)
        return F.relu(x)

    def fwd_conv2(self, x):
        x = self.fwd_conv1(x)
        x = self.conv2(x)
        return F.relu(x)

    def fwd_flat(self, x):
        x = self.fwd_conv2(x)
        # Pre-flattening transpose is necessary for TF-Torch conversion
        return torch.flatten(torch.transpose(x, 1, 3), 1)

    def fwd_fc1(self, x):
        x = self.fwd_flat(x)
        x = self.fc1(x)
        return F.relu(x)
    
    def fwd_fc2(self, x):
        x = self.fwd_fc1(x)
        return self.fc2(x)

    def forward(self, x):
        x = self.fwd_fc2(x)

        x -= torch.tensor(self.mu_s[0], device=device)
        x /= torch.tensor(self.mu_s[1], device=device)

        if self.qix == None:
          x = self.fc3(x)
        else:
          x = x[: ,self.qix:self.qix+1]

        if self.use_last_linear:
          x = self.linear(x)

        x = torch.sigmoid(x)
        return x.flatten()

    def layer_to_norm(self, x):
        return self.fwd_fc2(x)

    def load_weights(self, original_weights):
        if not PARAMS['env'] == 'mnist':
            self.conv1.weight.data = torch.from_numpy(np.transpose(original_weights[0]))
            self.fc1.weight.data = torch.from_numpy(np.transpose(original_weights[4]))
        else:
            self.conv1.weight.data = torch.from_numpy(np.transpose(original_weights[0][:,:,:1,:]))
            mnist_flt_weights = np.random.rand(64, 4608)
            mnist_flt_weights[:, :original_weights[4].shape[0]] = np.transpose(original_weights[4])
            mnist_flt_weights = mnist_flt_weights.astype(np.float32)
            self.fc1.weight.data = torch.from_numpy(mnist_flt_weights)

        self.conv1.bias.data = torch.from_numpy(original_weights[1])
        self.conv2.weight.data = torch.from_numpy(np.transpose(original_weights[2]))
        self.conv2.bias.data = torch.from_numpy(original_weights[3])
        self.fc1.bias.data = torch.from_numpy(original_weights[5])
        self.fc2.weight.data = torch.from_numpy(np.transpose(original_weights[6]))
        self.fc2.bias.data = torch.from_numpy(original_weights[7])
        self.fc3.weight.data = torch.from_numpy(np.ones(shape=[1,3], dtype=np.float32))
        self.fc3.bias.data = torch.from_numpy(np.zeros(shape=[1], dtype=np.float32))
        self.to(device)

In [157]:
class PPOMaskNet(nn.Module):
    """
    Class for Supermask Networks for PPO Agents
    When picking an agent with a different architecture
    You need to change the specs of this class as well (e.g. add a layer)
    """ 
    def __init__(self, fine_tune, k, use_last_linear, means_stds=[[0, 0, 0], [1, 1, 1]],
                 dropout=None):
        super(PPOMaskNet, self).__init__()
        
        if fine_tune:
            conv_layer = nn.Conv2d
            dense_layer = nn.Linear
            additional_args = {}
            init_from_act_index = None
        else:
            conv_layer = SupermaskConv
            dense_layer = SupermaskLinear
            additional_args = {'k': k}
        
        self.fc1 = dense_layer(in_features=36000, out_features=200, bias=True, **additional_args)
        self.fc2 = dense_layer(in_features=200, out_features=100, bias=True, **additional_args)
        #self.fc3 = dense_layer(in_features=100, out_features=4, bias=True, **additional_args)
        self.drop = None
        if dropout is not None:
            self.drop = nn.Dropout(p=dropout)
        self.linear = nn.Linear(100, 1, bias=True)
        self.use_last_linear = use_last_linear

        self.mu_s = means_stds

    def fwd_flat(self, x):
        # Pre-flattening transpose is necessary for TF-Torch conversion
        return torch.flatten(torch.transpose(x, 1, 3), 1)

    def fwd_fc1(self, x):
        x = self.fwd_flat(x)
        x = self.fc1(x)
        return F.relu(x)
    
    def fwd_fc2(self, x):
        x = self.fwd_fc1(x)
        x = self.fc2(x)
        return F.relu(x)

    def fwd_fc3(self, x):
        x = self.fwd_fc2(x)
        return self.fc3(x)

    def fwd_fc4(self, x):
        x = self.fwd_fc3(x)
        return self.fc4(x)

    def forward(self, x):
        x = self.fwd_fc2(x)

        x -= torch.tensor(self.mu_s[0], device=device)
        x /= torch.tensor(self.mu_s[1], device=device)

        if self.use_last_linear:
            if self.drop is not None:
                x = self.drop(x)
            x = self.linear(x)

        x = torch.sigmoid(x)
        return x.flatten()

    def layer_to_norm(self, x):
        # Change the function below when you want to do activation
        # normalizaiton on a different layer of the agent
        return self.fwd_fc2(x)

    def load_weights(self, original_weights):
        self.fc1.weight.data = torch.from_numpy(np.transpose(original_weights[0]))
        self.fc1.bias.data = torch.from_numpy(original_weights[1])
        self.fc2.weight.data = torch.from_numpy(np.transpose(original_weights[2]))
        self.fc2.bias.data = torch.from_numpy(original_weights[3])
        #self.fc3.weight.data = torch.from_numpy(np.transpose(original_weights[4]))
        #self.fc3.bias.data = torch.from_numpy(original_weights[5])
        self.to(device)

In [174]:
class RandomMaskNet(nn.Module):
    """
    Big Randomly Weighted Neural Net
    Used as a baseline for supermask technique
    Copies original Net class from paper
    See https://github.com/allenai/hidden-networks/blob/master/simple_mnist_example.py
    """ 
    def __init__(self, k, dropout):
        super(RandomMaskNet, self).__init__()
        if PARAMS['env'] == 'doom':
            channels_in = 6
            flattened_shape = 86016
        elif PARAMS['env'] == 'grid':
            channels_in = 5
            flattened_shape = 960
        elif PARAMS['env'] == 'mnist':
            channels_in = 1
            flattened_shape = 4608

        self.conv1 = SupermaskConv(in_channels=channels_in, out_channels=32, 
                                   kernel_size=3, stride=1, bias=False, k=k)
        self.conv2 = SupermaskConv(in_channels=32, out_channels=64, 
                                   kernel_size=3, stride=1, bias=False, k=k)
        self.dropout1 = nn.Dropout2d(dropout)
        self.dropout2 = nn.Dropout2d(dropout)
        self.fc1 = SupermaskLinear(in_features=flattened_shape, out_features=128, bias=False, k=k)
        self.fc2 = SupermaskLinear(in_features=128, out_features=1, bias=False, k=k)

        nn.init.kaiming_normal_(self.conv1.weight, mode="fan_in", nonlinearity="relu")
        nn.init.kaiming_normal_(self.conv2.weight, mode="fan_in", nonlinearity="relu")
        nn.init.kaiming_normal_(self.fc1.weight, mode="fan_in", nonlinearity="relu")
        nn.init.kaiming_normal_(self.fc2.weight, mode="fan_in", nonlinearity="relu")

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)
        return x.flatten()

## Loading Weights

In [15]:
!gsutil cp gs://pref-extr-data/agentv29/actorNet.keras actorNet.keras

Copying gs://pref-extr-data/agentv29/actorNet.keras...
\ [1 files][ 27.6 MiB/ 27.6 MiB]                                                
Operation completed over 1 objects/27.6 MiB.                                     


In [16]:
# Shortlisting different agents
dqn_grid_path = "Preference_Extraction/saved_model2"
ppo_doom_path = "/content/actorNet.keras"

In [186]:
# Choose agent to use (ppo, dqn or random)
PARAMS['agent_type'] = 'ppo'

if PARAMS['agent_type'] == 'ppo':
    model_path = ppo_doom_path
elif PARAMS['agent_type'] == 'dqn':
    model_path = dqn_grid_path

In [187]:
restored_model = tf.keras.models.load_model(model_path)
restored_model.summary()
original_weights = restored_model.get_weights()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (1, 36000)                0         
_________________________________________________________________
dense (Dense)                (1, 200)                  7200200   
_________________________________________________________________
dense_1 (Dense)              (1, 100)                  20100     
_________________________________________________________________
logits (Dense)               (1, 4)                    404       
Total params: 7,220,704
Trainable params: 7,220,704
Non-trainable params: 0
_________________________________________________________________


In [105]:
if PARAMS['agent_type'] == 'dqn':
    test_model = DQNMaskNet(k=1, fine_tune=False, q_head_index=None,
                            use_last_linear=True).to(device)
elif PARAMS['agent_type'] == 'ppo':
    test_model = PPOMaskNet(k=1, fine_tune=False, use_last_linear=True).to(device)

test_model.load_weights(original_weights)

In [108]:
# Testing that weigths loaded property
# Change this list to test different torch models
torch_model_layers = [test_model.fwd_conv1,
                    test_model.fwd_conv2,
                    test_model.fwd_fc1, 
                    test_model.fwd_fc2]
                    #test_model.fwd_fc3]

# Does not work of mnist because original agents don't have the shape!
if not PARAMS['env'] == 'mnist':

    def check_same(torch_layer, tf_layer):
        torch_out = np.transpose(torch_layer(single_observation_torch).detach().cpu().numpy())
        torch_out = torch_out.reshape(torch_out.shape[:-1])
        tf_out = tf_layer(np.array([np.transpose(single_observation)]))[0].numpy()
        np.testing.assert_allclose(torch_out, tf_out, rtol=.1, atol=5)  

    for i in range(100):
        single_observation = xs[i]
        single_observation_torch = torch.Tensor(np.array([xs[i]]))
        single_observation_torch = single_observation_torch.to(device)
        
        index_shift = 0
        for ix, original_lyr in enumerate(restored_model.layers):
            if ix < len(torch_model_layers):
                if original_lyr.name == 'flatten':
                    index_shift = 1
                else:
                    tf_sub_model = tf.keras.models.Model(inputs=restored_model.input, outputs=original_lyr.output)
                    check_same(torch_model_layers[ix-index_shift], tf_sub_model)

### Data getters

In [145]:
PARAMS['num_train'] = 500
PARAMS['num_val'] = 500
PARAMS['batch_size'] = 128

def get_data_sample(xs=None, ys=None):
    xs, ys = shuffle(xs, ys)
    
    train_split = PARAMS['num_train']
    test_split = PARAMS['num_train']+PARAMS['num_val']

    tr_data_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.Tensor(xs[:train_split]), torch.Tensor(ys[:train_split])),
        batch_size=PARAMS['batch_size'])
    
    val_data_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.Tensor(xs[train_split:test_split]), torch.Tensor(ys[train_split:test_split])),
        batch_size=PARAMS['batch_size'])
    
    return tr_data_loader, val_data_loader, xs[:train_split]

def get_heads_mu_and_sigma(model, obs):
    
    model.eval()
    obs_tensor = torch.Tensor(obs)
    obs_tensor = obs_tensor.to(device)
    heads_values = model.layer_to_norm(obs_tensor).detach().cpu().numpy()

    mu = heads_values.mean(axis=0)
    s = heads_values.std(axis=0)
    s[s == 0] = 1
    
    return np.array([mu, s])

### Single run train/test methods

In [190]:
"""
    Train/Test function for Randomly Weighted Hidden Neural Networks Techniques
    Adapted from https://github.com/NesterukSergey/hidden-networks/blob/master/demos/mnist.ipynb
"""

def compute_metrics(predictions, true_labels):
    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    accuracy = np.sum(np.equal((predictions > 0.5).astype(int), true_labels)) / len(true_labels)
    fpr, tpr, thresholds = metrics.roc_curve(true_labels, predictions)
    auc = metrics.auc(fpr, tpr)
    return accuracy, auc

def train(model, device, train_loader, optimizer, criterion):
    
    train_loss = 0
    true_labels = []
    predictions = [] # labels

    model.train()

    for data, target in itertools.islice(train_loader, PARAMS['num_train']):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss
        predictions.extend(output.detach().cpu().numpy())
        true_labels.extend(target.detach().cpu().numpy())
    
    train_loss /= len(train_loader.dataset)
    accuracy, auc = compute_metrics(predictions, true_labels)

    return train_loss.item(), accuracy, auc

def test(model, device, test_loader, criterion):
    true_labels = []
    predictions = [] # labels

    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, target in itertools.islice(test_loader, PARAMS['num_val']):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target)
            predictions.extend(output.detach().cpu().numpy())
            true_labels.extend(target.detach().cpu().numpy())
    
    test_loss /= len(test_loader.dataset)
    accuracy, auc = compute_metrics(predictions, true_labels)

    return test_loss.item(), accuracy, auc

def run_model(model, learning_rate, weight_decay, num_epochs):

  tr_data_loader, val_data_loader, x_train = get_data_sample(xs, ys)

  # Normalise last layer using training data
  if hasattr(model, 'layer_to_norm'):
      model.mu_s = get_heads_mu_and_sigma(model, x_train)
  
  optimizer = optim.Adam(
      [p for p in model.parameters() if p.requires_grad],
      lr=learning_rate,
      weight_decay=weight_decay
  )

  criterion = nn.BCELoss().to(device)
  scheduler = CosineAnnealingLR(optimizer, T_max=len(tr_data_loader))

  train_losses = []
  test_losses = []
  train_accs = []
  train_aucs = []
  test_accs = []
  test_aucs = []  

  best_test_loss = np.inf
  test_loss_up_since = 0
  early_stop = 100
  verbose = False
  for epoch in range(num_epochs):
      train_loss, train_accuracy, train_auc = train(model, device, tr_data_loader, optimizer, criterion)
      test_loss, test_accuracy, test_auc = test(model, device, val_data_loader, criterion)
      scheduler.step()
      if test_loss < best_test_loss:
          best_test_loss = test_loss
          test_loss_up_since = 0
      test_loss_up_since += 1
      if test_loss_up_since > early_stop:
          print('Epoch - ', epoch, 'Early stopping')
          break
      if verbose:
          print('Epoch - ', epoch)
          print('Train metrics: loss', train_loss, 'accuracy', train_accuracy, 'auc', train_auc)
          print('Val metrics: loss', test_loss, 'accuracy', test_accuracy, 'auc', test_auc)

      train_losses.append(train_loss)
      test_losses.append(test_loss)
      train_accs.append(train_accuracy)
      train_aucs.append(train_auc)
      test_accs.append(test_accuracy)
      test_aucs.append(test_auc)

  return {'trainLoss': train_losses, 'testLoss': test_losses, 
          'trainAccuracy': train_accs, 'testAccuracy': test_accs,
          'trainAUC': train_aucs, 'testAUC': test_aucs}

def plot_metric(results_dict, metric):
    plt.title(metric)
    plt.xlabel('Epochs')
    plt.plot(list(range(1, len(results_dict[f'train{metric}']) + 1)), results_dict[f'train{metric}'], label=f'Train {metric}')
    plt.plot(list(range(1, len(results_dict[f'train{metric}']) + 1)), results_dict[f'test{metric}'], label=f'Test {metric}')
    plt.legend()
    plt.show()

def multi_runs(model, learning_rate, weight_decay, plots=False):

    averaged_results = {}    
    for run_ix in range(PARAMS['num_run']):
 
        results = run_model(model, learning_rate=learning_rate, weight_decay=weight_decay, num_epochs=PARAMS['num_epochs'])       
        print(f'Train pass no. {run_ix+1}')

        if (run_ix == 0):
            print(model) 
            if plots:
                print('Debug charts for first training run')
                plot_metric(results, 'Loss')
                plot_metric(results, 'Accuracy')
                plot_metric(results, 'AUC')

        for res in results:
            if len(results[res]) > 0:
                if res not in averaged_results:
                    averaged_results[res] = [results[res][-1]]
                else:
                    averaged_results[res].append(results[res][-1])         
    
    return {x: sum(averaged_results[x]) / PARAMS['num_run'] for x in averaged_results}

def multi_agent_train(hparams):
    
    load_weigths = True
    if PARAMS['agent_type'] == 'ppo':
        model = PPOMaskNet(
            k=hparams['k'], 
            fine_tune=hparams['fine_tune'],
            means_stds=None,
            dropout=hparams['dropout'],
            use_last_linear=hparams['use_last_linear']).to(device)
    
    elif PARAMS['agent_type'] == 'dqn':
        model = DQNMaskNet(
            k=hparams['k'], 
            fine_tune=hparams['fine_tune'],
            q_head_index=hparams['q_head_index'],
            means_stds=None,
            use_last_linear=hparams['use_last_linear']).to(device)

    else:
        model = RandomMaskNet(
            k=hparams['k'],
            dropout=hparams['dropout'],).to(device)
        load_weigths = False
    
    if load_weigths:
       model.load_weights(original_weights)

    return multi_runs(model, learning_rate=hparams['learning_rate'], 
                      weight_decay=hparams['weight_decay'], plots=False)

In [184]:
# Set up hyperparamets for model training
PARAMS['num_run'] = 1
PARAMS['num_epochs'] = 100

all_hparam_possibilities = [
   {
    "dropout": [0.2],
    "k": [0.1],  
    "fine_tune": [False],
    "use_last_linear": [True],
    "learning_rate": [1e-4], 
    "weight_decay": [0],
    "q_head_index": [None]
   }
]

In [188]:
hparam_combinations = []
for hparam_possibilities in all_hparam_possibilities:
  hparam_keys, hparam_values = zip(*hparam_possibilities.items())
  hparam_combinations.extend([dict(zip(hparam_keys, v)) for v in itertools.product(*hparam_values)])
random.shuffle(hparam_combinations)
print('PARAMS', PARAMS)
print("len(hparam_combinations)", len(hparam_combinations), "hparam_combinations", hparam_combinations)

best_test_auc = -float('inf')
for hparams in hparam_combinations:
    print("hparams", hparams)
    results = multi_agent_train(hparams)
    print(results)
    test_auc = results['testAUC']
    if test_auc > best_test_auc:
        best_test_auc = test_auc
        best_hparams = hparams

print("Retraining on the best_hparams to make sure we didn't just get good results by random chance.")
print("best_hparams", best_hparams)
print("Result of retrain on the best hyperparameters", multi_agent_train(best_hparams))

PARAMS {'env': 'doom', 'num_train': 500, 'num_val': 500, 'batch_size': 128, 'num_run': 1, 'num_epochs': 100, 'agent_type': 'ppo', 'best_hparams': {'dropout': 0.2, 'k': 0.1, 'fine_tune': False, 'use_last_linear': True, 'learning_rate': 0.0001, 'weight_decay': 0, 'q_head_index': None}}
len(hparam_combinations) 1 hparam_combinations [{'dropout': 0.2, 'k': 0.1, 'fine_tune': False, 'use_last_linear': True, 'learning_rate': 0.0001, 'weight_decay': 0, 'q_head_index': None}]
hparams {'dropout': 0.2, 'k': 0.1, 'fine_tune': False, 'use_last_linear': True, 'learning_rate': 0.0001, 'weight_decay': 0, 'q_head_index': None}
Epoch -  0
Train metrics: loss 0.012701721861958504 accuracy 0.508 auc 0.4779825908858166
Val metrics: loss 0.005908384453505278 accuracy 0.508 auc 0.48295854955474404
Epoch -  1
Train metrics: loss 0.006377368234097958 accuracy 0.494 auc 0.48345494111623144
Val metrics: loss 0.00567703926935792 accuracy 0.484 auc 0.5202447306041387
Epoch -  2
Train metrics: loss 0.00644138967618

## Initialise Subnets Search with activation that obtained optimal AUC in previous experiment

We do this both as a sanity check as well as a potential improvement

In [None]:
from sklearn import metrics 

acts = []
prefs = []

for data in all_raw_data:
    for i in range(data.observation.shape[0]):
        acts.append(np.copy(data.policy_info["activations"][i]))
        prefs.append((data.policy_info['satisfaction'].as_list()[i] > -6).astype(int))

acts = np.array(acts)
prefs = np.array(prefs)

def display_auc_info(xs, ys):
    
    def calc_auc(xs, ys, i):
        fpr, tpr, thresholds = metrics.roc_curve(ys, xs[:,i], pos_label=1)
        return metrics.auc(fpr, tpr)


    multi_runs_aucs = []
    for run_ix in range(50):
        xs, ys = shuffle(xs, ys)
        flat_xs = np.reshape(xs, (xs.shape[0], -1))
        aucs = []    
        
        for i in range(flat_xs.shape[1]):
            auc = calc_auc(flat_xs[:params['num_train']], ys[:params['num_train']], i)
            aucs.append(auc)  

        aucs = np.array(aucs)
        multi_runs_aucs.append(aucs)

    aucs = np.array(multi_runs_aucs)
    aucs = aucs.mean(axis=0)

    print("AUC from only picking a single activation")
    print(np.argmin(aucs), "train", 1-np.min(aucs), "val", 1-calc_auc(flat_xs[params['num_train']:], ys[params['num_train']:], np.argmin(aucs)))
    print(np.argmax(aucs), "train", np.max(aucs), "val", calc_auc(flat_xs[params['num_train']:], ys[params['num_train']:], np.argmax(aucs)))
  
display_auc_info(acts, prefs)

AUC from only picking a single activation
34 train 0.8168389427248516 val 0.8203183087179845
13 train 0.6133060108723783 val 0.6179787990821131


In [None]:
## TODO: fix broken code after refactoring notebook

best_act_index = 34
K = 66774 / 67152  # Num of weigths with all dense activations except one set to 0 / Number of total weights

params['num_epochs'] = 1
results = multi_runs(fine_tune=False, K=1, q_head_index=None, q_means_stds=q_mu_s, 
                     use_last_linear=False, init_from_act_index=best_act_index,
                     learning_rate=0.000, weight_decay=0.000, plots=False)

####results is a tuple where the first element is the dict
print(1-max(results[0]['testAUC']))

Train pass no. 1
Train pass no. 2
Train pass no. 3
Train pass no. 4
Train pass no. 5
0.8254807323986824
