# Everything

## Imports

In [58]:
%tensorflow_version 2.x

!git clone https://github.com/arunraja-hub/Preference_Extraction.git

fatal: destination path 'Preference_Extraction' already exists and is not an empty directory.


In [0]:
from __future__ import print_function
import argparse
import os
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import label_binarize

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms.functional as TF
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
import torch.autograd as autograd
from torchsummary import summary

from tqdm import tqdm
from sklearn.utils import shuffle
from sklearn.metrics import roc_auc_score
import tensorflow as tf
import concurrent.futures
import itertools
import os
import random
import sys
import time
import re
import io
import itertools
import sys

sys.path.append('Preference_Extraction')
from imports_data import all_load_data

In [0]:
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [61]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
device

device(type='cuda')

## Parameters

In [0]:
params = {
    'num_train': 50,
    'num_val': 400,
    'batch_size': 10,
    'val_batch_size': 10,
    'use_qnet_weights': True, # Flag for running models that use the weights of Qnet vs models that use random weights
    'use_mnist': True  # Flag for running models on MNIST. If False uses RL Preference Extraction data
}

## Subnets Methods

In [0]:
"""
    Original code from What's hidden in a randomly weighted neural network? paper
    Implemented at https://github.com/allenai/hidden-networks
    Remove weigths-initialisation since it is not relevant for us
"""

class GetSubnet(autograd.Function):
    @staticmethod
    def forward(ctx, scores, k):
        # Get the supermask by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1 - k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1

        return out

    @staticmethod
    def backward(ctx, g):
        # send the gradient g straight-through on the backward pass.
        return g, None

class SupermaskConv(nn.Conv2d):
    def __init__(self, *args, k, **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k

        # initialize the scores
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))

        # initialize the weights
        nn.init.uniform_(self.weight)
        
        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), self.k)
        w = self.weight * subnet
        x = F.conv2d(
            x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x

class SupermaskLinear(nn.Linear):
    def __init__(self, *args, k, **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k

        # initialize the scores and weights
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))
        nn.init.uniform_(self.weight)

        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), self.k)
        w = self.weight * subnet
        return F.linear(x, w, self.bias)
        return x

# NOTE: not used here but we use NON-AFFINE Normalization!
# So there is no learned parameters for your nomralization layer.
class NonAffineBatchNorm(nn.BatchNorm2d):
    def __init__(self, dim):
        super(NonAffineBatchNorm, self).__init__(dim, affine=False)

## Load Data

In [0]:
# Run this cell for training with original RL Preference Extraction data
if not params['use_mnist']:
    all_raw_data = all_load_data("Preference_Extraction/data/simple_env_1/")

    activations = []
    observations = []
    preferences = []

    for data in all_raw_data:
        for i in range(data.observation.shape[0]):
            observations.append(np.copy(data.observation[i]))
            activations.append(np.copy(data.policy_info["activations"][i]))
            preferences.append((data.policy_info['satisfaction'].as_list()[i] > -6).astype(int))

    activations = np.array(activations)

    xs = np.rollaxis(np.array(observations), 3, 1) # Torch wants channel-first
    ys = np.array(preferences)
    xs, ys = shuffle(xs, ys)

    xs_tr = xs[:params['num_train']]
    ys_tr = ys[:params['num_train']]
    xs_val = xs[params['num_train']:params['num_train']+params['num_val']]
    ys_val = ys[params['num_train']:params['num_train']+params['num_val']]

    tr_data_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.Tensor(xs_tr), torch.Tensor(ys_tr)),
        batch_size=params['batch_size'])

    val_data_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.Tensor(xs_val), torch.Tensor(ys_val)),
        batch_size=params['val_batch_size'])

In [0]:
# Run this cell for training with MNIST data
if params['use_mnist']:
    tr_data_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), batch_size=params['batch_size'], shuffle=True)

    val_data_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist', train=False, download=True, 
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])), batch_size=params['val_batch_size'], shuffle=True)

## Define architecture

In [0]:
class SuperMaskQNets(nn.Module):
    """
      If q_head_index is None, this uses a linear model on the normalized q outputs.
      Otherwise, it gets the Q head with the specified index.
    """ 
    def __init__(self, k, q_head_index, q_means_stds):
        super(SuperMaskQNets, self).__init__()
        
        if not params['use_mnist']:
            channels_in = 5
            flattened_shape = 960
        else:
            channels_in = 1
            flattened_shape = 4608

        self.conv1 = SupermaskConv(in_channels=channels_in, out_channels=16, kernel_size=3, stride=1, bias=True, k=k)
        self.conv2 = SupermaskConv(in_channels=16, out_channels=32, kernel_size=3, stride=2, bias=True, k=k)
        self.fc1 = SupermaskLinear(in_features=flattened_shape, out_features=64, bias=True, k=k)
        self.fc2 = SupermaskLinear(in_features=64, out_features=3, bias=True, k=k)
        self.fc3 = SupermaskLinear(in_features=3, out_features=1, bias=True, k=k)

        self.qix = q_head_index
        self.qu_mu_s = q_means_stds

    def fwd_conv1(self, x):
        x = self.conv1(x)
        return F.relu(x)

    def fwd_conv2(self, x):
        x = self.fwd_conv1(x)
        x = self.conv2(x)
        return F.relu(x)

    def fwd_flat(self, x):
        x = self.fwd_conv2(x)
        return torch.flatten(torch.transpose(x, 1, 3), 1) # Pre-flattening transpose is necessary for TF-Torch conversion

    def fwd_fc1(self, x):
        x = self.fwd_flat(x)
        x = self.fc1(x)
        return F.relu(x)
    
    def fwd_fc2(self, x):
        x = self.fwd_fc1(x)
        return self.fc2(x)

    def forward(self, x):
        x = self.fwd_fc2(x)

        x -= torch.tensor(self.qu_mu_s[0], device=device)
        x /= torch.tensor(self.qu_mu_s[1], device=device)

        if self.qix == None:
          x = self.fc3(x).flatten()
        else:
          x = x[: ,self.qix]
        
        x = torch.sigmoid(x)
        return x

## Loading Weights

In [67]:
new_save_path = "Preference_Extraction/saved_model2"
restored_model = tf.keras.models.load_model(new_save_path)
restored_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
EncodingNetwork/conv2d (Conv (None, 12, 14, 16)        736       
_________________________________________________________________
EncodingNetwork/conv2d_1 (Co (None, 5, 6, 32)          4640      
_________________________________________________________________
flatten (Flatten)            (None, 960)               0         
_________________________________________________________________
EncodingNetwork/dense (Dense (None, 64)                61504     
_________________________________________________________________
dense (Dense)                (None, 3)                 195       
Total params: 67,075
Trainable params: 67,075
Non-trainable params: 0
_________________________________________________________________


In [0]:
original_weights=restored_model.get_weights()

In [0]:
def load_weights(model):
    if not params['use_mnist']:
        model.conv1.weight.data = torch.from_numpy(np.transpose(original_weights[0]))
        model.fc1.weight.data = torch.from_numpy(np.transpose(original_weights[4]))
    else:
        model.conv1.weight.data = torch.from_numpy(np.transpose(original_weights[0][:,:,:1,:]))
        mnist_flt_weights = np.random.rand(64, 4608)
        mnist_flt_weights[:, :original_weights[4].shape[0]] = np.transpose(original_weights[4])
        mnist_flt_weights = mnist_flt_weights.astype(np.float32)
        model.fc1.weight.data = torch.from_numpy(mnist_flt_weights)

    model.conv1.bias.data = torch.from_numpy(original_weights[1])
    model.conv2.weight.data = torch.from_numpy(np.transpose(original_weights[2]))
    model.conv2.bias.data = torch.from_numpy(original_weights[3])
    model.fc1.bias.data = torch.from_numpy(original_weights[5])
    model.fc2.weight.data = torch.from_numpy(np.transpose(original_weights[6]))
    model.fc2.bias.data = torch.from_numpy(original_weights[7])
    model.fc3.weight.data = torch.from_numpy(np.ones(shape=[1,3], dtype=np.float32))
    model.fc3.bias.data = torch.from_numpy(np.zeros(shape=[1], dtype=np.float32))
    model.to(device)

In [0]:
supermask_test_model = SuperMaskQNets(k=1, q_head_index=None, q_means_stds=[[0, 0, 0], [1, 1, 1]]).to(device)

if params['use_qnet_weights']:
    load_weights(supermask_test_model)

## Test the weights loaded properly

In [0]:
# Comparing that the models have identical observations for identical images
tf_conv1_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[0].output)
tf_conv2_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[1].output)
tf_flt_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[2].output)
tf_fc1_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[3].output)

def npsigmoid(x):
  return 1/(1 + np.exp(-x)) 

def check_same(torch_layer, tf_layer):
    torch_out = np.transpose(torch_layer(single_observation_torch).detach().cpu().numpy())
    torch_out = torch_out.reshape(torch_out.shape[:-1])
    tf_out = tf_layer(single_observation)[0].numpy()
    np.testing.assert_allclose(torch_out, tf_out, rtol=.1)  

# due to shape of original TF model this test can be done only when use_mnist = False
if not params['use_mnist'] and params['use_qnet_weights']:
    for i in range(len(all_raw_data[0].observation)):

        single_observation = np.array([all_raw_data[0].observation[i]])
        single_observation_torch = torch.Tensor(np.array([np.transpose(all_raw_data[0].observation[i])]))

        single_observation_torch = single_observation_torch.to(device)

        check_same(supermask_test_model.fwd_conv1, tf_conv1_fn)
        check_same(supermask_test_model.fwd_conv2, tf_conv2_fn)
        check_same(supermask_test_model.fwd_flat, tf_flt_fn)
        check_same(supermask_test_model.fwd_flat, tf_flt_fn)

        fc1_torch_out = np.transpose(supermask_test_model.fwd_fc1(single_observation_torch).detach().cpu().numpy())
        fc1_torch_out = fc1_torch_out.reshape(fc1_torch_out.shape[:-1])
        fc1_tf_out = tf_fc1_fn(single_observation)[0].numpy()
        
        np.testing.assert_allclose(fc1_torch_out, fc1_tf_out, rtol=.1)
        old_activations = all_raw_data[0].policy_info["activations"][i]
        np.testing.assert_allclose(fc1_torch_out, old_activations, rtol=.1)
        np.testing.assert_allclose(old_activations, fc1_tf_out, rtol=.1)

        check_same(supermask_test_model.fwd_fc2, restored_model)

        torch_out = np.transpose(supermask_test_model.forward(single_observation_torch).detach().cpu().numpy())
        torch_out = torch_out.reshape(torch_out.shape[:-1])
        tf_out = npsigmoid(np.sum(restored_model(single_observation)[0].numpy()))
        np.testing.assert_allclose(torch_out, tf_out, rtol=.1)  

## Create models for each q net head. And load weights

In [72]:
def get_q_heads_mu_and_sigma(model, all_obs, num_obs):
    
    model.eval()

    all_obs = shuffle(all_obs)
    obs_to_pass = all_obs[:num_obs]

    obs_tensor = torch.Tensor(obs_to_pass)
    obs_tensor = obs_tensor.to(device)
    qheads_values = model.fwd_fc2(obs_tensor).detach().cpu().numpy()

    mu = qheads_values.mean(axis=0)
    s = qheads_values.std(axis=0)

    print("mu", mu, "s", s)
    
    return np.array([mu, s])

if params['use_mnist']:
    img_batch, label = iter(tr_data_loader).next()
    xs = img_batch

q_mu_s = get_q_heads_mu_and_sigma(supermask_test_model, xs, 10000)

mu [21887.83 26996.55 39035.56] s [2435.595  3031.5972 4086.1343]


In [0]:
K = 0.5

spmsk_model_q_all = SuperMaskQNets(k=K, q_head_index=None, q_means_stds=q_mu_s).to(device)
spmsk_model_q0 = SuperMaskQNets(k=K, q_head_index=0, q_means_stds=q_mu_s).to(device)
spmsk_model_q1 = SuperMaskQNets(k=K, q_head_index=1, q_means_stds=q_mu_s).to(device)
spmsk_model_q2 = SuperMaskQNets(k=K, q_head_index=2, q_means_stds=q_mu_s).to(device)

if params['use_qnet_weights']:
    load_weights(spmsk_model_q_all)
    load_weights(spmsk_model_q0)
    load_weights(spmsk_model_q1)
    load_weights(spmsk_model_q1)

## Train models

In [74]:
"""
    Train/Test function for Randomly Weighted Hidden Neural Networks Techniques
    Adapted from https://github.com/NesterukSergey/hidden-networks/blob/master/demos/mnist.ipynb
"""

def train(model, device, train_loader, optimizer, criterion):
    
    train_loss = 0
    true_labels = []
    predictions = [] # labels
    outputs = [] # probabilities

    model.train()
    for data, target in itertools.islice(train_loader, params['num_train']):
        data, target = data.to(device), target.to(device)
        if params['use_mnist']:
            target = (target > 0).float()
        optimizer.zero_grad()

        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        train_loss += loss
        output_value = output.detach().cpu().numpy()
        outputs.append(output)
        pred = (output_value > 0.5).astype(float)
        predictions.extend(pred)
        true_labels.extend(target.detach().cpu().numpy())

    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    outputs = np.array(outputs)
    train_loss /= len(train_loader.dataset)
    accuracy = np.sum(np.equal(predictions, true_labels)) / len(true_labels)
    auc = roc_auc_score(true_labels, predictions)

    return train_loss.item(), accuracy, auc


def test(model, device, criterion, test_loader):
    true_labels = []
    predictions = [] # labels
    outputs = [] # probabilities

    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, target in itertools.islice(test_loader, params['num_val']):
            data, target = data.to(device), target.to(device)
            if params['use_mnist']:
                target = (target > 0).float()
            output = model(data)
            test_loss += criterion(output, target)

            output_value = output.detach().cpu().numpy()
            outputs.append(output)
            pred = (output_value > 0.5).astype(float)
            predictions.extend(pred)
            true_labels.extend(target.detach().cpu().numpy())

    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    outputs = np.array(outputs)
    test_loss /= len(test_loader.dataset)
    accuracy = np.sum(np.equal(predictions, true_labels)) / len(true_labels)
    auc = roc_auc_score(true_labels, predictions)

    return test_loss.item(), accuracy, auc

def run_model(model, num_epochs, verbose=False):
  # NOTE: only pass the parameters where p.requires_grad == True to the optimizer! Important!
  optimizer = optim.SGD(
      [p for p in model.parameters() if p.requires_grad],
      lr=0.1,
      momentum=0.9,
      weight_decay=0.0005,
  )

  criterion = nn.BCELoss().to(device)
  scheduler = CosineAnnealingLR(optimizer, T_max=14)

  train_accs = []
  train_aucs = []
  test_accs = []
  test_aucs = []

  for epoch in tqdm(range(num_epochs)):
      train_loss, train_accuracy, train_auc = train(model, device, tr_data_loader, optimizer, criterion)
      test_loss, test_accuracy, test_auc = test(model, device, criterion, val_data_loader)
      if verbose:
        print(f'Epoch {epoch}: train loss - {train_loss} / test loss {test_loss}')
      scheduler.step()

      train_accs.append(train_accuracy)
      train_aucs.append(train_auc)
      test_accs.append(test_accuracy)
      test_aucs.append(test_auc)


  print('Hyperparameters', params)

  print('Train accuracy: ', train_accs[-1])
  print('Test accuracy: ', test_accs[-1])

  print('Train AUC: ', train_aucs[-1])  
  print('Test AUC: ', test_aucs[-1])

num_epochs = 100

run_model(spmsk_model_q_all, num_epochs=num_epochs)
print()
run_model(spmsk_model_q0, num_epochs=num_epochs)
print()
run_model(spmsk_model_q1, num_epochs=num_epochs)
print()
run_model(spmsk_model_q2, num_epochs=num_epochs)

100%|██████████| 100/100 [03:37<00:00,  2.18s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Hyperparameters {'num_train': 50, 'num_val': 400, 'batch_size': 10, 'val_batch_size': 10, 'use_qnet_weights': True, 'use_mnist': True}
Train accuracy:  0.994
Test accuracy:  0.993
Train AUC:  0.9693877551020409
Test AUC:  0.978999648244108



100%|██████████| 100/100 [03:24<00:00,  2.05s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Hyperparameters {'num_train': 50, 'num_val': 400, 'batch_size': 10, 'val_batch_size': 10, 'use_qnet_weights': True, 'use_mnist': True}
Train accuracy:  0.992
Test accuracy:  0.99325
Train AUC:  0.9545454545454545
Test AUC:  0.9770928027230549



100%|██████████| 100/100 [03:24<00:00,  2.05s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Hyperparameters {'num_train': 50, 'num_val': 400, 'batch_size': 10, 'val_batch_size': 10, 'use_qnet_weights': True, 'use_mnist': True}
Train accuracy:  0.996
Test accuracy:  0.9935
Train AUC:  0.9877899877899878
Test AUC:  0.9830924217658132



100%|██████████| 100/100 [03:24<00:00,  2.05s/it]

Hyperparameters {'num_train': 50, 'num_val': 400, 'batch_size': 10, 'val_batch_size': 10, 'use_qnet_weights': True, 'use_mnist': True}
Train accuracy:  0.892
Test accuracy:  0.901
Train AUC:  0.5
Test AUC:  0.5



