In [3]:
import pandas as pd
import numpy as np
from mlxtend.frequent_patterns import apriori, fpgrowth, association_rules
from mlxtend.preprocessing import TransactionEncoder
from IPython.display import display, HTML
import os

#Load cleaned dataset
df = pd.read_csv("../outputs/cleaned_data.csv")

#Z-score normalization per participant
for col in ['HR', 'RMSSD', 'LF_HF', 'sampen']:
    df[col + '_z'] = df.groupby('datasetId')[col].transform(lambda x: (x - x.mean()) / x.std())

#Discretize z-scores into bins
def z_to_bin(z):
    if pd.isna(z): return None
    if z <= -1: return 'low'
    elif z >= 1: return 'high'
    else: return 'normal'

for col in ['HR', 'RMSSD', 'LF_HF', 'sampen']:
    df[col + '_bin'] = df[col + '_z'].apply(z_to_bin)
    df[col + '_item'] = df[col + '_bin'].apply(lambda b: f"{col}_{b}" if b else None)

#Add condition as item
df['condition_item'] = df['condition'].apply(lambda c: f"condition_{c}")

#Create transactions
features = ['HR', 'RMSSD', 'LF_HF', 'sampen']
def build_transaction(row):
    items = [row[f"{f}_item"] for f in features if pd.notna(row[f"{f}_item"])]
    items.append(row['condition_item'])
    return items

transactions = df.apply(build_transaction, axis=1).tolist()
transactions = [t for t in transactions if len(t) >= 2]

#One-hot encode transactions
te = TransactionEncoder()
te_ary = te.fit(transactions).transform(transactions)
df_trans = pd.DataFrame(te_ary, columns=te.columns_)

#Apriori Algorithm
freq_items = apriori(df_trans, min_support=0.03, use_colnames=True)
rules = association_rules(freq_items, metric='confidence', min_threshold=0.6)
rules = rules[rules['lift'] > 1.0]
rules_simple = rules[['antecedents','consequents','support','confidence','lift']].copy()
rules_simple['antecedents'] = rules_simple['antecedents'].apply(lambda x: ', '.join(list(x)))
rules_simple['consequents'] = rules_simple['consequents'].apply(lambda x: ', '.join(list(x)))
rules_simple = rules_simple.sort_values(['lift','confidence'], ascending=False).reset_index(drop=True)

#FP-Growth Algorithm
freq_items_fp = fpgrowth(df_trans, min_support=0.03, use_colnames=True)
rules_fp = association_rules(freq_items_fp, metric='confidence', min_threshold=0.6)
rules_fp = rules_fp[rules_fp['lift'] > 1.0]
rules_fp_simple = rules_fp[['antecedents','consequents','support','confidence','lift']].copy()
rules_fp_simple['antecedents'] = rules_fp_simple['antecedents'].apply(lambda x: ', '.join(list(x)))
rules_fp_simple['consequents'] = rules_fp_simple['consequents'].apply(lambda x: ', '.join(list(x)))
rules_fp_simple = rules_fp_simple.sort_values(['lift','confidence'], ascending=False).reset_index(drop=True)

#Display top 5 rules (tabular)
display(HTML("<h4>Top 5 Apriori Rules</h4>"))
display(rules_simple.head(5).style.hide(axis='index').set_table_styles(
    [{'selector': 'th', 'props': [('text-align', 'center')]},
     {'selector': 'td', 'props': [('text-align', 'center')]}]
))

display(HTML("<h4>Top 5 FP-Growth Rules</h4>"))
display(rules_fp_simple.head(5).style.hide(axis='index').set_table_styles(
    [{'selector': 'th', 'props': [('text-align', 'center')]},
     {'selector': 'td', 'props': [('text-align', 'center')]}]
))

#Save outputs
os.makedirs("../outputs/rules", exist_ok=True)
rules_simple.to_csv("../outputs/rules/apriori_rules.csv", index=False)
rules_fp_simple.to_csv("../outputs/rules/fpgrowth_rules.csv", index=False)


antecedents,consequents,support,confidence,lift
"RMSSD_high, HR_low","sampen_normal, LF_HF_high",0.043381,0.60245,6.518768
"sampen_normal, RMSSD_high, HR_low",LF_HF_high,0.043381,0.617602,6.492516
"RMSSD_high, HR_low",LF_HF_high,0.043974,0.610687,6.419821
"LF_HF_high, RMSSD_high","sampen_normal, HR_low",0.043381,0.986512,5.971974
"LF_HF_high, RMSSD_high",HR_low,0.043974,1.0,5.803306


antecedents,consequents,support,confidence,lift
"RMSSD_high, HR_low","sampen_normal, LF_HF_high",0.043381,0.60245,6.518768
"sampen_normal, RMSSD_high, HR_low",LF_HF_high,0.043381,0.617602,6.492516
"RMSSD_high, HR_low",LF_HF_high,0.043974,0.610687,6.419821
"LF_HF_high, RMSSD_high","sampen_normal, HR_low",0.043381,0.986512,5.971974
"LF_HF_high, RMSSD_high",HR_low,0.043974,1.0,5.803306
