In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class HybridModel(nn.Module):
    def __init__(self, diffusion_imputer, categorical_features_info, num_layers, device="cuda"):
        super().__init__()
        self.diffusion_imputer = diffusion_imputer
        self.categorical_features_info = categorical_features_info
        self.device = device

        # Flatten indices for categorical features
        self.categorical_indices = [idx for info in categorical_features_info for idx in info['indices']]
        self.target_categorical_indices = [idx for info in categorical_features_info for idx in info['indices_target']]

        # Define layers
        self.first_linear_layers = nn.ModuleList([
            nn.Linear(info['embedding_dim'], info['hidden_dim'])
            for info in categorical_features_info
        ])
        self.second_linear_layers = nn.ModuleList([
            nn.Linear(info['hidden_dim'], info['hidden_dim'])
            for info in categorical_features_info
        ])
        num_total_features_and_hidden_dims = sum(info['hidden_dim'] for info in categorical_features_info)

        self.first_layer_ff = nn.Linear(num_total_features_and_hidden_dims, num_total_features_and_hidden_dims)
        self.classification_layers = nn.ModuleList([
            nn.Linear(num_total_features_and_hidden_dims, num_total_features_and_hidden_dims) for _ in range(num_layers)
        ])
        self.output_layers = nn.ModuleList([
            nn.Linear(num_total_features_and_hidden_dims, info['num_classes']) for info in categorical_features_info
        ])

    def forward(self, data, target_data):
        imputed_samples, _, imputation_mask = self.diffusion_imputer.eval_with_grad(data)
        # imputation_results = torch.where(imputation_mask != 0, imputed_samples, data)

        # Use precomputed indices to split data
        target_data_categorical = target_data[:, :, self.target_categorical_indices]
        target_data_numerical = target_data[:, :, [i for i in range(target_data.shape[2]) if i not in self.target_categorical_indices]]
        imputation_results_categorical = imputed_samples[:, :, self.categorical_indices]
        imputation_results_numerical = imputed_samples[:, :, [i for i in range(data.shape[2]) if i not in self.categorical_indices]]
        print(imputation_results_categorical[0])
        # Process categorical data
        start_idx = 0
        class_predictions = []
        for first_layer, second_layer, info in zip(self.first_linear_layers, self.second_linear_layers, self.categorical_features_info):
            end_idx = start_idx + info['embedding_dim']
            feature_data = imputation_results_categorical[:, :, start_idx:end_idx]
            feature_data = F.relu(first_layer(feature_data))
            feature_data = F.relu(second_layer(feature_data))
            class_predictions.append(feature_data)
            start_idx = end_idx
        
        
        combined_data_categorical = torch.cat(class_predictions, dim=2)
        combined_data_categorical = F.relu(self.first_layer_ff(combined_data_categorical))
        for layer in self.classification_layers:
            combined_data_categorical = F.relu(layer(combined_data_categorical))

        print()

        final_outputs_categorical = torch.cat([
            output_layer(combined_data_categorical).unsqueeze(2) for output_layer in self.output_layers
        ], dim=2)

        return imputation_results_numerical, final_outputs_categorical, target_data_numerical, target_data_categorical

    def loss_func(self, outputs, targets):
        imputation_results_numerical, final_outputs_categorical = outputs
        target_numerical, target_categorical = targets
        
        loss_numerical = F.mse_loss(imputation_results_numerical, target_numerical)
        final_outputs_categorical = final_outputs_categorical.view(-1, final_outputs_categorical.shape[2], final_outputs_categorical.shape[3])
        target_categorical = target_categorical.view(-1, target_categorical.shape[2])
        target_categorical = target_categorical.long().to(self.device)

        loss_categorical = sum(F.cross_entropy(final_outputs_categorical[:, i, :], target_categorical[:, i]) for i in range(final_outputs_categorical.shape[1]))
        
        print(target_categorical[0])
        print(final_outputs_categorical[0])
        print(loss_numerical)
        print(loss_categorical)
        # return loss_numerical + loss_categorical
        # return loss_categorical 
        return loss_numerical + loss_categorical

    def eval(self, data, imputation_mask):
        # self.eval()
        with torch.no_grad():
            imputed_samples, _, _ = self.diffusion_imputer.eval_with_grad(data, imputation_mask)
            imputation_results = torch.where(imputation_mask != 0, imputed_samples, data)

            imputation_results_categorical = imputation_results[:, :, self.categorical_indices]
            imputation_results_numerical = imputation_results[:, :, [i for i in range(data.shape[2]) if i not in self.categorical_indices]]

            start_idx = 0
            class_predictions = []
            for first_layer, second_layer, info in zip(self.first_linear_layers, self.second_linear_layers, self.categorical_features_info):
                end_idx = start_idx + info['embedding_dim']
                feature_data = imputation_results_categorical[:, :, start_idx:end_idx]
                feature_data = F.relu(first_layer(feature_data))
                feature_data = F.relu(second_layer(feature_data))
                class_predictions.append(feature_data)
                start_idx = end_idx

            combined_data_categorical = torch.cat(class_predictions, dim=2)
            combined_data_categorical = F.relu(self.first_layer_ff(combined_data_categorical))
            for layer in self.classification_layers:
                combined_data_categorical = F.relu(layer(combined_data_categorical))

            final_outputs_categorical = torch.cat([
                output_layer(combined_data_categorical).unsqueeze(2) for output_layer in self.output_layers
            ], dim=2)
            final_outputs_categorical = torch.argmax(final_outputs_categorical, dim=2)

            final_outputs = torch.cat([imputation_results_numerical, final_outputs_categorical], dim=2)
            return final_outputs


In [29]:
import numpy as np

# Example initial data setup
data = {
    'outputs': np.array([
        [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]],  # Patient 1
        [[16, 17, 18], [19, 20, 21], [22, 23, 24], [25, 26, 27], [28, 29, 30]]  # Patient 2
    ]),
    'prev_outputs': np.array([
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]],  # Patient 1
        [[15, 16, 17], [18, 19, 20], [21, 22, 23], [24, 25, 26], [27, 28, 29]]  # Patient 2
    ]),
    'sequence_lengths': np.array([5, 5]),  # Both patients have full sequences
    'active_entries': np.ones((2, 5, 3)),  # All entries are active
    'current_treatments': np.array([
        [[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]],  # Patient 1
        [[5, 5, 5], [6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9]]  # Patient 2
    ]),
    'prev_treatments': np.array([
        [[0, 1, 0], [1, 0, 1], [0, 1, 0], [1, 0, 1], [0, 1, 0]],  # Patient 1
        [[1, 0, 1], [0, 1, 0], [1, 0, 1], [0, 1, 0], [1, 0, 1]]  # Patient 2
    ]),
    'static_features': np.array([
        [0.1, 0.2, 0.3],  # Patient 1
        [0.4, 0.5, 0.6]  # Patient 2
    ])
}

# Assume scaling params for unscaled outputs
scaling_params = {
    'output_means': np.array([1, 2, 3]),
    'output_stds': np.array([0.5, 0.5, 0.5])
}

# Projection horizon
projection_horizon = 2

In [30]:

def explode_trajectories(data, projection_horizon, scaling_params):
    outputs = data['outputs']
    prev_outputs = data['prev_outputs']
    sequence_lengths = data['sequence_lengths']
    active_entries = data['active_entries']
    current_treatments = data['current_treatments']
    previous_treatments = data['prev_treatments']
    static_features = data['static_features']

    num_patients, max_seq_length, num_features = outputs.shape
    num_seq2seq_rows = num_patients * max_seq_length

    seq2seq_previous_treatments = np.zeros((num_seq2seq_rows, max_seq_length, previous_treatments.shape[-1]))
    seq2seq_current_treatments = np.zeros((num_seq2seq_rows, max_seq_length, current_treatments.shape[-1]))
    seq2seq_static_features = np.zeros((num_seq2seq_rows, static_features.shape[-1]))
    seq2seq_outputs = np.zeros((num_seq2seq_rows, max_seq_length, outputs.shape[-1]))
    seq2seq_prev_outputs = np.zeros((num_seq2seq_rows, max_seq_length, prev_outputs.shape[-1]))
    seq2seq_active_entries = np.zeros((num_seq2seq_rows, max_seq_length, active_entries.shape[-1]))
    seq2seq_sequence_lengths = np.zeros(num_seq2seq_rows)

    total_seq2seq_rows = 0  # we use this to shorten any trajectories later

    for i in range(num_patients):
        sequence_length = int(sequence_lengths[i])

        for t in range(projection_horizon, sequence_length):  # shift outputs back by 1
            seq2seq_active_entries[total_seq2seq_rows, :(t + 1), :] = active_entries[i, :(t + 1), :]
            seq2seq_previous_treatments[total_seq2seq_rows, :(t + 1), :] = previous_treatments[i, :(t + 1), :]
            seq2seq_current_treatments[total_seq2seq_rows, :(t + 1), :] = current_treatments[i, :(t + 1), :]
            seq2seq_outputs[total_seq2seq_rows, :(t + 1), :] = outputs[i, :(t + 1), :]
            seq2seq_prev_outputs[total_seq2seq_rows, :(t + 1), :] = prev_outputs[i, :(t + 1), :]
            seq2seq_sequence_lengths[total_seq2seq_rows] = t + 1
            seq2seq_static_features[total_seq2seq_rows] = static_features[i]

            total_seq2seq_rows += 1

    # Filter everything shorter
    seq2seq_previous_treatments = seq2seq_previous_treatments[:total_seq2seq_rows, :, :]
    seq2seq_current_treatments = seq2seq_current_treatments[:total_seq2seq_rows, :, :]
    seq2seq_static_features = seq2seq_static_features[:total_seq2seq_rows, :]
    seq2seq_outputs = seq2seq_outputs[:total_seq2seq_rows, :, :]
    seq2seq_prev_outputs = seq2seq_prev_outputs[:total_seq2seq_rows, :, :]
    seq2seq_active_entries = seq2seq_active_entries[:total_seq2seq_rows, :, :]
    seq2seq_sequence_lengths = seq2seq_sequence_lengths[:total_seq2seq_rows]

    new_data = {
        'prev_treatments': seq2seq_previous_treatments,
        'current_treatments': seq2seq_current_treatments,
        'static_features': seq2seq_static_features,
        'prev_outputs': seq2seq_prev_outputs,
        'outputs': seq2seq_outputs,
        'unscaled_outputs': seq2seq_outputs * scaling_params['output_stds'] + scaling_params['output_means'],
        'sequence_lengths': seq2seq_sequence_lengths,
        'active_entries': seq2seq_active_entries,
    }

    return new_data



In [35]:
for key, value in data.items():
    print(key, value.shape, value)

outputs (2, 5, 3) [[[ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]
  [10 11 12]
  [13 14 15]]

 [[16 17 18]
  [19 20 21]
  [22 23 24]
  [25 26 27]
  [28 29 30]]]
prev_outputs (2, 5, 3) [[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]
  [ 9 10 11]
  [12 13 14]]

 [[15 16 17]
  [18 19 20]
  [21 22 23]
  [24 25 26]
  [27 28 29]]]
sequence_lengths (2,) [5 5]
active_entries (2, 5, 3) [[[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]]
current_treatments (2, 5, 3) [[[0 0 0]
  [1 1 1]
  [2 2 2]
  [3 3 3]
  [4 4 4]]

 [[5 5 5]
  [6 6 6]
  [7 7 7]
  [8 8 8]
  [9 9 9]]]
prev_treatments (2, 5, 3) [[[0 1 0]
  [1 0 1]
  [0 1 0]
  [1 0 1]
  [0 1 0]]

 [[1 0 1]
  [0 1 0]
  [1 0 1]
  [0 1 0]
  [1 0 1]]]
static_features (2, 3) [[0.1 0.2 0.3]
 [0.4 0.5 0.6]]


In [31]:
# Run the function
new_data = explode_trajectories(data, projection_horizon, scaling_params)

# Print the results for illustration
output_results = {}
for key, value in new_data.items():
    output_results[key] = value.shape, value

output_results

{'prev_treatments': ((6, 5, 3),
  array([[[0., 1., 0.],
          [1., 0., 1.],
          [0., 1., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
  
         [[0., 1., 0.],
          [1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.],
          [0., 0., 0.]],
  
         [[0., 1., 0.],
          [1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.],
          [0., 1., 0.]],
  
         [[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.],
          [0., 0., 0.],
          [0., 0., 0.]],
  
         [[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.],
          [0., 1., 0.],
          [0., 0., 0.]],
  
         [[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]])),
 'current_treatments': ((6, 5, 3),
  array([[[0., 0., 0.],
          [1., 1., 1.],
          [2., 2., 2.],
          [0., 0., 0.],
          [0., 0., 0.]],
  
         [[0., 0., 0.],
          [1., 1., 1.],
          [2., 2.

In [40]:
from copy import deepcopy

def process_sequential_test(data, projection_horizon, encoder_r=None, save_encoder_r=False):
    """
    Pre-process test dataset for multiple-step-ahead prediction: takes the last n-steps according to the projection horizon
    """

    sequence_lengths = data['sequence_lengths']
    outputs = data['outputs']
    current_treatments = data['current_treatments']
    previous_treatments = data['prev_treatments'][:, 1:, :]  # Without zero_init_treatment
    current_covariates = data['current_covariates']

    num_patient_points, max_seq_length, num_features = outputs.shape

    if encoder_r is not None:
        seq2seq_state_inits = np.zeros((num_patient_points, encoder_r.shape[-1]))
    seq2seq_active_encoder_r = np.zeros((num_patient_points, max_seq_length - projection_horizon))
    seq2seq_previous_treatments = np.zeros((num_patient_points, projection_horizon, previous_treatments.shape[-1]))
    seq2seq_current_treatments = np.zeros((num_patient_points, projection_horizon, current_treatments.shape[-1]))
    seq2seq_current_covariates = np.zeros((num_patient_points, projection_horizon, current_covariates.shape[-1]))
    seq2seq_outputs = np.zeros((num_patient_points, projection_horizon, outputs.shape[-1]))
    seq2seq_active_entries = np.zeros((num_patient_points, projection_horizon, 1))
    seq2seq_sequence_lengths = np.zeros(num_patient_points)

    for i in range(num_patient_points):
        fact_length = int(sequence_lengths[i]) - projection_horizon
        if encoder_r is not None:
            seq2seq_state_inits[i] = encoder_r[i, fact_length - 1]
        seq2seq_active_encoder_r[i, :fact_length] = 1.0

        seq2seq_active_entries[i] = np.ones(shape=(projection_horizon, 1))
        seq2seq_previous_treatments[i] = previous_treatments[i, fact_length - 1:fact_length + projection_horizon - 1, :]
        seq2seq_current_treatments[i] = current_treatments[i, fact_length:fact_length + projection_horizon, :]
        seq2seq_outputs[i] = outputs[i, fact_length: fact_length + projection_horizon, :]
        seq2seq_sequence_lengths[i] = projection_horizon
        seq2seq_current_covariates[i] = np.repeat([current_covariates[i, fact_length - 1]], projection_horizon, axis=0)

    # Package outputs
    seq2seq_data = {
        'active_encoder_r': seq2seq_active_encoder_r,
        'prev_treatments': seq2seq_previous_treatments,
        'current_treatments': seq2seq_current_treatments,
        'current_covariates': seq2seq_current_covariates,
        'prev_outputs': seq2seq_current_covariates[:, :, :1],
        'static_features': seq2seq_current_covariates[:, 0, 1:],
        'outputs': seq2seq_outputs,
        'sequence_lengths': seq2seq_sequence_lengths,
        'active_entries': seq2seq_active_entries,
        'unscaled_outputs': seq2seq_outputs * scaling_params['output_stds'] + scaling_params['output_means'],
        'patient_types': data['patient_types'],
        'patient_ids_all_trajectories': data['patient_ids_all_trajectories'],
        'patient_current_t': data['patient_current_t']
    }
    if encoder_r is not None:
        seq2seq_data['init_state'] = seq2seq_state_inits

    data_original = deepcopy(data)
    data = seq2seq_data
    data_shapes = {k: v.shape for k, v in data.items()}

    if save_encoder_r and encoder_r is not None:
        encoder_r = encoder_r[:, :max_seq_length - projection_horizon, :]

    processed_sequential = True

    return data

# Example initial data setup for process_sequential_test
data['prev_treatments'] = np.array([
    [[0, 1, 0], [1, 0, 1], [0, 1, 0], [1, 0, 1], [0, 1, 0]],  # Patient 1
    [[1, 0, 1], [0, 1, 0], [1, 0, 1], [0, 1, 0], [1, 0, 1]]  # Patient 2
])

# Adding dummy current_covariates for the example
data['current_covariates'] = np.array([
    [[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]],  # Patient 1
    [[5, 5, 5], [6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9]]  # Patient 2
])

# Adding missing keys for the sake of the example
data['patient_types'] = np.array([0, 1])
data['patient_ids_all_trajectories'] = np.array([101, 102])
data['patient_current_t'] = np.array([0, 0])


# Running the function
processed_data = process_sequential_test(data, projection_horizon)

# Printing the results

processed_data_shapes = {k: v.shape for k, v in processed_data.items()}
processed_data_shapes

{'active_encoder_r': (2, 3),
 'prev_treatments': (2, 2, 3),
 'current_treatments': (2, 2, 3),
 'current_covariates': (2, 2, 3),
 'prev_outputs': (2, 2, 1),
 'static_features': (2, 2),
 'outputs': (2, 2, 3),
 'sequence_lengths': (2,),
 'active_entries': (2, 2, 1),
 'unscaled_outputs': (2, 2, 3),
 'patient_types': (2,),
 'patient_ids_all_trajectories': (2,),
 'patient_current_t': (2,)}

In [45]:
data

{'outputs': array([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [10, 11, 12],
         [13, 14, 15]],
 
        [[16, 17, 18],
         [19, 20, 21],
         [22, 23, 24],
         [25, 26, 27],
         [28, 29, 30]]]),
 'prev_outputs': array([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11],
         [12, 13, 14]],
 
        [[15, 16, 17],
         [18, 19, 20],
         [21, 22, 23],
         [24, 25, 26],
         [27, 28, 29]]]),
 'sequence_lengths': array([5, 5]),
 'active_entries': array([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],
 
        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]]),
 'current_treatments': array([[[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2],
         [3, 3, 3],
         [4, 4, 4]],
 
        [[5, 5, 5],
         [6, 6, 6],
         [7, 7, 7],
         [8, 8, 8],
      

In [42]:
processed_data

{'active_encoder_r': array([[1., 1., 1.],
        [1., 1., 1.]]),
 'prev_treatments': array([[[1., 0., 1.],
         [0., 1., 0.]],
 
        [[0., 1., 0.],
         [1., 0., 1.]]]),
 'current_treatments': array([[[3., 3., 3.],
         [4., 4., 4.]],
 
        [[8., 8., 8.],
         [9., 9., 9.]]]),
 'current_covariates': array([[[2., 2., 2.],
         [2., 2., 2.]],
 
        [[7., 7., 7.],
         [7., 7., 7.]]]),
 'prev_outputs': array([[[2.],
         [2.]],
 
        [[7.],
         [7.]]]),
 'static_features': array([[2., 2.],
        [7., 7.]]),
 'outputs': array([[[10., 11., 12.],
         [13., 14., 15.]],
 
        [[25., 26., 27.],
         [28., 29., 30.]]]),
 'sequence_lengths': array([2., 2.]),
 'active_entries': array([[[1.],
         [1.]],
 
        [[1.],
         [1.]]]),
 'unscaled_outputs': array([[[ 6. ,  7.5,  9. ],
         [ 7.5,  9. , 10.5]],
 
        [[13.5, 15. , 16.5],
         [15. , 16.5, 18. ]]]),
 'patient_types': array([0, 1]),
 'patient_ids_all_tr