# Dementia Classification (Audio + ASR Text)

## Abstract
*(Write one paragraph: problem, data source, methods, and the insights you plan to extract.)*



## Introduction

## Problem Addressed

## Motivation

## Previous Work

## Dataset + EDA

## Project Schedule and Budget

## Technical Approach

## Main Results

## Explainability + Robustness

## Discussion

## Future Work



In [None]:
# Keep code minimal in the notebook; import from dementia_project/ modules.
import json
from pathlib import Path

import pandas as pd


def load_metrics(run_dir: str) -> dict:
    return json.loads(Path(run_dir, "metrics.json").read_text())


runs = {
    "nonml_scaled": "runs/nonml_baseline_scaled",
    "wav2vec2_full_cuda": "runs/wav2vec2_baseline_full_cuda",
    "densenet_full_cuda": "runs/densenet_spec_full_cuda",
}

rows = []
for name, rdir in runs.items():
    m = load_metrics(rdir)
    for split in ["train", "valid", "test"]:
        rows.append(
            {
                "model": name,
                "split": split,
                "accuracy": m[split].get("accuracy"),
                "f1": m[split].get("f1"),
                "roc_auc": m[split].get("roc_auc"),
            }
        )

df_results = pd.DataFrame(rows)
df_results



## Step-by-step: What code runs (module-by-module)

This section is a **walkthrough of every Python module** in `dementia_project/`, in the order we run them.

### 0) Project entrypoints (where things live)
- Code package: `dementia_project/`
- Config: `configs/default.yaml`
- Processed artifacts: `data/processed/`
- Experiment outputs: `runs/`



### 1) Build metadata (audio inventory + join to CSV)
**Module**: `dementia_project/data/build_metadata.py`

**What it does**
- Scans both class folders for `.wav`
- Computes audio duration/sample rate
- Joins dementia-side subjects to `DementiaNet - dementia.csv`
- Assigns control subjects from folder names

**Produces**
- `data/processed/metadata.csv`
- `data/processed/dropped.csv`
- `data/processed/metadata_report.json`

**Command**
```bash
poetry run python -m dementia_project.data.build_metadata \
  --dementia_dir "dementia-20251217T041331Z-1-001" \
  --control_dir "nodementia-20251217T041501Z-1-001" \
  --dementia_csv "DementiaNet - dementia.csv" \
  --out_dir "data/processed"
```

**Helper used**
- `dementia_project/data/name_normalize.py`: `normalize_person_name()` used for robust matching.



### 2) Build splits (subject-level train/valid/test)
**Modules**
- `dementia_project/data/splitting.py`: implements the hybrid split logic.
- `dementia_project/data/build_splits.py`: CLI wrapper that writes outputs.

**What it does**
- Creates `train/valid/test` splits
- Enforces **subject-level separation** using `person_name_norm`
- Uses CSV `datasplit` when available; otherwise assigns deterministically

**Produces**
- `data/processed/splits.csv`
- `data/processed/splits_report.json`

**Command**
```bash
poetry run python -m dementia_project.data.build_splits \
  --metadata_csv "data/processed/metadata.csv" \
  --out_dir "data/processed"
```

**Small I/O helpers**
- `dementia_project/data/io.py`: `load_metadata()` and `load_splits()`.



### 3) Segmentation manifests (time windows)
**Modules**
- `dementia_project/segmentation/time_windows.py`: generates window start/end times.
- `dementia_project/segmentation/build_manifests.py`: CLI wrapper that writes outputs.

**What it does**
- Creates fixed-length windows (e.g., 2s with 0.5s hop) for audio baselines.

**Produces**
- `data/processed/time_segments.csv`

**Command**
```bash
poetry run python -m dementia_project.segmentation.build_manifests \
  --metadata_csv "data/processed/metadata.csv" \
  --splits_csv "data/processed/splits.csv" \
  --out_dir "data/processed" \
  --window_sec 2.0 \
  --hop_sec 0.5
```



### 4) Baseline 1 ‚Äî Non-ML audio (MFCC + pause stats)
**Modules**
- `dementia_project/features/audio_features.py`: MFCC + RMS + pause proxy features
- `dementia_project/train/train_nonml.py`: trains/evaluates Logistic Regression baseline

**Produces**
- `runs/nonml_baseline_scaled/metrics.json`
- `runs/nonml_baseline_scaled/confusion_matrix_test.png`

**Command**
```bash
poetry run python -m dementia_project.train.train_nonml \
  --metadata_csv "data/processed/metadata.csv" \
  --splits_csv "data/processed/splits.csv" \
  --out_dir "runs/nonml_baseline_scaled"
```

**Plot helper**
- `dementia_project/viz/metrics.py`: writes the confusion matrix PNG.



### 5) Baseline 2 ‚Äî Audio-only Wav2Vec2 embeddings
**Modules**
- `dementia_project/features/wav2vec2_embed.py`: loads Wav2Vec2 + mean-pools embeddings
- `dementia_project/train/train_wav2vec2_nonml.py`: trains/evaluates sklearn classifier on embeddings

**Produces**
- `runs/wav2vec2_baseline_full_cuda/metrics.json`
- `runs/wav2vec2_baseline_full_cuda/confusion_matrix_test.png`

**Command (full dataset)**
```bash
poetry run python -m dementia_project.train.train_wav2vec2_nonml \
  --metadata_csv "data/processed/metadata.csv" \
  --splits_csv "data/processed/splits.csv" \
  --out_dir "runs/wav2vec2_baseline_full_cuda" \
  --max_audio_sec 10
```

**Note on CUDA**
- We switched Poetry‚Äôs torch to CUDA (`torch 2.6.0+cu124`), so embedding extraction uses the GPU.



### 6) Baseline 3 ‚Äî DenseNet on spectrograms
**Modules**
- `dementia_project/features/spectrograms.py`: creates log-mel spectrogram tensors
- `dementia_project/train/train_densenet_spec.py`: trains/evaluates DenseNet baseline

**Produces**
- `runs/densenet_spec_full_cuda/metrics.json`
- `runs/densenet_spec_full_cuda/confusion_matrix_test.png`

**Command (full dataset)**
```bash
poetry run python -m dementia_project.train.train_densenet_spec \
  --metadata_csv "data/processed/metadata.csv" \
  --splits_csv "data/processed/splits.csv" \
  --out_dir "runs/densenet_spec_full_cuda" \
  --epochs 5 \
  --batch_size 16 \
  --max_audio_sec 8
```



### 7) ASR (audio ‚Üí transcript + word timestamps)
**Modules**
- `dementia_project/asr/transcribe.py`: Whisper ASR backend (transformers pipeline) producing `words.json`
- `dementia_project/asr/run_asr.py`: CLI runner + caching + `asr_manifest.csv`

**Produces**
- `data/processed/asr_whisper/<audio_id>/transcript.json`
- `data/processed/asr_whisper/<audio_id>/words.json`
- `data/processed/asr_whisper/asr_manifest.csv`

**Command (example sanity run)**
```bash
poetry run python -m dementia_project.asr.run_asr \
  --metadata_csv "data/processed/metadata.csv" \
  --out_dir "data/processed/asr_whisper" \
  --limit 5 \
  --model_name "openai/whisper-tiny" \
  --language en \
  --task transcribe
```

**Command (full run, resumable)**
```bash
poetry run python -m dementia_project.asr.run_asr \
  --metadata_csv "data/processed/metadata.csv" \
  --out_dir "data/processed/asr_whisper" \
  --model_name "openai/whisper-tiny" \
  --language en \
  --task transcribe
```



### 8) Text-only + Fusion model
We will add next:
- **Text-only baseline**: Transformer classifier on `transcript.json`
- **Fusion model**: cross-attention between text embeddings and word-level audio embeddings

Planned new modules will live under:
- `dementia_project/models/`
- `dementia_project/train/`
- `dementia_project/segmentation/` (word-level segments derived from `words.json`)

Here we will fine-tune a BERT model (configured so we can tryout distilled ) to classify our transcripts as dementia or not.

You can run this in bash with:

```bash
poetry run python -m dementia_project.train.train_text_baseline \
    --metadata_csv data/processed/metadata.csv \
    --splits_csv data/processed/splits.csv \
    --asr_manifest_csv data/processed/asr_whisper/asr_manifest.csv \
    --out_dir runs/text_baseline \
    --epochs 3 \
    --batch_size 16
```


## Visualizing 

at this point I have noticed that the test set is majorly imbalanced, I want to check that the training set is not as well

In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

metadata = pd.read_csv("../data/processed/metadata.csv")
splits = pd.read_csv("../data/processed/splits.csv")
df = metadata.merge(splits[["audio_path", "split"]], on="audio_path")


print("=== FILE-LEVEL CLASS DISTRIBUTION ===\n")
for split in ["train", "valid", "test"]:
    subset = df[df["split"] == split]
    counts = subset["label"].value_counts().sort_index()
    print(f"{split.upper()}:")
    print(f"  No Dementia (0): {counts.get(0, 0)}")
    print(f"  Dementia (1):    {counts.get(1, 0)}")
    print(f"  Total: {len(subset)}")
    print(f"  Dementia %: {subset['label'].mean()*100:.1f}%\n")

print("=== SUBJECT-LEVEL CLASS DISTRIBUTION ===\n")
subject_splits = df.groupby("person_name_norm").agg({
    "split": "first",
    "label": "first"
}).reset_index()

for split in ["train", "valid", "test"]:
    subset = subject_splits[subject_splits["split"] == split]
    counts = subset["label"].value_counts().sort_index()
    print(f"{split.upper()} (unique subjects):")
    print(f"  No Dementia (0): {counts.get(0, 0)} subjects")
    print(f"  Dementia (1):    {counts.get(1, 0)} subjects")
    print(f"  Total: {len(subset)} subjects\n")

=== FILE-LEVEL CLASS DISTRIBUTION ===

TRAIN:
  No Dementia (0): 156
  Dementia (1):    108
  Total: 264
  Dementia %: 40.9%

VALID:
  No Dementia (0): 29
  Dementia (1):    20
  Total: 49
  Dementia %: 40.8%

TEST:
  No Dementia (0): 45
  Dementia (1):    3
  Total: 48
  Dementia %: 6.2%

=== SUBJECT-LEVEL CLASS DISTRIBUTION ===

TRAIN (unique subjects):
  No Dementia (0): 69 subjects
  Dementia (1):    68 subjects
  Total: 137 subjects

VALID (unique subjects):
  No Dementia (0): 13 subjects
  Dementia (1):    14 subjects
  Total: 27 subjects

TEST (unique subjects):
  No Dementia (0): 18 subjects
  Dementia (1):    2 subjects
  Total: 20 subjects



In [None]:
import json
from pathlib import Path

# Metrics from pre-computed previous runs
with open("../runs/text_baseline/metrics.json") as f:
    text_metrics = json.load(f)

print("=== TEXT BASELINE RESULTS ===\n")
for split in ["train", "valid", "test"]:
    m = text_metrics[split]
    cm = m["confusion_matrix"]

    # dditional metrics for true/false positive analysis
    tn, fp, fn, tp = cm[0][0], cm[0][1], cm[1][0], cm[1][1]
    total_positive = tp + fn  # actual dementia cases
    total_negative = tn + fp  # actual control cases

    print(f"{split.upper()}:")
    print(f"  Accuracy: {m['accuracy']:.3f}")
    print(f"  F1: {m['f1']:.3f}")
    print(f"  ROC AUC: {m.get('roc_auc', 'N/A')}")
    print(f"  Confusion Matrix:")
    print(f"    [[TN={tn}, FP={fp}],")
    print(f"     [FN={fn}, TP={tp}]]")
    print(f"  Class distribution:")
    print(f"    Dementia cases: {total_positive}")
    print(f"    Control cases: {total_negative}")

    if total_positive > 0:
        print(f"  Sensitivity (Recall): {tp/total_positive:.3f}")
    if total_negative > 0:
        print(f"  Specificity: {tn/total_negative:.3f}")
    print()

=== TEXT BASELINE RESULTS ===

TRAIN:
  Accuracy: 0.928
  F1: 0.919
  ROC AUC: 0.9798397688468611
  Confusion Matrix:
    [[TN=129, FP=12],
     [FN=6, TP=102]]
  Class distribution:
    Dementia cases: 108
    Control cases: 141
  Sensitivity (Recall): 0.944
  Specificity: 0.915

VALID:
  Accuracy: 0.630
  F1: 0.485
  ROC AUC: 0.648076923076923
  Confusion Matrix:
    [[TN=21, FP=5],
     [FN=12, TP=8]]
  Class distribution:
    Dementia cases: 20
    Control cases: 26
  Sensitivity (Recall): 0.400
  Specificity: 0.808

TEST:
  Accuracy: 0.652
  F1: 0.111
  ROC AUC: 0.4728682170542635
  Confusion Matrix:
    [[TN=29, FP=14],
     [FN=2, TP=1]]
  Class distribution:
    Dementia cases: 3
    Control cases: 43
  Sensitivity (Recall): 0.333
  Specificity: 0.674



### In the code below:
We visualize the language model's specific word attributions that predicted Dementia (with Captum)

you can run this to create html captum results, and a summary.json with prediction scores

```bash
poetry run python -m dementia_project.explain.text_explain \
    --model_dir runs/text_baseline \
    --metadata_csv data/processed/metadata.csv \
    --splits_csv data/processed/splits.csv \
    --asr_manifest_csv data/processed/asr_whisper/asr_manifest.csv \
    --out_dir runs/text_baseline/explanations \
    --num_examples 5
```

In [None]:
import json
from pathlib import Path
from IPython.display import HTML, display
import pandas as pd

print("# üîç Model Explainability (Captum LayerIntegratedGradients)\n")
print("=" * 80)

explanations_dir = Path("../runs/text_baseline/explanations")
with open(explanations_dir / "explanations_summary.json") as f:
    summary = json.load(f)

# Create predictions table from `explain/text_explain.py`, shows inferences on a control set
print(f"\n## Analyzed {summary['num_examples']} examples from validation set\n")
results = []
for i, ex in enumerate(summary["examples"]):
    results.append({
        "Example": i + 1,
        "True": "Dementia" if ex["true_label"] == 1 else "Control",
        "Predicted": "Dementia" if ex["pred_label"] == 1 else "Control",
        "Confidence": f"{ex['confidence']:.2%}",
        "Correct": "‚úÖ" if ex["true_label"] == ex["pred_label"] else "‚ùå"
    })
display(pd.DataFrame(results))

# Display the attribution html files produced from Captum script
print("\n## Word-Level Attribution Visualizations")
print("üü¢ Green = supports dementia | üî¥ Red = supports control\n")

html_files = sorted(explanations_dir.glob("example_*_attribution.html"))
for i, html_file in enumerate(html_files[:5]):         # ALL 5
    print(f"### Example {i+1}")
    with open(html_file) as f:
        display(HTML(f.read()))
    print()

# üîç Model Explainability (Captum LayerIntegratedGradients)


## Analyzed 5 examples from validation set



Unnamed: 0,Example,True,Predicted,Confidence,Correct
0,1,Control,Control,77.50%,‚úÖ
1,2,Control,Control,88.48%,‚úÖ
2,3,Control,Control,58.79%,‚úÖ
3,4,Control,Dementia,63.21%,‚ùå
4,5,Control,Control,88.64%,‚úÖ



## Word-Level Attribution Visualizations
üü¢ Green = supports dementia | üî¥ Red = supports control

### Example 1



### Example 2



### Example 3



### Example 4



### Example 5





I don't like to see that it pays attention to punctuation and [SEP] tags, it would seem that it realized these were control markers, I will need to clean this up for the future, but its only the controls so it shouldn't matter.

The most interesting one was example 5, where it actually predicted Dementia. It is hard to infer why it actually 'flagged' these, but it seems like repeated words (like movie in example 3) indicate dementia? seems like a straighforward takeaway, but there seems to be a lot of noise which I will include in the MODEL CARD and continue, maybe fix later
