Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.
Merged
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/Manifest.toml
data
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
authors = ["JingYu Ning <foldfelis@gmail.com> and contributors"]
version = "0.1.0"

[deps]
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"

[compat]
julia = "1.6"

Expand Down
7 changes: 5 additions & 2 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
module NeuralOperators
function __init__()
register_datasets()
end

# Write your package code here.

include("preprocess.jl")
include("fourier.jl")
end
81 changes: 81 additions & 0 deletions src/fourier.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
using Flux
using FFTW
using Tullio

export
SpectralConv1d,
FourierOperator,
FNO

c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...) * im

struct SpectralConv1d{T, S}
weight::T
in_channel::S
out_channel::S
modes::S
σ
end

function SpectralConv1d(
ch::Pair{<:Integer,<:Integer},
modes::Integer,
σ=identity;
init=c_glorot_uniform,
T::DataType=ComplexF32
)
in_chs, out_chs = ch
scale = one(T) / (in_chs * out_chs)
weights = scale * init(out_chs, in_chs, modes)

return SpectralConv1d(weights, in_chs, out_chs, modes, σ)
end

Flux.@functor SpectralConv1d

function (m::SpectralConv1d)(𝐱::AbstractArray)
𝐱_fft = fft(𝐱, 2) # [in_chs, x, batch]
𝐱_selected = 𝐱_fft[:, 1:m.modes, :] # [in_chs, modes, batch]

# [out_chs, modes, batch] <- [in_chs, modes, batch] [out_chs, in_chs, modes]
@tullio 𝐱_weighted[o, m, b] := 𝐱_selected[i, m, b] * m.weight[o, i, m]

s = size(𝐱_weighted)
d = size(𝐱, 2) - m.modes
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, s[1], d, s[3:end]...), dims=2)

𝐱_out = ifft(𝐱_padded, 2)

return m.σ.(𝐱_out)
end

function FourierOperator(
ch::Pair{<:Integer,<:Integer},
modes::Integer,
σ=identity
)
return Chain(
Parallel(+,
Dense(ch.first, ch.second, init=c_glorot_uniform),
SpectralConv1d(ch, modes)
),
x -> σ.(x)
)
end

function FNO()
modes = 16
ch = 64 => 64
σ = x -> @. log(1 + exp(x))

return Chain(
Dense(2, 64, init=c_glorot_uniform),
FourierOperator(ch, modes, σ),
FourierOperator(ch, modes, σ),
FourierOperator(ch, modes, σ),
FourierOperator(ch, modes),
Dense(64, 128, σ, init=c_glorot_uniform),
Dense(128, 1, init=c_glorot_uniform),
flatten
)
end
33 changes: 33 additions & 0 deletions src/preprocess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using DataDeps
using Fetch
using MAT

export
get_data

function register_datasets()
register(DataDep(
"BurgersR10",
"""
Burgers' equation dataset from
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
""",
"https://drive.google.com/file/d/16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe/view?usp=sharing",
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd",
fetch_method=gdownload,
post_fetch_method=unpack
))
end

function get_data(; n=1000, Δsamples=2^3, grid_size=div(2^13, Δsamples))
file = matopen(joinpath(datadep"BurgersR10", "burgers_data_R10.mat"))
x_data = collect(read(file, "a")[1:n, 1:Δsamples:end]')
y_data = collect(read(file, "u")[1:n, 1:Δsamples:end]')
close(file)

x_loc_data = Array{Float32, 3}(undef, 2, grid_size, n)
x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 1, grid_size), n), (grid_size, n))
x_loc_data[2, :, :] .= x_data

return x_loc_data, y_data
end
48 changes: 48 additions & 0 deletions test/fourier.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using Flux

@testset "SpectralConv1d" begin
modes = 16
ch = 64 => 64

m = Chain(
Dense(2, 64, init=NeuralOperators.c_glorot_uniform),
SpectralConv1d(ch, modes)
)

𝐱, _ = get_data()
@test size(m(𝐱)) == (64, 1024, 1000)

T = Float32
loss(x, y) = Flux.mse(real.(m(x)), y)
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
Flux.train!(loss, params(m), data, Flux.ADAM())
end

@testset "FourierOperator" begin
modes = 16
ch = 64 => 64

m = Chain(
Dense(2, 64, init=NeuralOperators.c_glorot_uniform),
FourierOperator(ch, modes)
)

𝐱, _ = get_data()
@test size(m(𝐱)) == (64, 1024, 1000)

T = Float32
loss(x, y) = Flux.mse(real.(m(x)), y)
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
Flux.train!(loss, params(m), data, Flux.ADAM())
end

@testset "FNO" begin
𝐱, 𝐲 = get_data()
𝐱, 𝐲 = Float32.(𝐱), Float32.(𝐲)
@test size(FNO()(𝐱)) == size(𝐲)

m = FNO()
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
data = [(𝐱[:, :, 1:5], 𝐲[:, 1:5])]
Flux.train!(loss, params(m), data, Flux.ADAM())
end
6 changes: 6 additions & 0 deletions test/preprocess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@testset "get data" begin
xs, ys = get_data()

@test size(xs) == (2, 1024, 1000)
@test size(ys) == (1024, 1000)
end
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using NeuralOperators
using Test

ENV["DATADEPS_ALWAYS_ACCEPT"] = true

@testset "NeuralOperators.jl" begin
# Write your tests here.
include("preprocess.jl")
include("fourier.jl")
end