# ItaData2024
propositional example

In [1]:
using MLJ
using MLJDecisionTreeInterface
using Sole, SoleDecisionTreeInterface
using CategoricalArrays
using DataFrames, JLD2, CSV
using StatsBase, Statistics
using Catch22, Audio911

### Open .jld2 file
the file contains 504 samples of respiratory sound, labeled with 2 classes: healty and pneumonia

In [2]:
ds_path = "/datasets/respiratory_Healthy_Pneumonia"

d = jldopen(string((@__DIR__), ds_path, ".jld2"))
x, y = d["dataframe_validated"]
@assert x isa DataFrame
close(d)

### Audio features extraction function
This function is called for every audio sample and extract 51 features:
26 bands of the mel spectrogram,
13 coefficients of the mfcc
12 spectral features: centroid, crest, entropy, f0, flatness, flux, kurtosis, rolloff, skewness, decrease, slope, spread

In [3]:
nan_replacer!(x::AbstractArray{Float64}) = replace!(x, NaN => 0.0)

function afe(x::AbstractVector{Float64}; get_only_melfreq=false)
    # -------------------------------- parameters -------------------------------- #
    # audio module
    sr = 8000
    norm = true
    speech_detection = false
    # stft module
    stft_length = 256
    win_type = (:hann, :periodic)
    win_length = 256
    overlap_length = 128
    stft_norm = :power                      # :power, :magnitude, :pow2mag
    # mel filterbank module
    nbands = 26
    scale = :mel_htk                        # :mel_htk, :mel_slaney, :erb, :bark
    melfb_norm = :bandwidth                 # :bandwidth, :area, :none
    freq_range = (300, round(Int, sr / 2))
    # mel spectrogram module
    db_scale = false
    # mfcc module
    ncoeffs = 13
    rectification = :log                    # :log, :cubic_root
    dither = true
    # f0 module
    method = :nfc
    f0_range = (50, 400)

    # --------------------------------- functions -------------------------------- #
    # audio module
    audio = load_audio(
        file=x,
        sr=sr,
        norm=norm,
    );

    stftspec = get_stft(
        audio=audio,
        stft_length=stft_length,
        win_type=win_type,
        win_length=win_length,
        overlap_length=overlap_length,
        norm=stft_norm
    );

    # mel filterbank module
    melfb = get_melfb(
        stft=stftspec,
        nbands=nbands,
        scale=scale,
        norm=melfb_norm,
        freq_range=freq_range
    );

    if get_only_melfreq
        return melfb.freq
    end

    # mel spectrogram module
    melspec =  get_melspec(
        stft=stftspec,
        fbank=melfb,
        db_scale=db_scale
    );

    # mfcc module
    mfcc = get_mfcc(
        source=melspec,
        ncoeffs=ncoeffs,
        rectification=rectification,
        dither=dither,
    );

    # f0 module
    f0 = get_f0(
        source=stftspec,
        method=method,
        freq_range=f0_range
    );

    # spectral features module
    spect = get_spectrals(
        source=stftspec,
        freq_range=freq_range
    );

    x_features = hcat(
        melspec.spec',
        mfcc.mfcc',
        f0.f0,
        spect.centroid,
        spect.crest,
        spect.entropy,
        spect.flatness,
        spect.flux,
        spect.kurtosis,
        spect.rolloff,
        spect.skewness,
        spect.decrease,
        spect.slope,
        spect.spread
    );

    nan_replacer!(x_features)

    return x_features
end

afe (generic function with 1 method)

### Compute DataFrame of features

In [4]:
freq = round.(Int, afe(x[1, :audio]; get_only_melfreq=true))

col_names = [
    ["mel$i=$(freq[i])Hz" for i in 1:26]...,
    ["mfcc$i" for i in 1:13]...,
    "f0", "cntrd", "crest", "entrp", "flatn", "flux", "kurts", "rllff", "skwns", "decrs", "slope", "sprd"
]
X = DataFrame([name => Vector{Float64}[] for name in col_names])

features = [minimum, maximum]

for i in 1:nrow(x)
    push!(X, collect(eachcol(afe(x[i, :audio]))))
end

### Data compression for propositional analysis

In [5]:
catch9 = [
    maximum,
    minimum,
    StatsBase.mean,
    median,
    std,
    Catch22.SB_BinaryStats_mean_longstretch1,
    Catch22.SB_BinaryStats_diff_longstretch0,
    Catch22.SB_MotifThree_quantile_hh,
    Catch22.SB_TransitionMatrix_3ac_sumdiagcov,
]

t = zeros(size(X, 1))
for i in eachcol(X)
    feature = map(x -> catch9[4](x...), eachrow(i))
    global t = hcat(t, feature)
end
Xc = DataFrame(t[:, 2:end], names(X));

yc = CategoricalArray{String,1,UInt32}(y);

train_ratio = 0.8

train, test = partition(eachindex(y), train_ratio, shuffle=true)
# train, test = partition(eachindex(yc), train_ratio, shuffle=false) ### Debug
X_train, y_train = Xc[train, :], yc[train]
X_test, y_test = Xc[test, :], yc[test]

println("Training set size: ", size(X_train), " - ", length(y_train))
println("Test set size: ", size(X_test), " - ", length(y_test))

Training set size: (403, 51) - 403
Test set size: (101, 51) - 101


### Train a model

In [6]:
learned_dt_tree = begin
    Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree
    model = Tree(max_depth=-1, )
    mach = machine(model, X_train, y_train)
    fit!(mach)
    fitted_params(mach).tree
end

┌ Info: For silent loading, specify `verbosity=0`. 
└ @ Main /home/paso/.julia/packages/MLJModels/8W54X/src/loading.jl:159


import MLJDecisionTreeInterface ✔


┌ Info: Training machine(DecisionTreeClassifier(max_depth = -1, …), …).
└ @ MLJBase /home/paso/.julia/packages/MLJBase/7nGJF/src/machines.jl:499


mfcc3 < 0.4613
├─ decrs < -1.063
│  ├─ mfcc2 < 4.63
│  │  ├─ Pneumonia (15/15)
│  │  └─ Healthy (3/3)
│  └─ mfcc12 < 0.1095
│     ├─ mfcc3 < 0.3804
│     │  ├─ Healthy (88/88)
│     │  └─ mfcc5 < 0.2311
│     │     ├─ Pneumonia (5/5)
│     │     └─ Healthy (7/7)
│     └─ mfcc3 < 0.06346
│        ├─ mfcc7 < -0.4175
│        │  ├─ Healthy (16/16)
│        │  └─ rllff < 1953.0
│        │     ⋮
│        │     
│        └─ mfcc5 < 0.03963
│           ├─ Pneumonia (1/1)
│           └─ Healthy (34/34)
└─ mel24=3257Hz < 1.472e-8
   ├─ Healthy (13/13)
   └─ mfcc4 < 0.6907
      ├─ mfcc5 < 0.332
      │  ├─ mfcc4 < 0.485
      │  │  ├─ slope < -2.982e-7
      │  │  │  ⋮
      │  │  │  
      │  │  └─ Pneumonia (49/49)
      │  └─ entrp < 0.4341
      │     ├─ Pneumonia (1/1)
      │     └─ Healthy (8/8)
      └─ Pneumonia (66/66)


### Model inspection & rule study

In [7]:
sole_dt = solemodel(learned_dt_tree)
# Make test instances flow into the model, so that test metrics can, then, be computed.
apply!(sole_dt, X_test, y_test);
# Print Sole model
printmodel(sole_dt; show_metrics = true);

▣ mfcc3 < 0.4612770678225398
├✔ decrs < -1.0634279559400643
│├✔ mfcc2 < 4.6296118691487536
││├✔ Pneumonia : (ninstances = 1, ncovered = 1, confidence = 1.0, lift = 1.0)
││└✘ Healthy : (ninstances = 1, ncovered = 1, confidence = 1.0, lift = 1.0)
│└✘ mfcc12 < 0.10953755873124793
│ ├✔ mfcc3 < 0.3804128595447388
│ │├✔ Healthy : (ninstances = 24, ncovered = 24, confidence = 0.96, lift = 1.0)
│ │└✘ mfcc5 < 0.23105512737108047
│ │ ├✔ Pneumonia : (ninstances = 3, ncovered = 3, confidence = 0.67, lift = 1.0)
│ │ └✘ Healthy : (ninstances = 2, ncovered = 2, confidence = 1.0, lift = 1.0)
│ └✘ mfcc3 < 0.0634642779198335
│  ├✔ mfcc7 < -0.41747629090722443
│  │├✔ Healthy : (ninstances = 5, ncovered = 5, confidence = 1.0, lift = 1.0)
│  │└✘ rllff < 1953.125
│  │ ├✔ mfcc8 < 0.08466379378597842
│  │ │├✔ Healthy : (ninstances = 2, ncovered = 2, confidence = 1.0, lift = 1.0)
│  │ │└✘ mel16=1802Hz < 7.687627070720451e-7
│  │ │ ├✔ mfcc13 < 0.010784983094397974
│  │ │ │├✔ Pneumonia : (ninstances = 0, ncovere

### Extract rules that are at least as good as a random baseline model

In [8]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0);
printmodel.(interesting_rules; show_metrics = true);

▣ (mfcc3 < 0.4612770678225398) ∧ (decrs < -1.0634279559400643) ∧ (mfcc2 < 4.6296118691487536)  ↣  Pneumonia : (ninstances = 101, ncovered = 1, coverage = 0.01, confidence = 1.0, lift = 2.35, natoms = 3)
▣ (mfcc3 < 0.4612770678225398) ∧ (decrs < -1.0634279559400643) ∧ (¬(mfcc2 < 4.6296118691487536))  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.01, confidence = 1.0, lift = 1.74, natoms = 3)
▣ (mfcc3 < 0.4612770678225398) ∧ (¬(decrs < -1.0634279559400643)) ∧ (mfcc12 < 0.10953755873124793) ∧ (mfcc3 < 0.3804128595447388)  ↣  Healthy : (ninstances = 101, ncovered = 24, coverage = 0.24, confidence = 0.96, lift = 1.67, natoms = 4)
▣ (mfcc3 < 0.4612770678225398) ∧ (¬(decrs < -1.0634279559400643)) ∧ (mfcc12 < 0.10953755873124793) ∧ (¬(mfcc3 < 0.3804128595447388)) ∧ (mfcc5 < 0.23105512737108047)  ↣  Pneumonia : (ninstances = 101, ncovered = 3, coverage = 0.03, confidence = 0.67, lift = 1.57, natoms = 5)
▣ (mfcc3 < 0.4612770678225398) ∧ (¬(decrs < -1.0634279559400643)) ∧ (mfcc12 < 0

### Simplify rules while extracting and prettify result

In [9]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0, normalize = true);
printmodel.(interesting_rules; show_metrics = true, syntaxstring_kwargs = (; threshold_digits = 2));

▣ (mfcc3 < 0.46) ∧ (decrs < -1.06) ∧ (mfcc2 < 4.63)  ↣  Pneumonia : (ninstances = 101, ncovered = 1, coverage = 0.01, confidence = 1.0, lift = 2.35, natoms = 3)
▣ (mfcc3 < 0.46) ∧ (decrs < -1.06) ∧ (mfcc2 ≥ 4.63)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.01, confidence = 1.0, lift = 1.74, natoms = 3)
▣ (mfcc3 < 0.38) ∧ (decrs ≥ -1.06) ∧ (mfcc12 < 0.11)  ↣  Healthy : (ninstances = 101, ncovered = 24, coverage = 0.24, confidence = 0.96, lift = 1.67, natoms = 3)
▣ (mfcc3 ∈ [0.38,0.46)) ∧ (decrs ≥ -1.06) ∧ (mfcc12 < 0.11) ∧ (mfcc5 < 0.23)  ↣  Pneumonia : (ninstances = 101, ncovered = 3, coverage = 0.03, confidence = 0.67, lift = 1.57, natoms = 4)
▣ (mfcc3 ∈ [0.38,0.46)) ∧ (decrs ≥ -1.06) ∧ (mfcc12 < 0.11) ∧ (mfcc5 ≥ 0.23)  ↣  Healthy : (ninstances = 101, ncovered = 2, coverage = 0.02, confidence = 1.0, lift = 1.74, natoms = 4)
▣ (mfcc3 < 0.06) ∧ (decrs ≥ -1.06) ∧ (mfcc12 ≥ 0.11) ∧ (mfcc7 < -0.42)  ↣  Healthy : (ninstances = 101, ncovered = 5, coverage = 0.05, confidence =

### Directly access rule metrics

In [10]:
readmetrics.(listrules(sole_dt; min_lift=1.0, min_ninstances = 0))

18-element Vector{@NamedTuple{ninstances::Int64, ncovered::Int64, coverage::Float64, confidence::Float64, lift::Float64, natoms::Int64}}:
 (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 2.3488372093023258, natoms = 3)
 (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.7413793103448276, natoms = 3)
 (ninstances = 101, ncovered = 24, coverage = 0.2376237623762376, confidence = 0.9583333333333334, lift = 1.66882183908046, natoms = 4)
 (ninstances = 101, ncovered = 3, coverage = 0.0297029702970297, confidence = 0.6666666666666666, lift = 1.565891472868217, natoms = 5)
 (ninstances = 101, ncovered = 2, coverage = 0.019801980198019802, confidence = 1.0, lift = 1.7413793103448276, natoms = 5)
 (ninstances = 101, ncovered = 5, coverage = 0.04950495049504951, confidence = 1.0, lift = 1.7413793103448276, natoms = 5)
 (ninstances = 101, ncovered = 2, coverage = 0.019801980198019802, confidence = 1.0, lift = 1.7

### Show rules with an additional metric (syntax height of the rule's antecedent)

In [11]:
printmodel.(sort(interesting_rules, by = readmetrics); show_metrics = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))));

▣ (mfcc3 < 0.4612770678225398) ∧ (decrs < -1.0634279559400643) ∧ (mfcc2 ≥ 4.6296118691487536)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.7413793103448276, natoms = 3, height = 2)
▣ (mfcc3 < 0.0634642779198335) ∧ (decrs ≥ -1.0634279559400643) ∧ (mfcc12 ≥ 0.10953755873124793) ∧ (mfcc7 ≥ -0.41747629090722443) ∧ (rllff < 1953.125) ∧ (mfcc8 ≥ 0.08466379378597842) ∧ (mel16=1802Hz < 7.687627070720451e-7) ∧ (mfcc13 ≥ 0.010784983094397974)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.7413793103448276, natoms = 8, height = 7)
▣ (mfcc3 < 0.4612770678225398) ∧ (decrs < -1.0634279559400643) ∧ (mfcc2 < 4.6296118691487536)  ↣  Pneumonia : (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 2.3488372093023258, natoms = 3, height = 2)
▣ (mfcc3 ∈ [0.0634642779198335,0.4612770678225398)) ∧ (decrs ≥ -1.0634279559400643) ∧ (mfcc12 ≥ 0.109537558

### Pretty table of rules and their metrics

In [12]:
metricstable(interesting_rules; metrics_kwargs = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))))

┌───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┬────────────┬────────────┬──────────┬────────────┬────────────┬─────────┬────────┬────────┐
│[33;1m                                                                                                                                                                                                                                                    Antecedent [0m│[33;1m Consequent [0m│[33;1m ninstances [0m│[33;1m ncovered [0m│[33;1m   coverage [0m│[33;1m confidence [0m│[33;1m    lift [0m│[33;1m natoms [0m│[33;1m height [0m│
├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────