In [2]:
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
import logging

# Setup paths
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.getcwd()))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

# Setup directories
MODEL_DIR = os.path.join(os.getcwd(), 'model')
RESULTS_DIR = os.path.join(os.getcwd(), 'results')
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Imports
from preprocessing.data_container import DataContainer
from utils.evaluation import cindex_score

In [51]:
from sklearn.base import BaseEstimator, RegressorMixin
import numpy as np
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
import torch
from lifelines.utils import concordance_index
from sklearn.utils.validation import check_X_y, check_is_fitted

logger = logging.getLogger(__name__)



class DeepSurvNet(nn.Module):
    """Neural network architecture for DeepSurv"""

    def __init__(self, n_features, hidden_layers=[32, 16], dropout=0.2):
        super().__init__()
        layers = []
        prev_size = n_features
        self.model = None

        # Build hidden layers
        for size in hidden_layers:
            layers.extend([
                nn.Linear(prev_size, size),
                nn.ReLU(),
                # BatchNorm1d nur bei größeren Batches verwenden
                nn.Dropout(dropout)
            ])
            prev_size = size

        # Output layer (1 node for hazard prediction)
        layers.append(nn.Linear(prev_size, 1))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)



from sklearn.base import BaseEstimator, RegressorMixin
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from sklearn.utils.validation import check_X_y, check_is_fitted
from lifelines.utils import concordance_index

class DeepSurvModel(BaseEstimator, RegressorMixin):
    def __init__(self, n_features=None, hidden_layers=[16, 16], dropout=0.5,
                 learning_rate=0.01, device='cpu', random_state=123):
        self.n_features = n_features
        self.hidden_layers = hidden_layers
        self.dropout = dropout
        self.learning_rate = learning_rate
        self.device = device if torch.cuda.is_available() and device == 'cuda' else 'cpu'
        self.random_state = random_state
        torch.manual_seed(random_state)
        np.random.seed(random_state)

        self.scaler = StandardScaler()
        self.model = None
        self.is_fitted_ = False
        self.training_history_ = {'train_loss': [], 'val_loss': []}
        self.n_features_in_ = None

    def fit(self, X, y, num_epochs=10):
        # Input validation for X and y
        X, y = check_X_y(X, y, accept_sparse=True)
        
        self.n_features_in_ = X.shape[1]
        self.init_network(self.n_features_in_)
        self.model.to(self.device)
        
        # Prepare and scale data
        train_dataset_ = self._prepare_data(X, y)
        train_loader_ = DataLoader(train_dataset_, batch_size=128, shuffle=True)

        # Training loop
        for epoch in range(num_epochs):
            self.model.train()
            epoch_loss_ = 0
            n_batches_ = 0
            for X_batch, time_batch, event_batch in train_loader_:
                loss = self._train_step(X_batch, time_batch, event_batch)
                epoch_loss_ += loss
                n_batches_ += 1
            avg_train_loss = epoch_loss_ / n_batches_
            self.training_history_['train_loss'].append(avg_train_loss)
        
        self.is_fitted_ = True
        return self

    def predict(self, X):
        check_is_fitted(self, 'is_fitted_')
        X = torch.FloatTensor(self.scaler.transform(X)).to(self.device)
        self.model.eval()
        with torch.no_grad():
            risk_scores = self.model(X).cpu().numpy()
        return risk_scores.flatten()

    def score(self, X, y):
        check_is_fitted(self, 'is_fitted_')
        preds = self.predict(X)
        return self.c_index(-preds, y)

    def get_params(self, deep=True):
        return {
            "n_features": self.n_features,
            "hidden_layers": self.hidden_layers,
            "dropout": self.dropout,
            "learning_rate": self.learning_rate,
            "device": self.device,
            "random_state": self.random_state,
        }

    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self
    
    def clone(self): 
        super(self).clone()

    def _prepare_data(self, X, y):
        X_scaled = self.scaler.fit_transform(X)
        times = np.ascontiguousarray(y['time']).astype(np.float32)
        event_field = 'status' if 'status' in y.dtype.names else 'event'
        events = np.ascontiguousarray(y[event_field]).astype(np.float32)
        
        X_tensor = torch.FloatTensor(X_scaled).to(self.device)
        time_tensor = torch.FloatTensor(times).to(self.device)
        event_tensor = torch.FloatTensor(events).to(self.device)
        return TensorDataset(X_tensor, time_tensor, event_tensor)

    def _negative_log_likelihood(self, risk_pred, times, events):
        _, idx = torch.sort(times, descending=True)
        risk_pred = risk_pred[idx]
        events = events[idx]
        log_risk = risk_pred
        #print("Risk predictions before exp:", risk_pred)
        risk = torch.exp(log_risk)
        cumsum_risk = torch.cumsum(risk, dim=0)
        log_cumsum_risk = torch.log(cumsum_risk + 1e-10)
        event_loss = events * (log_risk - log_cumsum_risk)
        return -torch.mean(event_loss)

    def _train_step(self, X, times, events):
        self.optimizer.zero_grad()
        risk_pred = self.model(X)
        loss = self._negative_log_likelihood(risk_pred, times, events)
        loss.backward()
        #print([param.grad.norm().item() for param in self.model.parameters() if param.grad is not None])

        self.optimizer.step()
        return loss.item()

    def c_index(self, risk_pred, y):
        if not isinstance(y, np.ndarray):
            y = y.detach().cpu().numpy()
        event_field = 'status' if 'status' in y.dtype.names else 'event'
        time = y['time']
        event = y[event_field]
        if not isinstance(risk_pred, np.ndarray):
            risk_pred = risk_pred.detach().cpu().numpy()
        return concordance_index(time, risk_pred, event)

    def init_network(self, n_features):
        self.model = DeepSurvNet(n_features=n_features, hidden_layers=self.hidden_layers).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)


In [100]:
"""
Resampling Module für Cross-Validation und Hyperparameter Tuning.
"""

import numpy as np
import pandas as pd
from sklearn.model_selection import LeaveOneGroupOut, KFold, GridSearchCV, RandomizedSearchCV
from sklearn.base import clone
from itertools import product
import logging
from utils.evaluation import cindex_score


logger = logging.getLogger(__name__)


def _get_survival_subset(y, indices):
    """Extract survival data subset while preserving structure"""
    subset = np.empty(len(indices), dtype=y.dtype)
    event_field = 'status' if 'status' in y.dtype.names else 'event'
    subset[event_field] = y[event_field][indices]
    subset['time'] = y['time'][indices]
    return subset

def _aggregate_results(results):
    """Aggregates nested CV results."""
    scores = [res['test_score'] for res in results]
    mean_score = np.mean(scores)
    std_score = np.std(scores)

    logger.info(f"Aggregated results:")
    logger.info(f"Mean score: {mean_score:.3f} ± {std_score:.3f}")
    logger.info(f"Individual scores: {scores}")

    return {
        'mean_score': mean_score,
        'std_score': std_score,
        'fold_results': results
    }

def nested_resampling(estimator, X, y, groups, param_grid, ss = GridSearchCV, outer_cv = LeaveOneGroupOut(), inner_cv = LeaveOneGroupOut(), scoring = None):
    logger.info("Starting nested resampling...")
    logger.info(f"Data shape: X={X.shape}, groups={len(np.unique(groups))} unique")

    outer_results = []

    for i, (train_idx, test_idx) in enumerate(outer_cv.split(X, y, groups)):
        logger.info(f"\nOuter fold {i+1}")

        X_train = X.iloc[train_idx]
        X_test = X.iloc[test_idx]
        y_train = _get_survival_subset(y, train_idx)
        y_test = _get_survival_subset(y, test_idx)
        train_groups = groups[train_idx] if groups is not None else None

        test_cohort = groups[test_idx][0] if groups is not None else None
        logger.info(f"Test cohort: {test_cohort}")
        
        inner_gcv = ss(estimator, param_grid, cv = inner_cv, refit = True, n_jobs=-1, verbose = 2)
        inner_results = inner_gcv.fit(X_train, y_train, groups = train_groups)
        
        inner_cv_results = inner_results.cv_results_
        inner_best_params = inner_results.best_params_
        
        outer_model = inner_results.best_estimator_.named_steps['model']
        test_score = outer_model.score(X_test, y_test)

        logger.info(f"Best parameters: {inner_best_params}")
        logger.info(f"Test score: {test_score:.3f}")

        outer_results.append({
            'test_cohort': test_cohort,
            'test_score': test_score,
            'best_params': inner_best_params,
            'inner_cv_results': inner_cv_results
        })

    return _aggregate_results(outer_results)

In [117]:
import os
import numpy as np
import pandas as pd
import logging
from sklearn.pipeline import Pipeline
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from utils.resampling import nested_resampling
from preprocessing.data_container import DataContainer
import pickle

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class ModellingProcess(): 
    def __init__(self) -> None:
        self.outer_cv = LeaveOneGroupOut()
        self.inner_cv = LeaveOneGroupOut()
        self.ss = GridSearchCV
        self.pipe = None
        self.cmplt_model = None
        self.nrs = None
        self.X = None
        self.y = None
        self.groups = None
        pass
            
    def prepare_data(self, data_config, root): 
        self.dc = DataContainer(data_config=data_config, project_root=root)
        self.X, self.y = self.dc.load_data()
        self.groups = self.dc.get_groups()
    
    def do_modelling(self, pipeline_steps, config): 
        if config.get("params_mp", None) is not None: 
            self.set_params(config['params_mp'])
        
        err, mes = self._check_modelling_prerequs(pipeline_steps)
        if err: 
           logger.error("Requirements setup error: %s", mes)
           raise Exception(mes)
        else: 
            self.pipe = Pipeline(pipeline_steps) 
        
        param_grid, do_nested_resampling, refit_hp_tuning = self._get_config_vals(config)
        
        # TODO: do this
        #param_grid = self._prefix_pipeline_params(param_grid, pipeline_steps)
        #print(param_grid)

        try:
            logger.info("Start model training...")
            logger.info(f"Input data shape: X={self.X.shape}")
                        
            if do_nested_resampling: 
                logger.info("Nested resampling...")
                self.nrs = nested_resampling(self.pipe, self.X, self.y, self.groups, param_grid, self.ss, self.outer_cv, self.inner_cv)
        except Exception as e:
            logger.error(f"Error during nested resampling: {str(e)}")
            raise
        
        if refit_hp_tuning: 
            try:
                logger.info("Do HP Tuning for complete model; refit + set complete model")
                self.cmplt_model = self.fit_cmplt_model(param_grid)   
            except Exception as e:
                logger.error(f"Error during complete model training: {str(e)}")
                raise    
        elif refit_hp_tuning is False and do_nested_resampling is False: 
            logger.info("Fit complete model wo. HP tuning (on default params)")
            self.cmplt_model = self.pipe.fit(self.X, self.y)
        
        return self.nrs
    
    
    def fit_cmplt_model(self, param_grid): 
        logger.info("Do HP Tuning for complete model")
        res = self.ss(estimator=self.pipe, param_grid=param_grid, cv=self.outer_cv, n_jobs=-1, verbose = 2, refit = True)
        res.fit(self.X, self.y, groups = self.groups)
        return res.best_estimator_.named_steps['model']  
    
    
    def save_results(self, path, fname, model = None, cv_results = None,): 
        """Save model and results"""
        if model is None: 
            raise Warning("Won't save any model, since its not provided")   
        else:  
        # Create directories
            model_dir = os.path.join(path, 'model')
            os.makedirs(model_dir, exist_ok=True)
            with open(os.path.join(model_dir, f"{fname}.pkl"), 'wb') as f:
                pickle.dump(self.model, f)
        
        if cv_results is None: 
            raise Warning("Won't save any cv results, since its not provided")
        else: 
            results_dir = os.path.join(path, 'results')
            os.makedirs(results_dir, exist_ok=True)
            results_file = os.path.join(results_dir, f"{fname}_cv_results.csv")
            pd.DataFrame(self.cv_results).to_csv(results_file)
            logger.info(f"Saved CV results to {results_file}")


    def save_pipe(self): 
        pass
    
    def load_pipe(self): 
        pass
    
    def load_model(self): 
        pass

    
    def _check_modelling_prerequs(self, pipeline_steps): 
        err = False
        mes = ""
        if self.X is None or self.y is None: 
            mes = mes + "1) Please call prepare_data() with your preferred config or set X, y, and groups as attributes of your modelling process instance"
            err = True
        if not any('model' in tup for tup in pipeline_steps): 
            mes = mes + "2) Caution! Your pipline must include a named step for the model of the form ('model', <Instantiated Model Object>)"
            err = True
        return err, mes

    def _get_config_vals(self, config): 
        if config.get("params_cv", None) is None: 
            logger.warning("No param grid for (nested) resampling detected - will fit model with default HPs and on complete data")
            return None, False, False
        return config['params_cv'], config.get('do_nested_resampling', True) , config.get('refit', True)
    
    def set_params(self, params):
        for key, value in params.items():
            setattr(self, key, value) 
            
            
    def _prefix_pipeline_params(self, params, pipeline_steps):
        """Add pipeline component prefixes to parameters if not already present"""
        prefixed_params = {}
        for param, value in params.items():
            if '__' not in param:
                # Find the relevant step in pipeline_steps
                step_found = False
                for step_name, _ in pipeline_steps:
                    try:
                        # Try setting the parameter to check if it belongs to this step
                        self.model.named_steps[step_name].get_params()[param]
                        prefixed_params[f"{step_name}__{param}"] = value
                        step_found = True
                        break
                    except KeyError:
                        continue
                if not step_found:
                    raise ValueError(f"Could not determine pipeline step for parameter: {param}")
            else:
                prefixed_params[param] = value
        return prefixed_params

In [59]:
from sksurv.ensemble import RandomSurvivalForest

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Data configuration
DATA_CONFIG = {
    'use_pca': False,
    'gene_type': 'intersection',
    'use_imputed': True,
    'use_cohorts': False
}

# Model configuration
MODEL_CONFIG = {
    'params_cv': {
        'model__n_estimators': [2, 4],
        'model__min_samples_split': [10]
    },
    'refit': True, 
    'do_nested_resampling': True}

In [7]:
MODEL_CONFIG['params_cv']

{'model__n_estimators': [2, 4], 'model__min_samples_split': [10]}

In [None]:
rsf_pipe = [('model', RandomSurvivalForest())]

mp = ModellingProcess()

In [None]:
mp.prepare_data(DATA_CONFIG, PROJECT_ROOT)

2024-11-14 20:52:33,166 - INFO - Loading data...
2024-11-14 20:56:50,400 - INFO - Loaded data: 1091 samples, 13214 features


In [114]:
X = mp.X
y = mp.y
groups = mp.groups

In [118]:
mp_2 = ModellingProcess()
mp_2.X = X
mp_2.y = y
mp_2.groups = groups

ds_pipe = [('model', DeepSurvModel())]

In [119]:
param_grid  = {
        'model__hidden_layers': [[16, 16]],
        'model__learning_rate': [0.01],
        'model__batch_size': [64, 256], 
        'model__n_features' : [mp_2.X.shape[1]], 
        'model__num_epochs': [10]
    }


In [123]:
MODEL_CONFIG = {
    'params_cv': {
        'model__n_estimators': [2, 4],
        'model__min_samples_split': [10]
    },
    'refit': False, 
    'do_nested_resampling': False}

mp_2.do_modelling(rsf_pipe, MODEL_CONFIG)

2024-11-14 23:21:21,560 - INFO - Start model training...
2024-11-14 23:21:21,560 - INFO - Input data shape: X=(1091, 13214)
2024-11-14 23:21:21,560 - INFO - Fit complete model wo. HP tuning (on default params)


KeyboardInterrupt: 

In [43]:
ds = DeepSurvModel()

param_grid  = {
        'hidden_layers': [[16, 16]],
        'learning_rate': [0.01]
    }


In [39]:
ds.fit(mp_2.X, mp_2.y)

Risk predictions before exp: tensor([[ 0.1237],
        [ 0.0863],
        [ 0.0615],
        [-0.1040],
        [ 0.0559],
        [-0.0805],
        [ 0.0840],
        [-0.0597],
        [-0.0068],
        [ 0.1262],
        [-0.0063],
        [-0.0494],
        [ 0.0115],
        [-0.0733],
        [-0.1470],
        [ 0.1218],
        [-0.1353],
        [-0.0154],
        [-0.1491],
        [-0.0927],
        [ 0.0146],
        [-0.1625],
        [ 0.0091],
        [-0.0495],
        [ 0.0360],
        [-0.0128],
        [ 0.0131],
        [-0.2086],
        [ 0.0088],
        [-0.1718],
        [-0.0274],
        [-0.0354],
        [-0.0292],
        [-0.0516],
        [ 0.0401],
        [-0.0518],
        [ 0.0006],
        [ 0.0558],
        [-0.0063],
        [ 0.0539],
        [ 0.0325],
        [-0.0767],
        [-0.1875],
        [-0.0610],
        [-0.0441],
        [ 0.1265],
        [-0.1549],
        [-0.0101],
        [ 0.0239],
        [ 0.0300],
        [-0.1276],
  

In [44]:
ds.score(mp_2.X, mp_2.y)

NotFittedError: This StandardScaler instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

In [45]:
gs = GridSearchCV(ds, param_grid=param_grid, cv=KFold())
gs.fit(mp_2.X, mp_2.y)

Risk predictions before exp: tensor([[ 0.2042],
        [ 0.3279],
        [ 0.1914],
        [ 0.0943],
        [-0.0227],
        [ 0.1245],
        [ 0.0963],
        [ 0.1092],
        [ 0.1252],
        [ 0.1307],
        [ 0.1744],
        [ 0.2523],
        [ 0.1032],
        [ 0.1210],
        [ 0.1102],
        [ 0.2304],
        [ 0.1323],
        [ 0.1862],
        [ 0.0695],
        [ 0.1448],
        [ 0.0947],
        [ 0.2575],
        [ 0.0869],
        [ 0.0414],
        [ 0.1057],
        [ 0.1749],
        [ 0.1534],
        [ 0.0428],
        [ 0.1768],
        [ 0.0785],
        [ 0.1527],
        [-0.0005],
        [ 0.0835],
        [ 0.2211],
        [ 0.1144],
        [ 0.2892],
        [ 0.3105],
        [ 0.1179],
        [ 0.1144],
        [ 0.1155],
        [ 0.1246],
        [ 0.1280],
        [ 0.0325],
        [ 0.1635],
        [ 0.1589],
        [ 0.0950],
        [ 0.1707],
        [ 0.1223],
        [ 0.0720],
        [ 0.1192],
        [ 0.1144],
  



Risk predictions before exp: tensor([[ 0.1130],
        [ 0.0412],
        [ 0.0309],
        [ 0.1678],
        [ 0.1243],
        [ 0.1358],
        [ 0.0125],
        [ 0.0980],
        [ 0.1622],
        [ 0.1231],
        [ 0.0679],
        [ 0.1332],
        [ 0.3408],
        [ 0.1673],
        [ 0.1411],
        [ 0.1121],
        [ 0.1755],
        [ 0.0650],
        [ 0.0417],
        [ 0.2066],
        [ 0.0927],
        [ 0.1861],
        [ 0.2417],
        [ 0.0143],
        [ 0.1517],
        [ 0.0354],
        [ 0.1163],
        [ 0.1106],
        [ 0.1103],
        [ 0.0895],
        [ 0.1061],
        [ 0.1053],
        [ 0.1096],
        [ 0.1304],
        [ 0.1646],
        [ 0.1070],
        [ 0.1122],
        [-0.0017],
        [ 0.0654],
        [ 0.0999],
        [ 0.1037],
        [ 0.0990],
        [ 0.0625],
        [ 0.0855],
        [ 0.0731],
        [ 0.2374],
        [ 0.1780],
        [ 0.0939],
        [ 0.1578],
        [ 0.1284],
        [ 0.1214],
  



Risk predictions before exp: tensor([[ 1.6901e-01],
        [ 1.2360e-01],
        [ 1.3673e-01],
        [ 1.3901e-02],
        [ 9.1847e-02],
        [ 1.6318e-01],
        [ 1.2250e-01],
        [ 6.7454e-02],
        [ 1.3608e-01],
        [ 3.4190e-01],
        [ 1.6690e-01],
        [ 2.8805e-01],
        [ 1.0851e-01],
        [ 1.7624e-01],
        [ 6.2412e-02],
        [ 1.4313e-01],
        [ 4.1877e-02],
        [ 5.6697e-02],
        [ 2.0446e-01],
        [ 2.4112e-01],
        [ 1.0853e-01],
        [-7.2402e-03],
        [ 1.3504e-02],
        [ 1.5136e-01],
        [ 2.8785e-02],
        [ 2.6577e-01],
        [ 2.3270e-01],
        [ 1.7549e-01],
        [ 1.3744e-01],
        [ 1.9574e-01],
        [ 1.1099e-01],
        [ 7.2577e-02],
        [ 3.6470e-02],
        [ 9.4274e-02],
        [ 1.1665e-01],
        [ 1.1251e-01],
        [-4.2927e-02],
        [ 1.0690e-01],
        [ 5.4909e-02],
        [ 1.2724e-01],
        [ 1.0877e-01],
        [-3.9651e-03],
     



Risk predictions before exp: tensor([[ 0.1239],
        [ 0.1691],
        [ 0.1244],
        [ 0.1359],
        [ 0.0130],
        [ 0.0931],
        [ 0.1622],
        [ 0.1225],
        [ 0.0675],
        [ 0.1353],
        [ 0.3407],
        [ 0.1663],
        [ 0.2859],
        [ 0.1084],
        [ 0.1762],
        [ 0.0622],
        [ 0.0300],
        [ 0.1438],
        [ 0.2039],
        [ 0.0563],
        [ 0.2422],
        [ 0.1106],
        [-0.0060],
        [ 0.1080],
        [ 0.0132],
        [ 0.1505],
        [ 0.1396],
        [ 0.2648],
        [ 0.0286],
        [ 0.2313],
        [ 0.1759],
        [ 0.1374],
        [ 0.1958],
        [ 0.1109],
        [ 0.0720],
        [ 0.0366],
        [ 0.1165],
        [ 0.0941],
        [-0.0434],
        [ 0.1130],
        [ 0.1023],
        [ 0.1349],
        [ 0.0563],
        [ 0.2315],
        [ 0.1083],
        [ 0.0658],
        [ 0.1214],
        [ 0.0869],
        [ 0.1173],
        [ 0.3530],
        [ 0.1008],
  



Risk predictions before exp: tensor([[ 0.1241],
        [ 0.1607],
        [ 0.1367],
        [ 0.0137],
        [ 0.0932],
        [ 0.1624],
        [ 0.0675],
        [ 0.1225],
        [ 0.1356],
        [ 0.3409],
        [ 0.1659],
        [ 0.2863],
        [ 0.1093],
        [ 0.0293],
        [ 0.1435],
        [ 0.0623],
        [ 0.0564],
        [ 0.2036],
        [ 0.2425],
        [ 0.1113],
        [-0.0068],
        [ 0.1074],
        [ 0.1395],
        [ 0.2655],
        [ 0.0285],
        [ 0.2313],
        [ 0.1757],
        [ 0.1372],
        [ 0.1948],
        [ 0.1103],
        [ 0.0722],
        [ 0.0364],
        [ 0.0941],
        [ 0.1168],
        [-0.0437],
        [ 0.1126],
        [ 0.1021],
        [ 0.1706],
        [ 0.0433],
        [ 0.1352],
        [ 0.0555],
        [ 0.2311],
        [ 0.1083],
        [ 0.0707],
        [ 0.0870],
        [ 0.1215],
        [ 0.1175],
        [ 0.3537],
        [ 0.0804],
        [ 0.1568],
        [ 0.1278],
  



Risk predictions before exp: tensor([[ 0.1677],
        [ 0.1175],
        [ 0.0429],
        [ 0.1572],
        [ 0.2749],
        [ 0.3947],
        [ 0.1736],
        [-0.0272],
        [ 0.0822],
        [ 0.1079],
        [ 0.1526],
        [ 0.1215],
        [ 0.0479],
        [ 0.0232],
        [ 0.1088],
        [ 0.2021],
        [ 0.1335],
        [ 0.1171],
        [ 0.1732],
        [ 0.1624],
        [ 0.1781],
        [ 0.0941],
        [ 0.1179],
        [ 0.0767],
        [ 0.0233],
        [ 0.1533],
        [ 0.1054],
        [ 0.0367],
        [ 0.0468],
        [ 0.1762],
        [ 0.0127],
        [ 0.0627],
        [ 0.0759],
        [ 0.0506],
        [ 0.1394],
        [ 0.2117],
        [ 0.1385],
        [ 0.0894],
        [ 0.2799],
        [ 0.0588],
        [ 0.1896],
        [ 0.0324],
        [ 0.1998],
        [ 0.1300],
        [ 0.2213],
        [ 0.0981],
        [ 0.1998],
        [ 0.1356],
        [ 0.1007],
        [ 0.2251],
        [ 0.1240],
  