In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

import torch.nn as nn
import torch.optim as optim


class PropensityNN(nn.Module):
    def __init__(self, input_size, output_size):
        super(PropensityNN, self).__init__()
        # self.fc1 = nn.Linear(input_size, 128)
        # self.fc2 = nn.Linear(128, 64)
        # self.fc3 = nn.Linear(64, output_size)
        self.simple_linear = nn.Linear(input_size, output_size)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # x = torch.relu(self.fc1(x))
        # x = torch.relu(self.fc2(x))
        # x = self.fc3(x)
        x = self.simple_linear(x)
        x = self.softmax(x)
        return x


class MSMPropensityTreatment(CausalDiff):

    model_type = 'propensity_treatment'

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: Union[RealDatasetCollection,
                                           SyntheticDatasetCollection] = None,
                 autoregressive: bool = None, has_vitals: bool = None, **kwargs):
        super().__init__(args, dataset_collection, autoregressive, has_vitals)

        self.input_size = self.dim_treatments
        logger.info(f'Input size of {self.model_type}: {self.input_size}')
        self.output_size = self.dim_treatments

        self.propensity_treatment = PropensityNN(
            self.input_size, self.output_size)
        self.optimizer = optim.Adam(
            self.propensity_treatment.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()
        self.save_hyperparameters(args)

    def get_inputs(self, dataset: Dataset) -> np.array:
        active_entries = dataset.data['active_entries']
        prev_treatments = dataset.data['prev_treatments']
        inputs = (prev_treatments * active_entries).sum(1)
        return inputs

    def fit(self):
        self.prepare_data()
        train_f = self.get_exploded_dataset(
            self.dataset_collection.train_f, min_length=self.lag_features)
        active_entries = train_f.data['active_entries']
        last_entries = active_entries - \
            np.concatenate([active_entries[:, 1:, :], np.zeros(
                (active_entries.shape[0], 1, 1))], axis=1)

        # Inputs
        inputs = self.get_inputs(train_f)

        # Outputs
        current_treatments = train_f.data['current_treatments']
        outputs = (current_treatments * last_entries).sum(1)

        # Convert to PyTorch tensors
        inputs = torch.tensor(inputs, dtype=torch.float32)
        outputs = torch.tensor(outputs, dtype=torch.float32)

        # Create DataLoader
        dataset = TensorDataset(inputs, outputs)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

        # Training loop
        # Training loop
        for epoch in range(args.exp.max_epochs):
            epoch_loss = 0.0
            for batch_inputs, batch_outputs in dataloader:
                self.optimizer.zero_grad()
                predictions = self.propensity_treatment(batch_inputs)
                loss = self.criterion(
                    predictions, batch_outputs)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()
            epoch_loss /= len(dataloader)
            print(
                f"Epoch {epoch+1}/{args.exp.max_epochs}, Loss: {epoch_loss:.4f}")
            # also show accuracy
            correct = 0
            total = 0
            with torch.no_grad():
                for batch_inputs, batch_outputs in dataloader:
                    outputs = self.propensity_treatment(batch_inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += batch_outputs.size(0)
                    # Convert one-hot targets to indices:
                    target_labels = batch_outputs.argmax(dim=1)
                    correct += (predicted == target_labels).sum().item()
            print(f"Accuracy: {100 * correct / total:.2f}%")


class MSMPropensityHistory(CausalDiff):

    model_type = 'propensity_history'

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: Union[RealDatasetCollection,
                                           SyntheticDatasetCollection] = None,
                 autoregressive: bool = None, has_vitals: bool = None, **kwargs):
        super().__init__(args, dataset_collection, autoregressive, has_vitals)

        self.input_size = self.dim_treatments + \
            self.dim_static_features + self.lag_features
        self.input_size += self.dim_vitals if self.has_vitals else 0
        self.input_size += self.dim_outcome if self.autoregressive else 0

        logger.info(f'Input size of {self.model_type}: {self.input_size}')
        self.output_size = self.dim_treatments

        self.propensity_history = PropensityNN(
            self.input_size, self.output_size)
        self.optimizer = optim.Adam(
            self.propensity_history.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()
        self.save_hyperparameters(args)

    def get_inputs(self, dataset: Dataset, projection_horizon=0) -> np.array:
        active_entries = dataset.data['active_entries']
        lagged_entries = active_entries - \
            np.concatenate([active_entries[:, self.lag_features + 1:, :],
                            np.zeros((active_entries.shape[0], self.lag_features + 1, 1))], axis=1)
        if projection_horizon > 0:
            lagged_entries = np.concatenate([lagged_entries[:, projection_horizon:, :],
                                             np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        active_entries_before_projection = np.concatenate([active_entries[:, projection_horizon:, :],
                                                          np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        prev_treatments = dataset.data['prev_treatments']
        inputs = [(prev_treatments * active_entries_before_projection).sum(1)]
        if self.has_vitals:
            vitals = dataset.data['vitals']
            inputs.append(vitals[np.repeat(lagged_entries, self.dim_vitals, 2) == 1.0].reshape(vitals.shape[0],
                                                                                               (self.lag_features + 1) *
                                                                                               self.dim_vitals))
        if self.autoregressive:
            prev_outputs = dataset.data['prev_outputs']
            inputs.append(prev_outputs[np.repeat(lagged_entries, self.dim_outcome, 2) == 1.0].reshape(prev_outputs.shape[0],
                                                                                                      (self.lag_features + 1) *
                                                                                                      self.dim_outcome))
        static_features = dataset.data['static_features']
        inputs.append(static_features)
        return np.concatenate(inputs, axis=1)

    def fit(self):
        self.prepare_data()
        train_f = self.get_exploded_dataset(
            self.dataset_collection.train_f, min_length=self.lag_features)
        active_entries = train_f.data['active_entries']
        last_entries = active_entries - \
            np.concatenate([active_entries[:, 1:, :], np.zeros(
                (active_entries.shape[0], 1, 1))], axis=1)

        # Inputs
        inputs = self.get_inputs(train_f)

        # Outputs
        current_treatments = train_f.data['current_treatments']
        outputs = (current_treatments * last_entries).sum(1)
        # Convert to PyTorch tensors
        inputs = torch.tensor(inputs, dtype=torch.float32)
        outputs = torch.tensor(outputs, dtype=torch.float32)

        # Create DataLoader
        dataset = TensorDataset(inputs, outputs)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

        # Training loop
        for epoch in range(args.exp.max_epochs):
            epoch_loss = 0.0
            for batch_inputs, batch_outputs in dataloader:
                self.optimizer.zero_grad()
                predictions = self.propensity_history(batch_inputs)
                loss = self.criterion(predictions, batch_outputs)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()

            epoch_loss /= len(dataloader)
            print(
                f"Epoch {epoch + 1}/{args.exp.max_epochs}, Loss: {epoch_loss:.4f}")
            # also show accuracy
            correct = 0
            total = 0
            for batch_inputs, batch_outputs in dataloader:
                predictions = self.propensity_history(batch_inputs)
                _, predicted = torch.max(predictions.data, 1)
                total += batch_outputs.size(0)

                target_labels = batch_outputs.argmax(dim=1)
                correct += (predicted == target_labels).sum().item()

            print(f"Accuracy: {100 * correct / total:.2f}%")

In [None]:
# %matplotlib inline
# from IPython.display import display, clear_output
# import matplotlib.pyplot as plt
# import time
# import statistics
# from itertools import chain
# import numpy as np
# from collections import deque

# def train(model, data_loader, data_loader_validation, epochs, lr, loss_func, batch_embedder,
#           windowed_mode=False, window_mode="uniform", window_start_mode="random", min_window=50, max_window=100, neg_bin_p=0.95, train_on_all_every=4,
#           annealing_mode = False, annealing_window=5, annealing_multiplier=1.25, annealing_ratio = 0.5, annealing_minimum = 1e-6,
#           device="cuda", verbose=False, plot_every=10,
#           validation_frequency=1, validation_prp=10, moving_avg_window=10):

#     batch_embedder = batch_embedder.to(device)
#     model = model.to(device)

#     optimizer = torch.optim.Adam(
#         chain(batch_embedder.parameters(), model.parameters()),
#         lr=lr
#     )

#     model.train()
#     batch_embedder.train()
#     loss_list = []
#     initial_value = 1.0  # Initial value for equal probability
#     window_losses = torch.ones(max_window - min_window + 1, device=device) * initial_value  # Track losses for each window length
#     window_counts = torch.zeros(max_window - min_window + 1, device=device)  # Track counts for each window length
#     loss_deques = [deque(maxlen=moving_avg_window) for _ in range(max_window - min_window + 1)]  # Deques for moving average
#     if windowed_mode and window_mode == "biased_loss":
#         fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 12))
#     else:
#         fig, ax1 = plt.subplots(1, 1, figsize=(10, 6))
#     epoch_loss_list = []
#     val_loss = 0

#     for epoch in range(epochs):
#         # Annealing for the learning rate
#         if annealing_mode and epoch > annealing_window:
#             if len(epoch_loss_list) > 0 and epoch_loss_list[-1] >= annealing_multiplier * (statistics.mean(epoch_loss_list[-annealing_window:])):
#                 for g in optimizer.param_groups:
#                     if g['lr'] * annealing_ratio < annealing_minimum:
#                         g['lr'] = annealing_minimum
#                     else:
#                         g['lr'] *= annealing_ratio

#         start = time.time()
#         for i, batch in enumerate(data_loader):

#             batch = batch.to(device)
#             batch = batch_embedder(batch)

#             batch_length = batch.shape[1]

#             # Windowed mode logic
#             if windowed_mode:
#                 if batch_length < min_window:
#                     continue
#                 if window_start_mode == "random":
#                     cut_start = torch.randint(0, batch_length - window_length + 1, (1,)).item()
#                 elif window_start_mode == "fixed":
#                     cut_start = 0
#                 if window_mode == "uniform":
#                     while True:
#                         window_length = torch.randint(min_window, batch_length + 1, (1,)).item()
#                         cut_end = cut_start + window_length
#                         if min_window <= (cut_end - cut_start) <= batch_length:
#                             break
#                     batch = batch[:, cut_start:cut_end, :]

#                 elif window_mode == "negative_binomial":
#                     total_count = 1
#                     probs = neg_bin_p
#                     distribution = torch.distributions.NegativeBinomial(total_count=total_count, probs=probs)
#                     while True:
#                         window_length = distribution.sample().item() + min_window
#                         cut_end = cut_start + window_length
#                         if min_window <= window_length <= batch_length:
#                             break
#                     batch = batch[:, cut_start:cut_end, :]

#                 elif window_mode == "biased_loss":
#                     if torch.min(window_counts) < 2:
#                         # Use uniform distribution until each length has been used at least twice
#                         window_probs = torch.ones_like(window_losses) / len(window_losses)
#                     elif torch.sum(window_counts) % train_on_all_every == 0:
#                         window_probs = torch.ones_like(window_losses) / len(window_losses)
#                     else:
#                         # Update probabilities based on moving average of losses
#                         avg_losses = torch.tensor([np.mean(loss_deque) if len(loss_deque) > 0 else initial_value for loss_deque in loss_deques], device=device)
#                         window_probs = avg_losses / avg_losses.sum()
#                     while True:
#                         window_length = torch.multinomial(window_probs, 1).item() + min_window
#                         #check if the window length does work with the batch length
#                         if window_length > batch_length:
#                             continue
#                         cut_end = cut_start + window_length
#                         if min_window <= window_length <= batch_length:
#                             break
#                     batch = batch[:, cut_start:cut_end, :]
#                     window_counts[window_length - min_window] += 1  # Update window counts

#             optimizer.zero_grad()
#             predicted_noise, noise, noise_mask = model(batch)
#             loss = loss_func(predicted_noise, noise, noise_mask)
#             loss.backward()
#             # # Gradient clipping
#             max_grad_norm = 1.0
#             torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
#             torch.nn.utils.clip_grad_norm_(batch_embedder.parameters(), max_grad_norm)
#             optimizer.step()
#             loss_list.append(loss.item())

#             epoch_loss = sum(loss_list[-len(data_loader):]) / len(data_loader)
#             epoch_loss_list.append(epoch_loss)

#             # Update window losses and moving average deque
#             if windowed_mode and window_mode == "biased_loss":
#                 window_idx = window_length - min_window
#                 window_losses[window_idx] += loss.item()
#                 loss_deques[window_idx].append(loss.item())

#             # Dynamic plot update
#             if i % plot_every == 0:
#                 ax1.clear()
#                 ax1.set_ylim(0, 1)
#                 ax1.plot(loss_list)
#                 if len(loss_list) > 100:
#                     ax1.plot(np.convolve(loss_list, np.ones((100,))/100, mode='valid'))
#                     ax1.text(len(loss_list) - 1, np.convolve(loss_list, np.ones((100,))/100, mode='valid')[-1],
#                             str(round(np.convolve(loss_list, np.ones((100,))/100, mode='valid')[-1], 3)))
#                 if len(epoch_loss_list) > 0:
#                     ax1.text(0.1, 0.9, f"Epoch: {epoch} | Learning rate: {optimizer.param_groups[0]['lr']:.2e}")
#                 # ax1.text(0.1, 0.8, f"Learning rate: {optimizer.param_groups[0]['lr']:.4e}")
#                 ax1.text(0.1, 0.8, f"Loss: {epoch_loss_list[-1]:.3e} | Validation loss: {val_loss:.3e}")
#                 ax1.text(0.1, 0.7, f"Time per step: {((time.time() - start) / (i + 1)):.2f} s | Time per epoch: {((time.time() - start) / (i + 1) * len(data_loader)):.2f} s")
#                 ax1.text(0.1, 0.6, f"Time till finish (est.): {((time.time() - start) / (i + 1) * len(data_loader) * (epochs - epoch)) / 60:.2f} min")
#                 if windowed_mode and window_mode == "biased_loss":
#                     ax2.clear()
#                     ax2.bar(range(min_window, max_window + 1), window_counts.cpu().numpy())
#                     ax2.set_ylabel("Counts")
#                     ax2.set_title("Counts of Each Window Length Used")

#                     moving_avg_losses = torch.tensor([np.mean(loss_deque) if len(loss_deque) > 0 else initial_value for loss_deque in loss_deques], device=device).cpu().numpy()
#                     ax3.clear()
#                     ax3.bar(range(min_window, max_window + 1), moving_avg_losses)
#                     ax3.set_xlabel("Window Length")
#                     ax3.set_ylabel("Moving Average Loss")
#                     ax3.set_title("Moving Average Loss for Each Window Length")

#                 display(fig)
#                 clear_output(wait=True)

#         end = time.time()

#         # Validation
#         if epoch % validation_frequency == 0:
#             loss_list_validation = []
#             for i, batch in enumerate(data_loader_validation):
#                 batch = batch.to(device)
#                 batch = batch_embedder(batch)
#                 if i % validation_prp == 0:
#                     predicted_noise, noise, noise_mask = model(batch)
#                     loss = loss_func(predicted_noise, noise, noise_mask)
#                     loss_list_validation.append(loss.item())

#             val_loss = np.mean(loss_list_validation)

#         if verbose:
#             print(f"Epoch {epoch} completed in {end - start} seconds, Loss: {epoch_loss}")
#             print(f"Validation Loss: {val_loss}")


#     return model, loss_list

In [None]:
# import numpy as np
# import random
# from torch.utils.data import DataLoader
# from torch.utils.data import Sampler
# from collections import defaultdict
# import torch
# from torch.utils.data import Dataset


# class CustomTimeSeriesDataset(Dataset):
#     def __init__(
#         self, data_tensor, sequence_lengths, min_seq_length=None, max_seq_length=None
#     ):
#         # Store the initial sequences and lengths
#         self.data_tensor = data_tensor
#         self.sequence_lengths = sequence_lengths
#         self.min_seq_length = min_seq_length
#         self.max_seq_length = max_seq_length

#         # Filter sequences based on min and max sequence length
#         if self.min_seq_length is not None:
#             valid_indices = [
#                 i
#                 for i, length in enumerate(self.sequence_lengths)
#                 if length >= self.min_seq_length
#             ]
#         else:
#             valid_indices = list(range(len(self.sequence_lengths)))

#         if self.max_seq_length is not None:
#             valid_indices = [
#                 i
#                 for i in valid_indices
#                 if self.sequence_lengths[i] <= self.max_seq_length
#             ]

#         self.data_tensor = self.data_tensor[valid_indices]
#         self.sequence_lengths = [self.sequence_lengths[i]
#                                  for i in valid_indices]

#     def __len__(self):
#         return len(self.sequence_lengths)

#     def __getitem__(self, idx):
#         seq_length = int(self.sequence_lengths[idx])
#         return self.data_tensor[idx, :seq_length, :], seq_length


# # class CustomTimeSeriesDataset(Dataset):
# #     def __init__(self, data_tensor, sequence_lengths, min_seq_length=None, max_seq_length=None):
# #         self.data_tensor = data_tensor
# #         self.sequence_lengths = sequence_lengths

# #     def __len__(self):
# #         return len(self.sequence_lengths)

# #     def __getitem__(self, idx):
# #         seq_length = int(self.sequence_lengths[idx])
# #         return self.data_tensor[idx, :seq_length, :], seq_length


# class LengthBatchSampler(Sampler):
#     def __init__(self, dataset, batch_size):
#         self.sequence_lengths = dataset.sequence_lengths
#         self.batch_size = batch_size
#         self.batches = self._create_batches()

#     def _create_batches(self):
#         length_to_indices = defaultdict(list)
#         for idx, length in enumerate(self.sequence_lengths):
#             length_to_indices[length].append(idx)

#         batches = []
#         for length, indices in length_to_indices.items():
#             # Split indices into batches of the specified batch size
#             for i in range(0, len(indices), self.batch_size):
#                 batches.append(indices[i: i + self.batch_size])
#         return batches

#     def __iter__(self):
#         for batch in self.batches:
#             yield batch

#     def __len__(self):
#         return len(self.batches)


# def collate_fn(batch):
#     data, lengths = zip(*batch)
#     data = torch.stack(data)
#     return data


# def create_dataloader(
#     data_tensor, sequence_lengths, batch_size, min_seq_length=None, max_seq_length=None
# ):
#     dataset = CustomTimeSeriesDataset(
#         data_tensor, sequence_lengths, min_seq_length, max_seq_length
#     )
#     sampler = LengthBatchSampler(dataset, batch_size)
#     dataloader = DataLoader(
#         dataset, batch_sampler=sampler, collate_fn=collate_fn)
#     return dataloader


# class BalancedLengthBatchSampler(Sampler):
#     def __init__(self, dataset, batch_size, balance_factor=1.0):
#         self.sequence_lengths = dataset.sequence_lengths
#         self.batch_size = batch_size
#         self.balance_factor = balance_factor
#         self.batches = self._create_balanced_batches()

#     def _create_balanced_batches(self):
#         # Group indices by sequence length
#         length_to_indices = defaultdict(list)
#         for idx, length in enumerate(self.sequence_lengths):
#             length_to_indices[length].append(idx)

#         # Calculate the maximum count of indices for balancing
#         max_count = max(len(indices) for indices in length_to_indices.values())

#         # Balance the distribution of sequence lengths by oversampling shorter sequences
#         balanced_batches = []
#         for length, indices in length_to_indices.items():
#             count = len(indices)
#             if count < max_count:
#                 repeat_factor = int(self.balance_factor * (max_count / count))
#                 # Correctly oversample the list elements
#                 oversampled_indices = indices * repeat_factor
#                 # oversampled_indices = oversampled_indices[:max_count]
#             else:
#                 oversampled_indices = indices

#             # Shuffle the indices of this particular length
#             random.shuffle(oversampled_indices)

#             # Create batches for this length
#             for i in range(0, len(oversampled_indices), self.batch_size):
#                 batch = oversampled_indices[i: i + self.batch_size]
#                 if len(batch) == self.batch_size:
#                     balanced_batches.append(batch)

#         # Shuffle the list of balanced batches to ensure random order
#         random.shuffle(balanced_batches)

#         return balanced_batches

#     def __iter__(self):
#         for batch in self.batches:
#             yield batch

#     def __len__(self):
#         return len(self.batches)


# def create_balanced_dataloader(
#     data_tensor,
#     sequence_lengths,
#     batch_size,
#     balance_factor=1.0,
#     min_seq_length=None,
#     max_seq_length=None,
# ):
#     dataset = CustomTimeSeriesDataset(
#         data_tensor, sequence_lengths, min_seq_length, max_seq_length
#     )
#     sampler = BalancedLengthBatchSampler(dataset, batch_size, balance_factor)
#     dataloader = DataLoader(
#         dataset, batch_sampler=sampler, collate_fn=collate_fn)
#     return dataloader

In [None]:
# train_loader = create_balanced_dataloader(
#     training_data_tensor,
#     training_data_sequence_lengths,
#     batch_size=100,
#     balance_factor=1.0,
#     min_seq_length=2,
#     max_seq_length=None,
# )

# val_loader = create_balanced_dataloader(
#     validation_data_tensor,
#     validation_data_sequence_lengths,
#     batch_size=100,
#     balance_factor=1.0,
#     min_seq_length=2,
#     max_seq_length=None,
# )

# # check the size of the train, val, and test sets
# print(len(train_loader))
# print(len(val_loader))

In [None]:
# training_data_tensor.shape

In [None]:
# categorical_indices_sizes = {
#     # 'time_step' : [0, 1],
#     "chemo_application_prev": [0, 1],
#     "radio_application_prev": [1, 1],
#     "patient_type_tile": [2, 1],
# }

# numerical_indices = {"cancer_volume": 3}

# numerical_indices = {
#     'chemo_application_prev': 0,
#     'radio_application_prev': 1,
#     'patient_type_tile': 2,
#     'cancer_volume': 3
# }


# training_data_tensor_embedded = data_embedder(training_data_tensor)
# validation_data_tensor_embedded = data_embedder(validation_data_tensor)
# test_data_factuals_tensor_embedded = data_embedder(test_data_factuals_tensor)
# test_data_counterfactuals_tensor_embedded = data_embedder(test_data_counterfactuals_tensor)
# test_data_seq_tensor_embedded = data_embedder(test_data_seq_tensor)

In [None]:
# import gc

# torch.cuda.empty_cache()
# gc.collect()

In [None]:
# diffusion_imputer = diffusion_imputation(
#     emb_dim=128,
#     # strategy="forecasting_last_n_time",
#     # strategy="random",
#     # missing_prp=0.5,
#     # strategy='selected_features',
#     strategy="selected_features_last_n_time",
#     last_n_time=1,
#     features_to_impute=[0, 1, 2, 3],
#     # excluded_features = [i for i in range(6)], #[2],#[0,1,2,3,5], #for the embedded stock names which we don't need to predict
#     # strategy='selected_features_and_selected_features_after_time',
#     # features_to_impute_completely=[2],
#     # features_to_impute_after_time=[3],
#     num_residual_layers=2,
#     diffusion_steps=50,
#     diffusion_beta_schedule="quadratic",
#     num_heads=8,
#     kernel_size=(1, 1),
#     ff_dim=512,
#     num_cells=1,
#     dropout=0,
#     # csdi, csdi_moded_transformer, rsa, rsa_moded_transformer, moded_transformer_alone, rsa_csdi
#     method="rsa_moded_transformer",
#     device="cuda",
# )

# # data_embedder = DataEmbedder(
# #     categorical_indices_sizes, numerical_indices, training_data_tensor
# # )