In [1]:
import random
import time
import warnings
from datetime import datetime

import torch
import os

import numpy as np
import warnings
np.warnings = warnings
import matplotlib.pyplot as plt
from tabpfn_new.scripts.differentiable_pfn_evaluation import eval_model_range
from tabpfn_new.scripts.model_builder import get_model, get_default_spec, save_model, load_model
from tabpfn_new.scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, load_model_workflow

from tabpfn_new.scripts.model_configs import *

#from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids
from tabpfn_new.priors.utils import plot_prior, plot_features
from tabpfn_new.priors.utils import uniform_int_sampler_f

from tabpfn_new.scripts.tabular_metrics import calculate_score_per_method, calculate_score
from tabpfn_new.scripts.tabular_evaluation import evaluate

from tabpfn_new.priors.differentiable_prior import DifferentiableHyperparameterList, draw_random_style, merge_style_with_info
from tabpfn_new.scripts import tabular_metrics
from tabpfn.notebook_utils import *

In [2]:
device = 'cpu'
base_path = '.'
max_features = 100

In [3]:
def reload_config(prior_type='forest', config_type='causal', task_type='binary', longer=0):
    config = get_prior_config(config_type=config_type)
    
    config['prior_type'] = prior_type
    
    model_string = ''
    
    config['epochs'] = 12000
    config['recompute_attn'] = True

    config['max_num_classes'] = 10
    config['num_classes'] = uniform_int_sampler_f(2, config['max_num_classes'])
    config['balanced'] = False
    model_string = model_string + '_multiclass'
    
    model_string = model_string + '_'+datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
    
    return config, model_string

In [11]:
prior_type = 'mlp'
config, model_string = reload_config(prior_type, longer=1)

config['differentiable'] = True
config['flexible'] = True
config['bptt_extra_samples'] = None

# diff
config['output_multiclass_ordered_p'] = 0.0
del config['differentiable_hyperparameters']['output_multiclass_ordered_p']



config['sampling'] = 'normal' # vielleicht schlecht?
del config['differentiable_hyperparameters']['sampling']

config['pre_sample_causes'] = True
# end diff

config['multiclass_loss_type'] = 'nono' # 'compatible'

config['categorical_feature_p'] = 0 # diff: .0

# turn this back on in a random search!?
config['nan_prob_no_reason'] = .0
config['nan_prob_unknown_reason'] = .0 # diff: .0
config['set_value_to_nan'] = .0 # diff: 1.


config['new_mlp_per_example'] = True
config['prior_mlp_scale_weights_sqrt'] = True
config['batch_size_per_gp_sample'] = None



config['differentiable_hps_as_style'] = False
config['max_eval_pos'] = 1000

config['random_feature_rotation'] = True
config['rotate_normalized_labels'] = True

config["mix_activations"] = True # False heisst eig True

config['emsize'] = 512
config['nhead'] = config['emsize'] // 128
config['bptt'] = 1024+128
config['canonical_y_encoder'] = False


config['total_available_time_in_s'] = None #60*60*22 # 22 hours for some safety...

config['train_mixed_precision'] = False
config['efficient_eval_masking'] = True

# mlp params
config['is_causal'] = True
config['num_causes'] = 5
config['prior_mlp_hidden_dim'] = 50
config['num_layers'] = 4
config['noise_std'] = 0.05
config['init_std'] = 0.05
config['y_is_effect'] = True
config['pre_sample_weights'] = True
config['prior_mlp_dropout_prob'] = 0
config['pre_sample_causes'] = True
config["prior_mlp_activations"] = torch.nn.ReLU
config["block_wise_dropout"] = True
config["sort_features"] = False
config["in_clique"] = False

# general data params
config['balanced'] = True
config['max_num_classes'] = 10
config['max_features'] = 100
config['num_features_used'] = 100

config['differentiable'] = True
config['flexible'] = True

# forest params
config["min_features"] = 100
config["max_features"] = 100
config["n_samples"] = 1000
config["max_num_classes"] = 2
config["base_size"] = 1000
config["n_estimators"] = 1
config["min_depth"] = 5
config["max_depth"] = 15
config["categorical_x"] = False
config["data_sample_func"] = "mnd"
config["comp"] = False

config['no_encoder'] = False
config['normalize_to_ranking'] = False # False
config['normalize_with_sqrt'] = False
config['normalize_ignore_label_too'] = False
config["normalize"] = True
config['num_classes'] = 10

config["balanced"] = False
config['multiclass_type'] = 'rank'
del config['differentiable_hyperparameters']['multiclass_type']

config["prior_type"] = "mlp"
config["microbiome_test"] = True
config["weight_classes"] = False
config["run_name"] = "mlp_multi"

    
config['aggregate_k_gradients'] = 16
config['batch_size'] = 4*config['aggregate_k_gradients']
config['num_steps'] = 4*config['aggregate_k_gradients']//config['aggregate_k_gradients']
config['epochs'] = 50
#config['warmup_epochs'] = 10

config["lr"] = 1e-4

config_sample = evaluate_hypers(config)
'''for key in config_sample:
    #if key == "check_is_compatible":
    print(key, config_sample[key])
for key in config_sample["differentiable_hyperparameters"]:
    print(key, config_sample["differentiable_hyperparameters"][key])'''

'for key in config_sample:\n    #if key == "check_is_compatible":\n    print(key, config_sample[key])\nfor key in config_sample["differentiable_hyperparameters"]:\n    print(key, config_sample["differentiable_hyperparameters"][key])'

In [12]:
model = get_model(config_sample, device, should_train=True, verbose=1)

Using style prior: True
Using cpu:0 device
Using a Transformer with 25.81 M parameters


% Positive predictions:
0.513  0.499  0.470  0.479  
% Positive targets:
0.231  0.127  0.292  0.043  
Train sample accuracy: 0.490


% Positive predictions:
0.000  0.000  0.000  0.000  
% Positive targets:
0.170  0.563  0.479  0.235  
Train sample accuracy: 0.638


% Positive predictions:
0.760  0.683  0.841  0.316  
% Positive targets:
0.588  0.502  0.501  0.485  
Train sample accuracy: 0.525


% Positive predictions:
0.997  0.992  0.979  0.996  
% Positive targets:
0.546  0.502  0.449  0.471  
Train sample accuracy: 0.487


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938226        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 263.27s | mean loss  0.71 |  mean accuracy  0.51 |  lr 0.0001 |  data time  0.28 step time  2.43 forward time  0.95 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
0.684  0.287  0.552  0.718  
% Positive targets:
0.511  0.329  0.457  0.499  
Train sample accuracy: 0.525


% Positive predictions:
0.714  0.086  0.433  0.999  
% Positive targets:
0.595  0.360  0.494  0.895  
Train sample accuracy: 0.627


% Positive predictions:
0.140  0.520  0.520  0.000  
% Positive targets:
0.360  0.540  0.640  0.320  
Train sample accuracy: 0.560


% Positive predictions:
1.000  0.000  0.952  0.495  
% Positive ta

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.939796        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 241.40s | mean loss  0.65 |  mean accuracy  0.63 |  lr 9.990133642141359e-05 |  data time  0.27 step time  2.09 forward time  0.77 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
1.000  0.000  0.736  0.000  
% Positive targets:
0.649  0.152  0.423  0.370  
Train sample accuracy: 0.649


% Positive predictions:
1.000  0.234  0.106  0.404  
% Positive targets:
0.532  0.447  0.447  0.426  
Train sample accuracy: 0.527


% Positive predictions:
0.000  1.000  0.027  0.000  
% Positive targets:
0.314  0.785  0.496  0.401  
Train sample accuracy: 0.643


% Positive predictions:
0.167  0.005  0.003  0.000 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.939622        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   3 | time: 249.53s | mean loss  0.64 |  mean accuracy  0.62 |  lr 9.96057350657239e-05 |  data time  0.31 step time  2.37 forward time  0.91 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
1.000  1.000  0.000  0.760  
% Positive targets:
0.601  0.480  0.268  0.408  
Train sample accuracy: 0.574


% Positive predictions:
0.993  0.000  0.009  0.996  
% Positive targets:
0.545  0.312  0.422  0.541  
Train sample accuracy: 0.585


% Positive predictions:
0.000  1.000  0.052  0.000  
% Positive targets:
0.172  0.649  0.563  0.437  
Train sample accuracy: 0.621


% Positive predictions:
0.920  0.006  0.440  0.000  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN   0.93997        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   4 | time: 327.39s | mean loss  0.64 |  mean accuracy  0.61 |  lr 9.911436253643445e-05 |  data time  0.27 step time  4.74 forward time  1.94 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
1.000  0.000  0.000  0.481  
% Positive targets:
0.733  0.503  0.501  0.493  
Train sample accuracy: 0.560


% Positive predictions:
0.787  0.000  0.000  1.000  
% Positive targets:
0.529  0.029  0.293  0.851  
Train sample accuracy: 0.751


% Positive predictions:
0.949  0.971  1.000  0.948  
% Positive targets:
0.499  0.581  0.653  0.493  
Train sample accuracy: 0.561


% Positive predictions:
1.000  0.906  1.000  0.438 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.942676        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   5 | time: 288.74s | mean loss  0.64 |  mean accuracy  0.61 |  lr 9.842915805643155e-05 |  data time  0.28 step time  5.34 forward time  1.99 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
0.930  0.006  0.000  0.131  
% Positive targets:
0.534  0.376  0.376  0.429  
Train sample accuracy: 0.592


% Positive predictions:
1.000  0.000  0.005  0.961  
% Positive targets:
0.773  0.166  0.372  0.635  
Train sample accuracy: 0.713


% Positive predictions:
0.000  0.983  0.840  0.006  
% Positive targets:
0.022  0.618  0.553  0.475  
Train sample accuracy: 0.661


% Positive predictions:
0.124  0.096  1.000  0.467 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938138        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   6 | time: 281.86s | mean loss  0.63 |  mean accuracy  0.63 |  lr 9.755282581475769e-05 |  data time  0.35 step time  2.51 forward time  0.87 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
0.798  0.000  0.185  0.000  
% Positive targets:
0.548  0.429  0.458  0.250  
Train sample accuracy: 0.604


% Positive predictions:
0.600  0.000  0.600  1.000  
% Positive targets:
0.400  0.000  0.400  0.800  
Train sample accuracy: 0.550


% Positive predictions:
0.956  1.000  1.000  0.721  
% Positive targets:
0.465  0.682  0.574  0.463  
Train sample accuracy: 0.552


% Positive predictions:
0.961  0.000  0.998  0.000 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.942675        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   7 | time: 310.43s | mean loss  0.62 |  mean accuracy  0.62 |  lr 9.648882429441257e-05 |  data time  0.32 step time  5.13 forward time  1.58 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
0.471  0.657  0.000  0.777  
% Positive targets:
0.507  0.488  0.421  0.486  
Train sample accuracy: 0.503


% Positive predictions:
1.000  0.000  0.156  0.092  
% Positive targets:
0.697  0.283  0.447  0.515  
Train sample accuracy: 0.607


% Positive predictions:
0.000  0.195  1.000  0.000  
% Positive targets:
0.090  0.530  0.764  0.353  
Train sample accuracy: 0.703


% Positive predictions:
0.347  0.041  1.000  0.000 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.941628        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   8 | time: 283.49s | mean loss  0.65 |  mean accuracy  0.59 |  lr 9.524135262330098e-05 |  data time  0.32 step time  4.59 forward time  1.91 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
0.142  0.009  0.080  0.017  
% Positive targets:
0.514  0.412  0.494  0.469  
Train sample accuracy: 0.521


% Positive predictions:
1.000  0.500  0.000  0.723  
% Positive targets:
0.752  0.517  0.357  0.532  
Train sample accuracy: 0.596


% Positive predictions:
0.154  1.000  0.962  0.962  
% Positive targets:
0.346  0.538  0.654  0.462  
Train sample accuracy: 0.538


% Positive predictions:
1.000  0.471  1.000  0.002 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.940232        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch   9 | time: 269.03s | mean loss  0.63 |  mean accuracy  0.61 |  lr 9.381533400219318e-05 |  data time  0.28 step time  2.38 forward time  0.87 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
1.000  0.029  0.000  0.000  
% Positive targets:
0.893  0.414  0.236  0.021  
Train sample accuracy: 0.805


% Positive predictions:
1.000  0.000  0.000  1.000  
% Positive targets:
0.635  0.312  0.363  0.661  
Train sample accuracy: 0.655


% Positive predictions:
0.360  0.190  0.645  0.195  
% Positive targets:
0.410  0.520  0.555  0.505  
Train sample accuracy: 0.498


% Positive predictions:
0.779  0.000  0.970  0.038 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.937265        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  10 | time: 255.62s | mean loss  0.63 |  mean accuracy  0.62 |  lr 9.221639627510076e-05 |  data time  0.27 step time  3.77 forward time  1.47 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
1.000  0.000  0.009  0.000  
% Positive targets:
0.837  0.303  0.519  0.370  
Train sample accuracy: 0.660


% Positive predictions:
1.000  0.340  0.360  0.957  
% Positive targets:
0.678  0.487  0.487  0.579  
Train sample accuracy: 0.579


% Positive predictions:
0.232  1.000  1.000  0.460  
% Positive targets:
0.552  0.568  0.564  0.440  
Train sample accuracy: 0.526


% Positive predictions:
0.956  0.892  0.863  0.362 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.943461        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  11 | time: 242.24s | mean loss  0.63 |  mean accuracy  0.61 |  lr 9.045084971874738e-05 |  data time  0.28 step time  3.45 forward time  1.46 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
0.854  0.991  0.000  0.836  
% Positive targets:
0.574  0.463  0.296  0.479  
Train sample accuracy: 0.549


% Positive predictions:
0.907  0.000  0.429  1.000  
% Positive targets:
0.536  0.408  0.450  0.699  
Train sample accuracy: 0.588


% Positive predictions:
0.000  1.000  1.000  0.000  
% Positive targets:
0.282  0.720  0.935  0.254  
Train sample accuracy: 0.780


% Positive predictions:
1.000  0.000  1.000  0.035 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938138        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  12 | time: 290.12s | mean loss  0.61 |  mean accuracy  0.63 |  lr 8.852566213878947e-05 |  data time  0.31 step time  3.06 forward time  1.11 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
1.000  0.170  0.007  0.000  
% Positive targets:
0.719  0.487  0.464  0.147  
Train sample accuracy: 0.647


% Positive predictions:
0.004  0.000  0.938  1.000  
% Positive targets:
0.509  0.467  0.497  0.536  
Train sample accuracy: 0.516


% Positive predictions:
0.000  0.774  0.998  0.000  
% Positive targets:
0.065  0.546  0.623  0.120  
Train sample accuracy: 0.740


% Positive predictions:
1.000  0.869  1.000  0.292 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938313        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  13 | time: 289.00s | mean loss  0.62 |  mean accuracy  0.62 |  lr 8.644843137107059e-05 |  data time  0.30 step time  3.92 forward time  1.53 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
1.000  0.000  0.000  0.743  
% Positive targets:
0.680  0.268  0.389  0.471  
Train sample accuracy: 0.629


% Positive predictions:
1.000  0.480  0.000  0.857  
% Positive targets:
0.561  0.408  0.133  0.520  
Train sample accuracy: 0.612


% Positive predictions:
0.240  1.000  1.000  0.604  
% Positive targets:
0.490  0.719  0.521  0.531  
Train sample accuracy: 0.578


% Positive predictions:
0.731  0.582  1.000  0.000 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.938662        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  14 | time: 316.46s | mean loss  0.62 |  mean accuracy  0.62 |  lr 8.422735529643444e-05 |  data time  0.45 step time  4.85 forward time  1.67 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
1.000  0.000  0.147  0.002  
% Positive targets:
0.643  0.472  0.484  0.430  
Train sample accuracy: 0.560


% Positive predictions:
0.827  0.488  0.002  1.000  
% Positive targets:
0.525  0.461  0.506  0.772  
Train sample accuracy: 0.564


% Positive predictions:
0.148  0.725  1.000  0.000  
% Positive targets:
0.452  0.522  0.613  0.391  
Train sample accuracy: 0.562


% Positive predictions:
0.030  0.000  0.999  0.436 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.939185        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  15 | time: 298.19s | mean loss  0.63 |  mean accuracy  0.61 |  lr 8.18711994874345e-05 |  data time  0.32 step time  3.61 forward time  1.74 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
1.000  1.000  0.000  0.000  
% Positive targets:
0.737  0.473  0.185  0.176  
Train sample accuracy: 0.712


% Positive predictions:
0.919  0.000  0.088  1.000  
% Positive targets:
0.554  0.366  0.490  0.933  
Train sample accuracy: 0.650


% Positive predictions:
0.258  1.000  1.000  0.000  
% Positive targets:
0.462  0.909  0.712  0.348  
Train sample accuracy: 0.688


% Positive predictions:
1.000  0.000  1.000  0.000  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                    accuracy  precision  recall  roc_auc
Micriobiome TabPFN  0.937876        0.0     0.0      0.5
-----------------------------------------------------------------------------------------
| end of epoch  16 | time: 270.96s | mean loss  0.61 |  mean accuracy  0.64 |  lr 7.938926261462366e-05 |  data time  0.25 step time  2.70 forward time  0.91 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
0.659  0.000  0.001  0.000  
% Positive targets:
0.526  0.380  0.356  0.134  
Train sample accuracy: 0.657


In [6]:
import os
def load_train_results(name):
    dir_path = os.path.abspath(os.getcwd())
    path = dir_path + f"/logs/trainrun_{name}"
    losses = torch.load(path+"/losses")
    mb_results = torch.load(path+"/mb_results")
    try:
        accuracies = torch.load(path+"/accuracies")
    except:
        accuracies = 0
    return losses, mb_results, accuracies

In [7]:
def plot_metrics(losses, results, metrics, name):
    plt.figure(1)
    plt.subplot(211)
    for ii, m in enumerate(metrics):
        plt.plot(results[:,ii], label=m)
    plt.ylabel("score")
    plt.title(name)
    plt.legend()
    plt.subplot(212)
    plt.plot(losses)
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.show()

In [8]:
name = config["run_name"]
metrics = ["accuracy", "precision", "recall", "roc_auc"]
losses, mb_results, accuracies = load_train_results(name)
plot_metrics(losses, mb_results, metrics, name)

KeyboardInterrupt: 

In [None]:
names = [
        "forest_balanced_noweight",
        "forest_balanced",
        "forest_nonorm",
        "mlp_baseline",
        "forest_longer"]
metrics = ["accuracy", "precision", "recall", "roc_auc"]
for name in names:
    losses, mb_results, _ = load_train_results(name)
    plot_metrics(losses, mb_results, metrics, name)

In [None]:
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
b = torch.randint(0,2,(100,2))
a = torch.rand((100,2,3))
a[3,0,:] = 100

for i in range(20):
    split = torch.randint(8,100,(1,))
    sss = StratifiedShuffleSplit(n_splits=1, test_size = a.shape[0]-split)
    sss.get_n_splits(a,b)
    
    train_index, test_index= next(sss.split(a,b))
    X_train, y_train, X_test, y_test = a[train_index], b[train_index], a[test_index], b[test_index]
    #print(X_train, X_test)
    print(torch.where(X_train>10)[1])
    print(torch.where(X_test>10)[1])
    #print(split)
    #print(y_test.shape)
    #for i in range(2):
        #print(torch.unique(y_train[:,i], return_counts=True)[1]/y_train.shape[0])
        #print(torch.unique(y_test[:,i], return_counts=True)[1]/y_test.shape[0])
    X = torch.cat((X_train,X_test), dim=0)
    y = torch.cat((y_train,y_test), dim=0)