In [1]:
!pip install tabpfn

Collecting tabpfn
  Downloading tabpfn-0.1.9-py3-none-any.whl (156 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m156.6/156.6 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: tabpfn
Successfully installed tabpfn-0.1.9

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [1]:
import pandas as pd
import numpy as np
from tabpfn import TabPFNClassifier
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
import torch
import random, os

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [3]:
df = pd.read_csv("../input/icr-identify-age-related-conditions/train.csv", index_col=[0])
df['EJ'] = df['EJ'].map({'A': 0.0, 'B': 1.0}, na_action=None)
col_mean = df.mean(axis=0)
for i, colname in enumerate(df.columns):
    df[colname].fillna(col_mean[i], inplace=True)
df.head()

Unnamed: 0_level_0,AB,AF,AH,AM,AR,AX,AY,AZ,BC,BD,...,FL,FR,FS,GB,GE,GF,GH,GI,GL,Class
Id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
000ff2bfdfe9,0.209377,3109.03329,85.200147,22.394407,8.138688,0.699861,0.025578,9.812214,5.555634,4126.58731,...,7.298162,1.73855,0.094822,11.339138,72.611063,2003.810319,22.136229,69.834944,0.120343,1
007255e47698,0.145282,978.76416,85.200147,36.968889,8.138688,3.63219,0.025578,13.51779,1.2299,5496.92824,...,0.173229,0.49706,0.568932,9.292698,72.611063,27981.56275,29.13543,32.131996,21.978,0
013f2bd269f5,0.47003,2635.10654,85.200147,32.360553,8.138688,6.73284,0.025578,12.82457,1.2299,5135.78024,...,7.70956,0.97556,1.198821,37.077772,88.609437,13676.95781,28.022851,35.192676,0.196941,0
043ac50845d5,0.252107,3819.65177,120.201618,77.112203,8.138688,3.685344,0.025578,11.053708,1.2299,4169.67738,...,6.122162,0.49706,0.284466,18.529584,82.416803,2094.262452,39.948656,90.493248,0.155829,0
044fb8a146ec,0.380297,3733.04844,85.200147,14.103738,8.138688,3.942255,0.05481,3.396778,102.15198,5728.73412,...,8.153058,48.50134,0.121914,16.408728,146.109943,8524.370502,45.381316,36.262628,0.096614,1


In [4]:
df_test = pd.read_csv("../input/icr-identify-age-related-conditions/test.csv", index_col=[0])
df_test['EJ'] = df['EJ'].map({'A': 0.0, 'B': 1.0}, na_action=None)
for i, colname in enumerate(df_test.columns):
    df_test[colname].fillna(col_mean[i], inplace=True)
df_test

Unnamed: 0_level_0,AB,AF,AH,AM,AR,AX,AY,AZ,BC,BD,...,FI,FL,FR,FS,GB,GE,GF,GH,GI,GL
Id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
00eed32682bb,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
010ebe33f668,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
02fa521e1838,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
040e15f562a2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
046e85c7cc7f,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [5]:
def balanced_log_loss(y_true, y_pred):
    y_pred = y_pred.astype('float64')
    y_pred = np.clip(y_pred, 1e-15, 1-1e-15)
    nc = np.bincount(y_true)
    w0, w1 = 1/(nc[0]/y_true.shape[0]), 1/(nc[1]/y_true.shape[0])
    balanced_log_loss_score = (-w0/nc[0]*(np.sum(np.where(y_true==0,1,0) * np.log(1-y_pred))) - w1/nc[1]*(np.sum(np.where(y_true!=0,1,0) * np.log(y_pred)))) / (w0+w1)
    return balanced_log_loss_score

In [6]:
seed = 57
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
num = 0
seed_everything(seed)
score = 0
tabpfn = TabPFNClassifier(N_ensemble_configurations=64) 
 
models = []
for train_idx, valid_idx in kf.split(df, df.Class):
    train_df = df.iloc[train_idx]
    valid_df = df.iloc[valid_idx]
    X_train = train_df[df.columns[:-1]]
    y_train = train_df.Class

    X_valid = valid_df[df.columns[:-1]]
    y_valid = valid_df.Class
    
    print(f"Fold{num}")
    num+=1

    tabpfn.fit(X_train,y_train)
    pred = tabpfn.predict_proba(X_valid)[:, 1]
#     pred[pred > 0.89] = 1
#     pred[pred < 0.10] = 0
    sc = balanced_log_loss(y_valid, pred)
    print("accuracy = ", accuracy_score(y_valid.to_numpy(), np.round(pred)), "balanced_log_loss = ",sc)
    score += sc
    
    models.append(tabpfn)
score/5

Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters
Fold0
accuracy =  0.9032258064516129 balanced_log_loss =  0.9586338284586867
Fold1
accuracy =  0.9274193548387096 balanced_log_loss =  0.4942478340607524
Fold2
accuracy =  0.9105691056910569 balanced_log_loss =  0.5579979885010247
Fold3
accuracy =  0.967479674796748 balanced_log_loss =  0.7284307120294491
Fold4
accuracy =  0.926829268292683 balanced_log_loss =  0.25697595044026744


0.5992572626980361

In [7]:
X_test = df_test
pred = []
for model in models:
    p = model.predict_proba(X_test)
    if len(pred) == 0 :
        pred = p
    else:
        pred += p
        
pred /= len(models)

In [8]:
pred

array([[0.7179315 , 0.28206852],
       [0.7179315 , 0.28206852],
       [0.7179315 , 0.28206852],
       [0.7179315 , 0.28206852],
       [0.7179315 , 0.28206852]], dtype=float32)