# NetCLR Fine-tuning 

In this notebook, we fine-tune the pre-trained base model of NetCLR in a open world scenario. 

We evaluate NetCLR using two datasets: AWF and Drift datasets. 

N defines the number of labeled samples that we use for fine-tuning.  

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import warnings
warnings.filterwarnings('ignore')
import numpy as np

from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler, SequentialSampler
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable
# from torchvision import datasets, transforms
import tqdm

import pickle
import argparse
from torch.cuda.amp import GradScaler, autocast

import random
import sys
import os
import collections

## GPU Allocation

In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu", 0)
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}
print (f'Device: {device}')

Device: cuda:0


## Parameters

In [3]:
batch_size = 32

## Helper Functions 

Helper functions to sample labeled traces randomly from both closed world and open world datasets.

In [4]:
def sample_traces(x, y, N):
    train_index = []
    
    for c in range(num_classes-1):
        idx = np.where(y == c)[0]
        idx = np.random.choice(idx, min(N, len(idx)), False)
        train_index.extend(idx)
        
    train_index = np.array(train_index)
    np.random.shuffle(train_index)
    
    x_train = x[train_index]
    y_train = y[train_index]
    
    return x_train, y_train

def sample_ow_traces(x, N, num_classes):
    idx = np.random.randint(0, len(x), size=N*num_classes)
    
    return x[idx]## Helper Functions 



## Loading Closed World Data

In [5]:
DATASET = 'Drift'

if DATASET == 'AWF':
    data_path = '/path/to/AWF/closed-world-data/' # AWF-Attack
    data = pickle.load(open(f'{data_path}', 'rb'))
elif DATASET == 'Drift':
    data_path = '/path/to/Drift/closed-world-data/'
    data = pickle.load(open(f'{data_path}', 'rb')) # Drift90

x_cw_train_total = data['x_train']
y_cw_train_total = data['y_train'] 
x_cw_test_sup = data['x_test_fast']
y_cw_test_sup = data['y_test_fast']
x_cw_test_inf = data ['x_test_slow']
y_cw_test_inf = data['y_test_slow']

In [6]:
print (f'Closed world dataset shapes: {x_cw_train_total.shape}, {x_cw_test_sup.shape}, {x_cw_test_inf.shape}')

Closed world dataset shapes: (9241, 5000), (1860, 5000), (1860, 5000)


In [7]:
num_classes = len(np.unique(y_cw_train_total)) + 1 # adding 1 for the open-world
print (f"Number of Classes: {num_classes}")

Number of Classes: 94


## Loading Open World Data

In [8]:
if DATASET == 'AWF':
    data = np.load('/path/to/AWF/open-world-data/') # AWF-OW
if DATASET == 'Drift':
    data = np.load('/path/to/Drift/open-world-data/') # Drift5000

x_ow_train = data['superior_train']
x_ow_test_sup = data['inferior_test']
x_ow_test_inf = data['inferior_test']

In [9]:
print (f'Open world dataset shapes: {x_ow_train.shape}, {x_ow_test_sup.shape}, {x_ow_test_inf.shape}')

Open world dataset shapes: (1000, 5000), (4051, 5000), (4051, 5000)


## Combine CW and OW test data

In [10]:
y_cw_test_sup += 1
y_cw_test_inf += 1

x_test_sup = np.vstack((x_cw_test_sup, x_ow_test_sup))
x_test_inf = np.vstack((x_cw_test_inf, x_ow_test_inf))

y_ow_test = np.zeros((len(x_ow_test_inf), ))

y_test_sup = np.hstack((y_cw_test_sup, y_ow_test))
y_test_inf = np.hstack((y_cw_test_inf, y_ow_test))

x_test_inf.shape, y_test_inf.shape, x_test_sup.shape, y_test_sup.shape

((5911, 5000), (5911,), (5911, 5000), (5911,))

## Backbone Model

In [11]:
class DFNet(nn.Module):
    def __init__(self, out_dim):
        super(DFNet, self).__init__()
        kernel_size = 8
        channels = [1, 32, 64, 128, 256]
        conv_stride = 1
        pool_stride = 4
        pool_size = 8
        
        self.conv1 = nn.Conv1d(1, 32, kernel_size, stride = conv_stride)
        self.conv1_1 = nn.Conv1d(32, 32, kernel_size, stride = conv_stride)
        
        self.conv2 = nn.Conv1d(32, 64, kernel_size, stride = conv_stride)
        self.conv2_2 = nn.Conv1d(64, 64, kernel_size, stride = conv_stride)
       
        self.conv3 = nn.Conv1d(64, 128, kernel_size, stride = conv_stride)
        self.conv3_3 = nn.Conv1d(128, 128, kernel_size, stride = conv_stride)
       
        self.conv4 = nn.Conv1d(128, 256, kernel_size, stride = conv_stride)
        self.conv4_4 = nn.Conv1d(256, 256, kernel_size, stride = conv_stride)
       
        
        self.batch_norm1 = nn.BatchNorm1d(32)
        self.batch_norm2 = nn.BatchNorm1d(64)
        self.batch_norm3 = nn.BatchNorm1d(128)
        self.batch_norm4 = nn.BatchNorm1d(256)
        
        self.max_pool_1 = nn.MaxPool1d(kernel_size=pool_size, stride=pool_stride)
        self.max_pool_2 = nn.MaxPool1d(kernel_size=pool_size, stride=pool_stride)
        self.max_pool_3 = nn.MaxPool1d(kernel_size=pool_size, stride=pool_stride)
        self.max_pool_4 = nn.MaxPool1d(kernel_size=pool_size, stride=pool_stride)
        
        self.dropout1 = nn.Dropout(p=0.1)
        self.dropout2 = nn.Dropout(p=0.1)
        self.dropout3 = nn.Dropout(p=0.1)
        self.dropout4 = nn.Dropout(p=0.1)
        

        
        self.fc = nn.Linear(5120, out_dim)

        
    def weight_init(self):
        for n, m in self.named_modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
                print (n)
                torch.nn.init.xavier_uniform(m.weight)
                m.bias.data.zero_()
            
        
    def forward(self, inp):
        x = inp
        # ==== first block ====
        x = F.pad(x, (3,4))
        x = F.elu((self.conv1(x)))
        x = F.pad(x, (3,4))
        x = F.elu(self.batch_norm1(self.conv1_1(x)))
        x = F.pad(x, (3, 4))
        x = self.max_pool_1(x)
        x = self.dropout1(x)
        
        # ==== second block ====
        x = F.pad(x, (3,4))
        x = F.relu((self.conv2(x)))
        x = F.pad(x, (3,4))
        x = F.relu(self.batch_norm2(self.conv2_2(x)))
        x = F.pad(x, (3,4))
        x = self.max_pool_2(x)
        x = self.dropout2(x)
        
        # ==== third block ====
        x = F.pad(x, (3,4))
        x = F.relu((self.conv3(x)))
        x = F.pad(x, (3,4))
        x = F.relu(self.batch_norm3(self.conv3_3(x)))
        x = F.pad(x, (3,4))
        x = self.max_pool_3(x)
        x = self.dropout3(x)
        
        # ==== fourth block ====
        x = F.pad(x, (3,4))
        x = F.relu((self.conv4(x)))
        x = F.pad(x, (3,4))
        x = F.relu(self.batch_norm4(self.conv4_4(x)))
        x = F.pad(x, (3,4))
        x = self.max_pool_4(x)
        x = self.dropout4(x)

                
        x = x.view(x.size(0), -1)
        


        x = self.fc(x)
                
        return x    
        

In [12]:
class DFsimCLR(nn.Module):
    def __init__(self, df, out_dim):
        super(DFsimCLR, self).__init__()
        
        self.backbone = df
        self.backbone.weight_init()
        dim_mlp = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Linear(dim_mlp, dim_mlp),
            nn.BatchNorm1d(dim_mlp),
            nn.ReLU(),
            nn.Linear(dim_mlp, out_dim)
        )
        
    def forward(self, inp):
        out = self.backbone(inp)
        return out

## Data Loader

In [13]:
class Data(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    
    def __len__(self):
        return len(self.x)

## Loading the Pre-trained Model

In [14]:
def load_checkpoint():

    model = DFNet(out_dim=num_classes).to(device)

    checkpoint = torch.load('/path/to/pre-trained/model/')


    for k in list(checkpoint.keys()):
        if k.startswith('backbone.'):
            if k.startswith('backbone') and not k.startswith('backbone.fc'):
          # remove prefix
                checkpoint[k[len("backbone."):]] = checkpoint[k]
        del checkpoint[k]

    log = model.load_state_dict(checkpoint, strict=False)
    assert log.missing_keys == ['fc.weight', 'fc.bias']
    
    return model

## Functions for Train and Test

In [15]:
def train(model, device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(data.size(0), 1, data.size(1)).float().to(device)
        target = target.type(torch.LongTensor)
        target = target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        # print (output.size())
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx%100 == 0:
            print ("Loss: {:0.6f}".format(loss.item()))
    
def test(model, device, loader):
    model.eval()
    correct = 0
    temp = 0
    with torch.no_grad():
        for data, target in loader:
            data = data.view(data.size(0), 1, data.size(1)).float().to(device)
            target = target.to(device)
            output = model(data)
            output = torch.softmax(output, dim=1)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).float().sum().item()
    return correct / len(loader.dataset)

## Function for Open World Evaluation

This function calculates the precision and recall of the WF classifier.

In [16]:
def test_ow(model, device, loader, threshold, num_ow_test_samples):
    
    ow_label = 0
    TP, FP, TN, FN, total = 0, 0, 0, 0, 0
    model.eval()
    with torch.no_grad():
        for data, target, in loader:
            data = data.view(data.size(0), 1, data.size(1)).float().to(device)
            target = target.detach().numpy() 
            
            
            total += len(data)
            
            output = model(data)
            output = torch.softmax(output, dim=1)
            output = output.cpu().detach().numpy()           
            for pred, label in zip(output, target):
                best_n = np.argmax(pred)
                
                # monitored websites
                if int(label) != ow_label:
                    if int(best_n) != ow_label:
                        if pred[best_n] >= threshold:
                            TP += 1
                        else:
                            FN += 1
                            
                    else:
                        FN += 1
                    
                    
                elif int(label) == ow_label:
                    
                    if int(best_n) != ow_label:
                        if pred[best_n] >= threshold:
                            FP += 1
                        else:
                            TN += 1
                    else:
                        TN += 1
                        
                        
        return TP / (TP + FP), TP / (total - num_ow_test_samples)

## Initiating Test Data Loaders

In [17]:
test_dataset_inf = Data(x_test_inf, y_test_inf)
test_loader_inf = DataLoader(test_dataset_inf, batch_size=batch_size, drop_last=True)

test_dataset_sup = Data(x_test_sup, y_test_sup)
test_loader_sup = DataLoader(test_dataset_sup, batch_size=batch_size, drop_last=True)

## Fine-tuning

In [18]:
N = 5

In [19]:
x_cw_train, y_cw_train = sample_traces(x_cw_train_total, y_cw_train_total, N)
y_cw_train += 1

print (f'Closed World: {x_cw_train.shape}, {y_cw_train.shape}')

# x_ow_train = sample_ow_traces(x_ow_train_total, N, num_classes-1)
y_ow_train = np.zeros((len(x_ow_train), ))

print (f'Open World: {x_ow_train.shape}, {y_ow_train.shape}')


x_train = np.vstack((x_cw_train, x_ow_train))
y_train = np.hstack((y_cw_train, y_ow_train))

print (f'Total: {x_train.shape}, {y_train.shape}')


Closed World: (465, 5000), (465,)
Open World: (1000, 5000), (1000,)
Total: (1465, 5000), (1465,)


In [20]:
train_dataset = Data(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

model = load_checkpoint()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
    

In [21]:
for e in range(30):
    train(model, device, train_loader, optimizer)
    
    acc_inf = test(model, device, test_loader_inf)
    acc_sup = test(model, device, test_loader_sup)
    
    
    print (f'-------------- Epoch {e} --------------')
    # print (f'accuracy on inferior traces: {acc_inf:.2f}')
    # print (f'accuracy on superior traces: {acc_sup:.2f}')

Loss: 4.700192
-------------- Epoch 0 --------------
Loss: 1.405838
-------------- Epoch 1 --------------
Loss: 1.632425
-------------- Epoch 2 --------------
Loss: 0.952068
-------------- Epoch 3 --------------
Loss: 0.873310
-------------- Epoch 4 --------------
Loss: 0.512898
-------------- Epoch 5 --------------
Loss: 0.289532
-------------- Epoch 6 --------------
Loss: 0.166581
-------------- Epoch 7 --------------
Loss: 0.104868
-------------- Epoch 8 --------------
Loss: 0.115181
-------------- Epoch 9 --------------
Loss: 0.057948
-------------- Epoch 10 --------------
Loss: 0.058850
-------------- Epoch 11 --------------
Loss: 0.026669
-------------- Epoch 12 --------------
Loss: 0.027162
-------------- Epoch 13 --------------
Loss: 0.029380
-------------- Epoch 14 --------------
Loss: 0.021085
-------------- Epoch 15 --------------
Loss: 0.012534
-------------- Epoch 16 --------------
Loss: 0.017626
-------------- Epoch 17 --------------
Loss: 0.020268
-------------- Epoch 18

## Evaluation

Following block shows the precision and recall of the model for different thresholds:

In [22]:
thresholds = np.arange(0.1, 1, 0.1)
num_ow_test_samples = len(x_ow_test_sup)

### Inferior traces

In [23]:
for th in thresholds:
    print (f'--------------------- threshold = {th:.1f}')
    P, R = test_ow(model, device, test_loader_inf, th, num_ow_test_samples)
    print (f'Precision: {P*100:.1f}, Recall: {R*100:.1f}, F1 Score: {2*(P*R)*100/(P + R):.1f}')

--------------------- threshold = 0.1
Precision: 98.4, Recall: 42.4, F1 Score: 59.3
--------------------- threshold = 0.2
Precision: 98.7, Recall: 42.1, F1 Score: 59.0
--------------------- threshold = 0.3
Precision: 99.3, Recall: 40.0, F1 Score: 57.0
--------------------- threshold = 0.4
Precision: 99.6, Recall: 37.8, F1 Score: 54.8
--------------------- threshold = 0.5
Precision: 99.5, Recall: 32.7, F1 Score: 49.2
--------------------- threshold = 0.6
Precision: 99.6, Recall: 27.5, F1 Score: 43.1
--------------------- threshold = 0.7
Precision: 99.5, Recall: 22.4, F1 Score: 36.5
--------------------- threshold = 0.8
Precision: 100.0, Recall: 16.9, F1 Score: 28.9
--------------------- threshold = 0.9
Precision: 100.0, Recall: 11.1, F1 Score: 20.0


### Superior traces

In [24]:
for th in thresholds:
    print (f'--------------------- threshold = {th:.1f}')
    P, R = test_ow(model, device, test_loader_sup, th, num_ow_test_samples)
    print (f'Precision: {P*100:.1f}, Recall: {R*100:.1f}, F1 Score: {2*(P*R)*100/(P + R):.1f}')

--------------------- threshold = 0.1
Precision: 99.4, Recall: 66.7, F1 Score: 79.8
--------------------- threshold = 0.2
Precision: 99.4, Recall: 66.7, F1 Score: 79.8
--------------------- threshold = 0.3
Precision: 99.3, Recall: 65.8, F1 Score: 79.2
--------------------- threshold = 0.4
Precision: 99.7, Recall: 63.3, F1 Score: 77.4
--------------------- threshold = 0.5
Precision: 99.8, Recall: 59.2, F1 Score: 74.3
--------------------- threshold = 0.6
Precision: 100.0, Recall: 52.9, F1 Score: 69.2
--------------------- threshold = 0.7
Precision: 100.0, Recall: 45.9, F1 Score: 63.0
--------------------- threshold = 0.8
Precision: 100.0, Recall: 37.7, F1 Score: 54.8
--------------------- threshold = 0.9
Precision: 100.0, Recall: 26.4, F1 Score: 41.8
