# SSL: SCARF


This work is based on the one published by the author of the [TS3L library](https://github.com/Alcoholrithm/TabularS3L).

Specifically, this code is implementing [SCARF](https://openreview.net/forum?id=CuV_qYkmKb3)

In [62]:
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import numpy as np

# Prepare the SCARFLightning Module
from ts3l.pl_modules import SCARFLightning
from ts3l.utils.scarf_utils import SCARFDataset
from ts3l.utils import TS3LDataModule
from ts3l.utils.scarf_utils import SCARFConfig
from ts3l.utils.embedding_utils import IdentityEmbeddingConfig
from ts3l.utils.backbone_utils import MLPBackboneConfig
from pytorch_lightning import Trainer


# Evaluation
from sklearn.metrics import accuracy_score
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, SequentialSampler

## Data loading

In [63]:
TRAIN_DATA_PATH = "../data/sequential/train_data.csv"
VAL_DATA_PATH = "../data/sequential/val_data.csv"
UNLABELLED_DATA_PATH = "../data/sequential/unlabelled_data.csv"

PSEUDO_LABELLED_DATA_PATH = "../data/ssl/pseudo_labelled_data_scarf.csv"

In [64]:
def get_dataframes(train_path, val_path, unlabelled_path, with_clinical=False):
    # Load the data
    train_df = pd.read_csv(train_path)
    val_df = pd.read_csv(val_path)
    unlabelled_df = pd.read_csv(unlabelled_path)

    # Drop the columns that are not needed
    train_df = train_df.drop(columns=['DssTime', 'Event', 'auto_id'])
    val_df = val_df.drop(columns=['DssTime', 'Event', 'auto_id'])
    unlabelled_df = unlabelled_df.drop(columns=['auto_id'])

    # Extract numerical and categorical columns
    # Numerical cols: Gene + Age
    numerical_cols = train_df.columns[:21].tolist()
    # But also Size
    numerical_cols.append('Size')
    # Categorical cols: Clinical
    categorical_cols = train_df.drop(columns=['Label', 'Size']).columns[21:].tolist()
    if not with_clinical:
        train_df = train_df.drop(columns=categorical_cols)
        val_df = val_df.drop(columns=categorical_cols)
        unlabelled_df = unlabelled_df.drop(columns=categorical_cols)
        categorical_cols = []
    else:
        categorical_cols = ['Chemotherapy', 'Menopausal State', 'Radio Therapy',
                            'Hormone Therapy', 'Surgery-breast conserving',
                            'Surgery-mastectomy', 'Neoplasm Histologic Grade',
                            'Cellularity']
        not_cols=[]
        # The model has problems with these columns
        train_df = train_df.drop(columns=not_cols)
        val_df = val_df.drop(columns=not_cols)
        unlabelled_df = unlabelled_df.drop(columns=not_cols)

    print(f'Train data shape: {train_df.shape}')
    print(f'Val data shape: {val_df.shape}')
    print(f'Unlabelled data shape: {unlabelled_df.shape}')
    print(f'Numerical columns: {numerical_cols}')
    if with_clinical:
        print(f'Categorical columns: {categorical_cols}')
    return train_df, val_df, unlabelled_df, numerical_cols, categorical_cols

In [65]:
train_data, val_data, unlabelled_data, numerical_cols, categorical_cols = get_dataframes(
    TRAIN_DATA_PATH,
    VAL_DATA_PATH,
    UNLABELLED_DATA_PATH,
    with_clinical=True)

Train data shape: (372, 31)
Val data shape: (93, 31)
Unlabelled data shape: (1168, 30)
Numerical columns: ['ESR1', 'PGR', 'ERBB2', 'MKI67', 'PLAU', 'ELAVL1', 'EGFR', 'BTRC', 'FBXO6', 'SHMT2', 'KRAS', 'SRPK2', 'YWHAQ', 'PDHA1', 'EWSR1', 'ZDHHC17', 'ENO1', 'DBN1', 'PLK1', 'GSK3B', 'Age', 'Size']
Categorical columns: ['Chemotherapy', 'Menopausal State', 'Radio Therapy', 'Hormone Therapy', 'Surgery-breast conserving', 'Surgery-mastectomy', 'Neoplasm Histologic Grade', 'Cellularity']


In [66]:
train_data.head()

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,Menopausal State,Size,Radio Therapy,Chemotherapy,Hormone Therapy,Neoplasm Histologic Grade,Cellularity,Surgery-breast conserving,Surgery-mastectomy,Label
0,11.445577,5.383985,10.523773,5.878819,7.376048,6.224054,5.730411,6.937679,7.314421,9.496653,...,1,43,1,0,0,1,0.5,1,0,1
1,11.728241,6.853981,9.093163,6.189151,9.775104,5.891141,5.644219,6.445037,7.500848,10.14059,...,1,43,1,0,1,3,1.0,0,1,0
2,11.122611,5.322075,11.719898,6.121027,8.733519,6.321382,6.011898,6.702699,7.682648,9.736608,...,1,15,0,0,0,2,1.0,0,1,0
3,11.246158,7.200805,10.48821,5.921079,9.480574,5.598995,5.923112,6.411695,7.998787,10.009468,...,1,20,1,0,1,3,0.5,1,0,0
4,9.744005,5.481691,11.031849,5.826471,8.732274,5.873227,5.495358,6.243658,7.71902,9.32783,...,1,30,1,0,1,2,1.0,1,0,1


In [67]:
full_X_train = train_data.drop(columns=['Label'])
full_y_train = train_data['Label']

# Use the validation data as test data
X_test = val_data.drop(columns=['Label'])
y_test = val_data['Label']

## Data configuration

In [68]:
# Split the train_data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(
    full_X_train,
    full_y_train,
    test_size=0.2,
    random_state=42,
    stratify=full_y_train)

print(f'Training data shape: {X_train.shape}')
print(f'Validation data shape: {X_val.shape}')

Training data shape: (297, 30)
Validation data shape: (75, 30)


## Model configuration

In [69]:
metric = "accuracy_score"
input_dim = X_train.shape[1]
pretraining_head_dim = 1024
output_dim = 2
head_depth = 2
dropout_rate = 0.05

corruption_rate = 0.6

batch_size = 128
max_epochs = 20

embedding_config = IdentityEmbeddingConfig(input_dim = input_dim)
backbone_config = MLPBackboneConfig(input_dim = embedding_config.output_dim)

In [70]:
config = SCARFConfig( 
    task="classification",
    loss_fn="CrossEntropyLoss",
    metric=metric, metric_hparams={},
    embedding_config=embedding_config,
    backbone_config=backbone_config,
    pretraining_head_dim=pretraining_head_dim,
    output_dim=output_dim,
    head_depth=head_depth,
    dropout_rate=dropout_rate,
    corruption_rate = corruption_rate
)

## Model training: 1st phase

In [71]:
### First Phase Learning
train_ds = SCARFDataset(
    X_train,
    unlabeled_data=unlabelled_data,
    config = config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols)

valid_ds = SCARFDataset(
    X_val,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols
)

datamodule = TS3LDataModule(train_ds, valid_ds, batch_size=batch_size, train_sampler="random", n_jobs=8)


In [72]:
from pytorch_lightning.callbacks import EarlyStopping
# Define early stopping callback
early_stopping = EarlyStopping(
    monitor='val_loss',  # Metric to monitor
    patience=3,          # Number of epochs with no improvement after which training will be stopped
    verbose=True,
    mode='min'           # Mode can be 'min' or 'max' depending on the metric
)

pl_scarf = SCARFLightning(config)

trainer = Trainer(
    accelerator = 'cpu',
    max_epochs = max_epochs,
    num_sanity_val_steps = 2,
    callbacks = [early_stopping]
    )

trainer.fit(pl_scarf, datamodule)

Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name             | Type             | Params | Mode 
--------------------------------------------------------------
0 | task_loss_fn     | CrossEntropyLoss | 0      | train
1 | contrastive_loss | NTXentLoss       | 0      | train
2 | model            | SCARF            | 1.2 M  | train
--------------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.820     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode


                                                                           

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (12) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 12/12 [00:15<00:00,  0.75it/s, v_num=128, train_loss=5.360, val_loss=4.550]

Metric val_loss improved. New best score: 4.547


Epoch 1: 100%|██████████| 12/12 [00:17<00:00,  0.69it/s, v_num=128, train_loss=5.220, val_loss=4.260]

Metric val_loss improved by 0.282 >= min_delta = 0.0. New best score: 4.264


Epoch 4: 100%|██████████| 12/12 [00:15<00:00,  0.78it/s, v_num=128, train_loss=5.170, val_loss=4.570]

Monitored metric val_loss did not improve in the last 3 records. Best score: 4.264. Signaling Trainer to stop.


Epoch 4: 100%|██████████| 12/12 [00:15<00:00,  0.77it/s, v_num=128, train_loss=5.170, val_loss=4.570]


## Model training: 2nd phase

In [73]:
### Second Phase Learning

pl_scarf.set_second_phase()


In [74]:
train_ds = SCARFDataset(
    X_train,
    y_train.values,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols,
    is_second_phase=True)

valid_ds = SCARFDataset(
    X_val,
    y_val.values,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols,
    is_second_phase=True)

datamodule = TS3LDataModule(train_ds, valid_ds, batch_size = batch_size, train_sampler="weighted", n_jobs=8)


In [75]:
trainer = Trainer(
                    accelerator = 'cpu',
                    max_epochs = max_epochs,
                    num_sanity_val_steps = 2,
                    callbacks = [early_stopping]
    )

trainer.fit(pl_scarf, datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name             | Type             | Params | Mode 
--------------------------------------------------------------
0 | task_loss_fn     | CrossEntropyLoss | 0      | train
1 | contrastive_loss | NTXentLoss       | 0      | train
2 | model            | SCARF            | 1.2 M  | train
--------------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.820     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode


                                                                           

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 3/3 [00:00<00:00,  8.61it/s, v_num=129, train_loss=0.857, train_accuracy_score=0.549, val_accuracy_score=0.560, val_loss=0.709]

Metric val_loss improved by 3.555 >= min_delta = 0.0. New best score: 0.709


Epoch 1: 100%|██████████| 3/3 [00:00<00:00,  5.43it/s, v_num=129, train_loss=0.776, train_accuracy_score=0.576, val_accuracy_score=0.627, val_loss=0.689]

Metric val_loss improved by 0.020 >= min_delta = 0.0. New best score: 0.689


Epoch 2: 100%|██████████| 3/3 [00:00<00:00,  5.74it/s, v_num=129, train_loss=0.869, train_accuracy_score=0.566, val_accuracy_score=0.587, val_loss=0.681]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.681


Epoch 3: 100%|██████████| 3/3 [00:00<00:00,  5.71it/s, v_num=129, train_loss=0.846, train_accuracy_score=0.522, val_accuracy_score=0.627, val_loss=0.675]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 0.675


Epoch 4: 100%|██████████| 3/3 [00:00<00:00,  5.61it/s, v_num=129, train_loss=0.896, train_accuracy_score=0.485, val_accuracy_score=0.573, val_loss=0.671]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.671


Epoch 5: 100%|██████████| 3/3 [00:00<00:00,  5.48it/s, v_num=129, train_loss=0.805, train_accuracy_score=0.572, val_accuracy_score=0.587, val_loss=0.670]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.670


Epoch 7: 100%|██████████| 3/3 [00:00<00:00,  5.08it/s, v_num=129, train_loss=0.791, train_accuracy_score=0.599, val_accuracy_score=0.587, val_loss=0.668]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.668


Epoch 8: 100%|██████████| 3/3 [00:00<00:00,  5.77it/s, v_num=129, train_loss=0.773, train_accuracy_score=0.566, val_accuracy_score=0.547, val_loss=0.666]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.666


Epoch 9: 100%|██████████| 3/3 [00:00<00:00,  5.48it/s, v_num=129, train_loss=0.815, train_accuracy_score=0.556, val_accuracy_score=0.560, val_loss=0.664]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.664


Epoch 10: 100%|██████████| 3/3 [00:00<00:00,  5.38it/s, v_num=129, train_loss=0.839, train_accuracy_score=0.596, val_accuracy_score=0.573, val_loss=0.664]

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.664


Epoch 12: 100%|██████████| 3/3 [00:00<00:00,  5.81it/s, v_num=129, train_loss=0.774, train_accuracy_score=0.593, val_accuracy_score=0.587, val_loss=0.662]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.662


Epoch 13: 100%|██████████| 3/3 [00:00<00:00,  5.79it/s, v_num=129, train_loss=0.775, train_accuracy_score=0.586, val_accuracy_score=0.587, val_loss=0.658]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.658


Epoch 15: 100%|██████████| 3/3 [00:00<00:00,  5.78it/s, v_num=129, train_loss=0.814, train_accuracy_score=0.596, val_accuracy_score=0.613, val_loss=0.656]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.656


Epoch 18: 100%|██████████| 3/3 [00:00<00:00,  5.72it/s, v_num=129, train_loss=0.793, train_accuracy_score=0.566, val_accuracy_score=0.613, val_loss=0.657]

Monitored metric val_loss did not improve in the last 3 records. Best score: 0.656. Signaling Trainer to stop.


Epoch 18: 100%|██████████| 3/3 [00:00<00:00,  5.24it/s, v_num=129, train_loss=0.793, train_accuracy_score=0.566, val_accuracy_score=0.613, val_loss=0.657]


## Evaluation

In [76]:
test_ds = SCARFDataset(
    X_test,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols,
    is_second_phase=True)

test_dl = DataLoader(
    test_ds,
    batch_size,
    shuffle=False,
    sampler=SequentialSampler(test_ds),
    num_workers=8)


In [78]:
preds = trainer.predict(pl_scarf, test_dl)
        
preds = F.softmax(torch.concat([out.cpu() for out in preds]).squeeze(),dim=1)

accuracy = accuracy_score(y_test, preds.argmax(1))

print("Accuracy %.2f" % accuracy)

Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 369.18it/s]
Accuracy 0.63


## Label prediction

In [79]:
unlabelled_ds = SCARFDataset(
    X=unlabelled_data,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols)

unlabelled_dl = DataLoader(
    unlabelled_ds,
    batch_size,
    shuffle=False,
    sampler=SequentialSampler(unlabelled_ds)
)

In [80]:
preds = trainer.predict(pl_scarf, unlabelled_dl)
preds = F.softmax(torch.concat([out.cpu() for out in preds]).squeeze(),dim=1)

predicted_labels = preds.argmax(1)
associated_probabilities = preds[np.arange(preds.shape[0]), predicted_labels]

pd.DataFrame({'label': predicted_labels, 'confidence': associated_probabilities})

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 96.02it/s]


Unnamed: 0,label,confidence
0,0,0.592871
1,1,0.516037
2,0,0.709010
3,0,0.524126
4,0,0.669345
...,...,...
1163,0,0.761280
1164,0,0.612800
1165,1,0.588077
1166,0,0.695925


## Data export

In [81]:
unlabelled_data = pd.read_csv(UNLABELLED_DATA_PATH)
train_data = pd.read_csv(TRAIN_DATA_PATH)
train_data['Confidence'] = 1.0

train_data.head()

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,Hormone Therapy,Neoplasm Histologic Grade,Cellularity,Surgery-breast conserving,Surgery-mastectomy,Label,DssTime,Event,auto_id,Confidence
0,11.445577,5.383985,10.523773,5.878819,7.376048,6.224054,5.730411,6.937679,7.314421,9.496653,...,0,1,0.5,1,0,1,52.733333,1,1271,1.0
1,11.728241,6.853981,9.093163,6.189151,9.775104,5.891141,5.644219,6.445037,7.500848,10.14059,...,1,3,1.0,0,1,0,61.1,1,1050,1.0
2,11.122611,5.322075,11.719898,6.121027,8.733519,6.321382,6.011898,6.702699,7.682648,9.736608,...,0,2,1.0,0,1,0,94.033333,1,1300,1.0
3,11.246158,7.200805,10.48821,5.921079,9.480574,5.598995,5.923112,6.411695,7.998787,10.009468,...,1,3,0.5,1,0,0,118.133333,1,1156,1.0
4,9.744005,5.481691,11.031849,5.826471,8.732274,5.873227,5.495358,6.243658,7.71902,9.32783,...,1,2,1.0,1,0,1,27.066667,1,1458,1.0


In [82]:
unlabelled_data['Label'] = predicted_labels
unlabelled_data['Confidence'] = associated_probabilities

unlabelled_data

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,Radio Therapy,Chemotherapy,Hormone Therapy,Neoplasm Histologic Grade,Cellularity,Surgery-breast conserving,Surgery-mastectomy,auto_id,Label,Confidence
0,10.047059,7.505424,9.729606,5.451007,8.474830,6.412419,5.899440,7.069394,7.100058,9.102318,...,1,0,1,3.0,1.0,1.0,0.0,2000,0,0.592871
1,10.404685,6.815637,10.334979,5.488309,9.994894,6.525927,5.585357,6.071653,7.810600,9.431167,...,1,1,1,2.0,0.5,0.0,1.0,2001,1,0.516037
2,10.793832,7.720952,9.276507,5.478224,8.184773,5.949741,5.743395,7.244882,7.781876,8.716918,...,1,1,1,2.0,0.5,1.0,0.0,2002,0,0.709010
3,10.440667,5.592522,8.613192,5.436625,8.210389,6.203913,5.852012,6.219653,7.560101,9.600030,...,1,0,1,2.0,0.5,1.0,0.0,2003,0,0.524126
4,12.521038,5.325554,10.678267,5.623786,7.786509,6.153012,5.502281,7.278257,7.674681,10.367760,...,1,0,1,3.0,0.5,1.0,0.0,2004,0,0.669345
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1163,11.628490,5.570690,10.475695,6.032211,9.944405,5.865408,5.703147,6.649948,7.272166,9.750208,...,0,0,1,3.0,1.0,0.0,1.0,3163,0,0.761280
1164,10.879891,6.431113,10.219154,5.435795,9.224122,5.699195,5.825643,6.404899,7.385644,9.271953,...,0,0,1,3.0,1.0,0.0,1.0,3164,0,0.612800
1165,9.591235,7.984515,9.935179,5.605596,9.799519,5.808704,5.905282,6.491419,7.865526,9.741103,...,1,0,1,3.0,1.0,1.0,0.0,3165,1,0.588077
1166,11.055114,8.282737,9.892589,5.753274,8.687667,5.475813,5.587906,6.830579,8.468221,9.482622,...,0,0,1,2.0,0.5,0.0,1.0,3166,0,0.695925


In [83]:
combined_data = pd.concat([train_data, unlabelled_data], ignore_index=True)
combined_data.sort_values(by='auto_id').to_csv(PSEUDO_LABELLED_DATA_PATH, index=False)

In [84]:
combined_data.shape

(1540, 35)