In [1]:
from xgboost import XGBClassifier
import warnings
from tabpfn_new.scripts.transformer_prediction_interface import TabPFNClassifier
import numpy as np
import pandas as pd
import os
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split, cross_val_score, cross_validate
from data_prep_utils import *
from evaluate import *
from load_models import *
import matplotlib.pyplot as plt
import torch
import openml
import time

In [2]:
path = "datasets/data_all.csv"
data, labels = get_microbiome(path)
data = top_non_zero(data)
data, labels = unison_shuffled_copies(data, labels)

In [3]:
for sampling in [None, undersample]:
    cv = 3
    strat_split = True
    n_optim = 1000
    ft_epochs = 10
    max_samples = None
    metrics = metrics = ["accuracy", "precision", "recall", "roc_auc"]
    models = [
        XGBClassifier(n_estimators=5, max_depth=5, learning_rate=1, objective='binary:logistic'),
        XGBoostOptim(n_optim=n_optim),
        LogisticRegression(max_iter=500), 
        TabPFNClassifier(device='cpu', N_ensemble_configurations=5, no_preprocess_mode=True),
        TabForestPFNClassifier("saved_models/tabforest/mix600k/tabforestpfn.pt", "saved_models/tabforest/mix600k/config_run.yaml", max_epochs=ft_epochs)
    ]
    results = pd.DataFrame(np.zeros((len(models), len(metrics)+1)), 
                           index=[m.__class__.__name__ for m in models],
                          columns=metrics+["runtime"])
    
    for ii, model in enumerate(models):
        results.iloc[ii,:] = cross_validate_sample(model, data, labels, metrics, strat_split, cv, sampling, max_samples)
    results_sorted = results.sort_values("roc_auc")
    print(results_sorted)

[32m2024-10-07 11:53:44.652[0m | [1mINFO    [0m | [36mtabularbench.core.trainer_finetune[0m:[36mtrain[0m:[36m83[0m - [1mEpoch 000 | Train loss: -.---- | Train score: -.---- | Val loss: 0.1880 | Val score: 0.9398[0m
[32m2024-10-07 11:54:18.606[0m | [1mINFO    [0m | [36mtabularbench.core.trainer_finetune[0m:[36mtrain[0m:[36m94[0m - [1mEpoch 001 | Train loss: 0.2043 | Train score: 0.9336 | Val loss: 0.1828 | Val score: 0.9437[0m
[32m2024-10-07 11:54:51.808[0m | [1mINFO    [0m | [36mtabularbench.core.trainer_finetune[0m:[36mtrain[0m:[36m94[0m - [1mEpoch 002 | Train loss: 0.1721 | Train score: 0.9453 | Val loss: 0.1831 | Val score: 0.9378[0m
[32m2024-10-07 11:55:26.732[0m | [1mINFO    [0m | [36mtabularbench.core.trainer_finetune[0m:[36mtrain[0m:[36m94[0m - [1mEpoch 003 | Train loss: 0.1703 | Train score: 0.9414 | Val loss: 0.1818 | Val score: 0.9378[0m
[32m2024-10-07 11:55:59.872[0m | [1mINFO    [0m | [36mtabularbench.core.trainer_finetun

                        accuracy  precision    recall   roc_auc     runtime
TabPFNClassifier        0.940997   0.933333  0.037196  0.518459  104.979090
LogisticRegression      0.935847   0.362993  0.068670  0.530431    0.302000
TabForestPFNClassifier  0.944663   0.761072  0.144492  0.570573  397.125838
XGBoostOptim            0.946146   0.695337  0.206009  0.600123  601.063999
XGBClassifier           0.932618   0.409771  0.228898  0.603620    0.238067


[32m2024-10-07 12:21:21.066[0m | [1mINFO    [0m | [36mtabularbench.core.trainer_finetune[0m:[36mtrain[0m:[36m83[0m - [1mEpoch 000 | Train loss: -.---- | Train score: -.---- | Val loss: 0.4535 | Val score: 0.7647[0m
[32m2024-10-07 12:21:22.961[0m | [1mINFO    [0m | [36mtabularbench.core.trainer_finetune[0m:[36mtrain[0m:[36m94[0m - [1mEpoch 001 | Train loss: 0.5223 | Train score: 0.7383 | Val loss: 0.4558 | Val score: 0.7701[0m
[32m2024-10-07 12:21:24.833[0m | [1mINFO    [0m | [36mtabularbench.core.trainer_finetune[0m:[36mtrain[0m:[36m94[0m - [1mEpoch 002 | Train loss: 0.5067 | Train score: 0.7584 | Val loss: 0.4531 | Val score: 0.7701[0m
[32m2024-10-07 12:21:26.655[0m | [1mINFO    [0m | [36mtabularbench.core.trainer_finetune[0m:[36mtrain[0m:[36m94[0m - [1mEpoch 003 | Train loss: 0.5710 | Train score: 0.7181 | Val loss: 0.4504 | Val score: 0.7861[0m
[32m2024-10-07 12:21:28.518[0m | [1mINFO    [0m | [36mtabularbench.core.trainer_finetun

                        accuracy  precision    recall   roc_auc     runtime
XGBClassifier           0.725757   0.146564  0.723891  0.724885    0.053333
LogisticRegression      0.745396   0.154930  0.711016  0.729323    0.106000
TabPFNClassifier        0.757790   0.170888  0.771102  0.764013    8.327999
TabForestPFNClassifier  0.773588   0.184911  0.793991  0.783127   24.987333
XGBoostOptim            0.774548   0.187351  0.805436  0.788989  152.246673


In [6]:
print(results_sorted)

                        accuracy  precision    recall   roc_auc     runtime
XGBClassifier           0.728288   0.141617  0.682403  0.706837    0.055000
LogisticRegression      0.738850   0.145708  0.673820  0.708447    0.105000
TabPFNClassifier        0.768264   0.175135  0.753934  0.761565   10.005333
TabForestPFNClassifier  0.784237   0.186522  0.749642  0.768063   29.427220
XGBoostOptim            0.797853   0.201961  0.785408  0.792035  174.394001
