<a href="https://colab.research.google.com/github/arunraja-hub/Preference_Extraction/blob/fine_tune/find_subnets_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [1]:
%tensorflow_version 2.x

!git clone https://github.com/arunraja-hub/Preference_Extraction.git

Cloning into 'Preference_Extraction'...
remote: Enumerating objects: 61, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 669 (delta 40), reused 31 (delta 25), pack-reused 608[K
Receiving objects: 100% (669/669), 21.61 MiB | 12.62 MiB/s, done.
Resolving deltas: 100% (90/90), done.


In [0]:
from __future__ import print_function
import argparse
import os
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import label_binarize

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms.functional as TF
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.autograd as autograd
from torchsummary import summary

from tqdm import tqdm
from sklearn.utils import shuffle
import tensorflow as tf
import concurrent.futures
import itertools
import os
import random
import sys
import time
import re
import io

import sys
sys.path.append('Preference_Extraction/utils')
import data_loading

In [0]:
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Subnets Methods

In [0]:
"""
    Original code from What's hidden in a randomly weighted neural network? paper
    Implemented at https://github.com/allenai/hidden-networks
    Remove weigths-initialisation since it is not relevant for us
"""

class GetSubnet(autograd.Function):
    @staticmethod
    def forward(ctx, scores, k):
        # Get the supermask by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1 - k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1

        return out

    @staticmethod
    def backward(ctx, g):
        # send the gradient g straight-through on the backward pass.
        return g, None

class SupermaskConv(nn.Conv2d):
    def __init__(self, *args, k, **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k

        # initialize the scores
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))

        # initialize the weights
        nn.init.uniform_(self.weight)
        
        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), self.k)
        w = self.weight * subnet
        x = F.conv2d(
            x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x

class SupermaskLinear(nn.Linear):
    def __init__(self, *args, k, **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k

        # initialize the scores
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))

        nn.init.uniform_(self.weight)

        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), self.k)
        w = self.weight * subnet
        return F.linear(x, w, self.bias)
        return x

# NOTE: not used here but we use NON-AFFINE Normalization!
# So there is no learned parameters for your nomralization layer.
class NonAffineBatchNorm(nn.BatchNorm2d):
    def __init__(self, dim):
        super(NonAffineBatchNorm, self).__init__(dim, affine=False)

## Define Supermask Network

In [0]:
"""
    Original code from https://github.com/allenai/hidden-networks/blob/master/simple_mnist_example.py
    Highly modified to work with PrefExtraction agent
"""

class SuperMaskNet(nn.Module):
    def __init__(self, k):
        super(SuperMaskNet, self).__init__()
        self.conv1 = SupermaskConv(in_channels=5, out_channels=16, kernel_size=3, stride=1, bias=True, k=k)
        self.conv2 = SupermaskConv(in_channels=16, out_channels=32, kernel_size=3, stride=2, bias=True, k=k)
        self.fc1 = SupermaskLinear(in_features=960, out_features=64, bias=True, k=k)
        self.fc2 = SupermaskLinear(in_features=64, out_features=3, bias=True, k=k)  # 3 qHeads ouput

    def fwd_conv1(self, x):
        x = self.conv1(x)
        return F.relu(x)

    def fwd_conv2(self, x):
        x = self.fwd_conv1(x)
        x = self.conv2(x)
        return F.relu(x)

    def fwd_flat(self, x):
        x = self.fwd_conv2(x)
        return torch.flatten(torch.transpose(x, 1, 3), 1) # Pre-flattening transpose is necessary for TF-Torch conversion

    def fwd_dense(self, x):
        x = self.fwd_flat(x)
        x = self.fc1(x)
        return F.relu(x)
    
    def forward(self, x):
        x = self.fwd_dense(x)
        x = self.fc2(x)
        output = F.sigmoid(x)
        return output

In [6]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
device

device(type='cuda')

In [7]:
# Using a mock supermask model with k=1 because we want to first test that the two models are equivalent 
supermask_test_model = SuperMaskNet(k=1).to(device)
summary(supermask_test_model, (5, 14, 16))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
     SupermaskConv-1           [-1, 16, 12, 14]             736
     SupermaskConv-2             [-1, 32, 5, 6]           4,640
   SupermaskLinear-3                   [-1, 64]          61,504
   SupermaskLinear-4                    [-1, 3]             195
Total params: 67,075
Trainable params: 0
Non-trainable params: 67,075
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.03
Params size (MB): 0.26
Estimated Total Size (MB): 0.29
----------------------------------------------------------------




## Load Data

In [0]:
all_raw_data = all_load_data("Preference_Extraction/data/simple_env_1/")

In [0]:
activations = []
observations = []
preferences = []

for data in all_raw_data:
    for i in range(data.observation.shape[0]):
        observations.append(np.copy(data.observation[i]))
        activations.append(np.copy(data.policy_info["activations"][i]))
        preferences.append((data.policy_info['satisfaction'].as_list()[i] > -6).astype(int))

activations = np.array(activations)
observations = np.array(observations)
preferences = np.array(preferences)

## Loading Weights

In [22]:
new_save_path = "Preference_Extraction/saved_model2"
restored_model = tf.keras.models.load_model(new_save_path)
restored_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
EncodingNetwork/conv2d (Conv (None, 12, 14, 16)        736       
_________________________________________________________________
EncodingNetwork/conv2d_1 (Co (None, 5, 6, 32)          4640      
_________________________________________________________________
flatten (Flatten)            (None, 960)               0         
_________________________________________________________________
EncodingNetwork/dense (Dense (None, 64)                61504     
_________________________________________________________________
dense (Dense)                (None, 3)                 195       
Total params: 67,075
Trainable params: 67,075
Non-trainable params: 0
_________________________________________________________________


In [0]:
original_weights=restored_model.get_weights()

In [0]:
# Check Shapes
assert tuple(supermask_test_model.conv1.weight.data.shape) == np.transpose(original_weights[0]).shape
assert tuple(supermask_test_model.conv1.bias.data.shape) == original_weights[1].shape
assert tuple(supermask_test_model.conv2.weight.data.shape) == np.transpose(original_weights[2]).shape
assert tuple(supermask_test_model.conv2.bias.data.shape) == original_weights[3].shape
assert tuple(supermask_test_model.fc1.weight.data.shape) == np.transpose(original_weights[4]).shape
assert tuple(supermask_test_model.fc1.bias.data.shape) == original_weights[5].shape
assert tuple(supermask_test_model.fc2.weight.data.shape) == np.transpose(original_weights[6]).shape
assert tuple(supermask_test_model.fc2.bias.data.shape) == original_weights[7].shape

In [25]:
# Load Weights
supermask_test_model.conv1.weight.data = torch.from_numpy(np.transpose(original_weights[0]))
supermask_test_model.conv1.bias.data = torch.from_numpy(original_weights[1])
supermask_test_model.conv2.weight.data = torch.from_numpy(np.transpose(original_weights[2]))
supermask_test_model.conv2.bias.data = torch.from_numpy(original_weights[3])

supermask_test_model.fc1.weight.data = torch.from_numpy(np.transpose(original_weights[4]))
supermask_test_model.fc1.bias.data = torch.from_numpy(original_weights[5])
supermask_test_model.fc2.weight.data = torch.from_numpy(np.transpose(original_weights[6]))
supermask_test_model.fc2.bias.data = torch.from_numpy(original_weights[7])

supermask_test_model.to(device)

SuperMaskNet(
  (conv1): SupermaskConv(5, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv2): SupermaskConv(16, 32, kernel_size=(3, 3), stride=(2, 2))
  (fc1): SupermaskLinear(in_features=960, out_features=64, bias=True)
  (fc2): SupermaskLinear(in_features=64, out_features=3, bias=True)
)

In [0]:
# Comparing that the models have identical observations for identical images
tf_conv1_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[0].output)
tf_conv2_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[1].output)
tf_flt_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[2].output)
tf_fc1_fn = tf.keras.models.Model(inputs=restored_model.input, outputs=restored_model.layers[3].output)

for i in range(len(all_raw_data[0].observation)):

    single_observation = np.array([all_raw_data[0].observation[i]])
    single_observation_torch = torch.Tensor(np.array([np.transpose(all_raw_data[0].observation[i])]))
    single_observation_torch = single_observation_torch.to(device)
    
    conv1_torch_out = np.transpose(supermask_test_model.fwd_conv1(single_observation_torch).detach().cpu().numpy())
    conv1_torch_out = conv1_torch_out.reshape(conv1_torch_out.shape[:-1])
    conv1_tf_out = tf_conv1_fn(single_observation)[0].numpy()
    np.testing.assert_allclose(conv1_torch_out, conv1_tf_out, rtol=.1)

    conv2_torch_out = np.transpose(supermask_test_model.fwd_conv2(single_observation_torch).detach().cpu().numpy())
    conv2_torch_out = conv2_torch_out.reshape(conv2_torch_out.shape[:-1])
    conv2_tf_out = tf_conv2_fn(single_observation)[0].numpy()
    np.testing.assert_allclose(conv2_torch_out, conv2_tf_out, rtol=.1)

    flt_torch_out = np.transpose(supermask_test_model.fwd_flat(single_observation_torch).detach().cpu().numpy())
    flt_torch_out = flt_torch_out.reshape(flt_torch_out.shape[:-1])
    tf_flt_out = tf_flt_fn(single_observation)[0].numpy()
    np.testing.assert_allclose(flt_torch_out, tf_flt_out, rtol=.1)

    fc1_torch_out = np.transpose(supermask_test_model.fwd_dense(single_observation_torch).detach().cpu().numpy())
    fc1_torch_out = fc1_torch_out.reshape(fc1_torch_out.shape[:-1])
    fc1_tf_out = tf_fc1_fn(single_observation)[0].numpy()
    
    old_activations = all_raw_data[0].policy_info["activations"][i]

    np.testing.assert_allclose(fc1_torch_out, fc1_tf_out, rtol=.1)
    np.testing.assert_allclose(fc1_torch_out, old_activations, rtol=.1)
    np.testing.assert_allclose(old_activations, fc1_tf_out, rtol=.1)

## Training Supermask (incomplete)

In [0]:
# Creating and loading weights
spmsk_model = SuperMaskNet(k=1).to(device)

spmsk_model.conv1.weight.data = torch.from_numpy(np.transpose(original_weights[0]))
spmsk_model.conv1.bias.data = torch.from_numpy(original_weights[1])
spmsk_model.conv2.weight.data = torch.from_numpy(np.transpose(original_weights[2]))
spmsk_model.conv2.bias.data = torch.from_numpy(original_weights[3])
spmsk_model.fc1.weight.data = torch.from_numpy(np.transpose(original_weights[4]))
spmsk_model.fc1.bias.data = torch.from_numpy(original_weights[5])
spmsk_model.fc2.weight.data = torch.from_numpy(np.transpose(original_weights[6]))
spmsk_model.fc2.bias.data = torch.from_numpy(original_weights[7])

spmsk_model.to(device)

SuperMaskNet(
  (conv1): SupermaskConv(5, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv2): SupermaskConv(16, 32, kernel_size=(3, 3), stride=(2, 2))
  (fc1): SupermaskLinear(in_features=960, out_features=64, bias=True)
  (fc2): SupermaskLinear(in_features=64, out_features=3, bias=True)
)

In [0]:
# Create dataset iterators
num_train = 50
num_val = 400
batch_size = 50
val_batch_size = 50

xs = np.rollaxis(observations, 3, 1) # Torch wants channel-first
ys = preferences
xs, ys = shuffle(xs, ys)

xs_tr = xs[:num_train]
ys_tr = ys[:num_train]
xs_val = xs[num_train:num_train+num_val]
ys_val = ys[num_train:num_train+num_val]

tr_data_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(torch.Tensor(xs_tr), torch.Tensor(ys_tr)),
    batch_size=batch_size)

val_data_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(torch.Tensor(xs_val), torch.Tensor(ys_val)),
    batch_size=val_batch_size)

In [0]:
"""
    Train/Test function for Randomly Weighted Hidden Neural Networks Techniques
    Adapted from https://github.com/NesterukSergey/hidden-networks/blob/master/demos/mnist.ipynb
"""

def train(model, device, train_loader, optimizer, criterion, epoch, verbose=False):
    train_loss = 0

    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        train_loss += loss

        if verbose:
          if batch_idx % 5 == 0:
              print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                  epoch, batch_idx * len(data), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader), loss.item()))
              
    return train_loss / len(train_loader.dataset)


def test(model, device, criterion, test_loader):
    true_labels = []
    predictions = [] # labels
    outputs = [] # probabilities

    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)
            test_loss += criterion(output, target)
            outputs.append(output.detach().cpu().numpy())
            pred = output > 0.5

            predictions.extend(pred)
            true_labels.extend(target.detach().cpu().numpy())

    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    outputs = np.array(outputs)

    test_loss /= len(test_loader.dataset)
    test_accuracy = np.sum(np.equal(predictions, true_labels)) / len(true_labels)
    
    return test_loss.item(), test_accuracy

def run_model(k, model, num_epochs):
  # NOTE: only pass the parameters where p.requires_grad == True to the optimizer! Important!
  optimizer = optim.SGD(
      [p for p in model.parameters() if p.requires_grad],
      lr=0.1,
      momentum=0.9,
      weight_decay=0.0005,
  )

  criterion = nn.BCELoss().to(device)
  scheduler = CosineAnnealingLR(optimizer, T_max=14)

  train_losses = []
  test_losses = []
  test_acc = []

  for epoch in tqdm(range(num_epochs)):
      train_loss = train(model, device, tr_data_loader, optimizer, criterion, epoch, verbose=False)
      test_loss, test_accuracy = test(model, device, criterion, val_data_loader)
      scheduler.step()

      train_losses.append(train_loss)
      test_losses.append(test_loss)
      test_acc.append(test_accuracy)

  print('k: ', k, '  init: ', init)
  print('Test loss: ', test_losses[-1])
  print('Test accuracy: ', test_acc[-1])

# run_model(0.3, spmsk_model, num_epochs=400)