# Do some imports

In [1]:
import torch
import numpy as np
from torchvision import datasets
import torchvision.transforms as transforms
import torch.utils.data
import pandas as pd
#from pytorch.ppg_feature_dataset

# Get UNSW_NB15 train and test set

In [None]:
#!wget https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/a%20part%20of%20training%20and%20testing%20set/UNSW_NB15_training-set.csv

In [None]:
#!wget https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/a%20part%20of%20training%20and%20testing%20set/UNSW_NB15_testing-set.csv

In [2]:
class UNSW_NB15(torch.utils.data.Dataset):
    def __init__(self, file_path, sequence_length=25, transform=None):
        #TODO have a sequence_overlap=True flag? Does overlap matter?
        self.transform = transform
        self.sequence_length = sequence_length
        self.columns = ['id', 'dur', 'proto', 'service', 'state', 'spkts', 'dpkts', 'sbytes',
       'dbytes', 'rate', 'sttl', 'dttl', 'sload', 'dload', 'sloss', 'dloss',
       'sinpkt', 'dinpkt', 'sjit', 'djit', 'swin', 'stcpb', 'dtcpb', 'dwin',
       'tcprtt', 'synack', 'ackdat', 'smean', 'dmean', 'trans_depth',
       'response_body_len', 'ct_srv_src', 'ct_state_ttl', 'ct_dst_ltm',
       'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm',
       'is_ftp_login', 'ct_ftp_cmd', 'ct_flw_http_mthd', 'ct_src_ltm',
       'ct_srv_dst', 'is_sm_ips_ports', 'attack_cat', 'label']
        self.dtypes = dtypes = {"id":"int32",
                                "scrip": "string",
                                #"sport": "int32",
                                "dstip": "string",
                                #"dsport": "int32",
                                "proto": "string",
                                "state": "string",
                                "dur": "float64",
                                "sbytes": "int32",
                                "dbytes": "int32",
                                "sttl": "int32",
                                "dttl": "int32",
                                "sloss": "int32",
                                "dloss": "int32",
                                "service": "string",
                                "sload": "float64",
                                "dload": "float64",
                                "spkts": "int32",
                                "dpkts": "int32",
                                "swin": "int32",
                                "dwin": "int32",
                                "stcpb": "int32",
                                "dtcpb": "int32", 
                                #"smeansz": "int32",
                                #"dmeansz": "int32",
                                "trans_depth": "int32",
                                #"res_bdy_len": "int32",
                                "sjit": "float64",
                                "djit": "float64",
                                #"stime": "int64",
                                #"ltime": "int64",
                                #"sintpkt": "float64",
                                #"dintpkt": "float64",
                                "tcprtt": "float64",
                                "synack": "float64",
                                "ackdat": "float64",

                                #commenting these because they have mixed values and we aren't going to generate them anyway
                                #"is_sm_ips_ports": "int32",
                                #"ct_state_ttl": "int32",
                                #"ct_flw_httpd_mthd": "int32",
                                #"is_ftp_login": "int32",
                                #"is_ftp_cmd": "int32",
                                #"ct_ftp_cmd": "int32",
                                #"ct_srv_src": "int32",
                                ##"ct_dst_ltm": "int32", 
                                #"ct_src_ltm": "int32",
                                #"ct_src_dport_ltm": "int32",
                                #"ct_dst_sport_ltm": "int32",
                                #"ct_dst_src_ltm": "int32",
                                "attack_cat": "string",
                                "label": "int32"}
        self.categorical_column_values = {"proto":None, "state":None, "service":None, "attack_cat":None}

        self.dataframe = pd.read_csv(file_path, encoding="latin-1", names=self.columns,header=0, dtype=self.dtypes)
        #self.dataframe.sort_values(by=['stime']) #sort chronologically upon loading
        
        #load all the unique values of categorical features at the start
        #and make these accessible via a fast function call.
        for key in self.categorical_column_values:
            self.categorical_column_values[key] = self.dataframe[key].unique()

        #cache all the maximum values in numeric columns since we'll be using these for feature extraction
        self.maximums = {}
        for key in self.dtypes:
            if "int" in self.dtypes[key] or "float" in self.dtypes[key]:
                self.maximums[key] = max(self.dataframe[key])
        
        #self.dataframe = torch.tensor(self.dataframe, dtype=torch.float32)
        self.tensor = torch.Tensor(self.dataframe.values)
    def __len__(self):
        return len(self.dataframe.index) - self.sequence_length
    
    def __getitem__(self, index):
        #TODO need error checking for out of bounds?
        #TODO return x,y where y is the category of the example
        #since none corresponds to "normal" data
        
        list_of_dicts = []
        for i in range(index,index+self.sequence_length):
            list_of_dicts.append(self.dataframe.loc[i, :].to_dict())
        
        if self.transform is not None:
            return self.transform(self, list_of_dicts)
        
        return list_of_dicts
    
    #get a list of all the unique labels in the dataset
    def get_labels(self):
        return self.dataframe['label'].unique().tolist()
    
    #get a list of all the unique attack categories in the dataset
    def get_attack_categories(self):
        return self.dataframe['attack_cat'].unique().tolist()
    
    def get_list_of_categories(self, column_name):
        pass #TODO

    #limit the dataset to only examples in the specified category
    def use_only_category(self, category_name):
        if category_name not in self.get_attack_categories():
            return False
        
        new_dataframe = self.dataframe[self.dataframe['attack_cat'] == category_name]
        new_dataframe = new_dataframe.reset_index()
        self.dataframe = new_dataframe
        return True
    
    #limit the dataset to only examples with the specified label
    def use_only_label(self, label):
        if label not in self.get_labels():
            return False
        
        new_dataframe = self.dataframe[self.dataframe['label'] == label]
        new_dataframe = new_dataframe.reset_index()
        self.dataframe = new_dataframe
        return True

# Create UNSW_NB15 class

In [4]:
unsw_nb15 = UNSW_NB15(file_path ='UNSW_NB15_training-set.csv')
unsw_nb15.head()

TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, int64, int32, int16, int8, and uint8.