Skip to content

Commit

Permalink
use ArrayInterface.restructure in update!
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jun 10, 2021
1 parent 9f1966f commit b7450be
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 68 deletions.
150 changes: 95 additions & 55 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ version = "0.3.4"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "f1b523983a58802c4695851926203b36e28f09db"
git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.3.0"
version = "3.3.1"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[ArrayInterface]]
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "045ff5e1bc8c6fb1ecb28694abba0a0d55b5f4f5"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.17"

[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

Expand All @@ -38,22 +44,22 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.1"

[[CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "TimerOutputs"]
git-tree-sha1 = "a6ce96dcf22fc4f1bfdfac02d54f0b77ecf2a4cc"
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
git-tree-sha1 = "364179416eabc34c9ca32126a6bdb431680c3bad"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "3.0.3"
version = "3.2.1"

[[ChainRules]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"]
git-tree-sha1 = "1f410fba5c04d03ab712f348f1542e6059376547"
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "720fa9a9ce61ff18842a40f501d6a1f8ba771c64"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.61"
version = "0.8.6"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "42e3c181483fbd2c416087a0a93838803e358358"
git-tree-sha1 = "8b31cc69cbc38c5c826aaa1c890c694be3622d99"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.38"
version = "0.10.3"

[[CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
Expand All @@ -63,15 +69,15 @@ version = "0.7.0"

[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "32a2b8af383f11cbb65803883837a149d10dfe8a"
git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.10.12"
version = "0.11.0"

[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "Reexport"]
git-tree-sha1 = "82f4e6ff9f847eca3e5ebc666ea2cd7b48e8b47e"
git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.12.7"
version = "0.12.8"

[[CommonSubexpressions]]
deps = ["MacroTools", "Test"]
Expand All @@ -81,9 +87,9 @@ version = "0.3.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956"
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.27.0"
version = "3.30.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
Expand Down Expand Up @@ -124,6 +130,12 @@ version = "1.0.2"
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.5"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Expand Down Expand Up @@ -158,23 +170,28 @@ uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.2.1"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
git-tree-sha1 = "9c95b2fd5c16bc7f97371e9f92f0fef77e0f5957"
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
git-tree-sha1 = "df5b8569904c5c10e84c640984cfff054b18c086"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "6.2.2"
version = "6.4.1"

[[GPUCompiler]]
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "6eadd2321dc3ac0fc9d530ab01c2caa7fe5d74c6"
git-tree-sha1 = "42d635f6d87af125b86288df3819f805fb4d851a"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.11.4"
version = "0.11.5"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.4.2"

[[IfElse]]
git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef"
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
version = "0.1.0"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -193,9 +210,9 @@ version = "0.8.4"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194"
git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "3.6.0"
version = "3.7.1"

[[LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
Expand Down Expand Up @@ -224,6 +241,12 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[LogExpFunctions]]
deps = ["DocStringExtensions", "LinearAlgebra"]
git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.2.4"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

Expand Down Expand Up @@ -255,9 +278,9 @@ version = "0.4.4"

[[Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c"
git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.5"
version = "1.0.0"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
Expand All @@ -267,15 +290,15 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[NNlib]]
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "80b8360670f445d88b3475e88b33bbcc92f7866e"
git-tree-sha1 = "0bf1fbb9dc557f2af9fb7e1337366d69de0dc78c"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.19"
version = "0.7.21"

[[NNlibCUDA]]
deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
git-tree-sha1 = "4b368b466bcdd25d448a5b20de4b7e481d68b88e"
git-tree-sha1 = "86d4d75e1091fe89d56422044be0dcc83766d2b4"
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
version = "0.1.0"
version = "0.1.2"

[[NaNMath]]
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
Expand All @@ -287,24 +310,24 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3"
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+4"
version = "0.5.5+0"

[[OrderedCollections]]
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.0"
version = "1.4.1"

[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Preferences]]
deps = ["TOML"]
git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902"
git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.2.1"
version = "1.2.2"

[[Printf]]
deps = ["Unicode"]
Expand All @@ -322,16 +345,22 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[Random123]]
deps = ["Libdl", "Random", "RandomNumbers"]
git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5"
uuid = "74087812-796a-5b5d-8853-05524746bad3"
version = "1.3.1"

[[RandomNumbers]]
deps = ["Random", "Requires"]
git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f"
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
version = "1.4.0"

[[Reexport]]
git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5"
git-tree-sha1 = "5f6c21241f0f655da3952fd60aa18477cf96c220"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.0.0"
version = "1.1.0"

[[Requires]]
deps = ["UUIDs"]
Expand All @@ -344,9 +373,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Scratch]]
deps = ["Dates"]
git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6"
git-tree-sha1 = "0b4b7f1393cff97c33891da2a0bf69c6ed241fda"
uuid = "6c6a2e73-6563-6170-7368-637461726353"
version = "1.0.3"
version = "1.1.0"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand All @@ -359,36 +388,47 @@ uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
deps = ["DataStructures"]
git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"
version = "1.0.0"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
git-tree-sha1 = "a50550fa3164a8c46747e62063b4d774ac1bcf49"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.3.0"
version = "1.5.1"

[[Static]]
deps = ["IfElse"]
git-tree-sha1 = "2740ea27b66a41f9d213561a04573da5d3823d4b"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
version = "0.2.5"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "e8cd1b100d37f5b4cfd2c83f45becf61c762eaf7"
git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.1.1"
version = "1.2.2"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsAPI]]
git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.0.0"

[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "4bc58880426274277a066de306ef19ecc22a6863"
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.5"
version = "0.33.8"

[[TOML]]
deps = ["Dates"]
Expand All @@ -403,10 +443,10 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Printf"]
git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236"
deps = ["ExprTools", "Printf"]
git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.8"
version = "0.5.9"

[[TranscodingStreams]]
deps = ["Random", "Test"]
Expand All @@ -433,9 +473,9 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a"

[[Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "927209c83efa62256788a9880c191774c07c5b51"
git-tree-sha1 = "b1d95edd4e693066c38c13a10aab0a8f6a6e2f65"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.10"
version = "0.6.12"

[[ZygoteRules]]
deps = ["MacroTools"]
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.12.4"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Expand All @@ -29,6 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
AbstractTrees = "0.3"
Adapt = "3.0"
ArrayInterface = "3.1"
CUDA = "3"
CodecZlib = "0.7"
Colors = "0.12"
Expand All @@ -45,9 +47,10 @@ julia = "1.6"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays"]
1 change: 0 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using Zygote, MacroTools, Juno, Reexport
using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback, @nograd

export gradient

export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
Expand Down
1 change: 1 addition & 0 deletions src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Optimise

using LinearAlgebra
import ArrayInterface

export train!, update!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
Expand Down
Loading

0 comments on commit b7450be

Please sign in to comment.