# Helper functions that are useful for all Julia notebooks

In [None]:
import Dates
import JupyterFormatter
import LinearAlgebra
import Logging
import LoggingExtras
import ProgressMeter

## Multi-threading

In [None]:
# let the @progress macro work with Threads.@threads
macro tprogress(expr)
    loop = expr
    if loop.head == :macrocall && loop.args[1] == :(Threads.var"@threads")
        loop = loop.args[end]
    end
    
    p = gensym()    
    r = loop.args[1].args[end]
    ex = quote
        n = Int(round(length($(esc(r))) / Threads.nthreads()))
        global $p = ProgressMeter.Progress(n; showspeed=true)
        $(esc(expr))
        ProgressMeter.finish!($p)
    end
    
    update = quote
        if Threads.threadid() == 1
            ProgressMeter.next!($p)
        end
    end
    push!(loop.args[end].args, update)    
    
    ex    
end;

In [None]:
# like Threads.@threads except we can specify the number of threads
function tforeach(f::Function, args, threads::Int)
    @sync for (t, chunk) in Iterators.enumerate(
        Iterators.partition(args, div(length(args), threads, RoundUp)),
    )
        Threads.@spawn begin
            @showprogress enabled = (t == 1) for i in chunk
                f(i)
            end
        end
    end
end;

In [None]:
# Prefer Julia multithreading to BLAS multithreading
LinearAlgebra.BLAS.set_num_threads(1);

## Formatting

In [None]:
JupyterFormatter.enable_autoformat();

## Logging

In [None]:
Logging.disable_logging(Logging.Debug);

In [None]:
# Logger that flushes after every log statement
struct FlushLogger <: LoggingExtras.AbstractLogger
    logger::LoggingExtras.ConsoleLogger
end

function FlushLogger(logger::LoggingExtras.AbstractLogger)
    FlushLogger(logger)
end

function Logging.handle_message(logger::FlushLogger, args...; kwargs...)
    Logging.handle_message(logger.logger, args...; kwargs...)
    flush(logger.logger.stream)
end

Logging.shouldlog(logger::FlushLogger, arg...) = Logging.shouldlog(logger.logger, arg...)
Logging.min_enabled_level(logger::FlushLogger) = Logging.min_enabled_level(logger.logger)
Logging.catch_exceptions(logger::FlushLogger) = Logging.catch_exceptions(logger.logger)

function logging_meta_formatter(level, _module, group, id, file, line)
    prefix_color = (
        level < Logging.Info ? 4 : level < Logging.Warn ? 6 : level < Logging.Error ? 3 : 1
    )
    prefix = (level == Logging.Warn ? "Warning" : string(level)) * ':'
    prefix_color, prefix, ""
end;

In [None]:
# Log to file and stdout at the same time
function redirect_logging(outdir; overwrite=true)
    date_format = "yyyymmdd HH:MM:SS"
    timestamp_logger(logger) =
        LoggingExtras.TransformerLogger(logger) do log
            merge(
                log,
                (; message = "$(Dates.format(Dates.now(), date_format)) $(log.message)"),
            )
        end

    outdir = mkpath(outdir)
    suffix = ""
    if !overwrite
        tries = 0
        while ispath("$(outdir)/log$(suffix)")
            tries += 1            
            suffix = ".$tries"
        end
    end
    Logging.global_logger(
        LoggingExtras.TeeLogger(
            FlushLogger(
                LoggingExtras.ConsoleLogger(
                    stderr,
                    Logging.Info;
                    meta_formatter = logging_meta_formatter,
                ),
            ) |> timestamp_logger,
            FlushLogger(
                LoggingExtras.ConsoleLogger(
                    open("$(outdir)/log$(suffix)", write = true),
                    Logging.Info;
                    meta_formatter = logging_meta_formatter,
                ),
            ) |> timestamp_logger,
        ),
    )
end;