Skip to content

Commit

Permalink
Coverage for GNFA and similarity methods
Browse files Browse the repository at this point in the history
  • Loading branch information
AP6YC committed Nov 3, 2020
1 parent 435b720 commit 5edc399
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 30 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Expand Up @@ -5,7 +5,6 @@ description = "A Julia package for Adaptive Resonance Theory (ART) algorithms."
version = "0.1.0"

[deps]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
Expand All @@ -23,9 +22,10 @@ PyPlot = "2.9.0"
julia = "1"

[extras]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SafeTestsets", "MLDatasets"]
test = ["Test", "SafeTestsets", "MLDatasets", "DelimitedFiles"]
5 changes: 0 additions & 5 deletions src/AdaptiveResonance.jl
@@ -1,10 +1,5 @@
module AdaptiveResonance

# using Logging
# using Parameters
# using Statistics
# using LinearAlgebra

abstract type AbstractARTOpts end
abstract type AbstractART end

Expand Down
29 changes: 11 additions & 18 deletions src/DDVFA.jl
Expand Up @@ -5,7 +5,6 @@ using LinearAlgebra
using ProgressBars
using Printf


"""
opts_GNFA()
Expand Down Expand Up @@ -39,7 +38,6 @@ Initialized GNFA
max_epochs = 1
end # opts_GNFA


"""
GNFA
Expand Down Expand Up @@ -72,7 +70,6 @@ mutable struct GNFA <: AbstractART
epoch::Int
end # GNFA


"""
GNFA()
Expand All @@ -91,7 +88,6 @@ function GNFA()
GNFA(opts)
end # GNFA()


"""
GNFA(opts)
Expand Down Expand Up @@ -121,7 +117,6 @@ function GNFA(opts)
)
end # GNFA(opts)


"""
initialize!()
Expand Down Expand Up @@ -150,7 +145,6 @@ function initialize!(art::GNFA, x::Array)
# push!(art.labels, label)
end # initialize!(GNFA, x)


"""
train!()
Expand Down Expand Up @@ -309,7 +303,6 @@ function classify(art::GNFA, x::Array)
return y_hat
end # classify(GNFA, x)


"""
activation_match!(art::GNFA, x::Array)
Expand Down Expand Up @@ -339,25 +332,27 @@ end # activation_match!(GNFA, x)


# Generic learning function
function learn(art::GNFA, x, W)
function learn(art::GNFA, x::Array, W::Array)
# Update W
return art.opts.beta .* element_min(x, W) .+ W .* (1 - art.opts.beta)
end # learn(GNFA, x, W)


# In place learning function with instance counting
function learn!(art::GNFA, x, index)
function learn!(art::GNFA, x::Array, index::Int)
# Update W
art.W[:, index] = learn(art, x, art.W[:, index])
art.n_instance[index] += 1
end # learn!(GNFA, x, index)

"""
stopping_conditions(art::GNFA)
Stopping conditions for a GNFA module.
"""
function stopping_conditions(art::GNFA)
return isequal(art.W, art.W_old) || art.epoch >= art.opts.max_epochs
end # stopping_conditions(GNFA)


"""
opts_DDVFA()
Expand Down Expand Up @@ -449,7 +444,6 @@ function DDVFA(opts::opts_DDVFA)
)
end # DDVFA(opts)


"""
train!(ddvfa, data)
Expand Down Expand Up @@ -538,7 +532,6 @@ function train!(art::DDVFA, x::Array)
end
end # train!(DDVFA, x)


"""
stopping_conditions(art::DDVFA)
Expand All @@ -551,7 +544,6 @@ function stopping_conditions(art::DDVFA)
return art.W == art.W_old || art.epoch >= art.opts.max_epoch
end # stopping_conditions(DDVFA)


"""
similarity(method, F2, field_name, gamma_ref)
Expand Down Expand Up @@ -595,14 +587,15 @@ function similarity(method::String, F2::GNFA, field_name::String, sample::Array,
# Weighted linkage
elseif method == "weighted"
if field_name == "T"
value = F2.T * (F2.n / sum(F2.n))
value = F2.T' * (F2.n_instance ./ sum(F2.n_instance))
elseif field_name == "M"
value = F2.M * (F2.n / sum(F2.n))
value = F2.M' * (F2.n_instance ./ sum(F2.n_instance))
end
# Centroid linkage
elseif method == "centroid"
Wc = minimum(F2.W)
T = norm(min(sample, Wc), 1)
Wc = minimum(F2.W, dims=2)
# (norm(min(obj.sample, Wc), 1)/(obj.alpha + norm(Wc, 1)))^obj.gamma;
T = norm(element_min(sample, Wc), 1) / (F2.opts.alpha + norm(Wc, 1))^F2.opts.gamma
if field_name == "T"
value = T
elseif field_name == "M"
Expand Down
6 changes: 5 additions & 1 deletion src/funcs.jl
Expand Up @@ -21,6 +21,11 @@ function complement_code(data::Array)
return x
end

"""
element_min(x::Array, W::Array)
Returns the element-wise minimum between sample x and weight W.
"""
function element_min(x::Array, W::Array)
# Compute the element-wise minimum of two vectors
return minimum([x W], dims = 2)
Expand Down Expand Up @@ -152,4 +157,3 @@ end
# error("Invalid/unimplemented similarity method")
# end
# end # similarity

30 changes: 29 additions & 1 deletion test/test_sets.jl
Expand Up @@ -25,7 +25,35 @@ end

# GNFA train and test
my_gnfa = GNFA()

data = load_am_data(200, 50)
local_complement_code = AdaptiveResonance.complement_code(data.train_x)
train!(my_gnfa, local_complement_code)

# Similarity methods
methods = ["single",
"average",
"complete",
"median",
"weighted",
"centroid"]

# Both field names
field_names = ["T", "M"]

# Compute a local sample for GNFA similarity method testing
local_sample = local_complement_code[:, 1]

# Compute the local activation and match
AdaptiveResonance.activation_match!(my_gnfa, local_sample)

# Test every method and field name
for method in methods
println("Method: ", method)
for field_name in field_names
result = AdaptiveResonance.similarity(method, my_gnfa, field_name, local_sample, my_gnfa.opts.gamma_ref)
println(field_name, ": ", result)
end
end
end


Expand Down
9 changes: 6 additions & 3 deletions test/test_utils.jl
Expand Up @@ -22,10 +22,13 @@ dataset and packages them into a DataSplit struct.
function load_am_data(N_train::Int, N_test::Int)
# Load the data, downloading if in a CI context: TODO
# if ENV["CI"] == true
MNIST.download("../data/mnist/", i_accept_the_terms_of_use=true)
data_dir = "../data/mnist/"
if !isdir(data_dir)
MNIST.download(data_dir, i_accept_the_terms_of_use=true)
end
# end
train_x, train_y = MNIST.traindata()
test_x, test_y = MNIST.testdata()
train_x, train_y = MNIST.traindata(dir=data_dir)
test_x, test_y = MNIST.testdata(dir=data_dir)

# Get sizes of train and test data
size_a, size_b, data_n = size(train_x)
Expand Down

0 comments on commit 5edc399

Please sign in to comment.