In [1]:
import os
# set wd to cancer
os.chdir("/work/postresearch/Shared/Researchers/Farbod/cancer/code/")
print(os.getcwd())

/work/postresearch/Shared/Researchers/Farbod/cancer/code


In [2]:
import os
import sys
import math
import copy
import logging
import requests
import zipfile
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from itertools import chain
from typing import Union

from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor

from pytorch_lightning import LightningModule, Trainer
from omegaconf import DictConfig, OmegaConf
from omegaconf.errors import MissingMandatoryValue

import ray
from ray import tune, ray_constants

import hydra
from hydra import initialize, compose
from hydra.utils import instantiate

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

from IPython.core.interactiveshell import InteractiveShell

from src.data import RealDatasetCollection, SyntheticDatasetCollection
from src.models import TimeVaryingCausalModel
from src.models.utils import (
    grad_reverse,
    BRTreatmentOutcomeHead,
    AlphaRise,
    clip_normalize_stabilized_weights,
)
from src.models.utils_lstm import VariationalLSTM
from copy import deepcopy

InteractiveShell.ast_node_interactivity = "all"

In [3]:
# show pytorch version and cuda version
print(torch.__version__)
print(torch.version.cuda)

if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for i in range(num_gpus):
        gpu_name = torch.cuda.get_device_name(i)
        print(f"GPU {i}: {gpu_name}")
else:
    print("CUDA is not available. No GPUs detected.")

2.0.0+cu118
11.8
Number of available GPUs: 1
GPU 0: NVIDIA A100 80GB PCIe


# MSM

In [None]:
from pytorch_lightning import LightningModule
from omegaconf import DictConfig
from omegaconf.errors import MissingMandatoryValue
import torch
import math
from typing import Union
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import logging
import numpy as np
from copy import deepcopy
from pytorch_lightning import Trainer
import ray
from ray import tune
from ray import ray_constants
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor

from src.data import RealDatasetCollection, SyntheticDatasetCollection
from src.models import TimeVaryingCausalModel
from src.models.utils import grad_reverse, BRTreatmentOutcomeHead, AlphaRise, clip_normalize_stabilized_weights
from src.models.utils_lstm import VariationalLSTM


logger = logging.getLogger(__name__)


class MSM(TimeVaryingCausalModel):
    """
    Pytorch-Lightning implementation of Marginal Structural Models (MSMs) (https://pubmed.ncbi.nlm.nih.gov/10955408/)
    """

    model_type = None  # Will be defined in subclasses
    possible_model_types = {'msm_regressor',
                            'propensity_treatment', 'propensity_history'}
    tuning_criterion = None

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: Union[RealDatasetCollection,
                                           SyntheticDatasetCollection] = None,
                 autoregressive: bool = None,
                 has_vitals: bool = None,
                 **kwargs):
        """
        Args:
            args: DictConfig of model hyperparameters
            dataset_collection: Dataset collection
            autoregressive: Flag of including previous outcomes to modelling
            has_vitals: Flag of vitals in dataset
            **kwargs: Other arguments
        """
        super().__init__(args, dataset_collection, autoregressive, has_vitals)
        self.lag_features = args.model.lag_features

    def prepare_data(self) -> None:
        if self.dataset_collection is not None and not self.dataset_collection.processed_data_multi:
            # Only binary multilabel regime possible
            assert self.hparams.dataset.treatment_mode == 'multilabel'
            self.dataset_collection.process_data_multi()

    def get_exploded_dataset(self, dataset: Dataset, min_length: int, only_active_entries=True, max_length=None) -> Dataset:
        exploded_dataset = deepcopy(dataset)
        if max_length is None:
            max_length = max(exploded_dataset.data['sequence_lengths'][:])
        if not only_active_entries:
            exploded_dataset.data['active_entries'][:, :, :] = 1.0
            exploded_dataset.data['sequence_lengths'][:] = max_length
        exploded_dataset.explode_trajectories(min_length)
        return exploded_dataset

    def get_propensity_scores(self, dataset: Dataset) -> np.array:
        logger.info(f'Propensity scores for {dataset.subset_name}.')
        exploded_dataset = self.get_exploded_dataset(
            dataset, min_length=self.lag_features, only_active_entries=False)

        inputs = self.get_inputs(exploded_dataset)
        classifier = getattr(self, self.model_type)

        propensity_scores = np.stack(
            classifier.predict_proba(inputs), 1)[:, :, 1]
        propensity_scores = propensity_scores.reshape(dataset.data['active_entries'].shape[0],
                                                      dataset.data['active_entries'].shape[1] -
                                                      self.lag_features,
                                                      self.dim_treatments)
        propensity_scores = np.concatenate([0.5 * np.ones((propensity_scores.shape[0], self.lag_features, self.dim_treatments)),
                                            propensity_scores], axis=1)
        return propensity_scores


class MSMPropensityTreatment(MSM):

    model_type = 'propensity_treatment'

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: Union[RealDatasetCollection,
                                           SyntheticDatasetCollection] = None,
                 autoregressive: bool = None, has_vitals: bool = None, **kwargs):
        super().__init__(args, dataset_collection, autoregressive, has_vitals)

        self.input_size = self.dim_treatments
        logger.info(f'Input size of {self.model_type}: {self.input_size}')
        self.output_size = self.dim_treatments

        self.propensity_treatment = MultiOutputClassifier(
            LogisticRegression(penalty=None, max_iter=args.exp.max_epochs))
        self.save_hyperparameters(args)

    def get_inputs(self, dataset: Dataset) -> np.array:
        active_entries = dataset.data['active_entries']
        prev_treatments = dataset.data['prev_treatments']
        inputs = (prev_treatments * active_entries).sum(1)
        return inputs

    def fit(self):
        self.prepare_data()
        train_f = self.get_exploded_dataset(
            self.dataset_collection.train_f, min_length=self.lag_features)
        active_entries = train_f.data['active_entries']
        last_entries = active_entries - \
            np.concatenate([active_entries[:, 1:, :], np.zeros(
                (active_entries.shape[0], 1, 1))], axis=1)

        # Inputs
        inputs = self.get_inputs(train_f)

        # Outputs
        current_treatments = train_f.data['current_treatments']
        outputs = (current_treatments * last_entries).sum(1)

        self.propensity_treatment.fit(inputs, outputs)


class MSMPropensityHistory(MSM):

    model_type = 'propensity_history'

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: Union[RealDatasetCollection,
                                           SyntheticDatasetCollection] = None,
                 autoregressive: bool = None, has_vitals: bool = None, **kwargs):
        super().__init__(args, dataset_collection, autoregressive, has_vitals)

        self.input_size = self.dim_treatments + self.dim_static_features
        self.input_size += self.dim_vitals if self.has_vitals else 0
        self.input_size += self.dim_outcome if self.autoregressive else 0

        logger.info(f'Input size of {self.model_type}: {self.input_size}')
        self.output_size = self.dim_treatments

        self.propensity_history = MultiOutputClassifier(
            LogisticRegression(penalty=None, max_iter=args.exp.max_epochs))
        self.save_hyperparameters(args)

    def get_inputs(self, dataset: Dataset, projection_horizon=0) -> np.array:
        active_entries = dataset.data['active_entries']
        lagged_entries = active_entries - \
            np.concatenate([active_entries[:, self.lag_features + 1:, :],
                            np.zeros((active_entries.shape[0], self.lag_features + 1, 1))], axis=1)
        if projection_horizon > 0:
            lagged_entries = np.concatenate([lagged_entries[:, projection_horizon:, :],
                                             np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        active_entries_before_proection = np.concatenate([active_entries[:, projection_horizon:, :],
                                                          np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        prev_treatments = dataset.data['prev_treatments']
        inputs = [(prev_treatments * active_entries_before_proection).sum(1)]
        if self.has_vitals:
            vitals = dataset.data['vitals']
            inputs.append(vitals[np.repeat(lagged_entries, self.dim_vitals, 2) == 1.0].reshape(vitals.shape[0],
                                                                                               (self.lag_features + 1) *
                                                                                               self.dim_vitals))
        if self.autoregressive:
            prev_outputs = dataset.data['prev_outputs']
            inputs.append(prev_outputs[np.repeat(lagged_entries, self.dim_outcome, 2) == 1.0].reshape(prev_outputs.shape[0],
                                                                                                      (self.lag_features + 1) *
                                                                                                      self.dim_outcome))
        static_features = dataset.data['static_features']
        inputs.append(static_features)
        return np.concatenate(inputs, axis=1)

    def fit(self):
        self.prepare_data()
        train_f = self.get_exploded_dataset(
            self.dataset_collection.train_f, min_length=self.lag_features)
        active_entries = train_f.data['active_entries']
        last_entries = active_entries - \
            np.concatenate([active_entries[:, 1:, :], np.zeros(
                (active_entries.shape[0], 1, 1))], axis=1)

        # Inputs
        inputs = self.get_inputs(train_f)

        # Outputs
        current_treatments = train_f.data['current_treatments']
        outputs = (current_treatments * last_entries).sum(1)

        self.propensity_history.fit(inputs, outputs)


class MSMRegressor(MSM):

    model_type = 'msm_regressor'

    def __init__(self,
                 args: DictConfig,
                 propensity_treatment: MSMPropensityTreatment = None,
                 propensity_history: MSMPropensityHistory = None,
                 dataset_collection: Union[RealDatasetCollection,
                                           SyntheticDatasetCollection] = None,
                 autoregressive: bool = None, has_vitals: bool = None, **kwargs):
        super().__init__(args, dataset_collection, autoregressive, has_vitals)

        self.input_size = self.dim_treatments + self.dim_static_features
        self.input_size += self.dim_vitals if self.has_vitals else 0
        self.input_size += self.dim_outcome if self.autoregressive else 0

        logger.info(f'Input size of {self.model_type}: {self.input_size}')
        self.output_size = self.dim_outcome

        self.propensity_treatment = propensity_treatment
        self.propensity_history = propensity_history

        self.msm_regressor = \
            [MultiOutputRegressor(LinearRegression()) for _ in range(
                self.dataset_collection.projection_horizon + 1)]
        self.save_hyperparameters(args)

    def get_inputs(self, dataset: Dataset, projection_horizon=0, tau=0) -> np.array:
        active_entries = dataset.data['active_entries']
        lagged_entries = active_entries - np.concatenate([active_entries[:, self.lag_features + 1:, :],
                                                          np.zeros((active_entries.shape[0], self.lag_features + 1, 1))], axis=1)
        if projection_horizon > 0:
            lagged_entries = np.concatenate([lagged_entries[:, projection_horizon:, :],
                                             np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        active_entries_before_proection = np.concatenate([active_entries[:, projection_horizon:, :],
                                                          np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        prev_treatments = dataset.data['prev_treatments']
        inputs = [(prev_treatments * active_entries_before_proection).sum(1)]
        if self.has_vitals:
            vitals = dataset.data['vitals']
            inputs.append(vitals[np.repeat(lagged_entries, self.dim_vitals, 2) == 1.0].reshape(vitals.shape[0],
                                                                                               (self.lag_features + 1) *
                                                                                               self.dim_vitals))
        if self.autoregressive:
            prev_outputs = dataset.data['prev_outputs']
            inputs.append(
                prev_outputs[np.repeat(lagged_entries, self.dim_outcome, 2) == 1.0].reshape(prev_outputs.shape[0],
                                                                                            (self.lag_features + 1) *
                                                                                            self.dim_outcome))
        static_features = dataset.data['static_features']
        inputs.append(static_features)

        # Adding current actions
        current_treatments = dataset.data['current_treatments']
        prediction_entries = active_entries - np.concatenate(
            [active_entries[:, tau + 1:, :], np.zeros((active_entries.shape[0], tau + 1, 1))], axis=1)
        prediction_entries = np.concatenate([prediction_entries[:, projection_horizon - tau:, :],
                                             np.zeros((prediction_entries.shape[0], projection_horizon - tau, 1))], axis=1)
        inputs.append((current_treatments * prediction_entries).sum(1))

        return np.concatenate(inputs, axis=1)

    def get_sample_weights(self, dataset: Dataset, tau=0) -> np.array:
        active_entries = dataset.data['active_entries']
        stabilized_weights = dataset.data['stabilized_weights']

        prediction_entries = active_entries - np.concatenate(
            [active_entries[:, tau + 1:, :],
                np.zeros((active_entries.shape[0], tau + 1, 1))],
            axis=1)
        stabilized_weights = stabilized_weights[np.squeeze(prediction_entries) == 1.0].reshape(stabilized_weights.shape[0],
                                                                                               tau + 1)
        sw = np.prod(stabilized_weights, axis=1)
        sw_tilde = np.clip(sw, np.nanquantile(
            sw, 0.01), np.nanquantile(sw, 0.99))
        return sw_tilde

    def prepare_data(self) -> None:
        if self.dataset_collection is not None and not self.dataset_collection.processed_data_multi:
            self.dataset_collection.process_data_multi()
        if self.dataset_collection is not None and 'stabilized_weights' not in self.dataset_collection.train_f.data:
            self.dataset_collection.process_propensity_train_f(
                self.propensity_treatment, self.propensity_history)

    def fit(self):
        self.prepare_data()
        for tau in range(self.dataset_collection.projection_horizon + 1):

            train_f = self.get_exploded_dataset(
                self.dataset_collection.train_f, min_length=self.lag_features + tau)
            active_entries = train_f.data['active_entries']
            last_entries = active_entries - \
                np.concatenate([active_entries[:, 1:, :], np.zeros(
                    (active_entries.shape[0], 1, 1))], axis=1)

            # Inputs
            inputs = self.get_inputs(train_f, projection_horizon=tau, tau=tau)

            # Stabilized weights
            sw = self.get_sample_weights(train_f, tau=tau)

            # Outputs
            outputs = train_f.data['outputs']
            outputs = (outputs * last_entries).sum(1)

            self.msm_regressor[tau].fit(inputs, outputs, sample_weight=sw)

    def get_predictions(self, dataset: Dataset) -> np.array:
        logger.info(f'Predictions for {dataset.subset_name}.')
        batch_size = 10000
        outcome_pred = np.zeros_like(dataset.data['outputs'])
        for batch in range(len(dataset) // batch_size + 1):
            subset = deepcopy(dataset)
            for (k, v) in subset.data.items():
                subset.data[k] = v[batch * batch_size:(batch + 1) * batch_size]

            exploded_dataset = self.get_exploded_dataset(subset, min_length=self.lag_features, only_active_entries=False,
                                                         max_length=max(dataset.data['sequence_lengths'][:]))
            inputs = self.get_inputs(
                exploded_dataset, projection_horizon=0, tau=0)
            outcome_pred_batch = self.msm_regressor[0].predict(inputs)

            outcome_pred_batch = outcome_pred_batch.reshape(subset.data['active_entries'].shape[0],
                                                            subset.data['active_entries'].shape[1] - 1,
                                                            self.dim_outcome)
            # First time-step requires two previous outcomes -> duplicating the next prediction
            outcome_pred_batch = np.concatenate(
                [outcome_pred_batch[:, :1, :], outcome_pred_batch], axis=1)
            outcome_pred[batch *
                         batch_size:(batch + 1) * batch_size] = outcome_pred_batch
        return outcome_pred

    def get_autoregressive_predictions(self, dataset: Dataset) -> np.array:
        logger.info(f'Autoregressive Prediction for {dataset.subset_name}.')
        predicted_outputs = np.zeros(
            (len(dataset), self.hparams.dataset.projection_horizon, self.dim_outcome))

        for t in range(1, self.dataset_collection.projection_horizon + 1):
            inputs = self.get_inputs(
                dataset, projection_horizon=self.dataset_collection.projection_horizon - 1, tau=t - 1)
            outcome_pred = self.msm_regressor[t].predict(inputs)
            predicted_outputs[:, t - 1] = outcome_pred

        return predicted_outputs

In [None]:
import logging
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from pytorch_lightning import seed_everything

from src.models.utils import FilteringMlFlowLogger
from src.models.msm import MSM


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
torch.set_default_dtype(torch.double)

# Load the configuration file
config_path = '/work/postresearch/Shared/Researchers/Farbod/cancer/code/config/cancer_sim_MSM.yaml'
args = OmegaConf.load(config_path)
# @hydra.main(config_name=f'config.yaml', config_path='../config/')


def main(args: DictConfig):
    """
    Training / evaluation script for MSMs
    Args:
        args: arguments of run as DictConfig

    Returns: dict with results (one and nultiple-step-ahead RMSEs)
    """

    results = {}

    # Non-strict access to fields
    OmegaConf.set_struct(args, False)
    OmegaConf.register_new_resolver("sum", lambda x, y: x + y, replace=True)
    logger.info('\n' + OmegaConf.to_yaml(args, resolve=True))

    # Initialisation of data to calculate dim_outcomes, dim_treatments, dim_vitals and dim_static_features
    seed_everything(args.exp.seed)
    dataset_collection = instantiate(args.dataset, _recursive_=True)
    # Only binary multilabel regime possible
    assert args.dataset.treatment_mode == 'multilabel'
    dataset_collection.process_data_multi()
    args.model.dim_outcomes = dataset_collection.train_f.data['outputs'].shape[-1]
    args.model.dim_treatments = dataset_collection.train_f.data['current_treatments'].shape[-1]
    args.model.dim_vitals = dataset_collection.train_f.data[
        'vitals'].shape[-1] if dataset_collection.has_vitals else 0
    args.model.dim_static_features = dataset_collection.train_f.data['static_features'].shape[-1]

    # MlFlow Logger
    if args.exp.logging:
        experiment_name = f'{args.model.name}/{args.dataset.name}'
        mlf_logger = FilteringMlFlowLogger(filter_submodels=MSM.possible_model_types, experiment_name=experiment_name,
                                           tracking_uri=args.exp.mlflow_uri)
    else:
        mlf_logger = None

    # ============================== Nominator (treatment propensity network) ==============================
    propensity_treatment = instantiate(
        args.model.propensity_treatment, args, dataset_collection, _recursive_=False)
    mlf_logger.log_hyperparams(propensity_treatment.hparams)
    propensity_treatment.fit()

    # ============================== Denominator (history propensity network) ==============================
    propensity_history = instantiate(
        args.model.propensity_history, args, dataset_collection, _recursive_=False)
    mlf_logger.log_hyperparams(propensity_history.hparams)
    propensity_history.fit()

    # ============================== Initialisation & Training of Encoder ==============================
    msm_regressor = instantiate(args.model.msm_regressor, args, propensity_treatment, propensity_history, dataset_collection,
                                _recursive_=False)
    mlf_logger.log_hyperparams(msm_regressor.hparams)
    msm_regressor.fit()
    encoder_results = {}

    if hasattr(dataset_collection, 'test_cf_one_step'):  # Test one_step_counterfactual rmse
        test_rmse_orig, test_rmse_all, test_rmse_last = \
            msm_regressor.get_normalised_masked_rmse(
                dataset_collection.test_cf_one_step, one_step_counterfactual=True)
        logger.info(f'Test normalised RMSE (all): {test_rmse_all}; '
                    f'Test normalised RMSE (orig): {test_rmse_orig}; '
                    f'Test normalised RMSE (only counterfactual): {test_rmse_last}')
        encoder_results = {
            'encoder_test_rmse_all': test_rmse_all,
            'encoder_test_rmse_orig': test_rmse_orig,
            'encoder_test_rmse_last': test_rmse_last
        }
    elif hasattr(dataset_collection, 'test_f'):  # Test factual rmse
        test_rmse_orig, test_rmse_all = msm_regressor.get_normalised_masked_rmse(
            dataset_collection.test_f)
        logger.info(f'Test normalised RMSE (all): {test_rmse_all}; '
                    f'Test normalised RMSE (orig): {test_rmse_orig}.')
        encoder_results = {
            # 'encoder_val_rmse_all': val_rmse_all,
            # 'encoder_val_rmse_orig': val_rmse_orig,
            'encoder_test_rmse_all': test_rmse_all,
            'encoder_test_rmse_orig': test_rmse_orig
        }

    mlf_logger.log_metrics(encoder_results) if args.exp.logging else None
    results.update(encoder_results)

    test_rmses = {}
    if hasattr(dataset_collection, 'test_cf_treatment_seq'):  # Test n_step_counterfactual rmse
        test_rmses = msm_regressor.get_normalised_n_step_rmses(
            dataset_collection.test_cf_treatment_seq)
    elif hasattr(dataset_collection, 'test_f_multi'):  # Test n_step_factual rmse
        test_rmses = msm_regressor.get_normalised_n_step_rmses(
            dataset_collection.test_f_multi)
    test_rmses = {f'{k+2}-step': v for (k, v) in enumerate(test_rmses)}

    logger.info(f'Test normalised RMSE (n-step prediction): {test_rmses}')
    decoder_results = {('decoder_test_rmse_' + k)                       : v for (k, v) in test_rmses.items()}

    mlf_logger.log_metrics(decoder_results) if args.exp.logging else None
    results.update(decoder_results)

    mlf_logger.experiment.set_terminated(
        mlf_logger.run_id) if args.exp.logging else None

    return results


if __name__ == "__main__":
    main(args)

INFO:__main__:
dataset:
  _target_: src.data.SyntheticCancerDatasetCollection
  name: tumor_generator
  coeff: 4
  chemo_coeff: 4
  radio_coeff: 4
  seed: 100
  num_patients:
    train: 10000
    val: 1000
    test: 1000
  window_size: 15
  lag: 0
  max_seq_length: 60
  projection_horizon: 5
  cf_seq_mode: sliding_treatment
  val_batch_size: 512
  treatment_mode: multilabel
model:
  dim_treatments: 4
  dim_vitals: 0
  dim_static_features: 1
  dim_outcomes: 1
  min_length: 1
  lag_features: 1
  name: MSM
  propensity_treatment:
    _target_: src.models.msm.MSMPropensityTreatment
  propensity_history:
    _target_: src.models.msm.MSMPropensityHistory
  msm_regressor:
    _target_: src.models.msm.MSMRegressor
exp:
  unscale_rmse: true
  percentage_rmse: true
  seed: 100
  max_epochs: 100
  gpus:
  - 0
  logging: true
  mlflow_uri: http://127.0.0.1:8081
  alpha: 1.0
  update_alpha: false
  alpha_rate: exp
  balancing: null
  bce_weight: false
  weights_ema: null
  beta: 0.99

Global seed s

Call to simulate counterfactuals data


  (1 + rho * np.log(K / (counterfactual_cancer_volume[current_t] + 1e-07) + 1e-07) -
100%|██████████| 1000/1000 [00:19<00:00, 50.24it/s]
INFO:src.data.cancer_sim.dataset:Processing train dataset before training
INFO:src.data.cancer_sim.dataset:Shape of processed train data: {'cancer_volume': (10000, 60), 'chemo_dosage': (10000, 60), 'radio_dosage': (10000, 60), 'chemo_application': (10000, 60), 'radio_application': (10000, 60), 'chemo_probabilities': (10000, 60), 'radio_probabilities': (10000, 60), 'sequence_lengths': (10000,), 'death_flags': (10000, 60), 'recovery_flags': (10000, 60), 'patient_types': (10000,), 'prev_treatments': (10000, 59, 2), 'current_treatments': (10000, 59, 2), 'current_covariates': (10000, 59, 2), 'outputs': (10000, 59, 1), 'active_entries': (10000, 59, 1), 'unscaled_outputs': (10000, 59, 1), 'prev_outputs': (10000, 59, 1), 'static_features': (10000, 1)}
INFO:src.data.cancer_sim.dataset:Processing val dataset before training
INFO:src.data.cancer_sim.dataset:Shap

🏃 View run useful-crow-778 at: http://127.0.0.1:8081/#/experiments/997620564669772123/runs/c71385cdc3c14e718e8138b77ee63478
🧪 View experiment at: http://127.0.0.1:8081/#/experiments/997620564669772123


{'encoder_test_rmse_all': 2.7317362036025155,
 'encoder_test_rmse_orig': 2.1275290025545286,
 'encoder_test_rmse_last': 1.966666032411937,
 'decoder_test_rmse_2-step': 3.8369426193723104,
 'decoder_test_rmse_3-step': 4.102038264183819,
 'decoder_test_rmse_4-step': 4.087582153521721,
 'decoder_test_rmse_5-step': 3.9701058554085225,
 'decoder_test_rmse_6-step': 3.7720720597800987}