<a href="https://colab.research.google.com/github/JosephThompson607/dir_vae/blob/main/sepsis_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler

In this section we prepare the data for training

In [8]:
#TODO: Read this from the cloud
patients = pd.read_csv("/content/unique_patient_dem.csv")

patients.drop(columns=['subject_id'], inplace=True)
numeric_cols = patients.select_dtypes(include=[np.number]).columns.tolist()
categorical_cols = patients.select_dtypes(exclude=[np.number]).columns.tolist()
# Reorder DataFrame
patients = patients[numeric_cols + categorical_cols]
#1 hot encoding
df_encoded = pd.get_dummies(patients, columns=categorical_cols)

#If cuda is available, device is cuda, otherwise cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler = StandardScaler()
#Scaling numeric columns
df_encoded[numeric_cols] = scaler.fit_transform(df_encoded[numeric_cols])
features = df_encoded.astype('float32').values
# print(features.columns)
# print(features.dtypes)
# Get indices for slicing
num_indices = list(range(len(numeric_cols)))
n_numeric = len(numeric_cols)
cat_indices = list(range(len(numeric_cols), len(features)))
tensor = torch.tensor(features, dtype=torch.float32)

X_train, X_test = train_test_split(tensor, test_size=0.2, random_state=42)

train_dataset = TensorDataset(X_train)  # or (X_train, y_train)
test_dataset = TensorDataset(X_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
input_size = X_train[0].shape[0] #input size is the number of features going into the network
print(input_size)

36


Below we define the model and related functions

In [9]:
ngf = 64
ndf = 64
nc = 1

def prior(K, alpha):
    """
    Prior for the model.
    :K: number of categories
    :alpha: Hyper param of Dir
    :return: mean and variance tensors
    """
    # Approximate to normal distribution using Laplace approximation
    a = torch.Tensor(1, K).float().fill_(alpha)
    mean = a.log().t() - a.log().mean(1)
    var = ((1 - 2.0 / K) * a.reciprocal()).t() + (1.0 / K ** 2) * a.reciprocal().sum(1)
    return mean.t(), var.t() # Parameters of prior distribution after approximation

class Dir_VAE(nn.Module):
    def __init__(self, input_size,n_numeric, latent_size=10, hidden_dim = 200):
        self.num_numeric_cols = n_numeric
        self.latent_size = latent_size
        self.hidden_dim = hidden_dim
        self.input_size = input_size
        super(Dir_VAE, self).__init__()

        self.encoder = nn.Sequential(
          nn.Linear(self.input_size, self.hidden_dim),
          nn.ReLU(),
          # nn.Linear(self.hidden_dim, self.hidden_dim),
          # nn.ReLU(),
          # nn.Linear(self.hidden_dim, self.hidden_dim),
          # nn.ReLU()
        )
        self.decoder = nn.Sequential(
          nn.Linear(self.latent_size, self.hidden_dim),
          nn.ReLU(),
          # nn.Linear(self.hidden_dim, self.hidden_dim),
          # nn.ReLU(),
          # nn.Linear(self.hidden_dim, self.hidden_dim),
          # nn.ReLU(),
          nn.Linear(self.hidden_dim, self.self.hidden_dim),
          nn.ReLU(),
          # nn.Unflatten(dim=1, unflattened_size=(1, 28, 28)) # This was for image data
        )
        #self.fc1 = nn.Linear(self.hidden_dim, 512)
        self.fc21 = nn.Linear(self.hidden_dim, self.latent_size)
        self.fc22 = nn.Linear(self.hidden_dim, self.latent_size)

        self.fc3 = nn.Linear(self.hidden_dim, self.input_size)
        #self.fc4 = nn.Linear(512, self.hidden_dim)

        self.lrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()

        # Dir prior
        self.prior_mean, self.prior_var = map(nn.Parameter, prior(self.latent_size, 0.3)) # 0.3 is a hyper param of Dirichlet distribution
        self.prior_logvar = nn.Parameter(self.prior_var.log())
        self.prior_mean.requires_grad = False
        self.prior_var.requires_grad = False
        self.prior_logvar.requires_grad = False


    def encode(self, x):
        encoding = self.encoder(x);
        #h1 = self.fc1(encoding)
        return self.fc21(encoding), self.fc22(encoding)

    def decode(self, gauss_z):
        dir_z = F.softmax(gauss_z,dim=1) #Reduntant, already done in forward
        # This variable (z) can be treated as a variable that follows a Dirichlet distribution (a variable that can be interpreted as a probability that the sum is 1)
        # Use the Softmax function to satisfy the simplex constraint
        x_out = self.decoder(dir_z)
        # Apply sigmoid to categorical output only
        x_out[:, self.num_numeric_cols:] = torch.sigmoid(x_out[:, self.num_numeric_cols:])
        return x_out

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std


    def forward(self, x):
        mu, logvar = self.encode(x)
        gauss_z = self.reparameterize(mu, logvar)
        # gause_z is a variable that follows a multivariate normal distribution
        # Inputting gause_z into softmax func yields a random variable that follows a Dirichlet distribution (Softmax func are used in decoder)
        dir_z = F.softmax(gauss_z,dim=1) # This variable follows a Dirichlet distribution
        return self.decode(gauss_z), mu, logvar, gauss_z, dir_z

    def reconstruction_loss(self, x_true, x_recon):
        # Slice the tensors
        x_true_num = x_true[:, :self.num_numeric_cols]
        x_true_cat = x_true[:, self.num_numeric_cols:]

        x_recon_num = x_recon[:, :self.num_numeric_cols]
        x_recon_cat = x_recon[:, self.num_numeric_cols:]

        # Compute losses
        num_loss = F.mse_loss(x_recon_num, x_true_num)
        cat_loss = F.cross_entropy(x_recon_cat, x_true_cat)

        return num_loss + cat_loss

    # Reconstruction + KL divergence losses summed over all elements and batch
    def loss_function(self, recon_x, x, mu, logvar):
        # Apply sigmoid to the input data x to ensure values are between 0 and 1
        recon_loss = self.reconstruction_loss(x, recon_x, )
        # ディリクレ事前分布と変分事後分布とのKLを計算
        # Calculating KL with Dirichlet prior and variational posterior distributions
        # Original paper:"Autoencodeing variational inference for topic model"-https://arxiv.org/pdf/1703.01488
        prior_mean = self.prior_mean.expand_as(mu)
        prior_var = self.prior_var.expand_as(logvar)
        prior_logvar = self.prior_logvar.expand_as(logvar)
        var_division = logvar.exp() / prior_var # Σ_0 / Σ_1
        diff = mu - prior_mean # μ_１ - μ_0
        diff_term = diff *diff / prior_var # (μ_1 - μ_0)(μ_1 - μ_0)/Σ_1
        logvar_division = prior_logvar - logvar # log|Σ_1| - log|Σ_0| = log(|Σ_1|/|Σ_2|)
        # KL
        KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(1) - self.latent_size)
        self.last_KLD = torch.mean(KLD) #Used for reporting
        self.last_BCE = recon_loss
        return recon_loss + KLD

Below are the training and test loops


In [10]:





model = Dir_VAE(input_size, n_numeric, latent_size=2, hidden_dim=20).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data,) in enumerate(train_loader): # Unpack only one element
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar, gauss_z, dir_z = model(data)

        loss = model.loss_function(recon_batch, data, mu, logvar, )
        loss = loss.mean()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 10 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader)}%)] \
            Loss:{loss.item() / len(data)}\
            R_loss {model.last_BCE}, KLD_loss {model.last_KLD}')

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data,) in enumerate(test_loader): # Unpack only one element
            data = data.to(device)
            recon_batch, mu, logvar, gauss_z, dir_z = model(data)
            loss = model.loss_function(recon_batch, data, mu, logvar)
            test_loss += loss.mean()
            test_loss.item()


    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

if __name__ == "__main__":
    for epoch in range(1, 10 + 1):
        train(epoch)
        test(epoch)


====> Epoch: 1 Average loss: 0.2272
====> Test set loss: 0.2130
====> Epoch: 2 Average loss: 0.2120
====> Test set loss: 0.2110
====> Epoch: 3 Average loss: 0.2101
====> Test set loss: 0.2089
====> Epoch: 4 Average loss: 0.2081
====> Test set loss: 0.2087
====> Epoch: 5 Average loss: 0.2081
====> Test set loss: 0.2084
====> Epoch: 6 Average loss: 0.2079
====> Test set loss: 0.2080
====> Epoch: 7 Average loss: 0.2080
====> Test set loss: 0.2079
====> Epoch: 8 Average loss: 0.2083
====> Test set loss: 0.2080
====> Epoch: 9 Average loss: 0.2082
====> Test set loss: 0.2079
====> Epoch: 10 Average loss: 0.2078
====> Test set loss: 0.2086


In [17]:
with torch.no_grad():
  print("testing encoding")
  data = next(iter(test_loader))
  print(data[0][0])
  recon_batch, mu, logvar, gauss_z, dir_z = model(data[0])
  print(recon_batch[0])


testing encoding
tensor([0.6984, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000])
tensor([7.0599e-01, 1.1541e-04, 1.5541e-04, 1.0534e-04, 1.4746e-04, 1.0779e-04,
        9.4734e-05, 9.8256e-05, 5.2996e-01, 1.8855e-04, 8.8673e-05, 1.2442e-04,
        1.6131e-04, 1.0111e-04, 1.5494e-04, 1.8781e-04, 1.5484e-04, 1.7324e-04,
        1.4581e-04, 2.0903e-04, 1.5217e-04, 1.3672e-04, 1.3104e-04, 9.8247e-04,
        1.4903e-04, 8.4857e-05, 1.5018e-04, 1.4926e-04, 4.4486e-01, 9.9989e-01,
        1.3694e-04, 1.0310e-04, 5.1154e-04, 1.4122e-04, 9.9989e-01, 9.9989e-01])
