RitSpls Example
================

The `RitSpls` package is written to be consistent with `ScikitLearn.jl`, such that widely used functions from ScikitLearn can be applied to it (e.g., `SPLS` and `GSSPP` objects).


Install the `RitSpls` package and check the key functions:

In [2]:
using Pkg
Pkg.instantiate()
using RitSpls
print(isdefined(RitSpls, :SPLS))
print(isdefined(RitSpls, :GSSPP))
print(isdefined(RitSpls, :wrap))

truetruetrue

1) Set up environment and load data
---------------------------------

Load dependencies

In [4]:
using CSV
using DataFrames
using ScikitLearn
using ScikitLearnBase
import ScikitLearn.GridSearch:GridSearchCV

Load data

In [5]:
# main_path = "YOUR_PATH"
# cd(main_path)
Xf = CSV.read("../data/Xfearncal.csv", DataFrame, header=0)
yf = CSV.read("../data/Yfearncal.csv", DataFrame, header=0)

Row,Column1
Unnamed: 0_level_1,Float64
1,9.23
2,8.01
3,10.95
4,11.67
5,10.41
6,9.51
7,8.67
8,7.75
9,8.05
10,11.39


# 2) Robustness-inducing transformations
--------------------------------------

Generalized spatial sign transformation of the predictors.

In [6]:
# GSS pre-processing transformation
gsspp_X = GSSPP()
Xpp = ScikitLearn.fit_transform!(gsspp_X, Xf)

# compare with original (centered) predictors
loc = kstepLTS(Matrix(Xf))
Xcentered = autoscale(Matrix(Xf), loc, "none").X_as_
Xpp .- Xcentered

24×6 Matrix{Float64}:
    0.0        0.0         0.0         0.0        0.0       0.0
    0.0        0.0         0.0         0.0        0.0       0.0
    0.0        0.0         0.0         0.0        0.0       0.0
    0.0        0.0         0.0         0.0        0.0       0.0
    0.0        0.0         0.0         0.0        0.0       0.0
    0.0        0.0         0.0         0.0        0.0       0.0
    0.0        0.0         0.0         0.0        0.0       0.0
    0.0        0.0         0.0         0.0        0.0       0.0
    0.0        0.0         0.0         0.0        0.0       0.0
    0.0        0.0         0.0         0.0        0.0       0.0
    0.0        0.0         0.0         0.0        0.0       0.0
    0.0        0.0         0.0         0.0        0.0       0.0
  -18.8805   -16.3169    -18.0145    -18.0491   -29.8971   -9.80402
  -27.7658   -24.4357    -26.1008    -26.6858   -42.6163  -16.5155
  -67.3599   -59.3425    -62.5626    -60.2625   -75.6403  -34.37
   -7.9494

Wrapping transformation of the predictand:

In [8]:
# wrapping pre-processing transformation
ywrap = wrap(yf)
ypp = ywrap.wrapX[1]

# compare with original predictand
ypp .- yf

Row,Column1
Unnamed: 0_level_1,Float64
1,0.0
2,0.0
3,0.0
4,0.0
5,0.0
6,0.0
7,0.0
8,0.0
9,0.0
10,0.0


3) SPLS on pre-processed data using cross-validation
-------------------

Estimate an SPLS model on robustly transformed data using cross validation.

In [9]:
cv_folds = 2
n_comp_range = collect(1:4);
eta_range = collect(0.9:-0.2:0.1)

ritSPLSreg = RitSpls.SPLS()
RitSpls.set_params_dict!(ritSPLSreg, Dict(:fit_algorithm=>"snipls", :verbose => false))
gridsearch = GridSearchCV(ritSPLSreg, cv=cv_folds, Dict(:eta => 0.5, :n_components => 3))
solfit = ScikitLearn.fit!(gridsearch,Xpp,ypp)
fit!(solfit,Xpp,ypp)

println(solfit.best_params_)
predict(solfit,Xpp)

Dict{Symbol, Any}(:eta => 0.5, :n_components => 3)


24-element Vector{Float64}:
  9.409796648688022
  8.172529462341041
 11.01874562886803
 11.636839881788624
 10.01583122073436
  9.395874649112235
  9.264237922156454
  8.136281688624493
  7.990493469016813
 11.657795691197707
 10.074402080468216
  8.242866126810235
 10.140573184537436
 10.067942104121313
 10.088876525134218
  8.894444804307824
  9.829343856277692
  9.367213895065781
 10.306780996969687
 11.818308290060616
 10.16856458274426
 10.513834388178932
 11.088357502180969
 11.890065400615066