In [1]:
import random
import time
import warnings
from datetime import datetime

import torch
import os

import numpy as np
import warnings
np.warnings = warnings
import matplotlib.pyplot as plt
#from tabpfn_new.scripts.differentiable_pfn_evaluation import eval_model_range
from tabpfn_new.scripts.model_builder import get_model, get_default_spec, save_model, load_model
from tabpfn_new.scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, load_model_workflow

from tabpfn_new.scripts.model_configs import *

#from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids
from tabpfn_new.priors.utils import plot_prior, plot_features
from tabpfn_new.priors.utils import uniform_int_sampler_f

#from tabpfn_new.scripts.tabular_metrics import calculate_score_per_method, calculate_score
#from tabpfn_new.scripts.tabular_evaluation import evaluate

from tabpfn_new.priors.differentiable_prior import DifferentiableHyperparameterList, draw_random_style, merge_style_with_info
from tabpfn_new.scripts import tabular_metrics
from tabpfn.notebook_utils import *

In [2]:
device = 'cpu'
base_path = '.'
max_features = 100

In [3]:
def reload_config(prior_type='forest', config_type='causal', task_type='binary', longer=0):
    config = get_prior_config(config_type=config_type)
    
    config['prior_type'] = prior_type
    
    model_string = ''
    
    config['epochs'] = 12000
    config['recompute_attn'] = True

    config['max_num_classes'] = 10
    config['num_classes'] = uniform_int_sampler_f(2, config['max_num_classes'])
    config['balanced'] = False
    model_string = model_string + '_multiclass'
    
    model_string = model_string + '_'+datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
    
    return config, model_string

In [4]:
prior_type = 'mlp'
config, model_string = reload_config(prior_type, longer=1)

config['differentiable'] = True
config['flexible'] = True
config['bptt_extra_samples'] = None

# diff
config['output_multiclass_ordered_p'] = 0.0
del config['differentiable_hyperparameters']['output_multiclass_ordered_p']



config['sampling'] = 'normal' # vielleicht schlecht?
del config['differentiable_hyperparameters']['sampling']

config['pre_sample_causes'] = True
# end diff

config['multiclass_loss_type'] = 'nono' # 'compatible'

config['categorical_feature_p'] = 0 # diff: .0

# turn this back on in a random search!?
config['nan_prob_no_reason'] = .0
config['nan_prob_unknown_reason'] = .0 # diff: .0
config['set_value_to_nan'] = .0 # diff: 1.


config['new_mlp_per_example'] = True
config['prior_mlp_scale_weights_sqrt'] = True
config['batch_size_per_gp_sample'] = None



config['differentiable_hps_as_style'] = False
config['max_eval_pos'] = 1025
config['min_eval_pos'] = 1024

config['random_feature_rotation'] = True
config['rotate_normalized_labels'] = True

config["mix_activations"] = True # False heisst eig True

config['n_layers'] = 2
config['emsize'] = 64
config['nhead'] = config['emsize'] // 16
config['bptt'] = 1024+128
config['canonical_y_encoder'] = False


config['total_available_time_in_s'] = None #60*60*22 # 22 hours for some safety...

config['train_mixed_precision'] = False
config['efficient_eval_masking'] = True

# mlp params
config['is_causal'] = False # False for data from MLP input
config['num_causes'] = 5
config['prior_mlp_hidden_dim'] = 50
config['num_layers'] = 4
config['noise_std'] = 0.05
config['init_std'] = 0.05
config['y_is_effect'] = True
config['pre_sample_weights'] = True
config['prior_mlp_dropout_prob'] = 0
config['pre_sample_causes'] = True
config["prior_mlp_activations"] = torch.nn.ReLU
config["block_wise_dropout"] = True
config["sort_features"] = False
config["in_clique"] = False

# general data params
config['balanced'] = True
config['max_features'] = 100
config['num_features_used'] = 100

config['differentiable'] = True
config['flexible'] = True

# forest params
config["min_features"] = 100
config["max_features"] = 100
config["n_samples"] = 1000
config["base_size"] = 1000
config["n_estimators"] = 1
config["categorical_x"] = False
config["comp"] = False

config['no_encoder'] = False
config['normalize_to_ranking'] = False # False
config['normalize_with_sqrt'] = False
config['normalize_ignore_label_too'] = False
config["normalize_labels"] = False
config["normalize"] = False
config["clr"] = True
config["num_classes"] = 2
config["max_num_classes"] = 2

# forest params
config["min_depth"] = 5
config["max_depth"] = 10
config["data_sample_func"] = "mnd"

# mlp params
config["mlp_noise"] = False # needs to be false such that noise doesn't get drown out information from input to final output
config["sampling"] = "mnd"
if "is_causal" in config['differentiable_hyperparameters']:
    del config['differentiable_hyperparameters']['is_causal']
config["is_causal"] = False # needs to be false such that mnd causes are x
if "block_wise_dropout" in config['differentiable_hyperparameters']:
    del config['differentiable_hyperparameters']['block_wise_dropout']
config["block_wise_dropout"] = False # needs to be false for final output = y, otherwise setting last-layer block to dropout creates bad datasets
# increase lower bound of hidden dim for more complex datasets
config['differentiable_hyperparameters']["prior_mlp_hidden_dim"] = {'distribution': 'meta_gamma', 'max_alpha': 3, 'max_scale': 100, 'round': True, 'lower_bound': 25} 

# general data params
config["prior_type"] = "mlp"
config['multiclass_type'] = 'static_balance'
del config['differentiable_hyperparameters']['multiclass_type']
config["align_majority"] = False
config["limit_imbalance"] = False
config["microbiome_test"] = True
config["weight_classes"] = False

# training params
config['aggregate_k_gradients'] = 1
config['batch_size'] = 1*config['aggregate_k_gradients']
config['num_steps'] = 1*config['aggregate_k_gradients']//config['aggregate_k_gradients']
config['epochs'] = 10
config["lr"] = 1e-2
config["frac"] = 0.35
#config['warmup_epochs'] = 10

config["run_name"] = "time"

config_sample = evaluate_hypers(config)
"""for key in config_sample:
    #if key == "check_is_compatible":
    print(key, config_sample[key])
for key in config_sample["differentiable_hyperparameters"]:
    print(key, config_sample["differentiable_hyperparameters"][key])"""

'for key in config_sample:\n    #if key == "check_is_compatible":\n    print(key, config_sample[key])\nfor key in config_sample["differentiable_hyperparameters"]:\n    print(key, config_sample["differentiable_hyperparameters"][key])'

In [5]:
model = get_model(config_sample, device, should_train=True, verbose=1)



% Positive predictions:
1.000  
% Positive targets:
0.594  
Train sample accuracy: 0.594


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                accuracy  precision  recall   roc_auc   f1  runtime
Medical TabPFN  0.941176        0.0     0.0  0.492188  0.0     0.53
-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  3.43s | mean loss  0.68 |  mean accuracy 0.5938 |  preds imbalance measure  0.41 |  lr 0.01 |  data time  0.07 step time  0.27 forward time  0.12 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
1.000  
% Positive targets:
0.672  
Train sample accuracy: 0.672

% of positive predictions:  1.0
                accuracy  precision  recall   roc_auc        f1   runtime
Medical TabPFN  0.058824   0.058824     1.0  0.489844  0.111111  0.395622
-----------------------------------------------------------------------------------------
| end of epoch   2 | time:  2.74s | mean loss  0.71 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



% of positive predictions:  0.0
                accuracy  precision  recall   roc_auc   f1   runtime
Medical TabPFN  0.941176        0.0     0.0  0.476562  0.0  0.377131
-----------------------------------------------------------------------------------------
| end of epoch   3 | time:  2.64s | mean loss  0.56 |  mean accuracy 0.7734 |  preds imbalance measure  0.23 |  lr 0.009045084971874737 |  data time  0.04 step time  0.26 forward time  0.11 nan share  0.00 ignore share (for classification tasks) 0.0000
-----------------------------------------------------------------------------------------


% Positive predictions:
1.000  
% Positive targets:
0.328  
Train sample accuracy: 0.328


In [6]:
names = [
        "forest_balanced_noweight",
        "forest_balanced",
        "forest_nonorm",
        "mlp_baseline",
        "forest_longer"]
metrics = ["accuracy", "precision", "recall", "roc_auc"]
for name in names:
    losses, mb_results, _ = load_train_results(name)
    plot_metrics(losses, mb_results, metrics, name)


KeyboardInterrupt



In [None]:
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
b = torch.randint(0,2,(100,2))
a = torch.rand((100,2,3))
a[3,0,:] = 100

for i in range(20):
    split = torch.randint(8,100,(1,))
    sss = StratifiedShuffleSplit(n_splits=1, test_size = a.shape[0]-split)
    sss.get_n_splits(a,b)
    
    train_index, test_index= next(sss.split(a,b))
    X_train, y_train, X_test, y_test = a[train_index], b[train_index], a[test_index], b[test_index]
    #print(X_train, X_test)
    print(torch.where(X_train>10)[1])
    print(torch.where(X_test>10)[1])
    #print(split)
    #print(y_test.shape)
    #for i in range(2):
        #print(torch.unique(y_train[:,i], return_counts=True)[1]/y_train.shape[0])
        #print(torch.unique(y_test[:,i], return_counts=True)[1]/y_test.shape[0])
    X = torch.cat((X_train,X_test), dim=0)
    y = torch.cat((y_train,y_test), dim=0)