In [40]:
# Prepare the VIMELightning Module
from ts3l.pl_modules import VIMELightning
from ts3l.utils.vime_utils import VIMEDataset
from ts3l.utils import TS3LDataModule, get_category_cardinality
from ts3l.utils.vime_utils import VIMEConfig
from ts3l.utils.embedding_utils import IdentityEmbeddingConfig
from ts3l.utils.backbone_utils import MLPBackboneConfig
from pytorch_lightning import Trainer
import numpy as np

import pandas as pd

In [12]:
test_data = pd.read_csv('../data/test_data.csv')
train_data = pd.read_csv('../data/train_data.csv')
unlabelled_data = pd.read_csv('../data/unlabelled_data.csv')

print(f'Train data shape: {train_data.shape}')
print(f'Test data shape: {test_data.shape}')
print(f'Unlabelled data shape: {unlabelled_data.shape}')

Train data shape: (465, 33)
Test data shape: (117, 33)
Unlabelled data shape: (1168, 30)


In [13]:
test_data.head()

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,Radio Therapy,Chemotherapy,Hormone Therapy,Neoplasm Histologic Grade,Cellularity,Surgery-breast conserving,Surgery-mastectomy,Label,DssTime,Event
0,11.23975,5.954311,9.739996,6.046045,10.040187,5.905724,5.881255,6.538235,7.260572,10.774752,...,1,0,1,3,0.5,0,1,1,7.8,1
1,10.927313,7.002502,10.033753,5.568993,8.306619,6.547491,5.733367,6.128118,7.917904,9.514045,...,1,0,1,2,0.5,0,1,0,132.033333,1
2,6.312633,5.305683,9.068778,5.919384,8.210977,5.896152,5.634379,5.625037,7.684047,11.422518,...,1,1,0,3,1.0,0,1,1,28.5,1
3,9.1852,5.480888,9.580607,5.655789,7.756504,6.026981,6.008594,6.269051,7.428641,9.478211,...,1,0,1,3,1.0,0,1,1,39.166667,1
4,7.249462,5.164281,10.233184,5.721403,8.918334,6.392132,5.58845,6.062906,7.968933,9.578638,...,1,1,0,2,1.0,0,1,1,31.3,1


In [14]:
numerical_cols = test_data.columns[:21].tolist() # gene + AGE
numerical_cols.append('Size')

test_data[numerical_cols]

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,PLAU,ELAVL1,EGFR,BTRC,FBXO6,SHMT2,...,YWHAQ,PDHA1,EWSR1,ZDHHC17,ENO1,DBN1,PLK1,GSK3B,Age,Size
0,11.239750,5.954311,9.739996,6.046045,10.040187,5.905724,5.881255,6.538235,7.260572,10.774752,...,12.049173,8.766946,9.710697,7.002427,12.416515,8.881028,6.466387,8.771647,78,31
1,10.927313,7.002502,10.033753,5.568993,8.306619,6.547491,5.733367,6.128118,7.917904,9.514045,...,11.475811,8.098890,9.762576,7.122037,12.113516,8.553396,6.575161,8.360427,85,22
2,6.312633,5.305683,9.068778,5.919384,8.210977,5.896152,5.634379,5.625037,7.684047,11.422518,...,12.534071,8.553177,9.328939,7.343709,12.022229,7.636171,6.221834,8.027209,50,40
3,9.185200,5.480888,9.580607,5.655789,7.756504,6.026981,6.008594,6.269051,7.428641,9.478211,...,12.044500,8.168313,9.644231,7.425378,12.284900,8.701101,6.383001,8.494059,83,150
4,7.249462,5.164281,10.233184,5.721403,8.918334,6.392132,5.588450,6.062906,7.968933,9.578638,...,12.989485,8.844283,9.537609,7.272580,12.556723,9.189911,6.909404,8.841997,82,45
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
112,12.185732,5.462241,10.359953,6.181130,7.712059,6.827588,5.681641,8.019460,6.743522,9.729085,...,11.318578,7.927799,10.112902,8.180081,11.426533,8.797799,7.082542,8.743363,58,60
113,9.705585,6.727055,10.990295,5.633936,8.161321,6.180448,5.828184,6.757800,7.328591,9.771554,...,11.543076,7.825238,9.851387,7.315788,12.194310,9.319627,6.282904,8.144959,43,30
114,10.303005,5.996609,10.523380,5.813295,8.719961,6.190540,6.144548,6.251790,7.811116,10.481299,...,12.045751,8.439066,9.269746,6.885769,12.788076,10.051038,6.588315,8.920050,67,40
115,11.886409,6.266055,10.260601,5.883065,7.146876,6.430670,5.579551,7.444408,8.435352,9.369971,...,11.325677,7.519622,10.078754,7.805507,11.999126,8.527640,6.587311,8.244960,59,40


In [15]:
categorical_cols = test_data.drop(columns=['Label', 'DssTime', 'Event', 'Size']).columns[21:].tolist()

test_data[categorical_cols]

Unnamed: 0,Menopausal State,Radio Therapy,Chemotherapy,Hormone Therapy,Neoplasm Histologic Grade,Cellularity,Surgery-breast conserving,Surgery-mastectomy
0,1,1,0,1,3,0.5,0,1
1,1,1,0,1,2,0.5,0,1
2,1,1,1,0,3,1.0,0,1
3,1,1,0,1,3,1.0,0,1
4,1,1,1,0,2,1.0,0,1
...,...,...,...,...,...,...,...,...
112,1,0,0,1,3,0.5,0,1
113,0,0,1,1,2,0.5,0,1
114,1,0,0,1,3,1.0,0,1
115,1,1,0,1,1,0.5,1,0


In [25]:
len(categorical_cols) + len(numerical_cols)

30

In [16]:
full_X_train = train_data.drop(columns=['Label', 'DssTime', 'Event'])
full_y_train = train_data['Label']

X_test = test_data.drop(columns=['Label', 'DssTime', 'Event'])
y_test = test_data['Label']

In [17]:
from sklearn.model_selection import train_test_split

# Split the train_data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(
    full_X_train,
    full_y_train,
    test_size=0.2,
    random_state=42,
    stratify=full_y_train)

print(f'Training data shape: {X_train.shape}')
print(f'Validation data shape: {X_val.shape}')

Training data shape: (372, 30)
Validation data shape: (93, 30)


In [65]:
np.unique(y_train)

array([0, 1])

In [66]:
np.unique(y_val)

array([0, 1])

In [58]:
metric = "accuracy_score"
input_dim = X_train.shape[1]
predictor_dim = 1024
alpha1 = 2.0
alpha2 = 2.0
beta = 1.0
K = 2
p_m = 0.2

In [59]:
batch_size = 128
max_epochs = 20

In [60]:
embedding_config = IdentityEmbeddingConfig(input_dim = input_dim)
backbone_config = MLPBackboneConfig(input_dim = embedding_config.output_dim)

In [61]:
config = VIMEConfig( 
                    task="classification",
                    loss_fn="CrossEntropyLoss",
                    metric=metric,
                    metric_hparams={},
                    embedding_config=embedding_config,
                    backbone_config=backbone_config,
                    predictor_dim=predictor_dim,
                    output_dim=2,
                    alpha1=alpha1,
                    alpha2=alpha2, 
                    beta=beta,
                    K=K,
                    p_m = p_m,
                    cat_cardinality=get_category_cardinality(X_train, categorical_cols),
                    num_continuous=len(numerical_cols),
)


In [62]:
pl_vime = VIMELightning(config)

Seed set to 42


In [67]:
### First Phase Learning
train_ds = VIMEDataset(
    X=X_train,
    unlabeled_data=None,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols
)

valid_ds = VIMEDataset(
    X=X_train,
    config=config,
    continuous_cols=numerical_cols,
    category_cols=categorical_cols)

datamodule = TS3LDataModule(train_ds, valid_ds, batch_size, train_sampler='random')

In [68]:
trainer = Trainer(
                    accelerator = 'cpu',
                    max_epochs = max_epochs,
                    num_sanity_val_steps = 2,
    )

trainer.fit(pl_vime, datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name                        | Type             | Params | Mode
------------------------------------------------------------------------
0 | task_loss_fn                | CrossEntropyLoss | 0      | eval
1 | mask_loss_fn                | BCELoss          | 0      | eval
2 | categorical_feature_loss_fn | CrossEntropyLoss | 0      | eval
3 | continuous_feature_loss_fn  | MSELoss          | 0      | eval
4 | consistency_loss_fn         | MSELoss          | 0      | eval
5 | model                       | VIME             | 1.2 M  | eval
------------------------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.854     Total estimated model params size (MB)
0         Modules in train mode
31        Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

IndexError: Target 3 is out of bounds.