In [1]:
using MLJ
using MLJDecisionTreeInterface
using Sole, SoleDecisionTreeInterface
using CategoricalArrays
using DataFrames, JLD2, CSV
using StatsBase, Statistics
using Catch22, Audio911

### Open .jld2 file
the file contains 504 samples of respiratory sound, labeled with 2 classes: healty and pneumonia

In [2]:
ds_path = "/home/paso/Documents/Aclai/audio-rules2024/datasets/respiratory_Healthy_Pneumonia"

d = jldopen(string(ds_path, ".jld2"))
x, y = d["dataframe_validated"]
@assert x isa DataFrame
close(d)

### Audio features extraction function
This function is called for every audio sample and extract 51 features:
26 bands of the mel spectrogram,
13 coefficients of the mfcc
12 spectral features: centroid, crest, entropy, f0, flatness, flux, kurtosis, rolloff, skewness, decrease, slope, spread

In [3]:
nan_replacer!(x::AbstractArray{Float64}) = replace!(x, NaN => 0.0)

function afe(x::AbstractVector{Float64})
    # -------------------------------- parameters -------------------------------- #
    # audio module
    sr = 8000
    norm = true
    speech_detection = false
    # stft module
    stft_length = 256
    win_type = (:hann, :periodic)
    win_length = 256
    overlap_length = 128
    stft_norm = :power                      # :power, :magnitude, :pow2mag
    # mel filterbank module
    nbands = 26
    scale = :mel_htk                        # :mel_htk, :mel_slaney, :erb, :bark
    melfb_norm = :bandwidth                 # :bandwidth, :area, :none
    freq_range = (300, round(Int, sr / 2))
    # mel spectrogram module
    db_scale = false
    # mfcc module
    ncoeffs = 13
    rectification = :log                    # :log, :cubic_root
    dither = true
    # f0 module
    method = :nfc
    f0_range = (50, 400)

    # --------------------------------- functions -------------------------------- #
    # audio module
    audio = load_audio(
        file=x, 
        sr=sr, 
        norm=norm,
    );

    stftspec = get_stft(
        audio=audio, 
        stft_length=stft_length,
        win_type=win_type,
        win_length=win_length,
        overlap_length=overlap_length,
        norm=stft_norm
    );

    # mel filterbank module
    melfb = get_melfb(
        stft=stftspec,
        nbands=nbands,
        scale=scale,
        norm=melfb_norm,
        freq_range=freq_range
    );

    # mel spectrogram module
    melspec =  get_melspec(
        stft=stftspec, 
        fbank=melfb,
        db_scale=db_scale
    );

    # mfcc module
    mfcc = get_mfcc(
        source=melspec,
        ncoeffs=ncoeffs,
        rectification=rectification,
        dither=dither,
    );

    # f0 module
    f0 = get_f0(
        source=stftspec,
        method=method,
        freq_range=f0_range
    );

    # spectral features module
    spect = get_spectrals(
        source=stftspec,
        freq_range=freq_range
    );

    x_features = hcat(
        melspec.spec',
        mfcc.mfcc',
        f0.f0,
        spect.centroid,
        spect.crest,
        spect.entropy,
        spect.flatness,
        spect.flux,
        spect.kurtosis,
        spect.rolloff,
        spect.skewness,
        spect.decrease,
        spect.slope,
        spect.spread
    );

    nan_replacer!(x_features)

    return x_features
end

afe (generic function with 1 method)

### Compute DataFrame of features

In [4]:
colnames = [
    ["mel$i" for i in 1:26]...,
    ["mfcc$i" for i in 1:13]...,
    "f0", "cntrd", "crest", "entrp", "flatn", "flux", "kurts", "rllff", "skwns", "decrs", "slope", "sprd"
]

X = DataFrame([name => Vector{Float64}[] for name in colnames])

for i in 1:nrow(x)
    push!(X, collect(eachcol(afe(x[i, :audio]))))
end

### Data compression for propositional analysis

In [5]:
catch9 = [
    maximum,
    minimum,
    StatsBase.mean,
    median,
    std,
    Catch22.SB_BinaryStats_mean_longstretch1,
    Catch22.SB_BinaryStats_diff_longstretch0,
    Catch22.SB_MotifThree_quantile_hh,
    Catch22.SB_TransitionMatrix_3ac_sumdiagcov,
]

t = zeros(size(X, 1))
for i in eachcol(X)
    feature = map(x -> catch9[4](x...), eachrow(i))
    global t = hcat(t, feature)
end
Xc = DataFrame(t[:, 2:end], names(X));

yc = CategoricalArray{String,1,UInt32}(y);

In [6]:
train_ratio = 0.8

train, test = partition(eachindex(y), train_ratio, shuffle=true)
X_train, y_train = Xc[train, :], yc[train]
X_test, y_test = Xc[test, :], yc[test]

println("Training set size: ", size(X_train), " - ", length(y_train))
println("Test set size: ", size(X_test), " - ", length(y_test))

Training set size: (403, 51) - 403
Test set size: (101, 51) - 101


### Train a model

In [7]:
learned_dt_tree = begin
    Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree
    model = Tree(max_depth=-1, )
    mach = machine(model, X_train, y_train)
    fit!(mach)
    fitted_params(mach).tree
end

┌ Info: For silent loading, specify `verbosity=0`. 
└ @ Main /home/paso/.julia/packages/MLJModels/jKlfT/src/loading.jl:159


import MLJDecisionTreeInterface ✔


┌ Info: Training machine(DecisionTreeClassifier(max_depth = -1, …), …).
└ @ MLJBase /home/paso/.julia/packages/MLJBase/7nGJF/src/machines.jl:499


mfcc3 < 0.5027
├─ decrs < -0.9409
│  ├─ mfcc2 < 4.553
│  │  ├─ mel8 < 8.095e-7
│  │  │  ├─ Healthy (1/1)
│  │  │  └─ Pneumonia (15/15)
│  │  └─ mel6 < 8.727e-5
│  │     ├─ Healthy (4/4)
│  │     └─ Pneumonia (1/1)
│  └─ mfcc5 < 0.112
│     ├─ mel12 < 2.954e-6
│     │  ├─ Healthy (10/10)
│     │  └─ crest < 16.94
│     │     ├─ mel3 < 9.497e-5
│     │     │  ⋮
│     │     │  
│     │     └─ Pneumonia (12/12)
│     └─ mfcc3 < -0.584
│        ├─ sprd < 435.1
│        │  ├─ Healthy (24/24)
│        │  └─ cntrd < 1027.0
│        │     ⋮
│        │     
│        └─ mfcc4 < 0.4992
│           ├─ Healthy (97/97)
│           └─ mfcc5 < 0.1821
│              ⋮
│              
└─ mel21 < 1.593e-8
   ├─ Healthy (10/10)
   └─ cntrd < 723.2
      ├─ mel10 < 2.102e-6
      │  ├─ Pneumonia (67/67)
      │  └─ mfcc6 < 0.136
      │     ├─ Pneumonia (43/43)
      │     └─ mel23 < 2.655e-7
      │        ⋮
      │        
      └─ flatn < 0.1328
         ├─ mfcc12 < 0.05214
         │  ├─ Pneumonia (2/2)

### Convert to Sole model


In [8]:
sole_dt = solemodel(learned_dt_tree)

▣ mfcc3 < 0.5026614031627068
├✔ decrs < -0.9409334853269118
│├✔ mfcc2 < 4.552895635419956
││├✔ mel8 < 8.095081277207338e-7
│││├✔ Healthy
│││└✘ Pneumonia
││└✘ mel6 < 8.726964519086241e-5
││ ├✔ Healthy
││ └✘ Pneumonia
│└✘ mfcc5 < 0.11196263858700894
│ ├✔ mel12 < 2.954392079872265e-6
│ │├✔ Healthy
│ │└✘ crest < 16.942226922108134
│ │ ├✔ mel3 < 9.496756752633945e-5
│ │ │├✔ Healthy
│ │ │└✘ Pneumonia
│ │ └✘ Pneumonia
│ └✘ mfcc3 < -0.5839648238986663
│  ├✔ sprd < 435.1042245447893
│  │├✔ Healthy
│  │└✘ cntrd < 1026.5265262669163
│  │ ├✔ mfcc13 < 0.15869459770719566
│  │ │├✔ mfcc13 < -0.10350259557720298
│  │ ││├✔ Healthy
│  │ ││└✘ Pneumonia
│  │ │└✘ Healthy
│  │ └✘ Healthy
│  └✘ mfcc4 < 0.4992252192245208
│   ├✔ Healthy
│   └✘ mfcc5 < 0.1821043922028262
│    ├✔ Pneumonia
│    └✘ mfcc8 < 0.4864465707982981
│     ├✔ Healthy
│     └✘ Pneumonia
└✘ mel21 < 1.5931912012283505e-8
 ├✔ Healthy
 └✘ cntrd < 723.180680404329
  ├✔ mel10 < 2.1023903663400553e-6
  │├✔ Pneumonia
  │└✘ mfcc6 < 0.1359627584155

### Model inspection & rule study

In [9]:
# Make test instances flow into the model, so that test metrics can, then, be computed.
apply!(sole_dt, X_test, y_test);
# Print Sole model
printmodel(sole_dt; show_metrics = true);

▣ mfcc3 < 0.5026614031627068
├✔ decrs < -0.9409334853269118
│├✔ mfcc2 < 4.552895635419956
││├✔ mel8 < 8.095081277207338e-7
│││├✔ Healthy : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
│││└✘ Pneumonia : (ninstances = 4, ncovered = 4, confidence = 0.75, lift = 1.0)
││└✘ mel6 < 8.726964519086241e-5
││ ├✔ Healthy : (ninstances = 1, ncovered = 1, confidence = 1.0, lift = 1.0)
││ └✘ Pneumonia : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
│└✘ mfcc5 < 0.11196263858700894
│ ├✔ mel12 < 2.954392079872265e-6
│ │├✔ Healthy : (ninstances = 3, ncovered = 3, confidence = 1.0, lift = 1.0)
│ │└✘ crest < 16.942226922108134
│ │ ├✔ mel3 < 9.496756752633945e-5
│ │ │├✔ Healthy : (ninstances = 1, ncovered = 1, confidence = 1.0, lift = 1.0)
│ │ │└✘ Pneumonia : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
│ │ └✘ Pneumonia : (ninstances = 6, ncovered = 6, confidence = 0.67, lift = 1.0)
│ └✘ mfcc3 < -0.5839648238986663
│  ├✔ sprd < 435.1042245447893
│  │├✔ Healthy 

### Extract rules that are at least as good as a random baseline model

In [10]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0);
printmodel.(interesting_rules; show_metrics = true);

▣ (mfcc3 < 0.5026614031627068) ∧ (decrs < -0.9409334853269118) ∧ (mfcc2 < 4.552895635419956) ∧ (¬(mel8 < 8.095081277207338e-7))  ↣  Pneumonia : (ninstances = 101, ncovered = 4, coverage = 0.04, confidence = 0.75, lift = 1.68, natoms = 4)
▣ (mfcc3 < 0.5026614031627068) ∧ (decrs < -0.9409334853269118) ∧ (¬(mfcc2 < 4.552895635419956)) ∧ (mel6 < 8.726964519086241e-5)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.01, confidence = 1.0, lift = 1.8, natoms = 4)
▣ (mfcc3 < 0.5026614031627068) ∧ (¬(decrs < -0.9409334853269118)) ∧ (mfcc5 < 0.11196263858700894) ∧ (mel12 < 2.954392079872265e-6)  ↣  Healthy : (ninstances = 101, ncovered = 3, coverage = 0.03, confidence = 1.0, lift = 1.8, natoms = 4)
▣ (mfcc3 < 0.5026614031627068) ∧ (¬(decrs < -0.9409334853269118)) ∧ (mfcc5 < 0.11196263858700894) ∧ (¬(mel12 < 2.954392079872265e-6)) ∧ (crest < 16.942226922108134) ∧ (mel3 < 9.496756752633945e-5)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.01, confidence = 1.0, lift = 1.8, 

### Simplify rules while extracting and prettify result

In [11]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0, normalize = true);
printmodel.(interesting_rules; show_metrics = true, syntaxstring_kwargs = (; threshold_digits = 2));

▣ (mfcc3 < 0.5) ∧ (decrs < -0.94) ∧ (mfcc2 < 4.55) ∧ (mel8 ≥ 0.0)  ↣  Pneumonia : (ninstances = 101, ncovered = 4, coverage = 0.04, confidence = 0.75, lift = 1.68, natoms = 4)
▣ (mfcc3 < 0.5) ∧ (decrs < -0.94) ∧ (mfcc2 ≥ 4.55) ∧ (mel6 < 0.0)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.01, confidence = 1.0, lift = 1.8, natoms = 4)
▣ (mfcc3 < 0.5) ∧ (decrs ≥ -0.94) ∧ (mfcc5 < 0.11) ∧ (mel12 < 0.0)  ↣  Healthy : (ninstances = 101, ncovered = 3, coverage = 0.03, confidence = 1.0, lift = 1.8, natoms = 4)
▣ (mfcc3 < 0.5) ∧ (decrs ≥ -0.94) ∧ (mfcc5 < 0.11) ∧ (mel12 ≥ 0.0) ∧ (crest < 16.94) ∧ (mel3 < 0.0)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.01, confidence = 1.0, lift = 1.8, natoms = 6)
▣ (mfcc3 < 0.5) ∧ (decrs ≥ -0.94) ∧ (mfcc5 < 0.11) ∧ (mel12 ≥ 0.0) ∧ (crest ≥ 16.94)  ↣  Pneumonia : (ninstances = 101, ncovered = 6, coverage = 0.06, confidence = 0.67, lift = 1.5, natoms = 5)
▣ (mfcc3 < -0.58) ∧ (decrs ≥ -0.94) ∧ (mfcc5 ≥ 0.11) ∧ (sprd < 435.1)  ↣  Healt

### Directly access rule metrics

In [12]:
readmetrics.(listrules(sole_dt; min_lift=1.0, min_ninstances = 0))

18-element Vector{@NamedTuple{ninstances::Int64, ncovered::Int64, coverage::Float64, confidence::Float64, lift::Float64, natoms::Int64}}:
 (ninstances = 101, ncovered = 4, coverage = 0.039603960396039604, confidence = 0.75, lift = 1.6833333333333333, natoms = 4)
 (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.8035714285714284, natoms = 4)
 (ninstances = 101, ncovered = 3, coverage = 0.0297029702970297, confidence = 1.0, lift = 1.8035714285714284, natoms = 4)
 (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.8035714285714284, natoms = 6)
 (ninstances = 101, ncovered = 6, coverage = 0.0594059405940594, confidence = 0.6666666666666666, lift = 1.4962962962962962, natoms = 5)
 (ninstances = 101, ncovered = 7, coverage = 0.06930693069306931, confidence = 1.0, lift = 1.8035714285714284, natoms = 5)
 (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.8035714285714

### Show rules with an additional metric (syntax height of the rule's antecedent)

In [13]:
printmodel.(sort(interesting_rules, by = readmetrics); show_metrics = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))));

▣ (mfcc3 < 0.5026614031627068) ∧ (decrs < -0.9409334853269118) ∧ (mfcc2 ≥ 4.552895635419956) ∧ (mel6 < 8.726964519086241e-5)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.8035714285714284, natoms = 4, height = 3)
▣ (mfcc3 < -0.5839648238986663) ∧ (decrs ≥ -0.9409334853269118) ∧ (mfcc5 ≥ 0.11196263858700894) ∧ (sprd ≥ 435.1042245447893) ∧ (cntrd ≥ 1026.5265262669163)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.8035714285714284, natoms = 5, height = 4)
▣ (mfcc3 < 0.5026614031627068) ∧ (decrs ≥ -0.9409334853269118) ∧ (mfcc5 < 0.11196263858700894) ∧ (mel12 ≥ 2.954392079872265e-6) ∧ (crest < 16.942226922108134) ∧ (mel3 < 9.496756752633945e-5)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.8035714285714284, natoms = 6, height = 5)
▣ (mfcc3 < -0.5839648238986663) ∧ (decrs ≥ -0.9409334853269118) ∧ (mfcc5 ≥ 0.1119

### Pretty table of rules and their metrics

In [14]:
metricstable(interesting_rules; metrics_kwargs = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))))

┌───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┬────────────┬────────────┬──────────┬────────────┬────────────┬─────────┬────────┬────────┐
│[33;1m                                                                                                                                                                                                                                                        Antecedent [0m│[33;1m Consequent [0m│[33;1m ninstances [0m│[33;1m ncovered [0m│[33;1m   coverage [0m│[33;1m confidence [0m│[33;1m    lift [0m│[33;1m natoms [0m│[33;1m height [0m│
├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────