In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

import torch.nn as nn
import torch.optim as optim


class PropensityNN(nn.Module):
    def __init__(self, input_size, output_size):
        super(PropensityNN, self).__init__()
        # self.fc1 = nn.Linear(input_size, 128)
        # self.fc2 = nn.Linear(128, 64)
        # self.fc3 = nn.Linear(64, output_size)
        self.simple_linear = nn.Linear(input_size, output_size)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # x = torch.relu(self.fc1(x))
        # x = torch.relu(self.fc2(x))
        # x = self.fc3(x)
        x = self.simple_linear(x)
        x = self.softmax(x)
        return x


class MSMPropensityTreatment(CausalDiff):

    model_type = 'propensity_treatment'

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: Union[RealDatasetCollection,
                                           SyntheticDatasetCollection] = None,
                 autoregressive: bool = None, has_vitals: bool = None, **kwargs):
        super().__init__(args, dataset_collection, autoregressive, has_vitals)

        self.input_size = self.dim_treatments
        logger.info(f'Input size of {self.model_type}: {self.input_size}')
        self.output_size = self.dim_treatments

        self.propensity_treatment = PropensityNN(
            self.input_size, self.output_size)
        self.optimizer = optim.Adam(
            self.propensity_treatment.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()
        self.save_hyperparameters(args)

    def get_inputs(self, dataset: Dataset) -> np.array:
        active_entries = dataset.data['active_entries']
        prev_treatments = dataset.data['prev_treatments']
        inputs = (prev_treatments * active_entries).sum(1)
        return inputs

    def fit(self):
        self.prepare_data()
        train_f = self.get_exploded_dataset(
            self.dataset_collection.train_f, min_length=self.lag_features)
        active_entries = train_f.data['active_entries']
        last_entries = active_entries - \
            np.concatenate([active_entries[:, 1:, :], np.zeros(
                (active_entries.shape[0], 1, 1))], axis=1)

        # Inputs
        inputs = self.get_inputs(train_f)

        # Outputs
        current_treatments = train_f.data['current_treatments']
        outputs = (current_treatments * last_entries).sum(1)

        # Convert to PyTorch tensors
        inputs = torch.tensor(inputs, dtype=torch.float32)
        outputs = torch.tensor(outputs, dtype=torch.float32)

        # Create DataLoader
        dataset = TensorDataset(inputs, outputs)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

        # Training loop
        # Training loop
        for epoch in range(args.exp.max_epochs):
            epoch_loss = 0.0
            for batch_inputs, batch_outputs in dataloader:
                self.optimizer.zero_grad()
                predictions = self.propensity_treatment(batch_inputs)
                loss = self.criterion(
                    predictions, batch_outputs)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()
            epoch_loss /= len(dataloader)
            print(
                f"Epoch {epoch+1}/{args.exp.max_epochs}, Loss: {epoch_loss:.4f}")
            # also show accuracy
            correct = 0
            total = 0
            with torch.no_grad():
                for batch_inputs, batch_outputs in dataloader:
                    outputs = self.propensity_treatment(batch_inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += batch_outputs.size(0)
                    # Convert one-hot targets to indices:
                    target_labels = batch_outputs.argmax(dim=1)
                    correct += (predicted == target_labels).sum().item()
            print(f"Accuracy: {100 * correct / total:.2f}%")


class MSMPropensityHistory(CausalDiff):

    model_type = 'propensity_history'

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: Union[RealDatasetCollection,
                                           SyntheticDatasetCollection] = None,
                 autoregressive: bool = None, has_vitals: bool = None, **kwargs):
        super().__init__(args, dataset_collection, autoregressive, has_vitals)

        self.input_size = self.dim_treatments + \
            self.dim_static_features + self.lag_features
        self.input_size += self.dim_vitals if self.has_vitals else 0
        self.input_size += self.dim_outcome if self.autoregressive else 0

        logger.info(f'Input size of {self.model_type}: {self.input_size}')
        self.output_size = self.dim_treatments

        self.propensity_history = PropensityNN(
            self.input_size, self.output_size)
        self.optimizer = optim.Adam(
            self.propensity_history.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()
        self.save_hyperparameters(args)

    def get_inputs(self, dataset: Dataset, projection_horizon=0) -> np.array:
        active_entries = dataset.data['active_entries']
        lagged_entries = active_entries - \
            np.concatenate([active_entries[:, self.lag_features + 1:, :],
                            np.zeros((active_entries.shape[0], self.lag_features + 1, 1))], axis=1)
        if projection_horizon > 0:
            lagged_entries = np.concatenate([lagged_entries[:, projection_horizon:, :],
                                             np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        active_entries_before_projection = np.concatenate([active_entries[:, projection_horizon:, :],
                                                          np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        prev_treatments = dataset.data['prev_treatments']
        inputs = [(prev_treatments * active_entries_before_projection).sum(1)]
        if self.has_vitals:
            vitals = dataset.data['vitals']
            inputs.append(vitals[np.repeat(lagged_entries, self.dim_vitals, 2) == 1.0].reshape(vitals.shape[0],
                                                                                               (self.lag_features + 1) *
                                                                                               self.dim_vitals))
        if self.autoregressive:
            prev_outputs = dataset.data['prev_outputs']
            inputs.append(prev_outputs[np.repeat(lagged_entries, self.dim_outcome, 2) == 1.0].reshape(prev_outputs.shape[0],
                                                                                                      (self.lag_features + 1) *
                                                                                                      self.dim_outcome))
        static_features = dataset.data['static_features']
        inputs.append(static_features)
        return np.concatenate(inputs, axis=1)

    def fit(self):
        self.prepare_data()
        train_f = self.get_exploded_dataset(
            self.dataset_collection.train_f, min_length=self.lag_features)
        active_entries = train_f.data['active_entries']
        last_entries = active_entries - \
            np.concatenate([active_entries[:, 1:, :], np.zeros(
                (active_entries.shape[0], 1, 1))], axis=1)

        # Inputs
        inputs = self.get_inputs(train_f)

        # Outputs
        current_treatments = train_f.data['current_treatments']
        outputs = (current_treatments * last_entries).sum(1)
        # Convert to PyTorch tensors
        inputs = torch.tensor(inputs, dtype=torch.float32)
        outputs = torch.tensor(outputs, dtype=torch.float32)

        # Create DataLoader
        dataset = TensorDataset(inputs, outputs)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

        # Training loop
        for epoch in range(args.exp.max_epochs):
            epoch_loss = 0.0
            for batch_inputs, batch_outputs in dataloader:
                self.optimizer.zero_grad()
                predictions = self.propensity_history(batch_inputs)
                loss = self.criterion(predictions, batch_outputs)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()

            epoch_loss /= len(dataloader)
            print(
                f"Epoch {epoch + 1}/{args.exp.max_epochs}, Loss: {epoch_loss:.4f}")
            # also show accuracy
            correct = 0
            total = 0
            for batch_inputs, batch_outputs in dataloader:
                predictions = self.propensity_history(batch_inputs)
                _, predicted = torch.max(predictions.data, 1)
                total += batch_outputs.size(0)

                target_labels = batch_outputs.argmax(dim=1)
                correct += (predicted == target_labels).sum().item()

            print(f"Accuracy: {100 * correct / total:.2f}%")

In [None]:
# import torch
# %matplotlib inline
# from IPython.display import display, clear_output
# import matplotlib.pyplot as plt
# import time
# import statistics
# from itertools import chain
# import numpy as np
# from collections import deque
# # torch.backends.cudnn.benchmark = True

# def train(model, data_loader, data_loader_validation, epochs, lr, loss_func, weighted_loss_func = None,
#           batch_embedder = None, gradient_clip = 1.0, amp_scale = True,
#           annealing_mode = False, annealing_window=5, annealing_multiplier=1.25, annealing_ratio = 0.5, annealing_minimum = 1e-6,
#           explode = False,
#           projection_horizon=1,
#           device="cuda", num_gpus=1, verbose=False,
#           plot_every=10, plot_display_window=100,
#           validation_frequency=1, validation_prp=10, moving_avg_window=10):

#     # Check for GPU availability
#     available_gpus = torch.cuda.device_count()
#     if available_gpus < num_gpus:
#         print(f"Requested {num_gpus} GPUs, but only {available_gpus} are available.")
#         num_gpus = available_gpus
#     else:
#         print(f"Using {num_gpus} GPUs for training.")
#         #also print gpu model
#         print(f"GPU model: {torch.cuda.get_device_name(0)}")

#     if num_gpus > 1:
#         model = torch.nn.DataParallel(model, device_ids=list(range(num_gpus)))
#         if batch_embedder is not None:
#             batch_embedder = torch.nn.DataParallel(batch_embedder, device_ids=list(range(num_gpus)))

#     if batch_embedder is not None:
#         batch_embedder = batch_embedder.to(device)
#     model = model.to(device)

#     if batch_embedder is not None:
#         optimizer = torch.optim.Adam(
#             chain(batch_embedder.parameters(), model.parameters()),
#             lr=lr
#         )
#     else:
#         optimizer = torch.optim.Adam(
#             model.parameters(),
#             lr=lr
#         )

#     model.train()

#     # if batch_embedder is not None:
#     #     batch_embedder.train()

#     loss_list = []
#     fig, ax1 = plt.subplots(1, 1, figsize=(10, 6))
#     epoch_loss_list = []
#     val_loss = 0

#     total_time_start = time.time()

#     if amp_scale:
#         scaler = torch.cuda.amp.GradScaler()

#     for epoch in range(epochs):
#         # Annealing for the learning rate
#         if annealing_mode and epoch > annealing_window:
#             if len(epoch_loss_list) > 0 and epoch_loss_list[-1] >= annealing_multiplier * (statistics.mean(epoch_loss_list[-annealing_window:])):
#                 for g in optimizer.param_groups:
#                     if g['lr'] * annealing_ratio < annealing_minimum:
#                         g['lr'] = annealing_minimum
#                     else:
#                         g['lr'] *= annealing_ratio
#         start = time.time()
#         for i, batch in enumerate(data_loader):
#             if explode:
#                 batch = model.explode_trajectories(batch, projection_horizon=projection_horizon)
#                 # convert to pytorch tensor
#                 batch = {key: torch.tensor(value).to(device) for key, value in batch.items()}
#             sequence_lengths = batch['sequence_lengths']
#             model.sequence_length = sequence_lengths
#             if 'stabilized_weights' in batch:
#                 stabilized_weights = batch['stabilized_weights']
#             curr_treatments = batch['current_treatments']
#             vitals_or_prev_outputs = []
#             # vitals_or_prev_outputs.append(batch['vitals']) if self.has_vitals else None
#             # if self.autoregressive else None
#             vitals_or_prev_outputs.append(batch['prev_outputs'])
#             vitals_or_prev_outputs = torch.cat(vitals_or_prev_outputs, dim=-1)
#             static_features = batch['static_features']
#             outputs = batch['outputs']

#             batch = torch.cat((vitals_or_prev_outputs, curr_treatments), dim=-1)
#             if explode:
#                 batch = torch.cat((batch, static_features), dim=-1)
#             else:
#                 batch = torch.cat((batch, static_features.unsqueeze(
#                     1).expand(-1, batch.size(1), -1)), dim=-1)
#             batch = torch.cat((batch, outputs), dim=-1)
#             if batch.shape[0] == 0:
#                 continue
#             batch = batch.to(device)

#             # if batch_embedder is not None:
#             #     batch = batch_embedder(batch)

#             optimizer.zero_grad()

#             if amp_scale:
#                 with torch.cuda.amp.autocast():
#                     predicted_noise, noise, noise_mask = model(batch)
#                     if weighted_loss_func is not None:
#                         loss = weighted_loss_func(predicted_noise, noise, noise_mask, stabilized_weights)
#                     else:
#                         loss = loss_func(predicted_noise, noise, noise_mask)
#                 scaler.scale(loss).backward()
#                 # Gradient clipping
#                 if gradient_clip is not None:
#                     scaler.unscale_(optimizer)  # Unscales the gradients of the optimizer's assigned parameters
#                     torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
#                 scaler.step(optimizer)
#                 scaler.update()
#                 loss_list.append(loss.item())
#             else:
#                 predicted_noise, noise, noise_mask = model(batch)
#                 if weighted_loss_func is not None:
#                     loss = weighted_loss_func(predicted_noise, noise, noise_mask, stabilized_weights)
#                 else:
#                     loss = loss_func(predicted_noise, noise, noise_mask)
#                 loss.backward()
#                 # Gradient clipping
#                 if gradient_clip is not None:
#                     max_grad_norm = gradient_clip
#                     torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
#                     if batch_embedder is not None:
#                         torch.nn.utils.clip_grad_norm_(batch_embedder.parameters(), max_grad_norm)
#                 optimizer.step()
#                 loss_list.append(loss.item())

#             epoch_loss = sum(loss_list[-len(data_loader):]) / len(data_loader)
#             epoch_loss_list.append(epoch_loss)

#             # Dynamic plot update focusing on recent losses
#             if i % plot_every == 0:
#                 ax1.clear()
#                 # Only consider the recent losses for plotting (e.g., last 100)
#                 display_window = plot_display_window
#                 recent_losses = loss_list[-display_window:] if len(loss_list) > display_window else loss_list
#                 # Dynamically set y-limits based on recent loss values
#                 if recent_losses:
#                     min_loss = min(recent_losses)
#                     max_loss = max(recent_losses)
#                     margin = 0.1 * (max_loss - min_loss) if max_loss != min_loss else 0.1 * max_loss
#                     ax1.set_ylim(min_loss - margin, max_loss + margin)
#                 ax1.plot(recent_losses)
#                 if len(recent_losses) > moving_avg_window:
#                     moving_avg = np.convolve(recent_losses, np.ones((moving_avg_window,)) / moving_avg_window, mode='valid')
#                     x = np.arange(moving_avg_window - 1, len(recent_losses))
#                     ax1.plot(x, moving_avg)
#                     ax1.text(len(recent_losses) - 1, moving_avg[-1],
#                              f"{moving_avg[-1]:.3e}")
#                 if epoch_loss_list:
#                     ax1.text(0.1, 0.9, f"Epoch: {epoch} | Learning rate: {optimizer.param_groups[0]['lr']:.2e}", transform=ax1.transAxes)
#                 ax1.text(0.1, 0.8, f"Loss: {epoch_loss_list[-1]:.3e} | Validation loss: {val_loss:.3e}", transform=ax1.transAxes)
#                 ax1.text(0.1, 0.7, f"Time per step: {((time.time() - start) / (i + 1)):.2f} s | Time per epoch: {((time.time() - start) / (i + 1) * len(data_loader)):.2f} s", transform=ax1.transAxes)
#                 ax1.text(0.1, 0.6, f"Time till finish (est.): {((time.time() - start) / (i + 1) * len(data_loader) * (epochs - epoch)) / 60:.2f} min", transform=ax1.transAxes)
#                 display(fig)
#                 clear_output(wait=True)

#         end = time.time()

#         # Validation
#         if epoch % validation_frequency == 0:
#             loss_list_validation = []
#             for i, batch in enumerate(data_loader_validation):
#                 curr_treatments = batch['current_treatments']
#                 vitals_or_prev_outputs = []
#                 # vitals_or_prev_outputs.append(batch['vitals']) if self.has_vitals else None
#                 # if self.autoregressive else None
#                 vitals_or_prev_outputs.append(batch['prev_outputs'])
#                 vitals_or_prev_outputs = torch.cat(vitals_or_prev_outputs, dim=-1)
#                 static_features = batch['static_features']
#                 outputs = batch['outputs']

#                 batch = torch.cat((vitals_or_prev_outputs, curr_treatments), dim=-1)
#                 batch = torch.cat((batch, static_features.unsqueeze(
#                     1).expand(-1, batch.size(1), -1)), dim=-1)
#                 batch = torch.cat((batch, outputs), dim=-1)

#                 batch = batch.to(device)
#                 # if batch_embedder is not None:
#                 #     batch = batch_embedder(batch)
#                 if i % validation_prp == 0:
#                     predicted_noise, noise, noise_mask = model(batch)
#                     loss = loss_func(predicted_noise, noise, noise_mask)
#                     loss_list_validation.append(loss.item())

#             val_loss = np.mean(loss_list_validation)

#         total_time_end = time.time()
#         total_time = total_time_end - total_time_start

#         if verbose:
#             print(f"Epoch {epoch} completed in {end - start} seconds, Loss: {epoch_loss}")
#             print(f"Validation Loss: {val_loss}")

#     print(f"Took {total_time} seconds for {epoch} epochs.")


#     return model, loss_list

In [None]:
# %matplotlib inline
# from IPython.display import display, clear_output
# import matplotlib.pyplot as plt
# import time
# import statistics
# from itertools import chain
# import numpy as np
# from collections import deque

# def train(model, data_loader, data_loader_validation, epochs, lr, loss_func, batch_embedder,
#           windowed_mode=False, window_mode="uniform", window_start_mode="random", min_window=50, max_window=100, neg_bin_p=0.95, train_on_all_every=4,
#           annealing_mode = False, annealing_window=5, annealing_multiplier=1.25, annealing_ratio = 0.5, annealing_minimum = 1e-6,
#           device="cuda", verbose=False, plot_every=10,
#           validation_frequency=1, validation_prp=10, moving_avg_window=10):

#     batch_embedder = batch_embedder.to(device)
#     model = model.to(device)

#     optimizer = torch.optim.Adam(
#         chain(batch_embedder.parameters(), model.parameters()),
#         lr=lr
#     )

#     model.train()
#     batch_embedder.train()
#     loss_list = []
#     initial_value = 1.0  # Initial value for equal probability
#     window_losses = torch.ones(max_window - min_window + 1, device=device) * initial_value  # Track losses for each window length
#     window_counts = torch.zeros(max_window - min_window + 1, device=device)  # Track counts for each window length
#     loss_deques = [deque(maxlen=moving_avg_window) for _ in range(max_window - min_window + 1)]  # Deques for moving average
#     if windowed_mode and window_mode == "biased_loss":
#         fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 12))
#     else:
#         fig, ax1 = plt.subplots(1, 1, figsize=(10, 6))
#     epoch_loss_list = []
#     val_loss = 0

#     for epoch in range(epochs):
#         # Annealing for the learning rate
#         if annealing_mode and epoch > annealing_window:
#             if len(epoch_loss_list) > 0 and epoch_loss_list[-1] >= annealing_multiplier * (statistics.mean(epoch_loss_list[-annealing_window:])):
#                 for g in optimizer.param_groups:
#                     if g['lr'] * annealing_ratio < annealing_minimum:
#                         g['lr'] = annealing_minimum
#                     else:
#                         g['lr'] *= annealing_ratio

#         start = time.time()
#         for i, batch in enumerate(data_loader):

#             batch = batch.to(device)
#             batch = batch_embedder(batch)

#             batch_length = batch.shape[1]

#             # Windowed mode logic
#             if windowed_mode:
#                 if batch_length < min_window:
#                     continue
#                 if window_start_mode == "random":
#                     cut_start = torch.randint(0, batch_length - window_length + 1, (1,)).item()
#                 elif window_start_mode == "fixed":
#                     cut_start = 0
#                 if window_mode == "uniform":
#                     while True:
#                         window_length = torch.randint(min_window, batch_length + 1, (1,)).item()
#                         cut_end = cut_start + window_length
#                         if min_window <= (cut_end - cut_start) <= batch_length:
#                             break
#                     batch = batch[:, cut_start:cut_end, :]

#                 elif window_mode == "negative_binomial":
#                     total_count = 1
#                     probs = neg_bin_p
#                     distribution = torch.distributions.NegativeBinomial(total_count=total_count, probs=probs)
#                     while True:
#                         window_length = distribution.sample().item() + min_window
#                         cut_end = cut_start + window_length
#                         if min_window <= window_length <= batch_length:
#                             break
#                     batch = batch[:, cut_start:cut_end, :]

#                 elif window_mode == "biased_loss":
#                     if torch.min(window_counts) < 2:
#                         # Use uniform distribution until each length has been used at least twice
#                         window_probs = torch.ones_like(window_losses) / len(window_losses)
#                     elif torch.sum(window_counts) % train_on_all_every == 0:
#                         window_probs = torch.ones_like(window_losses) / len(window_losses)
#                     else:
#                         # Update probabilities based on moving average of losses
#                         avg_losses = torch.tensor([np.mean(loss_deque) if len(loss_deque) > 0 else initial_value for loss_deque in loss_deques], device=device)
#                         window_probs = avg_losses / avg_losses.sum()
#                     while True:
#                         window_length = torch.multinomial(window_probs, 1).item() + min_window
#                         #check if the window length does work with the batch length
#                         if window_length > batch_length:
#                             continue
#                         cut_end = cut_start + window_length
#                         if min_window <= window_length <= batch_length:
#                             break
#                     batch = batch[:, cut_start:cut_end, :]
#                     window_counts[window_length - min_window] += 1  # Update window counts

#             optimizer.zero_grad()
#             predicted_noise, noise, noise_mask = model(batch)
#             loss = loss_func(predicted_noise, noise, noise_mask)
#             loss.backward()
#             # # Gradient clipping
#             max_grad_norm = 1.0
#             torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
#             torch.nn.utils.clip_grad_norm_(batch_embedder.parameters(), max_grad_norm)
#             optimizer.step()
#             loss_list.append(loss.item())

#             epoch_loss = sum(loss_list[-len(data_loader):]) / len(data_loader)
#             epoch_loss_list.append(epoch_loss)

#             # Update window losses and moving average deque
#             if windowed_mode and window_mode == "biased_loss":
#                 window_idx = window_length - min_window
#                 window_losses[window_idx] += loss.item()
#                 loss_deques[window_idx].append(loss.item())

#             # Dynamic plot update
#             if i % plot_every == 0:
#                 ax1.clear()
#                 ax1.set_ylim(0, 1)
#                 ax1.plot(loss_list)
#                 if len(loss_list) > 100:
#                     ax1.plot(np.convolve(loss_list, np.ones((100,))/100, mode='valid'))
#                     ax1.text(len(loss_list) - 1, np.convolve(loss_list, np.ones((100,))/100, mode='valid')[-1],
#                             str(round(np.convolve(loss_list, np.ones((100,))/100, mode='valid')[-1], 3)))
#                 if len(epoch_loss_list) > 0:
#                     ax1.text(0.1, 0.9, f"Epoch: {epoch} | Learning rate: {optimizer.param_groups[0]['lr']:.2e}")
#                 # ax1.text(0.1, 0.8, f"Learning rate: {optimizer.param_groups[0]['lr']:.4e}")
#                 ax1.text(0.1, 0.8, f"Loss: {epoch_loss_list[-1]:.3e} | Validation loss: {val_loss:.3e}")
#                 ax1.text(0.1, 0.7, f"Time per step: {((time.time() - start) / (i + 1)):.2f} s | Time per epoch: {((time.time() - start) / (i + 1) * len(data_loader)):.2f} s")
#                 ax1.text(0.1, 0.6, f"Time till finish (est.): {((time.time() - start) / (i + 1) * len(data_loader) * (epochs - epoch)) / 60:.2f} min")
#                 if windowed_mode and window_mode == "biased_loss":
#                     ax2.clear()
#                     ax2.bar(range(min_window, max_window + 1), window_counts.cpu().numpy())
#                     ax2.set_ylabel("Counts")
#                     ax2.set_title("Counts of Each Window Length Used")

#                     moving_avg_losses = torch.tensor([np.mean(loss_deque) if len(loss_deque) > 0 else initial_value for loss_deque in loss_deques], device=device).cpu().numpy()
#                     ax3.clear()
#                     ax3.bar(range(min_window, max_window + 1), moving_avg_losses)
#                     ax3.set_xlabel("Window Length")
#                     ax3.set_ylabel("Moving Average Loss")
#                     ax3.set_title("Moving Average Loss for Each Window Length")

#                 display(fig)
#                 clear_output(wait=True)

#         end = time.time()

#         # Validation
#         if epoch % validation_frequency == 0:
#             loss_list_validation = []
#             for i, batch in enumerate(data_loader_validation):
#                 batch = batch.to(device)
#                 batch = batch_embedder(batch)
#                 if i % validation_prp == 0:
#                     predicted_noise, noise, noise_mask = model(batch)
#                     loss = loss_func(predicted_noise, noise, noise_mask)
#                     loss_list_validation.append(loss.item())

#             val_loss = np.mean(loss_list_validation)

#         if verbose:
#             print(f"Epoch {epoch} completed in {end - start} seconds, Loss: {epoch_loss}")
#             print(f"Validation Loss: {val_loss}")


#     return model, loss_list

In [None]:
def run_multiple_evaluations(
    dataloader,
    imputer,
    training_mean,
    training_std,
    sample_number,
    batch_embedder=None,
    old_sample=[],
    min_sequence_len=2,
    max_sequence_len=None,
    scale=1,
    verbose=True,
    show_max_diff=False,
    show_rmse=False,
    num_gpus=1
):

    imputer.eval()
    final_samples = old_sample
    max_seq_len = 0
    total_batches = len(dataloader) * sample_number
    completed_batches = 0
    sample_time = []
    average_sample_time = 0

    for i in range(sample_number):
        sample_start = time.time()
        # print a line to separate the samples
        print("-------------------------------------------------")
        print(f"Running sample {i + 1}/{sample_number}")
        all_samples = []
        for batch_idx, batch in enumerate(dataloader):
            # Get the data from the batch (collate_fn returns a tuple)
            sequence_lengths = batch['sequence_lengths']
            for key in batch.keys():
                batch[key] = batch[key][sequence_lengths >= min_sequence_len]
            sequence_lengths = batch['sequence_lengths']
            if max_sequence_len is not None:
                for key in batch.keys():
                    batch[key] = batch[key][sequence_lengths <= max_sequence_len]
                sequence_lengths = batch['sequence_lengths']
            imputer.sequence_length = sequence_lengths

            curr_treatments = batch['current_treatments']
            vitals_or_prev_outputs = []
            # vitals_or_prev_outputs.append(batch['vitals']) if self.has_vitals else None
            # if self.autoregressive else None
            vitals_or_prev_outputs.append(batch['prev_outputs'])
            vitals_or_prev_outputs = torch.cat(vitals_or_prev_outputs, dim=-1)
            static_features = batch['static_features']
            outputs = batch['outputs'][:, :, 1:2]

            batch = torch.cat(
                (vitals_or_prev_outputs, curr_treatments), dim=-1)
            batch = torch.cat((batch, static_features.unsqueeze(
                1).expand(-1, batch.size(1), -1)), dim=-1)
            batch = torch.cat((batch, outputs), dim=-1)
            # select the entries in the batch that have more than min_sequence_len
            # batch = batch[sequence_lengths >= min_sequence_len]
            batch = batch.to(imputer.device)
            # if batch_embedder is not None:
            #     batch = batch_embedder(batch)
            # seq_length = batch.shape[1]
            # if seq_length < min_sequence_len:
            #     # print(f"Skipping batch {batch_idx + 1} as sequence length is less than {min_sequence_len}")
            #     completed_batches += 1
            #     continue
            # if max_sequence_len is not None and seq_length > max_sequence_len:
            #     # print(f"Skipping batch {batch_idx + 1} as sequence length is greater than {max_sequence_len}")
            #     completed_batches += 1
            #     continue

            # print(f"sequence length: {sequence_lengths}")

            # Generate imputation masks for the current batch
            imputation_masks = imputer.get_mask(
                batch, strategy="selected_features_last_n_sequence_length"
            ).to(imputer.device)
            imputer.features_to_impute = [6]
            imputer.last_n_time = 1
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    imputed_samples = imputer.get_predictions(
                        batch,
                        imputation_masks,
                        mean=training_mean,
                        std=training_std,
                        scale=scale,
                        verbose=verbose,
                        show_max_diff=show_max_diff,
                        show_rmse=show_rmse
                    )

            all_samples.append(imputed_samples)

            completed_batches += 1
            progress = completed_batches / total_batches
            print(f"Overall Progress: {progress * 100:.2f}%")
            print(
                f"Time to finish (est.): {average_sample_time * (sample_number - i - 1) / 60:.2f} min"
            )

        sample_end = time.time()
        sample_time.append(sample_end - sample_start)
        average_sample_time = sum(sample_time) / len(sample_time)
        final_samples.append(all_samples)

        rmse, rmse_median = calculate_rmse(
            final_samples, training_mean, training_std)
        print(f"RMSE: {rmse:.3f} | RMSE (Median): {rmse_median:.3f}")

        if mlflow.active_run() is not None:
            mlflow.log_metric("RMSE", rmse, step=i)
            mlflow.log_metric("RMSE_Median", rmse_median, step=i)

    return final_samples

In [None]:
# import numpy as np
# import random
# from torch.utils.data import DataLoader
# from torch.utils.data import Sampler
# from collections import defaultdict
# import torch
# from torch.utils.data import Dataset


# class CustomTimeSeriesDataset(Dataset):
#     def __init__(
#         self, data_tensor, sequence_lengths, min_seq_length=None, max_seq_length=None
#     ):
#         # Store the initial sequences and lengths
#         self.data_tensor = data_tensor
#         self.sequence_lengths = sequence_lengths
#         self.min_seq_length = min_seq_length
#         self.max_seq_length = max_seq_length

#         # Filter sequences based on min and max sequence length
#         if self.min_seq_length is not None:
#             valid_indices = [
#                 i
#                 for i, length in enumerate(self.sequence_lengths)
#                 if length >= self.min_seq_length
#             ]
#         else:
#             valid_indices = list(range(len(self.sequence_lengths)))

#         if self.max_seq_length is not None:
#             valid_indices = [
#                 i
#                 for i in valid_indices
#                 if self.sequence_lengths[i] <= self.max_seq_length
#             ]

#         self.data_tensor = self.data_tensor[valid_indices]
#         self.sequence_lengths = [self.sequence_lengths[i]
#                                  for i in valid_indices]

#     def __len__(self):
#         return len(self.sequence_lengths)

#     def __getitem__(self, idx):
#         seq_length = int(self.sequence_lengths[idx])
#         return self.data_tensor[idx, :seq_length, :], seq_length


# # class CustomTimeSeriesDataset(Dataset):
# #     def __init__(self, data_tensor, sequence_lengths, min_seq_length=None, max_seq_length=None):
# #         self.data_tensor = data_tensor
# #         self.sequence_lengths = sequence_lengths

# #     def __len__(self):
# #         return len(self.sequence_lengths)

# #     def __getitem__(self, idx):
# #         seq_length = int(self.sequence_lengths[idx])
# #         return self.data_tensor[idx, :seq_length, :], seq_length


# class LengthBatchSampler(Sampler):
#     def __init__(self, dataset, batch_size):
#         self.sequence_lengths = dataset.sequence_lengths
#         self.batch_size = batch_size
#         self.batches = self._create_batches()

#     def _create_batches(self):
#         length_to_indices = defaultdict(list)
#         for idx, length in enumerate(self.sequence_lengths):
#             length_to_indices[length].append(idx)

#         batches = []
#         for length, indices in length_to_indices.items():
#             # Split indices into batches of the specified batch size
#             for i in range(0, len(indices), self.batch_size):
#                 batches.append(indices[i: i + self.batch_size])
#         return batches

#     def __iter__(self):
#         for batch in self.batches:
#             yield batch

#     def __len__(self):
#         return len(self.batches)


# def collate_fn(batch):
#     data, lengths = zip(*batch)
#     data = torch.stack(data)
#     return data


# def create_dataloader(
#     data_tensor, sequence_lengths, batch_size, min_seq_length=None, max_seq_length=None
# ):
#     dataset = CustomTimeSeriesDataset(
#         data_tensor, sequence_lengths, min_seq_length, max_seq_length
#     )
#     sampler = LengthBatchSampler(dataset, batch_size)
#     dataloader = DataLoader(
#         dataset, batch_sampler=sampler, collate_fn=collate_fn)
#     return dataloader


# class BalancedLengthBatchSampler(Sampler):
#     def __init__(self, dataset, batch_size, balance_factor=1.0):
#         self.sequence_lengths = dataset.sequence_lengths
#         self.batch_size = batch_size
#         self.balance_factor = balance_factor
#         self.batches = self._create_balanced_batches()

#     def _create_balanced_batches(self):
#         # Group indices by sequence length
#         length_to_indices = defaultdict(list)
#         for idx, length in enumerate(self.sequence_lengths):
#             length_to_indices[length].append(idx)

#         # Calculate the maximum count of indices for balancing
#         max_count = max(len(indices) for indices in length_to_indices.values())

#         # Balance the distribution of sequence lengths by oversampling shorter sequences
#         balanced_batches = []
#         for length, indices in length_to_indices.items():
#             count = len(indices)
#             if count < max_count:
#                 repeat_factor = int(self.balance_factor * (max_count / count))
#                 # Correctly oversample the list elements
#                 oversampled_indices = indices * repeat_factor
#                 # oversampled_indices = oversampled_indices[:max_count]
#             else:
#                 oversampled_indices = indices

#             # Shuffle the indices of this particular length
#             random.shuffle(oversampled_indices)

#             # Create batches for this length
#             for i in range(0, len(oversampled_indices), self.batch_size):
#                 batch = oversampled_indices[i: i + self.batch_size]
#                 if len(batch) == self.batch_size:
#                     balanced_batches.append(batch)

#         # Shuffle the list of balanced batches to ensure random order
#         random.shuffle(balanced_batches)

#         return balanced_batches

#     def __iter__(self):
#         for batch in self.batches:
#             yield batch

#     def __len__(self):
#         return len(self.batches)


# def create_balanced_dataloader(
#     data_tensor,
#     sequence_lengths,
#     batch_size,
#     balance_factor=1.0,
#     min_seq_length=None,
#     max_seq_length=None,
# ):
#     dataset = CustomTimeSeriesDataset(
#         data_tensor, sequence_lengths, min_seq_length, max_seq_length
#     )
#     sampler = BalancedLengthBatchSampler(dataset, batch_size, balance_factor)
#     dataloader = DataLoader(
#         dataset, batch_sampler=sampler, collate_fn=collate_fn)
#     return dataloader

In [None]:
# train_loader = create_balanced_dataloader(
#     training_data_tensor,
#     training_data_sequence_lengths,
#     batch_size=100,
#     balance_factor=1.0,
#     min_seq_length=2,
#     max_seq_length=None,
# )

# val_loader = create_balanced_dataloader(
#     validation_data_tensor,
#     validation_data_sequence_lengths,
#     batch_size=100,
#     balance_factor=1.0,
#     min_seq_length=2,
#     max_seq_length=None,
# )

# # check the size of the train, val, and test sets
# print(len(train_loader))
# print(len(val_loader))

In [None]:
# training_data_tensor.shape

In [None]:
# categorical_indices_sizes = {
#     # 'time_step' : [0, 1],
#     "chemo_application_prev": [0, 1],
#     "radio_application_prev": [1, 1],
#     "patient_type_tile": [2, 1],
# }

# numerical_indices = {"cancer_volume": 3}

# numerical_indices = {
#     'chemo_application_prev': 0,
#     'radio_application_prev': 1,
#     'patient_type_tile': 2,
#     'cancer_volume': 3
# }


# training_data_tensor_embedded = data_embedder(training_data_tensor)
# validation_data_tensor_embedded = data_embedder(validation_data_tensor)
# test_data_factuals_tensor_embedded = data_embedder(test_data_factuals_tensor)
# test_data_counterfactuals_tensor_embedded = data_embedder(test_data_counterfactuals_tensor)
# test_data_seq_tensor_embedded = data_embedder(test_data_seq_tensor)

In [None]:
import torch
%matplotlib inline
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import time
import statistics
from itertools import chain
import numpy as np
from collections import deque

def train(model, data_loader, data_loader_validation, epochs, lr, loss_func, weighted_loss_func = None, batch_embedder = None, gradient_clip = 1.0,
          windowed_mode=False, window_mode="uniform", window_start_mode="random", min_window=50, max_window=100, neg_bin_p=0.95, train_on_all_every=4,
          annealing_mode = False, annealing_window=5, annealing_multiplier=1.25, annealing_ratio = 0.5, annealing_minimum = 1e-6,
          device="cuda", num_gpus=1, verbose=False, plot_every=10,
          validation_frequency=1, validation_prp=10, moving_avg_window=10):

    # Check for GPU availability
    available_gpus = torch.cuda.device_count()
    if available_gpus < num_gpus:
        print(f"Requested {num_gpus} GPUs, but only {available_gpus} are available.")
        num_gpus = available_gpus
    else:
        print(f"Using {num_gpus} GPUs for training.")
        #also print gpu model
        print(f"GPU model: {torch.cuda.get_device_name(0)}")
    
    if num_gpus > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(num_gpus)))
        if batch_embedder is not None:
            batch_embedder = torch.nn.DataParallel(batch_embedder, device_ids=list(range(num_gpus)))

    if batch_embedder is not None:
        batch_embedder = batch_embedder.to(device)
    model = model.to(device)
    
    if batch_embedder is not None:
        optimizer = torch.optim.Adam(
            chain(batch_embedder.parameters(), model.parameters()),
            lr=lr
        )
    else:
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=lr
        )

    model.train()

    # if batch_embedder is not None:
    #     batch_embedder.train()
    
    loss_list = []
    initial_value = 1.0  # Initial value for equal probability
    window_losses = torch.ones(max_window - min_window + 1, device=device) * initial_value  # Track losses for each window length
    window_counts = torch.zeros(max_window - min_window + 1, device=device)  # Track counts for each window length
    loss_deques = [deque(maxlen=moving_avg_window) for _ in range(max_window - min_window + 1)]  # Deques for moving average
    if windowed_mode and window_mode == "biased_loss":
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 12))
    else:
        fig, ax1 = plt.subplots(1, 1, figsize=(10, 6))
    epoch_loss_list = []
    val_loss = 0

    total_time_start = time.time()

    for epoch in range(epochs):
        # Annealing for the learning rate
        if annealing_mode and epoch > annealing_window:
            if len(epoch_loss_list) > 0 and epoch_loss_list[-1] >= annealing_multiplier * (statistics.mean(epoch_loss_list[-annealing_window:])):
                for g in optimizer.param_groups:
                    if g['lr'] * annealing_ratio < annealing_minimum:
                        g['lr'] = annealing_minimum
                    else:
                        g['lr'] *= annealing_ratio

        start = time.time()
        for i, batch in enumerate(data_loader):

            if 'stabilized_weights' in batch:
                stabilized_weights = batch['stabilized_weights']
            curr_treatments = batch['current_treatments']
            vitals_or_prev_outputs = []
            # vitals_or_prev_outputs.append(batch['vitals']) if self.has_vitals else None
            # if self.autoregressive else None
            vitals_or_prev_outputs.append(batch['prev_outputs'])
            vitals_or_prev_outputs = torch.cat(vitals_or_prev_outputs, dim=-1)
            static_features = batch['static_features']
            outputs = batch['outputs']

            batch = torch.cat((vitals_or_prev_outputs, curr_treatments), dim=-1)
            batch = torch.cat((batch, static_features.unsqueeze(
                1).expand(-1, batch.size(1), -1)), dim=-1)
            batch = torch.cat((batch, outputs), dim=-1)

            batch = batch.to(device)

            # if batch_embedder is not None:
            #     batch = batch_embedder(batch)

            batch_length = batch.shape[1]
            # batch_length = max_window

            # Windowed mode logic
            if windowed_mode:
                if batch_length < min_window:
                    continue
                if window_start_mode == "random":
                    cut_start = torch.randint(0, batch_length - window_length + 1, (1,)).item()
                elif window_start_mode == "fixed":
                    cut_start = 0
                if window_mode == "uniform":
                    while True:
                        window_length = torch.randint(min_window, batch_length + 1, (1,)).item()
                        cut_end = cut_start + window_length
                        if min_window <= (cut_end - cut_start) <= batch_length:
                            break
                    batch = batch[:, cut_start:cut_end, :]
                    stabilized_weights = stabilized_weights[:, cut_start:cut_end]

                elif window_mode == "negative_binomial":
                    total_count = 1
                    probs = neg_bin_p
                    distribution = torch.distributions.NegativeBinomial(total_count=total_count, probs=probs)
                    while True:
                        window_length = distribution.sample().item() + min_window
                        cut_end = cut_start + window_length
                        if min_window <= window_length <= batch_length:
                            break
                    batch = batch[:, cut_start:cut_end, :]

                elif window_mode == "biased_loss":
                    if torch.min(window_counts) < 2:
                        # Use uniform distribution until each length has been used at least twice
                        window_probs = torch.ones_like(window_losses) / len(window_losses)
                    elif torch.sum(window_counts) % train_on_all_every == 0:
                        window_probs = torch.ones_like(window_losses) / len(window_losses)
                    else:
                        # Update probabilities based on moving average of losses
                        avg_losses = torch.tensor([np.mean(loss_deque) if len(loss_deque) > 0 else initial_value for loss_deque in loss_deques], device=device)
                        window_probs = avg_losses / avg_losses.sum()
                    while True:
                        window_length = torch.multinomial(window_probs, 1).item() + min_window
                        #check if the window length does work with the batch length
                        if window_length > batch_length:
                            continue
                        cut_end = cut_start + window_length
                        if min_window <= window_length <= batch_length:
                            break
                    batch = batch[:, cut_start:cut_end, :]
                    window_counts[window_length - min_window] += 1  # Update window counts

            optimizer.zero_grad()
            predicted_noise, noise, noise_mask = model(batch)
            
            if weighted_loss_func is not None:
                loss = weighted_loss_func(predicted_noise, noise, noise_mask, stabilized_weights)
            else:
                loss = loss_func(predicted_noise, noise, noise_mask)
            
            loss.backward()
            # # Gradient clipping
            if gradient_clip is not None:
                max_grad_norm = gradient_clip
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                if batch_embedder is not None:
                    torch.nn.utils.clip_grad_norm_(batch_embedder.parameters(), max_grad_norm)
            optimizer.step()
            loss_list.append(loss.item())

            epoch_loss = sum(loss_list[-len(data_loader):]) / len(data_loader)
            epoch_loss_list.append(epoch_loss)
            
            # Update window losses and moving average deque
            if windowed_mode and window_mode == "biased_loss":
                window_idx = window_length - min_window
                window_losses[window_idx] += loss.item()
                loss_deques[window_idx].append(loss.item())

            # Dynamic plot update
            if i % plot_every == 0:
                ax1.clear()
                ax1.set_ylim(0, 1)
                ax1.plot(loss_list)
                if len(loss_list) > 100:
                    ax1.plot(np.convolve(loss_list, np.ones((100,))/100, mode='valid'))
                    ax1.text(len(loss_list) - 1, np.convolve(loss_list, np.ones((100,))/100, mode='valid')[-1],
                            str(round(np.convolve(loss_list, np.ones((100,))/100, mode='valid')[-1], 3)))
                if len(epoch_loss_list) > 0:
                    ax1.text(0.1, 0.9, f"Epoch: {epoch} | Learning rate: {optimizer.param_groups[0]['lr']:.2e}")
                # ax1.text(0.1, 0.8, f"Learning rate: {optimizer.param_groups[0]['lr']:.4e}")
                ax1.text(0.1, 0.8, f"Loss: {epoch_loss_list[-1]:.3e} | Validation loss: {val_loss:.3e}")
                ax1.text(0.1, 0.7, f"Time per step: {((time.time() - start) / (i + 1)):.2f} s | Time per epoch: {((time.time() - start) / (i + 1) * len(data_loader)):.2f} s")
                ax1.text(0.1, 0.6, f"Time till finish (est.): {((time.time() - start) / (i + 1) * len(data_loader) * (epochs - epoch)) / 60:.2f} min")
                if windowed_mode and window_mode == "biased_loss":
                    ax2.clear()
                    ax2.bar(range(min_window, max_window + 1), window_counts.cpu().numpy())
                    ax2.set_ylabel("Counts")
                    ax2.set_title("Counts of Each Window Length Used")

                    moving_avg_losses = torch.tensor([np.mean(loss_deque) if len(loss_deque) > 0 else initial_value for loss_deque in loss_deques], device=device).cpu().numpy()
                    ax3.clear()
                    ax3.bar(range(min_window, max_window + 1), moving_avg_losses)
                    ax3.set_xlabel("Window Length")
                    ax3.set_ylabel("Moving Average Loss")
                    ax3.set_title("Moving Average Loss for Each Window Length")

                display(fig)
                clear_output(wait=True)

        end = time.time()

        # Validation
        if epoch % validation_frequency == 0:
            loss_list_validation = []
            for i, batch in enumerate(data_loader_validation):
                curr_treatments = batch['current_treatments']
                vitals_or_prev_outputs = []
                # vitals_or_prev_outputs.append(batch['vitals']) if self.has_vitals else None
                # if self.autoregressive else None
                vitals_or_prev_outputs.append(batch['prev_outputs'])
                vitals_or_prev_outputs = torch.cat(vitals_or_prev_outputs, dim=-1)
                static_features = batch['static_features']
                outputs = batch['outputs']

                batch = torch.cat((vitals_or_prev_outputs, curr_treatments), dim=-1)
                batch = torch.cat((batch, static_features.unsqueeze(
                    1).expand(-1, batch.size(1), -1)), dim=-1)
                batch = torch.cat((batch, outputs), dim=-1)

                batch = batch.to(device)
                # if batch_embedder is not None:
                #     batch = batch_embedder(batch)
                if i % validation_prp == 0:
                    predicted_noise, noise, noise_mask = model(batch)
                    loss = loss_func(predicted_noise, noise, noise_mask)
                    loss_list_validation.append(loss.item())

            val_loss = np.mean(loss_list_validation)

        total_time_end = time.time()
        total_time = total_time_end - total_time_start
        
        if verbose:
            print(f"Epoch {epoch} completed in {end - start} seconds, Loss: {epoch_loss}")
            print(f"Validation Loss: {val_loss}")
    
    print(f"Took {total_time} seconds for {epoch} epochs.")
            


    return model, loss_list

In [None]:
# import gc

# torch.cuda.empty_cache()
# gc.collect()

In [None]:
# diffusion_imputer = diffusion_imputation(
#     emb_dim=128,
#     # strategy="forecasting_last_n_time",
#     # strategy="random",
#     # missing_prp=0.5,
#     # strategy='selected_features',
#     strategy="selected_features_last_n_time",
#     last_n_time=1,
#     features_to_impute=[0, 1, 2, 3],
#     # excluded_features = [i for i in range(6)], #[2],#[0,1,2,3,5], #for the embedded stock names which we don't need to predict
#     # strategy='selected_features_and_selected_features_after_time',
#     # features_to_impute_completely=[2],
#     # features_to_impute_after_time=[3],
#     num_residual_layers=2,
#     diffusion_steps=50,
#     diffusion_beta_schedule="quadratic",
#     num_heads=8,
#     kernel_size=(1, 1),
#     ff_dim=512,
#     num_cells=1,
#     dropout=0,
#     # csdi, csdi_moded_transformer, rsa, rsa_moded_transformer, moded_transformer_alone, rsa_csdi
#     method="rsa_moded_transformer",
#     device="cuda",
# )

# # data_embedder = DataEmbedder(
# #     categorical_indices_sizes, numerical_indices, training_data_tensor
# # )

In [None]:
# train(
#     diffusion_imputer,
#     train_dataloader,
#     val_dataloader,
#     num_gpus=1,
#     batch_embedder=None,
#     gradient_clip=True,
#     windowed_mode=True,
#     window_mode="uniform",
#     window_start_mode="fixed",
#     # train_on_all_every=100,
#     min_window=2,
#     max_window=60,
#     device="cuda",
#     epochs=100,
#     lr=1e-4,
#     annealing_mode=True,
#     annealing_window=2,
#     annealing_multiplier=1.0,
#     annealing_ratio=0.5,
#     annealing_minimum=1e-7,
#     loss_func=diffusion_imputer.loss_func,
#     weighted_loss_func=diffusion_imputer.weighted_loss_func,
#     validation_frequency=5,
#     validation_prp=1,
#     verbose=False,
#     plot_every=20
# )

In [None]:
# train_hybrid(
#     diffusion_imputer,
#     hybrid_model,
#     train_loader,
#     val_loader,
#     batch_embedder = embedder,
#     epochs = 20,
#     lr = 0.001,
#     annealing_window = 5,
#     annealing_multiplier = 1,
#     loss_func = diffusion_imputer.loss_func,
#     hybrid_loss_func = hybrid_model.loss_func,
#     hybrid_start_epoch = 0,
#     hybrid_every_n_epoch = 5,
#     validation_frequency=2,
#     validation_prp=1)

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import mlflow
# import time
# import math
# import statistics
# from torch.utils.data import DataLoader
# from itertools import chain
# from omegaconf import OmegaConf


# def distill_teacher_to_student(teacher, student, dataloader, loss_fn, num_epochs, lr, device="cuda", gradient_clip=1.0,
#                                amp_scale=True, plot_every=10, plot_display_window=100, moving_avg_window=10):
#     """
#     Distills a teacher diffusion imputation model (with many diffusion steps)
#     into a student model (with fewer steps) using a simple MSE loss.

#     The teacher is kept in evaluation mode while the student is trained.
#     """

#     teacher.eval()
#     student.train()
#     teacher.to(device)
#     student.to(device)

#     optimizer = torch.optim.Adam(
#         student.parameters(),
#         lr=lr
#     )

#     # Initialize dynamic plotting
#     loss_list = []
#     fig, ax1 = plt.subplots(1, 1, figsize=(10, 6))
#     epoch_loss_list = []

#     if mlflow.active_run() is not None:
#         mlflow.log_param("distillation_teacher_steps", teacher.diffusion_steps)
#         mlflow.log_param("distillation_student_steps", student.diffusion_steps)

#     if amp_scale:
#         scaler = torch.cuda.amp.GradScaler()

#     for epoch in range(num_epochs):
#         epoch_loss = 0.0
#         start = time.time()
#         for i, batch in enumerate(dataloader):
#             sequence_lengths = batch['sequence_lengths']
#             student.sequence_length = sequence_lengths
#             teacher.sequence_length = sequence_lengths
#             if 'stabilized_weights' in batch:
#                 stabilized_weights = batch['stabilized_weights']
#             curr_treatments = batch['current_treatments']
#             vitals_or_prev_outputs = []
#             vitals_or_prev_outputs.append(batch['prev_outputs'])
#             vitals_or_prev_outputs = torch.cat(vitals_or_prev_outputs, dim=-1)
#             static_features = batch['static_features']
#             outputs = batch['outputs']

#             batch = torch.cat(
#                 (vitals_or_prev_outputs, curr_treatments), dim=-1)

#             batch = torch.cat((batch, static_features.unsqueeze(
#                 1).expand(-1, batch.size(1), -1)), dim=-1)
#             batch = torch.cat((batch, outputs), dim=-1)
#             if batch.shape[0] == 0:
#                 continue
#             batch = batch.to(device)
#             optimizer.zero_grad()

#             if amp_scale:
#                 with torch.cuda.amp.autocast():
#                     # Teacher output (no grad)
#                     with torch.no_grad():
#                         teacher_output, teacher_noise, teacher_noise_mask = teacher(
#                             batch)
#                     student_output, student_noise, student_noise_mask = student(
#                         batch)
#                     loss = loss_fn(student_output, teacher_output,
#                                    student_noise_mask)
#                     if math.isnan(loss.item()):
#                         continue
#                 scaler.scale(loss).backward()
#                 if gradient_clip is not None:
#                     scaler.unscale_(optimizer)
#                     torch.nn.utils.clip_grad_norm_(
#                         student.parameters(), gradient_clip)
#                 else:
#                     scaler.unscale_(optimizer)
#                 optimizer.step()
#                 scaler.update()
#                 loss_list.append(loss.item())
#             else:
#                 with torch.no_grad():
#                     teacher_output, teacher_noise, teacher_noise_mask = teacher(
#                         batch)
#                 student_output, student_noise, student_noise_mask = student(
#                     batch)
#                 loss = loss_fn(student_output, teacher_output,
#                                student_noise_mask)
#                 if math.isnan(loss.item()):
#                     continue
#                 loss.backward()
#                 if gradient_clip is not None:
#                     torch.nn.utils.clip_grad_norm_(
#                         model.parameters(), gradient_clip)
#                 optimizer.step()
#                 loss_list.append(loss.item())

#             epoch_loss = sum(loss_list[-len(dataloader):]) / len(dataloader)
#             epoch_loss_list.append(epoch_loss)

#             # Dynamic plot update focusing on recent losses
#             if i % plot_every == 0:
#                 ax1.clear()
#                 # Only consider the recent losses for plotting (e.g., last 100)
#                 display_window = plot_display_window
#                 recent_losses = loss_list[-display_window:] if len(
#                     loss_list) > display_window else loss_list
#                 # Dynamically set y-limits based on recent loss values
#                 if recent_losses:
#                     min_loss = min(recent_losses)
#                     max_loss = max(recent_losses)
#                     margin = 0.1 * \
#                         (max_loss - min_loss) if max_loss != min_loss else 0.1 * max_loss
#                     ax1.set_ylim(min_loss - margin, max_loss + margin)
#                 ax1.plot(recent_losses)
#                 if len(recent_losses) > moving_avg_window:
#                     moving_avg = np.convolve(recent_losses, np.ones(
#                         (moving_avg_window,)) / moving_avg_window, mode='valid')
#                     x = np.arange(moving_avg_window - 1, len(recent_losses))
#                     ax1.plot(x, moving_avg)
#                     ax1.text(len(recent_losses) - 1, moving_avg[-1],
#                              f"{moving_avg[-1]:.3e}")
#                 if epoch_loss_list:
#                     ax1.text(
#                         0.1, 0.9, f"Epoch: {epoch} | LR: {optimizer.param_groups[0]['lr']:.2e}", transform=ax1.transAxes)
#                 ax1.text(
#                     0.1, 0.8, f"Loss: {epoch_loss_list[-1]:.3e}", transform=ax1.transAxes)
#                 ax1.text(
#                     0.1, 0.7, f"Time per step: {((time.time() - start) / (i + 1)):.2f} s | Time per epoch: {((time.time() - start) / (i + 1) * len(dataloader)):.2f} s", transform=ax1.transAxes)
#                 ax1.text(
#                     0.1, 0.6, f"Time till finish: {((time.time() - start) / (i + 1) * len(dataloader) * (num_epochs - epoch)) / 60:.2f} min", transform=ax1.transAxes)
#                 display(fig)
#                 clear_output(wait=True)

#         avg_loss = epoch_loss / len(dataloader)
#         elapsed = time.time() - start
#         # print(
#         #     f"Distillation Epoch {epoch+1}/{num_epochs}: Avg Loss = {avg_loss:.4f} (Epoch time: {elapsed:.2f} s)")

#         # Log distillation loss to MLflow if active
#         if mlflow.active_run() is not None:
#             mlflow.log_metric("distillation_epoch_loss", avg_loss, step=epoch)

#     return student


# # Suppose your existing diffusion_imputer is the teacher.
# teacher = diffusion_imputer
# # Create a student model with similar architecture but fewer diffusion steps.
# diffusion_imputer_student = diffusion_imputation(
#     emb_dim=64,
#     strategy="selected_features_sequence_length",
#     features_to_impute=[6],
#     num_residual_layers=1,            # Possibly fewer residual layers
#     # Fewer diffusion steps (e.g., 20 instead of 1000)
#     diffusion_steps=50,
#     diffusion_beta_schedule="quadratic",
#     num_heads=8,
#     kernel_size=(1, 1),
#     ff_dim=512,
#     num_cells=1,
#     dropout=0,
#     method="csdi",
#     device="cuda",
# )

# # Define a dataloader for distillation (could be the same as train_dataloader)
# distill_dataloader = DataLoader(
#     weighted_causal_diff.dataset_collection.val_f,
#     batch_size=1000,
#     shuffle=True,
#     num_workers=1,
#     pin_memory=True
# )

# distill_loss_fn = teacher.loss_func  # nn.MSELoss()

# diffusion_imputer_student = distill_teacher_to_student(
#     teacher, diffusion_imputer_student, distill_dataloader,
#     distill_loss_fn, num_epochs=20000, lr=5e-5, gradient_clip=5.0
# )

In [None]:
import os
import sys
import math
import copy
import logging
import requests
import zipfile
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from einops import rearrange

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from omegaconf import DictConfig, OmegaConf

from src.data import RealDatasetCollection, SyntheticDatasetCollection
from src.data.cancer_sim.dataset import SyntheticCancerDataset
from src.models import TimeVaryingCausalModel
from src.models.utils import (
    grad_reverse,
    BRTreatmentOutcomeHead,
    AlphaRise,
    clip_normalize_stabilized_weights,
)
from src.models.utils_lstm import VariationalLSTM
from copy import deepcopy

logger = logging.getLogger(__name__)

# Data embedding before feeding into the model

# This will first one-hot encode all the categorical features and then embed them to n columns. The resulting columns will then be concatenated with the numerical features. The result will then be used to create the torch tensor for the model. The torch tensor will be shaped as (Cases, Time, Features).

# The input data will be a dataframe like this:


class DataEmbedder(nn.Module):
    def __init__(self, categorical_indices_sizes, numerical_indices, dataset):
        super(DataEmbedder, self).__init__()
        # dictionary with feature name, and a list of index and size
        self.categoricals = categorical_indices_sizes
        self.numerics = numerical_indices  # dictionary with feature name and index
        self.embeddings = nn.ModuleDict()
        self.mapping_dicts = {}

        # Initialize embeddings and mapping dictionaries
        for key in self.categoricals:
            unique_values = np.unique(dataset[:, :, self.categoricals[key][0]])
            self.mapping_dicts[key] = {
                name: idx for idx, name in enumerate(unique_values)
            }
            self.embeddings[key] = nn.Embedding(
                num_embeddings=len(unique_values),
                embedding_dim=self.categoricals[key][1],
            )
            print(
                f"Feature: {key}, Categories: {len(unique_values)}, Embedding Size: {self.categoricals[key][1]}"
            )

    def forward(self, dataset):
        # Apply embeddings to the categorical indices
        if len(self.categoricals) == 0:
            return dataset
        else:
            embedded_features = []
            for key in self.categoricals:
                # Map the categorical values to their corresponding indices
                indices = dataset[:, :,
                                  self.categoricals[key][0]].cpu().numpy()
                mapped_indices = np.vectorize(
                    self.mapping_dicts[key].get)(indices)
                mapped_indices = torch.tensor(
                    mapped_indices, dtype=torch.long, device=dataset.device
                )
                # print(f"Feature: {key}, Mapped Indices: {mapped_indices}")
                embedded_features.append(self.embeddings[key](mapped_indices))

            embedded_features = torch.cat(embedded_features, dim=-1)

            numeric_features = dataset[:, :, list(
                self.numerics.values())].float()

            # Concatenate the embedded features with the numerical data
            result = torch.cat([embedded_features, numeric_features], dim=-1)

            feature_count_embedded = len(self.numerics) + sum(
                [self.categoricals[key][1] for key in self.categoricals]
            )

            result = result.reshape(
                dataset.shape[0], -1, feature_count_embedded)

            return result

class moded_TimesSeriesAttention(nn.Module):
    """
    A module that computes multi-head attention given query, key, and value tensors for time series data of shape (b, t, f, e)
    """

    def __init__(self, embed_dim: int, num_heads: int):
        """
        Constructor.

        Inputs:
        - input_dim: Dimension of the input query, key, and value. We assume they all have
          the same dimensions. This is basically the dimension of the embedding.
        - num_heads: Number of attention heads
        """
        super(moded_TimesSeriesAttention, self).__init__()

        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dim_per_head = embed_dim // num_heads

        self.linear_query = nn.Linear(embed_dim, embed_dim)
        self.linear_key = nn.Linear(embed_dim, embed_dim)
        self.linear_value = nn.Linear(
            embed_dim, embed_dim
        )  # (self.num_heads * self.dim_per_head * self.dim_per_head))
        self.output_linear = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax2d()

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor = None,
    ):
        """
        Compute the attended feature representations.

        Inputs:
        - query: Tensor of the shape BxTxFXE, where B is the batch size, T is the time dimension, F is the feature dimension,
        and E is the embedding dimension
        - key: Tensor of the shape BxTxFXE
        - value: Tensor of the shape BxTxFXE
        - mask: Tensor indicating where the attention should *not* be performed
        """
        b = query.shape[0]
        t = query.shape[1]
        f = query.shape[2]
        e = query.shape[3]
        d = self.dim_per_head
        h = self.num_heads

        query_linear = self.linear_query(query)
        key_linear = self.linear_key(key)
        value_linear = self.linear_value(value)

        query_reshaped = query_linear.reshape(
            b, t, f, self.num_heads, self.dim_per_head
        )
        key_reshaped = key_linear.reshape(
            b, t, f, self.num_heads, self.dim_per_head)
        value_reshaped = value_linear.reshape(
            b, t, f, self.num_heads, self.dim_per_head
        )  # , self.dim_per_head)

        query_reshaped = query_reshaped.permute(0, 3, 1, 2, 4)  # BxHxTxFxD
        key_reshaped = key_reshaped.permute(0, 3, 1, 2, 4)  # BxHxTxFxD
        value_reshaped = value_reshaped.permute(
            0, 3, 1, 2, 4)  # , 5) # BxHxTxFxDxD

        kq = torch.einsum("bhtfd,bhxyd->bhtfxy", key_reshaped, query_reshaped)

        dot_prod_scores = kq / math.sqrt(self.dim_per_head)

        # softmax across last 2 features (use softmax2d)
        dot_prod_scores = dot_prod_scores.reshape(b * h, t * f, t, f)
        dot_prod_scores = self.softmax(dot_prod_scores)
        dot_prod_scores = dot_prod_scores.reshape(b, h, t, f, t, f)

        out = torch.einsum("bhtfxy,bhtfd->bhtfd",
                           dot_prod_scores, value_reshaped)
        out = out.permute(0, 2, 3, 1, 4).reshape(b, t, f, e)
        out = self.output_linear(out)

        return out
    
class moded_TransformerEncoderCell(nn.Module):
    """
    A single cell (unit) for the Transformer encoder.
    """

    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout: float):
        """
        Inputs:
        - embed_dim: embedding dimension for each element in the time series data
        - num_heads: Number of attention heads in a multi-head attention module
        - ff_dim: The hidden dimension for a feedforward network
        - dropout: Dropout ratio for the output of the multi-head attention and feedforward
          modules.
        """
        super(moded_TransformerEncoderCell, self).__init__()

        self.time_series_attention = moded_TimesSeriesAttention(
            embed_dim, num_heads)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.activation = nn.ReLU()
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embed_dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        """
        Inputs:
        - x: Tensor of the shape BxTxFXE, where B is the batch size, T is the time dimension, F is the feature dimension,
        and E is the embedding dimension
        - mask: Tensor for multi-head attention
        """

        attention2 = self.time_series_attention(x, x, x, mask)
        attention = x + self.dropout1(attention2)
        attention = self.layer_norm(attention)

        attention2 = self.linear2(
            self.dropout(self.activation(self.linear1(attention)))
        )
        attention = attention + self.dropout2(attention2)
        attention = self.layer_norm(attention)

        return attention
class moded_TransformerEncoder(nn.Module):
    """
    A full encoder consisting of a set of TransformerEncoderCell.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        ff_dim: int,
        num_cells: int,
        dropout: float = 0.1,
    ):
        """
        Inputs:
        - embed_dim: embedding dimension for each element in the time series data
        - num_heads: Number of attention heads in a multi-head attention module
        - ff_dim: The hidden dimension for a feedforward network
        - num_cells: Number of time series attention cells in the encoder
        - dropout: Dropout ratio for the output of the multi-head attention and feedforward
          modules.
        """
        super(moded_TransformerEncoder, self).__init__()

        self.norm = None

        self.encoder_modules = nn.ModuleList(
            moded_TransformerEncoderCell(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_cells)
        )
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        """
        Inputs:
        - x: Tensor of the shape BxTxFXE, where B is the batch size, T is the time dimension, F is the feature dimension,
        and E is the embedding dimension
        - mask: Tensor for multi-head attention

        Return:
        - y: Tensor of the shape BxTxFXE
        """

        # run encoder modules and add residual connections
        for encoder_module in self.encoder_modules:
            x = encoder_module(x, mask)

        y = x

        return y

## RSA
class TimesSeriesAttention(nn.Module):
    def __init__(
        self,
        d_in,
        d_out,
        nh=8,
        dk=0,
        dv=0,
        dd=0,
        kernel_size=(3, 7),
        stride=(1, 1, 1),
        kernel_type="VplusR",  # ['V', 'R', 'VplusR']
        feat_type="VplusR",  # ['V', 'R', 'VplusR']
    ):
        super(TimesSeriesAttention, self).__init__()

        self.d_in = d_in
        self.d_out = d_out
        self.nh = nh
        self.dv = dv = d_out // nh if dv == 0 else dv
        self.dk = dk = dv if dk == 0 else dk
        self.dd = dd = dk if dd == 0 else dd

        self.kernel_size = kernel_size
        self.stride = stride
        self.kernel_type = kernel_type
        self.feat_type = feat_type

        assert self.kernel_type in [
            "V",
            "R",
            "VplusR",
        ], "Not implemented involution type: {}".format(self.kernel_type)
        assert self.feat_type in [
            "V",
            "R",
            "VplusR",
        ], "Not implemented feature type: {}".format(self.feat_type)

        # print("d_in: {}, d_out: {}, nh: {}, dk: {}, dv: {}, dd:{}, kernel_size: {}, kernel_type: {}, feat_type: {}"
        #       .format(d_in, d_out, nh, dk, dv, self.dd, kernel_size, kernel_type, feat_type))

        self.ksize = ksize = kernel_size[0] * kernel_size[1]
        self.pad = pad = tuple(k // 2 for k in kernel_size)

        # hidden dimension
        d_hid = nh * dk + dv if self.kernel_type == "V" else nh * dk + dk + dv

        # Linear projection
        # self.projection = nn.Conv2d(d_in, d_hid, 1, bias=False)
        self.projection_linear = nn.Sequential(
            nn.Linear(d_in, d_hid, bias=False),
            nn.SiLU(inplace=True),
            nn.Linear(d_hid, d_hid, bias=False),
        )

        # Intervolution Kernel
        if self.kernel_type == "V":
            self.H2 = nn.Conv2d(1, dd, kernel_size,
                                padding=self.pad, bias=False)
        elif self.kernel_type == "R":
            self.H1 = nn.Conv2d(
                dk, dk * dd, kernel_size, padding=self.pad, groups=dk, bias=False
            )
            self.H2 = nn.Conv2d(1, dd, kernel_size,
                                padding=self.pad, bias=False)
        elif self.kernel_type == "VplusR":
            self.P1 = nn.Parameter(
                torch.randn(dk, dd).unsqueeze(0) * np.sqrt(1 / (ksize * dd)),
                requires_grad=True,
            )
            self.H1 = nn.Conv2d(
                dk, dk * dd, kernel_size, padding=self.pad, groups=dk, bias=False
            )
            self.H2 = nn.Conv2d(1, dd, kernel_size,
                                padding=self.pad, bias=False)
        else:
            raise NotImplementedError

        # Feature embedding layer
        if self.feat_type == "V":
            pass
        elif self.feat_type == "R":
            self.G = nn.Conv2d(1, dv, kernel_size,
                               padding=self.pad, bias=False)
        elif self.feat_type == "VplusR":
            self.G = nn.Conv2d(1, dv, kernel_size,
                               padding=self.pad, bias=False)
            self.I = nn.Parameter(
                torch.eye(dk).unsqueeze(0), requires_grad=True)
        else:
            raise NotImplementedError

        # Downsampling layer
        if max(self.stride) > 1:
            self.avgpool = nn.AvgPool2d(
                kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)
            )

    def L2norm(self, x, d=1):
        eps = 1e-6
        norm = x**2
        norm = norm.sum(dim=d, keepdim=True) + eps
        norm = norm ** (0.5)
        return x / norm

    def forward(self, x):

        # print(x.shape)
        x = x.permute(0, 3, 1, 2)
        N, C, T, H = x.shape

        x = x.permute(0, 2, 3, 1)

        """Linear projection"""
        # x_proj = self.projection(x)
        x_proj = self.projection_linear(x)
        x_proj = x_proj.permute(0, 3, 1, 2)
        # print(x_proj.shape)

        if self.kernel_type != "V":
            q, k, v = torch.split(
                x_proj, [self.nh * self.dk, self.dk, self.dv], dim=1)
        else:
            q, v = torch.split(x_proj, [self.nh * self.dk, self.dv], dim=1)

        """Normalization"""
        q = rearrange(q, "b (nh k) t h -> b nh k t h", k=self.dk)
        q = self.L2norm(q, d=2)
        q = rearrange(q, "b nh k t h -> (b t h) nh k")

        v = self.L2norm(v, d=1)

        if self.kernel_type != "V":
            k = self.L2norm(k, d=1)

        """
        q = (b t h) nh k
        k = b k t h
        v = b v t h
        """

        # Intervolution generation
        # Basic kernel
        if self.kernel_type == "V":
            kernel = q
        # Relational kernel
        else:
            K_H1 = self.H1(k)
            K_H1 = rearrange(K_H1, "b (k d) t h-> (b t h) k d", k=self.dk)

            if self.kernel_type == "VplusR":
                K_H1 = K_H1 + self.P1

            kernel = torch.einsum(
                "abc,abd->acd", q.transpose(1, 2), K_H1
            )  # (bth, nh, d)

        # feature generation
        # Appearance feature
        v = rearrange(v, "b (v 1) t h-> (b v) 1 t h")

        V = self.H2(v)  # (bv, d, t, h)
        feature = rearrange(V, "(b v) d t h -> (b t h) v d", v=self.dv)

        # Relational feature
        if self.feat_type in ["R", "VplusR"]:
            V_G = self.G(v)  # (bv, v2, t, h)
            V_G = rearrange(V_G, "(b v) v2 t h -> (b t h) v v2", v=self.dv)

            if self.feat_type == "VplusR":
                V_G = V_G + self.I

            feature = torch.einsum("abc,abd->acd", V_G,
                                   feature)  # (bth, v2, d)

        # kernel * feat
        out = torch.einsum("abc,adc->adb", kernel, feature)  # (bth, nh, v2)

        out = rearrange(out, "(b t h) nh v -> b (nh v) t h", t=T, h=H)

        if max(self.stride) > 1:
            out = self.avgpool(out)

        out = out.permute(0, 2, 3, 1)

        return out
class TransformerEncoderCell(nn.Module):
    """
    A single cell (unit) for the Transformer encoder.
    """

    def __init__(
        self, embed_dim: int, num_heads: int, kernel_size, ff_dim: int, dropout: float
    ):
        """
        Inputs:
        - embed_dim: embedding dimension for each element in the time series data
        - num_heads: Number of attention heads in a multi-head attention module
        - ff_dim: The hidden dimension for a feedforward network
        - dropout: Dropout ratio for the output of the multi-head attention and feedforward
          modules.
        """
        super(TransformerEncoderCell, self).__init__()

        self.time_series_attention = TimesSeriesAttention(
            embed_dim, embed_dim, nh=num_heads, kernel_size=kernel_size
        )
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.activation = nn.ReLU()
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embed_dim)

    def forward(self, data: torch.Tensor, embeddings, mask: torch.Tensor = None):
        """
        Inputs:
        - x: Tensor of the shape BxTxFXE, where B is the batch size, T is the time dimension, F is the feature dimension,
        and E is the embedding dimension
        - mask: Tensor for multi-head attention
        """

        # attention2 = self.time_series_attention(x, x, x, mask)
        attention2 = self.time_series_attention(data)
        attention = data + self.dropout1(attention2)
        attention = self.layer_norm(attention)

        attention2 = self.linear2(
            self.dropout(self.activation(self.linear1(attention)))
        )
        attention = attention + self.dropout2(attention2)
        attention = self.layer_norm(attention)

        return attention
class TransformerEncoder(nn.Module):
    """
    A full encoder consisting of a set of TransformerEncoderCell.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        kernel_size,
        ff_dim: int,
        num_cells: int,
        dropout: float = 0.1,
    ):
        """
        Inputs:
        - embed_dim: embedding dimension for each element in the time series data
        - num_heads: Number of attention heads in a multi-head attention module
        - ff_dim: The hidden dimension for a feedforward network
        - num_cells: Number of time series attention cells in the encoder
        - dropout: Dropout ratio for the output of the multi-head attention and feedforward
          modules.
        """
        super(TransformerEncoder, self).__init__()

        self.norm = None

        self.encoder_modules = nn.ModuleList(
            TransformerEncoderCell(embed_dim, num_heads,
                                   kernel_size, ff_dim, dropout)
            for _ in range(num_cells)
        )
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        """
        Inputs:
        - x: Tensor of the shape BxTxFXE, where B is the batch size, T is the time dimension, F is the feature dimension,
        and E is the embedding dimension
        - mask: Tensor for multi-head attention

        Return:
        - y: Tensor of the shape BxTxFXE
        """

        # run encoder modules and add residual connections
        for encoder_module in self.encoder_modules:
            x = encoder_module(x, mask)

        y = x

        return y
## CSDI transformer

def get_torch_trans(num_heads=8, num_cells=1, embed_dim=128, ff_dim=512, dropout=0.1):
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=embed_dim,
        nhead=num_heads,
        dim_feedforward=ff_dim,
        activation="gelu",
        dropout=dropout,
    )
    return nn.TransformerEncoder(encoder_layer, num_layers=num_cells)
## Embeddings


class ContinuousDiffusionEmbedding(nn.Module):
    def __init__(self, embedding_dim, projection_dim=None, max_steps=1000):
        super(ContinuousDiffusionEmbedding, self).__init__()
        if projection_dim is None:
            projection_dim = embedding_dim
        self.embedding_dim = embedding_dim
        self.projection1 = nn.Linear(embedding_dim, projection_dim)
        self.projection2 = nn.Linear(projection_dim, embedding_dim)
        # maximum steps expected (for normalization)
        self.max_steps = max_steps

    def forward(self, diffusion_step, data, device="cpu"):
        # Ensure diffusion_step is on the correct device.
        diffusion_step = diffusion_step.to(device)

        # If diffusion_step is a scalar or a 1-element tensor, expand it to match the batch size.
        if diffusion_step.dim() == 0 or diffusion_step.numel() == 1:
            diffusion_step = diffusion_step.expand(data.shape[0])

        # Normalize diffusion step to [0, 1]
        t_normalized = diffusion_step.float() / self.max_steps  # shape: (batch,) or (1,)

        # Use half for sine and half for cosine.
        dim = self.embedding_dim // 2
        frequencies = 10.0 ** (torch.arange(dim, device=device,
                               dtype=torch.float) / (dim - 1) * 4.0)

        # Compute angles: (batch, dim)
        angles = t_normalized.unsqueeze(1) * frequencies.unsqueeze(0)

        # Compute sine and cosine parts and concatenate to get (batch, embedding_dim)
        embedding = torch.cat([torch.sin(angles), torch.cos(angles)], dim=1)

        # Pass through projection layers with SiLU activation.
        embedding = self.projection1(embedding)
        embedding = F.silu(embedding)
        embedding = self.projection2(embedding)
        embedding = F.silu(embedding)

        # Expand the embedding to match the shape of the input data.
        # Assume data shape is (b, t, f, e)
        embedding = embedding.unsqueeze(1).unsqueeze(1)  # shape: (b, 1, 1, e)
        embedding = embedding.expand(-1, data.shape[1], data.shape[2], -1)
        return embedding

# class DiffusionEmbedding(nn.Module):
#     def __init__(self, num_steps, embedding_dim, projection_dim=None):
#         super(DiffusionEmbedding, self).__init__()
#         if projection_dim is None:
#             projection_dim = embedding_dim
#         self.register_buffer(
#             "embedding",
#             self._build_embedding(num_steps, embedding_dim / 2),
#             persistent=False,
#         )
#         self.projection1 = nn.Linear(embedding_dim, projection_dim)
#         self.projection2 = nn.Linear(projection_dim, embedding_dim)

#     def forward(self, diffusion_step, data, device="cpu"):
#         x = self.embedding[diffusion_step]
#         x = self.projection1(x)
#         x = F.silu(x)
#         x = self.projection2(x)
#         x = F.silu(x)
#         x = torch.zeros(data.shape).to(device) + x.unsqueeze(1).unsqueeze(1)
#         return x

#     def _build_embedding(self, num_steps, dim=64):
#         steps = torch.arange(num_steps).unsqueeze(1)  # (T,1)
#         frequencies = 10.0 ** (torch.arange(dim) / (dim - 1) * 4.0).unsqueeze(
#             0
#         )  # (1,dim)
#         table = steps * frequencies  # (T,dim)
#         table = torch.cat(
#             [torch.sin(table), torch.cos(table)], dim=1)  # (T,dim*2)
#         return table


class TimeEmbedding(nn.Module):
    def __init__(self, embedding_dim, max_len=10000.0):
        super(TimeEmbedding, self).__init__()
        self.max_len = max_len
        self.learnable = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.SiLU(),
            nn.Linear(embedding_dim, embedding_dim),
        )

    def forward(self, data, device="cpu"):

        b, l, f, e = data.shape
        pe = None
        pe_row = torch.arange(l)

        pe = pe_row.unsqueeze(0)
        pe = pe.unsqueeze(2)

        pe = pe.repeat((b, 1, e))
        pe = pe.float()

        pe[:, :, 0::2] = torch.sin(
            pe[:, :, 0::2] / (self.max_len ** (torch.arange(0, e, 2) / e))
        )
        pe[:, :, 1::2] = torch.cos(
            pe[:, :, 1::2] / (self.max_len ** (torch.arange(0, e, 2) / e))
        )

        pe = pe.to(device).unsqueeze(2).repeat((1, 1, f, 1))

        # pe = torch.arange(l).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).to(device)
        # pe = torch.zeros(data.shape).to(device) + pe

        # div_term = 1 / torch.pow(
        #     self.max_len, torch.arange(0, f, 2) / f
        # ).unsqueeze(-1).to(device)

        # pe[:, :, 0::2] = torch.sin(pe[:, :, 0::2] * div_term)
        # pe[:, :, 1::2] = torch.cos(pe[:, :, 1::2] * div_term)

        return self.learnable(pe)


class FeatureEmbedding(nn.Module):
    def __init__(self, embedding_dim, max_len=10000.0):
        super(FeatureEmbedding, self).__init__()
        self.max_len = max_len
        self.learnable = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.SiLU(),
            nn.Linear(embedding_dim, embedding_dim),
        )

    def forward(self, data, device="cpu"):
        b, l, f, e = data.shape

        pe = None
        pe_row = torch.arange(f)

        pe = pe_row.unsqueeze(0)
        pe = pe.unsqueeze(2)

        pe = pe.repeat((b, 1, e))
        pe = pe.float()

        pe[:, :, 0::2] = torch.sin(
            pe[:, :, 0::2] / (self.max_len ** (torch.arange(0, e, 2) / e))
        )
        pe[:, :, 1::2] = torch.cos(
            pe[:, :, 1::2] / (self.max_len ** (torch.arange(0, e, 2) / e))
        )

        pe = pe.to(device).unsqueeze(1).repeat((1, l, 1, 1))

        # pe = torch.arange(f).unsqueeze(0).unsqueeze(0).unsqueeze(-1).to(device)
        # pe = torch.zeros(data.shape).to(device) + pe

        # div_term = 1 / torch.pow(
        #     self.max_len, torch.arange(0, e, 2) / e
        # ).to(device)

        # pe[:, :, :, 0::2] = torch.sin(pe[:, :, :, 0::2] * div_term)
        # pe[:, :, :, 1::2] = torch.cos(pe[:, :, :, 1::2] * div_term)

        return self.learnable(pe)
# Residual block

def Conv1d_with_init(in_channels, out_channels, kernel_size):
    layer = nn.Conv1d(in_channels, out_channels, kernel_size)
    nn.init.kaiming_normal_(layer.weight)
    return layer
class ResidualBlock(nn.Module):
    def __init__(
        self,
        num_heads=8,
        num_cells=1,
        kernel_size=(3, 7),
        embed_dim=128,
        ff_dim=512,
        dropout=0.1,
        method="rsa",
    ):
        super().__init__()

        self.method = method

        self.embedding_add = nn.Sequential(
            nn.Linear(embed_dim * 4, embed_dim * 4),
            nn.SiLU(),
            nn.Linear(embed_dim * 4, embed_dim * 2),
            nn.SiLU(),
            nn.Linear(embed_dim * 2, embed_dim),
        )

        self.layer_norm = nn.LayerNorm(embed_dim)

        self.mid_projection = Conv1d_with_init(embed_dim, 2 * embed_dim, 1)
        # nn.Linear(embed_dim, embed_dim*2)
        self.output_projection = Conv1d_with_init(embed_dim, 2 * embed_dim, 1)
        # self.output_projection = nn.Linear(embed_dim, embed_dim*2)

        if method == "rsa":
            self.feature_and_time_transformer = TransformerEncoder(
                embed_dim=embed_dim,
                num_heads=num_heads,
                kernel_size=kernel_size,
                ff_dim=ff_dim,
                num_cells=num_cells,
                dropout=dropout,
            )
            self.linear_time_and_feature = nn.Linear(embed_dim, embed_dim)

        elif method == "csdi":
            self.time_layer = get_torch_trans(
                num_heads=num_heads,
                num_cells=num_cells,
                embed_dim=embed_dim,
                ff_dim=ff_dim,
                dropout=dropout,
            )
            self.feature_layer = get_torch_trans(
                num_heads=num_heads,
                num_cells=num_cells,
                embed_dim=embed_dim,
                ff_dim=ff_dim,
                dropout=dropout,
            )
            self.linear_time = nn.Linear(embed_dim, embed_dim)
            self.linear_feature = nn.Linear(embed_dim, embed_dim)

        elif method == "csdi_moded_transformer":
            self.time_layer = get_torch_trans(
                num_heads=num_heads,
                num_cells=num_cells,
                embed_dim=embed_dim,
                ff_dim=ff_dim,
                dropout=dropout,
            )
            self.feature_layer = get_torch_trans(
                num_heads=num_heads,
                num_cells=num_cells,
                embed_dim=embed_dim,
                ff_dim=ff_dim,
                dropout=dropout,
            )
            self.linear_time = nn.Linear(embed_dim, embed_dim)
            self.linear_feature = nn.Linear(embed_dim, embed_dim)
            self.feature_and_time_transformer = moded_TransformerEncoder(
                embed_dim=embed_dim,
                num_heads=num_heads,
                ff_dim=ff_dim,
                num_cells=num_cells,
                dropout=dropout,
            )
            self.linear_time_and_feature = nn.Linear(embed_dim, embed_dim)

        elif method == "rsa_csdi":
            self.time_layer = get_torch_trans(
                num_heads=num_heads,
                num_cells=num_cells,
                embed_dim=embed_dim,
                ff_dim=ff_dim,
                dropout=dropout,
            )
            self.feature_layer = get_torch_trans(
                num_heads=num_heads,
                num_cells=num_cells,
                embed_dim=embed_dim,
                ff_dim=ff_dim,
                dropout=dropout,
            )
            self.linear_time = nn.Linear(embed_dim, embed_dim)
            self.linear_feature = nn.Linear(embed_dim, embed_dim)
            self.feature_and_time_transformer = TransformerEncoder(
                embed_dim=embed_dim,
                num_heads=num_heads,
                kernel_size=kernel_size,
                ff_dim=ff_dim,
                num_cells=num_cells,
                dropout=dropout,
            )
            self.linear_time_and_feature = nn.Linear(embed_dim, embed_dim)

        elif method == "rsa_moded_transformer":
            self.feature_and_time_transformer = TransformerEncoder(
                embed_dim=embed_dim,
                num_heads=num_heads,
                kernel_size=kernel_size,
                ff_dim=ff_dim,
                num_cells=num_cells,
                dropout=dropout,
            )
            self.linear_time_and_feature = nn.Linear(embed_dim, embed_dim)
            self.moded_feature_and_time_transformer = moded_TransformerEncoder(
                embed_dim=embed_dim,
                num_heads=num_heads,
                ff_dim=ff_dim,
                num_cells=num_cells,
                dropout=dropout,
            )
            self.moded_linear_time_and_feature = nn.Linear(
                embed_dim, embed_dim)

        elif method == "moded_transformer_alone":
            self.moded_feature_and_time_transformer = moded_TransformerEncoder(
                embed_dim=embed_dim,
                num_heads=num_heads,
                ff_dim=ff_dim,
                num_cells=num_cells,
                dropout=dropout,
            )
            self.moded_linear_time_and_feature = nn.Linear(
                embed_dim, embed_dim)

        elif method == "simple_neural_network":
            self.linear1 = nn.Linear(embed_dim, embed_dim)
            self.linear2 = nn.Linear(embed_dim, ff_dim)
            self.linear3 = nn.Linear(ff_dim, ff_dim)
            self.linear4 = nn.Linear(ff_dim, embed_dim)

        else:
            raise NotImplementedError

        logging.info("Initializing ResidualBlock with method: %s", method)

    def forward_time(self, y, base_shape):
        b, t, f, e = base_shape
        y = y.permute(0, 2, 1, 3).reshape(b * f, t, e)
        y = self.time_layer(y.permute(1, 0, 2)).permute(1, 0, 2)
        y = y.reshape(b, f, t, e).permute(0, 2, 1, 3)
        return y

    def forward_feature(self, y, base_shape):
        b, t, f, e = base_shape
        y = y.reshape(b * t, f, e)
        y = self.feature_layer(y.permute(1, 0, 2)).permute(1, 0, 2)
        y = y.reshape(b, t, f, e)
        return y

    def forward(self, noised_data, diffusion_emb, time_emb, feature_emb):

        logging.info("ResidualBlock forward started")

        b, t, f, e = noised_data.shape
        base_shape = noised_data.shape

        y = torch.stack((noised_data, diffusion_emb,
                        time_emb, feature_emb), dim=-1)
        y = y.reshape(b, t, f, -1)
        y = self.embedding_add(y)
        y_resid = y

        if self.method == "rsa":
            y = self.feature_and_time_transformer(y)
            y = y.squeeze(-1)
            y = self.linear_time_and_feature(y)

        elif self.method == "csdi":
            y = self.forward_time(y, base_shape)
            y = self.linear_time(y)
            y = (y + y_resid) / math.sqrt(2.0)
            y = self.layer_norm(y)
            y = self.forward_feature(y, base_shape)
            y = self.linear_feature(y)

        elif self.method == "csdi_moded_transformer":
            y = self.forward_time(y, base_shape)
            y = self.linear_time(y)
            y = (y + y_resid) / math.sqrt(2.0)
            y = self.layer_norm(y)
            y = self.forward_feature(y, base_shape)
            y = self.linear_feature(y)
            y = (y + y_resid) / math.sqrt(2.0)
            y_resid = y
            y = self.layer_norm(y)
            y = self.feature_and_time_transformer(y)
            y = self.linear_time_and_feature(y)

        elif self.method == "rsa_csdi":
            y = self.forward_time(y, base_shape)
            y = self.linear_time(y)
            y = (y + y_resid) / math.sqrt(2.0)
            y = self.layer_norm(y)
            y = self.forward_feature(y, base_shape)
            y = self.linear_feature(y)
            y = (y + y_resid) / math.sqrt(2.0)
            y_resid = y
            y = self.layer_norm(y)
            y = self.feature_and_time_transformer(y)
            y = y.squeeze(-1)
            y = self.linear_time_and_feature(y)

        elif self.method == "rsa_moded_transformer":
            y = self.feature_and_time_transformer(y)
            y = y.squeeze(-1)
            y = self.linear_time_and_feature(y)
            y = (y + y_resid) / math.sqrt(2.0)
            y = self.layer_norm(y)
            y = self.moded_feature_and_time_transformer(y)
            y = self.moded_linear_time_and_feature(y)

        elif self.method == "moded_transformer_alone":
            y = self.moded_feature_and_time_transformer(y)
            y = self.moded_linear_time_and_feature(y)

        elif self.method == "simple_neural_network":
            y = y.reshape(b, t * f, e)
            y = self.linear1(y)
            y = F.silu(y)
            y = self.linear2(y)
            y = F.silu(y)
            # y = self.linear3(y)
            # y = F.silu(y)
            y = self.linear4(y)
            y = y.reshape(b, t, f, e)

        y = (y + y_resid) / math.sqrt(2.0)
        y = self.layer_norm(y)
        y = y.permute(0, 3, 1, 2).reshape(b, e, t * f)
        y = self.mid_projection(y)
        # y = y.permute(0, 3, 2, 1).reshape(b, 2*e, t*f)

        gate, filter = torch.chunk(y, 2, dim=1)
        y = torch.sigmoid(gate) * torch.tanh(filter)  # (b,e,f*t)
        # y = y.permute(0, 2, 1)
        y = self.output_projection(y)
        # y = y.permute(0, 2, 1)

        residual, skip = torch.chunk(y, 2, dim=1)
        residual = residual.permute(0, 2, 1)
        skip = skip.permute(0, 2, 1)
        residual = residual.reshape(base_shape)
        skip = skip.reshape(base_shape)

        logging.info("ResidualBlock forward completed")

        return (noised_data + residual) / math.sqrt(2.0), skip
class ModelLoop(nn.Module):
    def __init__(
        self,
        embed_dim=128,
        # diffusion_steps=1000,
        max_steps=1000,
        num_heads=8,
        kernel_size=(3, 7),
        num_cells=1,
        num_residual_layers=4,
        ff_dim=512,
        dropout=0.1,
        method="rsa",
        device="cpu",
    ):
        super().__init__()

        self.device = device
        self.emb_dim = embed_dim

        # self.data_embedding_linear = nn.Sequential(
        #     nn.Linear(1, self.emb_dim),
        #     nn.SiLU(),
        #     nn.Linear(self.emb_dim, self.emb_dim)
        # )
        # self.x_embedding = nn.Sequential(
        #     nn.Linear(1, self.emb_dim),
        #     nn.SiLU(),
        #     nn.Linear(self.emb_dim, self.emb_dim)
        # )

        self.data_embedding_linear = Conv1d_with_init(1, self.emb_dim, 1)
        self.x_embedding = Conv1d_with_init(2, self.emb_dim, 1)

        self.output = Conv1d_with_init(self.emb_dim, 1, 1)
        self.output_final = Conv1d_with_init(self.emb_dim, 1, 1)

        # self.x_add = nn.Sequential(
        #     nn.Linear(embed_dim*num_residual_layers, embed_dim*num_residual_layers),
        #     nn.SiLU(),
        #     nn.Linear(embed_dim*num_residual_layers, embed_dim)
        # )

        # self.diffusion_embedding = DiffusionEmbedding(
        #     diffusion_steps, embed_dim)
        self.diffusion_embedding = ContinuousDiffusionEmbedding(
            embedding_dim=embed_dim,
            max_steps=max_steps
        )
        self.time_embedding = TimeEmbedding(embed_dim)
        self.feature_embedding = FeatureEmbedding(embed_dim)

        self.residual_layers = nn.ModuleList(
            ResidualBlock(
                num_heads=num_heads,
                num_cells=num_cells,
                kernel_size=kernel_size,
                embed_dim=embed_dim,
                ff_dim=ff_dim,
                dropout=dropout,
                method=method,
            )
            for _ in range(num_residual_layers)
        )

        # self.output = nn.Sequential(
        #     nn.Linear(self.emb_dim, self.emb_dim),
        #     nn.SiLU(),
        #     nn.Linear(self.emb_dim, 1)
        # )

        # self.output_final = nn.Sequential(
        #     nn.Linear(self.emb_dim, self.emb_dim),
        #     nn.SiLU(),
        #     nn.Linear(self.emb_dim, 1)
        # )

        logging.info("Initializing ModelLoop with embed_dim: %s, method: %s", embed_dim, method)

    def forward(self, noised_data, noise_mask, diffusion_t):

        logging.info("ModelLoop forward: noised_data shape: %s", noised_data.shape)

        b, t, f, a = noised_data.shape

        noised_data_reshaped = noised_data.permute(
            0, 3, 1, 2).reshape(b, 1, t * f)
        noised_data_embedded = (
            self.data_embedding_linear(noised_data_reshaped)
            .permute(0, 2, 1)
            .reshape(b, t, f, self.emb_dim)
        )
        diffusion_embedding = self.diffusion_embedding(
            diffusion_t, noised_data_embedded, device=self.device
        )
        time_embedding = self.time_embedding(
            noised_data_embedded, device=self.device)
        feature_embedding = self.feature_embedding(
            noised_data_embedded, device=self.device
        )

        x = noised_data_embedded
        skip = []
        for layer in self.residual_layers:
            x, skip_connection = layer(
                x, diffusion_embedding, time_embedding, feature_embedding
            )
            skip.append(skip_connection)
            x = x.permute(0, 3, 1, 2).reshape(b, self.emb_dim, t * f)
            x = self.output(x).permute(0, 2, 1).reshape(b, t, f)
            x = torch.stack((x, noised_data.squeeze(-1)), dim=-1)
            # x = x * noise_mask + noised_data * (1 - noise_mask)
            x = x.permute(0, 3, 1, 2).reshape(b, 2, t * f)
            x = self.x_embedding(x).permute(
                0, 2, 1).reshape(b, t, f, self.emb_dim)

        x = torch.sum(torch.stack(skip, dim=-1), dim=-1) / math.sqrt(
            len(self.residual_layers)
        )
        # x = torch.stack(skip, dim = -1).reshape(b, t, f, -1)
        # x = self.x_add(x)
        x = x.permute(0, 3, 1, 2).reshape(b, self.emb_dim, t * f)
        x = self.output_final(x).permute(
            0, 2, 1).reshape(b, t, f, 1).squeeze(-1)

        logging.info("ModelLoop forward: output shape: %s", x.shape)

        return x
# Beta Schedules

def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
    """
    Get a pre-defined beta schedule for the given name.

    The beta schedule library consists of beta schedules which remain similar
    in the limit of num_diffusion_timesteps.
    Beta schedules may be added, but should not be removed or changed once
    they are committed to maintain backwards compatibility.
    """
    if schedule_name == "linear":
        # Linear schedule from Ho et al, extended to work for any number of
        # diffusion steps.
        if num_diffusion_timesteps < 100:
            scale = 100 / num_diffusion_timesteps
        else:
            scale = 1000 / num_diffusion_timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return torch.linspace(beta_start, beta_end, num_diffusion_timesteps)

    elif schedule_name == "cosine":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )

    elif schedule_name == "quadratic":
        scale = 50 / num_diffusion_timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.5
        return (
            torch.linspace(beta_start**0.5, beta_end**0.5,
                           num_diffusion_timesteps) ** 2
        )

    else:
        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].

    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return torch.tensor(betas)
# Imputer

class diffusion_imputation(nn.Module):
    def __init__(
        self,
        emb_dim,
        excluded_features=None,
        # vocab_size,
        # pad_idx= None,
        strategy="random",
        num_residual_layers=4,
        features_to_impute=None,
        features_to_impute_completely=None,
        features_to_impute_after_time=None,
        last_n_time=1,
        missing_prp=0.1,
        diffusion_steps=1000,
        max_steps=1000,
        diffusion_beta_schedule="quadratic",
        num_heads=8,
        kernel_size=(3, 7),
        ff_dim=512,
        num_cells=2,
        dropout=0.1,
        method="rsa",
        device="cpu",
        sequence_length=None
    ):

        super().__init__()

        self.device = device
        self.emb_dim = emb_dim
        self.strategy = strategy
        self.features_to_impute = features_to_impute
        self.exclude_features = excluded_features
        self.num_residual_layers = num_residual_layers
        self.features_to_impute_completely = features_to_impute_completely
        self.features_to_impute_after_time = features_to_impute_after_time
        self.last_n_time = last_n_time
        self.missing_prp = missing_prp
        self.diffusion_steps = diffusion_steps
        self.diffusion_beta_schedule = diffusion_beta_schedule
        self.num_heads = num_heads
        self.kernel_size = kernel_size
        self.ff_dim = ff_dim
        self.num_cells = num_cells
        self.dropout = dropout
        self.method = method
        self.sequence_length = sequence_length
        self.max_steps = max_steps

        # set device to cuda if available
        if torch.cuda.is_available():
            self.device = "cuda"

        self.model_loop = ModelLoop(
            embed_dim=self.emb_dim,
            # diffusion_steps=diffusion_steps,
            max_steps=max_steps,
            num_heads=num_heads,
            kernel_size=kernel_size,
            ff_dim=ff_dim,
            num_cells=num_cells,
            dropout=dropout,
            num_residual_layers=num_residual_layers,
            method=method,
            device=self.device,
        )

        self.beta = get_named_beta_schedule(
            diffusion_beta_schedule, diffusion_steps)

        # self.beta = torch.linspace(0.0001, 0.5, diffusion_steps)

        # self.beta = torch.linspace(
        #         0.0001 ** 0.5, 0.5 ** 0.5, diffusion_steps
        #     ) ** 2

        self.alpha_hat = 1 - self.beta
        self.alpha = torch.cumprod(self.alpha_hat, dim=0)
        self.alpha_torch = torch.tensor(self.alpha).float()

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def get_mask(self, data, strategy="random"):

        b = data.shape[0]
        t = data.shape[1]
        f = data.shape[2]

        if strategy == "forecasting":
            forecasted_time = torch.randint(2, t, (b, 1, 1))
            mask = torch.zeros_like(data)
            for i in range(b):
                mask[i, forecasted_time[i]:, :] = 1

        if strategy == "forecasting_last_n_time":
            mask = torch.zeros_like(data)
            mask[:, -self.last_n_time, :] = 1

        if strategy == "death_prediction":
            mask = torch.zeros_like(data)
            # death is the last 7 columns of the data
            mask[:, :, -1] = 1

        if strategy == "random_features":
            selected_features = torch.randint(0, f, (b, 1, 1))
            mask = torch.zeros_like(data)
            mask[:, :, selected_features] = 1

        if strategy == "selected_features":
            mask = torch.zeros_like(data)
            mask[:, :, self.features_to_impute] = 1

        if strategy == "selected_features_after_time":
            selected_time = torch.randint(1, t, (b, 1, 1))
            mask = torch.zeros_like(data)
            mask[:, selected_time:, self.features_to_impute] = 1

        if strategy == "selected_features_last_n_time":
            mask = torch.zeros_like(data)
            mask[:, -self.last_n_time:, self.features_to_impute] = 1

        if strategy == "selected_features_last_n_sequence_length":
            assert self.sequence_length is not None
            mask = torch.zeros_like(data)
            for i in range(self.sequence_length.shape[0]):
                sequence_length = int(self.sequence_length[i])
                if i < mask.shape[0]:
                    mask[i, (sequence_length - self.last_n_time)
                             :sequence_length, self.features_to_impute] = 1

        if strategy == "selected_features_sequence_length":
            mask = torch.zeros_like(data)
            for i in range(self.sequence_length.shape[0]):
                sequence_length = int(self.sequence_length[i])
                if i < mask.shape[0]:
                    mask[i, :sequence_length, self.features_to_impute] = 1

        if strategy == "whole_sequence":
            mask = torch.ones_like(data)

        if strategy == "random":
            mask = torch.rand(size=(b, t, f))
            mask = mask < self.missing_prp
            mask = mask.float()

        if strategy == "selected_features_and_selected_features_after_time":
            mask = torch.zeros_like(data)
            mask[:, :, self.features_to_impute_completely] = 1
            mask[:, -self.last_n_time:, self.features_to_impute_after_time] = 1

        if self.exclude_features is not None:
            mask[:, :, self.exclude_features] = 0

        return mask

    def loss_func(self, predicted_noise, noise, noise_mask):
        # noise = torch.nan_to_num(noise, nan=0.0)
        # predicted_noise = torch.nan_to_num(predicted_noise, nan=0.0)
        residual = noise - predicted_noise
        num_obs = torch.sum(noise_mask)
        loss = (residual**2).sum() / num_obs
        return loss

    def weighted_loss_func(self, predicted_noise, noise, noise_mask, stabilized_weights):
        # Calculate the residuals
        residual = noise - predicted_noise

        # Get the sample weights
        # print(f"stabilized_weights shape: {stabilized_weights.shape}")
        sw = stabilized_weights.to(self.device)
        # clip sw at 5th and 95th percentile
        sw = torch.clamp(sw, 0.05, 0.95)
        sw = sw.unsqueeze(-1).repeat(1, 1, residual.shape[-1]) * noise_mask
        # print(residual.shape)
        # print(sw)
        # Apply the sample weights to the squared residuals
        weighted_squared_residuals = (residual**2) * sw

        # Sum the weighted squared residuals
        weighted_loss_sum = weighted_squared_residuals.sum()

        # Normalize the loss by the sum of the weights
        loss = weighted_loss_sum / sw.sum()

        # print(loss)
        return loss

    def explode_trajectories(self, data, projection_horizon):

        self.data = data
        # assert self.processed

        # logger.info(f'Exploding {self.subset_name} dataset before testing (multiple sequences)')

        outputs = self.data['outputs']
        prev_outputs = self.data['prev_outputs']
        sequence_lengths = self.data['sequence_lengths']
        # vitals = self.data['vitals']
        # next_vitals = self.data['next_vitals']
        active_entries = self.data['active_entries']
        current_treatments = self.data['current_treatments']
        previous_treatments = self.data['prev_treatments']
        static_features = self.data['static_features']
        # repeat static features t times (first dimension in outputs)
        static_features = static_features.unsqueeze(
            1).repeat(1, outputs.shape[1], 1)
        if 'stabilized_weights' in self.data:
            stabilized_weights = self.data['stabilized_weights']

        num_patients, max_seq_length, num_features = outputs.shape
        num_seq2seq_rows = num_patients * max_seq_length

        seq2seq_previous_treatments = np.zeros(
            (num_seq2seq_rows, max_seq_length, previous_treatments.shape[-1]))
        seq2seq_current_treatments = np.zeros(
            (num_seq2seq_rows, max_seq_length, current_treatments.shape[-1]))
        seq2seq_static_features = np.zeros(
            (num_seq2seq_rows, max_seq_length, static_features.shape[-1]))
        seq2seq_outputs = np.zeros(
            (num_seq2seq_rows, max_seq_length, outputs.shape[-1]))
        seq2seq_prev_outputs = np.zeros(
            (num_seq2seq_rows, max_seq_length, prev_outputs.shape[-1]))
        # seq2seq_vitals = np.zeros((num_seq2seq_rows, max_seq_length, vitals.shape[-1]))
        # seq2seq_next_vitals = np.zeros((num_seq2seq_rows, max_seq_length - 1, next_vitals.shape[-1]))
        seq2seq_active_entries = np.zeros(
            (num_seq2seq_rows, max_seq_length, active_entries.shape[-1]))
        seq2seq_sequence_lengths = np.zeros(num_seq2seq_rows)
        if 'stabilized_weights' in self.data:
            seq2seq_stabilized_weights = np.zeros(
                (num_seq2seq_rows, max_seq_length))

        total_seq2seq_rows = 0  # we use this to shorten any trajectories later

        for i in range(num_patients):
            sequence_length = int(sequence_lengths[i])

            for t in range(projection_horizon, sequence_length):  # shift outputs back by 1
                seq2seq_active_entries[total_seq2seq_rows, :(
                    t + 1), :] = active_entries[i, :(t + 1), :]
                if 'stabilized_weights' in self.data:
                    seq2seq_stabilized_weights[total_seq2seq_rows, :(
                        t + 1)] = stabilized_weights[i, :(t + 1)]
                seq2seq_previous_treatments[total_seq2seq_rows, :(
                    t + 1), :] = previous_treatments[i, :(t + 1), :]
                seq2seq_current_treatments[total_seq2seq_rows, :(
                    t + 1), :] = current_treatments[i, :(t + 1), :]
                seq2seq_outputs[total_seq2seq_rows, :(
                    t + 1), :] = outputs[i, :(t + 1), :]
                seq2seq_prev_outputs[total_seq2seq_rows, :(
                    t + 1), :] = prev_outputs[i, :(t + 1), :]
                seq2seq_static_features[total_seq2seq_rows, :(
                    t + 1), :] = static_features[i, :(t + 1), :]
                # seq2seq_vitals[total_seq2seq_rows, :(t + 1), :] = vitals[i, :(t + 1), :]
                # seq2seq_next_vitals[total_seq2seq_rows, :min(t + 1, sequence_length - 1), :] = \
                #     next_vitals[i, :min(t + 1, sequence_length - 1), :]
                seq2seq_sequence_lengths[total_seq2seq_rows] = t + 1
                # seq2seq_static_features[total_seq2seq_rows] = static_features[i]

                total_seq2seq_rows += 1

        # Filter everything shorter
        seq2seq_previous_treatments = seq2seq_previous_treatments[:total_seq2seq_rows, :, :]
        seq2seq_current_treatments = seq2seq_current_treatments[:total_seq2seq_rows, :, :]
        seq2seq_static_features = seq2seq_static_features[:total_seq2seq_rows, :]
        seq2seq_outputs = seq2seq_outputs[:total_seq2seq_rows, :, :]
        seq2seq_prev_outputs = seq2seq_prev_outputs[:total_seq2seq_rows, :, :]
        # seq2seq_vitals = seq2seq_vitals[:total_seq2seq_rows, :, :]
        # seq2seq_next_vitals = seq2seq_next_vitals[:total_seq2seqprocessed_rows, :, :]
        seq2seq_active_entries = seq2seq_active_entries[:total_seq2seq_rows, :, :]
        seq2seq_sequence_lengths = seq2seq_sequence_lengths[:total_seq2seq_rows]

        if 'stabilized_weights' in self.data:
            seq2seq_stabilized_weights = seq2seq_stabilized_weights[:total_seq2seq_rows]

        new_data = {
            'prev_treatments': seq2seq_previous_treatments,
            'current_treatments': seq2seq_current_treatments,
            'static_features': seq2seq_static_features,
            'prev_outputs': seq2seq_prev_outputs,
            'outputs': seq2seq_outputs,
            # 'vitals': seq2seq_vitals,
            # 'next_vitals': seq2seq_next_vitals,
            # 'unscaled_outputs': seq2seq_outputs * self.scaling_params['output_stds'] + self.scaling_params['output_means'],
            'sequence_lengths': seq2seq_sequence_lengths,
            'active_entries': seq2seq_active_entries,
        }
        if 'stabilized_weights' in self.data:
            new_data['stabilized_weights'] = seq2seq_stabilized_weights

        # self.data = new_data
        # self.exploded = True

        # data_shapes = {k: v.shape for k, v in self.data.items()}
        # logger.info(f'Shape of processed {self.subset_name} data: {data_shapes}')

        return new_data

    def get_exploded_dataset(self, dataset, min_length=1, only_active_entries=True, max_length=None):
        exploded_dataset = deepcopy(dataset)
        if max_length is None:
            max_length = max(exploded_dataset['sequence_lengths'][:])
        if not only_active_entries:
            exploded_dataset['active_entries'][:, :, :] = 1.0
            exploded_dataset['sequence_lengths'][:] = max_length
        # exploded_dataset.explode_trajectories(min_length)
        exploded_dataset = self.explode_trajectories(
            exploded_dataset, min_length)
        return exploded_dataset

    def forward(self, data):

        # data = self.get_exploded_dataset(data, 1, only_active_entries=True)
        # curr_treatments = data['current_treatments']
        # vitals_or_prev_outputs = []
        # # vitals_or_prev_outputs.append(data['vitals']) if self.has_vitals else None
        # # if self.autoregressive else None
        # vitals_or_prev_outputs.append(data['prev_outputs'])
        # vitals_or_prev_outputs = torch.cat(vitals_or_prev_outputs, dim=-1)
        # static_features = data['static_features']
        # outputs = data['outputs']

        # x = torch.cat((vitals_or_prev_outputs, curr_treatments), dim=-1)
        # x = torch.cat((x, static_features.unsqueeze(
        #     1).expand(-1, x.size(1), -1)), dim=-1)
        # x = torch.cat((x, outputs), dim=-1)
        # data = x
        # data = data.to(self.device)
        # print(f"Data shape: {data.shape}")
        # print(data)
        b, t, f = data.shape

        noise_mask = self.get_mask(data, self.strategy).to(self.device)
        # print(noise_mask[0])
        # print(data[0])
        noise = torch.randn((b, t, f)).to(self.device)
        noise = noise_mask * noise

        diffusion_t = torch.randint(0, self.diffusion_steps, (b, 1)).squeeze(1)
        alpha = self.alpha_torch[diffusion_t].unsqueeze(
            1).unsqueeze(2).to(self.device)

        noised_data = data * noise_mask
        noised_data = noised_data * (alpha**0.5) + noise * ((1 - alpha) ** 0.5)
        conditional_data = data * (1 - noise_mask)
        noised_data = noised_data + conditional_data
        noised_data = noised_data.float()

        predicted_noise = self.model_loop(
            noised_data.unsqueeze(3), noise_mask.unsqueeze(3), diffusion_t
        )
        predicted_noise = predicted_noise * noise_mask

        return (predicted_noise, noise, noise_mask)

    def eval_with_grad(self, data, scale=1):

        # with torch.no_grad():
        imputation_mask = self.get_mask(data, self.strategy).to(self.device)
        conditional_data = data * (1 - imputation_mask)
        random_noise = torch.randn_like(data) * imputation_mask * scale
        data_2 = (conditional_data + random_noise).unsqueeze(3)

        b, ti, f, e = data_2.shape
        imputed_samples = torch.zeros((b, ti, f)).to(self.device)
        x = conditional_data + random_noise

        for t in range(self.diffusion_steps - 1, -1, -1):

            x = x.unsqueeze(3).float()
            predicted_noise = self.model_loop(
                x, imputation_mask.unsqueeze(
                    3), torch.tensor([t]).to(self.device)
            )
            predicted_noise = predicted_noise * imputation_mask

            coeff1 = 1 / self.alpha_hat[t] ** 0.5
            coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5

            x = x.squeeze(3)
            x = coeff1 * (x - coeff2 * predicted_noise)

            if t > 0:
                noise = torch.randn_like(x)
                sigma = (
                    (1.0 - self.alpha[t - 1]) /
                    (1.0 - self.alpha[t]) * self.beta[t]
                ) ** 0.5
                x += sigma * noise

            x = data_2.squeeze(3) * (1 - imputation_mask) + x * imputation_mask

            imputed_samples = x

        return (imputed_samples, data, imputation_mask)

    # def get_predictions(
    #     self,
    #     data,
    #     imputation_mask,
    #     mean,
    #     std,
    #     scale=1,
    #     verbose=True,
    #     show_max_diff=False,
    #     show_rmse=False,
    # ):

    #     conditional_data = data * (1 - imputation_mask)
    #     random_noise = torch.randn_like(data) * imputation_mask * scale
    #     data_2 = (conditional_data + random_noise).unsqueeze(3)

    #     b, ti, f, e = data_2.shape
    #     imputed_samples = torch.zeros((b, ti, f)).to(self.device)
    #     x = conditional_data + random_noise

    #     with torch.no_grad():

    #         # for t in range(self.diffusion_steps - 1, -1, -1):
    #         for t in tqdm(range(self.diffusion_steps - 1, -1, -1), desc="Diffusion Steps", leave=False):
    #             x = x.unsqueeze(3).float()
    #             predicted_noise = self.model_loop(
    #                 x, imputation_mask.unsqueeze(
    #                     3), torch.tensor([t]).to(self.device)
    #             )
    #             predicted_noise = predicted_noise * imputation_mask

    #             coeff1 = 1 / self.alpha_hat[t] ** 0.5
    #             coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5

    #             x = x.squeeze(3)
    #             x = coeff1 * (x - coeff2 * predicted_noise)

    #             if t > 0:
    #                 noise = torch.randn_like(x)
    #                 sigma = (
    #                     (1.0 - self.alpha[t - 1]) /
    #                     (1.0 - self.alpha[t]) * self.beta[t]
    #                 ) ** 0.5
    #                 x += sigma * noise

    #             x = data_2.squeeze(3) * (1 - imputation_mask) + \
    #                 x * imputation_mask

    #         imputed_samples = x.detach()
    #         imputed_samples[torch.isnan(imputed_samples)] = 0

    #     if show_max_diff == True:
    #         # show the data at torch.max(torch.abs(data[imputation_mask !=0] - imputed_samples[imputation_mask !=0]))
    #         print(
    #             "max difference = ",
    #             torch.max(
    #                 torch.abs(
    #                     data[imputation_mask != 0]
    #                     - imputed_samples[imputation_mask != 0]
    #                 )
    #             ).item(),
    #         )
    #         print(
    #             "data at max difference = ",
    #             data[imputation_mask != 0][
    #                 torch.argmax(
    #                     torch.abs(
    #                         data[imputation_mask != 0]
    #                         - imputed_samples[imputation_mask != 0]
    #                     )
    #                 )
    #             ].item(),
    #         )
    #         print(
    #             "imputed at max difference = ",
    #             imputed_samples[imputation_mask != 0][
    #                 torch.argmax(
    #                     torch.abs(
    #                         data[imputation_mask != 0]
    #                         - imputed_samples[imputation_mask != 0]
    #                     )
    #                 )
    #             ].item(),
    #         )

    #     mae = torch.mean(
    #         torch.abs(
    #             data[imputation_mask != 0] -
    #             imputed_samples[imputation_mask != 0]
    #         )
    #     ).item()
    #     if verbose == True:
    #         print("mae = ", mae)

    #     if show_rmse == True:
    #         # descale the data
    #         imputed_samples_copy = imputed_samples.detach().clone()
    #         imputed_samples_copy = imputed_samples_copy * std + mean
    #         data_copy = data.detach().clone()
    #         data_copy = data_copy * std + mean
    #         rmse = torch.sqrt(
    #             torch.mean(
    #                 (
    #                     data_copy[imputation_mask != 0]
    #                     - imputed_samples_copy[imputation_mask != 0]
    #                 )
    #                 ** 2
    #             )
    #         ).item()
    #         rmse = rmse / 1150 * 100
    #         print("rmse = ", rmse)
    #     # data_to_print = data[imputation_mask !=0]
    #     # imputed_samples_to_print = imputed_samples[imputation_mask !=0]
    #     # print("data:", data_to_print)
    #     # print("imputed:", imputed_samples_to_print)
    #     # print("absolute difference in the first 100 : ", torch.abs(data_to_print - imputed_samples_to_print)[:100])
    #     # print("mae = ", torch.mean(torch.abs(data_to_print - imputed_samples_to_print)).item())

    #     return (imputed_samples, data, imputation_mask, mae)

    def get_predictions(
        self,
        data,
        imputation_mask,
        mean,
        std,
        extra_steps=0,
        scale=1,
        verbose=True,
        show_max_diff=False,
        show_rmse=False,
    ):
        # Compute total diffusion steps to use during inference.
        total_diffusion_steps = self.diffusion_steps + extra_steps

        # If extra steps are requested, generate an extended beta schedule.
        if extra_steps > 0:
            if self.diffusion_beta_schedule == "linear":
                if self.diffusion_steps < 100:
                    scale_lin = 100 / self.diffusion_steps
                else:
                    scale_lin = 1000 / self.diffusion_steps
                beta_start = scale_lin * 0.0001
                beta_end = scale_lin * 0.02
                extended_betas = torch.linspace(
                    beta_start, beta_end, total_diffusion_steps)
            elif self.diffusion_beta_schedule == "cosine":
                def alpha_bar(t):
                    return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
                # Create a sequence of time steps in [0,1]
                steps = torch.linspace(0, 1, total_diffusion_steps + 1)
                max_beta = 0.999
                beta_list = []
                for i in range(total_diffusion_steps):
                    t1 = steps[i].item()
                    t2 = steps[i + 1].item()
                    beta_val = min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)
                    beta_list.append(beta_val)
                extended_betas = torch.tensor(beta_list)
            elif self.diffusion_beta_schedule == "quadratic":
                scale_quad = 50 / self.diffusion_steps
                beta_start = scale_quad * 0.0001
                beta_end = scale_quad * 0.5
                extended_betas = (torch.linspace(
                    beta_start**0.5, beta_end**0.5, total_diffusion_steps)) ** 2
            else:
                raise NotImplementedError(
                    f"Unknown beta schedule: {self.diffusion_beta_schedule}")

            extended_alpha_hat = 1 - extended_betas
            extended_alpha = torch.cumprod(extended_alpha_hat, dim=0)
            extended_beta = extended_betas
        else:
            # Use the precomputed schedule from training.
            total_diffusion_steps = self.diffusion_steps
            extended_alpha = self.alpha_torch.to(self.device)
            extended_alpha_hat = 1 - self.beta.to(self.device)
            extended_beta = self.beta.to(self.device)

        conditional_data = data * (1 - imputation_mask)
        random_noise = torch.randn_like(data) * imputation_mask * scale
        data_2 = (conditional_data + random_noise).unsqueeze(3)

        b, ti, f, e = data_2.shape
        imputed_samples = torch.zeros((b, ti, f)).to(self.device)
        x = conditional_data + random_noise

        with torch.no_grad():
            # Reverse diffusion loop over the extended number of steps.
            for t in tqdm(range(total_diffusion_steps - 1, -1, -1), desc="Diffusion Steps", leave=False):
                x = x.unsqueeze(3).float()
                predicted_noise = self.model_loop(
                    x, imputation_mask.unsqueeze(
                        3), torch.tensor([t]).to(self.device)
                )
                predicted_noise = predicted_noise * imputation_mask

                # Compute coefficients using the extended schedule.
                coeff1 = 1 / (extended_alpha_hat[t] ** 0.5)
                coeff2 = (1 - extended_alpha_hat[t]) / \
                    ((1 - extended_alpha[t]) ** 0.5)

                x = x.squeeze(3)
                x = coeff1 * (x - coeff2 * predicted_noise)

                if t > 0:
                    noise = torch.randn_like(x)
                    sigma = (
                        ((1.0 - extended_alpha[t - 1]) / (1.0 - extended_alpha[t]) * extended_beta[t]) ** 0.5)
                    x += sigma * noise

                # Reinstate the conditional (unnoised) values.
                x = data_2.squeeze(3) * (1 - imputation_mask) + \
                    x * imputation_mask

            imputed_samples = x.detach()
            imputed_samples[torch.isnan(imputed_samples)] = 0

        if show_max_diff:
            print(
                "max difference = ",
                torch.max(
                    torch.abs(
                        data[imputation_mask != 0] -
                        imputed_samples[imputation_mask != 0]
                    )
                ).item(),
            )
            print(
                "data at max difference = ",
                data[imputation_mask != 0][
                    torch.argmax(
                        torch.abs(
                            data[imputation_mask != 0] -
                            imputed_samples[imputation_mask != 0]
                        )
                    )
                ].item(),
            )
            print(
                "imputed at max difference = ",
                imputed_samples[imputation_mask != 0][
                    torch.argmax(
                        torch.abs(
                            data[imputation_mask != 0] -
                            imputed_samples[imputation_mask != 0]
                        )
                    )
                ].item(),
            )

        mae = torch.mean(torch.abs(
            data[imputation_mask != 0] - imputed_samples[imputation_mask != 0])).item()
        if verbose:
            print("mae = ", mae)

        if show_rmse:
            imputed_samples_copy = imputed_samples.detach().clone() * std + mean
            data_copy = data.detach().clone() * std + mean
            rmse = torch.sqrt(torch.mean(
                (data_copy[imputation_mask != 0] - imputed_samples_copy[imputation_mask != 0]) ** 2)).item()
            rmse = rmse / 1150 * 100
            print("rmse = ", rmse)

        return (imputed_samples, data, imputation_mask, mae)

# # New main function for config handling and logging initialization
# def main(config: DictConfig):
#     logging.basicConfig(level=logging.INFO)
#     logger = logging.getLogger(__name__)
#     logger.info("Starting CausalDiff with config:\n%s", OmegaConf.to_yaml(config))

#     # Instantiate your model or application components using config args.
#     # Example:
#     # model = ModelLoop(
#     #     embed_dim=config.model.embed_dim,
#     #     diffusion_steps=config.model.diffusion_steps,
#     #     num_heads=config.model.num_heads,
#     #     kernel_size=config.model.kernel_size,
#     #     ff_dim=config.model.ff_dim,
#     #     num_cells=config.model.num_cells,
#     #     dropout=config.model.dropout,
#     #     num_residual_layers=config.model.num_residual_layers,
#     #     method=config.model.method,
#     #     device=config.device,
#     # )
#     # ...additional logic for training/evaluation...

# # Main guard to load config and start main()
# if __name__ == "__main__":
#     import sys
#     # Load config from the provided YAML file or use a default path
#     config_path = sys.argv[1] if len(sys.argv) > 1 else "configs/causaldiff.yaml"
#     config = OmegaConf.load(config_path)
#     main(config)
