In [None]:
using Distributions
using Plots

abstract type Component end
abstract type Scatterer end
abstract type Model end

function flattenall(A)
    V = []
    for x in A
        if isa(x, AbstractArray)
            append!(V, flattenall(x))
        else
            push!(V, x)
        end
    end
    V
end

In [None]:
# data stuff
using DelimitedFiles


mutable struct Data1D
    x::Vector
    y::Vector
    yerr::Union{Vector, Nothing}
    xerr::Union{Vector, Nothing}
    fname::Union{String, Nothing}
    Data1D(x::Vector, y::Vector) = new(x, y, nothing, nothing, nothing)
    Data1D(x::Vector, y::Vector, yerr::Vector) = new(x, y, yerr, nothing)
    Data1D(x::Vector, y::Vector, yerr::Vector, xerr::Vector) = new(x, y, yerr, xerr)
    Data1D(x::Vector, y::Vector; yerr::Vector, xerr::Vector) = new(x, y, yerr, xerr)
    function Data1D(x::Vector, y::Vector;
            yerr::Vector,
            xerr::Vector
        )
        new(x, y, yerr, xerr)
    end
    Data1D(filename::String) = read_data(filename)
end


function read_data(filename::String; delim=nothing)
    if isnothing(delim)
        arr = readdlm(filename)
    else
        arr = readdlm(filename, delim)
    end
    nrows, ncols = size(arr)

    if ncols == 2
        data = Data1D(arr[:, 1], arr[:, 2], nothing, nothing)
    elseif ncols == 3
        data = Data1D(arr[:, 1], arr[:, 2], arr[:, 3], nothing)
    elseif ncols == 4
        data = Data1D(arr[:, 1], arr[:, 2], arr[:, 3], arr[:, 4])
    end
    data.fname = filename
    return data
end


function refresh!(data::Data1D)
   data_updated = read_data(data.fname)
   data.x = data_updated.x
   data.y = data_updated.y
   data.yerr = data_updated.yerr
   data.xerr = data_updated.xerr
   return nothing
end


function Base.size(data::Data1D)
    return size(data.x)
end

function Base.length(data::Data1D)
    return length(data.x)
end

Plots.plot(data::Data1D) = plot(data.x, data.y, yaxis=:log)
Plots.plot!(data::Data1D) = plot!(data.x, data.y, yaxis=:log)

In [None]:
# parameters

mutable struct Parameter
    value::Real
    vary::Bool
    bounds::Distribution
    name::Union{String, Symbol}
    function Parameter(value::Real;
            vary::Bool=false,
            bounds::Distribution=Uniform(-Inf, Inf),
            name::Union{String, Symbol}=""
        )
        new(value, vary, bounds, name)
    end
end

function build_parameter(p::Union{Parameter, Real}; name::Union{String, Symbol}="", vary=false)
    isa(p, Parameter) ? p : Parameter(p; name=name, vary=vary)
end

# might need to get rid of this when it comes to autodifferentiation!
Base.Real(p::Parameter) = p.value

In [None]:
const TINY = 1e-30
const _FWHM = 2 * sqrt(2 * log(2.0))
const PI4 = 4e-6 * pi
const _INTLIMIT = 3.5

function abeles(q, w)
    nlayers = size(w, 1) - 2
    npnts = length(q)

    reflectivity = Vector{Any}(undef, (npnts))
    oneC = Complex(1.0)

    for j in eachindex(q)
        qq2 = (q[j] * q[j] / 4.0) + 0.0im
        kn = (q[j] / 2.) + 0.0im

        # variables are local to if blocks
        local MRtotal11, MRtotal12, MRtotal21, MRtotal22

        for i = 1:nlayers+1
            # wavevector in the layer
            sld_next = ((w[i+1, 2] - w[1, 2]) + ((abs(w[i+1, 3]) + TINY))im) * pi * 4.0e-6
            kn_next = sqrt(qq2 - sld_next)

            # reflectance of the interface
            rj = (kn - kn_next)/(kn + kn_next) * exp(kn * kn_next * (-2.0 * w[i+1, 4]^2))

            if i == 1
                # characteristic matrix for first interface
                MRtotal11 = oneC
                MRtotal12 = rj
                MRtotal21 = rj
                MRtotal22 = oneC
            else
                # work out the beta for the layer
                beta = exp(kn * (abs(w[i, 1]) * 1im))

                # this is the characteristic matrix of a layer
                MI11 = beta
                MI12 = rj * beta
                MI22 = oneC / beta
                MI21 = rj * MI22

                # propagate optical matrix by matmul
                p11 = MRtotal11 * MI11 + MRtotal12 * MI21
                p12 = MRtotal11 * MI12 + MRtotal12 * MI22
                p21 = MRtotal21 * MI11 + MRtotal22 * MI21
                p22 = MRtotal21 * MI12 + MRtotal22 * MI22

                MRtotal11 = p11
                MRtotal12 = p12
                MRtotal21 = p21
                MRtotal22 = p22

            end
            kn = kn_next;

        end
        reflectivity[j] = MRtotal21 / MRtotal11
    end
    return real(reflectivity .* conj(reflectivity))
end

In [None]:
# Structure and reflectometry stuff

struct Structure <: AbstractVector{Component}
    components::Vector{Component}
end
Structure() = Structure([])
slabs(structure::Structure) = hcat([slabs(c) for c in structure.components]...)'
parameters(structure::Structure) = [parameters(c) for c in structure.components]

Base.length(s::Structure) = length(s.components)
Base.size(s::Structure) = size(s.components)
Base.getindex(s::Structure, i::Int) = s.components[i]
Base.IndexStyle(::Type{<:Structure}) = IndexLinear()
function Base.push!(s::Structure, items...)
    for item in items
        push!(s.components, item)
    end
end
function Base.append!(s::Structure, arr::Array{N, 1} where N<:Component)
    for item in arr
        push!(s.components, item)
    end
end
Base.setindex!(s::Structure, c::Component, i::Int) = setindex!(s.components, c, i)


(|)(a::Component, b::Component) = Structure([a, b])
function |(s::Structure, b::Component)
    push!(s, b)
    s
end
function |(s::Structure, arr::Array{N, 1} where N<:Component)
    append!(s, arr)
    s
end


mutable struct SLD <: Scatterer
    re::Parameter
    im::Parameter
    SLD(re::Union{Real, Parameter}) = new(build_parameter(re), Parameter(0.0))
    SLD(re::Union{Real, Parameter}, im::Union{Real, Parameter}) = new(
        build_parameter(re), build_parameter(im)
    )
    SLD(sld::Complex) = new(Parameter(sld.re), Parameter(sld.im))
end
sld(s::SLD) = s.re.value + 1.0im * s.im.value
Base.complex(s::SLD) = s.re.value + 1.0im * s.im.value
parameters(s::SLD) = [s.re, s.im]


mutable struct Slab <: Component
    thickness::Parameter
    scatterer::Scatterer
    roughness::Parameter
    vfsolv::Parameter
    function Slab(thickness::Union{Parameter, Real},
        scatterer::Union{Scatterer, Real, Complex, Tuple{Real, Real}},
        roughness::Union{Parameter, Real},
        vfsolv::Union{Parameter, Real})
        if isa(scatterer, Scatterer)
            s = scatterer
        else
            s = SLD(scatterer)
        end
        new(build_parameter(thickness), s, build_parameter(roughness), build_parameter(vfsolv)
    )
    end
end

function slabs(slab::Slab)
    sldv = sld(slab.scatterer)
    return [slab.thickness.value, sldv.re, sldv.im, slab.roughness.value]
end
parameters(slab::Slab) = [slab.thickness, parameters(slab.scatterer), slab.roughness, slab.vfsolv]


In [None]:
mutable struct ReflectModel <: Model
    structure::Structure
    scale::Parameter
    background::Parameter
    dq::Parameter
    
    function ReflectModel(s::Structure;
            scale::Union{Real, Parameter}=1.0,
            background::Union{Real, Parameter}=0.0,
            dq::Union{Real, Parameter}=5.0
        )
        new(
            s,
            build_parameter(scale, name="scale"),
            build_parameter(background, name="background"),
            build_parameter(dq, name="dq")
        )
    end
end

generative(model::Model, x::Vector) = nothing
function generative(model::ReflectModel, x::Vector; xerr=nothing)
    w = slabs(model.structure)
    model.scale.value .* abeles(x, w) .+ model.background.value
end

parameters(model::ReflectModel) = [model.scale, model.background, model.dq, parameters(model.structure)]

In [None]:
mutable struct Objective
    model::Model
    data::Data1D
end

function logl(objective::Objective)
    gen = generative(objective.model, objective.data.x; xerr=objective.data.xerr)

    sd = objective.data.yerr
    if isnothing(objective.data.yerr)
        sd = fill(1.0, size(objective.data.y))
    end
    var_y = sd.^2.0

    logl = log.(2 .* pi .* var_y)
    logl .+= (objective.data.y .- gen).^2.0 ./ var_y
    return -0.5 * sum(logl)
end
parameters(obj::Objective) = flattenall(parameters(obj.model))

In [None]:
data = Data1D("c_PLP0011859_q.txt");

In [None]:
air = SLD(0.0)
sio2 = SLD(3.47)
polymer = SLD(2.74)
si = SLD(2.07)
d2o = SLD(6.36)

air_l = Slab(0, air, 0, 0)
sio2_l = Slab(39.724, sio2, 3, 0)
si_l = Slab(0, si, 0, 0)
polymer_l = Slab(259.433, polymer, 3, 0)
d2o_l = Slab(0, d2o, 3, 0)

s = si_l | sio2_l | polymer_l | d2o_l

model = ReflectModel(s; background=3.32e-7);

In [None]:
objective = Objective(model, data)
logl(objective)

In [None]:
x = collect(range(0.0051, 0.5; length=201))
y = generative(model, x)
plot(x, y, yaxis=:log)
plot!(data)

In [None]:
parameters(objective)