## Toy project (IG)

### Step 1. Module import

In [None]:
#!/usr/bin/env python

import numpy as np, pandas as pd, matplotlib.pyplot as plt
from copy import deepcopy
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
import random

### Step 2. Set configurations

In [None]:
SEED = 42

# data parameter
N_SAMPLES=4000
N_FEATURES=1000
N_INFORM=80
N_REDUN=40

# AE
Z_DIM=64
AE_HIDDEN=256
AE_EPOCHS=100
AE_BATCH=256
AE_LR=1e-4
AE_patience=20
AE_stale=0

# Transformer shape
TOKENS=64
D_MODEL=256

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [21]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Step 3. Dataset preparation and preprocessing

In [22]:
# Make pseudo DEG dataset
X, y = make_classification(
    n_samples=N_SAMPLES, n_features=N_FEATURES,
    n_informative=N_INFORM, n_redundant=N_REDUN, n_repeated=0, # n_redundant: mimicking colinear gene expression
    n_classes=2, class_sep=2.2, flip_y=0.01, random_state=SEED
) # flip the label to prevent overfitting

X = X.astype(np.float32)
y = y.astype(np.int64)

# Train (0.6), val (0.1), test(0.2) split
X_trainval, X_test, y_trainval, y_test = train_test_split(
    X, y, test_size=0.20, stratify=y, random_state=SEED
)

X_train, X_val, y_train, y_val = train_test_split(
    X_trainval, y_trainval, test_size=0.1/0.8, stratify=y_trainval, random_state=SEED
)

# Scaling
scaler = StandardScaler().fit(X_train)
X_train = scaler.transform(X_train).astype(np.float32)
X_val = scaler.transform(X_val).astype(np.float32)
X_test = scaler.transform(X_test).astype(np.float32)

# Load data
train_loader = DataLoader(TensorDataset(torch.from_numpy(X_train)), batch_size=AE_BATCH, shuffle=True, drop_last=False)
val_loader = DataLoader(TensorDataset(torch.from_numpy(X_val)), batch_size=AE_BATCH, shuffle=True, drop_last=False)
test_loader = DataLoader(TensorDataset(torch.from_numpy(X_test)), batch_size=AE_BATCH, shuffle=True, drop_last=False)

In [23]:
print(X.shape)

(4000, 1000)


In [None]:
# Check dataset imbalance
pos = int((y==1).sum())
neg = int((y==0).sum())
n = len(y)

pos_ratio = pos / n
imbalance_ratio = neg / max(pos, 1)
print(f"pos={pos}, neg={neg}, n={n}, pos_ratio={pos_ratio:.4f}, neg/pos={imbalance_ratio:.2f}")

pos=1998, neg=2002, n=4000, pos_ratio=0.4995, neg/pos=1.00


### Step 4. Autoencoder

In [25]:
# Autoencoder model structure
class AE(nn.Module):
    def __init__(self, in_dim, z_dim, hidden):
        super().__init__()
        self.enc = nn.Sequential(nn.Linear(in_dim, hidden), nn.ReLU(), nn.Linear(hidden, z_dim))
        self.dec = nn.Sequential(nn.Linear(z_dim, hidden), nn.ReLU(), nn.Linear(hidden, in_dim))
    def forward(self, x):
        z = self.enc(x)
        x_hat = self.dec(z)
        return x_hat, z

ae = AE(in_dim=N_FEATURES, z_dim=Z_DIM, hidden=AE_HIDDEN).to(device)
opt = torch.optim.Adam(ae.parameters(), lr=AE_LR)
crit = nn.MSELoss()

best_val = float('inf')
best_state = None

# Train autoencoder
for ep in range(1, AE_EPOCHS+1):
    # train
    ae.train()
    total=0.0
    n=0
    for (xb,) in train_loader:
        xb = xb.to(device)
        opt.zero_grad()
        xhat, _ = ae(xb)
        loss = crit(xhat, xb)
        loss.backward()
        opt.step()
        total += loss.item()*xb.size(0)
        n += xb.size(0)
    train_loss = total / n

    # validate
    ae.eval()
    with torch.no_grad():
        vt = 0.0
        vn = 0
        for (xb,) in val_loader:
            xb = xb.to(device)
            x_hat, _ = ae(xb)
            vloss = crit(x_hat, xb)
            vt += vloss.item() * xb.size(0)
            vn += xb.size(0)
        val_loss = vt / vn

    if ep%5==0 or ep==1 or ep==AE_EPOCHS:
        print(f"[AE] {ep:03d}/{AE_EPOCHS} train_recon={train_loss:.4f} val_recon={val_loss:.4f}")

    # early stopping
    if val_loss + 1e-6 < best_val:
        best_val = val_loss
        best_state = deepcopy(ae.state_dict())
        AE_stale = 0
    else:
        AE_stale += 1
        if AE_stale >= AE_patience:
            print(f"Early stopping at epoch {ep} (best val {best_val:.4f})")
            break

# Best state
if best_state is not None:
    ae.load_state_dict(best_state)

# Test reconstruction loss
ae.eval()
with torch.no_grad():
    tt = 0.0
    tn = 0
    for (xb,) in test_loader:
        xb = xb.to(device)
        x_hat, _ = ae(xb)
        tloss = crit(x_hat, xb)
        tt += tloss.item() * xb.size(0)
        tn += xb.size(0)
    test_loss = tt / tn

print(f"[AE] Test reconstruction={test_loss:.4f}")



[AE] 001/100 train_recon=1.0041 val_recon=1.0038
[AE] 005/100 train_recon=0.9964 val_recon=0.9992
[AE] 010/100 train_recon=0.9874 val_recon=0.9932
[AE] 015/100 train_recon=0.9715 val_recon=0.9816
[AE] 020/100 train_recon=0.9499 val_recon=0.9668
[AE] 025/100 train_recon=0.9320 val_recon=0.9559
[AE] 030/100 train_recon=0.9182 val_recon=0.9479
[AE] 035/100 train_recon=0.9072 val_recon=0.9418
[AE] 040/100 train_recon=0.8981 val_recon=0.9370
[AE] 045/100 train_recon=0.8906 val_recon=0.9333
[AE] 050/100 train_recon=0.8842 val_recon=0.9304
[AE] 055/100 train_recon=0.8788 val_recon=0.9283
[AE] 060/100 train_recon=0.8742 val_recon=0.9266
[AE] 065/100 train_recon=0.8702 val_recon=0.9253
[AE] 070/100 train_recon=0.8668 val_recon=0.9243
[AE] 075/100 train_recon=0.8637 val_recon=0.9237
[AE] 080/100 train_recon=0.8609 val_recon=0.9232
[AE] 085/100 train_recon=0.8585 val_recon=0.9229
[AE] 090/100 train_recon=0.8562 val_recon=0.9228
[AE] 095/100 train_recon=0.8541 val_recon=0.9227
[AE] 100/100 train_r

### Step 4-1. Dataset preparation

In [26]:
def encode_numpy(ae, X_np, batch_size=1024):
    ae.eval()
    Z = []
    with torch.no_grad():
        for i in range(0, len(X_np), batch_size):
            xb = torch.from_numpy(X_np[i:i+batch_size]).to(device)
            _, z = ae(xb)
            Z.append(z.cpu().numpy().astype(np.float32))
    return np.concatenate(Z, axis=0)

Z_train = encode_numpy(ae, X_train)
Z_val = encode_numpy(ae, X_val)
Z_test = encode_numpy(ae, X_test)

In [27]:
print(y_train)
print(len(y_train))

[1 1 1 ... 1 1 0]
2800


### Step 5. Model Training

#### 5-1. Application to transformer

In [None]:
# Set Transformer architecture for 
class TransformerHead(nn.Module):
    def __init__(self, z_dim, tokens, d_model, nhead=8, dim_ff=256, dropout=0.1):
        super().__init__()
        self.tokens, self.d_model = tokens, d_model
        self.embedding = nn.Linear(1, d_model)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
            dropout=dropout, activation="gelu", batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=4)
        self.norm = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, 1)
    def forward(self, z):
        B, latent_dim = z.shape
        z_tokens = z.unsqueeze(-1)  # (B, tokens=latent_dim, 1)
        x = self.embedding(z_tokens)    # (B, tokens, d_model)
        h = self.encoder(x) # (B, tokens, d_model)
        h = self.norm(h).mean(1)    # (B, d_model)
        base = self.out(h)  # (B, 1)
        # simple interaction on two coords to induce rare bursts
        return base # (B, 1)

In [29]:
import torch.optim as optim
from sklearn.metrics import accuracy_score, roc_auc_score

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 64

train_dataset = TensorDataset(torch.from_numpy(Z_train).float(), torch.from_numpy(y_train).float())
val_dataset   = TensorDataset(torch.from_numpy(Z_val).float(), torch.from_numpy(y_val).float())
test_dataset = TensorDataset(torch.from_numpy(Z_test).float(), torch.from_numpy(y_test).float())

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Set optimizer and loss function
latent_dim = Z_train.shape[1]
TFmodel = TransformerHead(z_dim=latent_dim, tokens= latent_dim, d_model=128, nhead=8).to(device)
optimizer = optim.Adam(TFmodel.parameters(), lr=5e-3)
criterion = nn.BCEWithLogitsLoss()

# Training
num_epochs = 50

for epoch in range(num_epochs):
    TFmodel.train()
    train_loss = 0.0
    all_preds_train = []
    all_labels_train = []

    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        y_pred = TFmodel(xb)
        yb = yb.unsqueeze(1)
        loss = criterion(y_pred, yb)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)

        prob = torch.sigmoid(y_pred)
        all_preds_train.append(prob.cpu())
        all_labels_train.append(yb.cpu())

    train_loss /= len(train_loader.dataset)
    all_preds_train = torch.cat(all_preds_train).detach().numpy()
    all_labels_train = torch.cat(all_labels_train).detach().numpy()
    train_acc = accuracy_score(all_labels_train, all_preds_train > 0.5)
    train_auc = roc_auc_score(all_labels_train, all_preds_train)

    # validation
    TFmodel.eval()
    val_loss = 0.0
    all_preds_val = []
    all_labels_val = []

    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            y_pred = TFmodel(xb)
            yb = yb.unsqueeze(1)
            loss = criterion(y_pred, yb)
            val_loss += loss.item() * xb.size(0)

            prob = torch.sigmoid(y_pred)
            all_preds_val.append(prob.cpu())
            all_labels_val.append(yb.cpu())

    val_loss /= len(val_loader.dataset)
    all_preds_val = torch.cat(all_preds_val).detach().numpy()
    all_labels_val = torch.cat(all_labels_val).detach().numpy()
    val_acc = accuracy_score(all_labels_val, all_preds_val > 0.5)
    val_auc = roc_auc_score(all_labels_val, all_preds_val)

    print(f"Epoch {epoch+1}/{num_epochs} |"
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | train AUROC: {train_auc:.4f} |"
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | val AUROC: {val_auc:.4f}")

# Test evaluation metrics
TFmodel.eval()
all_preds_test = []
all_labels_test = []

with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        y_pred = TFmodel(xb)
        yb = yb.unsqueeze(1)
        prob = torch.sigmoid(y_pred)
        all_preds_test.append(prob.cpu())
        all_labels_test.append(yb.cpu())

all_preds_test = torch.cat(all_preds_test).detach().numpy()
all_labels_test = torch.cat(all_labels_test).detach().numpy()
test_acc = accuracy_score(all_labels_test, all_preds_test > 0.5)
test_auc = roc_auc_score(all_labels_test, all_preds_test)

print(f"Test Accuracy: {test_acc:.4f} | Test AUROC: {test_auc:.4f}")



Epoch 1/50 |Train Loss: 0.7415 | Train Acc: 0.5029 | train AUROC: 0.5043 |Val Loss: 0.7372 | Val Acc: 0.5000 | val AUROC: 0.4752
Epoch 2/50 |Train Loss: 0.6987 | Train Acc: 0.4968 | train AUROC: 0.4896 |Val Loss: 0.6937 | Val Acc: 0.5000 | val AUROC: 0.5078
Epoch 3/50 |Train Loss: 0.6979 | Train Acc: 0.5100 | train AUROC: 0.5070 |Val Loss: 0.6932 | Val Acc: 0.5000 | val AUROC: 0.5241
Epoch 4/50 |Train Loss: 0.6956 | Train Acc: 0.4893 | train AUROC: 0.4869 |Val Loss: 0.6952 | Val Acc: 0.5000 | val AUROC: 0.5338
Epoch 5/50 |Train Loss: 0.6955 | Train Acc: 0.4957 | train AUROC: 0.4919 |Val Loss: 0.6930 | Val Acc: 0.4975 | val AUROC: 0.5432
Epoch 6/50 |Train Loss: 0.6996 | Train Acc: 0.4893 | train AUROC: 0.4926 |Val Loss: 0.6936 | Val Acc: 0.5000 | val AUROC: 0.5411
Epoch 7/50 |Train Loss: 0.7007 | Train Acc: 0.5000 | train AUROC: 0.4949 |Val Loss: 0.6949 | Val Acc: 0.5000 | val AUROC: 0.5417
Epoch 8/50 |Train Loss: 0.6977 | Train Acc: 0.4800 | train AUROC: 0.4773 |Val Loss: 0.6957 | Val 

#### 5-2. Application of IG

In [None]:
# model rapper
import numpy as np
from scipy.stats import spearmanr, pearsonr
import torch.nn as nn
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

ae.to(device).eval()
TFmodel.to(device).eval()

class XtoZ(nn.Module):
    def __init__(self, ae):
        super().__init__()
        self.ae = ae
    def forward(self, x):
        z = self.ae.enc(x)
        return z

class ZtoY(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.transformer = model
    def forward(self, z):
        return self.transformer(z)  # (B,1)

class XtoY(nn.Module):
    def __init__(self, ae, model):
        super().__init__()
        self.ae = ae
        self.transformer = model
    def forward(self, x):
        z = self.ae.enc(x)
        return self.transformer(z)

x_to_z = XtoZ(ae).to(device).eval()
z_to_y = ZtoY(TFmodel).to(device).eval()
x_to_y = XtoY(ae, TFmodel).to(device).eval()

cpu


In [None]:
# Tensor data
X_train_t = torch.from_numpy(X_train.astype(np.float32)).to(device)
X_test_t  = torch.from_numpy(X_test.astype(np.float32)).to(device)

n_test, n_feat = X_test_t.shape
print(f"n_test={n_test}, n_feat={n_feat}")

# baseline setting
baseline_vec = X_train_t.mean(0)    # (D,)

n_test=800, n_feat=1000


In [None]:
# Integrated Gradients
def integrated_gradients(x_single, *, steps=128):
    x_to_y.eval()
    x = x_single.to(device).float()
    baseline = baseline_vec.to(device).float()

    # (steps+1, 1) alpha
    alphas = torch.linspace(0.0, 1.0, steps+1, device=device).view(-1, 1)
    # (steps+1, D)
    x_interp = baseline.unsqueeze(0) + alphas * (x.unsqueeze(0) - baseline.unsqueeze(0))
    x_interp.requires_grad_(True)

    # forward
    y = x_to_y(x_interp).squeeze(-1)    # (steps+1,)
    y_sum = y.sum()

    # backward
    (grads,) = torch.autograd.grad(
        y_sum, x_interp,
        retain_graph=False,
        create_graph=False
    )   # (steps+1, D)

    avg_grad = grads.mean(dim=0)    # (D,)
    ig = (x - baseline) * avg_grad  # (D,)

    return ig.detach().cpu().numpy().astype(np.float32)

In [None]:
# utils
def compute_ig_matrix(X_tensor, *, steps=128, idx_list=None):
    if idx_list is None:
        idx_list = np.arange(X_tensor.shape[0])

    ig_list = []
    for i in idx_list:
        x_single = X_tensor[i]  # (D,)
        ig = integrated_gradients(x_single, steps=steps)    # (D,)
        ig_list.append(ig)

    ig_mat = np.stack(ig_list, axis=0)  # (N_used, D)
    return ig_mat

In [34]:
# reference IG vs approximation IG
steps_ref    = 256
steps_approx = 16

idx_all = np.arange(n_test)

ig_ref = compute_ig_matrix(X_test_t, steps=steps_ref, idx_list=idx_all)
print("ref IG shape:", ig_ref.shape)

ig_apx = compute_ig_matrix(X_test_t, steps=steps_approx, idx_list=idx_all)
print("approx IG shape:", ig_apx.shape)

ref IG shape: (800, 1000)
approx IG shape: (800, 1000)


In [35]:
# Correlation function
def prep(v, use_abs=True, eps=1e-12):
    v = np.asarray(v).ravel()
    if use_abs:
        v = np.abs(v)
    return np.clip(v, eps, None)

def corr_report(ref, apx, desc="", top_frac=0.1):
    ref_flat = prep(ref)
    apx_flat = prep(apx)

    # global
    sp_global = spearmanr(ref_flat, apx_flat).correlation
    pe_global = pearsonr(ref_flat, apx_flat)[0]

    # global gene importance
    ref_gene = prep(np.mean(np.abs(ref), axis=0))
    apx_gene = prep(np.mean(np.abs(apx), axis=0))

    # top-k mask
    kq   = 1.0 - top_frac
    thr  = np.quantile(ref_gene, kq)
    mask = ref_gene >= thr

    sp_gene_all = spearmanr(ref_gene, apx_gene).correlation
    pe_gene_all = pearsonr(ref_gene, apx_gene)[0]

    sp_gene_top = spearmanr(ref_gene[mask], apx_gene[mask]).correlation
    pe_gene_top = pearsonr(ref_gene[mask], apx_gene[mask])[0]

    print()
    print(f"=== {desc} ===")
    print(f"[GLOBAL (all entries)]  Spearman={sp_global:.3f}  Pearson={pe_global:.3f}")
    print(f"[GENE MEAN (all gene)] Spearman={sp_gene_all:.3f}  Pearson={pe_gene_all:.3f}")
    print(f"[GENE MEAN (top {int(top_frac*100)}% gene by ref)] "
          f"Spearman={sp_gene_top:.3f}  Pearson={pe_gene_top:.3f}")
    print()

In [36]:
# Report
corr_report(ig_ref, ig_apx,
            desc=f"Ref IG (steps={steps_ref}) vs Approx IG (steps={steps_approx})",
            top_frac=0.1)


=== Ref IG (steps=256) vs Approx IG (steps=16) ===
[GLOBAL (all entries)]  Spearman=0.998  Pearson=0.999
[GENE MEAN (all gene)] Spearman=1.000  Pearson=1.000
[GENE MEAN (top 10% gene by ref)] Spearman=0.991  Pearson=0.997

