From 803274d7cefbb75e49b1edcf3a4220ac59136d35 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Wed, 22 Sep 2021 16:27:30 +1000 Subject: [PATCH] add davibarreira's sinkhorn_divergence with some modifications (#145) * add davibarreira's sinkhorn_divergence with some modifications * added documentation entry for sinkhorn_divergence * add statsbase to test deps * add empirical measure example for sinkhorn divergence * format * add literate * implement symmetric sinkhorn * implement sinkhorn_loss * change formula for obj() * make empirical example run faster * update docstrings * Update src/entropic/sinkhorn.jl Co-authored-by: David Widmann * Update src/entropic/sinkhorn.jl Co-authored-by: David Widmann * address review comments * fix naming of plan and sinkhorn_plan * address comments * fix sinkhorn_divergence and docs * update docs * format * remove sinkhorn_loss * Update examples/empirical_sinkhorn_div/script.jl Co-authored-by: David Widmann * Update src/entropic/sinkhorn_divergence.jl Co-authored-by: David Widmann * Update src/entropic/symmetric.jl Co-authored-by: David Widmann * Update src/entropic/sinkhorn_gibbs.jl Co-authored-by: David Widmann * Update src/entropic/sinkhorn_gibbs.jl Co-authored-by: David Widmann * Update src/entropic/symmetric.jl Co-authored-by: David Widmann * bump version Co-authored-by: David Widmann --- Project.toml | 5 +- docs/src/index.md | 1 + examples/empirical_sinkhorn_div/Manifest.toml | 1062 +++++++++++++++++ examples/empirical_sinkhorn_div/Project.toml | 20 + examples/empirical_sinkhorn_div/script.jl | 103 ++ src/OptimalTransport.jl | 3 + src/entropic/sinkhorn.jl | 33 +- src/entropic/sinkhorn_divergence.jl | 96 ++ src/entropic/sinkhorn_epsscaling.jl | 2 +- src/entropic/sinkhorn_gibbs.jl | 34 +- src/entropic/sinkhorn_solve.jl | 4 +- src/entropic/sinkhorn_stabilized.jl | 2 +- src/entropic/symmetric.jl | 148 +++ test/entropic/sinkhorn_divergence.jl | 130 ++ 14 files changed, 1627 insertions(+), 16 deletions(-) create mode 100644 examples/empirical_sinkhorn_div/Manifest.toml create mode 100644 examples/empirical_sinkhorn_div/Project.toml create mode 100644 examples/empirical_sinkhorn_div/script.jl create mode 100644 src/entropic/sinkhorn_divergence.jl create mode 100644 src/entropic/symmetric.jl create mode 100644 test/entropic/sinkhorn_divergence.jl diff --git a/Project.toml b/Project.toml index af5ebea3..1919a7df 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimalTransport" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" authors = ["zsteve "] -version = "0.3.16" +version = "0.3.17" [deps] ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31" @@ -28,6 +28,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [targets] -test = ["Distances", "ForwardDiff", "ReverseDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test"] +test = ["Distances", "ForwardDiff", "ReverseDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "StatsBase"] diff --git a/docs/src/index.md b/docs/src/index.md index 06d1d177..8eed71e4 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -29,6 +29,7 @@ squared2wasserstein ```@docs sinkhorn sinkhorn2 +sinkhorn_divergence sinkhorn_barycenter ``` diff --git a/examples/empirical_sinkhorn_div/Manifest.toml b/examples/empirical_sinkhorn_div/Manifest.toml new file mode 100644 index 00000000..4306eec0 --- /dev/null +++ b/examples/empirical_sinkhorn_div/Manifest.toml @@ -0,0 +1,1062 @@ +# This file is machine-generated - editing it directly is not advised + +[[Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.3.1" + +[[ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[ArrayInterface]] +deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] +git-tree-sha1 = "d84c956c4c0548b4caf0e4e96cf5b6494b5b1529" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "3.1.32" + +[[Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[BenchmarkTools]] +deps = ["JSON", "Logging", "Printf", "Statistics", "UUIDs"] +git-tree-sha1 = "42ac5e523869a84eac9669eaceed9e4aa0e1587b" +uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +version = "1.1.4" + +[[Bzip2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2" +uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" +version = "1.0.8+0" + +[[Cairo_jll]] +deps = ["Artifacts", "Bzip2_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] +git-tree-sha1 = "f2202b55d816427cd385a9a4f3ffb226bee80f99" +uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" +version = "1.16.1+0" + +[[ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "4ce9393e871aca86cc457d9f66976c3da6902ea7" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.4.0" + +[[CodecBzip2]] +deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"] +git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7" +uuid = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd" +version = "0.7.2" + +[[CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.0" + +[[ColorSchemes]] +deps = ["ColorTypes", "Colors", "FixedPointNumbers", "Random"] +git-tree-sha1 = "9995eb3977fbf67b86d0a0a0508e83017ded03f2" +uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +version = "3.14.0" + +[[ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.0" + +[[Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.8" + +[[CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "4866e381721b30fac8dda4c8cb1d9db45c8d2994" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.37.0" + +[[CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" + +[[Contour]] +deps = ["StaticArrays"] +git-tree-sha1 = "9f02045d934dc030edad45944ea80dbd1f0ebea7" +uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" +version = "0.5.7" + +[[DataAPI]] +git-tree-sha1 = "bec2532f8adb82005476c141ec23e921fc20971b" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.8.0" + +[[DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "7d9d316f04214f7efdbb6398d545446e246eff02" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.10" + +[[DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.3" + +[[DiffRules]] +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "3ed8fa7178a10d1cd0f1ca524f249ba6937490c0" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.3.0" + +[[Distances]] +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "9f46deb4d4ee4494ffb5a40a27a2aced67bdd838" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.4" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[Distributions]] +deps = ["ChainRulesCore", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"] +git-tree-sha1 = "f4efaa4b5157e0cdb8283ae0b5428bc9208436ed" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.25.16" + +[[DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.5" + +[[Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" + +[[EarCut_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "3f3a2501fa7236e9b911e0f7a588c657e822bb6d" +uuid = "5ae413db-bbd1-5e63-b57d-d24a61df00f5" +version = "2.2.3+0" + +[[ExactOptimalTransport]] +deps = ["Distances", "Distributions", "LinearAlgebra", "MathOptInterface", "PDMats", "QuadGK", "SparseArrays", "StatsBase"] +git-tree-sha1 = "615791caeb11b3a62bea0527b3f03627352aa90d" +uuid = "24df6009-d856-477c-ac5c-91f668376b31" +version = "0.1.0" + +[[Expat_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "b3bfd02e98aedfa5cf885665493c5598c350cd2f" +uuid = "2e619515-83b5-522b-bb60-26c02a35a201" +version = "2.2.10+0" + +[[FFMPEG]] +deps = ["FFMPEG_jll"] +git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" +uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" +version = "0.4.1" + +[[FFMPEG_jll]] +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "Pkg", "Zlib_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] +git-tree-sha1 = "d8a578692e3077ac998b50c0217dfd67f21d1e5f" +uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" +version = "4.4.0+0" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays"] +git-tree-sha1 = "693210145367e7685d8604aee33d9bfb85db8b31" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.11.9" + +[[FiniteDiff]] +deps = ["ArrayInterface", "LinearAlgebra", "Requires", "SparseArrays", "StaticArrays"] +git-tree-sha1 = "8b3c09b56acaf3c0e581c66638b85c8650ee9dca" +uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" +version = "2.8.1" + +[[FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.4" + +[[Fontconfig_jll]] +deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Pkg", "Zlib_jll"] +git-tree-sha1 = "21efd19106a55620a188615da6d3d06cd7f6ee03" +uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" +version = "2.13.93+0" + +[[Formatting]] +deps = ["Printf"] +git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8" +uuid = "59287772-0a20-5a39-b81b-1366585eb4c0" +version = "0.4.2" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "b5e930ac60b613ef3406da6d4f42c35d8dc51419" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.19" + +[[FreeType2_jll]] +deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] +git-tree-sha1 = "87eb71354d8ec1a96d4a7636bd57a7347dde3ef9" +uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" +version = "2.10.4+0" + +[[FriBidi_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "aa31987c2ba8704e23c6c8ba8a4f769d5d7e4f91" +uuid = "559328eb-81f9-559d-9380-de523a88c83c" +version = "1.0.10+0" + +[[FunctionWrappers]] +git-tree-sha1 = "241552bc2209f0fa068b6415b1942cc0aa486bcc" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.2" + +[[GLFW_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Pkg", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll"] +git-tree-sha1 = "dba1e8614e98949abfa60480b13653813d8f0157" +uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89" +version = "3.3.5+0" + +[[GR]] +deps = ["Base64", "DelimitedFiles", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Printf", "Random", "Serialization", "Sockets", "Test", "UUIDs"] +git-tree-sha1 = "182da592436e287758ded5be6e32c406de3a2e47" +uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" +version = "0.58.1" + +[[GR_jll]] +deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Pkg", "Qt5Base_jll", "Zlib_jll", "libpng_jll"] +git-tree-sha1 = "ef49a187604f865f4708c90e3f431890724e9012" +uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" +version = "0.59.0+0" + +[[GeometryBasics]] +deps = ["EarCut_jll", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] +git-tree-sha1 = "58bcdf5ebc057b085e58d95c138725628dd7453c" +uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" +version = "0.4.1" + +[[Gettext_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] +git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" +uuid = "78b55507-aeef-58d4-861c-77aaff3498b1" +version = "0.21.0+0" + +[[Glib_jll]] +deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE_jll", "Pkg", "Zlib_jll"] +git-tree-sha1 = "7bf67e9a481712b3dbe9cb3dac852dc4b1162e02" +uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" +version = "2.68.3+0" + +[[Graphite2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "344bf40dcab1073aca04aa0df4fb092f920e4011" +uuid = "3b182d85-2403-5c21-9c21-1e1f0cc25472" +version = "1.3.14+0" + +[[Grisu]] +git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" +uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" +version = "1.0.2" + +[[HTTP]] +deps = ["Base64", "Dates", "IniFile", "Logging", "MbedTLS", "NetworkOptions", "Sockets", "URIs"] +git-tree-sha1 = "60ed5f1643927479f845b0135bb369b031b541fa" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "0.9.14" + +[[HarfBuzz_jll]] +deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] +git-tree-sha1 = "8a954fed8ac097d5be04921d595f741115c1b2ad" +uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" +version = "2.8.1+0" + +[[IOCapture]] +deps = ["Logging", "Random"] +git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" +uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" +version = "0.2.2" + +[[IfElse]] +git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.0" + +[[IniFile]] +deps = ["Test"] +git-tree-sha1 = "098e4d2c533924c921f9f9847274f2ad89e018b8" +uuid = "83e8ac13-25f8-5344-8a64-a9f2b223428f" +version = "0.5.0" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[IrrationalConstants]] +git-tree-sha1 = "f76424439413893a832026ca355fe273e93bce94" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.0" + +[[IterTools]] +git-tree-sha1 = "05110a2ab1fc5f932622ffea2a003221f4782c18" +uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" +version = "1.3.0" + +[[IterativeSolvers]] +deps = ["LinearAlgebra", "Printf", "Random", "RecipesBase", "SparseArrays"] +git-tree-sha1 = "1a8c6237e78b714e901e406c096fc8a65528af7d" +uuid = "42fd0dbc-a981-5370-80f2-aaf504508153" +version = "0.9.1" + +[[IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.3.0" + +[[JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "8076680b162ada2a031f707ac7b4953e30667a37" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.2" + +[[JSONSchema]] +deps = ["HTTP", "JSON", "URIs"] +git-tree-sha1 = "2f49f7f86762a0fbbeef84912265a1ae61c4ef80" +uuid = "7d188eb4-7ad8-530c-ae41-71a32a6d4692" +version = "0.3.4" + +[[JpegTurbo_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "d735490ac75c5cb9f1b00d8b5509c11984dc6943" +uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" +version = "2.1.0+0" + +[[LAME_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" +uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" +version = "3.100.1+0" + +[[LZO_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "e5b909bcf985c5e2605737d2ce278ed791b89be6" +uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" +version = "2.10.1+0" + +[[LaTeXStrings]] +git-tree-sha1 = "c7f1c695e06c01b95a67f0cd1d34994f3e7db104" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.2.1" + +[[Latexify]] +deps = ["Formatting", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "Printf", "Requires"] +git-tree-sha1 = "a4b12a1bd2ebade87891ab7e36fdbce582301a92" +uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" +version = "0.15.6" + +[[LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + +[[LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[Libffi_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "761a393aeccd6aa92ec3515e428c26bf99575b3b" +uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" +version = "3.2.2+0" + +[[Libgcrypt_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll", "Pkg"] +git-tree-sha1 = "64613c82a59c120435c067c2b809fc61cf5166ae" +uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4" +version = "1.8.7+0" + +[[Libglvnd_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll", "Xorg_libXext_jll"] +git-tree-sha1 = "7739f837d6447403596a75d19ed01fd08d6f56bf" +uuid = "7e76a0d4-f3c7-5321-8279-8d96eeed0f29" +version = "1.3.0+3" + +[[Libgpg_error_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "c333716e46366857753e273ce6a69ee0945a6db9" +uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8" +version = "1.42.0+0" + +[[Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "42b62845d70a619f063a7da093d995ec8e15e778" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.16.1+1" + +[[Libmount_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "9c30530bf0effd46e15e0fdcf2b8636e78cbbd73" +uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" +version = "2.35.0+0" + +[[Libtiff_jll]] +deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Pkg", "Zlib_jll", "Zstd_jll"] +git-tree-sha1 = "340e257aada13f95f98ee352d316c3bed37c8ab9" +uuid = "89763e89-9b03-5906-acba-b20f662cd828" +version = "4.3.0+0" + +[[Libuuid_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "7f3efec06033682db852f8b3bc3c1d2b0a0ab066" +uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" +version = "2.36.0+0" + +[[LineSearches]] +deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf"] +git-tree-sha1 = "f27132e551e959b3667d8c93eae90973225032dd" +uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" +version = "7.1.1" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Literate]] +deps = ["Base64", "IOCapture", "JSON", "REPL"] +git-tree-sha1 = "bbebc3c14dbfbe76bfcbabf0937481ac84dc86ef" +uuid = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +version = "2.9.3" + +[[LogExpFunctions]] +deps = ["ChainRulesCore", "DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "86197a8ecb06e222d66797b0c2d2f0cc7b69e42b" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.2" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "5a5bc6bf062f0f95e62d0fe0a2d99699fed82dd9" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.8" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[MathOptInterface]] +deps = ["BenchmarkTools", "CodecBzip2", "CodecZlib", "JSON", "JSONSchema", "LinearAlgebra", "MutableArithmetics", "OrderedCollections", "SparseArrays", "Test", "Unicode"] +git-tree-sha1 = "575644e3c05b258250bb599e57cf73bbf1062901" +uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" +version = "0.9.22" + +[[MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "Random", "Sockets"] +git-tree-sha1 = "1c38e51c3d08ef2278062ebceade0e46cefc96fe" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.0.3" + +[[MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + +[[Measures]] +git-tree-sha1 = "e498ddeee6f9fdb4551ce855a46f54dbd900245f" +uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e" +version = "0.3.1" + +[[Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "2ca267b08821e86c5ef4376cffed98a46c2cb205" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.0.1" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[MutableArithmetics]] +deps = ["LinearAlgebra", "SparseArrays", "Test"] +git-tree-sha1 = "3927848ccebcc165952dc0d9ac9aa274a87bfe01" +uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" +version = "0.2.20" + +[[NLSolversBase]] +deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] +git-tree-sha1 = "144bab5b1443545bc4e791536c9f1eacb4eed06a" +uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" +version = "7.8.1" + +[[NNlib]] +deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "5203a4532ad28c44f82c76634ad621d7c90abcbd" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.7.29" + +[[NaNMath]] +git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.5" + +[[NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" + +[[Ogg_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "7937eda4681660b4d6aeeecc2f7e1c81c8ee4e2f" +uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051" +version = "1.3.5+0" + +[[OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "15003dcb7d8db3c6c857fda14891a539a8f2705a" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "1.1.10+0" + +[[OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[Optim]] +deps = ["Compat", "FillArrays", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] +git-tree-sha1 = "7863df65dbb2a0fa8f85fcaf0a41167640d2ebed" +uuid = "429524aa-4258-5aef-a3af-852621145aeb" +version = "1.4.1" + +[[OptimalTransport]] +deps = ["ExactOptimalTransport", "IterativeSolvers", "LinearAlgebra", "LogExpFunctions", "NNlib", "Reexport"] +path = "../.." +uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" +version = "0.3.16" + +[[Opus_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720" +uuid = "91d4177d-7536-5919-b921-800302f37372" +version = "1.3.2+0" + +[[OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" + +[[PCRE_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "b2a7af664e098055a7529ad1a900ded962bca488" +uuid = "2f80f16e-611a-54ab-bc61-aa92de5b98fc" +version = "8.44.0+0" + +[[PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] +git-tree-sha1 = "95a4038d1011dfdbde7cecd2ad0ac411e53ab1bc" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.10.1" + +[[Parameters]] +deps = ["OrderedCollections", "UnPack"] +git-tree-sha1 = "2276ac65f1e236e0a6ea70baff3f62ad4c625345" +uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" +version = "0.12.2" + +[[Parsers]] +deps = ["Dates"] +git-tree-sha1 = "bfd7d8c7fd87f04543810d9cbd3995972236ba1b" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "1.1.2" + +[[Pixman_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "b4f5d02549a10e20780a24fce72bea96b6329e29" +uuid = "30392449-352a-5448-841d-b1acce4e97dc" +version = "0.40.1+0" + +[[Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[PlotThemes]] +deps = ["PlotUtils", "Requires", "Statistics"] +git-tree-sha1 = "a3a964ce9dc7898193536002a6dd892b1b5a6f1d" +uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" +version = "2.0.1" + +[[PlotUtils]] +deps = ["ColorSchemes", "Colors", "Dates", "Printf", "Random", "Reexport", "Statistics"] +git-tree-sha1 = "9ff1c70190c1c30aebca35dc489f7411b256cd23" +uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" +version = "1.0.13" + +[[Plots]] +deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "GeometryBasics", "JSON", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "PlotThemes", "PlotUtils", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "UUIDs"] +git-tree-sha1 = "2dbafeadadcf7dadff20cd60046bba416b4912be" +uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +version = "1.21.3" + +[[PositiveFactorizations]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "17275485f373e6673f7e7f97051f703ed5b15b20" +uuid = "85a6dd25-e78a-55b7-8502-1745935b8125" +version = "0.2.4" + +[[Preferences]] +deps = ["TOML"] +git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.2.2" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[Qt5Base_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "xkbcommon_jll"] +git-tree-sha1 = "ad368663a5e20dbb8d6dc2fddeefe4dae0781ae8" +uuid = "ea2cea3b-5b76-57ae-a6ef-0a8af62496e1" +version = "5.15.3+0" + +[[QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "12fbe86da16df6679be7521dfb39fbc861e1dc7b" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.4.1" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[RecipesBase]] +git-tree-sha1 = "44a75aa7a527910ee3d1751d1f0e4148698add9e" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.1.2" + +[[RecipesPipeline]] +deps = ["Dates", "NaNMath", "PlotUtils", "RecipesBase"] +git-tree-sha1 = "d4491becdc53580c6dadb0f6249f90caae888554" +uuid = "01d81517-befc-4cb6-b9ec-a95719d0359c" +version = "0.4.0" + +[[Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.1.3" + +[[ReverseDiff]] +deps = ["DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"] +git-tree-sha1 = "63ee24ea0689157a1113dbdab10c6cb011d519c4" +uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +version = "1.9.0" + +[[Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.7.0" + +[[Rmath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "68db32dff12bb6127bac73c209881191bf0efbb7" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.3.0+0" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Scratch]] +deps = ["Dates"] +git-tree-sha1 = "0b4b7f1393cff97c33891da2a0bf69c6ed241fda" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.1.0" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[Showoff]] +deps = ["Dates", "Grisu"] +git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" +uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" +version = "1.0.3" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.0.1" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"] +git-tree-sha1 = "a322a9493e49c5f3a10b50df3aedaf1cdb3244b7" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "1.6.1" + +[[Static]] +deps = ["IfElse"] +git-tree-sha1 = "a8f30abc7c64a39d389680b74e749cf33f872a70" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.3.3" + +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "3240808c6d463ac46f1c1cd7638375cd22abbccb" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.2.12" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsAPI]] +git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.0.0" + +[[StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "8cbbc098554648c84f79a463c9ff0fd277144b6c" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.10" + +[[StatsFuns]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "46d7ccc7104860c38b11966dd1f72ff042f382e4" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "0.9.10" + +[[StructArrays]] +deps = ["Adapt", "DataAPI", "StaticArrays", "Tables"] +git-tree-sha1 = "f41020e84127781af49fc12b7e92becd7f5dd0ba" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.2" + +[[SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"] +git-tree-sha1 = "368d04a820fe069f9080ff1b432147a6203c3c89" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.5.1" + +[[Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" + +[[Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[TranscodingStreams]] +deps = ["Random", "Test"] +git-tree-sha1 = "216b95ea110b5972db65aa90f88d8d89dcb8851c" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.6" + +[[URIs]] +git-tree-sha1 = "97bbe755a53fe859669cd907f2d96aee8d2c1355" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.3.0" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[UnPack]] +git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" +uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +version = "1.0.2" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[Wayland_jll]] +deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"] +git-tree-sha1 = "3e61f0b86f90dacb0bc0e73a0c5a83f6a8636e23" +uuid = "a2964d1f-97da-50d4-b82a-358c7fce9d89" +version = "1.19.0+0" + +[[Wayland_protocols_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Wayland_jll"] +git-tree-sha1 = "2839f1c1296940218e35df0bbb220f2a79686670" +uuid = "2381bf8a-dfd0-557d-9999-79630e7b1b91" +version = "1.18.0+4" + +[[XML2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "Zlib_jll"] +git-tree-sha1 = "1acf5bdf07aa0907e0a37d3718bb88d4b687b74a" +uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" +version = "2.9.12+0" + +[[XSLT_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"] +git-tree-sha1 = "91844873c4085240b95e795f692c4cec4d805f8a" +uuid = "aed1982a-8fda-507f-9586-7b0439959a61" +version = "1.1.34+0" + +[[Xorg_libX11_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] +git-tree-sha1 = "5be649d550f3f4b95308bf0183b82e2582876527" +uuid = "4f6342f7-b3d2-589e-9d20-edeb45f2b2bc" +version = "1.6.9+4" + +[[Xorg_libXau_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "4e490d5c960c314f33885790ed410ff3a94ce67e" +uuid = "0c0b7dd1-d40b-584c-a123-a41640f87eec" +version = "1.0.9+4" + +[[Xorg_libXcursor_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXfixes_jll", "Xorg_libXrender_jll"] +git-tree-sha1 = "12e0eb3bc634fa2080c1c37fccf56f7c22989afd" +uuid = "935fb764-8cf2-53bf-bb30-45bb1f8bf724" +version = "1.2.0+4" + +[[Xorg_libXdmcp_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "4fe47bd2247248125c428978740e18a681372dd4" +uuid = "a3789734-cfe1-5b06-b2d0-1dd0d9d62d05" +version = "1.1.3+4" + +[[Xorg_libXext_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] +git-tree-sha1 = "b7c0aa8c376b31e4852b360222848637f481f8c3" +uuid = "1082639a-0dae-5f34-9b06-72781eeb8cb3" +version = "1.3.4+4" + +[[Xorg_libXfixes_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] +git-tree-sha1 = "0e0dc7431e7a0587559f9294aeec269471c991a4" +uuid = "d091e8ba-531a-589c-9de9-94069b037ed8" +version = "5.0.3+4" + +[[Xorg_libXi_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll", "Xorg_libXfixes_jll"] +git-tree-sha1 = "89b52bc2160aadc84d707093930ef0bffa641246" +uuid = "a51aa0fd-4e3c-5386-b890-e753decda492" +version = "1.7.10+4" + +[[Xorg_libXinerama_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll"] +git-tree-sha1 = "26be8b1c342929259317d8b9f7b53bf2bb73b123" +uuid = "d1454406-59df-5ea1-beac-c340f2130bc3" +version = "1.1.4+4" + +[[Xorg_libXrandr_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll"] +git-tree-sha1 = "34cea83cb726fb58f325887bf0612c6b3fb17631" +uuid = "ec84b674-ba8e-5d96-8ba1-2a689ba10484" +version = "1.5.2+4" + +[[Xorg_libXrender_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] +git-tree-sha1 = "19560f30fd49f4d4efbe7002a1037f8c43d43b96" +uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa" +version = "0.9.10+4" + +[[Xorg_libpthread_stubs_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "6783737e45d3c59a4a4c4091f5f88cdcf0908cbb" +uuid = "14d82f49-176c-5ed1-bb49-ad3f5cbd8c74" +version = "0.1.0+3" + +[[Xorg_libxcb_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "XSLT_jll", "Xorg_libXau_jll", "Xorg_libXdmcp_jll", "Xorg_libpthread_stubs_jll"] +git-tree-sha1 = "daf17f441228e7a3833846cd048892861cff16d6" +uuid = "c7cfdc94-dc32-55de-ac96-5a1b8d977c5b" +version = "1.13.0+3" + +[[Xorg_libxkbfile_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] +git-tree-sha1 = "926af861744212db0eb001d9e40b5d16292080b2" +uuid = "cc61e674-0454-545c-8b26-ed2c68acab7a" +version = "1.1.0+4" + +[[Xorg_xcb_util_image_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "0fab0a40349ba1cba2c1da699243396ff8e94b97" +uuid = "12413925-8142-5f55-bb0e-6d7ca50bb09b" +version = "0.4.0+1" + +[[Xorg_xcb_util_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libxcb_jll"] +git-tree-sha1 = "e7fd7b2881fa2eaa72717420894d3938177862d1" +uuid = "2def613f-5ad1-5310-b15b-b15d46f528f5" +version = "0.4.0+1" + +[[Xorg_xcb_util_keysyms_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "d1151e2c45a544f32441a567d1690e701ec89b00" +uuid = "975044d2-76e6-5fbe-bf08-97ce7c6574c7" +version = "0.4.0+1" + +[[Xorg_xcb_util_renderutil_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "dfd7a8f38d4613b6a575253b3174dd991ca6183e" +uuid = "0d47668e-0667-5a69-a72c-f761630bfb7e" +version = "0.3.9+1" + +[[Xorg_xcb_util_wm_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "e78d10aab01a4a154142c5006ed44fd9e8e31b67" +uuid = "c22f9ab0-d5fe-5066-847c-f4bb1cd4e361" +version = "0.4.1+1" + +[[Xorg_xkbcomp_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libxkbfile_jll"] +git-tree-sha1 = "4bcbf660f6c2e714f87e960a171b119d06ee163b" +uuid = "35661453-b289-5fab-8a00-3d9160c6a3a4" +version = "1.4.2+4" + +[[Xorg_xkeyboard_config_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xkbcomp_jll"] +git-tree-sha1 = "5c8424f8a67c3f2209646d4425f3d415fee5931d" +uuid = "33bec58e-1273-512f-9401-5d533626f822" +version = "2.27.0+4" + +[[Xorg_xtrans_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "79c31e7844f6ecf779705fbc12146eb190b7d845" +uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10" +version = "1.4.0+3" + +[[Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[Zstd_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "cc4bf3fdde8b7e3e9fa0351bdeedba1cf3b7f6e6" +uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" +version = "1.5.0+0" + +[[libass_jll]] +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] +git-tree-sha1 = "5982a94fcba20f02f42ace44b9894ee2b140fe47" +uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0" +version = "0.15.1+0" + +[[libfdk_aac_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55" +uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280" +version = "2.0.2+0" + +[[libpng_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] +git-tree-sha1 = "94d180a6d2b5e55e447e2d27a29ed04fe79eb30c" +uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" +version = "1.6.38+0" + +[[libvorbis_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] +git-tree-sha1 = "c45f4e40e7aafe9d086379e5578947ec8b95a8fb" +uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a" +version = "1.3.7+0" + +[[nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" + +[[x264_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "4fea590b89e6ec504593146bf8b988b2c00922b2" +uuid = "1270edf5-f2f9-52d2-97e9-ab00b5d0237a" +version = "2021.5.5+0" + +[[x265_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "ee567a171cce03570d77ad3a43e90218e38937a9" +uuid = "dfaa095f-4041-5dcd-9319-2fabd8486b76" +version = "3.5.0+0" + +[[xkbcommon_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Wayland_jll", "Wayland_protocols_jll", "Xorg_libxcb_jll", "Xorg_xkeyboard_config_jll"] +git-tree-sha1 = "ece2350174195bb31de1a63bea3a41ae1aa593b6" +uuid = "d8fb68d0-12a3-5cfd-a85a-d49703b185fd" +version = "0.9.1+5" diff --git a/examples/empirical_sinkhorn_div/Project.toml b/examples/empirical_sinkhorn_div/Project.toml new file mode 100644 index 00000000..91327b00 --- /dev/null +++ b/examples/empirical_sinkhorn_div/Project.toml @@ -0,0 +1,20 @@ +[deps] +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" +OptimalTransport = "7e02d93a-ae51-4f58-b602-d97af76e3b33" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + +[compat] +Distances = "0.10" +Distributions = "0.25" +Literate = "2" +Optim = "1.4" +OptimalTransport = "0.3" +Plots = "1" +ReverseDiff = "1" +julia = "1" diff --git a/examples/empirical_sinkhorn_div/script.jl b/examples/empirical_sinkhorn_div/script.jl new file mode 100644 index 00000000..8cd30def --- /dev/null +++ b/examples/empirical_sinkhorn_div/script.jl @@ -0,0 +1,103 @@ +# # Sinkhorn divergences +# +# In this tutorial we provide a minimal example for using the Sinkhorn divergence as a loss function [FSV+19] on empirical distributions. +# [FSV+19]: Feydy, Jean, et al. "Interpolating between optimal transport and MMD using Sinkhorn divergences." The 22nd International Conference on Artificial Intelligence and Statistics. PMLR, 2019. +# +# While entropy-regularised optimal transport $\operatorname{OT}_{\varepsilon}(\cdot, \cdot)$ is commonly used as a loss function, it suffers from a problem of *bias*: namely that $\nu \mapsto \operatorname{OT}_{\varepsilon}(\mu, \nu)$ is *not* minimised at $\nu = \mu$. +# +# A fix to this problem is proposed by Genevay et al [GPC18] and subsequently Feydy et al. [FSV+19], which introduce the *Sinkhorn divergence* between two measures $\mu$ and $\nu$, defined as +# ```math +# \operatorname{S}_{\varepsilon}(\mu, \nu) = \operatorname{OT}_{\varepsilon}(\mu, \nu) - \frac{1}{2} \operatorname{OT}_{\varepsilon}(\mu, \mu) - \frac{1}{2} \operatorname{OT}_{\varepsilon}(\nu, \nu). +# ``` +# In the above, we have followed the convention taken by Feydy et al. and included the entropic regularisation in the definition of $\operatorname{OT}_\varepsilon$. +# [GPC18]: Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 +# [FSV+19]: Feydy, Jean, et al. "Interpolating between optimal transport and MMD using Sinkhorn divergences." The 22nd International Conference on Artificial Intelligence and Statistics. PMLR, 2019. +# +# Like the Sinkhorn loss, the Sinkhorn divergence is smooth and convex in both of its arguments. However, the Sinkhorn divergence is unbiased -- i.e. $S_{\varepsilon}(\mu, \nu) = 0$ iff $\mu = \nu$. +# +# Unlike previous examples, here we demonstrate a learning problem similar to Figure 1 of Feydy et al. over *empirical measures*, i.e. measures that have the form $\mu = \frac{1}{N} \sum_{i = 1}^{N} \delta_{x_i}$ where $\delta_x$ is the Dirac delta function at $x$. +# +# We first load packages. +# +using OptimalTransport +using ReverseDiff +using Distributions +using LinearAlgebra +using Distances +using Plots +using Logging +using Optim + +# As a ground truth distribution, we set $\rho$ to be a Gaussian mixture model with `k = 3` components, equally spaced around a circle, and sample an empirical distribution of size $N$, $\mu \sim \rho$. + +k = 3 +d = 2 +θ = π * range(0, 2(1 - 1 / k); length=k) +μ = 2 * hcat(sin.(θ), cos.(θ)) +ρ = MixtureModel(MvNormal[MvNormal(x, 0.25 * I) for x in eachrow(μ)]) +N = 100 +μ_spt = rand(ρ, N)' +scatter(μ_spt[:, 1], μ_spt[:, 2]; markeralpha=0.25, title=raw"$\mu$") + +# Now, suppose we want to approximate $\mu$ with another empirical distribution $\nu$, i.e. we want to minimise $\nu \mapsto \operatorname{S}_{\varepsilon}(\mu, \nu)$ over possible empirical distributions $\nu$. In this case we have $M$ particles in $\nu$, which we initialise following a Gaussian distribution. +M = 100 +ν_spt = rand(M, d); +# Assign uniform weights to the Diracs in each empirical distribution. +μ = fill(1 / N, N) +ν = fill(1 / M, M); + +# Since $\mu$ is fixed, we pre-compute the cost matrix $C_{\mu}$. +C_μ = pairwise(SqEuclidean(), μ_spt'); + +# Define the loss function to minimise, where `x` specifies the locations of the Diracs in $\nu$. +# +# We will be using `ReverseDiff` with a precompiled tape. For this reason, we need the Sinkhorn algorithm to perform a fixed number of (e.g. 50) iterations. +# Currently, this can be achieved by setting `maxiter = 50` and `atol = rtol = 0` in calls to `sinkhorn` and `sinkhorn_divergence`. +function loss(x, ε) + C_μν = pairwise(SqEuclidean(), μ_spt', x') + C_ν = pairwise(SqEuclidean(), x') + return sinkhorn_divergence( + μ, ν, C_μν, C_μ, C_ν, ε; maxiter=50, atol=rtol = 0, regularization=true + ) +end +# Set entropy regularisation parameter +ε = 1.0; + +# Use ReverseDiff with a precompiled tape and Optim.jl to minimise $\nu \mapsto \operatorname{S}_{\varepsilon}(\mu, \nu)$. Note that this is problem is *not* convex, so we find a local minimium. +const loss_tape = ReverseDiff.GradientTape(x -> loss(x, ε), ν_spt) +const compiled_loss_tape = ReverseDiff.compile(loss_tape) +opt = with_logger(SimpleLogger(stderr, Logging.Error)) do + optimize( + x -> loss(x, ε), + (∇, x) -> ReverseDiff.gradient!(∇, compiled_loss_tape, x), + ν_spt, + GradientDescent(), + Optim.Options(; iterations=10, g_tol=1e-6, show_trace=true), + ) +end +ν_opt = Optim.minimizer(opt) +plt1 = scatter(μ_spt[:, 1], μ_spt[:, 2]; markeralpha=0.25, title="Sinkhorn divergence") +scatter!(plt1, ν_opt[:, 1], ν_opt[:, 2]); + +# For comparison, let us do the same computation again, but this time we want to minimise $\nu \mapsto \operatorname{OT}_{\varepsilon}(\mu, \nu)$. +function loss_biased(x, ε) + C_μν = pairwise(SqEuclidean(), μ_spt', x') + return sinkhorn2(μ, ν, C_μν, ε; maxiter=50, atol=rtol = 0, regularization=true) +end +const loss_biased_tape = ReverseDiff.GradientTape(x -> loss_biased(x, ε), ν_spt) +const compiled_loss_biased_tape = ReverseDiff.compile(loss_biased_tape) +opt_biased = with_logger(SimpleLogger(stderr, Logging.Error)) do + optimize( + x -> loss_biased(x, ε), + (∇, x) -> ReverseDiff.gradient!(∇, compiled_loss_biased_tape, x), + ν_spt, + GradientDescent(), + Optim.Options(; iterations=10, g_tol=1e-6, show_trace=true), + ) +end +ν_opt_biased = Optim.minimizer(opt_biased) +plt2 = scatter(μ_spt[:, 1], μ_spt[:, 2]; markeralpha=0.25, title="Sinkhorn loss") +scatter!(plt2, ν_opt_biased[:, 1], ν_opt_biased[:, 2]); + +# Observe that the Sinkhorn divergence results in $\nu$ that matches $\mu$ quite well, while entropy-regularised transport is biased to producing $\nu$ that seems to concentrate around the mean of each Gaussian component. +plot(plt1, plt2) diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index ea13f6c3..cc1a9e73 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -21,12 +21,15 @@ export QuadraticOTNewton export sinkhorn, sinkhorn2 export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter export sinkhorn_unbalanced, sinkhorn_unbalanced2 +export sinkhorn_divergence export quadreg include("utils.jl") include("entropic/sinkhorn.jl") +include("entropic/sinkhorn_divergence.jl") include("entropic/sinkhorn_gibbs.jl") +include("entropic/symmetric.jl") include("entropic/sinkhorn_stabilized.jl") include("entropic/sinkhorn_epsscaling.jl") include("entropic/sinkhorn_unbalanced.jl") diff --git a/src/entropic/sinkhorn.jl b/src/entropic/sinkhorn.jl index 90e7a8db..22c13942 100644 --- a/src/entropic/sinkhorn.jl +++ b/src/entropic/sinkhorn.jl @@ -120,6 +120,19 @@ function check_convergence(solver::SinkhornSolver) ) end +function sinkhorn_plan(u, v, K) + return K .* add_singleton(u, Val(2)) .* add_singleton(v, Val(1)) +end + +# dual objective +function sinkhorn_dual_objective(u, v, Kv, K, ε) + # return ε * (dot_vecwise(log.(u), μ) .+ dot_vecwise(log.(v), ν)) + return ε * ( + dot_vecwise(LogExpFunctions.xlogx.(u), Kv) + + dot_vecwise(LogExpFunctions.xlogx.(v), K' * u) + ) +end + # API """ @@ -165,11 +178,21 @@ function sinkhorn(μ, ν, C, ε, alg::Sinkhorn; kwargs...) solve!(solver) # compute optimal transport plan - γ = plan(solver) + γ = sinkhorn_plan(solver) return γ end +function sinkhorn_cost_from_plan(γ, C, ε; regularization=false) + cost = if regularization + dot_matwise(γ, C) .+ + ε .* reshape(sum(LogExpFunctions.xlogx, γ; dims=(1, 2)), size(γ)[3:end]) + else + dot_matwise(γ, C) + end + return cost +end + """ sinkhorn2( μ, ν, C, ε, alg=SinkhornGibbs(); regularization=false, plan=nothing, kwargs... @@ -200,12 +223,6 @@ function sinkhorn2(μ, ν, C, ε, alg::Sinkhorn; regularization=false, plan=noth ) plan end - cost = if regularization - dot_matwise(γ, C) .+ - ε .* reshape(sum(LogExpFunctions.xlogx, γ; dims=(1, 2)), size(γ)[3:end]) - else - dot_matwise(γ, C) - end - + cost = sinkhorn_cost_from_plan(γ, C, ε; regularization=regularization) return cost end diff --git a/src/entropic/sinkhorn_divergence.jl b/src/entropic/sinkhorn_divergence.jl new file mode 100644 index 00000000..d162f879 --- /dev/null +++ b/src/entropic/sinkhorn_divergence.jl @@ -0,0 +1,96 @@ +struct SinkhornDivergence{A<:Sinkhorn,B<:Sinkhorn,C<:Sinkhorn} + algμν::A + algμμ::B + algνν::C +end + +""" + sinkhorn_divergence( + μ::AbstractVecOrMat, + ν::AbstractVecOrMat, + C, + ε, + alg::SinkhornDivergence=SinkhornDivergence( + SinkhornGibbs(), SymmetricSinkhornGibbs(), SymmetricSinkhornGibbs() + ); + kwargs..., + ) + +Compute the Sinkhorn Divergence between finite discrete +measures `μ` and `ν` with respect to a common cost matrix `C`, +entropic regularization parameter `ε` and algorithm `alg`. + +In the default case where `regularization = false`, the Sinkhorn Divergence is that of [^GPC18] and is computed as +```math +\\operatorname{S}_{ε}(μ,ν) := \\operatorname{W}_{ε}(μ,ν) +- \\frac{1}{2}(\\operatorname{W}_{ε}(μ,μ) + \\operatorname{W}_{ε}(ν,ν)), +``` +and ``\\operatorname{W}_{ε}`` is defined as +```math +\\operatorname{W}_{ε}(μ, ν) = \\langle C, γ^\\star \\rangle, +``` +where ``γ^\\star`` is the entropy-regularised transport plan between `μ` and `ν`. +For `regularization = true`, the Sinkhorn Divergence is that of [^FeydyP19] and is computed as above +where ``\\operatorname{W}_{ε}`` is replaced by ``\\operatorname{OT}_{ε}``, the entropy-regularised optimal transport +cost with regulariser penalty. + +The default algorithm for computing the term ``\\operatorname{W}_{ε}(μ, ν)`` is the `SinkhornGibbs` algorithm. +For the terms ``\\operatorname{W}_{ε}(μ, μ)`` and ``\\operatorname{W}_{ε}(ν, ν)``, the symmetric fixed point iteration of [^FeydyP19] is used. +Alternatively, a pre-computed optimal transport `plan` between `μ` and `ν` may be provided. + +[^GPC18]: Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 +[^FeydyP19]: Jean Feydy, Thibault Séjourné, François-Xavier Vialard, Shun-ichi Amari, Alain Trouvé, and Gabriel Peyré. Interpolating between optimal transport and mmd using sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics, pages 2681–2690. PMLR, 2019. +See also: [`sinkhorn2`](@ref) +""" +function sinkhorn_divergence( + μ::AbstractVecOrMat, + ν::AbstractVecOrMat, + C, + ε, + alg::SinkhornDivergence=SinkhornDivergence( + SinkhornGibbs(), SymmetricSinkhornGibbs(), SymmetricSinkhornGibbs() + ); + kwargs..., +) + return sinkhorn_divergence(μ, ν, C, C, C, ε, alg; kwargs...) +end + +""" + sinkhorn_divergence( + μ, + ν, + Cμν, + Cμ, + Cν, + ε, + alg::SinkhornDivergence=SinkhornDivergence( + SinkhornGibbs(), SymmetricSinkhornGibbs(), SymmetricSinkhornGibbs() + ); + kwargs..., + ) + +Compute the Sinkhorn Divergence between finite discrete +measures `μ` and `ν` with respect to the precomputed cost matrices `Cμν`, +`Cμμ` and `Cνν`, entropic regularization parameter `ε` and algorithm `alg`. + +A pre-computed optimal transport `plan` between `μ` and `ν` may be provided. + +See also: [`sinkhorn2`](@ref), [`sinkhorn_divergence`](@ref) +""" +function sinkhorn_divergence( + μ, + ν, + Cμν, + Cμ, + Cν, + ε, + alg::SinkhornDivergence=SinkhornDivergence( + SinkhornGibbs(), SymmetricSinkhornGibbs(), SymmetricSinkhornGibbs() + ); + kwargs..., +) + OTμν = sinkhorn2(μ, ν, Cμν, ε, alg.algμν; kwargs...) + OTμμ = sinkhorn2(μ, Cμ, ε, alg.algμμ; kwargs...) + OTνν = sinkhorn2(ν, Cν, ε, alg.algνν; kwargs...) + return max.(0, OTμν .- (OTμμ .+ OTνν) / 2) +end diff --git a/src/entropic/sinkhorn_epsscaling.jl b/src/entropic/sinkhorn_epsscaling.jl index ec5070f7..5820ddb0 100644 --- a/src/entropic/sinkhorn_epsscaling.jl +++ b/src/entropic/sinkhorn_epsscaling.jl @@ -59,7 +59,7 @@ function sinkhorn(μ, ν, C, ε, alg::SinkhornEpsilonScaling; kwargs...) solve!(solver) # compute final plan - γ = plan(solver) + γ = sinkhorn_plan(solver) return γ end diff --git a/src/entropic/sinkhorn_gibbs.jl b/src/entropic/sinkhorn_gibbs.jl index 51a199df..00cdfe6c 100644 --- a/src/entropic/sinkhorn_gibbs.jl +++ b/src/entropic/sinkhorn_gibbs.jl @@ -119,12 +119,40 @@ function sinkhorn2( ) end +# spceialised sinkhorn2 for SinkhornGibbs +function sinkhorn2( + μ, ν, C, ε, alg::SinkhornGibbs; regularization=false, plan=nothing, kwargs... +) + cost = if regularization && plan === nothing + # special case where we can take advantage of dual objective formula + # build solver + solver = build_solver(μ, ν, C, ε, alg; kwargs...) + # perform Sinkhorn algorithm + solve!(solver) + # return loss + cache = solver.cache + sinkhorn_dual_objective(cache.u, cache.v, cache.Kv, cache.K, solver.eps) + else + γ = if plan === nothing + sinkhorn(μ, ν, C, ε, alg; kwargs...) + else + # check dimensions + checksize(μ, ν, C) + size(plan) == size(C) || error( + "optimal transport plan `plan` and cost matrix `C` must be of the same size", + ) + plan + end + sinkhorn_cost_from_plan(γ, C, ε; regularization=regularization) + end + return cost +end + # interface prestep!(::SinkhornSolver{SinkhornGibbs}, ::Int) = nothing -function plan(solver::SinkhornSolver{SinkhornGibbs}) +function sinkhorn_plan(solver::SinkhornSolver{SinkhornGibbs}) cache = solver.cache - γ = cache.K .* add_singleton(cache.u, Val(2)) .* add_singleton(cache.v, Val(1)) - return γ + return sinkhorn_plan(cache.u, cache.v, cache.K) end diff --git a/src/entropic/sinkhorn_solve.jl b/src/entropic/sinkhorn_solve.jl index 10325639..e41c24b1 100644 --- a/src/entropic/sinkhorn_solve.jl +++ b/src/entropic/sinkhorn_solve.jl @@ -55,7 +55,9 @@ function check_convergence( end # Common solve! operation -function solve!(solver::Union{SinkhornSolver,SinkhornBarycenterSolver}) +function solve!( + solver::Union{SinkhornSolver,SinkhornBarycenterSolver,SymmetricSinkhornSolver} +) # unpack solver atol = solver.atol rtol = solver.rtol diff --git a/src/entropic/sinkhorn_stabilized.jl b/src/entropic/sinkhorn_stabilized.jl index ccced224..861cb72a 100644 --- a/src/entropic/sinkhorn_stabilized.jl +++ b/src/entropic/sinkhorn_stabilized.jl @@ -133,7 +133,7 @@ function update_K!(solver::SinkhornSolver{<:SinkhornStabilized}) end # obtain plan -function plan(solver::SinkhornSolver{<:SinkhornStabilized}) +function sinkhorn_plan(solver::SinkhornSolver{<:SinkhornStabilized}) absorb!(solver) return copy(solver.cache.K) end diff --git a/src/entropic/symmetric.jl b/src/entropic/symmetric.jl new file mode 100644 index 00000000..92a222fc --- /dev/null +++ b/src/entropic/symmetric.jl @@ -0,0 +1,148 @@ +struct SymmetricSinkhornSolver{A<:Sinkhorn,M,CT,E<:Real,T<:Real,R<:Real,C1,C2} + source::M + C::CT + eps::E + alg::A + atol::T + rtol::R + maxiter::Int + check_convergence::Int + cache::C1 + convergence_cache::C2 +end + +struct SymmetricSinkhornGibbs <: Sinkhorn end + +Base.show(io::IO, ::SymmetricSinkhornGibbs) = print(io, "Symmetric Sinkhorn algorithm") + +struct SymmetricSinkhornGibbsCache{U,KT} + u::U + K::KT + Kv::U +end + +function build_cache( + ::Type{T}, + ::SymmetricSinkhornGibbs, + size2::Tuple, + μ::AbstractVecOrMat, + C::AbstractMatrix, + ε::Real, +) where {T} + K = similar(C, T) + @. K = exp(-C / ε) + + u = similar(μ, T, size(μ, 1), size2...) + fill!(u, one(T)) + + Kv = similar(u) + + return SymmetricSinkhornGibbsCache(u, K, Kv) +end + +function check_convergence(solver::SymmetricSinkhornSolver) + A_batched_mul_B!(solver.cache.Kv, solver.cache.K, solver.cache.u) + return OptimalTransport.check_convergence( + solver.source, + solver.cache.u, + solver.cache.Kv, + solver.convergence_cache, + solver.atol, + solver.rtol, + ) +end + +function build_solver( + μ::AbstractVecOrMat, + C::AbstractMatrix, + ε::Real, + alg::SymmetricSinkhornGibbs; + atol=nothing, + rtol=nothing, + check_convergence=1, + maxiter::Int=25, +) + size2 = size(μ)[2:end] + + # compute type + T = float(Base.promote_eltype(μ, one(eltype(C)) / ε)) + + # build caches using SinkhornGibbsCache struct (since there is no dependence on ν) + cache = build_cache(T, alg, size2, μ, C, ε) + convergence_cache = build_convergence_cache(T, size2, μ) + + # set tolerances + _atol = atol === nothing ? 0 : atol + _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol + + # create solver + solver = SymmetricSinkhornSolver( + μ, C, ε, alg, _atol, _rtol, maxiter, check_convergence, cache, convergence_cache + ) + return solver +end + +function init_step!(solver::SymmetricSinkhornSolver{SymmetricSinkhornGibbs}) + source = solver.source + cache = solver.cache + u = cache.u + K = cache.K + Kv = cache.Kv + return A_batched_mul_B!(Kv, K, u) +end + +prestep!(::SymmetricSinkhornSolver{SymmetricSinkhornGibbs}, ::Int) = nothing + +function step!(solver::SymmetricSinkhornSolver{SymmetricSinkhornGibbs}, iter::Int) + source = solver.source + cache = solver.cache + u = cache.u + K = cache.K + Kv = cache.Kv + @. u = sqrt(source * u / Kv) + return A_batched_mul_B!(Kv, K, u) +end + +function sinkhorn_plan(solver::SymmetricSinkhornSolver{SymmetricSinkhornGibbs}) + cache = solver.cache + return sinkhorn_plan(cache.u, cache.u, cache.K) +end + +function sinkhorn(μ, C, ε, alg::SymmetricSinkhornGibbs; kwargs...) + # build solver + solver = build_solver(μ, C, ε, alg; kwargs...) + + # perform Sinkhorn algorithm + solve!(solver) + + # compute optimal transport plan + γ = sinkhorn_plan(solver) + + return γ +end + +function sinkhorn2( + μ, C, ε, alg::SymmetricSinkhornGibbs; regularization=false, plan=nothing, kwargs... +) + cost = if regularization && plan === nothing + # special case where we can take advantage of dual objective formula + # build solver + solver = build_solver(μ, C, ε, alg; kwargs...) + # perform Sinkhorn algorithm + solve!(solver) + # return loss + cache = solver.cache + sinkhorn_dual_objective(cache.u, cache.u, cache.Kv, cache.K, solver.eps) + else + γ = if plan === nothing + sinkhorn(μ, C, ε, alg; kwargs...) + else + size(plan) == size(C) || error( + "optimal transport plan `plan` and cost matrix `C` must be of the same size", + ) + plan + end + sinkhorn_cost_from_plan(γ, C, ε; regularization=regularization) + end + return cost +end diff --git a/test/entropic/sinkhorn_divergence.jl b/test/entropic/sinkhorn_divergence.jl new file mode 100644 index 00000000..16bc9d13 --- /dev/null +++ b/test/entropic/sinkhorn_divergence.jl @@ -0,0 +1,130 @@ +using OptimalTransport + +using Distances +using ForwardDiff +using ReverseDiff +using LogExpFunctions +using PythonOT: PythonOT +using StatsBase +using LinearAlgebra +using Random +using Test + +const POT = PythonOT + +Random.seed!(100) + +@testset "sinkhorn_divergence.jl" begin + @testset "fixed_support" begin + # size of problem + N = 250 + # number of target measures + x = range(-1, 1; length=N) + C = pairwise(SqEuclidean(), x) + f(x; μ, σ) = exp(-((x - μ) / σ)^2) + # regularization parameter + ε = 0.05 + @testset "basic" begin + μ = normalize!(f.(x; μ=0, σ=0.5), 1) + M = 100 + + ν_all = [normalize!(f.(x; μ=y, σ=0.5), 1) for y in range(-1, 1; length=M)] + + for reg in (true, false) + loss = map(ν -> sinkhorn_divergence(μ, ν, C, ε; regularization=reg), ν_all) + loss_ = map( + ν -> + sinkhorn2(μ, ν, C, ε; regularization=reg) - + ( + sinkhorn2(μ, μ, C, ε; regularization=reg) + + sinkhorn2(ν, ν, C, ε; regularization=reg) + ) / 2, + ν_all, + ) + + @test loss ≈ loss_ rtol = 1e-6 + @test all(loss .≥ 0) + @test sinkhorn_divergence(μ, μ, C, ε) ≈ 0 atol = 1e-9 + end + end + @testset "batch" begin + M = 10 + μ = hcat([normalize!(f.(x; μ=randn(), σ=0.5), 1) for _ in 1:M]...) + ν = hcat([normalize!(f.(x; μ=randn(), σ=0.5), 1) for _ in 1:M]...) + for reg in (true, false) + loss_batch = sinkhorn_divergence(μ, ν, C, ε; regularization=reg) + @test loss_batch ≈ [ + sinkhorn_divergence(x, y, C, ε; regularization=reg) for + (x, y) in zip(eachcol(μ), eachcol(ν)) + ] + loss_batch_μ = sinkhorn_divergence(μ, ν[:, 1], C, ε; regularization=reg) + @test loss_batch_μ ≈ [ + sinkhorn_divergence(x, ν[:, 1], C, ε; regularization=reg) for + x in eachcol(μ) + ] + loss_batch_ν = sinkhorn_divergence(μ[:, 1], ν, C, ε; regularization=reg) + @test loss_batch_ν ≈ [ + sinkhorn_divergence(μ[:, 1], y, C, ε; regularization=reg) for + y in eachcol(ν) + ] + end + end + @testset "AD" begin + ε = 0.05 + μ = normalize!(f.(x; μ=-0.5, σ=0.5), 1) + ν = normalize!(f.(x; μ=0.5, σ=0.5), 1) + for Diff in [ForwardDiff, ReverseDiff] + for reg in (true, false) + ∇ = Diff.gradient(log.(ν)) do xs + sinkhorn_divergence(μ, softmax(xs), C, ε; regularization=reg) + end + @test size(∇) == size(ν) + ∇ = Diff.gradient(log.(μ)) do xs + sinkhorn_divergence(μ, softmax(xs), C, ε; regularization=reg) + end + @test norm(∇, Inf) ≈ 0 atol = 1e-9 # Sinkhorn divergence has minimum at SD(μ, μ) + end + end + end + end + @testset "empirical" begin + N = 50 + M = 64 + d = 2 + μ_spt = randn(N, d) + ν_spt = 1.5randn(M, d) + μ = fill(1 / N, N) + ν = fill(1 / M, M) + Cμν = pairwise(SqEuclidean(), μ_spt', ν_spt'; dims=2) + Cμ = pairwise(SqEuclidean(), μ_spt'; dims=2) + Cν = pairwise(SqEuclidean(), ν_spt'; dims=2) + ε = 0.1 * max(mean(Cμν), mean(Cμ), mean(Cν)) + + @testset "basic" begin + for reg in (true, false) + @test sinkhorn_divergence(μ, ν, Cμν, Cμ, Cν, ε; regularization=reg) ≥ 0 + @test sinkhorn_divergence(μ, μ, Cμ, Cμ, Cμ, ε; regularization=reg) ≈ 0 atol = + 1e-6 + end + end + + @testset "AD" begin + for Diff in [ForwardDiff, ReverseDiff] + for reg in (true, false) + ∇ = Diff.gradient(ν_spt) do xs + Cμν = pairwise(SqEuclidean(), μ_spt', xs'; dims=2) + Cν = pairwise(SqEuclidean(), xs'; dims=2) + sinkhorn_divergence(μ, ν, Cμν, Cμ, Cν, ε) + end + @test size(∇) == size(ν_spt) + ∇ = Diff.gradient(μ_spt) do xs + Cμν = pairwise(SqEuclidean(), μ_spt', xs'; dims=2) + Cν = pairwise(SqEuclidean(), xs'; dims=2) + sinkhorn_divergence(μ, μ, Cμν, Cμ, Cν, ε) + end + @test norm(∇, Inf) ≈ 0 atol = 1e-6 + end + end + end + end +end