In [None]:
using Pkg
Pkg.activate("..")
Pkg.instantiate()
Pkg.status()

In [None]:
using Random
Random.seed!(1605)

## Learning with Modal Decision Trees

Let us try to tackle the Natops dataset with what we learned in the previous days.

In [None]:
using ARFFFiles

using DataFrames
using MLJ
using Plots
using Random
using StatsBase
using SoleData
using SoleModels

In [None]:
include(joinpath("..", "scripts", "parse-natops.jl"))

In [None]:
X, y = read(
    joinpath(@__DIR__, "..", "datasets", "natops.arff"), String) |> parse_natops

In [None]:
variablenames = [
    "X[Hand tip l]", "Y[Hand tip l]", "Z[Hand tip l]",
    "X[Hand tip r]", "Y[Hand tip r]", "Z[Hand tip r]", 
    "X[Elbow l]", "Y[Elbow l]", "Z[Elbow l]",
    "X[Elbow r]", "Y[Elbow r]", "Z[Elbow r]",
    "X[Wrist l]", "Y[Wrist l]", "Z[Wrist l]",
    "X[Wrist r]", "Y[Wrist r]", "Z[Wrist r]",
    "X[Thumb l]", "Y[Thumb l]", "Z[Thumb l]",
    "X[Thumb r]", "Y[Thumb r]", "Z[Thumb r]",
]

classnames = [
    "I have command",
    "All clear",
    "Not clear",
    "Spread wings",
    "Fold wings",
    "Lock wings"
]

try
    X = map(i -> variablenames[round(Int, parse(Float64, i))], X)
    y = map(i -> classnames[round(Int, parse(Float64, i))], y)
catch
    println("You already converted the variable and class names to human readable strings.")
end

In [None]:
X_ninstances, X_nattributes = size(X)
X_ndatapoints = length(X[1,1])

println("Number of instances: $(X_ninstances)")
println("Number of attributes: $(X_nattributes)")
println("Number of datapoints for each attribute: $(X_ndatapoints)")

In [None]:
# for every combination of instance and attributes,
# we are still dealing with the same number of datapoints (51)
all(
    i -> length(X[i[1],i[2]]) == X_ndatapoints, 
    Iterators.product(1:X_ninstances, 1:X_nattributes)
)

In [None]:
# try to change the target attribute
_attribute = 6
plot(X[1,_attribute], label = names(X)[_attribute])

In [None]:
countmap(y)

In [None]:
 # let us summarize one instance for each class
plot(map(i -> 
    plot(collect(X[i,:]), 
        labels=nothing,
        title=y[i]), 
        1:30:180
    )..., 
    layout = (2, 3), 
    size = (1500,400)
)

In [None]:
# length of X[hand tip l] of the first instance 
length(X[1,1])

In [None]:
# each instance can be shaped as a Kripke Frame, whose worlds encode all the intervals 
# in the range [1, 51] (including the degenerate, punctual cases such as [1, 1])
fr = SoleLogics.frame(X, 1)

In [None]:
allworlds(fr) |> collect

In [None]:
using SoleLogics: Interval

# enumerate the intervals that are "Later" than [1,10]
collect(accessibles(fr, Interval(1,10), IA_L))

In [None]:
# we compute the value of a certain feature on each world where we can
feature = SoleData.VariableMax(4)

In [None]:
plot(X[1, 4], labels="X Hand tip right")

In [None]:
SoleData.featvalue(feature, X, 1, Interval(10, 30))

In [None]:
# when we are interested in windowing the data, it is easy to transform a 
# dataset into a Kripke Model
Xk = scalarlogiset(X)

In [None]:
# we can check custom conditions over the logiset we just created
p = Atom(ScalarCondition(feature, <, 1.0))
check(p, Xk, 1, Interval(10, 30))

In [None]:
plot(collect(X[1, 4:6]), labels=[
    "V4 (X right hand)" "V5 (Y right hand)" "V6 (Z right hand)"])

In [None]:
p = Atom(ScalarCondition(VariableMin(4), >, 1.0))
q = Atom(ScalarCondition(VariableMax(5), <=, 3.0))
r = Atom(ScalarCondition(VariableMax(6), <=, 0.0))

phi = ¬p ∨ (q ∧ r)
println(syntaxstring(phi))

check(phi, SoleLogics.LogicalInstance(Xk, 1), Interval(10, 30))

Let us try to check some modal formulae.

In [None]:
boxlater = box(SoleLogics.IA_A)

In [None]:
later_always_phi = boxlater(phi)

In [None]:
check(later_always_phi, SoleLogics.LogicalInstance(Xk, 1), Interval(10, 30))

In [None]:
SoleLogics.getinstance(Xk, 1)

In [None]:
# let us try with an even more complex scenario
check_mask = zeros(Int64, 51)
for i in 1:X_ndatapoints
    check_mask[i] = check(phi, SoleLogics.LogicalInstance(Xk, i), Interval(1,30))
end

println(check_mask)

### Modal Decision Trees

In [None]:
using SoleBase
using ModalDecisionTrees

In [None]:
# the experiment we are just going to execute could be too 
# heavy for standard commodity hardware;
# we can reduce data dimensionality via a moving window
X_small = broadcast(x -> movingwindow(mean, x; nwindows = 10, relative_overlap = 0.2), X)

X_small_ninstances, X_small_nattributes = size(X_small)
X_small_ndatapoints = length(X_small[1,1])

println("The number of datapoints changed from $(X_ndatapoints) to $(X_small_ndatapoints)")

In [None]:
features = [maximum, minimum]
Xk_small = scalarlogiset(X_small, features)

In [None]:
model = ModalDecisionTree(; relations = :IA, features = features)

In [None]:
(X_small_train, X_small_test), (y_small_train, y_small_test) = partition(
    (X_small, y), 0.7, rng=121, shuffle=true, multi=true);

In [None]:
# bind the modal decision tree to the logiset;
# then train it and compute the accuracy

mach = machine(model, X_small_train, y_small_train)
@time fit!(mach);

y_small_predict_probabilities = MLJ.predict(mach, X_small_test)
y_small_predict = mode.(y_small_predict_probabilities)

MLJ.accuracy(y_small_predict, y_small_test)

In [None]:
# show the restricted modal decision tree learned
printmodel(report(mach).rawmodel_full; hidemodality = true)

In [None]:
# show its *pure* version
printmodel(report(mach).solemodel_full; show_metrics = true, hidemodality = true)

In [None]:
simplified_restricted_tree = ModalDecisionTrees.prune(
    report(mach).rawmodel_full; simplify = true)

puretree = ModalDecisionTrees.translate(simplified_restricted_tree)
printmodel(
    puretree; 
    threshold_digits = 2, 
    use_feature_abbreviations = true, 
    parenthesize_atoms = false, 
    variable_names_map = [names(X)], 
    hidemodality = true
)

println("# Leaves: ", SoleModels.nsubmodels(puretree))
println("# Classes: ", length(unique(y)))

In [None]:
# print the leaf rules and their training performances
ruleset = listrules(puretree)
printmodel.(
    ruleset; 
    show_metrics = true, 
    threshold_digits = 2, 
    use_feature_abbreviations = true, 
    parenthesize_atoms = false, 
    hidemodality = true
);

In [None]:
println("IF\n\t", 
    SoleLogics.experimentals.formula2natlang(
        antecedent(ruleset[4]);
        threshold_digits = 2,
        variable_names_map = [names(X)]
    )
)

println("THEN\n\t", consequent(ruleset[4]))