# ItaData2024
propositional example

In [None]:
using Pkg
Pkg.activate(".")
using MLJ, ModalDecisionTrees
using SoleDecisionTreeInterface, Sole, SoleData
using CategoricalArrays
using DataFrames, JLD2, CSV
using Audio911
using Random
using StatsBase, Catch22
using Test
using Plots

### Settings

In [2]:
sr = 8000
audioparams = (
    sr = sr,
    nfft = 256,
    nbands = 14,
    freq_range = (300, round(Int, sr / 2)),
    db_scale = true,
)

experiment = :pneumonia
# experiment = :bronchiectasis

findhealthy = y -> findall(x -> x == "Healthy", y)
if experiment == :pneumonia
    ds_path = "/datasets/respiratory_Healthy_Pneumonia"
    findsick = y -> findall(x -> x == "Pneumonia", y)
elseif experiment == :bronchiectasis
    ds_path = "/datasets/respiratory_Healthy_Bronchiectasis"
    findsick = y -> findall(x -> x == "Bronchiectasis", y)
else
    error("Unknown type of experiment: $experiment.")
end

d = jldopen(string((@__DIR__), ds_path, ".jld2"))
x, y = d["dataframe_validated"]
@assert x isa DataFrame
close(d)

### Audio features extraction function
This function is called for every audio sample and extract 25 features:
14 bands of the mel spectrogram,
11 spectral features: centroid, crest, entropy, flatness, flux, kurtosis, rolloff, skewness, decrease, slope, spread

In [3]:
function afe(x::AbstractVector{Float64}; sr::Int64, nfft::Int64, nbands::Int64, freq_range::Tuple{Int64, Int64}, db_scale::Bool, get_only_melfreq=false)
    # -------------------------------- parameters -------------------------------- #
    # audio module
    sr = sr
    norm = true
    speech_detection = false
    # stft module
    nfft = nfft
    win_type = (:hann, :periodic)
    win_length = nfft
    overlap_length = round(Int, nfft / 2)
    stft_norm = :power                      # :power, :magnitude, :pow2mag
    # mel filterbank module
    nbands = nbands
    scale = :mel_htk                        # :mel_htk, :mel_slaney, :erb, :bark
    melfb_norm = :bandwidth                 # :bandwidth, :area, :none
    freq_range = freq_range
    # mel spectrogram module
    db_scale = db_scale

    # --------------------------------- functions -------------------------------- #
    # audio module
    audio = load_audio(
        file=x,
        sr=sr,
        norm=norm,
    );

    stftspec = get_stft(
        audio=audio,
        nfft=nfft,
        win_type=win_type,
        win_length=win_length,
        overlap_length=overlap_length,
        norm=stft_norm
    );

    # mel filterbank module
    melfb = get_melfb(
        stft=stftspec,
        nbands=nbands,
        scale=scale,
        norm=melfb_norm,
        freq_range=freq_range
    );

    if get_only_melfreq
        return melfb.data.freq
    end

    # mel spectrogram module
    melspec =  get_melspec(
        stft=stftspec,
        fbank=melfb,
        db_scale=db_scale
    );

    # spectral features module
    spect = get_spectrals(
        source=stftspec,
        freq_range=freq_range
    );

    hcat(
        melspec.spec',
        spect.centroid,
        spect.crest,
        spect.entropy,
        spect.flatness,
        spect.flux,
        spect.kurtosis,
        spect.rolloff,
        spect.skewness,
        spect.decrease,
        spect.slope,
        spect.spread
    );
end

afe (generic function with 1 method)

### Prepare dataset for analysis

In [4]:
color_code = Dict(:red => 31, :green => 32, :yellow => 33, :blue => 34, :magenta => 35, :cyan => 36)
freq = round.(Int, afe(x[1, :audio]; audioparams..., get_only_melfreq=true))
r_select = r"\e\[\d+m(.*?)\e\[0m"

catch9_f = ["max", "min", "mean", "med", "std", "bsm", "bsd", "qnt", "3ac"]
variable_names = vcat([
    vcat(
        ["\e[$(color_code[:yellow])mmel$i=$(freq[i])Hz->$j\e[0m" for i in 1:audioparams.nbands]...,
        "\e[$(color_code[:cyan])mcntrd->$j\e[0m", "\e[$(color_code[:cyan])mcrest->$j\e[0m",
        "\e[$(color_code[:cyan])mentrp->$j\e[0m", "\e[$(color_code[:cyan])mflatn->$j\e[0m", "\e[$(color_code[:cyan])mflux->$j\e[0m",
        "\e[$(color_code[:cyan])mkurts->$j\e[0m", "\e[$(color_code[:cyan])mrllff->$j\e[0m", "\e[$(color_code[:cyan])mskwns->$j\e[0m",
        "\e[$(color_code[:cyan])mdecrs->$j\e[0m", "\e[$(color_code[:cyan])mslope->$j\e[0m", "\e[$(color_code[:cyan])msprd->$j\e[0m"
    )
    for j in catch9_f
]...)

catch9 = [
    maximum,
    minimum,
    StatsBase.mean,
    median,
    std,
    Catch22.SB_BinaryStats_mean_longstretch1,
    Catch22.SB_BinaryStats_diff_longstretch0,
    Catch22.SB_MotifThree_quantile_hh,
    Catch22.SB_TransitionMatrix_3ac_sumdiagcov,
]

X = DataFrame([name => Float64[] for name in [match(r_select, v)[1] for v in variable_names]])

audio_feats = [afe(row[:audio]; audioparams...) for row in eachrow(x)]
push!(X, vcat([vcat([map(func, eachcol(row)) for func in catch9]...) for row in audio_feats])...)

yc = CategoricalArray(y)

train_ratio = 0.8

train, test = partition(eachindex(yc), train_ratio, shuffle=true)
X_train, y_train = X[train, :], yc[train]
X_test, y_test = X[test, :], yc[test]

println("Training set size: ", size(X_train), " - ", length(y_train))
println("Test set size: ", size(X_test), " - ", length(y_test))

Training set size: (403, 225) - 403
Test set size: (101, 225) - 101


### Train a model

In [5]:
learned_dt_tree = begin
    Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree
    model = Tree(max_depth=-1, )
    mach = machine(model, X_train, y_train)
    fit!(mach)
    fitted_params(mach).tree
end

┌ Info: For silent loading, specify `verbosity=0`. 
└ @ Main /home/paso/.julia/packages/MLJModels/8W54X/src/loading.jl:159


import MLJDecisionTreeInterface ✔


┌ Info: Training machine(DecisionTreeClassifier(max_depth = -1, …), …).
└ @ MLJBase /home/paso/.julia/packages/MLJBase/7nGJF/src/machines.jl:499


cntrd->max < 1307.0
├─ skwns->min < 1.268
│  ├─ mel6=1157Hz->min < -5.16
│  │  ├─ mel2=529Hz->min < -5.318
│  │  │  ├─ mel7=1359Hz->qnt < 2.15
│  │  │  │  ├─ mel4=811Hz->max < -3.749
│  │  │  │  │  ⋮
│  │  │  │  │  
│  │  │  │  └─ mel4=811Hz->mean < -4.891
│  │  │  │     ⋮
│  │  │  │     
│  │  │  └─ cntrd->std < 246.7
│  │  │     ├─ mel8=1583Hz->min < -5.483
│  │  │     │  ⋮
│  │  │     │  
│  │  │     └─ Healthy (2/2)
│  │  └─ flatn->min < 0.008407
│  │     ├─ Healthy (26/26)
│  │     └─ mel13=3124Hz->mean < -5.418
│  │        ├─ Pneumonia (5/5)
│  │        └─ Healthy (6/6)
│  └─ mel13=3124Hz->min < -8.135
│     ├─ Healthy (4/4)
│     └─ entrp->qnt < 2.183
│        ├─ mel10=2106Hz->mean < -5.045
│        │  ├─ Pneumonia (111/111)
│        │  └─ Healthy (1/1)
│        └─ slope->max < -6.039e-8
│           ├─ Healthy (2/2)
│           └─ Pneumonia (2/2)
└─ mel8=1583Hz->std < 0.1833
   ├─ Pneumonia (2/2)
   └─ mel12=2749Hz->std < 0.3126
      ├─ Healthy (98/98)
      └─ mel6=1157Hz->std

### Model inspection & rule study

In [6]:
sole_dt = solemodel(learned_dt_tree)
# Make test instances flow into the model, so that test metrics can, then, be computed.
apply!(sole_dt, X_test, y_test);
# Print Sole model
printmodel(sole_dt; show_metrics = true, variable_names_map = variable_names);

[34m▣[0m [1m[36mcntrd->max[0m [1m<[0m[0m 1307.4545801505142
├✔ [1m[36mskwns->min[0m [1m<[0m[0m 1.2677830790873983
│├✔ [1m[33mmel6=1157Hz->min[0m [1m<[0m[0m -5.160466644906318
││├✔ [1m[33mmel2=529Hz->min[0m [1m<[0m[0m -5.318134764797197
│││├✔ [1m[33mmel7=1359Hz->qnt[0m [1m<[0m[0m 2.149699676804287
││││├✔ [1m[33mmel4=811Hz->max[0m [1m<[0m[0m -3.749222280014931
│││││├✔ Healthy : (ninstances = 8, ncovered = 8, confidence = 0.5, lift = 1.0)
│││││└✘ [1m[33mmel4=811Hz->med[0m [1m<[0m[0m -5.211889812618957
│││││ ├✔ Pneumonia : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
│││││ └✘ [1m[36mentrp->mean[0m [1m<[0m[0m 0.5575021763809584
│││││  ├✔ [1m[33mmel12=2749Hz->std[0m [1m<[0m[0m 0.32395301625944595
│││││  │├✔ Pneumonia : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
│││││  │└✘ Healthy : (ninstances = 2, ncovered = 2, confidence = 1.0, lift = 1.0)
│││││  └✘ [1m[36mcntrd->mean[0m [1m<[0m[0m 821.351284681

### Extract rules that are at least as good as a random baseline model

In [7]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0);
printmodel.(interesting_rules; show_metrics = true, variable_names_map = variable_names);

[34m▣[0m ([1m[36mcntrd->max[0m [1m<[0m[0m 1307.4545801505142) ∧ ([1m[36mskwns->min[0m [1m<[0m[0m 1.2677830790873983) ∧ ([1m[33mmel6=1157Hz->min[0m [1m<[0m[0m -5.160466644906318) ∧ ([1m[33mmel2=529Hz->min[0m [1m<[0m[0m -5.318134764797197) ∧ ([1m[33mmel7=1359Hz->qnt[0m [1m<[0m[0m 2.149699676804287) ∧ (¬([1m[33mmel4=811Hz->max[0m [1m<[0m[0m -3.749222280014931)) ∧ (¬([1m[33mmel4=811Hz->med[0m [1m<[0m[0m -5.211889812618957)) ∧ ([1m[36mentrp->mean[0m [1m<[0m[0m 0.5575021763809584) ∧ (¬([1m[33mmel12=2749Hz->std[0m [1m<[0m[0m 0.32395301625944595))  ↣  Healthy : (ninstances = 101, ncovered = 2, coverage = 0.02, confidence = 1.0, lift = 1.98, natoms = 9)
[34m▣[0m ([1m[36mcntrd->max[0m [1m<[0m[0m 1307.4545801505142) ∧ ([1m[36mskwns->min[0m [1m<[0m[0m 1.2677830790873983) ∧ ([1m[33mmel6=1157Hz->min[0m [1m<[0m[0m -5.160466644906318) ∧ ([1m[33mmel2=529Hz->min[0m [1m<[0m[0m -5.318134764797197) ∧ (¬([1m[33mmel7=1359Hz->

### Simplify rules while extracting and prettify result

In [8]:
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0, normalize = true);
printmodel.(interesting_rules; show_metrics = true, syntaxstring_kwargs = (; threshold_digits = 2), variable_names_map = variable_names);

[34m▣[0m ([1m[36mcntrd->max[0m [1m<[0m[0m 1307.45) ∧ ([1m[36mskwns->min[0m [1m<[0m[0m 1.27) ∧ ([1m[33mmel6=1157Hz->min[0m [1m<[0m[0m -5.16) ∧ ([1m[33mmel2=529Hz->min[0m [1m<[0m[0m -5.32) ∧ ([1m[33mmel7=1359Hz->qnt[0m [1m<[0m[0m 2.15) ∧ ([1m[33mmel4=811Hz->max[0m [1m≥[0m[0m -3.75) ∧ ([1m[33mmel4=811Hz->med[0m [1m≥[0m[0m -5.21) ∧ ([1m[36mentrp->mean[0m [1m<[0m[0m 0.56) ∧ ([1m[33mmel12=2749Hz->std[0m [1m≥[0m[0m 0.32)  ↣  Healthy : (ninstances = 101, ncovered = 2, coverage = 0.02, confidence = 1.0, lift = 1.98, natoms = 9)
[34m▣[0m ([1m[36mcntrd->max[0m [1m<[0m[0m 1307.45) ∧ ([1m[36mskwns->min[0m [1m<[0m[0m 1.27) ∧ ([1m[33mmel6=1157Hz->min[0m [1m<[0m[0m -5.16) ∧ ([1m[33mmel2=529Hz->min[0m [1m<[0m[0m -5.32) ∧ ([1m[33mmel7=1359Hz->qnt[0m [1m≥[0m[0m 2.15) ∧ ([1m[33mmel4=811Hz->mean[0m [1m<[0m[0m -4.89)  ↣  Pneumonia : (ninstances = 101, ncovered = 5, coverage = 0.05, confidence = 1.0, lift = 2.02, na

### Directly access rule metrics

In [9]:
readmetrics.(listrules(sole_dt; min_lift=1.0, min_ninstances = 0))

12-element Vector{@NamedTuple{ninstances::Int64, ncovered::Int64, coverage::Float64, confidence::Float64, lift::Float64, natoms::Int64}}:
 (ninstances = 101, ncovered = 2, coverage = 0.019801980198019802, confidence = 1.0, lift = 1.9803921568627452, natoms = 9)
 (ninstances = 101, ncovered = 5, coverage = 0.04950495049504951, confidence = 1.0, lift = 2.02, natoms = 6)
 (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.9803921568627452, natoms = 6)
 (ninstances = 101, ncovered = 10, coverage = 0.09900990099009901, confidence = 0.9, lift = 1.818, natoms = 6)
 (ninstances = 101, ncovered = 9, coverage = 0.0891089108910891, confidence = 1.0, lift = 1.9803921568627452, natoms = 4)
 (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.9803921568627452, natoms = 5)
 (ninstances = 101, ncovered = 3, coverage = 0.0297029702970297, confidence = 1.0, lift = 1.9803921568627452, natoms = 3)
 (ninstances = 101, ncover

### Show rules with an additional metric (syntax height of the rule's antecedent)

In [10]:
printmodel.(sort(interesting_rules, by = readmetrics); show_metrics = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))), variable_names_map = variable_names);

[34m▣[0m ([1m[36mcntrd->max[0m [1m<[0m[0m 1307.4545801505142) ∧ ([1m[36mskwns->min[0m [1m<[0m[0m 1.2677830790873983) ∧ ([1m[33mmel6=1157Hz->min[0m [1m≥[0m[0m -5.160466644906318) ∧ ([1m[36mflatn->min[0m [1m≥[0m[0m 0.008407301030766876) ∧ ([1m[33mmel13=3124Hz->mean[0m [1m≥[0m[0m -5.418313296366907)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.9803921568627452, natoms = 5, height = 4)
[34m▣[0m ([1m[36mcntrd->max[0m [1m<[0m[0m 1307.4545801505142) ∧ ([1m[36mskwns->min[0m [1m≥[0m[0m 1.2677830790873983) ∧ ([1m[33mmel13=3124Hz->min[0m [1m≥[0m[0m -8.134815239307551) ∧ ([1m[36mentrp->qnt[0m [1m<[0m[0m 2.1832827078922596) ∧ ([1m[33mmel10=2106Hz->mean[0m [1m≥[0m[0m -5.044707915173815)  ↣  Healthy : (ninstances = 101, ncovered = 1, coverage = 0.009900990099009901, confidence = 1.0, lift = 1.9803921568627452, natoms = 5, height = 4)
[34m▣[0m ([1m[36mcntrd->max[0m [1m<[

### Pretty table of rules and their metrics

In [11]:
metricstable(interesting_rules; variable_names_map = variable_names, metrics_kwargs = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))))

┌──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┬────────────┬────────────┬──────────┬────────────┬────────────┬─────────┬────────┬────────┐
│[33;1m                                                                                                                                                                                                                                                                                                                                                           Antecedent [0m│[33;1m Consequent [0m│[33;1m ninstances [0m│[33;1m ncovered [0m│[33;1m   coverage [0m│[33;1m confidence [0m│[33;1m    lift [0m│[33;1m natoms [0m│[33;1m height [0

# Inspect features

In [12]:
interesting_rules = listrules(sole_dt,
	min_lift = 1.0,
	# min_lift = 2.0,
	min_ninstances = 0,
	min_coverage = 0.10,
	normalize = true,
);
map(r->(consequent(r), readmetrics(r)), interesting_rules)
printmodel.(interesting_rules; show_metrics = true, syntaxstring_kwargs = (; threshold_digits = 2), variable_names_map=variable_names);

[34m▣[0m ([1m[36mcntrd->max[0m [1m<[0m[0m 1307.45) ∧ ([1m[36mskwns->min[0m [1m≥[0m[0m 1.27) ∧ ([1m[33mmel13=3124Hz->min[0m [1m≥[0m[0m -8.13) ∧ ([1m[36mentrp->qnt[0m [1m<[0m[0m 2.18) ∧ ([1m[33mmel10=2106Hz->mean[0m [1m<[0m[0m -5.04)  ↣  Pneumonia : (ninstances = 101, ncovered = 24, coverage = 0.24, confidence = 0.96, lift = 1.94, natoms = 5)
[34m▣[0m ([1m[36mcntrd->max[0m [1m≥[0m[0m 1307.45) ∧ ([1m[33mmel8=1583Hz->std[0m [1m≥[0m[0m 0.18) ∧ ([1m[33mmel12=2749Hz->std[0m [1m<[0m[0m 0.31)  ↣  Healthy : (ninstances = 101, ncovered = 27, coverage = 0.27, confidence = 0.85, lift = 1.69, natoms = 3)


In [13]:
interesting_features = unique(SoleData.feature.(SoleLogics.value.(vcat(SoleLogics.atoms.(i.antecedent for i in interesting_rules)...))))
interesting_variables = sort(SoleData.i_variable.(interesting_features))

7-element Vector{Symbol}:
 Symbol("cntrd->max")
 Symbol("entrp->qnt")
 Symbol("mel10=2106Hz->mean")
 Symbol("mel12=2749Hz->std")
 Symbol("mel13=3124Hz->min")
 Symbol("mel8=1583Hz->std")
 Symbol("skwns->min")