# SSL: VIME

This work is based on the one published by the author of the [TS3L library](https://github.com/Alcoholrithm/TabularS3L).

Specifically, this code is implementing [VIME](https://github.com/Alcoholrithm/TabularS3L)

In [154]:
# Prepare the VIMELightning Module
from ts3l.pl_modules import VIMELightning
from ts3l.utils.vime_utils import VIMEDataset
from ts3l.utils import TS3LDataModule, get_category_cardinality
from ts3l.utils.vime_utils import VIMEConfig
from ts3l.utils.embedding_utils import IdentityEmbeddingConfig
from ts3l.utils.backbone_utils import MLPBackboneConfig
from pytorch_lightning import Trainer
import numpy as np
from ts3l.utils.vime_utils import VIMESecondPhaseCollateFN
from sklearn.metrics import accuracy_score
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, SequentialSampler
from sklearn.model_selection import train_test_split

import pandas as pd

## Data loading

In [364]:
TEST_DATA_PATH = "../data/test_data.csv"
TRAIN_DATA_PATH = "../data/train_data.csv"
UNLABELLED_DATA_PATH = "../data/unlabelled_data.csv"
PSEUDO_LABELLED_DATA_PATH = "../data/pseudo_labelled_data.csv"

In [335]:
def get_dataframes(test_path, train_path, unlabelled_path, with_clinical=False):
    # Load the data
    test_df = pd.read_csv(test_path)
    train_df = pd.read_csv(train_path)
    unlabelled_df = pd.read_csv(unlabelled_path)

    # Drop the columns that are not needed
    test_df = test_df.drop(columns=['DssTime', 'Event'])
    train_df = train_df.drop(columns=['DssTime', 'Event'])

    # Extract numerical and categorical columns
    # Numerical cols: Gene + Age
    numerical_cols = test_df.columns[:21].tolist()
    # But also Size
    numerical_cols.append('Size')
    # Categorical cols: Clinical
    categorical_cols = test_df.drop(columns=['Label', 'Size']).columns[21:].tolist()
    if not with_clinical:
        test_df = test_df.drop(columns=categorical_cols)
        train_df = train_df.drop(columns=categorical_cols)
        unlabelled_df = unlabelled_df.drop(columns=categorical_cols)
        categorical_cols = []
    else:
        categorical_cols = ['Chemotherapy', 'Menopausal State', 'Radio Therapy', 'Hormone Therapy', 'Surgery-breast conserving', 'Surgery-mastectomy']
        # The model has problems with these columns
        not_cols = ['Neoplasm Histologic Grade', 'Cellularity']
        test_df = test_df.drop(columns=not_cols)
        train_df = train_df.drop(columns=not_cols)
        unlabelled_df = unlabelled_df.drop(columns=not_cols)

    print(f'Train data shape: {train_df.shape}')
    print(f'Test data shape: {test_df.shape}')
    print(f'Unlabelled data shape: {unlabelled_df.shape}')
    print(f'Numerical columns: {numerical_cols}')
    if with_clinical:
        print(f'Categorical columns: {categorical_cols}')
    return test_df, train_df, unlabelled_df, numerical_cols, categorical_cols

In [339]:
test_data, train_data, unlabelled_data, numerical_cols, categorical_cols = get_dataframes(
    TEST_DATA_PATH,
    TRAIN_DATA_PATH,
    UNLABELLED_DATA_PATH,
    with_clinical=False)

Train data shape: (465, 23)
Test data shape: (117, 23)
Unlabelled data shape: (1168, 22)
Numerical columns: ['ESR1', 'PGR', 'ERBB2', 'MKI67', 'PLAU', 'ELAVL1', 'EGFR', 'BTRC', 'FBXO6', 'SHMT2', 'KRAS', 'SRPK2', 'YWHAQ', 'PDHA1', 'EWSR1', 'ZDHHC17', 'ENO1', 'DBN1', 'PLK1', 'GSK3B', 'Age', 'Size']


In [340]:
test_data.head()

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,PDHA1,EWSR1,ZDHHC17,ENO1,DBN1,PLK1,GSK3B,Age,Size,Label
0,11.23975,5.954311,9.739996,6.046045,10.040187,5.905724,5.881255,6.538235,7.260572,10.774752,...,8.766946,9.710697,7.002427,12.416515,8.881028,6.466387,8.771647,78,31,1
1,10.927313,7.002502,10.033753,5.568993,8.306619,6.547491,5.733367,6.128118,7.917904,9.514045,...,8.09889,9.762576,7.122037,12.113516,8.553396,6.575161,8.360427,85,22,0
2,6.312633,5.305683,9.068778,5.919384,8.210977,5.896152,5.634379,5.625037,7.684047,11.422518,...,8.553177,9.328939,7.343709,12.022229,7.636171,6.221834,8.027209,50,40,1
3,9.1852,5.480888,9.580607,5.655789,7.756504,6.026981,6.008594,6.269051,7.428641,9.478211,...,8.168313,9.644231,7.425378,12.2849,8.701101,6.383001,8.494059,83,150,1
4,7.249462,5.164281,10.233184,5.721403,8.918334,6.392132,5.58845,6.062906,7.968933,9.578638,...,8.844283,9.537609,7.27258,12.556723,9.189911,6.909404,8.841997,82,45,1


In [341]:
unlabelled_data.head()

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,YWHAQ,PDHA1,EWSR1,ZDHHC17,ENO1,DBN1,PLK1,GSK3B,Age,Size
0,10.047059,7.505424,9.729606,5.451007,8.47483,6.412419,5.89944,7.069394,7.100058,9.102318,...,12.124106,8.234815,9.674483,7.375406,12.024185,7.757081,6.45558,8.385158,43.19,10.0
1,10.404685,6.815637,10.334979,5.488309,9.994894,6.525927,5.585357,6.071653,7.8106,9.431167,...,12.712623,8.242056,9.15434,7.585137,11.844377,8.042064,6.036497,7.464064,47.68,25.0
2,10.793832,7.720952,9.276507,5.478224,8.184773,5.949741,5.743395,7.244882,7.781876,8.716918,...,12.102056,8.105696,9.1365,7.730418,10.790484,7.599985,5.95331,7.229434,56.45,10.0
3,10.440667,5.592522,8.613192,5.436625,8.210389,6.203913,5.852012,6.219653,7.560101,9.60003,...,11.839824,8.376467,9.257898,7.743428,11.21288,7.770462,6.194253,7.289188,89.08,29.0
4,12.521038,5.325554,10.678267,5.623786,7.786509,6.153012,5.502281,7.278257,7.674681,10.36776,...,11.363913,9.031051,9.46249,7.472229,12.95541,8.591434,6.64576,8.737746,86.41,16.0


In [342]:
len(categorical_cols) + len(numerical_cols)

22

## Data configuration

In [343]:
full_X_train = train_data.drop(columns=['Label'])
full_y_train = train_data['Label']

X_test = test_data.drop(columns=['Label'])
y_test = test_data['Label']

In [344]:
# Split the train_data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(
    full_X_train,
    full_y_train,
    test_size=0.2,
    random_state=42,
    stratify=full_y_train)

print(f'Training data shape: {X_train.shape}')
print(f'Validation data shape: {X_val.shape}')

Training data shape: (372, 22)
Validation data shape: (93, 22)


## Model configuration

In [345]:
metric = "accuracy_score"
input_dim = X_train.shape[1]
predictor_dim = 1024
alpha1 = 2.0
alpha2 = 2.0
beta = 1.0
K = 2
p_m = 0.2
batch_size = 128
max_epochs = 20

In [346]:
embedding_config = IdentityEmbeddingConfig(input_dim = input_dim)
backbone_config = MLPBackboneConfig(input_dim = embedding_config.output_dim)

In [347]:
cardinality = get_category_cardinality(X_train, categorical_cols)

cardinality

[]

In [348]:
config = VIMEConfig( 
                    task="classification",
                    loss_fn="CrossEntropyLoss",
                    metric=metric,
                    metric_hparams={},
                    embedding_config=embedding_config,
                    backbone_config=backbone_config,
                    predictor_dim=predictor_dim,
                    output_dim=2,
                    alpha1=alpha1,
                    alpha2=alpha2, 
                    beta=beta,
                    K=K,
                    p_m = p_m,
                    cat_cardinality=cardinality,
                    num_continuous=len(numerical_cols),
)


In [349]:
pl_vime = VIMELightning(config)

Seed set to 42


In [350]:
train_ds = VIMEDataset(
    X=X_train,
    unlabeled_data=unlabelled_data,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols
)

valid_ds = VIMEDataset(
    X=X_train,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols
)

datamodule = TS3LDataModule(train_ds, valid_ds, batch_size, train_sampler='random')

## Model training: 1st phase

In [351]:
trainer = Trainer(
                    accelerator = 'cpu',
                    max_epochs = max_epochs,
                    num_sanity_val_steps = 2,
    )

trainer.fit(pl_vime, datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name                        | Type             | Params | Mode 
-------------------------------------------------------------------------
0 | task_loss_fn                | CrossEntropyLoss | 0      | train
1 | mask_loss_fn                | BCELoss          | 0      | train
2 | categorical_feature_loss_fn | CrossEntropyLoss | 0      | train
3 | continuous_feature_loss_fn  | MSELoss          | 0      | train
4 | consistency_loss_fn         | MSELoss          | 0      | train
5 | model                       | VIME             | 1.2 M  | train
-------------------------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.837     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode


                                                                            

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (13) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 19: 100%|██████████| 13/13 [00:01<00:00,  6.51it/s, v_num=55, train_loss=276.0, val_loss=280.0]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 13/13 [00:02<00:00,  6.36it/s, v_num=55, train_loss=276.0, val_loss=280.0]


## Model training: 2nd phase

In [352]:
X_train

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,YWHAQ,PDHA1,EWSR1,ZDHHC17,ENO1,DBN1,PLK1,GSK3B,Age,Size
271,11.445577,5.383985,10.523773,5.878819,7.376048,6.224054,5.730411,6.937679,7.314421,9.496653,...,12.212333,8.472817,10.014916,8.037095,11.556363,9.697175,6.687842,8.197441,63,43
50,11.728241,6.853981,9.093163,6.189151,9.775104,5.891141,5.644219,6.445037,7.500848,10.140590,...,11.938529,7.908106,9.290438,7.124771,12.319678,8.687310,6.408426,8.303559,74,43
300,11.122611,5.322075,11.719898,6.121027,8.733519,6.321382,6.011898,6.702699,7.682648,9.736608,...,12.368506,8.242456,9.763899,8.006279,12.571911,9.236787,6.921407,8.155855,52,15
156,11.246158,7.200805,10.488210,5.921079,9.480574,5.598995,5.923112,6.411695,7.998787,10.009468,...,10.978177,8.216202,9.566260,7.362756,12.310459,9.304842,6.561758,8.985041,62,20
458,9.744005,5.481691,11.031849,5.826471,8.732274,5.873227,5.495358,6.243658,7.719020,9.327830,...,12.078461,8.311747,8.969407,7.349488,11.630288,8.838334,6.456821,7.996271,55,30
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
175,9.646853,6.867510,11.548091,5.924567,7.296133,8.187168,5.600126,6.898188,7.883845,9.856422,...,10.292150,8.623320,10.110875,6.780786,11.715737,9.178288,6.453299,9.293761,58,15
241,11.852769,6.062227,10.560881,5.982969,8.425843,6.169082,5.764507,8.075540,7.504144,9.845907,...,11.484319,7.994028,9.499510,7.191114,12.438446,9.333542,6.489902,8.643844,73,25
238,9.979798,7.154027,11.187826,5.589705,9.142333,6.381046,6.613969,7.102050,7.290591,9.781446,...,11.501904,7.671545,9.797579,7.490341,12.483094,8.166508,6.440541,8.128982,45,11
205,9.991981,6.927914,11.341993,5.546633,7.917850,6.365073,6.679638,6.406762,6.592135,9.886140,...,12.486510,8.420539,9.735076,7.707661,12.529626,8.877610,6.226765,8.281162,44,31


In [353]:
unlabelled_data

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,YWHAQ,PDHA1,EWSR1,ZDHHC17,ENO1,DBN1,PLK1,GSK3B,Age,Size
0,10.047059,7.505424,9.729606,5.451007,8.474830,6.412419,5.899440,7.069394,7.100058,9.102318,...,12.124106,8.234815,9.674483,7.375406,12.024185,7.757081,6.455580,8.385158,43.19,10.0
1,10.404685,6.815637,10.334979,5.488309,9.994894,6.525927,5.585357,6.071653,7.810600,9.431167,...,12.712623,8.242056,9.154340,7.585137,11.844377,8.042064,6.036497,7.464064,47.68,25.0
2,10.793832,7.720952,9.276507,5.478224,8.184773,5.949741,5.743395,7.244882,7.781876,8.716918,...,12.102056,8.105696,9.136500,7.730418,10.790484,7.599985,5.953310,7.229434,56.45,10.0
3,10.440667,5.592522,8.613192,5.436625,8.210389,6.203913,5.852012,6.219653,7.560101,9.600030,...,11.839824,8.376467,9.257898,7.743428,11.212880,7.770462,6.194253,7.289188,89.08,29.0
4,12.521038,5.325554,10.678267,5.623786,7.786509,6.153012,5.502281,7.278257,7.674681,10.367760,...,11.363913,9.031051,9.462490,7.472229,12.955410,8.591434,6.645760,8.737746,86.41,16.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1163,11.628490,5.570690,10.475695,6.032211,9.944405,5.865408,5.703147,6.649948,7.272166,9.750208,...,11.184906,7.907846,9.350417,6.902080,11.863881,8.171088,6.401429,7.814694,66.48,25.0
1164,10.879891,6.431113,10.219154,5.435795,9.224122,5.699195,5.825643,6.404899,7.385644,9.271953,...,11.529199,8.334326,9.309069,7.048923,12.041769,8.572415,6.128115,7.682540,56.90,45.0
1165,9.591235,7.984515,9.935179,5.605596,9.799519,5.808704,5.905282,6.491419,7.865526,9.741103,...,12.011174,8.173189,9.294825,7.316246,12.540138,8.634905,6.089312,7.838041,43.10,25.0
1166,11.055114,8.282737,9.892589,5.753274,8.687667,5.475813,5.587906,6.830579,8.468221,9.482622,...,11.872388,8.245726,9.357164,7.354395,11.980115,8.250034,6.443083,8.319481,61.16,25.0


In [354]:

print(unlabelled_data.shape)

threshold = 1164

f_unlabelled_data = unlabelled_data.head(threshold)
print(f_unlabelled_data.shape)

# Something's wrong with this rows, it make the model crash
unlabelled_data.tail(len(unlabelled_data) - threshold)

(1168, 22)
(1164, 22)


Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,YWHAQ,PDHA1,EWSR1,ZDHHC17,ENO1,DBN1,PLK1,GSK3B,Age,Size
1164,10.879891,6.431113,10.219154,5.435795,9.224122,5.699195,5.825643,6.404899,7.385644,9.271953,...,11.529199,8.334326,9.309069,7.048923,12.041769,8.572415,6.128115,7.68254,56.9,45.0
1165,9.591235,7.984515,9.935179,5.605596,9.799519,5.808704,5.905282,6.491419,7.865526,9.741103,...,12.011174,8.173189,9.294825,7.316246,12.540138,8.634905,6.089312,7.838041,43.1,25.0
1166,11.055114,8.282737,9.892589,5.753274,8.687667,5.475813,5.587906,6.830579,8.468221,9.482622,...,11.872388,8.245726,9.357164,7.354395,11.980115,8.250034,6.443083,8.319481,61.16,25.0
1167,10.696475,5.533486,10.227787,5.588965,10.212643,6.633792,5.504219,6.468054,8.391464,9.812163,...,11.701313,8.618193,9.742585,7.435703,12.639017,7.710563,6.684053,7.930635,60.02,20.0


In [355]:
pl_vime.set_second_phase()

train_ds = VIMEDataset(
    X_train,
    y_train.values,
    config,
    unlabeled_data=f_unlabelled_data,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols,
    is_second_phase=True)

valid_ds = VIMEDataset(
    X_val,
    y_val.values,
    config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols,
    is_second_phase=True)

datamodule = TS3LDataModule(
    train_ds,
    valid_ds,
    batch_size=batch_size,
    train_sampler="weighted",
    train_collate_fn=VIMESecondPhaseCollateFN()
)

In [356]:
trainer = Trainer(
        accelerator = 'cpu',
        max_epochs = max_epochs,
        num_sanity_val_steps = 2,
    )

trainer.fit(pl_vime, datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name                        | Type             | Params | Mode 
-------------------------------------------------------------------------
0 | task_loss_fn                | CrossEntropyLoss | 0      | train
1 | mask_loss_fn                | BCELoss          | 0      | train
2 | categorical_feature_loss_fn | CrossEntropyLoss | 0      | train
3 | continuous_feature_loss_fn  | MSELoss          | 0      | train
4 | consistency_loss_fn         | MSELoss          | 0      | train
5 | model                       | VIME             | 1.2 M  | train
-------------------------------------------------------------------------
1.2 M     Trainable params
19.7 K    Non-trainable params
1.2 M     Total params
4.837     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode


                                                                            

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (12) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 1:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=1.810, train_accuracy_score=0.545, val_accuracy_score=0.591, val_loss=0.720]         



Epoch 2:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=1.190, train_accuracy_score=0.591, val_accuracy_score=0.548, val_loss=0.680]         



Epoch 3:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=1.110, train_accuracy_score=0.593, val_accuracy_score=0.559, val_loss=0.670]         



Epoch 4:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=1.020, train_accuracy_score=0.570, val_accuracy_score=0.548, val_loss=0.719]         



Epoch 5:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.954, train_accuracy_score=0.597, val_accuracy_score=0.559, val_loss=0.681]         



Epoch 6:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.925, train_accuracy_score=0.569, val_accuracy_score=0.581, val_loss=0.671]         



Epoch 7:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.857, train_accuracy_score=0.631, val_accuracy_score=0.624, val_loss=0.672]         



Epoch 8:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.852, train_accuracy_score=0.618, val_accuracy_score=0.570, val_loss=0.725]         



Epoch 9:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.852, train_accuracy_score=0.588, val_accuracy_score=0.634, val_loss=0.650]         



Epoch 2:  92%|█████████▏| 12/13 [09:25<00:47,  0.02it/s, v_num=50, train_loss=nan.0, train_accuracy_score=0.624, val_accuracy_score=0.656, val_loss=0.623]
Epoch 2:  92%|█████████▏| 12/13 [09:25<00:47,  0.02it/s, v_num=50, train_loss=nan.0, train_accuracy_score=0.624, val_accuracy_score=0.656, val_loss=0.623]
Epoch 10:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.846, train_accuracy_score=0.603, val_accuracy_score=0.591, val_loss=0.773]        



Epoch 11:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.897, train_accuracy_score=0.582, val_accuracy_score=0.581, val_loss=0.659]         



Epoch 12:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.919, train_accuracy_score=0.538, val_accuracy_score=0.613, val_loss=0.653]         



Epoch 13:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.840, train_accuracy_score=0.633, val_accuracy_score=0.645, val_loss=0.658]         



Epoch 14:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.813, train_accuracy_score=0.619, val_accuracy_score=0.667, val_loss=0.640]         



Epoch 15:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.820, train_accuracy_score=0.634, val_accuracy_score=0.624, val_loss=0.638]         



Epoch 16:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.812, train_accuracy_score=0.639, val_accuracy_score=0.570, val_loss=0.706]         



Epoch 17:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.818, train_accuracy_score=0.602, val_accuracy_score=0.656, val_loss=0.654]         



Epoch 18:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.931, train_accuracy_score=0.577, val_accuracy_score=0.473, val_loss=1.050]         



Epoch 19:   0%|          | 0/12 [00:00<?, ?it/s, v_num=56, train_loss=0.910, train_accuracy_score=0.568, val_accuracy_score=0.613, val_loss=0.632]         



Epoch 19: 100%|██████████| 12/12 [00:02<00:00,  5.63it/s, v_num=56, train_loss=0.812, train_accuracy_score=0.601, val_accuracy_score=0.656, val_loss=0.654]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 12/12 [00:02<00:00,  5.30it/s, v_num=56, train_loss=0.812, train_accuracy_score=0.601, val_accuracy_score=0.656, val_loss=0.654]


## Evaluation

In [357]:
test_ds = VIMEDataset(
    X_test,
    category_cols=categorical_cols,
    continuous_cols=numerical_cols,
    is_second_phase=True)

test_dl = DataLoader(
    test_ds,
    batch_size,
    shuffle=False,
    sampler=SequentialSampler(test_ds)
)

In [358]:
preds = trainer.predict(pl_vime, test_dl)
preds = F.softmax(torch.concat([out.cpu() for out in preds]).squeeze(),dim=1)

accuracy = accuracy_score(y_test, preds.argmax(1))

print("Accuracy %.2f" % accuracy)

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 214.24it/s]
Accuracy 0.70


In [359]:
predicted_labels = preds.argmax(1)
associated_probabilities = preds[np.arange(preds.shape[0]), predicted_labels]
pd.DataFrame({'label': predicted_labels, 'probability': associated_probabilities})

Unnamed: 0,label,probability
0,1,0.525858
1,0,0.517442
2,1,0.632501
3,0,0.955566
4,1,0.534167
...,...,...
112,0,0.703786
113,1,0.510575
114,0,0.559650
115,0,0.551885


## Label prediction

In [360]:
unlabelled_ds = VIMEDataset(
    X=unlabelled_data,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols)

unlabelled_dl = DataLoader(
    unlabelled_ds,
    batch_size,
    shuffle=False,
    sampler=SequentialSampler(unlabelled_ds)
)

In [361]:
preds = trainer.predict(pl_vime, unlabelled_dl)
preds = F.softmax(torch.concat([out.cpu() for out in preds]).squeeze(),dim=1)

predicted_labels = preds.argmax(1)
associated_probabilities = preds[np.arange(preds.shape[0]), predicted_labels]

pd.DataFrame({'label': predicted_labels, 'confidence': associated_probabilities})

/home/sonk/envs/pandas/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 56.90it/s]


Unnamed: 0,label,confidence
0,0,0.659109
1,0,0.590492
2,0,0.739665
3,1,0.682598
4,0,0.541471
...,...,...
1163,0,0.546133
1164,1,0.544064
1165,0,0.557837
1166,0,0.583897


## Data export

In [363]:
unlabelled_data['Label'] = predicted_labels
unlabelled_data['Confidence'] = associated_probabilities

unlabelled_data

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,EWSR1,ZDHHC17,ENO1,DBN1,PLK1,GSK3B,Age,Size,Label,Confidence
0,10.047059,7.505424,9.729606,5.451007,8.474830,6.412419,5.899440,7.069394,7.100058,9.102318,...,9.674483,7.375406,12.024185,7.757081,6.455580,8.385158,43.19,10.0,0,0.659109
1,10.404685,6.815637,10.334979,5.488309,9.994894,6.525927,5.585357,6.071653,7.810600,9.431167,...,9.154340,7.585137,11.844377,8.042064,6.036497,7.464064,47.68,25.0,0,0.590492
2,10.793832,7.720952,9.276507,5.478224,8.184773,5.949741,5.743395,7.244882,7.781876,8.716918,...,9.136500,7.730418,10.790484,7.599985,5.953310,7.229434,56.45,10.0,0,0.739665
3,10.440667,5.592522,8.613192,5.436625,8.210389,6.203913,5.852012,6.219653,7.560101,9.600030,...,9.257898,7.743428,11.212880,7.770462,6.194253,7.289188,89.08,29.0,1,0.682598
4,12.521038,5.325554,10.678267,5.623786,7.786509,6.153012,5.502281,7.278257,7.674681,10.367760,...,9.462490,7.472229,12.955410,8.591434,6.645760,8.737746,86.41,16.0,0,0.541471
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1163,11.628490,5.570690,10.475695,6.032211,9.944405,5.865408,5.703147,6.649948,7.272166,9.750208,...,9.350417,6.902080,11.863881,8.171088,6.401429,7.814694,66.48,25.0,0,0.546133
1164,10.879891,6.431113,10.219154,5.435795,9.224122,5.699195,5.825643,6.404899,7.385644,9.271953,...,9.309069,7.048923,12.041769,8.572415,6.128115,7.682540,56.90,45.0,1,0.544064
1165,9.591235,7.984515,9.935179,5.605596,9.799519,5.808704,5.905282,6.491419,7.865526,9.741103,...,9.294825,7.316246,12.540138,8.634905,6.089312,7.838041,43.10,25.0,0,0.557837
1166,11.055114,8.282737,9.892589,5.753274,8.687667,5.475813,5.587906,6.830579,8.468221,9.482622,...,9.357164,7.354395,11.980115,8.250034,6.443083,8.319481,61.16,25.0,0,0.583897


In [365]:
unlabelled_data.to_csv(PSEUDO_LABELLED_DATA_PATH, index=False)