# Neural Network Base Class
* This class contains infrastructure to train neural networks
* The following algorithms are implemented:
    * Baseline predictors
    * Item-based collaborative filtering
    * Matrix Factorization
* The following algoirthms will be implemented
    * Autoencoder

In [1]:
using Flux
using Random
using SparseArrays
import CUDA
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb");

In [2]:
# support both gpu and cpu training

function device(x)
    gpu(x)
end

# efficiently convert a sparse cpu matrix into a dense CUDA array
function device(x::AbstractSparseArray)
    CUDA.functional() ? CUDA.CuArray(gpu(x)) : x
end

if !CUDA.functional()
    LinearAlgebra.BLAS.set_num_threads(Threads.nthreads())
end;

## Models

In [3]:
@nbinclude("Models.ipynb");

## Data 

In [4]:
@nbinclude("Data.ipynb");

## Batching

In [5]:
@nbinclude("Batching.ipynb");

## Loss Functions

In [6]:
@nbinclude("Loss.ipynb");

## Training

In [7]:
@nbinclude("Training.ipynb");

## Hyperparameter Tuning

In [8]:
@nbinclude("Hyperparameters.ipynb");

## Retrain User Embeddings

In [9]:
@nbinclude("Retraining.ipynb");

## Write predictions

In [10]:
@nbinclude("Saving.ipynb");

## Putting it all together

In [12]:
function train_alpha(hyp, outdir; tune_hyperparams = true)
    set_logging_outdir(outdir)

    if tune_hyperparams
        @info "Optimizing hyperparameters..."
        subsampling_factor = get_subsampling_factor(hyp.model)
        hyp_subset = @set hyp.num_users = Int(round(num_users() * subsampling_factor))
        λ = optimize_hyperparams(hyp_subset; max_evals = 100)
        hyp = create_hyperparams(hyp, λ)
        tuning_evals = 10
    else
        tuning_evals = 1
    end

    @info "Training model..."
    m, epochs, validation_loss, hyp = optimize_learning_rate(hyp, tuning_evals)
    @info "Trained model loss: $validation_loss"

    if should_retrain_user_embeddings(hyp.model)
        @info "Retraining user embeddings..."
        m, epochs, retrain_loss, retrain_hyp = retrain_user_embeddings(hyp, m, 10)
        @info "Retrained user embeddings with loss: $retrain_loss"
    else
        retrain_hyp = nothing
        retrain_loss = nothing
    end

    @info "Writing alpha..."
    write_params(
        Dict(
            "m" => m,
            "epochs" => epochs,
            "hyp" => hyp,
            "retrain_hyp" => retrain_hyp,
            "validation_loss" => validation_loss,
            "retrain_loss" => retrain_loss,
        ),
        outdir,
    )
    write_alpha(hyp, m, outdir)
end;