In [1]:
using MLJ, ModalDecisionTrees
using SoleDecisionTreeInterface, Sole
using CategoricalArrays
using DataFrames, JLD2, CSV
using Audio911
using Random

### Open .jld2 file

In [3]:
ds_path = "/datasets/respiratory_Healthy_Pneumonia"

d = jldopen(string((@__DIR__), ds_path, ".jld2"))
x, y = d["dataframe_validated"]
@assert x isa DataFrame
close(d)

# out of memory guard
x = vcat(x[1:10, :], x[400:410, :])
y = vcat(y[1:10], y[400:410]);

### Audio features extraction function

In [4]:
nan_replacer!(x::AbstractArray{Float64}) = replace!(x, NaN => 0.0)

function afe(x::AbstractVector{Float64})
    # -------------------------------- parameters -------------------------------- #
    # audio module
    sr = 8000
    norm = true
    speech_detection = false
    # stft module
    stft_length = 256
    win_type = (:hann, :periodic)
    win_length = 256
    overlap_length = 128
    stft_norm = :power                      # :power, :magnitude, :pow2mag
    # mel filterbank module
    nbands = 26
    scale = :mel_htk                        # :mel_htk, :mel_slaney, :erb, :bark
    melfb_norm = :bandwidth                 # :bandwidth, :area, :none
    freq_range = (300, round(Int, sr / 2))
    # mel spectrogram module
    db_scale = false
    # mfcc module
    ncoeffs = 13
    rectification = :log                    # :log, :cubic_root
    dither = true
    # f0 module
    method = :nfc
    f0_range = (50, 400)

    # --------------------------------- functions -------------------------------- #
    # audio module
    audio = load_audio(
        file=x, 
        sr=sr, 
        norm=norm,
    );

    stftspec = get_stft(
        audio=audio, 
        stft_length=stft_length,
        win_type=win_type,
        win_length=win_length,
        overlap_length=overlap_length,
        norm=stft_norm
    );

    # mel filterbank module
    melfb = get_melfb(
        stft=stftspec,
        nbands=nbands,
        scale=scale,
        norm=melfb_norm,
        freq_range=freq_range
    );

    # mel spectrogram module
    melspec =  get_melspec(
        stft=stftspec, 
        fbank=melfb,
        db_scale=db_scale
    );

    # mfcc module
    mfcc = get_mfcc(
        source=melspec,
        ncoeffs=ncoeffs,
        rectification=rectification,
        dither=dither,
    );

    # f0 module
    f0 = get_f0(
        source=stftspec,
        method=method,
        freq_range=f0_range
    );

    # spectral features module
    spect = get_spectrals(
        source=stftspec,
        freq_range=freq_range
    );

    x_features = hcat(
        melspec.spec',
        mfcc.mfcc',
        f0.f0,
        spect.centroid,
        spect.crest,
        spect.entropy,
        spect.flatness,
        spect.flux,
        spect.kurtosis,
        spect.rolloff,
        spect.skewness,
        spect.decrease,
        spect.slope,
        spect.spread
    );

    nan_replacer!(x_features)

    return x_features
end

afe (generic function with 1 method)

### Compute audio features dataframes and labels

In [5]:
colnames = [
    ["mel$i" for i in 1:26]...,
    ["mfcc$i" for i in 1:13]...,
    "f0", "cntrd", "crest", "entrp", "flatn", "flux", "kurts", "rllff", "skwns", "decrs", "slope", "sprd"
]

X = DataFrame([name => Vector{Float64}[] for name in colnames])

for i in 1:nrow(x)
    push!(X, collect(eachcol(afe(x[i, :audio]))))
end

yc = CategoricalArray(y);

### Split dataset

In [8]:
train_ratio = 0.8
train, test = partition(eachindex(yc), train_ratio, shuffle=true)
X_train, y_train = X[train, :], yc[train]
X_test, y_test = X[test, :], yc[test]

println("Training set size: ", size(X_train), " - ", length(y_train))
println("Test set size: ", size(X_test), " - ", length(y_test))

Training set size: (17, 51) - 17
Test set size: (4, 51) - 4


### Train a model

In [9]:
learned_dt_tree = begin
    model = ModalDecisionTree(; relations = :IA7, features = [minimum, maximum])   
    mach = machine(model, X_train, y_train) |> fit!
end

report(learned_dt_tree).printmodel();

┌ Info: Precomputing logiset...
└ @ ModalDecisionTrees.MLJInterface /home/paso/.julia/packages/ModalDecisionTrees/ePtMQ/src/interfaces/MLJ/wrapdataset.jl:135
┌ Info: Training machine(ModalDecisionTree(max_depth = nothing, …), …).
└ @ MLJBase /home/paso/.julia/packages/MLJBase/7nGJF/src/machines.jl:499


▣ (⟨G⟩min[mel6] ≥ 3.152526691682754e-5)
├✔ (⟨G⟩(min[mel6] ≥ 3.152526691682754e-5 ∧ min[mel7] < 1.4169648430536683e-5))
│├✔ Pneumonia
│└✘ Healthy
└✘ Pneumonia


### Model inspection & rule study

In [11]:
_, mtree = report(mach).sprinkle(X_test, y_test)
sole_dt = ModalDecisionTrees.translate(mtree)

printmodel(sole_dt; show_metrics = true);

▣ {1}(⟨G⟩(min[V6] ≥ 3.152526691682754e-5))
├✔ {1}(⟨G⟩((min[V6] ≥ 3.152526691682754e-5) ∧ (min[V7] < 1.4169648430536683e-5)))
│├✔ Pneumonia : (ninstances = 4, ncovered = 4, confidence = 0.5, lift = 1.0)
│└✘ Healthy : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
└✘ Pneumonia : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)


### Extract rules that are at least as good as a random baseline model

In [12]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0);
printmodel.(interesting_rules; show_metrics = true);

▣ {1}(⟨G⟩((min[V6] ≥ 3.152526691682754e-5) ∧ (min[V7] < 1.4169648430536683e-5)))  ↣  Pneumonia : (ninstances = 4, ncovered = 4, coverage = 1.0, confidence = 0.5, lift = 1.0, natoms = 2)


### Simplify rules while extracting and prettify result

In [13]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0, normalize = true);
printmodel.(interesting_rules; show_metrics = true, syntaxstring_kwargs = (; threshold_digits = 2));

▣ {1}(⟨G⟩(min[V6] ≥ 0.0 ∧ min[V7] < 0.0))  ↣  Pneumonia : (ninstances = 4, ncovered = 4, coverage = 1.0, confidence = 0.5, lift = 1.0, natoms = 2)


### Directly access rule metrics

In [14]:
readmetrics.(listrules(sole_dt; min_lift=1.0, min_ninstances = 0))

1-element Vector{@NamedTuple{ninstances::Int64, ncovered::Int64, coverage::Float64, confidence::Float64, lift::Float64, natoms::Int64}}:
 (ninstances = 4, ncovered = 4, coverage = 1.0, confidence = 0.5, lift = 1.0, natoms = 2)