Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ PDE training examples are provided in `example` folder.

[Time dependent Navier-Stokes equation](example/FlowOverCircle)

### Super Resolution with MNO

[Super resolution on time dependent Navier-Stokes equation](example/SuperResolution)

## Roadmap

- [x] `FourierOperator` layer
Expand Down
75 changes: 75 additions & 0 deletions docs/src/assets/notebook/super_resolution_mno.jl.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion example/FlowOverCircle/src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function gen_data(ts::AbstractRange)
𝐩s = Array{Float32}(undef, 1, n, m, length(ts))
for (i, t) in enumerate(ts)
sim_step!(circ, t)
𝐩s[:, :, :, i] = Float32.(circ.flow.p)[2:end-1, 2:end-1]
𝐩s[1, :, :, i] .= Float32.(circ.flow.p)[2:end-1, 2:end-1]
end

return 𝐩s
Expand Down
19 changes: 19 additions & 0 deletions example/SuperResolution/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name = "SuperResolution"
uuid = "a8258e1f-331c-4af2-83e9-878628278453"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
WaterLily = "ed894a53-35f9-47f1-b17f-85db9237eebd"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
18 changes: 18 additions & 0 deletions example/SuperResolution/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Super Resolution

The time dependent Navier-Stokes equation is learned by the `MarkovNeuralOperator` with only one time step information.
The result of this example can be found [here](https://foldfelis.github.io/NeuralOperators.jl/dev/assets/notebook/super_resolution_mno.jl.html).

Apart from just training a MNO, here, we train the model with lower resolution (96x64) and inference result with higher resolution (192x128).

| **Ground Truth** | **Inferenced** |
|:----------------:|:--------------:|
| ![](gallery/ans.gif) | ![](gallery/inferenced.gif) |

Change directory to `example/SuperResolution` and use following commend to train model:

```julia
$ julia --proj

julia> using SuperResolution; SuperResolution.train()
```
Binary file added example/SuperResolution/gallery/ans.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example/SuperResolution/gallery/inferenced.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
101 changes: 101 additions & 0 deletions example/SuperResolution/notebook/super_resolution_mno.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
### A Pluto.jl notebook ###
# v0.16.1

using Markdown
using InteractiveUtils

# ╔═╡ 194baef2-0417-11ec-05ab-4527ef614024
using Pkg; Pkg.develop(path=".."); Pkg.activate("..")

# ╔═╡ 38c9ced5-dcf8-4e03-ac07-7c435687861b
using SuperResolution, Plots

# ╔═╡ 50ce80a3-a1e8-4ba9-a032-dad315bcb432
md"
# Super Resolution with MNO

JingYu Ning
"

# ╔═╡ 59769504-ebd5-4c6f-981f-d03826d8e34a
md"
This demo trains a Markov neural operator (MNO) introduced by [Zongyi Li *et al.*](https://arxiv.org/abs/2106.06898) with only one time step information. Then composed the operator to a Markov chain and inference the Navier-Stokes equations."

# ╔═╡ 823b3547-6723-43cf-85e6-cc6eb44efea1
md"
## Generate data
"

# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7
begin
n = 10
data = SuperResolution.gen_data(LinRange(100, 100+n-1, n))
end;

# ╔═╡ 5531bba6-94bd-4c99-be8c-43fe19ad8a60
md"
## Training
"

# ╔═╡ 74fc528f-ccd4-4670-9b17-dbfa7a1c74b6
md"
Apart from just training a MNO, here, we train the model with lower resolution (96x64) and inference result with higher resolution (192x128).
"

# ╔═╡ f6d1ce85-a195-4ab1-bd3a-dbd4b0d1fcca
begin
anim = @animate for i in 1:size(data)[end]
heatmap(data[1, 1:2:end, 1:2:end, i]', color=:coolwarm, clim=(-1.5, 1.5))
scatter!(
[size(data, 3)÷4-1], [size(data, 3)÷4-1],
markersize=45, color=:black, legend=false, ticks=false
)
annotate!(5, 5, text("i=$i", :left))
end
gif(anim, fps=2)
end

# ╔═╡ 55058635-c7e9-4ee3-81c2-0153e84f4c8e
md"
## Inference

Use the first data generated above as the initial state, and apply the operator recurrently.
"

# ╔═╡ fbc287b8-f232-4350-9948-2091908e5a30
begin
m = SuperResolution.get_model()

states = Array{Float32}(undef, size(data))
states[:, :, :, 1] .= view(data, :, :, :, 1)
for i in 2:size(data)[end]
states[:, :, :, i:i] .= m(view(states, :, :, :, i-1:i-1))
end
end

# ╔═╡ a0b5e94c-a839-4cc0-a325-1a4ac39fafbc
begin
anim_model = @animate for i in 1:size(states)[end]
heatmap(states[1, :, :, i]', color=:coolwarm, clim=(-1.5, 1.5))
scatter!(
[size(data, 3)÷2-1], [size(data, 3)÷2-1],
markersize=45, color=:black, legend=false, ticks=false
)
annotate!(5, 5, text("i=$i", :left))
end
gif(anim_model, fps=2)
end

# ╔═╡ Cell order:
# ╟─50ce80a3-a1e8-4ba9-a032-dad315bcb432
# ╟─59769504-ebd5-4c6f-981f-d03826d8e34a
# ╠═194baef2-0417-11ec-05ab-4527ef614024
# ╠═38c9ced5-dcf8-4e03-ac07-7c435687861b
# ╟─823b3547-6723-43cf-85e6-cc6eb44efea1
# ╠═5268feee-bda2-4612-9d4c-a1db424a11c7
# ╟─5531bba6-94bd-4c99-be8c-43fe19ad8a60
# ╟─74fc528f-ccd4-4670-9b17-dbfa7a1c74b6
# ╠═f6d1ce85-a195-4ab1-bd3a-dbd4b0d1fcca
# ╟─55058635-c7e9-4ee3-81c2-0153e84f4c8e
# ╠═fbc287b8-f232-4350-9948-2091908e5a30
# ╟─a0b5e94c-a839-4cc0-a325-1a4ac39fafbc
63 changes: 63 additions & 0 deletions example/SuperResolution/src/SuperResolution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
module SuperResolution

using NeuralOperators
using Flux
using CUDA
using JLD2

include("data.jl")

function update_model!(model_file_path, model)
model = cpu(model)
jldsave(model_file_path; model)
@warn "model updated!"
end

function train()
if has_cuda()
@info "CUDA is on"
device = gpu
CUDA.allowscalar(false)
else
device = cpu
end

m = Chain(
Dense(1, 64),
FourierOperator(64=>64, (24, 24), gelu),
FourierOperator(64=>64, (24, 24), gelu),
FourierOperator(64=>64, (24, 24), gelu),
FourierOperator(64=>64, (24, 24), gelu),
Dense(64, 1),
) |> device

loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]

opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))

@info "gen data... "
@time loader_train, loader_test = get_dataloader()

losses = Float32[]
function validate()
validation_loss = sum(loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test)/length(loader_test)
@info "loss: $validation_loss"

push!(losses, validation_loss)
(losses[end] == minimum(losses)) && update_model!(joinpath(@__DIR__, "../model/model.jld2"), m)
end
call_back = Flux.throttle(validate, 5, leading=false, trailing=true)

data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
Flux.@epochs 50 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
end

function get_model()
f = jldopen(joinpath(@__DIR__, "../model/model.jld2"))
model = f["model"]
close(f)

return model
end

end
43 changes: 43 additions & 0 deletions example/SuperResolution/src/data.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using WaterLily
using LinearAlgebra: norm2

"""
circle(n, m; Re=250)

This function is copy from [WaterLily](https://github.com/weymouth/WaterLily.jl)
"""
function circle(n, m; Re=250)
# Set physical parameters
U, R, center = 1., m/8., [m/2, m/2]
ν = U * R / Re

body = AutoBody((x,t) -> norm2(x .- center) - R)
Simulation((n+2, m+2), [U, 0.], R; ν, body)
end

function gen_data(ts::AbstractRange)
n, m = 2 * 3(2^5), 2 * 2^6
circ = circle(n, m)

𝐩s = Array{Float32}(undef, 1, n, m, length(ts))
for (i, t) in enumerate(ts)
sim_step!(circ, t)
𝐩s[1, :, :, i] .= Float32.(circ.flow.p)[2:end-1, 2:end-1]
end

return 𝐩s
end

function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::Float64=0.95, batchsize=100)
data = gen_data(ts)

n_train, n_test = floor(Int, length(ts)*ratio), floor(Int, length(ts)*(1-ratio))

𝐱_train, 𝐲_train = data[:, 1:2:end, 1:2:end, 1:(n_train-1)], data[:, 1:2:end, 1:2:end, 2:n_train]
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)

𝐱_test, 𝐲_test = data[:, :, :, (end-n_test+1):(end-1)], data[:, :, :, (end-n_test+2):end]
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)

return loader_train, loader_test
end
3 changes: 3 additions & 0 deletions example/SuperResolution/test/data.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
@testset "flow over circle" begin

end
6 changes: 6 additions & 0 deletions example/SuperResolution/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using SuperResolution
using Test

@testset "SuperResolution" begin
include("data.jl")
end