In [None]:
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from statistics import mean

import torch as th

from tabpfn import TabPFNClassifier

from tab_pfn.networks import SCM
from tab_pfn.metrics import ConfusionMeter

import warnings

warnings.filterwarnings("ignore")

In [None]:
classifier = TabPFNClassifier(device='cuda', N_ensemble_configurations=32)

In [None]:
n_datasets = 128

prec, rec = [], []
nb = 0

for _ in tqdm(range(n_datasets)):
    scm = SCM(100, (10, 10), False)
    x, y = scm(1024)
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42)
    
    try:
        classifier.fit(x_train, y_train)
        p_eval = classifier.predict_proba(x_test)
        conf_meter = ConfusionMeter(scm.nb_class)
        conf_meter.add(th.tensor(p_eval), y_test)
        prec.append(conf_meter.precision().sum().item())
        rec.append(conf_meter.recall().sum().item())
        nb += conf_meter.recall().size(0)
    except ValueError:
        pass

In [None]:
sum(prec)/nb, sum(rec)/nb, len(prec)