In [1]:
using MLJ
using MLJDecisionTreeInterface
using Sole, SoleDecisionTreeInterface
using CategoricalArrays
using DataFrames, JLD2, CSV
using StatsBase, Statistics
using Catch22, Audio911

### Open .jld2 file
the file contains 504 samples of respiratory sound, labeled with 2 classes: healty and pneumonia

In [2]:
ds_path = "/home/paso/Documents/Aclai/audio-rules2024/datasets/respiratory_Healthy_Pneumonia"

d = jldopen(string(ds_path, ".jld2"))
x, y = d["dataframe_validated"]
@assert x isa DataFrame
close(d)

### Audio features extraction function
This function is called for every audio sample and extract 51 features:
26 bands of the mel spectrogram,
13 coefficients of the mfcc
12 spectral features: centroid, crest, entropy, f0, flatness, flux, kurtosis, rolloff, skewness, decrease, slope, spread

In [3]:
nan_replacer!(x::AbstractArray{Float64}) = replace!(x, NaN => 0.0)

function afe(x::AbstractVector{Float64})
    # -------------------------------- parameters -------------------------------- #
    # audio module
    sr = 8000
    norm = true
    speech_detection = false
    # stft module
    stft_length = 256
    win_type = (:hann, :periodic)
    win_length = 256
    overlap_length = 128
    stft_norm = :power                      # :power, :magnitude, :pow2mag
    # mel filterbank module
    nbands = 26
    scale = :mel_htk                        # :mel_htk, :mel_slaney, :erb, :bark
    melfb_norm = :bandwidth                 # :bandwidth, :area, :none
    freq_range = (300, round(Int, sr / 2))
    # mel spectrogram module
    db_scale = false
    # mfcc module
    ncoeffs = 13
    rectification = :log                    # :log, :cubic_root
    dither = true
    # f0 module
    method = :nfc
    f0_range = (50, 400)

    # --------------------------------- functions -------------------------------- #
    # audio module
    audio = load_audio(
        file=x, 
        sr=sr, 
        norm=norm,
    );

    stftspec = get_stft(
        audio=audio, 
        stft_length=stft_length,
        win_type=win_type,
        win_length=win_length,
        overlap_length=overlap_length,
        norm=stft_norm
    );

    # mel filterbank module
    melfb = get_melfb(
        stft=stftspec,
        nbands=nbands,
        scale=scale,
        norm=melfb_norm,
        freq_range=freq_range
    );

    # mel spectrogram module
    melspec =  get_melspec(
        stft=stftspec, 
        fbank=melfb,
        db_scale=db_scale
    );

    # mfcc module
    mfcc = get_mfcc(
        source=melspec,
        ncoeffs=ncoeffs,
        rectification=rectification,
        dither=dither,
    );

    # f0 module
    f0 = get_f0(
        source=stftspec,
        method=method,
        freq_range=f0_range
    );

    # spectral features module
    spect = get_spectrals(
        source=stftspec,
        freq_range=freq_range
    );

    x_features = hcat(
        melspec.spec',
        mfcc.mfcc',
        f0.f0,
        spect.centroid,
        spect.crest,
        spect.entropy,
        spect.flatness,
        spect.flux,
        spect.kurtosis,
        spect.rolloff,
        spect.skewness,
        spect.decrease,
        spect.slope,
        spect.spread
    );

    nan_replacer!(x_features)

    return x_features
end

afe (generic function with 1 method)

### Compute DataFrame of features

In [11]:
colnames = [
    ["mel$i" for i in 1:26]...,
    ["mfcc$i" for i in 1:13]...,
    "f0", "cntrd", "crest", "entrp", "flatn", "flux", "kurts", "rllff", "skwns", "decrs", "slope", "sprd"
]

X = DataFrame([name => Vector{Float64}[] for name in colnames])

for i in 1:nrow(x)
    push!(X, collect(eachcol(afe(x[i, :audio]))))
end

### Data compression for propositional analysis

In [12]:
catch9 = [
    maximum,
    minimum,
    StatsBase.mean,
    median,
    std,
    Catch22.SB_BinaryStats_mean_longstretch1,
    Catch22.SB_BinaryStats_diff_longstretch0,
    Catch22.SB_MotifThree_quantile_hh,
    Catch22.SB_TransitionMatrix_3ac_sumdiagcov,
]

t = zeros(size(X, 1))
for i in eachcol(X)
    feature = map(x -> catch9[4](x...), eachrow(i))
    global t = hcat(t, feature)
end
Xc = DataFrame(t[:, 2:end], names(X));

yc = CategoricalArray(y);

In [13]:
train_ratio = 0.8

train, test = partition(eachindex(y), train_ratio, shuffle=true)
X_train, y_train = Xc[train, :], yc[train]
X_test, y_test = Xc[test, :], yc[test]

println("Training set size: ", size(X_train), " - ", length(y_train))
println("Test set size: ", size(X_test), " - ", length(y_test))

Training set size: (403, 51) - 403
Test set size: (101, 51) - 101


### Train a model

In [14]:
learned_dt_tree = begin
    Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree
    model = Tree(max_depth=-1, )
    mach = machine(model, X_train, y_train)
    fit!(mach)
    fitted_params(mach).tree
end

import MLJDecisionTreeInterface ✔


┌ Info: For silent loading, specify `verbosity=0`. 
└ @ Main /home/paso/.julia/packages/MLJModels/jKlfT/src/loading.jl:159
┌ Info: Training machine(DecisionTreeClassifier(max_depth = -1, …), …).
└ @ MLJBase /home/paso/.julia/packages/MLJBase/7nGJF/src/machines.jl:499


a29 < 0.4999
├─ a49 < -1.063
│  ├─ a28 < 4.565
│  │  ├─ Pneumonia (16/16)
│  │  └─ Healthy (4/4)
│  └─ a31 < 0.08705
│     ├─ a34 < 0.2875
│     │  ├─ a43 < 0.7714
│     │  │  ├─ a34 < -0.0256
│     │  │  │  ⋮
│     │  │  │  
│     │  │  └─ Healthy (3/3)
│     │  └─ Healthy (7/7)
│     └─ a29 < -0.5518
│        ├─ a51 < 435.1
│        │  ├─ Healthy (27/27)
│        │  └─ a38 < 0.1336
│        │     ⋮
│        │     
│        └─ a30 < 0.45
│           ├─ a5 < 1.038e-6
│           │  ⋮
│           │  
│           └─ a31 < 0.1821
│              ⋮
│              
└─ a24 < 1.419e-8
   ├─ Healthy (12/12)
   └─ a30 < 0.6907
      ├─ a31 < 0.3639
      │  ├─ a38 < 0.139
      │  │  ├─ a41 < 648.4
      │  │  │  ⋮
      │  │  │  
      │  │  └─ a32 < -0.04129
      │  │     ⋮
      │  │     
      │  └─ Healthy (7/7)
      └─ Pneumonia (66/66)


### Convert to Sole model


In [15]:
sole_dt = solemodel(learned_dt_tree)

▣ a29 < 0.49994379966299807
├✔ a49 < -1.0634279559400643
│├✔ a28 < 4.565094388756867
││├✔ Pneumonia
││└✘ Healthy
│└✘ a31 < 0.08705196308151007
│ ├✔ a34 < 0.2874586422195802
│ │├✔ a43 < 0.7713892442949303
│ ││├✔ a34 < -0.025600915581680234
│ │││├✔ Healthy
│ │││└✘ Pneumonia
│ ││└✘ Healthy
│ │└✘ Healthy
│ └✘ a29 < -0.5517645980120216
│  ├✔ a51 < 435.1042245447893
│  │├✔ Healthy
│  │└✘ a38 < 0.13355575106403889
│  │ ├✔ a33 < -0.285050658562475
│  │ │├✔ Healthy
│  │ │└✘ a43 < 0.627488672274678
│  │ │ ├✔ Healthy
│  │ │ └✘ Pneumonia
│  │ └✘ Pneumonia
│  └✘ a30 < 0.4499577415821601
│   ├✔ a5 < 1.037799461931649e-6
│   │├✔ Pneumonia
│   │└✘ Healthy
│   └✘ a31 < 0.1821043922028262
│    ├✔ a37 < 0.15118073771035784
│    │├✔ Pneumonia
│    │└✘ Healthy
│    └✘ a37 < 0.010630296835023388
│     ├✔ Healthy
│     └✘ a36 < 0.027685621343476227
│      ├✔ Pneumonia
│      └✘ Healthy
└✘ a24 < 1.418545024653986e-8
 ├✔ Healthy
 └✘ a30 < 0.6906847999463177
  ├✔ a31 < 0.3639340311977496
  │├✔ a38 < 0.139011738

### Model inspection & rule study

# GIO!!! se metto show_metrics = true mi da errore!

In [16]:
# Make test instances flow into the model, so that test metrics can, then, be computed.
apply!(sole_dt, X_test, y_test);
# Print Sole model
printmodel(sole_dt; show_metrics = false);

▣ a29 < 0.49994379966299807
├✔ a49 < -1.0634279559400643
│├✔ a28 < 4.565094388756867
││├✔ Pneumonia
││└✘ Healthy
│└✘ a31 < 0.08705196308151007
│ ├✔ a34 < 0.2874586422195802
│ │├✔ a43 < 0.7713892442949303
│ ││├✔ a34 < -0.025600915581680234
│ │││├✔ Healthy
│ │││└✘ Pneumonia
│ ││└✘ Healthy
│ │└✘ Healthy
│ └✘ a29 < -0.5517645980120216
│  ├✔ a51 < 435.1042245447893
│  │├✔ Healthy
│  │└✘ a38 < 0.13355575106403889
│  │ ├✔ a33 < -0.285050658562475
│  │ │├✔ Healthy
│  │ │└✘ a43 < 0.627488672274678
│  │ │ ├✔ Healthy
│  │ │ └✘ Pneumonia
│  │ └✘ Pneumonia
│  └✘ a30 < 0.4499577415821601
│   ├✔ a5 < 1.037799461931649e-6
│   │├✔ Pneumonia
│   │└✘ Healthy
│   └✘ a31 < 0.1821043922028262
│    ├✔ a37 < 0.15118073771035784
│    │├✔ Pneumonia
│    │└✘ Healthy
│    └✘ a37 < 0.010630296835023388
│     ├✔ Healthy
│     └✘ a36 < 0.027685621343476227
│      ├✔ Pneumonia
│      └✘ Healthy
└✘ a24 < 1.418545024653986e-8
 ├✔ Healthy
 └✘ a30 < 0.6906847999463177
  ├✔ a31 < 0.3639340311977496
  │├✔ a38 < 0.139011738

### Extract rules

In [17]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0);

MethodError: MethodError: no method matching readmetrics(::Rule{String15})

Closest candidates are:
  readmetrics(!Matched::Rule{L}; round_digits, class_share_map, additional_metrics, kwargs...) where L<:Union{AbstractFloat, Integer, String, CategoricalValue}
   @ SoleModels ~/.julia/packages/SoleModels/kTdJO/src/evaluate.jl:118
  readmetrics(!Matched::SoleModels.LeafModel{L}; class_share_map, round_digits, additional_metrics) where L<:Union{AbstractFloat, Integer, String, CategoricalValue}
   @ SoleModels ~/.julia/packages/SoleModels/kTdJO/src/evaluate.jl:67
