In [None]:
using MLJ, ModalDecisionTrees
using SoleDecisionTreeInterface, Sole
using CategoricalArrays
using DataFrames, JLD2, CSV
using Audio911
using Random;

### Open .jld2 file

In [4]:
ds_path = "/datasets/respiratory_Healthy_Pneumonia"

d = jldopen(string((@__DIR__), ds_path, ".jld2"))
x, y = d["dataframe_validated"]
@assert x isa DataFrame
close(d)

# out of memory guard
# x = vcat(x[1:10, :], x[400:410, :])
# y = vcat(y[1:10], y[400:410]);

### Audio features extraction function

In [5]:
nan_replacer!(x::AbstractArray{Float64}) = replace!(x, NaN => 0.0)

function afe(x::AbstractVector{Float64})
    # -------------------------------- parameters -------------------------------- #
    # audio module
    sr = 8000
    norm = true
    speech_detection = false
    # stft module
    stft_length = 256
    win_type = (:hann, :periodic)
    win_length = 256
    overlap_length = 128
    stft_norm = :power                      # :power, :magnitude, :pow2mag
    # mel filterbank module
    nbands = 26
    scale = :mel_htk                        # :mel_htk, :mel_slaney, :erb, :bark
    melfb_norm = :bandwidth                 # :bandwidth, :area, :none
    freq_range = (300, round(Int, sr / 2))
    # mel spectrogram module
    db_scale = false
    # mfcc module
    ncoeffs = 13
    rectification = :log                    # :log, :cubic_root
    dither = true
    # f0 module
    method = :nfc
    f0_range = (50, 400)

    # --------------------------------- functions -------------------------------- #
    # audio module
    audio = load_audio(
        file=x, 
        sr=sr, 
        norm=norm,
    );

    stftspec = get_stft(
        audio=audio, 
        stft_length=stft_length,
        win_type=win_type,
        win_length=win_length,
        overlap_length=overlap_length,
        norm=stft_norm
    );

    # mel filterbank module
    melfb = get_melfb(
        stft=stftspec,
        nbands=nbands,
        scale=scale,
        norm=melfb_norm,
        freq_range=freq_range
    );

    # mel spectrogram module
    melspec =  get_melspec(
        stft=stftspec, 
        fbank=melfb,
        db_scale=db_scale
    );

    # mfcc module
    mfcc = get_mfcc(
        source=melspec,
        ncoeffs=ncoeffs,
        rectification=rectification,
        dither=dither,
    );

    # f0 module
    f0 = get_f0(
        source=stftspec,
        method=method,
        freq_range=f0_range
    );

    # spectral features module
    spect = get_spectrals(
        source=stftspec,
        freq_range=freq_range
    );

    x_features = hcat(
        melspec.spec',
        mfcc.mfcc',
        f0.f0,
        spect.centroid,
        spect.crest,
        spect.entropy,
        spect.flatness,
        spect.flux,
        spect.kurtosis,
        spect.rolloff,
        spect.skewness,
        spect.decrease,
        spect.slope,
        spect.spread
    );

    nan_replacer!(x_features)

    return x_features
end

afe (generic function with 1 method)

### Compute audio features dataframes and labels

In [7]:
colnames = [
    ["mel$i" for i in 1:26]...,
    ["mfcc$i" for i in 1:13]...,
    "f0", "cntrd", "crest", "entrp", "flatn", "flux", "kurts", "rllff", "skwns", "decrs", "slope", "sprd"
]

X = DataFrame([name => Vector{Float64}[] for name in colnames])

for i in 1:nrow(x)
    push!(X, collect(eachcol(afe(x[i, :audio]))))
end

yc = CategoricalArray(y);

### Split dataset

In [8]:
train_ratio = 0.8
train, test = partition(eachindex(yc), train_ratio, shuffle=true)
X_train, y_train = X[train, :], yc[train]
X_test, y_test = X[test, :], yc[test]

println("Training set size: ", size(X_train), " - ", length(y_train))
println("Test set size: ", size(X_test), " - ", length(y_test))

Training set size: (403, 51) - 403
Test set size: (101, 51) - 101


### Train a model

In [9]:
learned_dt_tree = begin
    model = ModalDecisionTree(; relations = :IA7, features = [minimum, maximum])   
    mach = machine(model, X_train, y_train) |> fit!
end

report(learned_dt_tree).printmodel();

┌ Info: Precomputing logiset...
└ @ ModalDecisionTrees.MLJInterface /home/paso/.julia/packages/ModalDecisionTrees/aAlvI/src/interfaces/MLJ/wrapdataset.jl:135
┌ Info: Training machine(ModalDecisionTree(max_depth = nothing, …), …).
└ @ MLJBase /home/paso/.julia/packages/MLJBase/7nGJF/src/machines.jl:499


▣ (⟨G⟩min[mfcc3] < -0.40330013029458534)
├✔ (⟨G⟩(min[mfcc3] < -0.40330013029458534 ∧ ⟨L̅⟩min[entrp] ≥ 0.8346062223276545))
│├✔ Healthy
│└✘ (⟨G⟩(min[mfcc3] < -0.40330013029458534 ∧ ⟨A̅O̅⟩min[mfcc9] ≥ 0.77828494040009))
│ ├✔ (⟨G⟩(min[mfcc3] < -0.40330013029458534 ∧ ⟨A̅O̅⟩(min[mfcc9] ≥ 0.77828494040009 ∧ ⟨AO⟩min[mfcc9] ≥ 0.5608704397457179)))
│ │├✔ Healthy
│ │└✘ (⟨G⟩(min[mfcc3] < -0.40330013029458534 ∧ ⟨A̅O̅⟩(min[mfcc9] ≥ 0.77828494040009 ∧ ⟨L̅⟩min[mfcc7] ≥ 0.2085923879517647)))
│ │ ├✔ (⟨G⟩(min[mfcc3] < -0.40330013029458534 ∧ ⟨A̅O̅⟩(min[mfcc9] ≥ 0.77828494040009 ∧ ⟨L̅⟩(min[mfcc7] ≥ 0.2085923879517647 ∧ ⟨A̅O̅⟩max[mel1] < 0.00015175318693076242))))
│ │ │├✔ (⟨G⟩(min[mfcc3] < -0.40330013029458534 ∧ ⟨A̅O̅⟩(min[mfcc9] ≥ 0.77828494040009 ∧ ⟨L̅⟩(min[mfcc7] ≥ 0.2085923879517647 ∧ ⟨A̅O̅⟩(max[mel1] < 0.00015175318693076242 ∧ ⟨G⟩min[mfcc12] < -0.6021351572411308)))))
│ │ ││├✔ Pneumonia
│ │ ││└✘ Healthy
│ │ │└✘ Pneumonia
│ │ └✘ Healthy
│ └✘ (⟨G⟩(min[mfcc3] < -0.40330013029458534 ∧ ⟨AO⟩max[flatn] < 0.0

### Model inspection & rule study

In [10]:
_, mtree = report(mach).sprinkle(X_test, y_test)
sole_dt = ModalDecisionTrees.translate(mtree)

printmodel(sole_dt; show_metrics = true);

[32mApplying tree... 100%|███████████████████████████████████| Time: 0:00:06[39m


▣ {1}(⟨G⟩(min[V29] < -0.40330013029458534))
├✔ {1}(⟨G⟩((min[V29] < -0.40330013029458534) ∧ ⟨L̅⟩(min[V43] ≥ 0.8346062223276545)))
│├✔ {1}(⟨G⟩((min[V29] < -0.40330013029458534) ∧ ⟨L̅⟩((min[V43] ≥ 0.8346062223276545) ∧ (min[V35] ≥ -0.1401434513258752))))
││├✔ Healthy : (ninstances = 25, ncovered = 25, confidence = 0.92, lift = 1.0)
││└✘ Healthy : (ninstances = 3, ncovered = 3, confidence = 0.67, lift = 1.0)
│└✘ {1}(⟨G⟩((min[V29] < -0.40330013029458534) ∧ ⟨A̅O̅⟩(min[V35] ≥ 0.77828494040009)))
│ ├✔ {1}(⟨G⟩((min[V29] < -0.40330013029458534) ∧ ⟨A̅O̅⟩((min[V35] ≥ 0.77828494040009) ∧ ⟨AO⟩(min[V35] ≥ 0.5608704397457179))))
│ │├✔ Healthy : (ninstances = 7, ncovered = 7, confidence = 0.86, lift = 1.0)
│ │└✘ {1}(⟨G⟩((min[V29] < -0.40330013029458534) ∧ ⟨A̅O̅⟩((min[V35] ≥ 0.77828494040009) ∧ ⟨L̅⟩(min[V33] ≥ 0.2085923879517647))))
│ │ ├✔ {1}(⟨G⟩((min[V29] < -0.40330013029458534) ∧ ⟨A̅O̅⟩((min[V35] ≥ 0.77828494040009) ∧ ⟨L̅⟩((min[V33] ≥ 0.2085923879517647) ∧ ⟨A̅O̅⟩(max[V1] < 0.00015175318693076242)))))

### Extract rules that are at least as good as a random baseline model

In [11]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0);
printmodel.(interesting_rules; show_metrics = true);

▣ {1}(⟨G⟩((min[V29] < -0.40330013029458534) ∧ ⟨L̅⟩((min[V43] ≥ 0.8346062223276545) ∧ (min[V35] ≥ -0.1401434513258752))))  ↣  Healthy : (ninstances = 101, ncovered = 25, coverage = 0.25, confidence = 0.92, lift = 1.79, natoms = 3)
▣ {1}⟨G⟩((min[V29] < -0.40330013029458534) ∧ ⟨L̅⟩(min[V43] ≥ 0.8346062223276545)) ∧ [G]((min[V29] < -0.40330013029458534) → [L̅]((min[V43] ≥ 0.8346062223276545) → (min[V35] < -0.1401434513258752)))  ↣  Healthy : (ninstances = 101, ncovered = 3, coverage = 0.03, confidence = 0.67, lift = 1.29, natoms = 5)
▣ {1}⟨G⟩((min[V29] < -0.40330013029458534) ∧ ⟨A̅O̅⟩((min[V35] ≥ 0.77828494040009) ∧ ⟨AO⟩(min[V35] ≥ 0.5608704397457179))) ∧ [G]((min[V29] < -0.40330013029458534) → [L̅](min[V43] < 0.8346062223276545))  ↣  Healthy : (ninstances = 101, ncovered = 7, coverage = 0.07, confidence = 0.86, lift = 1.66, natoms = 5)
▣ {1}⟨G⟩((min[V29] < -0.40330013029458534) ∧ ⟨A̅O̅⟩((min[V35] ≥ 0.77828494040009) ∧ ⟨L̅⟩((min[V33] ≥ 0.2085923879517647) ∧ ⟨A̅O̅⟩((max[V1] < 0.000151753186

### Simplify rules while extracting and prettify result

In [12]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0, normalize = true);
printmodel.(interesting_rules; show_metrics = true, syntaxstring_kwargs = (; threshold_digits = 2));

▣ {1}(⟨G⟩(min[V29] < -0.4 ∧ ⟨L̅⟩(min[V43] ≥ 0.83 ∧ min[V35] ≥ -0.14)))  ↣  Healthy : (ninstances = 101, ncovered = 25, coverage = 0.25, confidence = 0.92, lift = 1.79, natoms = 3)
▣ {1}⟨G⟩(min[V29] < -0.4 ∧ ⟨L̅⟩min[V43] ≥ 0.83) ∧ [G](min[V29] < -0.4 → [L̅](min[V43] ≥ 0.83 → min[V35] < -0.14))  ↣  Healthy : (ninstances = 101, ncovered = 3, coverage = 0.03, confidence = 0.67, lift = 1.29, natoms = 5)
▣ {1}⟨G⟩(min[V29] < -0.4 ∧ ⟨A̅O̅⟩(min[V35] ≥ 0.78 ∧ ⟨AO⟩min[V35] ≥ 0.56)) ∧ [G](min[V29] < -0.4 → [L̅]min[V43] < 0.83)  ↣  Healthy : (ninstances = 101, ncovered = 7, coverage = 0.07, confidence = 0.86, lift = 1.66, natoms = 5)
▣ {1}⟨G⟩(min[V29] < -0.4 ∧ ⟨A̅O̅⟩(min[V35] ≥ 0.78 ∧ ⟨L̅⟩(min[V33] ≥ 0.21 ∧ ⟨A̅O̅⟩(max[V1] < 0.0 ∧ min[V1] < 0.0)))) ∧ [G](min[V29] < -0.4 → [L̅]min[V43] < 0.83) ∧ [G](min[V29] < -0.4 → [A̅O̅](min[V35] ≥ 0.78 → [AO]min[V35] < 0.56)) ∧ [G](min[V29] < -0.4 → [A̅O̅](min[V35] ≥ 0.78 → [L̅](min[V33] ≥ 0.21 → [A̅O̅](max[V1] < 0.0 → [G]min[V38] ≥ -0.6))))  ↣  Healthy : (ninsta

### Directly access rule metrics

In [13]:
readmetrics.(listrules(sole_dt; min_lift=1.0, min_ninstances = 0))

11-element Vector{@NamedTuple{ninstances::Int64, ncovered::Int64, coverage::Float64, confidence::Float64, lift::Float64, natoms::Int64}}:
 (ninstances = 101, ncovered = 25, coverage = 0.24752475247524752, confidence = 0.92, lift = 1.786923076923077, natoms = 3)
 (ninstances = 101, ncovered = 3, coverage = 0.0297029702970297, confidence = 0.6666666666666666, lift = 1.2948717948717947, natoms = 5)
 (ninstances = 101, ncovered = 7, coverage = 0.06930693069306931, confidence = 0.8571428571428571, lift = 1.6648351648351647, natoms = 5)
 (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.9423076923076923, natoms = 15)
 (ninstances = 101, ncovered = 4, coverage = 0.039603960396039604, confidence = 0.75, lift = 1.4567307692307692, natoms = 19)
 (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 2.061224489795918, natoms = 12)
 (ninstances = 101, ncovered = 2, coverage = 0.019801980198019802, confidence = 1.0, lif