# Helper functions that are useful for all Julia notebooks

In [None]:
import Dates
import JupyterFormatter
import LinearAlgebra
import Logging
import LoggingExtras
import ProgressMeter

## Multi-threading

In [None]:
# let the @progress macro work with Threads.@threads
# TODO upstream this into ProgressMeter
macro tprogress(expr)
    loop = expr
    if loop.head == :macrocall && loop.args[1] == :(Threads.var"@threads")
        loop = loop.args[end]
    end
    
    p = gensym()    
    r = loop.args[1].args[end]
    ex = quote
        n = Int(round(length($(esc(r))) / Threads.nthreads()))
        global $p = ProgressMeter.Progress(n; showspeed=true)
        $(esc(expr))
        ProgressMeter.finish!($p)
    end
    
    update = quote
        if Threads.threadid() == 1
            ProgressMeter.next!($p)
        end
    end
    push!(loop.args[end].args, update)    
    
    ex    
end

In [None]:
# # partitions the range 1:n by the number of threads
# # and returns the range corresponding to your thread id
# # must only be used with Threads.@threads :static
# function thread_range(n)
#     tid = Threads.threadid()
#     nt = Threads.nthreads()
#     d, r = divrem(n, nt)
#     from = (tid - 1) * d + min(r, tid - 1) + 1
#     to = from + d - 1 + (tid ≤ r ? 1 : 0)
#     from:to
# end

In [None]:
# Prefer Julia multithreading to BLAS multithreading
LinearAlgebra.BLAS.set_num_threads(1);

## Formatting

In [None]:
JupyterFormatter.enable_autoformat();

## Early stopping

In [None]:
# # stop training when the loss function stops decreasing
# @kwdef mutable struct early_stopper
#     max_iters = Inf
#     patience = Inf
#     min_rel_improvement = 0
#     iters = 0
#     iters_without_improvement = 0
#     loss = NaN
# end

# function stop!(x::early_stopper, loss)
#     x.iters += 1
#     if x.iters > x.max_iters
#         return true
#     end

#     if x.iters == 1
#         x.loss = loss
#         return false
#     end

#     if loss < x.loss * (1 - x.min_rel_improvement)
#         x.iters_without_improvement = 0
#         x.loss = loss        
#     else
#         x.iters_without_improvement += 1
#     end
#     x.iters_without_improvement > x.patience
# end;

In [None]:
# # stop training when the parameters have converged
# @kwdef mutable struct convergence_stopper
#     tolerance::AbstractFloat
#     max_iters = Inf
#     params::AbstractVector
#     prev_params::AbstractVector
#     iters = 0
# end

# function convergence_stopper(tolerance; max_iters = Inf)
#     convergence_stopper(
#         tolerance = tolerance,
#         max_iters = max_iters,
#         params = [],
#         prev_params = [],
#     )
# end

# function stop!(x::convergence_stopper, params)
#     x.iters += 1
#     if x.iters > x.max_iters
#         return true
#     end

#     if x.iters == 1
#         x.params = deepcopy(params)
#         return false
#     end

#     function maxabs(a)
#         maximum(abs.(a))
#     end

#     x.prev_params = deepcopy(x.params)
#     x.params = deepcopy(params)
#     maximum(maxabs.(x.params - x.prev_params)) < x.tolerance
# end;

## Logging

In [None]:
Logging.disable_logging(Logging.Debug);

In [None]:
# Logger that flushes after every log statement
struct FlushLogger <: LoggingExtras.AbstractLogger
    logger::LoggingExtras.ConsoleLogger
end

function FlushLogger(logger::LoggingExtras.AbstractLogger)
    FlushLogger(logger)
end

function Logging.handle_message(logger::FlushLogger, args...; kwargs...)
    Logging.handle_message(logger.logger, args...; kwargs...)
    flush(logger.logger.stream)
end

Logging.shouldlog(logger::FlushLogger, arg...) = Logging.shouldlog(logger.logger, arg...)
Logging.min_enabled_level(logger::FlushLogger) = Logging.min_enabled_level(logger.logger)
Logging.catch_exceptions(logger::FlushLogger) = Logging.catch_exceptions(logger.logger)

function logging_meta_formatter(level, _module, group, id, file, line)
    prefix_color = (
        level < Logging.Info ? 4 : level < Logging.Warn ? 6 : level < Logging.Error ? 3 : 1
    )
    prefix = (level == Logging.Warn ? "Warning" : string(level)) * ':'
    prefix_color, prefix, ""
end;

In [None]:
# Log to file and stdout at the same time
function redirect_logging(outdir; overwrite=true)
    date_format = "yyyymmdd HH:MM:SS"
    timestamp_logger(logger) =
        LoggingExtras.TransformerLogger(logger) do log
            merge(
                log,
                (; message = "$(Dates.format(Dates.now(), date_format)) $(log.message)"),
            )
        end

    outdir = mkpath(outdir)
    suffix = ""
    if !overwrite
        tries = 0
        while ispath("$(outdir)/log$(suffix)")
            tries += 1            
            suffix = ".$tries"
        end
    end
    Logging.global_logger(
        LoggingExtras.TeeLogger(
            FlushLogger(
                LoggingExtras.ConsoleLogger(
                    stderr,
                    Logging.Info;
                    meta_formatter = logging_meta_formatter,
                ),
            ) |> timestamp_logger,
            FlushLogger(
                LoggingExtras.ConsoleLogger(
                    open("$(outdir)/log$(suffix)", write = true),
                    Logging.Info;
                    meta_formatter = logging_meta_formatter,
                ),
            ) |> timestamp_logger,
        ),
    )
end;