In [1]:
import random
import time
import warnings
from datetime import datetime

import torch

import numpy as np
import warnings
np.warnings = warnings
import matplotlib.pyplot as plt
from tabpfn_new.scripts.differentiable_pfn_evaluation import eval_model_range
from tabpfn_new.scripts.model_builder import get_model, get_default_spec, save_model, load_model
from tabpfn_new.scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, load_model_workflow

from tabpfn_new.scripts.model_configs import *

#from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids
from tabpfn_new.priors.utils import plot_prior, plot_features
from tabpfn_new.priors.utils import uniform_int_sampler_f

from tabpfn_new.scripts.tabular_metrics import calculate_score_per_method, calculate_score
from tabpfn_new.scripts.tabular_evaluation import evaluate

from tabpfn_new.priors.differentiable_prior import DifferentiableHyperparameterList, draw_random_style, merge_style_with_info
from tabpfn_new.scripts import tabular_metrics
from tabpfn.notebook_utils import *

In [2]:
device = 'cpu'
base_path = '.'
max_features = 100

In [3]:
def reload_config(prior_type='forest', config_type='causal', task_type='binary', longer=0):
    config = get_prior_config(config_type=config_type)
    
    config['prior_type'] = prior_type
    
    model_string = ''
    
    config['epochs'] = 12000
    config['recompute_attn'] = True

    config['max_num_classes'] = 10
    config['num_classes'] = uniform_int_sampler_f(2, config['max_num_classes'])
    config['balanced'] = False
    model_string = model_string + '_multiclass'
    
    model_string = model_string + '_'+datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
    
    return config, model_string

In [4]:
prior_type = 'mlp'
config, model_string = reload_config(prior_type, longer=1)

config['differentiable'] = True
config['flexible'] = True
config['bptt_extra_samples'] = None

# diff
config['output_multiclass_ordered_p'] = 0.0
del config['differentiable_hyperparameters']['output_multiclass_ordered_p']



config['sampling'] = 'normal' # vielleicht schlecht?
del config['differentiable_hyperparameters']['sampling']

config['pre_sample_causes'] = True
# end diff

config['multiclass_loss_type'] = 'nono' # 'compatible'
config['normalize_to_ranking'] = False # False

config['categorical_feature_p'] = 0 # diff: .0

# turn this back on in a random search!?
config['nan_prob_no_reason'] = .0
config['nan_prob_unknown_reason'] = .0 # diff: .0
config['set_value_to_nan'] = .0 # diff: 1.

config['normalize_with_sqrt'] = False

config['new_mlp_per_example'] = True
config['prior_mlp_scale_weights_sqrt'] = True
config['batch_size_per_gp_sample'] = None

config['normalize_ignore_label_too'] = False

config['differentiable_hps_as_style'] = False
config['max_eval_pos'] = 1000

config['random_feature_rotation'] = True
config['rotate_normalized_labels'] = True

config["mix_activations"] = True # False heisst eig True

config['emsize'] = 512
config['nhead'] = config['emsize'] // 128
config['bptt'] = 1024+128
config['canonical_y_encoder'] = False

    
config['aggregate_k_gradients'] = 4
config['batch_size'] = 8*config['aggregate_k_gradients']
config['num_steps'] = 32//config['aggregate_k_gradients']
config['epochs'] = 50

config['total_available_time_in_s'] = None #60*60*22 # 22 hours for some safety...

config['train_mixed_precision'] = False
config['efficient_eval_masking'] = True

# mlp params
config['is_causal'] = True
config['num_causes'] = 5
config['prior_mlp_hidden_dim'] = 50
config['num_layers'] = 4
config['noise_std'] = 0.05
config['init_std'] = 0.05
config['y_is_effect'] = True
config['pre_sample_weights'] = True
config['prior_mlp_dropout_prob'] = 0
config['pre_sample_causes'] = True
config["prior_mlp_activations"] = torch.nn.ReLU
config["block_wise_dropout"] = True
config["sort_features"] = False
config["in_clique"] = False

# general data params
config['balanced'] = True
config['max_num_classes'] = 10
config['max_features'] = 100
config['num_features_used'] = 100

config['differentiable'] = True
config['flexible'] = True

# forest params
config["min_features"] = 100
config["max_features"] = 100
config["n_samples"] = 1000
config["max_num_classes"] = 2
config["base_size"] = 1000
config["n_estimators"] = 1
config["min_depth"] = 1
config["max_depth"] = 25
config["categorical_x"] = False
config["data_sample_func"] = "mnd"
config["comp"] = False

config['num_classes'] = 2
config["balanced"] = False
config['multiclass_type'] = 'rank'
del config['differentiable_hyperparameters']['multiclass_type']

config["prior_type"] = "forest"
config["microbiome_test"] = True
config["run_name"] = "forest_same_sampling"

config_sample = evaluate_hypers(config)
'''for key in config_sample:
    #if key == "check_is_compatible":
    print(key, config_sample[key])
for key in config_sample["differentiable_hyperparameters"]:
    print(key, config_sample["differentiable_hyperparameters"][key])'''

'for key in config_sample:\n    #if key == "check_is_compatible":\n    print(key, config_sample[key])\nfor key in config_sample["differentiable_hyperparameters"]:\n    print(key, config_sample["differentiable_hyperparameters"][key])'

In [5]:
model = get_model(config_sample, device, should_train=True, verbose=1)

Using style prior: True
Using cpu:0 device
Using a Transformer with 25.81 M parameters
                    accuracy  precision    recall   roc_auc
Micriobiome TabPFN  0.562081   0.067538  0.481439  0.524383
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 350.20s | mean loss  0.70 |  data time  0.69 step time  9.69 forward time  3.12 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------
                    accuracy  precision    recall   roc_auc
Micriobiome TabPFN  0.740686   0.065551  0.256134  0.513625
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 270.89s | mean loss  0.65 |  data time  0.53 step time  4.93 forward time  1.65 nan share  0.00 ignore share (for classification tasks) 0.0000
---------------------------------------------------------------

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN   0.94093        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   4 | time: 240.46s | mean loss  0.53 |  data time  0.55 step time  5.63 forward time  2.30 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------
                    accuracy  precision    recall   roc_auc
Micriobiome TabPFN   0.93805    0.07906  0.004255  0.500828
-----------------------------------------------------------------------------------------
| end of epoch   5 | time: 241.67s | mean loss  0.48 |  data time  0.51 step time  8.38 forward time  3.57 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN   0.94032        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   6 | time: 264.73s | mean loss  0.45 |  data time  0.51 step time  8.06 forward time  3.25 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN    0.9384        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   7 | time: 244.56s | mean loss  0.46 |  data time  0.54 step time  8.10 forward time  2.98 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.940319        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   8 | time: 232.32s | mean loss  0.48 |  data time  0.51 step time  3.65 forward time  1.24 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.937265        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   9 | time: 229.65s | mean loss  0.45 |  data time  0.53 step time  6.55 forward time  2.70 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.935869        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  10 | time: 230.35s | mean loss  0.42 |  data time  0.48 step time  6.38 forward time  2.36 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.936741        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  11 | time: 248.55s | mean loss  0.50 |  data time  0.52 step time  5.74 forward time  2.35 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.936916        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  12 | time: 228.95s | mean loss  0.49 |  data time  0.48 step time  5.05 forward time  1.94 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.939011        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  13 | time: 239.18s | mean loss  0.45 |  data time  0.55 step time  8.76 forward time  3.30 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.934386        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  14 | time: 251.36s | mean loss  0.47 |  data time  0.55 step time  4.55 forward time  1.68 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938138        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  15 | time: 255.28s | mean loss  0.46 |  data time  0.51 step time  8.11 forward time  3.11 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.936742        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  16 | time: 318.10s | mean loss  0.53 |  data time  0.52 step time  9.35 forward time  3.43 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938923        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  17 | time: 295.12s | mean loss  0.48 |  data time  0.77 step time 11.49 forward time  4.53 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938312        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  18 | time: 276.45s | mean loss  0.48 |  data time  0.54 step time  5.52 forward time  2.15 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938225        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  19 | time: 261.53s | mean loss  0.47 |  data time  0.53 step time  8.27 forward time  3.15 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938836        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  20 | time: 291.00s | mean loss  0.47 |  data time  0.55 step time  4.04 forward time  1.54 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.939011        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  21 | time: 272.85s | mean loss  0.48 |  data time  0.51 step time  4.93 forward time  1.77 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.940843        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  22 | time: 260.52s | mean loss  0.48 |  data time  0.62 step time  7.80 forward time  3.10 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.940145        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  23 | time: 246.67s | mean loss  0.46 |  data time  0.49 step time  4.25 forward time  1.66 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.936567        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  24 | time: 232.71s | mean loss  0.47 |  data time  0.49 step time  5.38 forward time  2.10 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938661        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  25 | time: 252.86s | mean loss  0.45 |  data time  0.54 step time  6.69 forward time  2.54 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.937963        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  26 | time: 266.45s | mean loss  0.47 |  data time  0.53 step time  9.72 forward time  3.80 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.936916        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  27 | time: 251.45s | mean loss  0.48 |  data time  0.53 step time  7.20 forward time  2.66 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.941715        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  28 | time: 250.65s | mean loss  0.41 |  data time  0.55 step time  4.96 forward time  1.88 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.939447        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  29 | time: 238.28s | mean loss  0.44 |  data time  0.55 step time  5.13 forward time  2.02 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.939272        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  30 | time: 250.81s | mean loss  0.51 |  data time  0.54 step time  7.39 forward time  2.86 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938836        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  31 | time: 258.76s | mean loss  0.50 |  data time  0.52 step time  8.51 forward time  3.16 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.939272        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  32 | time: 258.65s | mean loss  0.44 |  data time  0.49 step time  4.68 forward time  1.61 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.939534        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  33 | time: 252.09s | mean loss  0.45 |  data time  0.53 step time  7.61 forward time  2.91 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938313        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  34 | time: 260.30s | mean loss  0.47 |  data time  0.52 step time 10.47 forward time  4.02 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.941629        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  35 | time: 253.24s | mean loss  0.46 |  data time  0.55 step time  6.88 forward time  2.77 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.941454        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  36 | time: 248.14s | mean loss  0.48 |  data time  0.51 step time  8.77 forward time  3.59 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.942413        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  37 | time: 255.88s | mean loss  0.44 |  data time  0.48 step time  4.40 forward time  1.59 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938749        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  38 | time: 244.77s | mean loss  0.49 |  data time  0.49 step time  4.79 forward time  1.80 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.936916        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  39 | time: 235.71s | mean loss  0.49 |  data time  0.52 step time  7.13 forward time  2.62 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN   0.93805        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  40 | time: 251.66s | mean loss  0.45 |  data time  0.56 step time  5.14 forward time  2.10 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.936567        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  41 | time: 253.17s | mean loss  0.47 |  data time  0.50 step time  6.89 forward time  2.86 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.937266        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  42 | time: 265.99s | mean loss  0.41 |  data time  0.53 step time 10.44 forward time  4.27 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.937964        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  43 | time: 230.19s | mean loss  0.47 |  data time  0.53 step time  6.44 forward time  2.40 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.935346        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  44 | time: 247.71s | mean loss  0.48 |  data time  0.49 step time  7.83 forward time  3.00 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.940669        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  45 | time: 242.97s | mean loss  0.43 |  data time  0.55 step time  4.27 forward time  1.34 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN   0.93648        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  46 | time: 252.36s | mean loss  0.44 |  data time  0.52 step time  9.19 forward time  3.45 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.937876        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  47 | time: 246.24s | mean loss  0.49 |  data time  0.51 step time  7.48 forward time  3.10 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.939272        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  48 | time: 252.70s | mean loss  0.47 |  data time  0.49 step time  7.64 forward time  2.97 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN   0.93552        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  49 | time: 262.69s | mean loss  0.49 |  data time  0.58 step time  9.22 forward time  4.06 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.942501        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  50 | time: 329.88s | mean loss  0.45 |  data time  0.60 step time  4.72 forward time  1.70 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


In [None]:
config_sample['batch_size'] = 4
model = get_model(config_sample, device, should_train=False, verbose=2) # , state_dict=model[2].state_dict()
(hp_embedding, data, _), targets, single_eval_pos = next(iter(model[3]))

#from utils import normalize_data
fig = plt.figure(figsize=(8, 8))
N = 100
plot_features(data[0:N, 0, 0:4], targets[0:N, 0], fig=fig)

d = np.concatenate([data[:, 0, :].T, np.expand_dims(targets[:, 0], -1).T])
d[np.isnan(d)] = 0
c = np.corrcoef(d)
plt.matshow(np.abs(c), vmin=0, vmax=1)
plt.show()