In [None]:
import openml
from tqdm import tqdm
import math
import torch as th
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss, roc_auc_score , accuracy_score

from tabpfn import TabPFNClassifier

from tab_pfn.networks import TabPFN, SklearnClassifier

In [None]:
benchmark = openml.study.get_suite('OpenML-CC18')
tasks = openml.tasks.list_tasks(task_id=benchmark.tasks, output_format="dataframe")

retained_datasets = []

for _, row in tqdm(list(tasks.iterrows())):
    try:
        datasets = openml.tasks.get_task(row["tid"]).get_dataset()
    except Exception as e:
        print(e)
        print(row["tid"])
        continue
    
    if row["NumberOfInstances"] > 2000:
        continue
    if row["NumberOfNumericFeatures"] > 100:
        continue
    if datasets.qualities["NumberOfClasses"] > 10:
        continue
    
    retained_datasets.append(openml.tasks.get_task(row["tid"]).get_dataset())

In [None]:
len(retained_datasets)

In [None]:
tab_pfn = TabPFN(100, 10, 512, 1024, 4, 6)
tab_pfn.load_state_dict(th.load("/home/samuel/PycharmProjects/TabPFN/out/out_train_trf_big/model_183295.pt", map_location="cuda"))
tab_pfn.eval()
tab_pfn.cuda()

#tab_pfn_clf = SklearnClassifier.from_torch(tab_pfn)
tab_pfn_clf = TabPFNClassifier(device="cuda")

scores = {}

for dataset in tqdm(retained_datasets):
    x, y, _, _ = dataset.get_data(target=dataset.default_target_attribute)

    with th.no_grad():
        
        x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.5, test_size=0.5)
        
        try:
            tab_pfn_clf.fit(x_train, y_train)
            out_proba = tab_pfn_clf.predict_proba(x_test)
            out = tab_pfn_clf.predict(x_test)
        except ValueError as ve:
            continue
        except TypeError as te:
            continue
        # roc_auc = roc_auc_score(y_test, out, multi_class="ovr")
        cross_entropy = log_loss(y_test, out_proba)
        accuracy = accuracy_score(y_test, out)
        
        scores[dataset.name] = {
            # "roc": roc_auc,
            "cross_entropy": cross_entropy,
            "accuracy": accuracy
        }

In [None]:
for n, score in scores.items():
    print(n, score)

3 : loss=0.1690, acc=0.9507 | rec=0.9491
11 : loss=0.2526, acc=0.8263 | rec=0.8025
14 : loss=0.5841, acc=0.7704 | rec=0.7747
15 : loss=nan, acc=0.3295 | rec=0.5000
16 : loss=0.2647, acc=0.9336 | rec=0.9341
18 : loss=0.7649, acc=0.6760 | rec=0.6736
22 : loss=0.4107, acc=0.8285 | rec=0.8313
23 : loss=1.0401, acc=0.1406 | rec=0.3333
29 : loss=nan, acc=0.2377 | rec=0.5000
31 : loss=0.5340, acc=0.6851 | rec=0.5948
37 : loss=0.4828, acc=0.7703 | rec=0.7518
46 : loss=0.9109, acc=0.5742 | rec=0.5267
50 : loss=0.4127, acc=0.7752 | rec=0.7244
54 : loss=0.6417, acc=0.6987 | rec=0.7033
188 : loss=nan, acc=0.0652 | rec=0.2000
38 : loss=nan, acc=0.4706 | rec=0.5000
458 : loss=0.0191, acc=0.9958 | rec=0.9964
469 : loss=1.8085, acc=0.0327 | rec=0.1667
1049 : loss=0.2422, acc=0.9570 | rec=0.6187
1050 : loss=0.2954, acc=0.4501 | rec=0.5000
1063 : loss=0.4099, acc=0.7803 | rec=0.6860
1067 : loss=0.3358, acc=0.5825 | rec=0.5066
1068 : loss=0.2535, acc=0.7101 | rec=0.5101
1510 : loss=0.0973, acc=0.9519 | rec=0.9586
1494 : loss=0.3405, acc=0.8471 | rec=0.8348
1480 : loss=0.5533, acc=0.5745 | rec=0.5225
1487 : loss=0.1464, acc=0.8084 | rec=0.5150
1462 : loss=0.0508, acc=0.9887 | rec=0.9849
1464 : loss=0.5115, acc=0.8881 | rec=0.5174
6332 : loss=nan, acc=0.2185 | rec=0.5000
23381 : loss=0.7787, acc=0.5707 | rec=0.5681
40966 : loss=nan, acc=0.0188 | rec=0.1250
40982 : loss=0.8137, acc=0.7151 | rec=0.6522
40994 : loss=0.2078, acc=0.8597 | rec=0.6902
40975 : loss=0.2396, acc=0.8264 | rec=0.6853
40984 : loss=0.3858, acc=0.8246 | rec=0.8291
40978 : loss=0.2861, acc=0.9171 | rec=0.7535
40670 : loss=1.0561, acc=0.1668 | rec=0.3333

In [None]:
cross_entropy = sum(s["cross_entropy"] for _, s in scores.items()) / len(scores)
accuracy = sum(s["accuracy"] for _, s in scores.items()) / len(scores)
print("cross_entropy", cross_entropy)
print("accuracy", accuracy)

Our model :

cross_entropy 0.5815344112381117
accuracy 0.7349593118393449

Author code :

cross_entropy 0.2765864623252986
accuracy 0.8887808004275888