## Example of fir with two rotations

In [1]:
using Pkg; Pkg.activate("../.")
using Revise 

using LinearAlgebra, Statistics, Distributions 
using OrdinaryDiffEq
using SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL

using SphereFit

[32m[1m  Activating[22m[39m project at `~/.julia/dev/SphereFit`


In [2]:
using Random
rng = Random.default_rng()
Random.seed!(rng, 000666)
# Fisher concentration parameter on observations (small = more dispersion)
κ = 200 

200

Let's create a simple example consisting in two solid rotations around the globe with Fisher noise on top. 

In [3]:
# Total time simulation
tspan = [0, 130.0]
# Number of sample points
N_samples = 50
# Times where we sample points
times_samples = sort(rand(sampler(Uniform(tspan[1], tspan[2])), N_samples))

# Expected maximum angular deviation in one unit of time (degrees)
Δω₀ = 1.0   
# Angular velocity 
ω₀ = Δω₀ * π / 180.0
# Change point
τ₀ = 65.0
# Angular momentum
L0 = ω₀    .* [1.0, 0.0, 0.0]
L1 = 0.5ω₀ .* [0.0, sqrt(2), sqrt(2)]

# Solver tolerances 
reltol = 1e-7
abstol = 1e-7

1.0e-7

In [4]:
function true_rotation!(du, u, p, t)
    if t < τ₀
        L = p[1]
    else 
        L = p[2]
    end
    du .= cross(L, u)
end

prob = ODEProblem(true_rotation!, [0.0, 0.0, -1.0], tspan, [L0, L1])
true_sol  = solve(prob, Tsit5(), reltol=reltol, abstol=abstol, saveat=times_samples)

# Add Fisher noise to true solution 
X_noiseless = Array(true_sol)
X_true = mapslices(x -> rand(sampler(VonMisesFisher(x/norm(x), κ)), 1), X_noiseless, dims=1)

3×50 Matrix{Float64}:
 -0.0385817  -0.102203    0.0130063  …  -0.801308   -0.804544   -0.777151
  0.0229773   0.0624114   0.134454       0.593017    0.591397    0.625861
 -0.998991   -0.992804   -0.990834      -0.0789676  -0.0543948   0.065833

Let's make a plot of this using `PyCall` to call `cartopy` and `matplotlib`. 

In [5]:
X_true

3×50 Matrix{Float64}:
 -0.0385817  -0.102203    0.0130063  …  -0.801308   -0.804544   -0.777151
  0.0229773   0.0624114   0.134454       0.593017    0.591397    0.625861
 -0.998991   -0.992804   -0.990834      -0.0789676  -0.0543948   0.065833

In [6]:
X_true_sph = cart2sph(X_true, radians=false)

2×50 Matrix{Float64}:
 149.224   148.589    84.4747  104.938   …  143.496    143.681    141.155
 -87.4262  -83.1222  -82.2367  -78.8478      -4.52923   -3.11813    3.77468

## Training

In [32]:
data   = SphereData(times=times_samples, directions=X_true, kappas=nothing)
params = SphereParameters(tmin=tspan[1], tmax=tspan[2], u0=[0.0, 0.0, -1.0], ωmax=2*ω₀, reltol=reltol, abstol=abstol)

θ_trained, U, st = train_sphere(data, params, rng, nothing)

Current loss after 50 iterations: 0.16101314731406519


Current loss after 100 iterations: 0.11037340614009217


Current loss after 150 iterations: 0.0672007616613746


Current loss after 200 iterations: 0.03152708372080601


Current loss after 250 iterations: 0.010739288203707937


Current loss after 300 iterations: 0.005644418949798711


Current loss after 350 iterations: 0.005083565929698203


Current loss after 400 iterations: 0.005015284713162176


Current loss after 450 iterations: 0.00497058269872371


Current loss after 500 iterations: 0.004906234394502047


Current loss after 550 iterations: 0.004776632831478111


Current loss after 600 iterations: 0.0042792889392969


Current loss after 650 iterations: 0.004141349801464287


Current loss after 700 iterations: 0.004110626173838344


Current loss after 750 iterations: 0.004087219223675899


Current loss after 800 iterations: 0.004066461747562576


Current loss after 850 iterations: 0.0040480743234733284


Current loss after 900 iterations: 0.004031816367129134


Current loss after 950 iterations: 0.004017436125286587


Current loss after 1000 iterations: 0.004004684244888116
Training loss after 1001 iterations: 0.004004684244888116


Current loss after 1050 iterations: 0.0033236596189285426


Current loss after 1100 iterations: 0.003111357177414428


Current loss after 1150 iterations: 0.002991621613023731


Current loss after 1200 iterations: 0.002975056694828765


Current loss after 1250 iterations: 0.002966203059360758


Current loss after 1300 iterations: 0.0029608994616977255
Final training loss after 1302 iterations: 0.0029608677531788394


((layer_1 = (weight = [-0.029283678130141755; -0.7846595845532214; … ; 0.4824681548469449; -0.7653773405771641;;], bias = [1.3877211281620812; -0.2292506085891694; … ; -0.21354364462799935; -0.2111699908354089;;]), layer_2 = (weight = [0.4331927448885181 0.19025281760115167 … -0.23214131965235882 -0.48444998249163995; -0.005934016578657395 0.21407031256851175 … -0.16415189329723437 0.40785256119032065; … ; 0.5906190183540482 0.13771128071783653 … 0.1933073587202226 0.2518645154746848; 0.5575251432951867 -0.3914652962047988 … 0.26225009777192854 -0.07533374208994519], bias = [-0.19391460134147748; -0.12147743910258456; … ; -0.03269148361741536; 0.04856695619410239;;]), layer_3 = (weight = [-0.2734986908908024 -0.16486663919902309 … -0.6780334755960238 -0.23977689898977597; -0.08437333020280792 -0.3810663954777153 … -0.32269412553855653 0.32425612578467106; … ; 0.5792528765442908 0.3311609600385446 … -0.3101348642938916 -0.1359125208377849; 0.5516218199477535 0.4562211234554291 … -0.1138

### Python plots

In [31]:
using PyPlot, PyCall

mpl_colors = pyimport("matplotlib.colors")
mpl_colormap = pyimport("matplotlib.cm")
sns = pyimport("seaborn")
ccrs = pyimport("cartopy.crs")
feature = pyimport("cartopy.feature")

plt.figure(figsize=(10,10))
ax = plt.axes(projection=ccrs.Orthographic(central_latitude=-20, central_longitude=150))

# ax.coastlines()
ax.gridlines()
ax.set_global()

cmap = mpl_colormap.get_cmap("viridis")

sns.scatterplot(ax=ax, x = X_true_sph[1,:], y=X_true_sph[2, :], 
                hue = times_samples, s=50,
                palette="viridis",
                transform = ccrs.PlateCarree());

plt.savefig("testing.pdf", format="pdf")
# plt.show()

In [38]:
typeof(st)

NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), NTuple{4, NamedTuple{(), Tuple{}}}}