In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

import torch.nn as nn
import torch.optim as optim


class PropensityNN(nn.Module):
    def __init__(self, input_size, output_size):
        super(PropensityNN, self).__init__()
        # self.fc1 = nn.Linear(input_size, 128)
        # self.fc2 = nn.Linear(128, 64)
        # self.fc3 = nn.Linear(64, output_size)
        self.simple_linear = nn.Linear(input_size, output_size)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # x = torch.relu(self.fc1(x))
        # x = torch.relu(self.fc2(x))
        # x = self.fc3(x)
        x = self.simple_linear(x)
        x = self.softmax(x)
        return x


class MSMPropensityTreatment(CausalDiff):

    model_type = 'propensity_treatment'

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: Union[RealDatasetCollection,
                                           SyntheticDatasetCollection] = None,
                 autoregressive: bool = None, has_vitals: bool = None, **kwargs):
        super().__init__(args, dataset_collection, autoregressive, has_vitals)

        self.input_size = self.dim_treatments
        logger.info(f'Input size of {self.model_type}: {self.input_size}')
        self.output_size = self.dim_treatments

        self.propensity_treatment = PropensityNN(
            self.input_size, self.output_size)
        self.optimizer = optim.Adam(
            self.propensity_treatment.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()
        self.save_hyperparameters(args)

    def get_inputs(self, dataset: Dataset) -> np.array:
        active_entries = dataset.data['active_entries']
        prev_treatments = dataset.data['prev_treatments']
        inputs = (prev_treatments * active_entries).sum(1)
        return inputs

    def fit(self):
        self.prepare_data()
        train_f = self.get_exploded_dataset(
            self.dataset_collection.train_f, min_length=self.lag_features)
        active_entries = train_f.data['active_entries']
        last_entries = active_entries - \
            np.concatenate([active_entries[:, 1:, :], np.zeros(
                (active_entries.shape[0], 1, 1))], axis=1)

        # Inputs
        inputs = self.get_inputs(train_f)

        # Outputs
        current_treatments = train_f.data['current_treatments']
        outputs = (current_treatments * last_entries).sum(1)

        # Convert to PyTorch tensors
        inputs = torch.tensor(inputs, dtype=torch.float32)
        outputs = torch.tensor(outputs, dtype=torch.float32)

        # Create DataLoader
        dataset = TensorDataset(inputs, outputs)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

        # Training loop
        # Training loop
        for epoch in range(args.exp.max_epochs):
            epoch_loss = 0.0
            for batch_inputs, batch_outputs in dataloader:
                self.optimizer.zero_grad()
                predictions = self.propensity_treatment(batch_inputs)
                loss = self.criterion(
                    predictions, batch_outputs)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()
            epoch_loss /= len(dataloader)
            print(
                f"Epoch {epoch+1}/{args.exp.max_epochs}, Loss: {epoch_loss:.4f}")
            # also show accuracy
            correct = 0
            total = 0
            with torch.no_grad():
                for batch_inputs, batch_outputs in dataloader:
                    outputs = self.propensity_treatment(batch_inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += batch_outputs.size(0)
                    # Convert one-hot targets to indices:
                    target_labels = batch_outputs.argmax(dim=1)
                    correct += (predicted == target_labels).sum().item()
            print(f"Accuracy: {100 * correct / total:.2f}%")


class MSMPropensityHistory(CausalDiff):

    model_type = 'propensity_history'

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: Union[RealDatasetCollection,
                                           SyntheticDatasetCollection] = None,
                 autoregressive: bool = None, has_vitals: bool = None, **kwargs):
        super().__init__(args, dataset_collection, autoregressive, has_vitals)

        self.input_size = self.dim_treatments + \
            self.dim_static_features + self.lag_features
        self.input_size += self.dim_vitals if self.has_vitals else 0
        self.input_size += self.dim_outcome if self.autoregressive else 0

        logger.info(f'Input size of {self.model_type}: {self.input_size}')
        self.output_size = self.dim_treatments

        self.propensity_history = PropensityNN(
            self.input_size, self.output_size)
        self.optimizer = optim.Adam(
            self.propensity_history.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()
        self.save_hyperparameters(args)

    def get_inputs(self, dataset: Dataset, projection_horizon=0) -> np.array:
        active_entries = dataset.data['active_entries']
        lagged_entries = active_entries - \
            np.concatenate([active_entries[:, self.lag_features + 1:, :],
                            np.zeros((active_entries.shape[0], self.lag_features + 1, 1))], axis=1)
        if projection_horizon > 0:
            lagged_entries = np.concatenate([lagged_entries[:, projection_horizon:, :],
                                             np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        active_entries_before_projection = np.concatenate([active_entries[:, projection_horizon:, :],
                                                          np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        prev_treatments = dataset.data['prev_treatments']
        inputs = [(prev_treatments * active_entries_before_projection).sum(1)]
        if self.has_vitals:
            vitals = dataset.data['vitals']
            inputs.append(vitals[np.repeat(lagged_entries, self.dim_vitals, 2) == 1.0].reshape(vitals.shape[0],
                                                                                               (self.lag_features + 1) *
                                                                                               self.dim_vitals))
        if self.autoregressive:
            prev_outputs = dataset.data['prev_outputs']
            inputs.append(prev_outputs[np.repeat(lagged_entries, self.dim_outcome, 2) == 1.0].reshape(prev_outputs.shape[0],
                                                                                                      (self.lag_features + 1) *
                                                                                                      self.dim_outcome))
        static_features = dataset.data['static_features']
        inputs.append(static_features)
        return np.concatenate(inputs, axis=1)

    def fit(self):
        self.prepare_data()
        train_f = self.get_exploded_dataset(
            self.dataset_collection.train_f, min_length=self.lag_features)
        active_entries = train_f.data['active_entries']
        last_entries = active_entries - \
            np.concatenate([active_entries[:, 1:, :], np.zeros(
                (active_entries.shape[0], 1, 1))], axis=1)

        # Inputs
        inputs = self.get_inputs(train_f)

        # Outputs
        current_treatments = train_f.data['current_treatments']
        outputs = (current_treatments * last_entries).sum(1)
        # Convert to PyTorch tensors
        inputs = torch.tensor(inputs, dtype=torch.float32)
        outputs = torch.tensor(outputs, dtype=torch.float32)

        # Create DataLoader
        dataset = TensorDataset(inputs, outputs)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

        # Training loop
        for epoch in range(args.exp.max_epochs):
            epoch_loss = 0.0
            for batch_inputs, batch_outputs in dataloader:
                self.optimizer.zero_grad()
                predictions = self.propensity_history(batch_inputs)
                loss = self.criterion(predictions, batch_outputs)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()

            epoch_loss /= len(dataloader)
            print(
                f"Epoch {epoch + 1}/{args.exp.max_epochs}, Loss: {epoch_loss:.4f}")
            # also show accuracy
            correct = 0
            total = 0
            for batch_inputs, batch_outputs in dataloader:
                predictions = self.propensity_history(batch_inputs)
                _, predicted = torch.max(predictions.data, 1)
                total += batch_outputs.size(0)

                target_labels = batch_outputs.argmax(dim=1)
                correct += (predicted == target_labels).sum().item()

            print(f"Accuracy: {100 * correct / total:.2f}%")