Skip to content

Commit

Permalink
Hotfix DDVFA handling of single samples, bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
AP6YC committed May 13, 2021
1 parent 5756263 commit 8f27d0b
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Expand Up @@ -2,7 +2,7 @@ name = "AdaptiveResonance"
uuid = "3d72adc0-63d3-4141-bf9b-84450dd0395b"
authors = ["Sasha Petrenko"]
description = "A Julia package for Adaptive Resonance Theory (ART) algorithms."
version = "0.2.1"
version = "0.2.2"

[deps]
ClusterValidityIndices = "2fefd308-f647-4698-a2f0-e90cfcbca9ad"
Expand Down
70 changes: 56 additions & 14 deletions src/ART/DDVFA.jl
Expand Up @@ -505,14 +505,22 @@ function train!(art::DDVFA, x::Array ; y::Array=[], preprocessed=false)
end

# art.labels = zeros(n_samples)
y_hat = zeros(Int, n_samples)
if n_samples == 1
y_hat = zero(Int)
else
y_hat = zeros(Int, n_samples)
end

# Initialization
if isempty(art.F2)
# Set the first label as either 1 or the first provided label
local_label = supervised ? y[1] : 1
# Add the local label to the output vector
y_hat[1] = local_label
if n_samples == 1
y_hat = local_label
else
y_hat[1] = local_label
end
# Create a new category
create_category(art, get_sample(x, 1), local_label)
# Skip the first training entry
Expand Down Expand Up @@ -541,11 +549,16 @@ function train!(art::DDVFA, x::Array ; y::Array=[], preprocessed=false)
(i == 1 && skip_first) && continue
# Grab the sample slice
sample = get_sample(x, i)

# Default to mismatch
mismatch_flag = true
# If label is new, break to make new category
if supervised && !(y[i] in art.labels)
y_hat[i] = y[i]
if n_samples == 1
y_hat = y[i]
else
y_hat[i] = y[i]
end
create_category(art, sample, y[i])
continue
end
Expand All @@ -566,14 +579,23 @@ function train!(art::DDVFA, x::Array ; y::Array=[], preprocessed=false)
# Update the weights with the sample
train!(art.F2[bmu], sample)
# Save the output label for the sample
y_hat[i] = art.labels[bmu]
label = art.labels[bmu]
if n_samples == 1
y_hat = label
else
y_hat[i] = label
end
mismatch_flag = false
break
end
end
if mismatch_flag
label = supervised ? y[i] : art.n_categories + 1
y_hat[i] = label
if n_samples == 1
y_hat = label
else
y_hat[i] = label
end
create_category(art, sample, label)
end
end
Expand Down Expand Up @@ -714,15 +736,26 @@ function classify(art::DDVFA, x::Array ; preprocessed=false)
end

# Initialize the output vector
y_hat = zeros(Int, n_samples)
if n_samples == 1
y_hat = zero(Int)
else
y_hat = zeros(Int, n_samples)
end

iter_raw = 1:n_samples
iter = art.opts.display ? ProgressBar(iter_raw) : iter_raw
# iter_raw = 1:n_samples
# iter = art.opts.display ? ProgressBar(iter_raw) : iter_raw
iter = get_iterator(art.opts, x)
for ix = iter
if art.opts.display
set_description(iter, string(@sprintf("Ep: %i, ID: %i, Cat: %i", art.epoch, ix, art.n_categories)))
end
sample = x[:, ix]
# Update the iterator if necessary
update_iter(art, iter, ix)
# if art.opts.display
# set_description(iter, string(@sprintf("Ep: %i, ID: %i, Cat: %i", art.epoch, ix, art.n_categories)))
# end

# Grab the sample slice
sample = get_sample(x, ix)
# sample = x[:, ix]

T = zeros(art.n_categories)
for jx = 1:art.n_categories
activation_match!(art.F2[jx], sample)
Expand All @@ -735,14 +768,23 @@ function classify(art::DDVFA, x::Array ; preprocessed=false)
M = similarity(art.opts.method, art.F2[bmu], "M", sample, art.opts.gamma_ref)
if M >= art.threshold
# Current winner
y_hat[ix] = art.labels[bmu]
label = art.labels[bmu]
if n_samples == 1
y_hat = label
else
y_hat[ix] = label
end
mismatch_flag = false
break
end
end
if mismatch_flag
@debug "Mismatch"
y_hat[ix] = -1
if n_samples == 1
y_hat = -1
else
y_hat[ix] = -1
end
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/AdaptiveResonance.jl
Expand Up @@ -33,7 +33,7 @@ export

# Common structures
DataConfig,
data_setup,
data_setup!,

# Common utility functions
complement_code,
Expand Down
9 changes: 6 additions & 3 deletions src/common.jl
Expand Up @@ -115,8 +115,10 @@ function get_data_shape(data::Array)
if ndims(data) > 1
dim, n_samples = size(data)
else
dim = 1
n_samples = length(data)
# dim = 1
# n_samples = length(data)
dim = length(data)
n_samples = 1
end

return dim, n_samples
Expand All @@ -132,7 +134,8 @@ function get_n_samples(data::Array)
if ndims(data) > 1
n_samples = size(data)[2]
else
n_samples = length(data)
# n_samples = length(data)
n_samples = 1
end

return n_samples
Expand Down
43 changes: 43 additions & 0 deletions test/test_ddvfa.jl
Expand Up @@ -23,6 +23,49 @@ function tt_ddvfa(opts::opts_DDVFA, train_x::Array)
return art
end # tt_ddvfa(opts::opts_DDVFA, train_x::Array)

@testset "DDVFA Sequential" begin
# Set the logging level to Info and standardize the random seed
LogLevel(Logging.Info)
Random.seed!(0)

@info "------- DDVFA Sequential -------"

# Load the data and test across all supervised modules
data = load_iris("../data/Iris.csv")

# Initialize the ART module
art = DDVFA()
# Turn off display for sequential training/testing
art.opts.display = false
# Set up the data manually because the module can't infer from single samples
data_setup!(art.config, data.train_x)

# Get the dimension and size of the data
dim, n_samples = get_data_shape(data.train_x)
y_hat_train = zeros(Int64, n_samples)
dim_test, n_samples_test = get_data_shape(data.test_x)
y_hat = zeros(Int64, n_samples_test)

# Iterate over all examples sequentially
for i = 1:n_samples
y_hat_train[i] = train!(art, data.train_x[:, i], y=[data.train_y[i]])
end

# Iterate over all test samples sequentially
for i = 1:n_samples_test
y_hat[i] = classify(art, data.test_x[:, i])
end

# Calculate performance
perf_train = performance(y_hat_train, data.train_y)
perf_test = performance(y_hat, data.test_y)
@test perf_train > 0.8
@test perf_test > 0.8

@info "DDVFA Training Perf: $perf_train"
@info "DDVFA Testing Perf: $perf_test"
end

@testset "DDVFA Supervised" begin
# Set the logging level to Info and standardize the random seed
LogLevel(Logging.Info)
Expand Down
5 changes: 2 additions & 3 deletions test/test_sets.jl
Expand Up @@ -21,15 +21,14 @@ include("test_utils.jl")
# Example arrays
three_by_two = [1 2; 3 4; 5 6]


# Test DataConfig constructors
dc1 = DataConfig() # Default constructor
dc2 = DataConfig(0, 1, 2) # When min and max are same across all features
dc3 = DataConfig([0, 1], [2, 3]) # When min and max differ across features

# Test get_n_samples
@test get_n_samples([1,2,3]) == 3 # 1-D array case
@test get_n_samples(three_by_two) == 2 # 2-D array case
@test get_n_samples([1,2,3]) == 1 # 1-D array case
@test get_n_samples(three_by_two) == 2 # 2-D array case

# Test breaking situations
@test_throws ErrorException performance([1,2],[1,2,3])
Expand Down

2 comments on commit 8f27d0b

@AP6YC
Copy link
Owner Author

@AP6YC AP6YC commented on 8f27d0b May 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

This version changes the behavior of get_n_samples and get_data_shape to interpret 1-D vectors as a single sample, reflected in the changed behavior of DDVFA.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/36702

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.2 -m "<description of version>" 8f27d0b499b37845623ab89aaf741bd29b8fb166
git push origin v0.2.2

Please sign in to comment.