In [16]:
using Plots, StatsPlots
include("src/SFGP_classification-learnable.jl")
using Flux
using StatsBase
using CSV, DataFrames
using ScikitLearn
using KnetMetrics
using Tables

In [17]:
df = Matrix(CSV.File("./data/banana.csv") |> DataFrame)[:,1:3]
X = transpose(df)[1:2,:]
y = (transpose(df)[3:3,:] .+ 1) ./2

1×5300 Matrix{Float64}:
 0.0  1.0  0.0  1.0  1.0  1.0  0.0  1.0  …  1.0  1.0  1.0  1.0  1.0  0.0  1.0

In [18]:
function getLoglike(mm::SFGP,X,y)
    
    m,S = mm.gp(X)
    s = Matrix(transpose(sqrt.(diag(S))[:,:]))
    
    probs = getProbs.(m,s,mm.lower,mm.upper)
    vals = σ.(mm.vals)
    
    return mean(logpdf.([MixtureModel(Bernoulli.(vals),p) for p in Flux.unstack(probs,2)],y[:]))
end


function getF1score(mm::SFGP,X,y)
    
    m,S = mm.gp(X)
    s = Matrix(transpose(sqrt.(diag(S))[:,:]))
    
    probs = getProbs.(m,s,mm.lower,mm.upper)
    vals = σ.(mm.vals)
    
    predictions = round.(mean.([MixtureModel(Bernoulli.(vals),p) for p in Flux.unstack(probs,2)]))
    
    return KnetMetrics.f1_score(KnetMetrics.confusion_matrix(Int.(y[:]),Int.(predictions), labels=[0,1]),class_name=1)
    
end

getF1score (generic function with 1 method)

In [29]:
import Random
Random.seed!(321)

folds = ScikitLearn.CrossValidation.KFold(size(X,2),n_folds=10)

lls = []
f1s = []
i = 0

for (train, test) in folds

    Xm = mean(X,dims=2)
    Xs = std(X,dims=2)
    sfgp = SFGP(SVGP((X[:,1:10] .- Xm)./Xs),20,3)
    
    params = Flux.params(sfgp)
    opt = ADAM(0.05)
    
    Xtrain = X[:,train]
    
    mean_train = mean(Xtrain,dims=2)
    std_train = std(Xtrain,dims=2)
    Xtrain = (Xtrain.-mean_train) ./ std_train
    
    Xtest = X[:,test]
    Xtest = (Xtest.-mean_train) ./ std_train
    ytrain = y[:,train]
    ytest = y[:,test]
    
    for i in 1:350
        grads = Zygote.gradient(() -> sample_elbo(sfgp,Xtrain,ytrain),params)
        Flux.Optimise.update!(opt,params,grads)
    end
    
    push!(lls, getLoglike(sfgp,Xtest,ytest))
    push!(f1s, getF1score(sfgp,Xtest,ytest))
    
    i = i+1
    println(i)
end

1
2
3
4
5
6
7
8
9
10


In [30]:
println(mean(lls))
println(std(lls))

-0.2575942432497359
0.022267283405403897


In [31]:
println(mean(f1s))
println(std(f1s))

0.8758773297434018
0.01854844341478424


In [32]:
df = DataFrame(hcat(lls,f1s),[:loglike,:f1score])
CSV.write("./evals/sfgp_learnable_classification_banana.csv",  df)

"./evals/sfgp_learnable_classification_banana.csv"