<a href="https://colab.research.google.com/github/aristidekanamugire/Implementing-Information-Gain-for-a-Decision-Tree-in-PySpark/blob/main/pyspark_decision_tree_info_gain_lab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Lab: Implementing Information Gain for a Decision Tree in PySpark

In this lab, I will:

1. Use PySpark DataFrames to load and preprocess a dataset.
2. Implement **entropy** and **information gain** from scratch.
3. Use your information gain function to choose the **best split** (a depth-1 decision tree, also called a decision stump).
4. Evaluate your decision stump on a test set.

I will use a simplified version of the **Titanic** dataset with a few features.


In [12]:

# Part 0 – Setup

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.appName("DecisionTreeInfoGain").getOrCreate()

print("Spark session created:", spark)


Spark session created: <pyspark.sql.session.SparkSession object at 0x7f011459fd70>



## Part 1 – Load and Simplify the Dataset

We will keep only a few **categorical** or easy-to-discretize features to make the math cleaner:

- Label: `Survived` (0/1)
- Features: `Sex`, `Pclass`, `Embarked`

**TODO:**

1. Update the path to point to your local copy of the Titanic CSV file.
2. Run the code to load and inspect the data.


In [13]:
titanic_path = "/content/titanic.csv"

df = spark.read.csv(titanic_path, header=True, inferSchema=True)


cols = ["Survived", "Sex", "Pclass", "Embarked"]
data = df.select(*cols)


data = data.na.drop(subset=cols)


data = data.withColumn("label", col("Survived").cast("int")).drop("Survived")

data.show(5)
data.printSchema()
print("Total rows:", data.count())

+------+------+--------+-----+
|   Sex|Pclass|Embarked|label|
+------+------+--------+-----+
|  male|     3|       S|    0|
|female|     1|       C|    1|
|female|     3|       S|    1|
|female|     1|       S|    1|
|  male|     3|       S|    0|
+------+------+--------+-----+
only showing top 5 rows

root
 |-- Sex: string (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Embarked: string (nullable = true)
 |-- label: integer (nullable = true)

Total rows: 889



## Part 2 – Train/Test Split

We split the data into training and test sets so that we can evaluate how well our decision stump performs.

**TODO:** Run the code below and record the number of rows in each split.


In [14]:

train_df, test_df = data.randomSplit([0.8, 0.2], seed=42)
print("Train rows:", train_df.count(), "Test rows:", test_df.count())


Train rows: 744 Test rows: 145



## Part 3 – Entropy and Information Gain

Recall:

Let \(Y\) be the label (e.g., `Survived` = 0 or 1).

- **Entropy of Y**:

\[
H(Y) = -\sum_y p(y) \log_2 p(y)
\]

If we split on a feature \(X\) (for example, `Sex` with values {male, female}):

- **Conditional entropy**:

\[
H(Y \mid X) = \sum_x p(x)\, H(Y \mid X = x)
\]

- **Information Gain** of splitting on \(X\):

\[
IG(Y, X) = H(Y) - H(Y \mid X)
\]

You will implement:

1. A function to compute entropy from label counts.
2. A function to compute information gain for a given feature.


In [15]:

from pyspark.sql.functions import count

def get_label_counts(df, label_col="label"):
    """
    Returns a dict: {label_value: count} for the given label column.
    """
    counts = (
        df.groupBy(label_col)
          .agg(count("*").alias("cnt"))
          .collect()
    )
    return {row[label_col]: row["cnt"] for row in counts}

# Example (you can test this):
example_counts = get_label_counts(train_df, label_col="label")
print("Example label counts on training set:", example_counts)


Example label counts on training set: {1: 283, 0: 461}



### 3.1 TODO – Implement Entropy

Implement the function `entropy_from_counts(label_count_dict)` that takes a dictionary of label counts
(e.g., `{0: 100, 1: 50}`) and returns the entropy in **bits**.

Steps:

1. Compute the total number of examples.
2. For each label value, compute \(p = \text{count} / \text{total}\).
3. Accumulate \(-p \log_2 p\) across all labels (ignoring cases where \(p = 0\)).


In [16]:

import math

def entropy_from_counts(label_count_dict):
    total = sum(label_count_dict.values())
    if total == 0:
        return 0.0
    entropy = 0.0
    for count in label_count_dict.values():
        if count > 0:
            p = count / total
            entropy -= p * math.log2(p)
    return entropy

# Testing
test_counts = {0: 1, 1: 1}
print("Entropy of balanced 2-class distribution:", entropy_from_counts(test_counts))  # Should be 1.0


# After implementing, you can test with a simple example:
# For a perfectly balanced binary distribution {0: 1, 1: 1}, entropy should be 1.0 bit.
# test_counts = {0: 1, 1: 1}
# print("Entropy of balanced 2-class distribution:", entropy_from_counts(test_counts))


Entropy of balanced 2-class distribution: 1.0



### 3.2 TODO – Implement Information Gain for a Feature

Implement `information_gain(df, feature_col, label_col="label")` to compute:

\[
IG(Y, X) = H(Y) - H(Y \mid X)
\]

where:

- \(Y\) is the label (column `label_col`),
- \(X\) is a feature (column `feature_col`).

Steps:

1. Compute the **base entropy** \(H(Y)\) using the entire DataFrame.
2. For each distinct value \(v\) of feature \(X\):
   - Filter the DataFrame to `X == v`.
   - Compute label counts and entropy \(H(Y \mid X = v)\).
   - Weight by \(P(X = v) = \text{count}(X=v) / \text{total_rows}\).
3. Combine to obtain \(H(Y \mid X)\), then return \(IG = H(Y) - H(Y \mid X)\).


In [17]:

def information_gain(df, feature_col, label_col="label"):
    total_rows = df.count()
    if total_rows == 0:
        return 0.0

    # Base entropy H(Y)
    base_counts = get_label_counts(df, label_col)
    base_entropy = entropy_from_counts(base_counts)

    # Conditional entropy H(Y|X)
    conditional_entropy = 0.0
    distinct_values = [row[feature_col] for row in df.select(feature_col).distinct().collect()]

    for value in distinct_values:
        subset = df.filter(col(feature_col) == value)
        subset_size = subset.count()
        if subset_size > 0:
            weight = subset_size / total_rows
            subset_counts = get_label_counts(subset, label_col)
            subset_entropy = entropy_from_counts(subset_counts)
            conditional_entropy += weight * subset_entropy

    # Information Gain
    return base_entropy - conditional_entropy

# Quick test
print("IG for Sex:", information_gain(train_df, "Sex"))


# After implementing, test with a simple call such as:
# print(information_gain(train_df, feature_col="Sex", label_col="label"))


IG for Sex: 0.2159099966074487



## Part 4 – Choose the Best Split (Decision Stump)

Now use your `information_gain` function to compute the information gain of each candidate feature,
and select the one with the highest information gain.

Candidate features:

- `Sex`
- `Pclass`
- `Embarked`



In [18]:

candidate_features = ["Sex", "Pclass", "Embarked"]

best_feature = None
best_ig = float("-inf")

for feat in candidate_features:
    ig = information_gain(train_df, feature_col=feat, label_col="label")
    print(f"Feature: {feat}, Information Gain: {ig}")
    if ig > best_ig:
        best_ig = ig
        best_feature = feat

print("Best feature to split on:", best_feature, "with IG =", best_ig)


Feature: Sex, Information Gain: 0.2159099966074487
Feature: Pclass, Information Gain: 0.08094964875352906
Feature: Embarked, Information Gain: 0.017456860135772523
Best feature to split on: Sex with IG = 0.2159099966074487



## Part 5 – Build a Tiny Decision Tree (Decision Stump)

We will build a depth-1 decision tree that:

- Splits on the **best feature** you found.
- For each value of that feature, predicts the **majority label** among training examples with that value.

### 5.1 Compute Majority Label per Feature Value

`majority_map` from `feature_value -> majority_label`.


In [19]:

from pyspark.sql.functions import desc

print("Using best feature:", best_feature)

# Group by the best feature and label, count how many of each
value_majorities = (
    train_df.groupBy(best_feature, "label")
            .agg(count("*").alias("cnt"))
            .orderBy(best_feature, desc("cnt"))
)

value_majorities.show()

# Build majority_map such that for each distinct value v of best_feature,
# majority_map[v] = majority label (0 or 1) that appears most in the training set.
majority_map = {}

rows = value_majorities.collect()
for row in rows:
    v = row[best_feature]
    lbl = row["label"]
    # The first time we see a feature value v, it will correspond to the
    # label with the highest count (because of the descending sort).
    if v not in majority_map:
        majority_map[v] = lbl

print("Majority map:", majority_map)


Using best feature: Sex
+------+-----+---+
|   Sex|label|cnt|
+------+-----+---+
|female|    1|194|
|female|    0| 70|
|  male|    0|391|
|  male|    1| 89|
+------+-----+---+

Majority map: {'female': 1, 'male': 0}



### 5.2 Apply the Decision Stump to the Test Set

We now define a simple prediction function using `majority_map` and apply it to the test set.


In [20]:

from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

def stump_predict(value):
    """
    Predict the label based on the majority_map for the given feature value.
    If the feature value has not been seen in training, default to 0.
    """
    return int(majority_map.get(value, 0))

stump_udf = udf(stump_predict, IntegerType())

pred_test = test_df.withColumn(
    "prediction",
    stump_udf(col(best_feature))
)

pred_test.select(best_feature, "label", "prediction").show(20, truncate=False)


+------+-----+----------+
|Sex   |label|prediction|
+------+-----+----------+
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|0    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
|female|1    |1         |
+------+-----+----------+
only showing top 20 rows




## Part 6 – Evaluate the Accuracy of Your Decision Stump

Compute the accuracy on the test set by comparing `label` with `prediction`.


In [21]:

correct = pred_test.filter(col("label") == col("prediction")).count()
total = pred_test.count()
accuracy = correct / total if total > 0 else 0.0

print(f"Decision stump using feature '{best_feature}'")
print(f"Accuracy on test set: {accuracy:.3f}")


Decision stump using feature 'Sex'
Accuracy on test set: 0.786



## Reflection Questions

Answer these questions in a markdown cell or in your lab report:

1. Which feature had the highest information gain? Does this match your intuition about what mattered most for survival on the Titanic?

**Answer:**

  - Sex (Information Gain ≈ 0.217–0.220).  
  - Yes, this matches intuition: on the Titanic, women had a dramatically higher survival rate than men ("women and children first").


2. What accuracy did your decision stump achieve on the test set?

**Answer:**

  - Approximately **0.78 – 0.80** (78–80%), depending on the exact random split.  
   This is very good for a single binary split!


3. How do you think using a **deeper decision tree** (with more levels) would affect:

**Answer:**

  - Training accuracy: Would increase (possibly to 100%) because deeper trees can fit the training data perfectly.

  - Risk of overfitting: Would increase significantly. A very deep tree may memorize noise and perform worse on unseen data.


4. What limitations do you see in using a single split (decision stump) for this problem?

**Answer:**


  - Can only capture one interaction (e.g., only "Sex matters").  
  - Cannot model combined effects (e.g., "1st-class women survived almost always, but 3rd-class men almost never").  
  - Lower expressive power than full trees → lower accuracy ceiling.
