## Neural Spline Flow 1D for estimating distribution of potential outcome

In [18]:
import sys
import os

code_dir = os.path.abspath(os.path.join(os.getcwd(), '.', 'code'))
if code_dir not in sys.path:
    sys.path.append(code_dir)

from pathlib import Path
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import pandas as pd 
import argparse
import pickle
import imageio

from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns

from neural_spline_flow1D import NeuralSplineFlow1D
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
from torch.distributions import Normal
from IPython.display import display, Markdown
from gif import *

### Dataset loading

In [19]:
potential_outcome_df = pd.read_csv('./data/processed_potential_outcome.csv')
valid_df = potential_outcome_df[potential_outcome_df['prev_output']>5]

del potential_outcome_df

# KMeans
prev_outputs = valid_df[['prev_output']].copy()

n_clusters = 3
kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init='auto')
prev_outputs['cluster'] = kmeans.fit_predict(prev_outputs)

cluster_counts = prev_outputs['cluster'].value_counts().sort_index()
cluster_centers = kmeans.cluster_centers_

cluster_summary = pd.DataFrame({
    'Cluster ID': cluster_counts.index,
    'Sample Count': cluster_counts.values,
    'Center (prev_output)': cluster_centers.flatten()
})

display(cluster_summary)

Unnamed: 0,Cluster ID,Sample Count,Center (prev_output)
0,0,3531,33.36316
1,1,162,700.420938
2,2,527,268.839934


In [20]:
def plot_training_1d_kde(ssc, device, base_dist, y_original, model, epoch, cluster, treatment):
    with torch.no_grad():
        
        # Normal distribution
        z_samples = base_dist.sample((50000,)).to(device)
        
        # Flow forward
        y_vis, _ = model._elementwise_forward(z_samples.unsqueeze(-1))
        y_np = y_vis.detach().cpu().squeeze().numpy()  
        
        inversed = ssc.inverse_transform(y_np.reshape(-1, 1))

        plt.figure(figsize=(8, 4))
        plt.hist(y_original, bins=30, density=True, alpha=0.3, color='gray', label='Histogram') # 관측 히스토그램
        sns.kdeplot(inversed, label=f'P(Y[{treatment}] = y)', color='green')
        plt.title(f"Epoch {epoch}")
        plt.xlabel("y")
        plt.ylabel("Density")
        plt.legend()
        
        png_dir = f"train_plots_Cancer_cluster_{cluster}_treatment_{treatment}"
        
        os.makedirs(png_dir, exist_ok=True)
        plt.savefig(png_dir + f"/epoch_{epoch:04d}.png")
        plt.close()

In [25]:
def make_gif_from_train_plots(treatment, cluster, fname: str) -> None:

    png_dir = f"train_plots_Cancer_cluster_{cluster}_treatment_{treatment}/"
    images = []
    sort = sorted(os.listdir(png_dir))
    for file_name in sort[1::1]:
        if file_name.endswith(".png"):
            file_path = os.path.join(png_dir, file_name)
            images.append(imageio.imread(file_path))

    imageio.mimsave("./code/gifs/" + fname, images, duration=0.05)

### Training

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose cluster number you want to analyze
cluster_num = 2

# Choose treatment number you want to analyze (1: not treated, 2: Chemo, 3: Radio, 4: Both)
treatment = 4

centroid = cluster_summary[cluster_summary['Cluster ID'] == cluster_num]['Center (prev_output)'].values

cluster_data = prev_outputs[prev_outputs['cluster'] == cluster_num]
filtered_df = valid_df.loc[cluster_data.index]

df = filtered_df.copy()
ssc = MinMaxScaler()
df[f'y{treatment}_scaled']=ssc.fit_transform(df[f'y{treatment}'].values.reshape(-1, 1))

# hyperparameter setting
batch_size = len(df)

dataset = TensorDataset(torch.tensor(df[f'y{treatment}_scaled'].values).unsqueeze(-1))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

features = 1 # 1 Dimension
num_bins = 10
epochs = 500

model = NeuralSplineFlow1D(features, num_bins).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
base_dist = torch.distributions.Normal(0, 1)

model.train()

for epoch in range(1, epochs + 1):
    
    epoch_loss = 0.0
    
    for batch in dataloader:
        x_batch = batch[0].to(device).float()   # shape: (batch_size, 1) or (batch_size,)
        
        optimizer.zero_grad()
        
        z, logdet = model._elementwise_inverse(x_batch) # x -> z_0
        logprob = base_dist.log_prob(z).sum(dim=-1)
        loss = -(logprob + logdet).mean()
        

        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
    N = len(dataloader.dataset)
    avg_loss = epoch_loss / N
        
    if epoch % 10 == 0:
        print(f"Epoch {epoch:03d}: "
            f"Loss={avg_loss:.4f} ")

    if epoch == 1 or epoch % 1 == 0:
        plot_training_1d_kde(ssc, device, base_dist, df[f'y{treatment}'].values, model, epoch, cluster_num, treatment)
        
    model_path = f'{dataset}_PO_distribution/'
        
    save_dict = {
        'model_state_dict': model.state_dict(),
        'scaler': ssc,
        'bins': num_bins,
        'centroid':centroid,
        'features': features
    }
        
sum_probability(model, base_dist, device)        

Epoch 010: Loss=0.0270 
Epoch 020: Loss=0.0207 
Epoch 030: Loss=0.0150 
Epoch 040: Loss=0.0103 
Epoch 050: Loss=0.0075 
Epoch 060: Loss=0.0059 
Epoch 070: Loss=0.0047 
Epoch 080: Loss=0.0038 
Epoch 090: Loss=0.0032 
Epoch 100: Loss=0.0027 
Epoch 110: Loss=0.0023 
Epoch 120: Loss=0.0021 
Epoch 130: Loss=0.0018 
Epoch 140: Loss=0.0016 
Epoch 150: Loss=0.0015 
Epoch 160: Loss=0.0013 
Epoch 170: Loss=0.0012 
Epoch 180: Loss=0.0010 
Epoch 190: Loss=0.0009 
Epoch 200: Loss=0.0007 
Epoch 210: Loss=0.0007 
Epoch 220: Loss=0.0006 
Epoch 230: Loss=0.0005 
Epoch 240: Loss=0.0004 
Epoch 250: Loss=0.0004 
Epoch 260: Loss=0.0003 
Epoch 270: Loss=0.0003 
Epoch 280: Loss=0.0003 
Epoch 290: Loss=0.0002 
Epoch 300: Loss=0.0002 
Epoch 310: Loss=0.0002 
Epoch 320: Loss=0.0002 
Epoch 330: Loss=0.0002 
Epoch 340: Loss=0.0001 
Epoch 350: Loss=0.0001 
Epoch 360: Loss=0.0001 
Epoch 370: Loss=0.0001 
Epoch 380: Loss=0.0001 
Epoch 390: Loss=0.0000 
Epoch 400: Loss=-0.0000 
Epoch 410: Loss=-0.0001 
Epoch 420: Los

### Visualizing

In [None]:
# Generate and display an animation of the training progress.
make_gif_from_train_plots(treatment, cluster_num, f'example.gif')
display(Markdown('<img src="./code/gifs/example.gif" width="400">'))

<img src="./code/gifs/example.gif" width="400">