# Handling Imbalanced Dataset

Most classification algorithms assume that classes/categories are evenly distributed.

However, many real-world classification problems have an imbalanced class distribution such as fraud detection, spam detection, and churn prediction. (anomaly detections)

In classification problems, target variables that are severely imbalanced between different categories (i.e. beyond 70-30 ratio for binary classification and n equally divided categories for multiclass classification) may result in inaccurate predictions from classification algorithms.

Thus, this scenario needs to be handled accordingly using various resampling techniques available using <b>imblearn</b> library.

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import scipy.stats as st
import feature_engine.imputation as fei
from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import SMOTE, SMOTEN, SMOTENC
from imblearn.combine import SMOTETomek, SMOTEENN
from sklearn.utils import compute_class_weight
from imblearn.ensemble import EasyEnsembleClassifier
from collections import Counter
from sklearn.datasets import fetch_kddcup99
import category_encoders as ce
pd.set_option('display.max_columns', 100)

In [2]:
dataset = fetch_kddcup99()
data = pd.concat([pd.DataFrame(dataset['data'], columns=dataset['feature_names']),
                  pd.DataFrame(dataset['target'], columns=dataset['target_names'])],axis=1)
data.head()

Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,num_failed_logins,logged_in,num_compromised,root_shell,su_attempted,num_root,num_file_creations,num_shells,num_access_files,num_outbound_cmds,is_host_login,is_guest_login,count,srv_count,serror_rate,srv_serror_rate,rerror_rate,srv_rerror_rate,same_srv_rate,diff_srv_rate,srv_diff_host_rate,dst_host_count,dst_host_srv_count,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate,labels
0,0,b'tcp',b'http',b'SF',181,5450,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,8,8,0.0,0.0,0.0,0.0,1.0,0.0,0.0,9,9,1.0,0.0,0.11,0.0,0.0,0.0,0.0,0.0,b'normal.'
1,0,b'tcp',b'http',b'SF',239,486,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,8,8,0.0,0.0,0.0,0.0,1.0,0.0,0.0,19,19,1.0,0.0,0.05,0.0,0.0,0.0,0.0,0.0,b'normal.'
2,0,b'tcp',b'http',b'SF',235,1337,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,8,8,0.0,0.0,0.0,0.0,1.0,0.0,0.0,29,29,1.0,0.0,0.03,0.0,0.0,0.0,0.0,0.0,b'normal.'
3,0,b'tcp',b'http',b'SF',219,1337,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,6,6,0.0,0.0,0.0,0.0,1.0,0.0,0.0,39,39,1.0,0.0,0.03,0.0,0.0,0.0,0.0,0.0,b'normal.'
4,0,b'tcp',b'http',b'SF',217,2032,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,6,6,0.0,0.0,0.0,0.0,1.0,0.0,0.0,49,49,1.0,0.0,0.02,0.0,0.0,0.0,0.0,0.0,b'normal.'


In [3]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 494021 entries, 0 to 494020
Data columns (total 42 columns):
 #   Column                       Non-Null Count   Dtype 
---  ------                       --------------   ----- 
 0   duration                     494021 non-null  object
 1   protocol_type                494021 non-null  object
 2   service                      494021 non-null  object
 3   flag                         494021 non-null  object
 4   src_bytes                    494021 non-null  object
 5   dst_bytes                    494021 non-null  object
 6   land                         494021 non-null  object
 7   wrong_fragment               494021 non-null  object
 8   urgent                       494021 non-null  object
 9   hot                          494021 non-null  object
 10  num_failed_logins            494021 non-null  object
 11  logged_in                    494021 non-null  object
 12  num_compromised              494021 non-null  object
 13  root_shell    

In [4]:
data.nunique()

duration                        2495
protocol_type                      3
service                           66
flag                              11
src_bytes                       3300
dst_bytes                      10725
land                               2
wrong_fragment                     3
urgent                             4
hot                               22
num_failed_logins                  6
logged_in                          2
num_compromised                   23
root_shell                         2
su_attempted                       3
num_root                          20
num_file_creations                18
num_shells                         3
num_access_files                   7
num_outbound_cmds                  1
is_host_login                      1
is_guest_login                     2
count                            490
srv_count                        470
serror_rate                       92
srv_serror_rate                   51
rerror_rate                       77
s

In [5]:
data[np.array(data.columns[24:31])] = data[np.array(data.columns[24:31])].astype('float')
data[np.array(data.columns[33:41])] = data[np.array(data.columns[33:41])].astype('float')
data[['protocol_type','service','flag','labels']] = data[['protocol_type','service','flag','labels']].applymap(
                                                    lambda x: x.decode('ascii'))
integer_columns = np.array(data.drop(['serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 
                                      'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_diff_srv_rate', 
                                      'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 
                                      'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate', 
                                      'dst_host_same_srv_rate','protocol_type','service','flag','labels'],axis=1).columns)
data[integer_columns] = data[integer_columns].astype('int')
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 494021 entries, 0 to 494020
Data columns (total 42 columns):
 #   Column                       Non-Null Count   Dtype  
---  ------                       --------------   -----  
 0   duration                     494021 non-null  int32  
 1   protocol_type                494021 non-null  object 
 2   service                      494021 non-null  object 
 3   flag                         494021 non-null  object 
 4   src_bytes                    494021 non-null  int32  
 5   dst_bytes                    494021 non-null  int32  
 6   land                         494021 non-null  int32  
 7   wrong_fragment               494021 non-null  int32  
 8   urgent                       494021 non-null  int32  
 9   hot                          494021 non-null  int32  
 10  num_failed_logins            494021 non-null  int32  
 11  logged_in                    494021 non-null  int32  
 12  num_compromised              494021 non-null  int32  
 13 

In [6]:
X = data.drop('labels',axis=1)
y = data['labels']

In [7]:
y.value_counts()/len(y)

smurf.              0.568377
neptune.            0.216997
normal.             0.196911
back.               0.004459
satan.              0.003216
ipsweep.            0.002524
portsweep.          0.002105
warezclient.        0.002065
teardrop.           0.001982
pod.                0.000534
nmap.               0.000468
guess_passwd.       0.000107
buffer_overflow.    0.000061
land.               0.000043
warezmaster.        0.000040
imap.               0.000024
rootkit.            0.000020
loadmodule.         0.000018
ftp_write.          0.000016
multihop.           0.000014
phf.                0.000008
perl.               0.000006
spy.                0.000004
Name: labels, dtype: float64

In [8]:
y = y.map(lambda x: 'others' if x not in ['smurf.','neptune.','normal.'] else x)
y.value_counts()/len(y)

smurf.      0.568377
neptune.    0.216997
normal.     0.196911
others      0.017716
Name: labels, dtype: float64

## Under-Sampling

Under-sampling method reduces number of labels for more frequent class to match with number of labels for least frequent class.

Advantages: Improves runtime and storage issues for very large datasets.

Disadvantages: Very prone to data loss, may result in inaccurate results. (Rarely use undersampling methods)

In [9]:
# Under-sampling method
sampler = RandomUnderSampler()
X_resampled, y_resampled = sampler.fit_resample(X,y)
# Checking for count of values between groupings of variable after under-sampling
Counter(y_resampled)

Counter({'neptune.': 8752, 'normal.': 8752, 'others': 8752, 'smurf.': 8752})

## Over-Sampling

Over-sampling method increases the number of labels for less frequent classes to match with number of labels for most frequent class

Advantages: 
- No information loss
- Usually outperforms under-sampling method

Disadvantages:
- Increase likelihood of overfitting, since it only replicates minority class records.

The following techniques are most commonly used for oversampling:
1. SMOTE (Oversampling for continuous data only)
2. SMOTEN (Oversampling for categorical data only)
3. SMOTENC (Oversampling for continuous and categorical data)

In [10]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y,stratify=y,test_size=0.1)
# Test set is used to reduce computational time of SMOTENC for demo purposes
sampler = SMOTENC(categorical_features=[1,2,3])
X_resampled, y_resampled = sampler.fit_resample(X_test,y_test)
# Checking for count of values between groupings of variable after over-sampling
Counter(y_resampled)

Counter({'neptune.': 28080,
         'normal.': 28080,
         'smurf.': 28080,
         'others': 28080})

## Combination of over-sampling and under-sampling

Both oversampling and undersampling methods can be combined together using one of the two methods available:
1. SMOTETomek
2. SMOTEENN

Note that both methods require categorical data to be label encoded first before using SMOTEN/SMOTENC as smote component

### 1. SMOTETomek

SMOTETomek method combines over-sampling (SMOTE - Synthetic Minority Oversampling Technique) and under-sampling methods (Tomek) by creating new data points around less frequent classes.

In [11]:
encoder = ce.OrdinalEncoder(cols=['protocol_type','service','flag'])
X_test = encoder.fit_transform(X_test)

In [12]:
sampler = SMOTETomek(smote=SMOTENC([1,2,3]))
X_resampled, y_resampled = sampler.fit_resample(X_test,y_test)
# Checking for count of values between groupings of variable after over-sampling
Counter(y_resampled)

Counter({'neptune.': 28080,
         'normal.': 28068,
         'smurf.': 28079,
         'others': 28069})

### 2. SMOTEENN

SMOTETomek method combines over-sampling (SMOTE - Synthetic Minority Oversampling Technique) and under-sampling methods (Edited-nearest-neighbors) by creating new data points around less frequent classes.

Unlike SMOTETomek, SMOTEENN tends to clean more noisy samples.

In [13]:
sampler = SMOTEENN(smote=SMOTENC([1,2,3]))
X_resampled, y_resampled = sampler.fit_resample(X_test,y_test)
# Checking for count of values between groupings of variable after over-sampling
Counter(y_resampled)

Counter({'neptune.': 28043,
         'normal.': 27962,
         'others': 27883,
         'smurf.': 28076})

## Class Weights Distribution

An alternative method to handle imbalanced dataset in classification problems is to assign class weights to classifiers that do have the option of mentioning class weights.

More weight is provided to the target class with lower frequency for classificationn algorithm to place more importance on given rare class.

In [14]:
weights = compute_class_weight('balanced', classes=y.unique(), y=y)
weights

array([ 1.26961132, 14.11166019,  1.15209047,  0.43984918])

In [15]:
y.unique()

array(['normal.', 'others', 'neptune.', 'smurf.'], dtype=object)