# Stratifying Training and Test Data to Handle Class Imbalance

```
topic_label
Lifestyle    4835
Politics     1783
Events        446
```

In the given dataset, the distribution of classes in the topic_label column is imbalanced. Specifically, the "Lifestyle" class has 4835 samples, the "Politics" class has 1783 samples, and the "Events" class has only 446 samples. This class imbalance could lead to biased model training and inaccurate performance evaluation, particularly for minority classes like "Events".

## Solution

To address class imbalance, it's important to ensure that the distribution of classes is maintained in both the training and test datasets. One way to achieve this is through stratified sampling, where samples are drawn from the dataset in such a way that the class proportions are preserved in each subset.

```
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
```

With stratification, each subset (training and test sets) maintains the same proportion of classes as the original dataset. This ensures that the distribution of classes in the training and test sets mirrors that of the original dataset.

## Reasoning

Stratifying the training and test data ensures that the model sees a representative distribution of classes during both training and testing, mitigating the effects of class imbalance. This approach leads to more reliable model performance evaluation and better generalization to unseen data, particularly for minority classes like "Events". By maintaining class proportions in both datasets, stratified sampling contributes to more robust and fair model training and evaluation.

In [21]:
import pandas as pd

df = pd.read_pickle("out/df_final.pkl")

df

Unnamed: 0,topic_label,type_of_material_Biography,type_of_material_Brief,type_of_material_Correction,type_of_material_Editorial,type_of_material_First Chapter,type_of_material_Interview,type_of_material_Letter,type_of_material_List,type_of_material_News,...,world,would,write,writer,writing,yankee,yankees,yearold,yet,young
0,Lifestyle,False,False,False,False,False,False,False,False,True,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0
1,Lifestyle,False,False,False,False,False,False,False,False,True,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0
2,Lifestyle,False,False,False,False,False,False,False,False,True,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0
3,Politics,False,False,False,False,False,False,False,False,True,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.159553,0.0
4,Lifestyle,False,False,False,False,False,False,False,False,True,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7059,Politics,False,False,False,False,False,False,False,False,True,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0
7060,Lifestyle,False,False,False,False,False,False,False,False,True,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0
7061,Lifestyle,False,False,False,False,False,False,False,False,False,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.145484,0.0
7062,Lifestyle,False,False,False,False,False,False,False,False,False,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0


In [22]:
df['topic_label'].value_counts()

topic_label
Lifestyle    4835
Politics     1783
Events        446
Name: count, dtype: int64

In [23]:
from sklearn.model_selection import train_test_split

X = df.drop(columns=['topic_label'])

y = df['topic_label']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)


In [24]:
import numpy as np

np.unique(y_train, return_counts=True)

(array(['Events', 'Lifestyle', 'Politics'], dtype=object),
 array([ 357, 3868, 1426], dtype=int64))

In [25]:
import numpy as np

np.unique(y_train, return_counts=True)

(array(['Events', 'Lifestyle', 'Politics'], dtype=object),
 array([ 357, 3868, 1426], dtype=int64))

In [26]:
import pickle

# Save train-test split data to a pickle file
data = {'X_train': X_train, 'X_test': X_test, 'y_train': y_train, 'y_test': y_test}

with open("supervised_train_test_data/train_test_data.pkl", "wb") as f:
    pickle.dump(data, f)