# ItaData2024
propositional example

In [1]:
using MLJ
using MLJDecisionTreeInterface
using Sole, SoleDecisionTreeInterface
using CategoricalArrays
using DataFrames, JLD2, CSV
using StatsBase, Statistics
using Catch22, Audio911

### Open .jld2 file
the file contains 504 samples of respiratory sound, labeled with 2 classes: healty and pneumonia

In [2]:
ds_path = "/datasets/respiratory_Healthy_Pneumonia"

d = jldopen(string((@__DIR__), ds_path, ".jld2"))
x, y = d["dataframe_validated"]
@assert x isa DataFrame
close(d)

### Audio features extraction function
This function is called for every audio sample and extract 51 features:
26 bands of the mel spectrogram,
13 coefficients of the mfcc
12 spectral features: centroid, crest, entropy, f0, flatness, flux, kurtosis, rolloff, skewness, decrease, slope, spread

In [None]:
nan_replacer!(x::AbstractArray{Float64}) = replace!(x, NaN => 0.0)

function afe(x::AbstractVector{Float64}; get_only_melfreq=false)
    # -------------------------------- parameters -------------------------------- #
    # audio module
    sr = 8000
    norm = true
    speech_detection = false
    # stft module
    stft_length = 256
    win_type = (:hann, :periodic)
    win_length = 256
    overlap_length = 128
    stft_norm = :power                      # :power, :magnitude, :pow2mag
    # mel filterbank module
    nbands = 26
    scale = :mel_htk                        # :mel_htk, :mel_slaney, :erb, :bark
    melfb_norm = :bandwidth                 # :bandwidth, :area, :none
    freq_range = (300, round(Int, sr / 2))
    # mel spectrogram module
    db_scale = false
    # mfcc module
    ncoeffs = 13
    rectification = :log                    # :log, :cubic_root
    dither = true
    # f0 module
    method = :nfc
    f0_range = (50, 400)

    # --------------------------------- functions -------------------------------- #
    # audio module
    audio = load_audio(
        file=x,
        sr=sr,
        norm=norm,
    );

    stftspec = get_stft(
        audio=audio,
        stft_length=stft_length,
        win_type=win_type,
        win_length=win_length,
        overlap_length=overlap_length,
        norm=stft_norm
    );

    # mel filterbank module
    melfb = get_melfb(
        stft=stftspec,
        nbands=nbands,
        scale=scale,
        norm=melfb_norm,
        freq_range=freq_range
    );

    if get_only_melfreq
        return melfb.freq
    end

    # mel spectrogram module
    melspec =  get_melspec(
        stft=stftspec,
        fbank=melfb,
        db_scale=db_scale
    );

    # mfcc module
    mfcc = get_mfcc(
        source=melspec,
        ncoeffs=ncoeffs,
        rectification=rectification,
        dither=dither,
    );

    # f0 module
    f0 = get_f0(
        source=stftspec,
        method=method,
        freq_range=f0_range
    );

    # spectral features module
    spect = get_spectrals(
        source=stftspec,
        freq_range=freq_range
    );

    x_features = hcat(
        melspec.spec',
        mfcc.mfcc',
        f0.f0,
        spect.centroid,
        spect.crest,
        spect.entropy,
        spect.flatness,
        spect.flux,
        spect.kurtosis,
        spect.rolloff,
        spect.skewness,
        spect.decrease,
        spect.slope,
        spect.spread
    );

    nan_replacer!(x_features)

    return x_features
end

### Compute DataFrame of features

In [4]:
color_code = Dict(:red => 31, :green => 32, :yellow => 33, :blue => 34, :magenta => 35, :cyan => 36)
freq = round.(Int, afe(x[1, :audio]; get_only_melfreq=true))
r_select = r"\e\[\d+m(.*?)\e\[0m"
variable_names = [
    ["\e[$(color_code[:yellow])mmel$i=$(freq[i])Hz\e[0m" for i in 1:26]...,
    ["\e[$(color_code[:red])mmfcc$i\e[0m" for i in 1:13]...,
    "\e[$(color_code[:green])mf0\e[0m", "\e[$(color_code[:cyan])mcntrd\e[0m", "\e[$(color_code[:cyan])mcrest\e[0m",
    "\e[$(color_code[:cyan])mentrp\e[0m", "\e[$(color_code[:cyan])mflatn\e[0m", "\e[$(color_code[:cyan])mflux\e[0m",
    "\e[$(color_code[:cyan])mkurts\e[0m", "\e[$(color_code[:cyan])mrllff\e[0m", "\e[$(color_code[:cyan])mskwns\e[0m",
    "\e[$(color_code[:cyan])mdecrs\e[0m", "\e[$(color_code[:cyan])mslope\e[0m", "\e[$(color_code[:cyan])msprd\e[0m"
]

X = DataFrame([name => Vector{Float64}[] for name in [match(r_select, v)[1] for v in variable_names]])

for i in 1:nrow(x)
    push!(X, collect(eachcol(afe(x[i, :audio]))))
end

### Data compression for propositional analysis

In [5]:
catch9 = [
    maximum,
    minimum,
    StatsBase.mean,
    median,
    std,
    Catch22.SB_BinaryStats_mean_longstretch1,
    Catch22.SB_BinaryStats_diff_longstretch0,
    Catch22.SB_MotifThree_quantile_hh,
    Catch22.SB_TransitionMatrix_3ac_sumdiagcov,
]

t = zeros(size(X, 1))
for i in eachcol(X)
    feature = map(x -> catch9[4](x...), eachrow(i))
    global t = hcat(t, feature)
end
Xc = DataFrame(t[:, 2:end], names(X));

yc = CategoricalArray{String,1,UInt32}(y);

train_ratio = 0.8

train, test = partition(eachindex(y), train_ratio, shuffle=true)
# train, test = partition(eachindex(yc), train_ratio, shuffle=false) ### Debug
X_train, y_train = Xc[train, :], yc[train]
X_test, y_test = Xc[test, :], yc[test]

println("Training set size: ", size(X_train), " - ", length(y_train))
println("Test set size: ", size(X_test), " - ", length(y_test))

Training set size: (403, 51) - 403
Test set size: (101, 51) - 101


### Train a model

In [6]:
learned_dt_tree = begin
    Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree
    model = Tree(max_depth=-1, )
    mach = machine(model, X_train, y_train)
    fit!(mach)
    fitted_params(mach).tree
end

┌ Info: For silent loading, specify `verbosity=0`. 
└ @ Main /home/paso/.julia/packages/MLJModels/8W54X/src/loading.jl:159


import MLJDecisionTreeInterface ✔


┌ Info: Training machine(DecisionTreeClassifier(max_depth = -1, …), …).
└ @ MLJBase /home/paso/.julia/packages/MLJBase/7nGJF/src/machines.jl:499


mfcc3 < 0.5497
├─ mfcc2 < 2.331
│  ├─ Healthy (45/45)
│  └─ mfcc9 < 0.2035
│     ├─ mfcc8 < 0.09316
│     │  ├─ mfcc6 < -0.04626
│     │  │  ├─ Pneumonia (1/1)
│     │  │  └─ Healthy (27/27)
│     │  └─ mel26=3738Hz < 1.921e-6
│     │     ├─ mfcc7 < -0.01596
│     │     │  ⋮
│     │     │  
│     │     └─ Healthy (14/14)
│     └─ mfcc12 < 0.3
│        ├─ Healthy (43/43)
│        └─ Pneumonia (1/1)
└─ mel25=3491Hz < 1.13e-8
   ├─ Healthy (11/11)
   └─ mfcc3 < 0.9517
      ├─ mel1=359Hz < 0.0002043
      │  ├─ mfcc5 < -0.09931
      │  │  ├─ mel6=710Hz < 1.841e-5
      │  │  │  ⋮
      │  │  │  
      │  │  └─ mfcc7 < -0.08553
      │  │     ⋮
      │  │     
      │  └─ mfcc5 < 0.3077
      │     ├─ mel12=1289Hz < 6.603e-6
      │     │  ⋮
      │     │  
      │     └─ Healthy (7/7)
      └─ mfcc12 < 0.1039
         ├─ Pneumonia (72/72)
         └─ mfcc10 < 0.08356
            ├─ mel1=359Hz < 0.0009976
            │  ⋮
            │  
            └─ Pneumonia (23/23)


### Model inspection & rule study

In [7]:
sole_dt = solemodel(learned_dt_tree)
# Make test instances flow into the model, so that test metrics can, then, be computed.
apply!(sole_dt, X_test, y_test);
# Print Sole model
printmodel(sole_dt; show_metrics = true, variable_names_map = variable_names);

[34m▣[0m [1m[31mmfcc3[0m [1m<[0m[0m 0.5496694323834052
├✔ [1m[31mmfcc2[0m [1m<[0m[0m 2.330758535389905
│├✔ Healthy : (ninstances = 10, ncovered = 10, confidence = 1.0, lift = 1.0)
│└✘ [1m[31mmfcc9[0m [1m<[0m[0m 0.20354855851824852
│ ├✔ [1m[31mmfcc8[0m [1m<[0m[0m 0.09315706434102256
│ │├✔ [1m[31mmfcc6[0m [1m<[0m[0m -0.046255002023545475
│ ││├✔ Pneumonia : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
│ ││└✘ Healthy : (ninstances = 9, ncovered = 9, confidence = 0.78, lift = 1.0)
│ │└✘ [1m[33mmel26=3738Hz[0m [1m<[0m[0m 1.921368089902453e-6
│ │ ├✔ [1m[31mmfcc7[0m [1m<[0m[0m -0.015962461569718516
│ │ │├✔ [1m[33mmel18=2106Hz[0m [1m<[0m[0m 3.305510299482089e-6
│ │ ││├✔ [1m[31mmfcc5[0m [1m<[0m[0m 1.1141034082276149
│ │ │││├✔ [1m[36mcrest[0m [1m<[0m[0m 16.08596821857623
│ │ ││││├✔ [1m[31mmfcc9[0m [1m<[0m[0m 0.13824315599479117
│ │ │││││├✔ Pneumonia : (ninstances = 1, ncovered = 1, confidence = 0.0, lift = NaN)
│ 

### Extract rules that are at least as good as a random baseline model

In [8]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0);
printmodel.(interesting_rules; show_metrics = true, variable_names_map = variable_names);

[34m▣[0m ([1m[31mmfcc3[0m [1m<[0m[0m 0.5496694323834052) ∧ ([1m[31mmfcc2[0m [1m<[0m[0m 2.330758535389905)  ↣  Healthy : (ninstances = 101, ncovered = 10, coverage = 0.1, confidence = 1.0, lift = 1.8, natoms = 2)
[34m▣[0m ([1m[31mmfcc3[0m [1m<[0m[0m 0.5496694323834052) ∧ (¬([1m[31mmfcc2[0m [1m<[0m[0m 2.330758535389905)) ∧ ([1m[31mmfcc9[0m [1m<[0m[0m 0.20354855851824852) ∧ ([1m[31mmfcc8[0m [1m<[0m[0m 0.09315706434102256) ∧ (¬([1m[31mmfcc6[0m [1m<[0m[0m -0.046255002023545475))  ↣  Healthy : (ninstances = 101, ncovered = 9, coverage = 0.09, confidence = 0.78, lift = 1.4, natoms = 5)
[34m▣[0m ([1m[31mmfcc3[0m [1m<[0m[0m 0.5496694323834052) ∧ (¬([1m[31mmfcc2[0m [1m<[0m[0m 2.330758535389905)) ∧ ([1m[31mmfcc9[0m [1m<[0m[0m 0.20354855851824852) ∧ (¬([1m[31mmfcc8[0m [1m<[0m[0m 0.09315706434102256)) ∧ ([1m[33mmel26=3738Hz[0m [1m<[0m[0m 1.921368089902453e-6) ∧ ([1m[31mmfcc7[0m [1m<[0m[0m -0.015962461569718516) ∧ (

### Simplify rules while extracting and prettify result

In [9]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0, normalize = true);
printmodel.(interesting_rules; show_metrics = true, syntaxstring_kwargs = (; threshold_digits = 2), variable_names_map = variable_names);

[34m▣[0m ([1m[31mmfcc3[0m [1m<[0m[0m 0.55) ∧ ([1m[31mmfcc2[0m [1m<[0m[0m 2.33)  ↣  Healthy : (ninstances = 101, ncovered = 10, coverage = 0.1, confidence = 1.0, lift = 1.8, natoms = 2)
[34m▣[0m ([1m[31mmfcc3[0m [1m<[0m[0m 0.55) ∧ ([1m[31mmfcc2[0m [1m≥[0m[0m 2.33) ∧ ([1m[31mmfcc9[0m [1m<[0m[0m 0.2) ∧ ([1m[31mmfcc8[0m [1m<[0m[0m 0.09) ∧ ([1m[31mmfcc6[0m [1m≥[0m[0m -0.05)  ↣  Healthy : (ninstances = 101, ncovered = 9, coverage = 0.09, confidence = 0.78, lift = 1.4, natoms = 5)
[34m▣[0m ([1m[31mmfcc3[0m [1m<[0m[0m 0.55) ∧ ([1m[31mmfcc2[0m [1m≥[0m[0m 2.33) ∧ ([1m[31mmfcc9[0m [1m<[0m[0m 0.2) ∧ ([1m[31mmfcc8[0m [1m≥[0m[0m 0.09) ∧ ([1m[33mmel26=3738Hz[0m [1m<[0m[0m 0.0) ∧ ([1m[31mmfcc7[0m [1m<[0m[0m -0.02) ∧ ([1m[33mmel18=2106Hz[0m [1m<[0m[0m 0.0) ∧ ([1m[31mmfcc5[0m [1m<[0m[0m 1.11) ∧ ([1m[36mcrest[0m [1m≥[0m[0m 16.09)  ↣  Healthy : (ninstances = 101, ncovered = 12, coverage = 0.12, confidence

### Directly access rule metrics

In [10]:
readmetrics.(listrules(sole_dt; min_lift=1.0, min_ninstances = 0))

17-element Vector{@NamedTuple{ninstances::Int64, ncovered::Int64, coverage::Float64, confidence::Float64, lift::Float64, natoms::Int64}}:
 (ninstances = 101, ncovered = 10, coverage = 0.09900990099009901, confidence = 1.0, lift = 1.8035714285714284, natoms = 2)
 (ninstances = 101, ncovered = 9, coverage = 0.0891089108910891, confidence = 0.7777777777777778, lift = 1.4027777777777777, natoms = 5)
 (ninstances = 101, ncovered = 12, coverage = 0.1188118811881188, confidence = 0.9166666666666666, lift = 1.6532738095238093, natoms = 9)
 (ninstances = 101, ncovered = 2, coverage = 0.019801980198019802, confidence = 0.5, lift = 1.1222222222222222, natoms = 8)
 (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.8035714285714284, natoms = 9)
 (ninstances = 101, ncovered = 8, coverage = 0.07920792079207921, confidence = 0.625, lift = 1.4027777777777777, natoms = 7)
 (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift 

### Show rules with an additional metric (syntax height of the rule's antecedent)

In [11]:
printmodel.(sort(interesting_rules, by = readmetrics); show_metrics = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))), variable_names_map = variable_names);

[34m▣[0m ([31mmfcc3[0m ∈ [0.5496694323834052,0.951699595249857)) ∧ ([1m[33mmel25=3491Hz[0m [1m≥[0m[0m 1.1301846692747932e-8) ∧ ([1m[33mmel1=359Hz[0m [1m<[0m[0m 0.00020426911077869796) ∧ ([1m[31mmfcc5[0m [1m≥[0m[0m -0.0993068738544915) ∧ ([1m[31mmfcc7[0m [1m≥[0m[0m -0.08553336842669973) ∧ ([31mmfcc12[0m ∈ [0.15390044342266546,0.15747684201500028))  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.8035714285714284, natoms = 6, height = 5)
[34m▣[0m ([1m[31mmfcc3[0m [1m<[0m[0m 0.5496694323834052) ∧ ([1m[31mmfcc2[0m [1m≥[0m[0m 2.330758535389905) ∧ ([1m[31mmfcc9[0m [1m<[0m[0m 0.08630403429612474) ∧ ([1m[31mmfcc8[0m [1m≥[0m[0m 0.09315706434102256) ∧ ([1m[33mmel26=3738Hz[0m [1m<[0m[0m 1.3876455014298657e-6) ∧ ([1m[31mmfcc7[0m [1m<[0m[0m -0.015962461569718516) ∧ ([1m[33mmel18=2106Hz[0m [1m≥[0m[0m 3.305510299482089e-6)  ↣  Healthy : (ninstances = 101, ncovered = 1, c

### Pretty table of rules and their metrics

In [12]:
metricstable(interesting_rules; variable_names_map = variable_names, metrics_kwargs = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))))

┌───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┬────────────┬────────────┬──────────┬────────────┬────────────┬─────────┬────────┬────────┐
│[33;1m                                                                                                                                                                                                                                                                                                Antecedent [0m│[33;1m Consequent [0m│[33;1m ninstances [0m│[33;1m ncovered [0m│[33;1m   coverage [0m│[33;1m confidence [0m│[33;1m    lift [0m│[33;1m natoms [0m│[33;1m height [0m│
├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────