From 5edc3995166ad887f0860478763d0e12a846311e Mon Sep 17 00:00:00 2001 From: Sasha Petrenko Date: Tue, 3 Nov 2020 15:45:25 -0600 Subject: [PATCH] Coverage for GNFA and similarity methods --- Project.toml | 4 ++-- src/AdaptiveResonance.jl | 5 ----- src/DDVFA.jl | 29 +++++++++++------------------ src/funcs.jl | 6 +++++- test/test_sets.jl | 30 +++++++++++++++++++++++++++++- test/test_utils.jl | 9 ++++++--- 6 files changed, 53 insertions(+), 30 deletions(-) diff --git a/Project.toml b/Project.toml index 724aadd2..ff5b2b7d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ description = "A Julia package for Adaptive Resonance Theory (ART) algorithms." version = "0.1.0" [deps] -DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" @@ -23,9 +22,10 @@ PyPlot = "2.9.0" julia = "1" [extras] +DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "SafeTestsets", "MLDatasets"] +test = ["Test", "SafeTestsets", "MLDatasets", "DelimitedFiles"] diff --git a/src/AdaptiveResonance.jl b/src/AdaptiveResonance.jl index 744c6460..9b8dcde9 100644 --- a/src/AdaptiveResonance.jl +++ b/src/AdaptiveResonance.jl @@ -1,10 +1,5 @@ module AdaptiveResonance -# using Logging -# using Parameters -# using Statistics -# using LinearAlgebra - abstract type AbstractARTOpts end abstract type AbstractART end diff --git a/src/DDVFA.jl b/src/DDVFA.jl index 5725e810..590592bc 100644 --- a/src/DDVFA.jl +++ b/src/DDVFA.jl @@ -5,7 +5,6 @@ using LinearAlgebra using ProgressBars using Printf - """ opts_GNFA() @@ -39,7 +38,6 @@ Initialized GNFA max_epochs = 1 end # opts_GNFA - """ GNFA @@ -72,7 +70,6 @@ mutable struct GNFA <: AbstractART epoch::Int end # GNFA - """ GNFA() @@ -91,7 +88,6 @@ function GNFA() GNFA(opts) end # GNFA() - """ GNFA(opts) @@ -121,7 +117,6 @@ function GNFA(opts) ) end # GNFA(opts) - """ initialize!() @@ -150,7 +145,6 @@ function initialize!(art::GNFA, x::Array) # push!(art.labels, label) end # initialize!(GNFA, x) - """ train!() @@ -309,7 +303,6 @@ function classify(art::GNFA, x::Array) return y_hat end # classify(GNFA, x) - """ activation_match!(art::GNFA, x::Array) @@ -339,25 +332,27 @@ end # activation_match!(GNFA, x) # Generic learning function -function learn(art::GNFA, x, W) +function learn(art::GNFA, x::Array, W::Array) # Update W return art.opts.beta .* element_min(x, W) .+ W .* (1 - art.opts.beta) end # learn(GNFA, x, W) - # In place learning function with instance counting -function learn!(art::GNFA, x, index) +function learn!(art::GNFA, x::Array, index::Int) # Update W art.W[:, index] = learn(art, x, art.W[:, index]) art.n_instance[index] += 1 end # learn!(GNFA, x, index) +""" + stopping_conditions(art::GNFA) +Stopping conditions for a GNFA module. +""" function stopping_conditions(art::GNFA) return isequal(art.W, art.W_old) || art.epoch >= art.opts.max_epochs end # stopping_conditions(GNFA) - """ opts_DDVFA() @@ -449,7 +444,6 @@ function DDVFA(opts::opts_DDVFA) ) end # DDVFA(opts) - """ train!(ddvfa, data) @@ -538,7 +532,6 @@ function train!(art::DDVFA, x::Array) end end # train!(DDVFA, x) - """ stopping_conditions(art::DDVFA) @@ -551,7 +544,6 @@ function stopping_conditions(art::DDVFA) return art.W == art.W_old || art.epoch >= art.opts.max_epoch end # stopping_conditions(DDVFA) - """ similarity(method, F2, field_name, gamma_ref) @@ -595,14 +587,15 @@ function similarity(method::String, F2::GNFA, field_name::String, sample::Array, # Weighted linkage elseif method == "weighted" if field_name == "T" - value = F2.T * (F2.n / sum(F2.n)) + value = F2.T' * (F2.n_instance ./ sum(F2.n_instance)) elseif field_name == "M" - value = F2.M * (F2.n / sum(F2.n)) + value = F2.M' * (F2.n_instance ./ sum(F2.n_instance)) end # Centroid linkage elseif method == "centroid" - Wc = minimum(F2.W) - T = norm(min(sample, Wc), 1) + Wc = minimum(F2.W, dims=2) + # (norm(min(obj.sample, Wc), 1)/(obj.alpha + norm(Wc, 1)))^obj.gamma; + T = norm(element_min(sample, Wc), 1) / (F2.opts.alpha + norm(Wc, 1))^F2.opts.gamma if field_name == "T" value = T elseif field_name == "M" diff --git a/src/funcs.jl b/src/funcs.jl index f9809741..f9fa16d4 100644 --- a/src/funcs.jl +++ b/src/funcs.jl @@ -21,6 +21,11 @@ function complement_code(data::Array) return x end +""" + element_min(x::Array, W::Array) + +Returns the element-wise minimum between sample x and weight W. +""" function element_min(x::Array, W::Array) # Compute the element-wise minimum of two vectors return minimum([x W], dims = 2) @@ -152,4 +157,3 @@ end # error("Invalid/unimplemented similarity method") # end # end # similarity - diff --git a/test/test_sets.jl b/test/test_sets.jl index faec0435..c28035ea 100644 --- a/test/test_sets.jl +++ b/test/test_sets.jl @@ -25,7 +25,35 @@ end # GNFA train and test my_gnfa = GNFA() - + data = load_am_data(200, 50) + local_complement_code = AdaptiveResonance.complement_code(data.train_x) + train!(my_gnfa, local_complement_code) + + # Similarity methods + methods = ["single", + "average", + "complete", + "median", + "weighted", + "centroid"] + + # Both field names + field_names = ["T", "M"] + + # Compute a local sample for GNFA similarity method testing + local_sample = local_complement_code[:, 1] + + # Compute the local activation and match + AdaptiveResonance.activation_match!(my_gnfa, local_sample) + + # Test every method and field name + for method in methods + println("Method: ", method) + for field_name in field_names + result = AdaptiveResonance.similarity(method, my_gnfa, field_name, local_sample, my_gnfa.opts.gamma_ref) + println(field_name, ": ", result) + end + end end diff --git a/test/test_utils.jl b/test/test_utils.jl index f59c70c3..97e0987d 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -22,10 +22,13 @@ dataset and packages them into a DataSplit struct. function load_am_data(N_train::Int, N_test::Int) # Load the data, downloading if in a CI context: TODO # if ENV["CI"] == true - MNIST.download("../data/mnist/", i_accept_the_terms_of_use=true) + data_dir = "../data/mnist/" + if !isdir(data_dir) + MNIST.download(data_dir, i_accept_the_terms_of_use=true) + end # end - train_x, train_y = MNIST.traindata() - test_x, test_y = MNIST.testdata() + train_x, train_y = MNIST.traindata(dir=data_dir) + test_x, test_y = MNIST.testdata(dir=data_dir) # Get sizes of train and test data size_a, size_b, data_n = size(train_x)