In [None]:
from ucimlrepo import fetch_ucirepo

student_performance = fetch_ucirepo(id=320)
X = student_performance.data.features
y = student_performance.data.targets

In [None]:
print(f"Features: {X.shape}, Target: {y.shape}")
print(f"\nTarget distribution:\n{y['G3'].apply(lambda x: 'Pass' if x >= 10 else 'Fail').value_counts()}")

Features shape: (649, 30)

Target shape: (649, 3)

First few rows of features:
  school sex  age address famsize Pstatus  Medu  Fedu     Mjob      Fjob  ...  \
0     GP   F   18       U     GT3       A     4     4  at_home   teacher  ...   
1     GP   F   17       U     GT3       T     1     1  at_home     other  ...   
2     GP   F   15       U     LE3       T     1     1  at_home     other  ...   
3     GP   F   15       U     GT3       T     4     2   health  services  ...   
4     GP   F   16       U     GT3       T     3     3    other     other  ...   

  higher internet  romantic  famrel  freetime goout Dalc Walc health absences  
0    yes       no        no       4         3     4    1    1      3        4  
1    yes      yes        no       5         3     3    1    1      3        2  
2    yes      yes        no       4         3     2    2    3      3        6  
3    yes      yes       yes       3         2     2    1    1      5        0  
4    yes       no        no       

## Feature Engineering

In [None]:
import numpy as np
import pandas as pd

data = X.copy()
data['outcome'] = y['G3'].apply(lambda x: 'Pass' if x >= 10 else 'Fail')

# ORDINAL predictors
data['study_time'] = data['studytime'].map({1: '<2h', 2: '2-5h', 3: '5-10h', 4: '>10h'})
STUDY_TIME_ORDER = ['<2h', '2-5h', '5-10h', '>10h']

data['mother_education'] = data['Medu'].map({0: 'None', 1: 'Primary', 2: 'Primary', 3: 'Secondary', 4: 'Higher'})
MOTHER_EDU_ORDER = ['None', 'Primary', 'Secondary', 'Higher']

data['failures_cat'] = data['failures'].apply(lambda x: str(x) if x < 3 else '3+')
FAILURES_ORDER = ['0', '1', '2', '3+']

# FLOATING predictor (ordinal + missing category)
np.random.seed(42)
absence_bins = pd.cut(data['absences'], bins=[-1, 0, 5, 10, 100], labels=['None', 'Low', 'Medium', 'High'])
data['absence_level'] = absence_bins.astype(str)
missing_mask = np.random.rand(len(data)) < 0.10
data.loc[missing_mask, 'absence_level'] = 'miss'
ABSENCE_ORDER = ['None', 'Low', 'Medium', 'High', 'miss']

print(f"Dataset: {data.shape[0]} samples")
print(f"Outcome: {dict(data['outcome'].value_counts())}")

Feature Engineering Complete!

Dataset shape: (649, 35)

Outcome distribution:
outcome
Pass    549
Fail    100
Name: count, dtype: int64


--- Predictor Type Examples ---

1. ORDINAL (study_time) - ordered categories:
   Order: ['<2h', '2-5h', '5-10h', '>10h']
study_time
<2h      212
2-5h     305
5-10h     97
>10h      35
Name: count, dtype: int64

2. NOMINAL (higher) - unordered categories:
higher
yes    580
no      69
Name: count, dtype: int64

3. FLOATING (absence_level) - ordinal + floating 'miss' category:
   Order: ['None', 'Low', 'Medium', 'High', 'miss']
absence_level
None      210
Low       202
Medium    113
High       46
miss       78
Name: count, dtype: int64


In [None]:
import importlib, sys
for mod in [k for k in sys.modules if k.startswith('chaid')]:
    del sys.modules[mod]

from chaid import CHAIDTree, PredictorConfig, PredictorType
from chaid.visualization import (
    get_all_pairwise_at_step, 
    get_successive_merges_table,
    pairwise_chi_square_table,
    visualize_tree,
    get_predictor_summary_table
)

✓ CHAID module reloaded successfully!
  Available exports: CHAIDTree, PredictorConfig, PredictorType
  Visualization functions loaded


In [None]:
predictors = ['study_time', 'mother_education', 'failures_cat', 'absence_level', 'sex', 'higher', 'internet']
X_chaid = data[predictors].copy()
y_chaid = data['outcome']

predictor_configs = {
    'study_time': PredictorConfig('study_time', PredictorType.ORDINAL, STUDY_TIME_ORDER),
    'mother_education': PredictorConfig('mother_education', PredictorType.ORDINAL, MOTHER_EDU_ORDER),
    'failures_cat': PredictorConfig('failures_cat', PredictorType.ORDINAL, FAILURES_ORDER),
    'absence_level': PredictorConfig('absence_level', PredictorType.FLOATING, ABSENCE_ORDER, 'miss'),
    'sex': PredictorConfig('sex', PredictorType.NOMINAL),
    'higher': PredictorConfig('higher', PredictorType.NOMINAL),
    'internet': PredictorConfig('internet', PredictorType.NOMINAL)
}

tree = CHAIDTree(alpha_merge=0.05, alpha_split=0.05, max_depth=3, min_parent_size=30, min_child_size=15)
tree.fit(X_chaid, y_chaid, predictor_types=predictor_configs)

print(f"Nodes: {len(tree.nodes)}, Depth: {tree.get_depth()}, Leaves: {len(tree.get_leaves())}")
print(f"Root split: {tree.root.split_variable}")

Training CHAID tree with all predictor types...

✓ Tree trained successfully!
   Total nodes: 7
   Tree depth: 3
   Leaf nodes: 4

   Root split variable: failures_cat
   Root split groups: (frozenset({np.str_('0')}), frozenset({np.str_('2'), np.str_('3+'), np.str_('1')}))


## Tree Visualization

In [None]:
print(tree.print_tree())

In [None]:
fig = visualize_tree(tree, method="plot", figsize=(16, 10))
fig.savefig('chaid_tree.png', dpi=150, bbox_inches='tight')

## Pairwise Chi-Square Tables by Predictor Type

In [None]:
# ORDINAL predictor: failures_cat (only adjacent pairs)
print("ORDINAL: failures_cat")
print("=" * 60)
print(get_all_pairwise_at_step(tree, 0, 'failures_cat', step=0))
_, table = get_successive_merges_table(tree, 0, 'failures_cat')
print(table)

╔══════════════════════════════════════════════════════════════════════════════╗
║                       ORDINAL PREDICTOR: failures_cat                        ║
║       (Categories: 0 < 1 < 2 < 3+)  -  Only adjacent pairs considered        ║
╚══════════════════════════════════════════════════════════════════════════════╝

──────────────────────────────────────────────────────────────────────────────
TABLE 3 STYLE: Chi-squares and p-values by pair (Step 0 - Initial)
──────────────────────────────────────────────────────────────────────────────
Initial pairwise table (before any merging)

Groups:
  [1] 0
  [2] 1
  [3] 2
  [4] 3+

                 1           2           3           4
------------------------------------------------------
     1           —       70.94                        
     2      0.0000           —        0.10            
     3                  0.7565           —        0.62
     4                              0.4308           —
---------------------------------

In [None]:
# NOMINAL predictor: higher (all pairs)
print("NOMINAL: higher")
print("=" * 60)
print(get_all_pairwise_at_step(tree, 0, 'higher', step=0))
_, table = get_successive_merges_table(tree, 0, 'higher')
print(table)

╔══════════════════════════════════════════════════════════════════════════════╗
║                          NOMINAL PREDICTOR: higher                           ║
║                (Categories: yes, no)  -  All pairs considered                ║
╚══════════════════════════════════════════════════════════════════════════════╝

──────────────────────────────────────────────────────────────────────────────
TABLE 3 STYLE: Chi-squares and p-values by pair (Step 0 - Initial)
──────────────────────────────────────────────────────────────────────────────
Initial pairwise table (before any merging)

Groups:
  [1] no
  [2] yes

                 1           2
------------------------------
     1           —       62.25
     2      0.0000           —
------------------------------
Upper triangle: χ² values | Lower triangle: p-values

→ Most similar pair: [1] and [2] (χ²=62.25, p=0.0000)

──────────────────────────────────────────────────────────────────────────────
TABLE 6 STYLE: Successive Merges
─

In [None]:
# FLOATING predictor: absence_level (adjacent + floating 'miss' with all)
print("FLOATING: absence_level")
print("=" * 60)
print(get_all_pairwise_at_step(tree, 0, 'absence_level', step=0))
_, table = get_successive_merges_table(tree, 0, 'absence_level')
print(table)

╔══════════════════════════════════════════════════════════════════════════════╗
║                      FLOATING PREDICTOR: absence_level                       ║
║           (Ordinal: None < Low < Medium < High) + Floating: 'miss'           ║
║             Adjacent pairs + 'miss' can merge with ANY category              ║
╚══════════════════════════════════════════════════════════════════════════════╝

──────────────────────────────────────────────────────────────────────────────
TABLE 9 STYLE: Chi-squares and p-values by pair (Step 0 - Initial)
Note: 'miss' column/row shows ALL pairwise comparisons (floating)
      Other pairs only show adjacent comparisons (ordinal)
──────────────────────────────────────────────────────────────────────────────
Initial pairwise table (before any merging)

Groups:
  [1] None
  [2] Low
  [3] Medium
  [4] High
  [5] miss

                 1           2           3           4           5
------------------------------------------------------------------


## Predictor Summary (Table 10)

In [None]:
print(get_predictor_summary_table(tree, node_id=0))

╔════════════════════════════════════════════════════════════════════════════════════════════════════╗
║                          TABLE 10: Summary of Possible First Level Splits                          ║
╠════════════════════════════════════════════════════════════════════════════════════════════════════╣
║ Predictor            │ Type     │  #cat │  #grp │     Chi-sq │  df │        p-value │ Selected ║
╠────────────────────────────────────────────────────────────────────────────────────────────────────╣
║ failures_cat         │ ORDINAL  │     4 │     2 │     102.34 │   1 │   0.0000000000 │        ★ ║
║ higher               │ NOMINAL  │     2 │     2 │      62.25 │   1 │   0.0000000000 │          ║
║ study_time           │ ORDINAL  │     4 │     3 │      19.25 │   2 │       0.000198 │          ║
║ mother_education     │ ORDINAL  │     4 │     2 │      10.09 │   1 │       0.004474 │          ║
║ internet             │ NOMINAL  │     2 │     2 │       5.05 │   1 │       0.024620 │      