diff --git a/Project.toml b/Project.toml index 2f40d7ca..5bfa72db 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "AdaptiveResonance" uuid = "3d72adc0-63d3-4141-bf9b-84450dd0395b" authors = ["Sasha Petrenko"] description = "A Julia package for Adaptive Resonance Theory (ART) algorithms." -version = "0.2.1" +version = "0.2.2" [deps] ClusterValidityIndices = "2fefd308-f647-4698-a2f0-e90cfcbca9ad" diff --git a/src/ART/DDVFA.jl b/src/ART/DDVFA.jl index 7ccfd612..d19ad46f 100644 --- a/src/ART/DDVFA.jl +++ b/src/ART/DDVFA.jl @@ -505,14 +505,22 @@ function train!(art::DDVFA, x::Array ; y::Array=[], preprocessed=false) end # art.labels = zeros(n_samples) - y_hat = zeros(Int, n_samples) + if n_samples == 1 + y_hat = zero(Int) + else + y_hat = zeros(Int, n_samples) + end # Initialization if isempty(art.F2) # Set the first label as either 1 or the first provided label local_label = supervised ? y[1] : 1 # Add the local label to the output vector - y_hat[1] = local_label + if n_samples == 1 + y_hat = local_label + else + y_hat[1] = local_label + end # Create a new category create_category(art, get_sample(x, 1), local_label) # Skip the first training entry @@ -541,11 +549,16 @@ function train!(art::DDVFA, x::Array ; y::Array=[], preprocessed=false) (i == 1 && skip_first) && continue # Grab the sample slice sample = get_sample(x, i) + # Default to mismatch mismatch_flag = true # If label is new, break to make new category if supervised && !(y[i] in art.labels) - y_hat[i] = y[i] + if n_samples == 1 + y_hat = y[i] + else + y_hat[i] = y[i] + end create_category(art, sample, y[i]) continue end @@ -566,14 +579,23 @@ function train!(art::DDVFA, x::Array ; y::Array=[], preprocessed=false) # Update the weights with the sample train!(art.F2[bmu], sample) # Save the output label for the sample - y_hat[i] = art.labels[bmu] + label = art.labels[bmu] + if n_samples == 1 + y_hat = label + else + y_hat[i] = label + end mismatch_flag = false break end end if mismatch_flag label = supervised ? y[i] : art.n_categories + 1 - y_hat[i] = label + if n_samples == 1 + y_hat = label + else + y_hat[i] = label + end create_category(art, sample, label) end end @@ -714,15 +736,26 @@ function classify(art::DDVFA, x::Array ; preprocessed=false) end # Initialize the output vector - y_hat = zeros(Int, n_samples) + if n_samples == 1 + y_hat = zero(Int) + else + y_hat = zeros(Int, n_samples) + end - iter_raw = 1:n_samples - iter = art.opts.display ? ProgressBar(iter_raw) : iter_raw + # iter_raw = 1:n_samples + # iter = art.opts.display ? ProgressBar(iter_raw) : iter_raw + iter = get_iterator(art.opts, x) for ix = iter - if art.opts.display - set_description(iter, string(@sprintf("Ep: %i, ID: %i, Cat: %i", art.epoch, ix, art.n_categories))) - end - sample = x[:, ix] + # Update the iterator if necessary + update_iter(art, iter, ix) + # if art.opts.display + # set_description(iter, string(@sprintf("Ep: %i, ID: %i, Cat: %i", art.epoch, ix, art.n_categories))) + # end + + # Grab the sample slice + sample = get_sample(x, ix) + # sample = x[:, ix] + T = zeros(art.n_categories) for jx = 1:art.n_categories activation_match!(art.F2[jx], sample) @@ -735,14 +768,23 @@ function classify(art::DDVFA, x::Array ; preprocessed=false) M = similarity(art.opts.method, art.F2[bmu], "M", sample, art.opts.gamma_ref) if M >= art.threshold # Current winner - y_hat[ix] = art.labels[bmu] + label = art.labels[bmu] + if n_samples == 1 + y_hat = label + else + y_hat[ix] = label + end mismatch_flag = false break end end if mismatch_flag @debug "Mismatch" - y_hat[ix] = -1 + if n_samples == 1 + y_hat = -1 + else + y_hat[ix] = -1 + end end end diff --git a/src/AdaptiveResonance.jl b/src/AdaptiveResonance.jl index 8475564c..258523e6 100644 --- a/src/AdaptiveResonance.jl +++ b/src/AdaptiveResonance.jl @@ -33,7 +33,7 @@ export # Common structures DataConfig, - data_setup, + data_setup!, # Common utility functions complement_code, diff --git a/src/common.jl b/src/common.jl index 9c0dfc7c..52bf64b4 100644 --- a/src/common.jl +++ b/src/common.jl @@ -115,8 +115,10 @@ function get_data_shape(data::Array) if ndims(data) > 1 dim, n_samples = size(data) else - dim = 1 - n_samples = length(data) + # dim = 1 + # n_samples = length(data) + dim = length(data) + n_samples = 1 end return dim, n_samples @@ -132,7 +134,8 @@ function get_n_samples(data::Array) if ndims(data) > 1 n_samples = size(data)[2] else - n_samples = length(data) + # n_samples = length(data) + n_samples = 1 end return n_samples diff --git a/test/test_ddvfa.jl b/test/test_ddvfa.jl index 0efb881c..28bb5937 100644 --- a/test/test_ddvfa.jl +++ b/test/test_ddvfa.jl @@ -23,6 +23,49 @@ function tt_ddvfa(opts::opts_DDVFA, train_x::Array) return art end # tt_ddvfa(opts::opts_DDVFA, train_x::Array) +@testset "DDVFA Sequential" begin + # Set the logging level to Info and standardize the random seed + LogLevel(Logging.Info) + Random.seed!(0) + + @info "------- DDVFA Sequential -------" + + # Load the data and test across all supervised modules + data = load_iris("../data/Iris.csv") + + # Initialize the ART module + art = DDVFA() + # Turn off display for sequential training/testing + art.opts.display = false + # Set up the data manually because the module can't infer from single samples + data_setup!(art.config, data.train_x) + + # Get the dimension and size of the data + dim, n_samples = get_data_shape(data.train_x) + y_hat_train = zeros(Int64, n_samples) + dim_test, n_samples_test = get_data_shape(data.test_x) + y_hat = zeros(Int64, n_samples_test) + + # Iterate over all examples sequentially + for i = 1:n_samples + y_hat_train[i] = train!(art, data.train_x[:, i], y=[data.train_y[i]]) + end + + # Iterate over all test samples sequentially + for i = 1:n_samples_test + y_hat[i] = classify(art, data.test_x[:, i]) + end + + # Calculate performance + perf_train = performance(y_hat_train, data.train_y) + perf_test = performance(y_hat, data.test_y) + @test perf_train > 0.8 + @test perf_test > 0.8 + + @info "DDVFA Training Perf: $perf_train" + @info "DDVFA Testing Perf: $perf_test" +end + @testset "DDVFA Supervised" begin # Set the logging level to Info and standardize the random seed LogLevel(Logging.Info) diff --git a/test/test_sets.jl b/test/test_sets.jl index 59d031b9..60304ab3 100644 --- a/test/test_sets.jl +++ b/test/test_sets.jl @@ -21,15 +21,14 @@ include("test_utils.jl") # Example arrays three_by_two = [1 2; 3 4; 5 6] - # Test DataConfig constructors dc1 = DataConfig() # Default constructor dc2 = DataConfig(0, 1, 2) # When min and max are same across all features dc3 = DataConfig([0, 1], [2, 3]) # When min and max differ across features # Test get_n_samples - @test get_n_samples([1,2,3]) == 3 # 1-D array case - @test get_n_samples(three_by_two) == 2 # 2-D array case + @test get_n_samples([1,2,3]) == 1 # 1-D array case + @test get_n_samples(three_by_two) == 2 # 2-D array case # Test breaking situations @test_throws ErrorException performance([1,2],[1,2,3])