In [None]:
using Pkg 
# Pkg.instantiate()
Pkg.activate(".")  # activate the environment for this notebook
# load the packages
using Flux, Dates, BSON, Plots, Measures, JLD2, Noise, JSON, Random
using Printf, Statistics, ProgressMeter, ParameterSchedulers, ArgCheck
using ParameterSchedulers: Stateful , Optimisers

pyplot()
# # optional gpu usage
const use_gpu = false
if use_gpu
    using cuDNN, CUDA
    CUDA.allowscalar(false) # disable scalar operations on GPU
    if CUDA.functional() # check if CUDA is functional
        println("CUDA is functional, using GPU")
    else
        error("CUDA is not functional, using CPU")
    end
end

# TSCV

In [None]:
function collect_data(data, branch; noise_param = 0.1)   
    # branch is either 1 or 2, branch 1 is h0, branch 2 is h_s, u_c 
    # notice how the branch number is the same as the output size
    timex = data["times"][branch]
    if branch == 1
        q_in = data["q_in"][branch]
        q_out = data["q_out"][branch]
        ext_forc = [ q_in q_out] 
        lookback = findfirst(timex.>= 24*365 )
    else 
        wind = data["tau"][branch]
        ext_forc = wind 
        lookback= findfirst(timex .>= 24 ) 
    end

    # sizes
    n_times=length(data["solution"][branch])
    n_steps=n_times # number of time steps in the data

    # Create inputs X and outputs Y
    n_in = 3
    n_out = branch

    X = zeros(Float32,n_in,n_steps)
    Y = zeros(Float32,n_out,n_steps)

    for t in 1:n_steps
        X[1:branch,t] .=  data["solution"][branch][t] .* (1 +  noise_param * randn(Float32) )
        X[1+branch:end,t] .=  ext_forc[t]
        Y[:,t] .= data["solution"][branch][t]   
    end

    Xμ= mean(X[:,1:lookback], dims=2); Xσ = std(X[:,1:lookback],  dims=2)
    # Normalize data only using past values
    X = (X .- Xμ) ./ Xσ    
    Y = (Y .- Xμ[1:branch,:]) ./ Xσ[1:branch,:]
    return X, Y, timex
end


function make_rolling_test_splits(time, branch; horizon=30)
    # Create splits based on horizon lenght
    lookback = branch == 1 ? findfirst(time .>= 24*365 ) : findfirst(time .>= 24 )
    splits = []
    start = lookback
    data_len = size(time,1)
    while (start + horizon) <= data_len
        train_idx = 1:start
        val_idx   = (start+1):(start+horizon)
        push!(splits, (train_idx, val_idx))
        start += horizon
    end
    return splits
end

nothing

# Model Architecture

In [None]:
function cyclical_encoding(z,time, branch)
    # Encode the d_k values
    a = (branch == 1) ? findfirst(time .>= 24*365 ) : findfirst(time .>= 24 )# this should be 1 year
    dsin = sin.( 2π .* (z .- 1) ./ a)
    dcos = cos.( 2π .* (z .- 1) ./ a)
    return Float32.(vcat(dsin' , dcos'))      # date d_t = [ dsin, dcos ]
end 

function make_seq2seq(branch; hidden_size = 64, alt = 0)
    # Creating the layers
    encoder = LSTM(3 => hidden_size)
    
    # Checking which inputs go in the decoder
    # alt=0 means the standard decoder with only the dates
    # alt=1 is the first alternative both forcing and dates
    # alt=2 is the seconf alternative with only the forcing

    n1 = (alt <= 1) ? 2 : 0                 # nonzero if we are using date in decoder 
    n2 = (alt >= 1) ? (3 - branch) : 0        # nonzero if we are using forcing in decoder
    n_in = n1 + n2
    decoder = LSTM(n_in => hidden_size)
    outlayer = Dense(hidden_size => branch)

    # Move model to gpu or keep on cpu
    if use_gpu
        encoder, decoder, outlayer = gpu.( (encoder, decoder, outlayer) )
    end
    
    return Chain(encoder, decoder, outlayer)
end


function forward_pass(m, x_seq, decod_seq)
    # Forward pass of the model, corresponds to the seq2seq architecture
    encoder, decoder, outlayer = m[1], m[2], m[3]
    # Reset hidden/cell states memory (no leakage)
    Flux.reset!(encoder); Flux.reset!(decoder)
    # Encode internal h_t and c_t states using prev values
    x_seq = reshape(x_seq, size(x_seq, 1), 1, size(x_seq, 2))  
    encoder(x_seq)
    # Initialize with encoder states
    decoder.state = encoder.state
    decod_seq = reshape(decod_seq, size(decod_seq,1), 1, size(decod_seq,2))
    # Forecast the next steps  
    yhat = outlayer(decoder(decod_seq))
    return dropdims(yhat; dims=2)
end

# Training 

In [None]:
function train_rolling_tscv(m_tuple, X, Y, time, alt, branch; horizon=30, n_epochs = 1000, decay_rate=0.99, step=5)
    # Create container
    loss_epoch = []

    # Create folds
    splits = make_rolling_test_splits(time, branch; horizon=horizon) 

    # Initialize optimizer and scheduler
    learning_rate = 0.001 / (decay_rate^(n_epochs/step))
    optimizer = Flux.setup(Adam(learning_rate), m_tuple)
    lr_scheduler = Stateful(Step(learning_rate, decay_rate, step))
    # Train data
    # @showprogress 
    for epoch in 1:n_epochs
        # Perform one epoch of training
        for (fold, (train_idx, val_idx)) in enumerate(splits)
            # Form training and validation sets
            x_train = X[:, train_idx] 
            y_val = Y[:, val_idx]
            # Encode dates and forcing            
            date_seq = cyclical_encoding(val_idx, time, branch)
            fork = X[1+branch:end, val_idx]
            decod_seq = (date_seq, vcat(date_seq, fork), fork)[alt + 1]
            
            loss, grads = Flux.withgradient(m_tuple) do m 
                # Evaluate model and loss inside gradient context:
                y_hat = forward_pass(m, x_train, decod_seq) # apply the model to the input batch      
                Flux.huber_loss(y_hat, y_val) # compute the loss
            end
            Flux.update!(optimizer, m_tuple, grads[1])
            if epoch == n_epochs 
                y_hat = forward_pass(m_tuple, x_train, decod_seq)
                push!(loss_epoch, Flux.huber_loss(y_hat, y_val)  )
            end
        end
        # Update learning rate
        nextlr = ParameterSchedulers.next!(lr_scheduler)
        Optimisers.adjust!(optimizer, nextlr)
    end
    loss_stat = mean(loss_epoch), std(loss_epoch)

    return m_tuple, loss_stat
end

nothing

# Unroll

In [1]:
# Unroll the model to get the output for the initial condition
function unroll(model, X, Y, horizon, branch, time, alt; forcing = false)    
    # Create container
    loss_epoch = []
    markers = [] # for plotting

    # Create folds
    splits = make_rolling_test_splits(time, branch; horizon=horizon)

    lookback = (branch == 1) ? findfirst(time .>= 24*365 ) : findfirst(time .>= 24 )
    X1 = nothing; temp = nothing

    n_steps = Int(floor( (size(time,1) - lookback )/ horizon))

    outputs = zeros(Float32, branch, n_steps*horizon)
    # Unroll the model for n_step
    for (fold, (train_idx, val_idx)) in enumerate(splits)
        # Form training and validation sets
        X1 = (forcing || fold==1) ? X[:, train_idx] :  hcat(X1, temp)
        y_val = Y[:, val_idx]
        # Encode dates and forcing            
        date_seq = cyclical_encoding(val_idx, time, branch)
        fork = X[1+branch:end, val_idx]
        decod_seq = (date_seq, vcat(date_seq, fork), fork)[alt + 1]

        y = forward_pass(model, X1, decod_seq)  
        # Compare forecast to exact value  
        push!(loss_epoch, Flux.huber_loss(y, y_val)  )
        # For plotting
        push!(markers, val_idx[1])
        # Add to unroll array
        outputs[:,mod.(val_idx, lookback) .+ lookback.*( val_idx .÷ lookback .- 1)] = y    
        # Add the unroll to the training data     
        temp = vcat( y , X[1+branch:end,val_idx]) 
    end
    loss_stat = mean(loss_epoch), std(loss_epoch)
    return outputs, markers, loss_stat
end


function plot_unroll(Y_unroll, Y, time, branch; markers = nothing)
    # Plot the unroll
    lookback = (branch == 1) ? findfirst(time .>= 24*365 ) : findfirst(time .>= 24 )
    
    if branch == 1
        p1=plot(time[lookback:lookback+size(Y_unroll,2)-1]/24, Y_unroll[1,:], xticks = 0:60:maximum(time), xlabel="time [days]", label="ML unroll", legend=true, title="Mean water level in lake")# ,st=:scatter)
        plot!(p1, time[lookback:end]/24, Y[1,lookback:end], xlabel="time [days]", label="Ground truth", legend=true, title="Mean water level in lake")
        if markers !== nothing
            scatter!(p1, [time[markers]/24], Y_unroll[1,mod.(markers,lookback)], markershape=:x, label="", markersize = 5, markerstrokewidth = 2)
        end
        compareplot1 = plot(p1,layout=(1,1), size=(1200,300))
        display(compareplot1)

    elseif branch == 2
        p2=plot(time[lookback:lookback+size(Y_unroll,2)-1], Y_unroll[1,:], xlabel="time [hours]", xticks = 0:5:maximum(time), label="ML unroll", legend=true, title="Surface slope in lake [m]")
        plot!(p2, time[lookback:end], Y[1,lookback:end], xlabel="time [hours]", label="Ground truth", legend=true, title="Surface slope in lake [m]")
      
        p3=plot(time[lookback:lookback+size(Y_unroll,2)-1], Y_unroll[2,:], xlabel="time [hours]", xticks = 0:5:maximum(time), label="ML unroll", legend=true, title="Velocity in lake [m/s]")
        plot!(p3, time[lookback:end], Y[2,lookback:end], xlabel="time [hours]", label="Ground truth", legend=true, title="Velocity in lake [m/s]")
        
        if markers !== nothing
            scatter!(p2, [time[markers]], Y_unroll[1,mod.(markers,lookback)], markershape=:x, label="", markersize = 5, markerstrokewidth = 2)
            scatter!(p3, [time[markers]], Y_unroll[2,mod.(markers,lookback)], markershape=:x, label="", markersize = 5, markerstrokewidth = 2)
        end

        compareplot2 = plot(p2,p3,layout=(2,1), size=(1200,600))
        display(compareplot2)
    end
end

nothing

# Example

In [None]:
data=load("Dataset/model_0d_lake_sep.jld2")
branch = 2
horizon = 40
alt = 0
X, Y, time = collect_data(data, branch, noise_param = 0.1)  
m_tuple = make_seq2seq(branch,alt = alt)
m_tuple, loss_stat = train_rolling_tscv(m_tuple, X, Y, time, alt, branch; horizon=horizon)
horizon_u = horizon
Y_unroll, markers, loss_stat_u = unroll(m_tuple, X, Y, horizon_u, branch, time, alt)
@show loss_stat_u , loss_stat
plot_unroll(Y_unroll, Y, time, branch)

# Comparing alts

Compare the unrolling loss of the model on the storm dataset

In [None]:
# # hyperparameters
# branch = 1

# loss = Dict{Tuple{Int,Int,Int,Int}, Any}()

# for horizon in [60,50,40,30,20]
#     for alt in [0,1,2]
#         for epochs in [100, 200, 300, 400]

#             @show horizon, alt, epochs

#             horizon_u = horizon

#             # --- Training data ---
#             data = load("Dataset/model_0d_lake_sep.jld2")
#             X, Y, time = collect_data(data, branch, noise_param = 0.1)

#             # --- Build model ---
#             m_tuple = make_seq2seq(branch, alt = alt)

#             # --- Train model ---
#             m_tuple, loss_stat = train_rolling_tscv(m_tuple, X, Y, time, alt, branch; horizon=horizon, n_epochs=epochs)

#             # --- Test / unroll on odd data ---
#             for storm in [1,2,3]
#                 if storm == 1
#                     data2 = load("Dataset/model_0d_lake_sep_odd1.jld2")
#                 elseif storm == 2
#                     data2 = load("Dataset/model_0d_lake_sep_odd2.jld2")
#                 elseif storm == 3
#                     data2 = load("Dataset/model_0d_lake_sep_odd2.jld2")
#                 end 
#                 X, Y, time = collect_data(data2, branch, noise_param = 0.1)

#                 Y_unroll, markers, loss_stat_u = unroll(m_tuple, X, Y, horizon_u, branch, time, alt)

#                 # --- Save loss in dictionary ---
#                 loss[(horizon, alt, epochs, storm)] = loss_stat_u
#             end

#         end
#     end
# end


# @save "loss_alt_comp_storms_b1.jld2" loss

# Plotting the unrolling of Alts

In [None]:

# data=load("Dataset/model_0d_lake_sep.jld2")
# branch = 2
# horizon = 40
# horizon_u = horizon

# p1=plot();p2=plot(); p3=plot()

# X, Y, time = collect_data(data, branch, noise_param = 0.1)  

# for alt in [0,1,2]
#     m_tuple = make_seq2seq(branch,alt = alt)
#     m_tuple, loss_stat = train_rolling_tscv(m_tuple, X, Y, time, alt, branch; horizon=horizon)    
#     @show loss_stat, alt
#     Y_unroll, markers, loss_stat_u = unroll(m_tuple, X, Y, horizon_u, branch, time, alt)
#     @show loss_stat_u, alt

#     lookback = (branch == 1) ? findfirst(time .>= 24*365 ) : findfirst(time .>= 24 )

#     if branch == 1
#         plot!(p1,time[lookback:lookback+size(Y_unroll,2)-1]/24, Y_unroll[1,:], xticks = 0:60:maximum(time), xlabel="time [days]", label="ML unroll", legend=true, title="Mean water level in lake")# ,st=:scatter)
#         if alt==2
#             plot!(p1, time[lookback:end]/24, Y[1,lookback:end], xlabel="time [days]", label="Ground truth", legend=true, title="Mean water level in lake")
#         end

#     elseif branch == 2
#         plot!(p2,time[lookback:lookback+size(Y_unroll,2)-1], Y_unroll[1,:], xlabel="time [hours]", xticks = 0:5:maximum(time), label="ML unroll alt=$alt", legend=true, title="Surface slope in lake [m]")
#         plot!(p3,time[lookback:lookback+size(Y_unroll,2)-1], Y_unroll[2,:], xlabel="time [hours]", xticks = 0:5:maximum(time), label="ML unroll alt=$alt", legend=true, title="Velocity in lake [m/s]")
#         if alt==2
#             plot!(p2, time[lookback:end], Y[1,lookback:end], xlabel="time [hours]", label="Ground truth", legend=true, title="Surface slope in lake [m]")
#             plot!(p3, time[lookback:end], Y[2,lookback:end], xlabel="time [hours]", label="Ground truth", legend=true, title="Velocity in lake [m/s]")
#         end
#     end

# end


# compareplot2 = plot(p2,p3,layout=(2,1), size=(1200,600))
# # compareplot2 = plot(p1,layout=(1,1), size=(1200,300))




# Creating data for plots

This code creates data for plots made in File "Make plots.ipynb"

In [None]:
# data = load("Dataset/model_0d_lake_sep.jld2")

# # Prepare a container to store results
# # results[branch][horizon_t][epochs][alt][horizon_u] => loss_stat_u
# results = Dict{Int, Any}()

# # train_results[branch][horizon_t][epochs][alt] => Dict:loss_stat_t => ..., :train_time  => ...)
# train_results = Dict{Int, Any}()

# for branch in [2]
#     X, Y, time = collect_data(data, branch, noise_param = 0.1)
#     branch_results = Dict{Int, Dict{Int, Dict{Int, Dict{Int, Any}}}}() 
#     # branch_results[horizon_t][epochs][alt][horizon_u] => loss_stat_u
#     branch_train= Dict{Int, Dict{Int, Dict{Int, Any}}}()
#     # branch_train[horizon_t][epochs][alt] => Dict:loss_stat_t => ..., :train_time  => ...)

#     for horizon_t in [60,40,20,10,5,1]
#         splits = make_rolling_test_splits(time, branch; horizon=horizon_t)
#         branch_results[horizon_t] = Dict{Int, Dict{Int, Dict{Int, Any}}}()
#         branch_train[horizon_t]   = Dict{Int, Dict{Int, Any}}()

#         for epochs in [100, 200, 300, 400]
#             branch_results[horizon_t][epochs] = Dict{Int, Dict{Int, Any}}()
#             branch_train[horizon_t][epochs]   = Dict{Int, Any}()

#             for alt in [0,1,2]
#                 m_tuple = make_seq2seq(branch, alt=alt)
#                 train_time = @elapsed begin
#                     m_tuple, loss_stat_t = train_rolling_tscv(m_tuple, X, Y, time, alt, branch;
#                                                                 horizon=horizon_t, n_epochs=epochs)
#                 end  # timed
#                 horizon_u_results = Dict{Int, Any}()
#                 branch_train[horizon_t][epochs][alt] = Dict(
#                     :loss_stat_t => loss_stat_t,
#                     :train_time  => train_time,
#                 )

#                 for horizon_u in [60,40,20,10,5,1]
#                     Y_unroll, markers, loss_stat_u = unroll(m_tuple, X, Y, horizon_u, branch, time, alt)
#                     horizon_u_results[horizon_u] = loss_stat_u
#                 end  # horizon_u

#                 branch_results[horizon_t][epochs][alt] = horizon_u_results
#                 @save "checkpoint_b2.jld2" branch_results branch_train
#                 print("\rTraining branch $branch, horizon $horizon_t, epochs $epochs, alt $alt, train_time $train_time")
#                 flush(stdout)
#             end #alt
#         end # epochs
#     end # horizon_t
#     results[branch] = branch_results
#     train_results[branch] = branch_train
#     @save "checkpoint2_b2.jld2" results train_results

# end # branch

# @save "final_res_b2.jld2" results train_results