# SSL: VIME

This work is based on the one published by the author of the [TS3L library](https://github.com/Alcoholrithm/TabularS3L).

Specifically, this code is implementing [VIME](https://proceedings.neurips.cc/paper_files/paper/2020/file/7d97667a3e056acab9aaf653807b4a03-Paper.pdf)

In [58]:
# Prepare the VIMELightning Module
from ts3l.pl_modules import VIMELightning
from ts3l.utils.vime_utils import VIMEDataset
from ts3l.utils import TS3LDataModule, get_category_cardinality
from ts3l.utils.vime_utils import VIMEConfig
from ts3l.utils.embedding_utils import IdentityEmbeddingConfig
from ts3l.utils.backbone_utils import MLPBackboneConfig
from pytorch_lightning import Trainer
import numpy as np
from ts3l.utils.vime_utils import VIMESecondPhaseCollateFN
from sklearn.metrics import accuracy_score
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, SequentialSampler
from sklearn.model_selection import train_test_split

import pandas as pd

## Data loading

In [59]:
TRAIN_DATA_PATH = "../data/sequential/train_data.csv"
VAL_DATA_PATH = "../data/sequential/val_data.csv"
UNLABELLED_DATA_PATH = "../data/sequential/unlabelled_data.csv"

PSEUDO_LABELLED_DATA_PATH = "../data/ssl/pseudo_labelled_data_vime.csv"

In [60]:
def get_dataframes(train_path, val_path, unlabelled_path, with_clinical=False):
    # Load the data
    train_df = pd.read_csv(train_path)
    unlabelled_df = pd.read_csv(unlabelled_path)
    val_df = pd.read_csv(val_path)

    # Drop the columns that are not needed
    train_df = train_df.drop(columns=['DssTime', 'Event', 'auto_id'])
    val_df = val_df.drop(columns=['DssTime', 'Event', 'auto_id'])
    unlabelled_df = unlabelled_df.drop(columns=['auto_id'])

    # Extract numerical and categorical columns
    # Numerical cols: Gene + Age
    numerical_cols = train_df.columns[:21].tolist()
    # But also Size
    numerical_cols.append('Size')
    # Categorical cols: Clinical
    categorical_cols = train_df.drop(columns=['Label', 'Size']).columns[21:].tolist()
    if not with_clinical:
        train_df = train_df.drop(columns=categorical_cols)
        val_df = val_df.drop(columns=categorical_cols)
        unlabelled_df = unlabelled_df.drop(columns=categorical_cols)
        categorical_cols = []
    else:
        categorical_cols = ['Chemotherapy', 'Menopausal State', 'Radio Therapy', 'Hormone Therapy', 'Surgery-breast conserving', 'Surgery-mastectomy']
        # The model has problems with these columns
        not_cols = ['Neoplasm Histologic Grade', 'Cellularity']
        train_df = train_df.drop(columns=not_cols)
        val_df = val_df.drop(columns=not_cols)
        unlabelled_df = unlabelled_df.drop(columns=not_cols)

    print(f'Train data shape: {train_df.shape}')
    print(f'Val data shape: {val_df.shape}')
    print(f'Unlabelled data shape: {unlabelled_df.shape}')
    print(f'Numerical columns: {numerical_cols}')
    if with_clinical:
        print(f'Categorical columns: {categorical_cols}')
    return train_df, val_df, unlabelled_df, numerical_cols, categorical_cols

In [61]:
train_data, val_data, unlabelled_data, numerical_cols, categorical_cols = get_dataframes(
    TRAIN_DATA_PATH,
    VAL_DATA_PATH,
    UNLABELLED_DATA_PATH,
    with_clinical=False)

Train data shape: (372, 23)
Val data shape: (93, 23)
Unlabelled data shape: (1168, 22)
Numerical columns: ['ESR1', 'PGR', 'ERBB2', 'MKI67', 'PLAU', 'ELAVL1', 'EGFR', 'BTRC', 'FBXO6', 'SHMT2', 'KRAS', 'SRPK2', 'YWHAQ', 'PDHA1', 'EWSR1', 'ZDHHC17', 'ENO1', 'DBN1', 'PLK1', 'GSK3B', 'Age', 'Size']


In [62]:
train_data.head()

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,PDHA1,EWSR1,ZDHHC17,ENO1,DBN1,PLK1,GSK3B,Age,Size,Label
0,11.445577,5.383985,10.523773,5.878819,7.376048,6.224054,5.730411,6.937679,7.314421,9.496653,...,8.472817,10.014916,8.037095,11.556363,9.697175,6.687842,8.197441,63,43,1
1,11.728241,6.853981,9.093163,6.189151,9.775104,5.891141,5.644219,6.445037,7.500848,10.14059,...,7.908106,9.290438,7.124771,12.319678,8.68731,6.408426,8.303559,74,43,0
2,11.122611,5.322075,11.719898,6.121027,8.733519,6.321382,6.011898,6.702699,7.682648,9.736608,...,8.242456,9.763899,8.006279,12.571911,9.236787,6.921407,8.155855,52,15,0
3,11.246158,7.200805,10.48821,5.921079,9.480574,5.598995,5.923112,6.411695,7.998787,10.009468,...,8.216202,9.56626,7.362756,12.310459,9.304842,6.561758,8.985041,62,20,0
4,9.744005,5.481691,11.031849,5.826471,8.732274,5.873227,5.495358,6.243658,7.71902,9.32783,...,8.311747,8.969407,7.349488,11.630288,8.838334,6.456821,7.996271,55,30,1


In [63]:
val_data.head()

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,PDHA1,EWSR1,ZDHHC17,ENO1,DBN1,PLK1,GSK3B,Age,Size,Label
0,6.407398,5.332991,11.028405,5.570715,7.88406,6.170632,6.524852,6.087229,7.607153,9.709837,...,8.532754,9.604273,7.083617,12.200205,8.89117,6.152993,7.93927,66,35,0
1,5.359227,5.233657,11.295086,5.467575,7.806072,7.934447,6.342535,6.249445,6.337625,9.996191,...,8.390242,10.186778,8.099785,13.601299,9.575482,6.218334,7.666581,82,45,1
2,5.874314,5.226781,10.052359,6.345439,8.733782,7.140753,7.271026,6.408314,7.436915,11.848892,...,8.245571,9.444395,6.876545,13.17686,9.369397,7.176979,8.714151,67,45,1
3,11.050258,5.361671,9.424651,6.489388,8.299958,6.422778,5.70652,6.598701,7.400819,10.87385,...,8.034139,9.131501,7.746938,12.840906,9.30558,7.096915,8.55851,58,21,0
4,11.424278,5.372667,10.864729,6.410957,6.787171,6.926255,5.685226,7.046461,6.94662,9.987552,...,8.297451,9.849713,7.425202,11.245177,9.321138,6.654876,8.72385,54,30,1


In [64]:
unlabelled_data.head()

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,YWHAQ,PDHA1,EWSR1,ZDHHC17,ENO1,DBN1,PLK1,GSK3B,Age,Size
0,10.047059,7.505424,9.729606,5.451007,8.47483,6.412419,5.89944,7.069394,7.100058,9.102318,...,12.124106,8.234815,9.674483,7.375406,12.024185,7.757081,6.45558,8.385158,43.19,10.0
1,10.404685,6.815637,10.334979,5.488309,9.994894,6.525927,5.585357,6.071653,7.8106,9.431167,...,12.712623,8.242056,9.15434,7.585137,11.844377,8.042064,6.036497,7.464064,47.68,25.0
2,10.793832,7.720952,9.276507,5.478224,8.184773,5.949741,5.743395,7.244882,7.781876,8.716918,...,12.102056,8.105696,9.1365,7.730418,10.790484,7.599985,5.95331,7.229434,56.45,10.0
3,10.440667,5.592522,8.613192,5.436625,8.210389,6.203913,5.852012,6.219653,7.560101,9.60003,...,11.839824,8.376467,9.257898,7.743428,11.21288,7.770462,6.194253,7.289188,89.08,29.0
4,12.521038,5.325554,10.678267,5.623786,7.786509,6.153012,5.502281,7.278257,7.674681,10.36776,...,11.363913,9.031051,9.46249,7.472229,12.95541,8.591434,6.64576,8.737746,86.41,16.0


In [65]:
len(categorical_cols) + len(numerical_cols)

22

## Data configuration

In [66]:
full_X_train = train_data.drop(columns=['Label'])
full_y_train = train_data['Label']

# The validation data will serve as the test data
X_test = val_data.drop(columns=['Label'])
y_test = val_data['Label']

In [67]:
# Split the train_data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(
    full_X_train,
    full_y_train,
    test_size=0.2,
    random_state=42,
    stratify=full_y_train)

print(f'Training data shape: {X_train.shape}')
print(f'Validation data shape: {X_val.shape}')

Training data shape: (297, 22)
Validation data shape: (75, 22)


## Model configuration

In [68]:
metric = "accuracy_score"
input_dim = X_train.shape[1]
# predictor_dim = 1024
predictor_dim = 256
# alpha: hyper-parameter to control two self-supervised loss
alpha1 = 2.0
alpha2 = 2.0
# beta is introduced to control the trade-off between the supervised and unsupervised losses
beta = 1.0
#  K: number of augmented data
K = 2
# The probability (bernoulli) of masking a feature
# This will increase the difficulty of the reconstruction task
p_m = 0.2
# batch_size = 128
batch_size = 32
max_epochs = 20

embedding_config = IdentityEmbeddingConfig(input_dim = input_dim)
backbone_config = MLPBackboneConfig(input_dim = embedding_config.output_dim)
cardinality = get_category_cardinality(X_train, categorical_cols)

In [69]:
config = VIMEConfig( 
                    task="classification",
                    loss_fn="CrossEntropyLoss",
                    metric=metric,
                    metric_hparams={},
                    embedding_config=embedding_config,
                    backbone_config=backbone_config,
                    predictor_dim=predictor_dim,
                    output_dim=2,
                    alpha1=alpha1,
                    alpha2=alpha2, 
                    beta=beta,
                    K=K,
                    p_m = p_m,
                    cat_cardinality=cardinality,
                    num_continuous=len(numerical_cols),
)


In [70]:
train_ds = VIMEDataset(
    X=X_train,
    unlabeled_data=unlabelled_data,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols
)

valid_ds = VIMEDataset(
    X=X_train,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols
)

datamodule = TS3LDataModule(train_ds, valid_ds, batch_size, train_sampler='random', n_jobs=8)

## Model training: 1st phase

In [71]:
from pytorch_lightning.callbacks import EarlyStopping

# Define early stopping callback
early_stopping = EarlyStopping(
    monitor='val_loss',  # Metric to monitor
    patience=3,          # Number of epochs with no improvement after which training will be stopped
    verbose=True,
    mode='min'           # Mode can be 'min' or 'max' depending on the metric
)

trainer = Trainer(
    accelerator = 'cpu',
    max_epochs = max_epochs,
    num_sanity_val_steps = 2,
    callbacks = [early_stopping],
)

pl_vime = VIMELightning(config)
trainer.fit(pl_vime, datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 42

  | Name                        | Type             | Params | Mode 
-------------------------------------------------------------------------
0 | task_loss_fn                | CrossEntropyLoss | 0      | train
1 | mask_loss_fn                | BCELoss          | 0      | train
2 | categorical_feature_loss_fn | CrossEntropyLoss | 0      | train
3 | continuous_feature_loss_fn  | MSELoss          | 0      | train
4 | consistency_loss_fn         | MSELoss          | 0      | train
5 | model                       | VIME             | 124 K  | train
-------------------------------------------------------------------------
124 K     Trainable params
0         Non-trainable params
124 K     Total params
0.499     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode


                                                                           

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (46) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 46/46 [00:00<00:00, 84.01it/s, v_num=130, train_loss=556.0, val_loss=571.0]

Metric val_loss improved. New best score: 570.792


Epoch 1: 100%|██████████| 46/46 [00:00<00:00, 60.60it/s, v_num=130, train_loss=502.0, val_loss=495.0]

Metric val_loss improved by 75.719 >= min_delta = 0.0. New best score: 495.073


Epoch 2: 100%|██████████| 46/46 [00:00<00:00, 65.52it/s, v_num=130, train_loss=447.0, val_loss=430.0]

Metric val_loss improved by 64.823 >= min_delta = 0.0. New best score: 430.250


Epoch 3: 100%|██████████| 46/46 [00:00<00:00, 64.46it/s, v_num=130, train_loss=390.0, val_loss=368.0] 

Metric val_loss improved by 61.880 >= min_delta = 0.0. New best score: 368.371


Epoch 4: 100%|██████████| 46/46 [00:00<00:00, 61.08it/s, v_num=130, train_loss=331.0, val_loss=312.0]

Metric val_loss improved by 56.194 >= min_delta = 0.0. New best score: 312.177


Epoch 5: 100%|██████████| 46/46 [00:00<00:00, 65.20it/s, v_num=130, train_loss=277.0, val_loss=265.0] 

Metric val_loss improved by 47.239 >= min_delta = 0.0. New best score: 264.939


Epoch 6: 100%|██████████| 46/46 [00:00<00:00, 63.35it/s, v_num=130, train_loss=232.0, val_loss=231.0]

Metric val_loss improved by 33.676 >= min_delta = 0.0. New best score: 231.262


Epoch 7: 100%|██████████| 46/46 [00:00<00:00, 65.50it/s, v_num=130, train_loss=195.0, val_loss=218.0] 

Metric val_loss improved by 13.526 >= min_delta = 0.0. New best score: 217.736


Epoch 8: 100%|██████████| 46/46 [00:00<00:00, 66.78it/s, v_num=130, train_loss=173.0, val_loss=185.0] 

Metric val_loss improved by 32.307 >= min_delta = 0.0. New best score: 185.429


Epoch 10: 100%|██████████| 46/46 [00:00<00:00, 61.75it/s, v_num=130, train_loss=134.0, val_loss=152.0]

Metric val_loss improved by 33.762 >= min_delta = 0.0. New best score: 151.667


Epoch 11: 100%|██████████| 46/46 [00:00<00:00, 62.06it/s, v_num=130, train_loss=120.0, val_loss=139.0]

Metric val_loss improved by 12.289 >= min_delta = 0.0. New best score: 139.378


Epoch 13: 100%|██████████| 46/46 [00:00<00:00, 59.20it/s, v_num=130, train_loss=106.0, val_loss=127.0]

Metric val_loss improved by 12.589 >= min_delta = 0.0. New best score: 126.789


Epoch 14: 100%|██████████| 46/46 [00:00<00:00, 63.68it/s, v_num=130, train_loss=98.90, val_loss=125.0]

Metric val_loss improved by 1.783 >= min_delta = 0.0. New best score: 125.006


Epoch 16: 100%|██████████| 46/46 [00:00<00:00, 58.69it/s, v_num=130, train_loss=97.30, val_loss=118.0] 

Metric val_loss improved by 7.045 >= min_delta = 0.0. New best score: 117.962


Epoch 17: 100%|██████████| 46/46 [00:00<00:00, 64.07it/s, v_num=130, train_loss=94.60, val_loss=104.0] 

Metric val_loss improved by 13.578 >= min_delta = 0.0. New best score: 104.383


Epoch 19: 100%|██████████| 46/46 [00:00<00:00, 64.85it/s, v_num=130, train_loss=82.50, val_loss=117.0] 

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 46/46 [00:00<00:00, 63.99it/s, v_num=130, train_loss=82.50, val_loss=117.0]


## Model training: 2nd phase

In [72]:

print(unlabelled_data.shape)

threshold = 1164

f_unlabelled_data = unlabelled_data.head(threshold)
print(f_unlabelled_data.shape)

# Something's wrong with this rows, it make the model crash
unlabelled_data.tail(len(unlabelled_data) - threshold)

(1168, 22)
(1164, 22)


Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,YWHAQ,PDHA1,EWSR1,ZDHHC17,ENO1,DBN1,PLK1,GSK3B,Age,Size
1164,10.879891,6.431113,10.219154,5.435795,9.224122,5.699195,5.825643,6.404899,7.385644,9.271953,...,11.529199,8.334326,9.309069,7.048923,12.041769,8.572415,6.128115,7.68254,56.9,45.0
1165,9.591235,7.984515,9.935179,5.605596,9.799519,5.808704,5.905282,6.491419,7.865526,9.741103,...,12.011174,8.173189,9.294825,7.316246,12.540138,8.634905,6.089312,7.838041,43.1,25.0
1166,11.055114,8.282737,9.892589,5.753274,8.687667,5.475813,5.587906,6.830579,8.468221,9.482622,...,11.872388,8.245726,9.357164,7.354395,11.980115,8.250034,6.443083,8.319481,61.16,25.0
1167,10.696475,5.533486,10.227787,5.588965,10.212643,6.633792,5.504219,6.468054,8.391464,9.812163,...,11.701313,8.618193,9.742585,7.435703,12.639017,7.710563,6.684053,7.930635,60.02,20.0


In [73]:
pl_vime.set_second_phase()

train_ds = VIMEDataset(
    X_train,
    y_train.values,
    config,
    unlabeled_data=f_unlabelled_data,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols,
    is_second_phase=True)

valid_ds = VIMEDataset(
    X_val,
    y_val.values,
    config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols,
    is_second_phase=True)

datamodule = TS3LDataModule(
    train_ds,
    valid_ds,
    batch_size=batch_size,
    train_sampler="weighted",
    train_collate_fn=VIMESecondPhaseCollateFN()
)

In [74]:
trainer = Trainer(
        accelerator = 'cpu',
        max_epochs = max_epochs,
        num_sanity_val_steps = 2,
        callbacks = [early_stopping],
    )

trainer.fit(pl_vime, datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name                        | Type             | Params | Mode 
-------------------------------------------------------------------------
0 | task_loss_fn                | CrossEntropyLoss | 0      | train
1 | mask_loss_fn                | BCELoss          | 0      | train
2 | categorical_feature_loss_fn | CrossEntropyLoss | 0      | train
3 | continuous_feature_loss_fn  | MSELoss          | 0      | train
4 | consistency_loss_fn         | MSELoss          | 0      | train
5 | model                       | VIME             | 124 K  | train
-------------------------------------------------------------------------
105 K     Trainable params
19.7 K    Non-trainable params
124 K     Total params
0.499     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode


                                                                           

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (46) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 46/46 [00:01<00:00, 39.08it/s, v_num=131, train_loss=2.210, train_accuracy_score=0.530, val_accuracy_score=0.520, val_loss=0.705]

Metric val_loss improved by 103.678 >= min_delta = 0.0. New best score: 0.705


Epoch 1:   0%|          | 0/46 [00:00<?, ?it/s, v_num=131, train_loss=2.210, train_accuracy_score=0.530, val_accuracy_score=0.520, val_loss=0.705]         



Epoch 1: 100%|██████████| 46/46 [00:03<00:00, 13.19it/s, v_num=131, train_loss=1.490, train_accuracy_score=0.532, val_accuracy_score=0.533, val_loss=0.685]

Metric val_loss improved by 0.020 >= min_delta = 0.0. New best score: 0.685


Epoch 2:   0%|          | 0/46 [00:00<?, ?it/s, v_num=131, train_loss=1.490, train_accuracy_score=0.532, val_accuracy_score=0.533, val_loss=0.685]         



Epoch 2: 100%|██████████| 46/46 [00:01<00:00, 23.80it/s, v_num=131, train_loss=1.190, train_accuracy_score=0.537, val_accuracy_score=0.587, val_loss=0.674]

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 0.674


Epoch 3:   0%|          | 0/46 [00:00<?, ?it/s, v_num=131, train_loss=1.190, train_accuracy_score=0.537, val_accuracy_score=0.587, val_loss=0.674]         



Epoch 3: 100%|██████████| 46/46 [00:01<00:00, 24.46it/s, v_num=131, train_loss=1.060, train_accuracy_score=0.583, val_accuracy_score=0.680, val_loss=0.650]

Metric val_loss improved by 0.024 >= min_delta = 0.0. New best score: 0.650


Epoch 4:   0%|          | 0/46 [00:00<?, ?it/s, v_num=131, train_loss=1.060, train_accuracy_score=0.583, val_accuracy_score=0.680, val_loss=0.650]         



Epoch 5:   0%|          | 0/46 [00:00<?, ?it/s, v_num=131, train_loss=1.010, train_accuracy_score=0.555, val_accuracy_score=0.640, val_loss=0.672]         



Epoch 6:   0%|          | 0/46 [00:00<?, ?it/s, v_num=131, train_loss=0.970, train_accuracy_score=0.600, val_accuracy_score=0.640, val_loss=0.654]         



Epoch 6: 100%|██████████| 46/46 [00:02<00:00, 22.53it/s, v_num=131, train_loss=0.916, train_accuracy_score=0.610, val_accuracy_score=0.627, val_loss=0.698]

Monitored metric val_loss did not improve in the last 3 records. Best score: 0.650. Signaling Trainer to stop.


Epoch 6: 100%|██████████| 46/46 [00:02<00:00, 22.36it/s, v_num=131, train_loss=0.916, train_accuracy_score=0.610, val_accuracy_score=0.627, val_loss=0.698]


## Evaluation

In [75]:
test_ds = VIMEDataset(
    X_test,
    category_cols=categorical_cols,
    continuous_cols=numerical_cols,
    is_second_phase=True)

test_dl = DataLoader(
    test_ds,
    batch_size,
    shuffle=False,
    sampler=SequentialSampler(test_ds)
)

In [76]:
preds = trainer.predict(pl_vime, test_dl)
preds = F.softmax(torch.concat([out.cpu() for out in preds]).squeeze(),dim=1)

accuracy = accuracy_score(y_test, preds.argmax(1))

print("Accuracy %.2f" % accuracy)

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 47.29it/s]
Accuracy 0.61


In [78]:
predicted_labels = preds.argmax(1)
associated_probabilities = preds[np.arange(preds.shape[0]), predicted_labels]
pd.DataFrame({'label': predicted_labels, 'probability': associated_probabilities})

Unnamed: 0,label,probability
0,1,0.557235
1,0,0.582431
2,1,0.568620
3,0,0.697947
4,0,0.688896
...,...,...
88,1,0.604350
89,1,0.510552
90,0,0.627536
91,0,0.739085


## Label prediction

In [79]:
unlabelled_ds = VIMEDataset(
    X=unlabelled_data,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols)

unlabelled_dl = DataLoader(
    unlabelled_ds,
    batch_size,
    shuffle=False,
    sampler=SequentialSampler(unlabelled_ds)
)

In [80]:
preds = trainer.predict(pl_vime, unlabelled_dl)
preds = F.softmax(torch.concat([out.cpu() for out in preds]).squeeze(),dim=1)

predicted_labels = preds.argmax(1)
associated_probabilities = preds[np.arange(preds.shape[0]), predicted_labels]

pd.DataFrame({'label': predicted_labels, 'confidence': associated_probabilities})

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 37/37 [00:00<00:00, 163.89it/s]


Unnamed: 0,label,confidence
0,0,0.638018
1,0,0.627758
2,0,0.655118
3,0,0.728606
4,0,0.845933
...,...,...
1163,0,0.742593
1164,0,0.601021
1165,0,0.671212
1166,0,0.604925


## Data export

In [81]:
unlabelled_data = pd.read_csv(UNLABELLED_DATA_PATH)

In [82]:
train_data = pd.read_csv(TRAIN_DATA_PATH)

In [83]:
train_data['Confidence'] = 1.0

train_data.head()

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,Hormone Therapy,Neoplasm Histologic Grade,Cellularity,Surgery-breast conserving,Surgery-mastectomy,Label,DssTime,Event,auto_id,Confidence
0,11.445577,5.383985,10.523773,5.878819,7.376048,6.224054,5.730411,6.937679,7.314421,9.496653,...,0,1,0.5,1,0,1,52.733333,1,1271,1.0
1,11.728241,6.853981,9.093163,6.189151,9.775104,5.891141,5.644219,6.445037,7.500848,10.14059,...,1,3,1.0,0,1,0,61.1,1,1050,1.0
2,11.122611,5.322075,11.719898,6.121027,8.733519,6.321382,6.011898,6.702699,7.682648,9.736608,...,0,2,1.0,0,1,0,94.033333,1,1300,1.0
3,11.246158,7.200805,10.48821,5.921079,9.480574,5.598995,5.923112,6.411695,7.998787,10.009468,...,1,3,0.5,1,0,0,118.133333,1,1156,1.0
4,9.744005,5.481691,11.031849,5.826471,8.732274,5.873227,5.495358,6.243658,7.71902,9.32783,...,1,2,1.0,1,0,1,27.066667,1,1458,1.0


In [84]:
unlabelled_data['Label'] = predicted_labels
unlabelled_data['Confidence'] = associated_probabilities

unlabelled_data

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,Radio Therapy,Chemotherapy,Hormone Therapy,Neoplasm Histologic Grade,Cellularity,Surgery-breast conserving,Surgery-mastectomy,auto_id,Label,Confidence
0,10.047059,7.505424,9.729606,5.451007,8.474830,6.412419,5.899440,7.069394,7.100058,9.102318,...,1,0,1,3.0,1.0,1.0,0.0,2000,0,0.638018
1,10.404685,6.815637,10.334979,5.488309,9.994894,6.525927,5.585357,6.071653,7.810600,9.431167,...,1,1,1,2.0,0.5,0.0,1.0,2001,0,0.627758
2,10.793832,7.720952,9.276507,5.478224,8.184773,5.949741,5.743395,7.244882,7.781876,8.716918,...,1,1,1,2.0,0.5,1.0,0.0,2002,0,0.655118
3,10.440667,5.592522,8.613192,5.436625,8.210389,6.203913,5.852012,6.219653,7.560101,9.600030,...,1,0,1,2.0,0.5,1.0,0.0,2003,0,0.728606
4,12.521038,5.325554,10.678267,5.623786,7.786509,6.153012,5.502281,7.278257,7.674681,10.367760,...,1,0,1,3.0,0.5,1.0,0.0,2004,0,0.845933
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1163,11.628490,5.570690,10.475695,6.032211,9.944405,5.865408,5.703147,6.649948,7.272166,9.750208,...,0,0,1,3.0,1.0,0.0,1.0,3163,0,0.742593
1164,10.879891,6.431113,10.219154,5.435795,9.224122,5.699195,5.825643,6.404899,7.385644,9.271953,...,0,0,1,3.0,1.0,0.0,1.0,3164,0,0.601021
1165,9.591235,7.984515,9.935179,5.605596,9.799519,5.808704,5.905282,6.491419,7.865526,9.741103,...,1,0,1,3.0,1.0,1.0,0.0,3165,0,0.671212
1166,11.055114,8.282737,9.892589,5.753274,8.687667,5.475813,5.587906,6.830579,8.468221,9.482622,...,0,0,1,2.0,0.5,0.0,1.0,3166,0,0.604925


In [85]:
combined_data = pd.concat([train_data, unlabelled_data], ignore_index=True)

In [86]:
combined_data.shape

(1540, 35)

In [87]:
combined_data.sort_values(by='auto_id').to_csv(PSEUDO_LABELLED_DATA_PATH, index=False)