# SSL: SCARF


This work is based on the one published by the author of the [TS3L library](https://github.com/Alcoholrithm/TabularS3L).

Specifically, this code is implementing [SCARF](https://openreview.net/forum?id=CuV_qYkmKb3)

In [3]:
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import numpy as np

# Prepare the SCARFLightning Module
from ts3l.pl_modules import SCARFLightning
from ts3l.utils.scarf_utils import SCARFDataset
from ts3l.utils import TS3LDataModule
from ts3l.utils.scarf_utils import SCARFConfig
from ts3l.utils.embedding_utils import IdentityEmbeddingConfig
from ts3l.utils.backbone_utils import MLPBackboneConfig
from pytorch_lightning import Trainer


# Evaluation
from sklearn.metrics import accuracy_score
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, SequentialSampler

## Data loading

In [4]:
TEST_DATA_PATH = "../data/test_data.csv"
TRAIN_DATA_PATH = "../data/train_data.csv"
UNLABELLED_DATA_PATH = "../data/unlabelled_data.csv"
PSEUDO_LABELLED_DATA_PATH = "../data/pseudo_labelled_scarf.csv"

In [5]:
test_df = pd.read_csv(TEST_DATA_PATH)
categorical_cols = ['Chemotherapy', 'Menopausal State', 'Radio Therapy',
                            'Hormone Therapy', 'Surgery-breast conserving',
                            'Surgery-mastectomy', 'Neoplasm Histologic Grade',
                            'Cellularity']
test_df[categorical_cols]

Unnamed: 0,Chemotherapy,Menopausal State,Radio Therapy,Hormone Therapy,Surgery-breast conserving,Surgery-mastectomy,Neoplasm Histologic Grade,Cellularity
0,0,1,1,1,0,1,3,0.5
1,0,1,1,1,0,1,2,0.5
2,1,1,1,0,0,1,3,1.0
3,0,1,1,1,0,1,3,1.0
4,1,1,1,0,0,1,2,1.0
...,...,...,...,...,...,...,...,...
112,0,1,0,1,0,1,3,0.5
113,1,0,0,1,0,1,2,0.5
114,0,1,0,1,0,1,3,1.0
115,0,1,1,1,1,0,1,0.5


In [6]:
def get_dataframes(test_path, train_path, unlabelled_path, with_clinical=False):
    # Load the data
    test_df = pd.read_csv(test_path)
    train_df = pd.read_csv(train_path)
    unlabelled_df = pd.read_csv(unlabelled_path)

    # Drop the columns that are not needed
    test_df = test_df.drop(columns=['DssTime', 'Event'])
    train_df = train_df.drop(columns=['DssTime', 'Event'])

    # Extract numerical and categorical columns
    # Numerical cols: Gene + Age
    numerical_cols = test_df.columns[:21].tolist()
    # But also Size
    numerical_cols.append('Size')
    # Categorical cols: Clinical
    categorical_cols = test_df.drop(columns=['Label', 'Size']).columns[21:].tolist()
    if not with_clinical:
        test_df = test_df.drop(columns=categorical_cols)
        train_df = train_df.drop(columns=categorical_cols)
        unlabelled_df = unlabelled_df.drop(columns=categorical_cols)
        categorical_cols = []
    else:
        categorical_cols = ['Chemotherapy', 'Menopausal State', 'Radio Therapy',
                            'Hormone Therapy', 'Surgery-breast conserving',
                            'Surgery-mastectomy', 'Neoplasm Histologic Grade',
                            'Cellularity']
        not_cols=[]
        # The model has problems with these columns
        test_df = test_df.drop(columns=not_cols)
        train_df = train_df.drop(columns=not_cols)
        unlabelled_df = unlabelled_df.drop(columns=not_cols)

    print(f'Train data shape: {train_df.shape}')
    print(f'Test data shape: {test_df.shape}')
    print(f'Unlabelled data shape: {unlabelled_df.shape}')
    print(f'Numerical columns: {numerical_cols}')
    if with_clinical:
        print(f'Categorical columns: {categorical_cols}')
    return test_df, train_df, unlabelled_df, numerical_cols, categorical_cols

In [7]:
test_data, train_data, unlabelled_data, numerical_cols, categorical_cols = get_dataframes(
    TEST_DATA_PATH,
    TRAIN_DATA_PATH,
    UNLABELLED_DATA_PATH,
    with_clinical=True)

Train data shape: (465, 31)
Test data shape: (117, 31)
Unlabelled data shape: (1168, 30)
Numerical columns: ['ESR1', 'PGR', 'ERBB2', 'MKI67', 'PLAU', 'ELAVL1', 'EGFR', 'BTRC', 'FBXO6', 'SHMT2', 'KRAS', 'SRPK2', 'YWHAQ', 'PDHA1', 'EWSR1', 'ZDHHC17', 'ENO1', 'DBN1', 'PLK1', 'GSK3B', 'Age', 'Size']
Categorical columns: ['Chemotherapy', 'Menopausal State', 'Radio Therapy', 'Hormone Therapy', 'Surgery-breast conserving', 'Surgery-mastectomy', 'Neoplasm Histologic Grade', 'Cellularity']


In [8]:
test_data.head()

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,Menopausal State,Size,Radio Therapy,Chemotherapy,Hormone Therapy,Neoplasm Histologic Grade,Cellularity,Surgery-breast conserving,Surgery-mastectomy,Label
0,11.23975,5.954311,9.739996,6.046045,10.040187,5.905724,5.881255,6.538235,7.260572,10.774752,...,1,31,1,0,1,3,0.5,0,1,1
1,10.927313,7.002502,10.033753,5.568993,8.306619,6.547491,5.733367,6.128118,7.917904,9.514045,...,1,22,1,0,1,2,0.5,0,1,0
2,6.312633,5.305683,9.068778,5.919384,8.210977,5.896152,5.634379,5.625037,7.684047,11.422518,...,1,40,1,1,0,3,1.0,0,1,1
3,9.1852,5.480888,9.580607,5.655789,7.756504,6.026981,6.008594,6.269051,7.428641,9.478211,...,1,150,1,0,1,3,1.0,0,1,1
4,7.249462,5.164281,10.233184,5.721403,8.918334,6.392132,5.58845,6.062906,7.968933,9.578638,...,1,45,1,1,0,2,1.0,0,1,1


In [9]:
full_X_train = train_data.drop(columns=['Label'])
full_y_train = train_data['Label']

X_test = test_data.drop(columns=['Label'])
y_test = test_data['Label']

## Data configuration

In [10]:
# Split the train_data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(
    full_X_train,
    full_y_train,
    test_size=0.2,
    random_state=42,
    stratify=full_y_train)

print(f'Training data shape: {X_train.shape}')
print(f'Validation data shape: {X_val.shape}')

Training data shape: (372, 30)
Validation data shape: (93, 30)


## Model configuration

In [11]:
metric = "accuracy_score"
input_dim = X_train.shape[1]
pretraining_head_dim = 1024
output_dim = 2
head_depth = 2
dropout_rate = 0.05

corruption_rate = 0.6

batch_size = 128
max_epochs = 20

In [12]:
embedding_config = IdentityEmbeddingConfig(input_dim = input_dim)
backbone_config = MLPBackboneConfig(input_dim = embedding_config.output_dim)

In [13]:
config = SCARFConfig( 
    task="classification",
    loss_fn="CrossEntropyLoss",
    metric=metric, metric_hparams={},
    embedding_config=embedding_config,
    backbone_config=backbone_config,
    pretraining_head_dim=pretraining_head_dim,
    output_dim=output_dim,
    head_depth=head_depth,
    dropout_rate=dropout_rate,
    corruption_rate = corruption_rate
)

In [14]:
pl_scarf = SCARFLightning(config)

Seed set to 42


In [15]:
### First Phase Learning
train_ds = SCARFDataset(
    X_train,
    unlabeled_data=unlabelled_data,
    config = config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols)

valid_ds = SCARFDataset(
    X_val,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols
)

datamodule = TS3LDataModule(train_ds, valid_ds, batch_size=batch_size, train_sampler="random")


## Model training: 1st phase

In [16]:
trainer = Trainer(
    accelerator = 'cpu',
    max_epochs = max_epochs,
    num_sanity_val_steps = 2,
    )

trainer.fit(pl_scarf, datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
2024-12-14 02:12:53.437726: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-14 02:12:53.448262: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-14 02:12:53.451135: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-14 02:12:53.459712: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations

                                                                           

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (13) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 19: 100%|██████████| 13/13 [00:18<00:00,  0.69it/s, v_num=98, train_loss=4.930, val_loss=4.880]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 13/13 [00:18<00:00,  0.68it/s, v_num=98, train_loss=4.930, val_loss=4.880]


## Model training: 2nd phase

In [17]:
### Second Phase Learning

pl_scarf.set_second_phase()


In [18]:
train_ds = SCARFDataset(
    X_train,
    y_train.values,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols,
    is_second_phase=True)

valid_ds = SCARFDataset(
    X_val,
    y_val.values,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols,
    is_second_phase=True)

datamodule = TS3LDataModule(train_ds, valid_ds, batch_size = batch_size, train_sampler="weighted")


In [19]:
trainer = Trainer(
                    accelerator = 'cpu',
                    max_epochs = max_epochs,
                    num_sanity_val_steps = 2,
    )

trainer.fit(pl_scarf, datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name             | Type             | Params | Mode 
--------------------------------------------------------------
0 | task_loss_fn     | CrossEntropyLoss | 0      | train
1 | contrastive_loss | NTXentLoss       | 0      | train
2 | model            | SCARF            | 1.2 M  | train
--------------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.820     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode


                                                                           

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 1:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.845, train_accuracy_score=0.516, val_accuracy_score=0.613, val_loss=0.676]        



Epoch 2:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.935, train_accuracy_score=0.513, val_accuracy_score=0.624, val_loss=0.659]        



Epoch 3:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.780, train_accuracy_score=0.589, val_accuracy_score=0.624, val_loss=0.653]        



Epoch 4:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.768, train_accuracy_score=0.548, val_accuracy_score=0.656, val_loss=0.644]        



Epoch 5:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.862, train_accuracy_score=0.513, val_accuracy_score=0.656, val_loss=0.641]        



Epoch 6:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.725, train_accuracy_score=0.594, val_accuracy_score=0.634, val_loss=0.639]        



Epoch 7:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.742, train_accuracy_score=0.589, val_accuracy_score=0.634, val_loss=0.637]        



Epoch 8:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.849, train_accuracy_score=0.581, val_accuracy_score=0.634, val_loss=0.637]        



Epoch 9:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.855, train_accuracy_score=0.551, val_accuracy_score=0.634, val_loss=0.636]        



Epoch 10:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.725, train_accuracy_score=0.629, val_accuracy_score=0.634, val_loss=0.633]       



Epoch 11:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.871, train_accuracy_score=0.527, val_accuracy_score=0.634, val_loss=0.632]        



Epoch 12:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.779, train_accuracy_score=0.589, val_accuracy_score=0.645, val_loss=0.628]        



Epoch 13:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.756, train_accuracy_score=0.594, val_accuracy_score=0.624, val_loss=0.627]        



Epoch 14:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.789, train_accuracy_score=0.559, val_accuracy_score=0.613, val_loss=0.624]        



Epoch 15:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.835, train_accuracy_score=0.583, val_accuracy_score=0.602, val_loss=0.623]        



Epoch 16:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.754, train_accuracy_score=0.570, val_accuracy_score=0.591, val_loss=0.620]        



Epoch 17:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.777, train_accuracy_score=0.591, val_accuracy_score=0.602, val_loss=0.618]        



Epoch 18:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.747, train_accuracy_score=0.602, val_accuracy_score=0.634, val_loss=0.616]        



Epoch 19:   0%|          | 0/3 [00:00<?, ?it/s, v_num=99, train_loss=0.812, train_accuracy_score=0.594, val_accuracy_score=0.645, val_loss=0.616]        



Epoch 19: 100%|██████████| 3/3 [00:01<00:00,  1.74it/s, v_num=99, train_loss=0.857, train_accuracy_score=0.565, val_accuracy_score=0.634, val_loss=0.615]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, v_num=99, train_loss=0.857, train_accuracy_score=0.565, val_accuracy_score=0.634, val_loss=0.615]


## Evaluation

In [20]:
test_ds = SCARFDataset(
    X_test,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols,
    is_second_phase=True)

test_dl = DataLoader(
    test_ds,
    batch_size,
    shuffle=False,
    sampler=SequentialSampler(test_ds),
    num_workers=4)


In [21]:
preds = trainer.predict(pl_scarf, test_dl)
        
preds = F.softmax(torch.concat([out.cpu() for out in preds]).squeeze(),dim=1)

accuracy = accuracy_score(y_test, preds.argmax(1))

print("Accuracy %.2f" % accuracy)



Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 312.68it/s]
Accuracy 0.74


## Label prediction

In [22]:
unlabelled_ds = SCARFDataset(
    X=unlabelled_data,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols)

unlabelled_dl = DataLoader(
    unlabelled_ds,
    batch_size,
    shuffle=False,
    sampler=SequentialSampler(unlabelled_ds)
)

In [23]:
preds = trainer.predict(pl_scarf, unlabelled_dl)
preds = F.softmax(torch.concat([out.cpu() for out in preds]).squeeze(),dim=1)

predicted_labels = preds.argmax(1)
associated_probabilities = preds[np.arange(preds.shape[0]), predicted_labels]

pd.DataFrame({'label': predicted_labels, 'confidence': associated_probabilities})

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 123.49it/s]


Unnamed: 0,label,confidence
0,0,0.713222
1,0,0.578509
2,0,0.747164
3,1,0.535473
4,0,0.592202
...,...,...
1163,0,0.768499
1164,0,0.652164
1165,1,0.501235
1166,0,0.709410


## Data export

In [24]:
unlabelled_data = pd.read_csv(UNLABELLED_DATA_PATH)
train_data = pd.read_csv(TRAIN_DATA_PATH)
train_data['Confidence'] = 1.0

train_data.head()

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,Chemotherapy,Hormone Therapy,Neoplasm Histologic Grade,Cellularity,Surgery-breast conserving,Surgery-mastectomy,Label,DssTime,Event,Confidence
0,10.041281,7.376123,9.725825,5.427919,9.300307,6.219375,6.125355,5.888779,7.893369,9.007343,...,1,1,2,1.0,0,1,0,163.1,1,1.0
1,11.276581,7.331223,9.956267,5.629876,8.119906,5.66562,5.775809,6.251167,8.242063,10.871432,...,1,1,3,1.0,0,1,1,41.366667,1,1.0
2,7.536847,5.587666,11.514514,5.722951,6.741081,6.32148,5.466188,6.956486,7.673015,9.837096,...,0,0,2,1.0,0,1,1,36.266667,1,1.0
3,10.395644,6.531288,9.075396,5.440774,7.861422,5.973844,5.75712,6.026611,7.666777,9.455256,...,0,1,2,0.5,1,0,0,86.066667,1,1.0
4,6.204958,5.172111,8.881671,5.861609,8.530361,6.671294,11.724683,6.046692,7.401715,10.481299,...,1,0,3,0.0,0,1,1,8.066667,1,1.0


In [25]:
unlabelled_data['Label'] = predicted_labels
unlabelled_data['Confidence'] = associated_probabilities

unlabelled_data

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,Size,Radio Therapy,Chemotherapy,Hormone Therapy,Neoplasm Histologic Grade,Cellularity,Surgery-breast conserving,Surgery-mastectomy,Label,Confidence
0,10.047059,7.505424,9.729606,5.451007,8.474830,6.412419,5.899440,7.069394,7.100058,9.102318,...,10.0,1,0,1,3.0,1.0,1.0,0.0,0,0.713222
1,10.404685,6.815637,10.334979,5.488309,9.994894,6.525927,5.585357,6.071653,7.810600,9.431167,...,25.0,1,1,1,2.0,0.5,0.0,1.0,0,0.578509
2,10.793832,7.720952,9.276507,5.478224,8.184773,5.949741,5.743395,7.244882,7.781876,8.716918,...,10.0,1,1,1,2.0,0.5,1.0,0.0,0,0.747164
3,10.440667,5.592522,8.613192,5.436625,8.210389,6.203913,5.852012,6.219653,7.560101,9.600030,...,29.0,1,0,1,2.0,0.5,1.0,0.0,1,0.535473
4,12.521038,5.325554,10.678267,5.623786,7.786509,6.153012,5.502281,7.278257,7.674681,10.367760,...,16.0,1,0,1,3.0,0.5,1.0,0.0,0,0.592202
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1163,11.628490,5.570690,10.475695,6.032211,9.944405,5.865408,5.703147,6.649948,7.272166,9.750208,...,25.0,0,0,1,3.0,1.0,0.0,1.0,0,0.768499
1164,10.879891,6.431113,10.219154,5.435795,9.224122,5.699195,5.825643,6.404899,7.385644,9.271953,...,45.0,0,0,1,3.0,1.0,0.0,1.0,0,0.652164
1165,9.591235,7.984515,9.935179,5.605596,9.799519,5.808704,5.905282,6.491419,7.865526,9.741103,...,25.0,1,0,1,3.0,1.0,1.0,0.0,1,0.501235
1166,11.055114,8.282737,9.892589,5.753274,8.687667,5.475813,5.587906,6.830579,8.468221,9.482622,...,25.0,0,0,1,2.0,0.5,0.0,1.0,0,0.709410


In [26]:
combined_data = pd.concat([train_data, unlabelled_data], ignore_index=True)
combined_data.to_csv(PSEUDO_LABELLED_DATA_PATH, index=False)