<a href="https://colab.research.google.com/github/RCortez25/Scientific-Machine-Learning/blob/main/Differential_equations/Schrodinger.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

# Code walkthrough

In [None]:
# A
using MethodOfLines, ModelingToolkit, DifferentialEquations, DomainSets, Plots, Printf

# B
@parameters x t
@variables u(..) v(..)



*   **A**: Import necessary libraries
*   **B**: Defining the parameters and the state variables. In this case, `x,t` are both independent variables. Now, for the case of the state variables, one writes `(..)` to indicate that these will later take arguments. Now, in this case, we're splitting the wave function into real and imaginary parts
    
$$\psi(x,t)=u(x,t)+iv(x,t)$$

for convenience in working with both parts separatedly for the solver. Note that both $u$ anf $v$ are real valued. That is, one has

$$i\partial_t\psi=-\frac{1}{2}\partial_{xx}\psi⟹i(∂_tu+i∂_tv)=-\frac{1}{2}(∂_{xx}u+i\partial_{xx}v)$$

$$\therefore i∂_tu-∂_tv=-\frac{1}{2}∂_{xx}u-\frac{1}{2}i\partial_{xx}v$$

from which it follows that

$$
\begin{align}
\text{Real}&:  ∂_tv=\frac{1}{2}∂_{xx}u\\
\text{Imaginary}&: ∂_tu=-\frac{1}{2}∂_{xx}v
\end{align}
$$

In [None]:
# A
Dt  = Differential(t)
Dxx = Differential(x)^2



*   **A**: Definition of the differential operators

$$
∂_t,\quad\partial_{xx}
$$

In [None]:
# A
eqs = [
    Dt(u(x,t)) ~ -(1/2) * Dxx(v(x,t)),
    Dt(v(x,t)) ~  +(1/2) * Dxx(u(x,t))
]



*   **A**: Definition of the equations of the system. This is the statement of the real and imaginary parts as we saw before

$$
\begin{align}
\text{Real}&:  ∂_tv=\frac{1}{2}∂_{xx}u\\
\text{Imaginary}&: ∂_tu=-\frac{1}{2}∂_{xx}v
\end{align}
$$

Here we define these equations and write explicitly the dependende of $u$ and $v$ on $x$ and $t$.

In [None]:
# A
bcs = [
    u(0,t) ~ 0.0,  u(1,t) ~ 0.0,
    v(0,t) ~ 0.0,  v(1,t) ~ 0.0,
    u(x,0) ~ sin(2π*x),
    v(x,0) ~ 0.0
]

*   **A**: Definition of the boundary/initial conditions of the problem. We have *Dirichlet boundary conditions* in which one fixes the value of the field at the boundary, in this case

$$
\begin{align}
u(0,t) &= 0 \\
u(1,t) &= 0 \\
v(0,t) &= 0 \\
v(1,t) &= 0
\end{align}
$$

from which we observe that we have a particle in a box$x\in[0,1]$, that is, an infinite potential well, because the wave funtion vanishes at the walls.

Then, we have the initial conditions of the field

$$
\begin{align}
u(x,0) &= \sin(2\pi x) \\
v(x,0) &= 0
\end{align}
$$

which means a purely real initial wave because $v(x,0)=0$.

In [None]:
# A
domains = [
    x ∈ Interval(0.0..1.0),
    t ∈ Interval(0.0..10.0)
]

*   **A**: Declaring the domains of the independent variables.

$x\in [0,1]$ the box

$t\in[0,10]$ the time window of the simulation


In [None]:
# A
@named pde = PDESystem(eqs, bcs, domains, [x, t], [u(x,t), v(x,t)])

*   **A**: Assembling the system by including
    *   The equations
    *   The initial/boundary conditions
    *   The domains of the independent variables
    *   The independent variables
    *   The state variables, with their dependence explicitly written down

In [None]:
# A
nx = 101

# B
order = 2

# C
discretization = MOLFiniteDifference([x=>nx], t; approx_order=order)



*   **A**: Select the number of spatial grid points along $x\in[0,1]$. This will give one a $\Delta x=1/(101-1)=0.01$ spacing.
*   **B**: Selecting the accuracy order of the spatial finite-difference stencil.
*   **C**: Creation of the discretized grid in $x$ while letting $t$ be continuous using the **Method of lines**. This transforms the system into a semi-discrete ODE system by generating a big coupled ODEs system in time so that one can use standard ODE solvers. This just gives the description, the recipe of the discretization, it doesn't apply it. It's like creating only the discretization instructions.



In [None]:
semi = discretize(pde, discretization)

Here, the discretization instructions are applied to the system. This creates the coupled ODEs system from the PDE system, called `semi` in this example. This new ODE system is ready for solving.

In [None]:
# A
tspan = (0.0, 10.0)

# B
sol = solve(semi, RK4(); dt=1e-4, saveat=0.02, progress=true)

**A**: Setting the time window for the integrator to be $(0.0, 10.0)$. THis is for the integrator and could be different from the domain defined earlier. Tha domain was an indication for the MOL method, whereas the timespan is for the solver.

**B**: Integrating the semi-discrete ODE system using `RK4()`. `dt=1e-4` is the fixed time step for the solver. From $t=0$ to $t=10$ one then takes $10/10^{-4}=100,000$ RK4 steps. The `savetat=0.02` option records the solution at every 0.02 time steps for plotting. It doesn't force the integrator so step every 0.02 units, it just saves snapshots there. Smaller values yield smoother plots.

In [None]:
# A
x_grid = sol[x]
t_grid = sol[t]

# B
u_solution = sol[u(x,t)] # Real part
v_solution = sol[v(x,t)] # Imaginary part

# C
psi2 = @.(u_solution^2 + v_solution^2)



*   **A**: Obtaining the $x$ and $t$ values from the solution, that is, the points at which the spanshots were taken by the integrator. These will be used for plotting the results.
*   **B**: Querying the solution object to obtain the $u$ values, which correspond to the real part of the wave function, and the $v$ values for the imaginary part.
*   **C**: Obtaining the value of the probability density $|\psi|^2=u^2+v^2$. The Julia's broadcasting macro `@.` makes the operation element-wise, in order to sum the squares of each part of the solution element-wise. This variable will be used for plotting as well.


In [None]:
# A
p_u = heatmap(t_grid, x_grid, u_solution; xlabel="t", ylabel="x", title="u(x,t)", colorbar=true)
p_v = heatmap(t_grid, x_grid, v_solution; xlabel="t", ylabel="x", title="v(x,t)", colorbar=true)

# B
plot(p_u, p_v, layout=(1,2), size=(1000,400))



*   **A**: Create two static heatmaps for each part of the solution.
*   **B**: Plot the two static heatmaps into a single plot of `(1,2)`, that is, 1 row and 2 columns of the specified size.



## Plots

In [None]:
# A
anim_u = @animate for k in 1:length(t_grid)
    plot(x_grid, u_solution[:, k];
         xlabel = "x", ylabel = "u(x,t)",
         title = @sprintf("t = %.2f", t_grid[k]),
         legend = false, lw = 2,
         ylims = (minimum(u_solution), maximum(u_solution)))
end

# B
gif(anim_u, "u_evolution.gif", fps = 20)

# C
anim_v = @animate for k in 1:length(t_grid)
    plot(x_grid, v_solution[:, k];
         xlabel = "x", ylabel = "v(x,t)",
         title = @sprintf("t = %.2f", t_grid[k]),
         legend = false, lw = 2,
         ylims = (minimum(v_solution), maximum(v_solution)))
end

gif(anim_v, "v_evolution.gif", fps = 20)



*   **A**:
    * `@animate` record each iteration of the loop as a frame
    * `k in 1:length(t_grid)` iterates over all of the time steps (Julia indexing starts at 1)
    * `u_solution[:, k]` like Numpy, this is all rows and the *k*-th column, which corresponds to the *k*-th time step, as in `u(x,t_k)`. Rows are *x* and columns are *t*.
    * `@sprintf("t = %.2f", t_grid[k])` prints the title which is updated at every frame. `t = %.2f` prints the value of `t` with 2 decimals. The value of `t_grid[k]` is passed into the `%` symbol and displayed
    * `ylims = (minimum(v_solution), maximum(v_solution))` This fixes the vertical scale for all frames yielding a smoother animation
*   **B**: Saves all frames into a GIF file with `fps=20`, that is, 20 frames per second.

