In [191]:
import torch

In [192]:
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

from tabpfn import TabPFNClassifier

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

# N_ensemble_configurations controls the number of model predictions that are ensembled with feature and class rotations (See our work for details).
# When N_ensemble_configurations > #features * #classes, no further averaging is applied.

classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=1)

classifier.fit(X_train, y_train)
y_eval, p_eval = classifier.predict(X_test, return_winning_probability=True)

print('Accuracy', accuracy_score(y_test, y_eval))



Accuracy 0.9787234042553191


In [193]:
type(classifier.model[2])

tabpfn.transformer.TransformerModel

In [194]:
y_eval

array([1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1,
       0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1,
       0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0,
       1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1,
       0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0,
       1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1,
       1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
       0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1])

In [195]:
y_test

array([1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1,
       0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1,
       0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0,
       1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1,
       0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0,
       1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1,
       1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
       0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1])

In [196]:
import numpy as np

In [197]:
device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')

In [198]:
X_train = torch.tensor(X_train, dtype=torch.float32, device=device)
X_test = torch.tensor(X_test, dtype=torch.float32, device=device)
y_train = torch.tensor(y_train, dtype=torch.int64, device=device)
y_test = torch.tensor(y_test, dtype=torch.int64, device=device)

In [199]:
# Concatenate train and test data

X_full = torch.cat((X_train, X_test), dim=0).float().unsqueeze(1).to(device)

y_full = np.concatenate([y_train, np.zeros(shape=X.shape[0])], axis=0)   # for the test data, we don't have the labels, thus we use zeros
y_full = torch.tensor(y_full, device=device).float().unsqueeze(1)   

eval_pos = X_train.shape[0]  # position where the test data starts

In [200]:
from tabpfn.scripts.transformer_prediction_interface import transformer_predict

In [201]:
model = classifier.model[2]  # extract the pytorch model from the TabPFNClassifier

In [202]:
def get_params_from_config(c):
    return {'max_features': c['num_features']
        , 'rescale_features': c["normalize_by_used_features"]
        , 'normalize_to_ranking': c["normalize_to_ranking"]
        , 'normalize_with_sqrt': c.get("normalize_with_sqrt", False)
            }


In [203]:
pred = transformer_predict(classifier.model[2], X_full, y_full, eval_pos,
                                         device=classifier.device,
                                         style=classifier.style,
                                         inference_mode=False,
                                         preprocess_transform='none' if classifier.no_preprocess_mode else 'mix',
                                         normalize_with_test=False,
                                         N_ensemble_configurations=classifier.N_ensemble_configurations,
                                         softmax_temperature=classifier.temperature,
                                         multiclass_decoder=classifier.multiclass_decoder,
                                         feature_shift_decoder=classifier.feature_shift_decoder,
                                         differentiable_hps_as_style=classifier.differentiable_hps_as_style,
                                         seed=classifier.seed,
                                         return_logits=False,
                                         no_grad=False,
                                         batch_size_inference=classifier.batch_size_inference,
                                         **get_params_from_config(classifier.c))



In [179]:
pred_label = torch.argmax(pred, dim=2).cpu()

In [180]:
pred_label

tensor([[1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1,
         1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
         1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0,
         1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1,
         0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0,
         1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1,
         0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,
         1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1]])

In [181]:
from tabpfn.utils import normalize_data, to_ranking_low_mem, remove_outliers
from tabpfn.utils import NOP, normalize_by_used_features_f


from sklearn.preprocessing import PowerTransformer, QuantileTransformer, RobustScaler
import warnings

In [182]:
def preprocess_input(eval_xs, 
                     eval_ys,
                     preprocess_transform,
                     eval_position,
                     max_features,
                     normalize_with_test = False,
                     normalize_to_ranking = False,
                     normalize_with_sqrt = False,
                     device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda'),
                     categorical_feats = []):
        
        """
        Preprocess the input data for the transformer model
        Args:
            eval_xs: torch.Tensor, x-value input for evaluation
            eval_ys: torch.Tensor, ys input
            preprocess_transform: str, type of preprocessing to be applied. Options: 'none', 'power', 'quantile', 'robust', 'power_all', 'quantile_all', 'robust_all'
            eval_position: int, position where the evaluation data starts
            max_features: int, maximum number of features to be used
            normalize_with_test: bool, whether to normalize with test data
            normalize_to_ranking: bool, whether to normalize to ranking
            normalize_with_sqrt: bool, whether to normalize with sqrt
            device: str, device to be used
            categorical_feats: list, list of categorical features
        """

        if eval_xs.shape[1] > 1:
            raise Exception("Transforms only allow one batch dim - TODO")

        if eval_xs.shape[2] > max_features:
            eval_xs = eval_xs[:, :, sorted(np.random.choice(eval_xs.shape[2], max_features, replace=False))]

        if preprocess_transform != 'none':
            if preprocess_transform == 'power' or preprocess_transform == 'power_all':
                pt = PowerTransformer(standardize=True)
            elif preprocess_transform == 'quantile' or preprocess_transform == 'quantile_all':
                pt = QuantileTransformer(output_distribution='normal')
            elif preprocess_transform == 'robust' or preprocess_transform == 'robust_all':
                pt = RobustScaler(unit_variance=True)

        # eval_xs, eval_ys = normalize_data(eval_xs), normalize_data(eval_ys)
        eval_xs = normalize_data(eval_xs, normalize_positions=-1 if normalize_with_test else eval_position)

        # Removing empty features
        eval_xs = eval_xs[:, 0, :]
        sel = [len(torch.unique(eval_xs[0:eval_ys.shape[0], col])) > 1 for col in range(eval_xs.shape[1])]
        eval_xs = eval_xs[:, sel]

        warnings.simplefilter('error')
        if preprocess_transform != 'none':
            eval_xs = eval_xs.cpu().numpy()
            feats = set(range(eval_xs.shape[1])) if 'all' in preprocess_transform else set(
                range(eval_xs.shape[1])) - set(categorical_feats)
            for col in feats:
                try:
                    pt.fit(eval_xs[0:eval_position, col:col + 1])
                    trans = pt.transform(eval_xs[:, col:col + 1])
                    # print(scipy.stats.spearmanr(trans[~np.isnan(eval_xs[:, col:col+1])], eval_xs[:, col:col+1][~np.isnan(eval_xs[:, col:col+1])]))
                    eval_xs[:, col:col + 1] = trans
                except:
                    pass
            eval_xs = torch.tensor(eval_xs).float()
        warnings.simplefilter('default')

        eval_xs = eval_xs.unsqueeze(1)

        # TODO: Caution there is information leakage when to_ranking is used, we should not use it
        eval_xs = remove_outliers(eval_xs, normalize_positions=-1 if normalize_with_test else eval_position) \
                if not normalize_to_ranking else normalize_data(to_ranking_low_mem(eval_xs))
        # Rescale X
        eval_xs = normalize_by_used_features_f(eval_xs, eval_xs.shape[-1], max_features,
                                               normalize_with_sqrt=normalize_with_sqrt)

        return eval_xs.to(device)

In [183]:
def predict(eval_xs, eval_ys, softmax_temperature, return_logits, eval_position, model, num_classes):

        output = model(
                    (None, eval_xs, eval_ys.float()),
                    single_eval_pos=eval_position)[:, :, 0:num_classes]
        

        output = output[:, :, 0:num_classes] / torch.exp(softmax_temperature)

        if not return_logits:
                output = torch.nn.functional.softmax(output, dim=-1)


        return output

In [221]:
import itertools
import random

def transformer_predict_custom(
                            model, 
                            eval_xs, 
                            eval_ys, 
                            eval_position,
                            device='cpu',
                            max_features=100,
                            style=None,
                            num_classes=2,
                            extend_features=True,
                            normalize_with_test=False,
                            normalize_to_ranking=False,
                            softmax_temperature=0.0,
                            preprocess_transform='mix',
                            categorical_feats=[],
                            batch_size_inference=16,
                            normalize_with_sqrt=False,
                            seed=0,
                            return_logits=False):
    
    num_classes = len(torch.unique(eval_ys))


    eval_xs, eval_ys = eval_xs.to(device), eval_ys.to(device)
    eval_ys = eval_ys[:eval_position]

    model.to(device)

    model.eval()

    if seed is not None:
        torch.manual_seed(seed)

    style = None
    softmax_temperature = torch.log(torch.tensor([0.8]))
    softmax_temperature = torch.tensor(softmax_temperature).to(device)

    inputs, labels = [], []

    eval_xs_, eval_ys_ = eval_xs.clone(), eval_ys.clone()


    eval_xs_ = preprocess_input(
        eval_xs=eval_xs_,
        eval_ys=eval_ys_,
        preprocess_transform=preprocess_transform,
        eval_position=eval_position,
        max_features=max_features,
        normalize_with_test=normalize_with_test,
        normalize_to_ranking=normalize_to_ranking,
        normalize_with_sqrt=normalize_with_sqrt,
        device=device,
        categorical_feats=categorical_feats
    )

    eval_ys_ = ((eval_ys_ + 0) % num_classes).float()

    print(eval_xs_.shape, eval_ys_.shape)

    eval_xs_ = torch.cat([eval_xs_[..., 0:],eval_xs_[..., :0]],dim=-1) # this shifts the features by one

    print(eval_xs_.shape, eval_ys_.shape)

    # Extend X
    if extend_features:
        eval_xs_ = torch.cat(
            [eval_xs_,
                torch.zeros((eval_xs_.shape[0], eval_xs_.shape[1], max_features - eval_xs_.shape[2])).to(device)], -1)
    inputs += [eval_xs_]
    labels += [eval_ys_]

    inputs = torch.cat(inputs, 1)
    inputs = torch.split(inputs, batch_size_inference, dim=1)
    labels = torch.cat(labels, 1)
    labels = torch.split(labels, batch_size_inference, dim=1)

    print(inputs[0].shape, labels[0].shape)

    outputs = []
    for batch_input, batch_label in zip(inputs, labels):            
        output_batch = predict(eval_xs=batch_input, eval_ys=batch_label, softmax_temperature=softmax_temperature, return_logits=return_logits, eval_position=eval_position, model=model, num_classes=num_classes)
           
        outputs.append(output_batch)
    
    return torch.cat(outputs, 1).squeeze(1)


    

In [222]:
get_params_from_config(classifier.c)

{'max_features': 100,
 'rescale_features': True,
 'normalize_to_ranking': False,
 'normalize_with_sqrt': False}

In [223]:
len(y_full)

950

In [224]:
len(X_full)

569

In [225]:
model

TransformerModel(
  (transformer_encoder): TransformerEncoderDiffInit(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=1024, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.0, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (encoder): Linear(in_features=100, out_features=512, bias=True)
  (y_encoder): Linear(in_features=1, out_features=512, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=512, out_features=1024, bias=True)
    (1): GELU(approximate='none')

In [226]:
pred_custom = transformer_predict_custom(classifier.model[2], X_full, y_full, eval_pos,
                                         device=classifier.device,
                                         style=classifier.style,
                                         preprocess_transform='power_all',
                                         normalize_with_test=False,
                                         softmax_temperature=classifier.temperature,
                                         seed=classifier.seed,
                                         return_logits=False,
                                         batch_size_inference=classifier.batch_size_inference,
                                         )

  softmax_temperature = torch.tensor(softmax_temperature).to(device)


torch.Size([569, 1, 30]) torch.Size([381, 1])
torch.Size([569, 1, 30]) torch.Size([381, 1])
torch.Size([569, 1, 100]) torch.Size([381, 1])


In [216]:
len(pred_custom)

188

In [217]:
pred_custom

tensor([[1.1139e-01, 8.8861e-01],
        [9.9998e-01, 1.5099e-05],
        [9.9999e-01, 1.4381e-05],
        [8.2190e-06, 9.9999e-01],
        [2.5166e-06, 1.0000e+00],
        [9.9999e-01, 1.0031e-05],
        [9.9999e-01, 7.8525e-06],
        [9.2876e-01, 7.1236e-02],
        [1.1724e-01, 8.8276e-01],
        [3.9087e-05, 9.9996e-01],
        [5.1093e-02, 9.4891e-01],
        [9.9852e-01, 1.4848e-03],
        [1.2111e-04, 9.9988e-01],
        [9.9227e-01, 7.7311e-03],
        [3.2559e-04, 9.9967e-01],
        [9.9467e-01, 5.3276e-03],
        [8.9337e-06, 9.9999e-01],
        [3.2585e-06, 1.0000e+00],
        [5.7493e-05, 9.9994e-01],
        [1.0000e+00, 3.8525e-06],
        [5.0348e-02, 9.4965e-01],
        [2.1155e-04, 9.9979e-01],
        [1.0000e+00, 2.4908e-06],
        [1.4606e-04, 9.9985e-01],
        [1.2914e-05, 9.9999e-01],
        [1.6486e-03, 9.9835e-01],
        [2.5991e-06, 1.0000e+00],
        [1.4376e-04, 9.9986e-01],
        [4.7790e-06, 1.0000e+00],
        [9.999

In [218]:
pred_label_c = torch.argmax(pred_custom, dim=1).cpu()

In [219]:
pred_label_c == pred_label

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  

In [220]:
pred

tensor([[[2.5643e-02, 9.7436e-01],
         [1.0000e+00, 1.4656e-06],
         [9.9999e-01, 9.1454e-06],
         [3.2564e-06, 1.0000e+00],
         [3.4475e-06, 1.0000e+00],
         [1.0000e+00, 4.6052e-06],
         [1.0000e+00, 1.2244e-06],
         [9.7865e-01, 2.1351e-02],
         [4.8726e-01, 5.1274e-01],
         [3.3818e-05, 9.9997e-01],
         [1.9664e-02, 9.8034e-01],
         [9.9965e-01, 3.5022e-04],
         [2.5653e-05, 9.9997e-01],
         [8.6371e-01, 1.3629e-01],
         [2.1190e-04, 9.9979e-01],
         [9.7162e-01, 2.8376e-02],
         [3.6410e-06, 1.0000e+00],
         [6.0988e-06, 9.9999e-01],
         [6.8569e-05, 9.9993e-01],
         [1.0000e+00, 6.5333e-07],
         [1.5332e-02, 9.8467e-01],
         [6.0217e-05, 9.9994e-01],
         [1.0000e+00, 5.6474e-07],
         [5.0694e-05, 9.9995e-01],
         [1.6663e-05, 9.9998e-01],
         [6.7215e-04, 9.9933e-01],
         [4.2186e-06, 1.0000e+00],
         [3.5856e-04, 9.9964e-01],
         [3.7867e-06