# TreeModelBase
* Uses LightGBM to fit a tree model

In [1]:
using LightGBM

import SparseArrays: sparse
import Statistics: mean
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb");
@nbinclude("EnsembleInputs.ipynb");

┌ Info: lib_lightgbm found in system dirs!
└ @ LightGBM /Users/kundan/.julia/packages/LightGBM/A7zVd/src/LightGBM.jl:28


## Training

In [2]:
function augment_dataset(ds, y, w)
    LightGBM.LGBM_DatasetSetField(ds, "label", y)
    LightGBM.LGBM_DatasetSetField(ds, "weight", w)
    ds
end

function create_train_dataset(X, y, w, estimator)
    augment_dataset(
        LightGBM.LGBM_DatasetCreateFromMat(X, LightGBM.stringifyparams(estimator), false),
        y,
        w,
    )
end

function create_test_dataset(X, y, w, estimator, train_ds)
    augment_dataset(
        LightGBM.LGBM_DatasetCreateFromMat(
            X,
            LightGBM.stringifyparams(estimator),
            train_ds,
            false,
        ),
        y,
        w,
    )
end;

In [3]:
function get_features(alphas, split, implicit)
    reduce(hcat, [read_alpha(x, split, implicit).rating for x in alphas])
end

function get_data(split::String, feature_alphas, target_alphas, implicit, error_model)
    X = get_features(feature_alphas, split, implicit)
    if implicit
        if error_model
            y =
                get_split(split, implicit).rating .*
                get_weights(split, implicit, "inverse") -
                read_alpha(target_alphas, split, implicit).rating
            w = get_weights(split, implicit, "inverse")
        else
            @assert false
        end
    else
        y =
            get_split(split, implicit).rating -
            read_alpha(target_alphas, split, implicit).rating
        w = get_weights(split, implicit, "inverse")
    end
    if error_model
        y = abs.(y)
    end
    training_mask = get_split(split, implicit).user .<= num_users() * 0.9
    X_train, X_test = X[training_mask, :], X[.!training_mask, :]
    y_train, y_test = y[training_mask], y[.!training_mask]
    w_train, w_test = w[training_mask], w[.!training_mask]
    X_train, X_test, y_train, y_test, w_train, w_test
end

function get_data(
    splits::Vector{String},
    feature_alphas,
    target_alphas,
    implicit,
    error_model,
    estimator,
)
    data = []
    for split in splits
        push!(data, get_data(split, feature_alphas, target_alphas, implicit, error_model))
    end
    X_train = reduce(vcat, data[n][1] for n = 1:length(data))
    X_test = reduce(vcat, data[n][2] for n = 1:length(data))
    y_train = reduce(vcat, data[n][3] for n = 1:length(data))
    y_test = reduce(vcat, data[n][4] for n = 1:length(data))
    w_train = reduce(vcat, data[n][5] for n = 1:length(data))
    w_test = reduce(vcat, data[n][6] for n = 1:length(data))
    train_ds = create_train_dataset(X_train, y_train, w_train, estimator)
    test_ds = create_test_dataset(X_test, y_test, w_test, estimator, train_ds)
    train_ds, test_ds
end;

In [4]:
function train_model(
    feature_alphas,
    target_alphas,
    implicit,
    training_splits::Vector{String},
    outdir,
    error_model,
)
    set_logging_outdir(outdir)

    # create lightgbm tree model
    estimator = LGBMRegression(
        num_iterations = 100,
        learning_rate = 0.1,
        early_stopping_round = 10,
        feature_fraction = 0.8,
        bagging_fraction = 0.9,
        bagging_freq = 1,
        num_leaves = 1000,
    )

    # get training data
    train_ds, test_ds = get_data(
        training_splits,
        feature_alphas,
        target_alphas,
        implicit,
        error_model,
        estimator,
    )

    # train model
    fit!(estimator, train_ds, test_ds)

    # save model
    @info "Saving model... (this may take a while)"
    write_params(Dict("model" => estimator, "alphas" => feature_alphas), outdir)
    splits = reduce(cat, [get_split(split, implicit) for split in all_raw_splits])
    split_features = reduce(
        vcat,
        [get_features(feature_alphas, split, implicit) for split in all_raw_splits],
    )
    preds = vec(predict(estimator, split_features))
    sparse_preds = sparse(splits.user, splits.item, preds)
    @info "Average model value: $(mean(preds))"
    write_alpha(sparse_preds, target_alphas, implicit, outdir; log_splits = false)
end;