In [8]:
using Plots, StatsPlots
include("src/SFGP_classification-learnable.jl")
using Flux
using StatsBase
using CSV, DataFrames
using ScikitLearn
using KnetMetrics
using Tables

In [9]:
df = CSV.File("./data/cancer.csv",header = false) |> DataFrame

y = df[!,2]
y = Float64.(Flux.onehotbatch(y,["M", "B"])[1:1,:])

X = Matrix(transpose(Matrix(df[!,3:end])))

30×569 Matrix{Float64}:
   17.99        20.57        19.69      …    20.6         7.76
   10.38        17.77        21.25           29.33       24.54
  122.8        132.9        130.0           140.1        47.92
 1001.0       1326.0       1203.0          1265.0       181.0
    0.1184       0.08474      0.1096          0.1178      0.05263
    0.2776       0.07864      0.1599    …     0.277       0.04362
    0.3001       0.0869       0.1974          0.3514      0.0
    0.1471       0.07017      0.1279          0.152       0.0
    0.2419       0.1812       0.2069          0.2397      0.1587
    0.07871      0.05667      0.05999         0.07016     0.05884
    1.095        0.5435       0.7456    …     0.726       0.3857
    0.9053       0.7339       0.7869          1.595       1.428
    8.589        3.398        4.585           5.772       2.548
    ⋮                                   ⋱               
    0.03003      0.01389      0.0225          0.02324     0.02676
    0.006193     0.003

In [10]:
function getLoglike(mm::SFGP,X,y)
    
    m,S = mm.gp(X)
    s = Matrix(transpose(sqrt.(diag(S))[:,:]))
    
    probs = getProbs.(m,s,mm.lower,mm.upper)
    vals = σ.(mm.vals)
    
    return mean(logpdf.([MixtureModel(Bernoulli.(vals),p) for p in Flux.unstack(probs,2)],y[:]))
end


function getF1score(mm::SFGP,X,y)
    
    m,S = mm.gp(X)
    s = Matrix(transpose(sqrt.(diag(S))[:,:]))
    
    probs = getProbs.(m,s,mm.lower,mm.upper)
    vals = σ.(mm.vals)
    
    predictions = round.(mean.([MixtureModel(Bernoulli.(vals),p) for p in Flux.unstack(probs,2)]))
    
    return KnetMetrics.f1_score(KnetMetrics.confusion_matrix(Int.(y[:]),Int.(predictions), labels=[0,1]),class_name=1)
    
end

getF1score (generic function with 1 method)

In [11]:
import Random
Random.seed!(321)

folds = ScikitLearn.CrossValidation.KFold(size(X,2),n_folds=10)

lls = []
f1s = []
i = 0

for (train, test) in folds

    Xm = mean(X,dims=2)
    Xs = std(X,dims=2)
    sfgp = SFGP(SVGP((X[:,1:10] .- Xm)./Xs),20,3)
    
    params = Flux.params(sfgp)
    opt = ADAM(0.05)
    
    Xtrain = X[:,train]
    
    mean_train = mean(Xtrain,dims=2)
    std_train = std(Xtrain,dims=2)
    Xtrain = (Xtrain.-mean_train) ./ std_train
    
    Xtest = X[:,test]
    Xtest = (Xtest.-mean_train) ./ std_train
    ytrain = y[:,train]
    ytest = y[:,test]
    
    for i in 1:350
        grads = Zygote.gradient(() -> sample_elbo(sfgp,Xtrain,ytrain),params)
        Flux.Optimise.update!(opt,params,grads)
    end
    
    push!(lls, getLoglike(sfgp,Xtest,ytest))
    push!(f1s, getF1score(sfgp,Xtest,ytest))
    
    i = i+1
    println(i)
end

1
2
3
4
5
6
7
8
9
10


└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:209
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:212
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:209
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:212
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:209
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:212
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix.jl:209
└ @ KnetMetrics.Classification /Users/saremseitz/.julia/packages/KnetMetrics/9L9oI/src/classification/confusion_matrix

In [12]:
println(mean(lls))
println(std(lls))

-0.08584886978665499
0.03302769689439013


In [13]:
println(mean(f1s))
println(std(f1s))

0.9618343250951945
0.019822076250138033


In [14]:
df = DataFrame(hcat(lls,f1s),[:loglike,:f1score])
CSV.write("./evals/sfgp_learnable_classification_cancer.csv",  df)

"./evals/sfgp_learnable_classification_cancer.csv"