In [None]:
using LinearAlgebra
using PolyChaos
using QuadGK
using Plots

function spectral_func(input::AbstractString, D::Float64, g::Float64)
    inband = x -> (-D <= x <= D)
        
    if input == "flat"
        J = x -> inband(x) ? 1/(2D) : 0.0
    elseif input == "elliptical"
        J = x -> inband(x) ? sqrt(1 - (x/D)^2) : 0.0
    elseif input == "ohmic"
        J = x -> inband(x) ? abs(x) : 0.0
    elseif input == "lorentzian"
        J = x -> 1/(1 + (x/D)^2)
    else
        error("spectral function type not recognized")
    end
    # normalization
    norm = quadgk(J, -D, D)[1]           
    Jnorm = x -> g*D/pi * J(x) / norm
    return Jnorm
end

function thermofield_transform(J, beta::Float64, mu::Float64) #spectral function, inverse temp, chemical potential
    """thermofield purification using fermi function ancilla"""
    fermi(k) = 1/(1 + exp(beta*k - beta*mu))
    J1 = w -> J(w) * fermi(w) #filled mode spectral density
    J2 = w -> J(w) * (1 - fermi(w)) #empty mode spectral density
    return J1, J2
end

function chain_map(J, N::Int64, D::Float64)
    """calculates family of monic orthogonal polynomials w.r.t the measure J(x) up to the Nth term.
    returns the coefficients alpha and beta from the recurrence relation of the family."""
    supp = (-D, D)
    meas = Measure("bath", J, supp, false, Dict())
    ortho_poly = OrthoPoly("bath_op", N, meas; Nquad=10000)   
    chain = coeffs(ortho_poly)                                  
    Es = chain[1:N,1] #site energies
    ts = sqrt.(chain[1:N,2]) #site hoppings (first term is system hopping)
    return Es, ts
end

function prepare_corrs(N, sys_occs; n_sites=2)
    N_tot = 2N + n_sites
    C = zeros(ComplexF64, N_tot, N_tot)
    for n in 1:n_sites
        C[n.n] = sys_occs[n]
    end
    for n in 1:N
        C[n_sites + 2n - 1, n_sites + 2n - 1] = 1.0
        C[n_sites + 2n, n_sites + 2n] = 0.0
    end
    return C
end


function Unitary(N, sys_occs, E_sys, t_sys, E1, t1, E2, t2; n_sites=2)
    N_tot = 2N + n_sites
    H = zeros(ComplexF64, N_tot, N_tot)
    for n in 1:n_sites
        H[n.n] = E_sys[n]
        H[1,2] = t_sys
        H[2,1] = t_sys
    end

    #2nd system site coupling to bath
    H[2,3] = t1[1]
    H[3,2] = t1[1]
    H[2,4] = t2[1]
    H[4,2] = t2[1]

    for n in 1+n_sites:N+n_sites
        H[2n-1, 2n-1] = E1[n]
        H[2n, 2n] = E2[n]
    end
    for n in 1+n_sites:N+n_sites
        H[2n-1,2n+1] = t1[n]
        H[2n+1,2n-1] = t1[n]
        H[2n, 2n+2] = t2[n]
        H[2n+2, 2n] = t2[n]
    end
    return exp(-im*t*H)
end

function evolve_corrs(C0, U, dt, tmax)
    Cs = Vector{Array{ComplexF64}}(undef, 0)
    times = collect(0:dt:tmax)
    C0 = Matrix(C0)
    U = Matrix(U)

    for t in times
        C = U * C0 * U'
        push!(Cs, C)
    end
    return Cs
end

In [None]:
#simulation params
N = 50
dt = 0.1
tmax = 100.0

#bath params
D = 1.0
g = 0.1
beta = 1.0
mu = 0.0

#system params
Esys = [0.0, 0.0]
sys_occs = [1.0, 1.0]
t_sys = 0.1

J = create_spectral("elliptical", D, g)
J1, J2 = thermofield_transform(J, beta, mu)

E1, t1 = chain_map(J1, N, D)
E2, t2 = chain_map(J2, N, D)

U = Unitary(N, sys_occs, E_sys, t_sys, E1, t1, E2, t2)
C0 = prepare_correlations(N,sys_occs)
Cs = evolve_corrs(C0, U, dt, tmax)

In [None]:
nSys1 = zeros(length(Cs))
nSys2 = zeros(length(Cs))
for (t,C) in enumerate(Cs)
    nSys1[t] = real(C[1,1])
    nSys2[t] = real(C[2,2])
end
p = plot()