In [1]:
using Plots, StatsPlots
include("src/SFGP_classification.jl")
using Flux
using StatsBase
using CSV, DataFrames
using ScikitLearn
using KnetMetrics
using Tables

In [2]:
df = CSV.File("./data/diabetes.csv",header = true) |> DataFrame
y = df[!,end]
y = Float64.(Flux.onehotbatch(y,["Positive", "Negative"])[1:1,:])

age = Matrix(transpose(Float64.(df[!,1][:,:])))
gender = Matrix((Float64.(Flux.onehotbatch(df[!,2][:,:], ["Male", "Female"])[1:1,:])))
rest = vcat([Matrix((Float64.(Flux.onehotbatch(df[!,c][:,:], ["Yes", "No"])[1:1,:]))) for c in 3:(size(df,2)-1)]...)

X = vcat(age,gender,rest)

16×520 Matrix{Float64}:
 40.0  58.0  41.0  45.0  60.0  55.0  …  54.0  39.0  48.0  58.0  32.0  42.0
  1.0   1.0   1.0   1.0   1.0   1.0      0.0   0.0   0.0   0.0   0.0   1.0
  0.0   0.0   1.0   0.0   1.0   1.0      1.0   1.0   1.0   1.0   0.0   0.0
  1.0   0.0   0.0   0.0   1.0   1.0      1.0   1.0   1.0   1.0   0.0   0.0
  0.0   0.0   0.0   1.0   1.0   0.0      1.0   1.0   1.0   1.0   0.0   0.0
  1.0   1.0   1.0   1.0   1.0   1.0  …   1.0   0.0   1.0   1.0   1.0   0.0
  0.0   0.0   1.0   1.0   1.0   1.0      1.0   1.0   1.0   1.0   0.0   0.0
  0.0   0.0   0.0   1.0   0.0   0.0      0.0   0.0   0.0   0.0   0.0   0.0
  0.0   1.0   0.0   0.0   1.0   1.0      0.0   0.0   0.0   1.0   1.0   0.0
  1.0   0.0   1.0   1.0   1.0   1.0      0.0   1.0   1.0   0.0   1.0   0.0
  0.0   0.0   0.0   0.0   1.0   0.0  …   0.0   0.0   1.0   0.0   0.0   0.0
  1.0   0.0   1.0   1.0   1.0   1.0      0.0   1.0   1.0   0.0   1.0   0.0
  0.0   1.0   0.0   0.0   1.0   0.0      1.0   1.0   1.0   1.0   0.0   0.0
 

In [3]:
function getLoglike(mm::SFGP,X,y)
    
    m,S = mm.gp(X)
    s = Matrix(transpose(sqrt.(diag(S))[:,:]))
    
    probs = getProbs.(m,s,mm.lower,mm.upper)
    vals = (mm.vals)
    
    return mean(logpdf.([MixtureModel(Bernoulli.(vals),p) for p in Flux.unstack(probs,2)],y[:]))
end


function getF1score(mm::SFGP,X,y)
    
    m,S = mm.gp(X)
    s = Matrix(transpose(sqrt.(diag(S))[:,:]))
    
    probs = getProbs.(m,s,mm.lower,mm.upper)
    vals = (mm.vals)
    
    predictions = round.(mean.([MixtureModel(Bernoulli.(vals),p) for p in Flux.unstack(probs,2)]))
    
    return KnetMetrics.f1_score(KnetMetrics.confusion_matrix(Int.(y[:]),Int.(predictions), labels=[0,1]),class_name=1)
    
end

getF1score (generic function with 1 method)

In [4]:
import Random
Random.seed!(321)

folds = ScikitLearn.CrossValidation.KFold(size(X,2),n_folds=10)

lls = []
f1s = []
i = 0

for (train, test) in folds
    
    Xm = mean(X,dims=2)
    Xs = std(X,dims=2)
    sfgp = SFGP(SVGP((X[:,1:10] .- Xm)./Xs),100,5)
    
    params = Flux.params(sfgp)
    opt = ADAM(0.05)
    
    Xtrain = X[:,train]
    
    mean_train = mean(Xtrain,dims=2)
    std_train = std(Xtrain,dims=2)
    Xtrain = (Xtrain.-mean_train) ./ std_train
    
    Xtest = X[:,test]
    Xtest = (Xtest.-mean_train) ./ std_train
    ytrain = y[:,train]
    ytest = y[:,test]
    
    for i in 1:350
        grads = Zygote.gradient(() -> sample_elbo(sfgp,Xtrain,ytrain),params)
        Flux.Optimise.update!(opt,params,grads)
    end
    
    push!(lls, getLoglike(sfgp,Xtest,ytest))
    push!(f1s, getF1score(sfgp,Xtest,ytest))
    
    i = i+1
    println(i)
end

1
2
3
4
5
6
7
8
9
10


└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:203
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:206
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:209
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:212
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:203
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:206
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:209
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix

In [5]:
println(mean(lls))
println(std(lls))

-0.24063003573982839
0.09592383849016625


In [6]:
println(mean(f1s))
println(std(f1s))

0.9123873024804834
0.07368496468281466


In [7]:
df = DataFrame(hcat(lls,f1s),[:loglike,:f1score])
CSV.write("./evals/sfgp_classification_diabetes.csv",  df)

"./evals/sfgp_classification_diabetes.csv"