# Conditional GLOW for Cell Data Generation

In [22]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns

# Create output directory
os.makedirs("generated", exist_ok=True)

torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# %% [markdown]
# ## 2. Load and Preprocess

# %%
def load_and_preprocess_data(data_path='input/ML_data.csv'):
    df = pd.read_csv(data_path).drop(columns=['Samples'])
    df = df[~df['Cell_type'].isin(['Unknown'])]

    feature_names = df.columns.drop('Cell_type')
    le = LabelEncoder()
    df['Cell_type'] = le.fit_transform(df['Cell_type'])
    label_names = le.classes_
    num_classes = len(label_names)

    y = df['Cell_type'].values
    X = df.drop(columns=['Cell_type']).values
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    X_train, X_val, y_train, y_val = train_test_split(
        X_scaled, y, test_size=0.2, random_state=42, stratify=y
    )

    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)

    train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=64, shuffle=True)
    val_loader = DataLoader(TensorDataset(X_val_tensor, y_val_tensor), batch_size=64, shuffle=False)

    return X, y, X_scaled, scaler, le, label_names, feature_names, num_classes, train_loader, val_loader

X, y, X_scaled, scaler, le, label_names, feature_names, num_classes, train_loader, val_loader = load_and_preprocess_data()
input_dim = X.shape[1]
print(f"Input dimension: {input_dim}, Number of classes: {num_classes}")

# %% [markdown]
# ## 3. Model Components

# %%
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.LeakyReLU(0.2),
            nn.LayerNorm(hidden_features),
            nn.Linear(hidden_features, hidden_features),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_features, out_features),
        )

    def forward(self, x):
        return self.net(x)

class ActNorm(nn.Module):
    def __init__(self, features, scale=1.0):
        super().__init__()
        self.initialized = False
        self.scale = scale
        self.register_parameter("bias", nn.Parameter(torch.zeros(features)))
        self.register_parameter("logs", nn.Parameter(torch.zeros(features)))

    def initialize(self, x):
        with torch.no_grad():
            mean = x.mean(0)
            std = x.std(0)
            self.bias.data.copy_(-mean)
            self.logs.data.copy_(torch.log(self.scale / (std + 1e-6)))
        self.initialized = True

    def forward(self, x):
        if not self.initialized:
            self.initialize(x)
        z = (x + self.bias) * torch.exp(self.logs)
        return z, self.logs.sum()

    def inverse(self, z):
        return (z * torch.exp(-self.logs)) - self.bias

class InvertibleLinear(nn.Module):
    def __init__(self, dim):
        super().__init__()
        W = torch.nn.init.orthogonal_(torch.randn(dim, dim))
        P, L, U = torch.lu_unpack(*W.lu())
        self.register_buffer('P', P)
        self.L = nn.Parameter(L)
        self.S = nn.Parameter(U.diag())
        self.U = nn.Parameter(torch.triu(U, diagonal=1))

    def _assemble_W(self):
        L = torch.tril(self.L, -1) + torch.eye(self.L.size(0), device=self.L.device)
        U = torch.triu(self.U, 1) + torch.diag(self.S)
        return self.P @ L @ U

    def forward(self, x):
        W = self._assemble_W()
        return x @ W, torch.sum(torch.log(torch.abs(self.S)))

    def inverse(self, z):
        W = self._assemble_W()
        return z @ torch.inverse(W)

class AffineCoupling(nn.Module):
    def __init__(self, input_dim, hidden_dim, mask, num_classes):
        super().__init__()
        self.mask = mask
        self.embedding = nn.Embedding(num_classes, 8)
        self.scale_net = MLP(input_dim + 8, hidden_dim, input_dim)
        self.translate_net = MLP(input_dim + 8, hidden_dim, input_dim)

    def forward(self, x, y):
        x_masked = x * self.mask
        y_embed = self.embedding(y)
        x_cat = torch.cat([x_masked, y_embed], dim=1)
        scale = self.scale_net(x_cat) * (1 - self.mask)
        translate = self.translate_net(x_cat) * (1 - self.mask)
        y_out = x_masked + (1 - self.mask) * (x * torch.exp(scale) + translate)
        return y_out, scale.sum(1)

    def inverse(self, y, label):
        y_masked = y * self.mask
        y_embed = self.embedding(label)
        y_cat = torch.cat([y_masked, y_embed], dim=1)
        scale = self.scale_net(y_cat) * (1 - self.mask)
        translate = self.translate_net(y_cat) * (1 - self.mask)
        x = y_masked + (1 - self.mask) * (y - translate) * torch.exp(-scale)
        return x

class GlowStep(nn.Module):
    def __init__(self, input_dim, hidden_dim, mask, num_classes):
        super().__init__()
        self.actnorm = ActNorm(input_dim)
        self.invconv = InvertibleLinear(input_dim)
        self.coupling = AffineCoupling(input_dim, hidden_dim, mask, num_classes)

    def forward(self, x, y):
        x, ldj1 = self.actnorm(x)
        x, ldj2 = self.invconv(x)
        x, ldj3 = self.coupling(x, y)
        return x, ldj1 + ldj2 + ldj3

    def inverse(self, z, y):
        z = self.coupling.inverse(z, y)
        z = self.invconv.inverse(z)
        z = self.actnorm.inverse(z)
        return z

class ConditionalGLOW(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_flow_steps, num_classes):
        super().__init__()
        self.flow_steps = nn.ModuleList()
        for i in range(n_flow_steps):
            mask = torch.zeros(input_dim).to(device)
            mask[i % 2::2] = 1
            self.flow_steps.append(GlowStep(input_dim, hidden_dim, mask, num_classes))

    def forward(self, x, y):
        log_det = 0
        for step in self.flow_steps:
            x, ldj = step(x, y)
            log_det += ldj
        return x, log_det

    def inverse(self, z, y):
        for step in reversed(self.flow_steps):
            z = step.inverse(z, y)
        return z

# %% [markdown]
# ## 4. Training

# %%
def train(model, train_loader, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for x_batch, y_batch in train_loader:
            z, log_det = model(x_batch, y_batch)
            log_prob = -0.5 * torch.sum(z ** 2, dim=1) - 0.5 * input_dim * np.log(2 * np.pi)
            loss = -(log_prob + log_det).mean()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()

            total_loss += loss.item() * x_batch.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

# %%
hidden_dim = 128
n_flow_steps = 16
model = ConditionalGLOW(input_dim, hidden_dim, n_flow_steps, num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

train(model, train_loader, optimizer, num_epochs=150)
torch.save(model.state_dict(), "generated/glow_model.pt")

Using device: cuda
Input dimension: 89, Number of classes: 4
Epoch 1/150, Loss: 168.5968
Epoch 2/150, Loss: 141.0645
Epoch 3/150, Loss: 118.6972
Epoch 4/150, Loss: 105.6923
Epoch 5/150, Loss: 97.3759
Epoch 6/150, Loss: 89.3348
Epoch 7/150, Loss: 84.1397
Epoch 8/150, Loss: 79.4649
Epoch 9/150, Loss: 75.3399
Epoch 10/150, Loss: 71.6551
Epoch 11/150, Loss: 69.3471
Epoch 12/150, Loss: 67.0650
Epoch 13/150, Loss: 65.5853
Epoch 14/150, Loss: 63.2964
Epoch 15/150, Loss: 61.0782
Epoch 16/150, Loss: 59.3868
Epoch 17/150, Loss: 57.0311
Epoch 18/150, Loss: 56.3959
Epoch 19/150, Loss: 55.7575
Epoch 20/150, Loss: 54.3217
Epoch 21/150, Loss: 52.5229
Epoch 22/150, Loss: 52.0908
Epoch 23/150, Loss: 49.9148
Epoch 24/150, Loss: 49.0820
Epoch 25/150, Loss: 48.5599
Epoch 26/150, Loss: 48.0411
Epoch 27/150, Loss: 46.0734
Epoch 28/150, Loss: 45.2139
Epoch 29/150, Loss: 46.0345
Epoch 30/150, Loss: 44.9313
Epoch 31/150, Loss: 42.0675
Epoch 32/150, Loss: 41.6909
Epoch 33/150, Loss: 40.6379
Epoch 34/150, Loss: 

# 5. Generate and Visualize

In [None]:
def generate_synthetic_data(model, class_id, num_samples):
    z = torch.randn(num_samples, input_dim).to(device)
    y = torch.full((num_samples,), class_id, dtype=torch.long).to(device)
    with torch.no_grad():
        x_gen = model.inverse(z, y).cpu().numpy()
    return x_gen

def plot_real_vs_fake_per_class(X_scaled, y, scaler, model, label_names):
    for cls_id in range(num_classes):
        cls_name = label_names[cls_id]
        real = X_scaled[y == cls_id]
        fake = generate_synthetic_data(model, cls_id, len(real))

        pca = PCA(n_components=2)
        combined = np.vstack([real, fake])
        pca_result = pca.fit_transform(combined)
        real_pca = pca_result[:len(real)]
        fake_pca = pca_result[len(real):]

        plt.figure(figsize=(6, 5))
        plt.scatter(real_pca[:, 0], real_pca[:, 1], label='Real', alpha=0.6)
        plt.scatter(fake_pca[:, 0], fake_pca[:, 1], label='Fake', alpha=0.6)
        plt.title(f'PCA: Real vs. Fake ({cls_name})')
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"generated/pca_real_vs_fake_{cls_name}.png")
        plt.show()
        plt.close()

# Plot PCA comparisons
plot_real_vs_fake_per_class(X_scaled, y, scaler, model, label_names)


In [None]:
def generate_synthetic_data(model, scaler, le, n_per_class=500):
    model.eval()
    zs, labels = [], []
    for class_idx in range(num_classes):
        z = torch.randn(n_per_class, input_dim).to(device)
        y = torch.full((n_per_class,), class_idx, dtype=torch.long).to(device)
        x_fake = model.inverse(z, y).detach().cpu().numpy()
        zs.append(x_fake)
        labels.extend([class_idx] * n_per_class)

    X_fake = np.vstack(zs)
    y_fake = np.array(labels)
    X_fake_rescaled = scaler.inverse_transform(X_fake)
    df_fake = pd.DataFrame(X_fake_rescaled, columns=feature_names)
    df_fake["Cell_type"] = le.inverse_transform(y_fake)
    df_fake.to_csv("generated/fake_data.csv", index=False)
    print("✅ Saved synthetic data to 'generated/fake_data.csv'")
    return df_fake

df_fake = generate_synthetic_data(model, scaler, le)

# %% [markdown]
# ## 7. Visualization: PCA Per Class + Feature Distribution

def plot_pca_per_class(X_real, y_real, X_fake, y_fake, class_names, scaler, label_encoder):
    pca = PCA(n_components=2)
    for i, class_name in enumerate(class_names):
        real_mask = y_real == i
        fake_mask = y_fake == i

        if np.sum(real_mask) == 0 or np.sum(fake_mask) == 0:
            continue

        Xr = X_real[real_mask]
        Xf = X_fake[fake_mask]
        X_combined = np.vstack([Xr, Xf])
        X_pca = pca.fit_transform(X_combined)

        Xr_pca = X_pca[:len(Xr)]
        Xf_pca = X_pca[len(Xr):]

        plt.figure(figsize=(14, 10))
        plt.scatter(Xr_pca[:, 0], Xr_pca[:, 1], alpha=0.6, label="Real", c='blue')
        plt.scatter(Xf_pca[:, 0], Xf_pca[:, 1], alpha=0.6, label="Fake", c='orange')
        plt.title(f"PCA: {class_name}")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"generated/pca_{class_name}.png")
        plt.close()

def plot_feature_distributions_grid(df_real, df_fake, features, max_features=12, n_cols=4):
    selected_features = features[:max_features]
    n_feats = len(selected_features)
    n_rows = (n_feats + n_cols - 1) // n_cols
    plt.figure(figsize=(14, 10))
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3))
    axes = axes.flatten()

    for idx, feature in enumerate(selected_features):
        ax = axes[idx]
        sns.kdeplot(data=df_real, x=feature, label="Real", fill=True, color="blue", alpha=0.5, ax=ax)
        sns.kdeplot(data=df_fake, x=feature, label="Fake", fill=True, color="orange", alpha=0.5, ax=ax)
        ax.set_title(feature)
        ax.legend()

    # Hide any empty subplots
    for i in range(n_feats, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig("generated/feature_distributions_grid.png")
    plt.close()
    print(f"✅ Feature distribution grid saved as 'generated/feature_distributions_grid.png'")



def plot_feature_distributions(X_real_df, X_fake_df, features, class_names):
    for feature in features:
        plt.figure(figsize=(14, 10))
        sns.kdeplot(data=X_real_df, x=feature, label="Real", fill=True, color="blue", alpha=0.5)
        sns.kdeplot(data=X_fake_df, x=feature, label="Fake", fill=True, color="orange", alpha=0.5)
        plt.title(f"Feature Distribution: {feature}")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"generated/feature_dist_{feature}.png")
        plt.close()

# Prepare for plotting
df_real = pd.DataFrame(scaler.inverse_transform(X_scaled), columns=feature_names)
df_real["Cell_type"] = le.inverse_transform(y)
df_fake = pd.read_csv("generated/fake_data.csv")
X_fake = scaler.transform(df_fake[feature_names])
y_fake = le.transform(df_fake["Cell_type"])

plot_pca_per_class(X_scaled, y, X_fake, y_fake, label_names, scaler, le)
#plot_feature_distributions(df_real, df_fake, feature_names, label_names)
plot_feature_distributions_grid(df_real, df_fake, feature_names, max_features=12, n_cols=4)

✅ Saved synthetic data to 'generated/fake_data.csv'




✅ Feature distribution grid saved as 'generated/feature_distributions_grid.png'


<Figure size 1400x1000 with 0 Axes>