# Handling Imbalanced Dataset

Most classification algorithms assume that classes/categories are evenly distributed.

However, many real-world classification problems have an imbalanced class distribution such as fraud detection, spam detection, and churn prediction. (anomaly detections)

In classification problems, target variables that are severely imbalanced between different categories (i.e. beyond 70-30 ratio for binary classification and n equally divided categories for multiclass classification) may result in inaccurate predictions from classification algorithms.

Thus, this scenario needs to be handled accordingly using various resampling techniques available using <b>imblearn</b> library.

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import scipy.stats as st
import feature_engine.imputation as fei
from imblearn.under_sampling import NearMiss
from imblearn.over_sampling import RandomOverSampler
from imblearn.combine import SMOTETomek
from sklearn.utils import compute_class_weight
from imblearn.ensemble import EasyEnsembleClassifier
from collections import Counter
from sklearn.datasets import fetch_kddcup99
import category_encoders as ce
pd.set_option('display.max_columns', 100)

In [2]:
dataset = fetch_kddcup99()
data = pd.concat([pd.DataFrame(dataset['data'], columns=dataset['feature_names']),
                  pd.DataFrame(dataset['target'], columns=dataset['target_names'])],axis=1)
data.head()

Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,num_failed_logins,logged_in,num_compromised,root_shell,su_attempted,num_root,num_file_creations,num_shells,num_access_files,num_outbound_cmds,is_host_login,is_guest_login,count,srv_count,serror_rate,srv_serror_rate,rerror_rate,srv_rerror_rate,same_srv_rate,diff_srv_rate,srv_diff_host_rate,dst_host_count,dst_host_srv_count,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate,labels
0,0,b'tcp',b'http',b'SF',181,5450,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,8,8,0.0,0.0,0.0,0.0,1.0,0.0,0.0,9,9,1.0,0.0,0.11,0.0,0.0,0.0,0.0,0.0,b'normal.'
1,0,b'tcp',b'http',b'SF',239,486,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,8,8,0.0,0.0,0.0,0.0,1.0,0.0,0.0,19,19,1.0,0.0,0.05,0.0,0.0,0.0,0.0,0.0,b'normal.'
2,0,b'tcp',b'http',b'SF',235,1337,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,8,8,0.0,0.0,0.0,0.0,1.0,0.0,0.0,29,29,1.0,0.0,0.03,0.0,0.0,0.0,0.0,0.0,b'normal.'
3,0,b'tcp',b'http',b'SF',219,1337,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,6,6,0.0,0.0,0.0,0.0,1.0,0.0,0.0,39,39,1.0,0.0,0.03,0.0,0.0,0.0,0.0,0.0,b'normal.'
4,0,b'tcp',b'http',b'SF',217,2032,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,6,6,0.0,0.0,0.0,0.0,1.0,0.0,0.0,49,49,1.0,0.0,0.02,0.0,0.0,0.0,0.0,0.0,b'normal.'


In [3]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 494021 entries, 0 to 494020
Data columns (total 42 columns):
 #   Column                       Non-Null Count   Dtype 
---  ------                       --------------   ----- 
 0   duration                     494021 non-null  object
 1   protocol_type                494021 non-null  object
 2   service                      494021 non-null  object
 3   flag                         494021 non-null  object
 4   src_bytes                    494021 non-null  object
 5   dst_bytes                    494021 non-null  object
 6   land                         494021 non-null  object
 7   wrong_fragment               494021 non-null  object
 8   urgent                       494021 non-null  object
 9   hot                          494021 non-null  object
 10  num_failed_logins            494021 non-null  object
 11  logged_in                    494021 non-null  object
 12  num_compromised              494021 non-null  object
 13  root_shell    

In [4]:
data.nunique()

duration                        2495
protocol_type                      3
service                           66
flag                              11
src_bytes                       3300
dst_bytes                      10725
land                               2
wrong_fragment                     3
urgent                             4
hot                               22
num_failed_logins                  6
logged_in                          2
num_compromised                   23
root_shell                         2
su_attempted                       3
num_root                          20
num_file_creations                18
num_shells                         3
num_access_files                   7
num_outbound_cmds                  1
is_host_login                      1
is_guest_login                     2
count                            490
srv_count                        470
serror_rate                       92
srv_serror_rate                   51
rerror_rate                       77
s

In [5]:
data[np.array(data.columns[24:31])] = data[np.array(data.columns[24:31])].astype('float')
data[np.array(data.columns[33:41])] = data[np.array(data.columns[33:41])].astype('float')
data[['protocol_type','service','flag','labels']] = data[['protocol_type','service','flag','labels']].applymap(
                                                    lambda x: x.decode('ascii'))

In [6]:
integer_columns = np.array(data.drop(['serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 
                                      'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_diff_srv_rate', 
                                      'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 
                                      'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate', 
                                      'dst_host_same_srv_rate','protocol_type','service','flag','labels'],axis=1).columns)
data[integer_columns] = data[integer_columns].astype('int')

In [7]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 494021 entries, 0 to 494020
Data columns (total 42 columns):
 #   Column                       Non-Null Count   Dtype  
---  ------                       --------------   -----  
 0   duration                     494021 non-null  int32  
 1   protocol_type                494021 non-null  object 
 2   service                      494021 non-null  object 
 3   flag                         494021 non-null  object 
 4   src_bytes                    494021 non-null  int32  
 5   dst_bytes                    494021 non-null  int32  
 6   land                         494021 non-null  int32  
 7   wrong_fragment               494021 non-null  int32  
 8   urgent                       494021 non-null  int32  
 9   hot                          494021 non-null  int32  
 10  num_failed_logins            494021 non-null  int32  
 11  logged_in                    494021 non-null  int32  
 12  num_compromised              494021 non-null  int32  
 13 

In [8]:
X = data.drop('labels',axis=1)
y = data['labels']

In [9]:
# Service column has 66 unique values. Count Encoder is used instead of One Hot Encoder
countencoder = ce.CountEncoder(cols='service')
X_dummy = countencoder.fit_transform(X)

In [10]:
# Protocol_type and flag column has less than 15 unique values. One Hot Encoder is used.
onehotencoder = ce.OneHotEncoder(cols=['protocol_type','flag'])
X_dummy = onehotencoder.fit_transform(X_dummy).drop(['protocol_type_1','flag_1'],axis=1)

In [11]:
y.value_counts()/len(y)

smurf.              0.568377
neptune.            0.216997
normal.             0.196911
back.               0.004459
satan.              0.003216
ipsweep.            0.002524
portsweep.          0.002105
warezclient.        0.002065
teardrop.           0.001982
pod.                0.000534
nmap.               0.000468
guess_passwd.       0.000107
buffer_overflow.    0.000061
land.               0.000043
warezmaster.        0.000040
imap.               0.000024
rootkit.            0.000020
loadmodule.         0.000018
ftp_write.          0.000016
multihop.           0.000014
phf.                0.000008
perl.               0.000006
spy.                0.000004
Name: labels, dtype: float64

In [12]:
y = y.map(lambda x: 'others' if x not in ['smurf.','neptune.','normal.'] else x)

In [13]:
y.value_counts()/len(y)

smurf.      0.568377
neptune.    0.216997
normal.     0.196911
others      0.017716
Name: labels, dtype: float64

## Under-Sampling

Under-sampling method reduces number of labels for more frequent class to match with number of labels for least frequent class.

Advantages: Improves runtime and storage issues for very large datasets.

Disadvantages: Very prone to data loss, may result in inaccurate results.

In [14]:
# Under-sampling method
nm = NearMiss()
X_resampled, y_resampled = nm.fit_resample(X_dummy,y)
X_resampled.describe()

Unnamed: 0,duration,protocol_type_2,protocol_type_3,service,flag_2,flag_3,flag_4,flag_5,flag_6,flag_7,flag_8,flag_9,flag_10,flag_11,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,num_failed_logins,logged_in,num_compromised,root_shell,su_attempted,num_root,num_file_creations,num_shells,num_access_files,num_outbound_cmds,is_host_login,is_guest_login,count,srv_count,serror_rate,srv_serror_rate,rerror_rate,srv_rerror_rate,same_srv_rate,diff_srv_rate,srv_diff_host_rate,dst_host_count,dst_host_srv_count,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate
count,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0,35008.0
mean,75.137997,0.153622,0.314642,125984.004942,0.0002,0.154936,0.0002,0.180502,0.000143,0.002685,0.024937,0.000314,0.000229,0.003056,32289.66,2867.238,0.000628,0.090779,0.000114,0.365888,0.001771,0.142625,0.064928,0.000914,2.9e-05,0.065328,0.002199,0.000343,0.000314,0.0,0.0,0.009569,184.115402,126.941213,0.185946,0.184047,0.182861,0.185572,0.695568,0.087128,0.048578,210.278451,111.266996,0.533335,0.124226,0.426309,0.036937,0.185393,0.183737,0.18713,0.181652
std,1353.326206,0.360591,0.46438,102457.392525,0.014139,0.361849,0.014139,0.38461,0.01195,0.051749,0.155936,0.017724,0.015115,0.055201,3711776.0,107013.8,0.025061,0.498792,0.013091,2.543861,0.049533,0.349695,0.3541,0.03022,0.005345,0.776022,0.129473,0.022673,0.019268,0.0,0.0,0.097355,212.855305,214.842018,0.382528,0.386966,0.378887,0.383674,0.444386,0.238017,0.203232,88.191237,114.163593,0.468499,0.270908,0.467935,0.136175,0.380786,0.386293,0.373258,0.380784
min,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,64293.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.03,0.0,0.0,255.0,2.0,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,0.0,110893.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,105.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,101.0,2.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,255.0,54.0,0.83,0.01,0.09,0.0,0.0,0.0,0.0,0.0
75%,0.0,0.0,1.0,281400.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1032.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,507.0,90.0,0.0,0.0,0.0,0.0,1.0,0.06,0.0,255.0,255.0,1.0,0.07,1.0,0.0,0.01,0.0,0.03,0.0
max,42448.0,1.0,1.0,281400.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,693375600.0,5155468.0,1.0,3.0,2.0,28.0,5.0,1.0,38.0,1.0,1.0,54.0,21.0,2.0,2.0,0.0,0.0,1.0,511.0,511.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,255.0,255.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [15]:
# Checking for count of values between groupings of variable after under-sampling
Counter(y_resampled)

Counter({'neptune.': 8752, 'normal.': 8752, 'others': 8752, 'smurf.': 8752})

In [16]:
# Checking for proportion of values between groupings of variable after under-sampling
y_resampled.value_counts()/len(y_resampled)

neptune.    0.25
normal.     0.25
others      0.25
smurf.      0.25
Name: labels, dtype: float64

## Over-Sampling

Over-sampling method increases the number of labels for less frequent classes to match with number of labels for most frequent class

Advantages: 
- No information loss
- Usually outperforms under-sampling method

Disadvantages:
- Increase likelihood of overfitting, since it only replicates minority class records.

In [17]:
sampler = RandomOverSampler()
X_resampled, y_resampled = sampler.fit_resample(X_dummy,y)
X_resampled.describe()

Unnamed: 0,duration,protocol_type_2,protocol_type_3,service,flag_2,flag_3,flag_4,flag_5,flag_6,flag_7,flag_8,flag_9,flag_10,flag_11,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,num_failed_logins,logged_in,num_compromised,root_shell,su_attempted,num_root,num_file_creations,num_shells,num_access_files,num_outbound_cmds,is_host_login,is_guest_login,count,srv_count,serror_rate,srv_serror_rate,rerror_rate,srv_rerror_rate,same_srv_rate,diff_srv_rate,srv_diff_host_rate,dst_host_count,dst_host_srv_count,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate
count,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0,1123160.0
mean,129.8267,0.08312262,0.2967244,123314.6,0.0002368318,0.1036326,0.0002564194,0.2087646,9.348624e-05,0.002789451,0.0249359,0.0003427829,0.0002003276,0.003047651,29926.35,3566.707,0.0006018733,0.09097546,0.0001397842,0.3714181,0.001645358,0.2740874,0.07403487,0.001014993,8.992485e-05,0.02084832,0.003340575,0.0004603084,0.001552762,0.0,0.0,0.009795577,198.0708,134.8448,0.2145222,0.2127426,0.1316645,0.1350044,0.7129175,0.06881464,0.07057187,207.2587,139.5726,0.6102915,0.1068803,0.3872333,0.03324583,0.2138255,0.2123622,0.1366793,0.1317574
std,1521.395,0.2760676,0.4568142,100482.9,0.01538752,0.3047835,0.01601105,0.406426,0.00966838,0.05274156,0.1559298,0.01851123,0.0141523,0.05512137,3468348.0,105327.3,0.02452573,0.4994676,0.01482889,2.561585,0.04788115,0.4460534,2.346981,0.0318428,0.01159457,2.62387,0.1637473,0.02571605,0.04590979,0.0,0.0,0.09848671,212.8361,215.6063,0.4042677,0.4085511,0.3292804,0.3358115,0.4307698,0.2157235,0.2281585,90.44107,116.7366,0.4589657,0.2599974,0.4691071,0.1298626,0.4026758,0.4080643,0.324831,0.3321026
min,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,64293.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,2.0,0.0,0.0,0.0,0.0,0.11,0.0,0.0,255.0,12.0,0.05,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,0.0,110893.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,213.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,117.0,11.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,255.0,183.0,1.0,0.0,0.02,0.0,0.0,0.0,0.0,0.0
75%,0.0,0.0,1.0,281400.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1032.0,147.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,510.0,265.0,0.02,0.0,0.0,0.0,1.0,0.06,0.0,255.0,255.0,1.0,0.07,1.0,0.0,0.05,0.0,0.0,0.0
max,58329.0,1.0,1.0,281400.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,693375600.0,5155468.0,1.0,3.0,3.0,30.0,5.0,1.0,884.0,1.0,2.0,993.0,28.0,2.0,8.0,0.0,0.0,1.0,511.0,511.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,255.0,255.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [18]:
# Checking for count of values between groupings of variable after over-sampling
Counter(y_resampled)

Counter({'normal.': 280790,
         'others': 280790,
         'neptune.': 280790,
         'smurf.': 280790})

In [19]:
# Checking for proportion of values between groupings of variable after over-sampling
y_resampled.value_counts()/len(y_resampled)

normal.     0.25
others      0.25
neptune.    0.25
smurf.      0.25
Name: labels, dtype: float64

## SMOTETomek

SMOTETomek method combines over-sampling (SMOTE - Synthetic Minority Oversampling Technique) and under-sampling methods (Tomek) by creating new data points around less frequent classes.

Note that SMOTETomek method is less suitable for very large datasets due to large computation time.

In [20]:
# tomek = SMOTETomek(n_jobs=-1)
# X_resampled, y_resampled = tomek.fit_resample(X_dummy,y)

## Easy Ensemble Classifier

Easy Ensemble classifier is an ensemble of AdaBoost learners trained on different balanced boostrap samples. The balancing is achieved by random under-sampling.

In [21]:
eec = EasyEnsembleClassifier()
eec.fit(X_dummy, y)

EasyEnsembleClassifier()

In [22]:
y_pred = eec.predict(X_dummy)

In [23]:
print(Counter(y_pred))

Counter({'smurf.': 280461, 'neptune.': 107971, 'normal.': 101440, 'others': 4149})


## Class Weights Distribution

An alternative method to handle imbalanced dataset in classification problems is to assign class weights to classifiers that do have the option of mentioning class weights.

More weight is provided to the target class with lower frequency for classificationn algorithm to place more importance on given rare class.

In [24]:
weights = compute_class_weight('balanced', classes=y.unique(), y=y)
weights

array([ 1.26961132, 14.11166019,  1.15209047,  0.43984918])

In [25]:
y.unique()

array(['normal.', 'others', 'neptune.', 'smurf.'], dtype=object)