### Import Relevant Libraries

In [1]:
import numpy as np
import pandas as pd
import findspark
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, DoubleType, StringType

### Initialize the Spark Context

In [2]:
# Use findspark.init() to locate spark on the system and import the library for you.
# If you don't do this step, you might face an issue of python worker failed to connect.
findspark.init()

# Start the Spark session
spark = SparkSession.builder.appName("ClassImbalanceHandling").getOrCreate()

## Create Mock Pyspark Dataset with Class Imbalance

In [3]:
# Define the number of samples
total_samples = 100000

# Define the features (e.g., you can choose any distribution for feature values)
# In this example, we use random values from normal distribution
np.random.seed(0)
feature1 = np.random.normal(0, 1000, total_samples)
feature2 = np.random.normal(0, 1, total_samples)
feature3 = np.random.normal(0, 100, total_samples)
feature4 = np.random.normal(0, 1, total_samples)
feature5 = np.random.normal(0, 10, total_samples)

# Create the dataset
data = {
    "col1": feature1,
    "col2": feature2,
    "col3": feature3,
    "col4": feature4,
    "col5": feature5
}

# Create a pandas DataFrame with the dictionary
df = pd.DataFrame(data)

# Introduce class imbalance by assigning one class more samples
# In this example, Class A has the majority and Class B has the minority
# You can adjust the imbalance ratio as needed
imbalance_ratio = 0.005  # 0.5% Class B samples
class_b_samples = int(total_samples * imbalance_ratio)
total_class_b_samples = - class_b_samples
print(f"There are total of {-total_class_b_samples} Class B samples.")

# Add in the "class" column with the labels
df['class'] = 'A'
df.iloc[total_class_b_samples:,5] = 'B'

# Shuffle the dataset to mix the classes (not that it makes a difference but just for demo.)
df = df.sample(frac=1, random_state=0).reset_index(drop=True)

# Show the 1st 5 rows.
print(df.head())

There are total of 500 Class B samples.
          col1      col2        col3      col4       col5 class
0 -1718.649877  0.898257  -39.836576 -0.198589   5.539138     A
1  -121.041019 -0.444813    3.686693 -0.558107  -6.573311     A
2  1038.512670  0.783736   42.352528  0.450083  13.525636     A
3  -209.312628 -2.215415   36.416854  1.678712  12.133693     A
4  -429.805510 -1.292846 -237.993497 -0.564347  16.672520     A


In [4]:
# Check to see if there are the correct number of rows with Class B
print(df.loc[df['class']=='B'].count())

col1     500
col2     500
col3     500
col4     500
col5     500
class    500
dtype: int64


In [5]:
# Define the schema for the PySpark DataFrame
schema = StructType([
    StructField("Feature1", DoubleType(), True),
    StructField("Feature2", DoubleType(), True),
    StructField("Feature3", DoubleType(), True),
    StructField("Feature4", DoubleType(), True),
    StructField("Feature5", DoubleType(), True),
    StructField("Class", StringType(), True)
])

# Convert the pandas DataFrame to PySpark DataFrame
pyspark_df = spark.createDataFrame(df, schema=schema)

In [6]:
# Show the first few rows of the PySpark DataFrame
pyspark_df.show()

+-------------------+--------------------+-------------------+--------------------+-------------------+-----+
|           Feature1|            Feature2|           Feature3|            Feature4|           Feature5|Class|
+-------------------+--------------------+-------------------+--------------------+-------------------+-----+
| -1718.649876612095|  0.8982566006303742| -39.83657612644435|-0.19858908539429249|  5.539138272731238|    A|
|-121.04101922277877| -0.4448128444740752| 3.6866931235569558| -0.5581070764909885| -6.573310665724776|    A|
|  1038.512669661264|  0.7837362000748797| 42.352528263407955| 0.45008269775923704| 13.525636175411687|    A|
|-209.31262765761383|  -2.215414966767953| 36.416853544786235|  1.6787123778283846| 12.133693494115981|    A|
|-429.80551004065086| -1.2928462486265264|-237.99349743855237| -0.5643467373641192|  16.67251969166381|    A|
| -1330.663464004668|  0.7637053829834995| -78.96444127988173| -1.2669440573794475|   3.51228852908923|    A|
|   847.29