In [3]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.io import arff

# Paths

ADULT_ARFF = "../data/Autism-Adult-Data.arff"
CHILD_ARFF = "../data/Autism-Child-Data.arff"
OUT_DIR = "../results/FIGURES_METHOD"  

os.makedirs(OUT_DIR, exist_ok=True)


def load_arff_to_df(path: str) -> pd.DataFrame:
    data, meta = arff.loadarff(path)
    df = pd.DataFrame(data)

    # Decode byte strings (ARFF often loads categoricals as bytes)
    for col in df.columns:
        if df[col].dtype == object:
            df[col] = df[col].apply(lambda x: x.decode("utf-8") if isinstance(x, (bytes, bytearray)) else x)

    return df

def standardise_missing(df: pd.DataFrame) -> pd.DataFrame:
    # Replace common missing markers with actual NaN
    df = df.replace("?", np.nan)
    df = df.replace("", np.nan)
    return df

def find_target_column(df: pd.DataFrame) -> str:
    # Common target names in these datasets
    candidates = ["Class/ASD", "class", "Class", "asd", "ASD", "target"]
    for c in candidates:
        if c in df.columns:
            return c
    # fallback: last column often the target
    return df.columns[-1]


# Load datasets

adult_df = standardise_missing(load_arff_to_df(ADULT_ARFF))
child_df = standardise_missing(load_arff_to_df(CHILD_ARFF))

adult_target = find_target_column(adult_df)
child_target = find_target_column(child_df)


# Missing value heatmap (Adult + Child)

def plot_missing_heatmap(df: pd.DataFrame, title: str, outpath: str, max_rows: int = 250):
    # For readability, sample rows if dataset is large
    if len(df) > max_rows:
        plot_df = df.sample(n=max_rows, random_state=42)
    else:
        plot_df = df.copy()

    missing_matrix = plot_df.isna()

    plt.figure(figsize=(12, 4))
    sns.heatmap(missing_matrix, cbar=False)
    plt.title(title)
    plt.xlabel("Features")
    plt.ylabel("Rows (sampled)" if len(df) > max_rows else "Rows")
    plt.tight_layout()
    plt.savefig(outpath, dpi=300)
    plt.close()

plot_missing_heatmap(
    adult_df,
    "Missing Value Pattern (Adult Dataset)",
    os.path.join(OUT_DIR, "Missing_heatmap_adult.png")
)

plot_missing_heatmap(
    child_df,
    "Missing Value Pattern (Child Dataset)",
    os.path.join(OUT_DIR, "Missing_heatmap_child.png")
)


def plot_missing_heatmap_combined(adult_df: pd.DataFrame, child_df: pd.DataFrame, outpath: str):
    def prep(df, max_rows=200):
        if len(df) > max_rows:
            df = df.sample(n=max_rows, random_state=42)
        return df.isna()

    adult_m = prep(adult_df)
    child_m = prep(child_df)

    fig, axes = plt.subplots(2, 1, figsize=(12, 7), sharex=True)
    sns.heatmap(adult_m, cbar=False, ax=axes[0])
    axes[0].set_title("Missing Value Pattern (Adult Dataset)")
    axes[0].set_ylabel("Rows")

    sns.heatmap(child_m, cbar=False, ax=axes[1])
    axes[1].set_title("Missing Value Pattern (Child Dataset)")
    axes[1].set_ylabel("Rows")
    axes[1].set_xlabel("Features")

    plt.tight_layout()
    plt.savefig(outpath, dpi=300)
    plt.close()

plot_missing_heatmap_combined(
    adult_df,
    child_df,
    os.path.join(OUT_DIR, "Missing_heatmap_adult_child.png")
)


# Class distribution plot (Adult + Child)

def plot_class_distribution(df: pd.DataFrame, target_col: str, title: str, outpath: str):
    counts = df[target_col].value_counts(dropna=False)

    plt.figure(figsize=(6, 4))
    counts.plot(kind="bar")
    plt.title(title)
    plt.xlabel("Class")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(outpath, dpi=300)
    plt.close()

plot_class_distribution(
    adult_df, adult_target,
    f"Class Distribution (Adult Dataset) [{adult_target}]",
    os.path.join(OUT_DIR, "Class_distribution_adult.png")
)

plot_class_distribution(
    child_df, child_target,
    f"Class Distribution (Child Dataset) [{child_target}]",
    os.path.join(OUT_DIR, "Class_distribution_child.png")
)

# combined Adult vs Child on one figure
def plot_class_distribution_combined(adult_df, adult_target, child_df, child_target, outpath: str):
    adult_counts = adult_df[adult_target].value_counts()
    child_counts = child_df[child_target].value_counts()

    # Align index categories
    all_classes = sorted(set(adult_counts.index).union(set(child_counts.index)))
    adult_counts = adult_counts.reindex(all_classes, fill_value=0)
    child_counts = child_counts.reindex(all_classes, fill_value=0)

    x = np.arange(len(all_classes))
    width = 0.35

    fig, ax = plt.subplots(figsize=(7, 4))
    ax.bar(x - width/2, adult_counts.values, width, label="Adult")
    ax.bar(x + width/2, child_counts.values, width, label="Child")

    ax.set_title("Class Distribution (Adult vs Child)")
    ax.set_xticks(x)
    ax.set_xticklabels([str(c) for c in all_classes])
    ax.set_xlabel("Class")
    ax.set_ylabel("Count")
    ax.legend()

    plt.tight_layout()
    plt.savefig(outpath, dpi=300)
    plt.close()

plot_class_distribution_combined(
    adult_df, adult_target,
    child_df, child_target,
    os.path.join(OUT_DIR, "Class_distribution_adult_child.png")
)

print(f"Saved figures to: {OUT_DIR}")


Saved figures to: ../results/FIGURES_METHOD
