Skip to content

Commit

Permalink
Merge da0b886 into b24c5b9
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Jul 20, 2021
2 parents b24c5b9 + da0b886 commit 831dff1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 15 deletions.
32 changes: 31 additions & 1 deletion examples/support-vector-machine/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,26 @@ version = "2.1.0+0"
deps = ["ChainRulesCore", "Compat", "CompositionsBase", "Distances", "FillArrays", "Functors", "LinearAlgebra", "Random", "Requires", "SpecialFunctions", "StatsBase", "StatsFuns", "TensorCore", "Test", "ZygoteRules"]
path = "../.."
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.6"
version = "0.10.8"

[[LAME_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c"
uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d"
version = "3.100.1+0"

[[LIBLINEAR]]
deps = ["Libdl", "SparseArrays", "liblinear_jll"]
git-tree-sha1 = "81e40115c23acca9dfa30944050096b958271e5a"
uuid = "2d691ee1-e668-5016-a719-b2531b85e0f5"
version = "0.6.0"

[[LIBSVM]]
deps = ["LIBLINEAR", "LinearAlgebra", "ScikitLearnBase", "SparseArrays", "libsvm_jll"]
git-tree-sha1 = "729ea2db931587c983d0ef6691b62de5005c5570"
uuid = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
version = "0.7.0"

[[LZO_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "e5b909bcf985c5e2605737d2ce278ed791b89be6"
Expand Down Expand Up @@ -587,6 +599,12 @@ version = "0.3.0+0"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[ScikitLearnBase]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "7877e55c1523a4b336b433da39c8e8c08d2f221f"
uuid = "6e75b9c4-186b-50bd-896f-2d2496a4843e"
version = "0.5.0"

[[Scratch]]
deps = ["Dates"]
git-tree-sha1 = "0b4b7f1393cff97c33891da2a0bf69c6ed241fda"
Expand Down Expand Up @@ -882,12 +900,24 @@ git-tree-sha1 = "7a5780a0d9c6864184b3a2eeeb833a0c871f00ab"
uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280"
version = "0.1.6+4"

[[liblinear_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "6a4a6a3697269cb2da57e698e9318972d88de0bb"
uuid = "275f1f90-abd2-5ca1-9ad8-abd4e3d66eb7"
version = "2.30.0+0"

[[libpng_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"]
git-tree-sha1 = "94d180a6d2b5e55e447e2d27a29ed04fe79eb30c"
uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f"
version = "1.6.38+0"

[[libsvm_jll]]
deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"]
git-tree-sha1 = "ac78676ee5b1707de969d68d0a39db71f222925d"
uuid = "08558c22-525a-5d2a-acf6-0ac6658ffce4"
version = "3.24.0+1"

[[libvorbis_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"]
git-tree-sha1 = "c45f4e40e7aafe9d086379e5578947ec8b95a8fb"
Expand Down
2 changes: 2 additions & 0 deletions examples/support-vector-machine/Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

[compat]
Distributions = "0.25"
KernelFunctions = "0.10"
LIBSVM = "0.7"
Literate = "2"
Plots = "1"
julia = "1.3"
32 changes: 18 additions & 14 deletions examples/support-vector-machine/script.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# # Support Vector Machine
#
# !!! warning
# This example is under construction

using KernelFunctions
using Distributions
using Plots

using KernelFunctions
using LIBSVM
using LinearAlgebra
using Plots
using Random

## Set plotting theme
Expand All @@ -20,23 +18,29 @@ Random.seed!(1234);
N = 100;

# Select randomly between two classes:
y = rand([-1, 1], N);
y_train = rand([-1, 1], N);

# Random attributes for both classes:
X = Matrix{Float64}(undef, 2, N)
rand!(MvNormal(randn(2), I), view(X, :, y .== 1))
rand!(MvNormal(randn(2), I), view(X, :, y .== -1));
rand!(MvNormal(randn(2), I), view(X, :, y_train .== 1))
rand!(MvNormal(randn(2), I), view(X, :, y_train .== -1));
x_train = ColVecs(X);

# Create a 2D grid:
xgrid = range(floor(Int, minimum(X)), ceil(Int, maximum(X)); length=100)
Xgrid = ColVecs(mapreduce(collect, hcat, Iterators.product(xgrid, xgrid)));
test_range = range(floor(Int, minimum(X)), ceil(Int, maximum(X)); length=100)
x_test = ColVecs(mapreduce(collect, hcat, Iterators.product(test_range, test_range)));

# Create kernel function:
k = SqExponentialKernel() ScaleTransform(2.0)

# Optimal prediction:
f(x, X, k, λ) = kernelmatrix(k, x, X) / (kernelmatrix(k, X) + exp(λ) * I) * y
# [LIBSVM](https://github.com/JuliaML/LIBSVM.jl) can make use of a pre-computed kernel matrix.
# KernelFunctions.jl can be used to produce that.
# Precomputed matrix for training (corresponds to linear kernel)
model = svmtrain(kernelmatrix(k, x_train), y_train; kernel=LIBSVM.Kernel.Precomputed)

# Precomputed matrix for prediction
y_pr, _ = svmpredict(model, kernelmatrix(k, x_train, x_test));

# Compute prediction on a grid:
contourf(xgrid, xgrid, f(Xgrid, ColVecs(X), k, 0.1))
scatter!(X[1, :], X[2, :]; color=y, lab="data", widen=false)
contourf(test_range, test_range, y_pr)
scatter!(X[1, :], X[2, :]; color=y_train, lab="data", widen=false)

0 comments on commit 831dff1

Please sign in to comment.