In [1]:
import random
from os import path

import numpy as np
import polars as pl
import torch
from sklearn.preprocessing import StandardScaler
from torch import nn, optim

In [2]:
SEED = 491
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
data_path = "../../data/previous"
seer_path = path.join(data_path, "processed_seer_with_age.csv")
cohort_path = path.join(data_path, "processed_cohort_with_age.csv")
seer = pl.read_csv(seer_path)
cohort = pl.read_csv(cohort_path)

In [5]:
label_col = "target"
cols = seer.select(pl.exclude(label_col)).columns
X_seer = seer.select(pl.exclude(label_col)).to_numpy().astype(np.float32)
y_seer = seer.get_column(label_col).to_numpy().astype(np.int64).ravel()
X_cohort = cohort.select(pl.exclude(label_col)).to_numpy().astype(np.float32)
y_cohort = cohort.get_column(label_col).to_numpy().astype(np.int64).ravel()

In [6]:
scaler_seer = StandardScaler()
scaler_cohort = StandardScaler()
X_seer_std = scaler_seer.fit_transform(X_seer).astype(np.float32)
X_cohort_std = scaler_cohort.fit_transform(X_cohort).astype(np.float32)

In [7]:
input_dim = X_seer_std.shape[1]
cond_dim = 1
latent_dim = 24
hidden = 128


class CVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = nn.Linear(input_dim + cond_dim, hidden)
        self.enc2 = nn.Linear(hidden, hidden)
        self.mu = nn.Linear(hidden, latent_dim)
        self.logvar = nn.Linear(hidden, latent_dim)
        self.dec1 = nn.Linear(latent_dim + cond_dim, hidden)
        self.dec2 = nn.Linear(hidden, hidden)
        self.out = nn.Linear(hidden, input_dim)

    def encode(self, x, c):
        h = torch.relu(self.enc1(torch.cat([x, c], 1)))
        h = torch.relu(self.enc2(h))
        return self.mu(h), self.logvar(h)

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        h = torch.relu(self.dec1(torch.cat([z, c], 1)))
        h = torch.relu(self.dec2(h))
        return self.out(h)

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparam(mu, logvar)
        x_hat = self.decode(z, c)
        return x_hat, mu, logvar

In [8]:
seer_model = CVAE().to(device)
cohort_model = CVAE().to(device)
mse = nn.MSELoss(reduction="mean")

In [9]:
seer_optimizer = optim.Adam(seer_model.parameters(), lr=1e-3)
seer_inputs = torch.tensor(X_seer_std, device=device)
seer_labels = torch.tensor(y_seer, dtype=torch.float32, device=device).unsqueeze(1)

for epoch in range(50):
    seer_recon, seer_mu, seer_logvar = seer_model(seer_inputs, seer_labels)
    seer_recon_loss = mse(seer_recon, seer_inputs)
    seer_kld_loss = -0.5 * torch.mean(1 + seer_logvar - seer_mu.pow(2) - seer_logvar.exp())
    seer_loss = seer_recon_loss + 0.01 * seer_kld_loss
    seer_optimizer.zero_grad()
    seer_loss.backward()
    seer_optimizer.step()

In [10]:
cohort_optimizer = optim.Adam(cohort_model.parameters(), lr=5e-4)
cohort_inputs = torch.tensor(X_cohort_std, device=device)
cohort_labels = torch.tensor(y_cohort, dtype=torch.float32, device=device).unsqueeze(1)

for epoch in range(50):
    cohort_recon, cohort_mu, cohort_logvar = cohort_model(cohort_inputs, cohort_labels)
    cohort_recon_loss = mse(cohort_recon, cohort_inputs)
    cohort_kld_loss = -0.5 * torch.mean(1 + cohort_logvar - cohort_mu.pow(2) - cohort_logvar.exp())
    cohort_loss = cohort_recon_loss + 0.005 * cohort_kld_loss
    cohort_optimizer.zero_grad()
    cohort_loss.backward()
    cohort_optimizer.step()

In [11]:
target_pos_rate = 0.5
n_syn = 10000

In [12]:
seer_y_syn = np.random.binomial(1, target_pos_rate, size=n_syn).astype(np.int64)
with torch.no_grad():
    seer_z = torch.randn(n_syn, latent_dim, device=device)
    seer_c = torch.tensor(seer_y_syn, dtype=torch.float32, device=device).unsqueeze(1)
    seer_x_gen_std = seer_model.decode(seer_z, seer_c).cpu().numpy().astype(np.float32)
seer_X_syn = scaler_seer.inverse_transform(seer_x_gen_std)

In [13]:
cohort_y_syn = np.random.binomial(1, target_pos_rate, size=n_syn).astype(np.int64)
with torch.no_grad():
    cohort_z = torch.randn(n_syn, latent_dim, device=device)
    cohort_c = torch.tensor(cohort_y_syn, dtype=torch.float32, device=device).unsqueeze(1)
    cohort_x_gen_std = cohort_model.decode(cohort_z, cohort_c).cpu().numpy().astype(np.float32)
cohort_X_syn = scaler_cohort.inverse_transform(cohort_x_gen_std)

In [14]:
seer_syn_df = pl.DataFrame(seer_X_syn, schema=cols).with_columns(pl.Series(label_col, seer_y_syn))
out_path = path.join(data_path, "synthetic_patients_2.0_seer.csv")
seer_syn_df.write_csv(out_path)

In [15]:
cohort_syn_df = pl.DataFrame(cohort_X_syn, schema=cols).with_columns(pl.Series(label_col, cohort_y_syn))
out_path = path.join(data_path, "synthetic_patients_2.0_cohort.csv")
cohort_syn_df.write_csv(out_path)