In [None]:
import sys
sys.path.insert(0, "../../src")

import time
import json
import copy
import torch
import pickle
import logging
import warnings
import matplotlib
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from pathlib import Path
from gears import PertData
from transformers import set_seed
from typing import List, Dict, Optional
from torch_geometric.loader import DataLoader
from gears.utils import create_cell_graph_dataset_for_prediction
from gears.inference import deeper_analysis, non_dropout_analysis
from stella.models.modeling_stella import STELLAForPerturbation
from stella.utils import map_raw_id_to_vocab_id, compute_perturbation_metrics

warnings.filterwarnings("ignore")
matplotlib.rcParams["savefig.transparent"] = False
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
set_seed(42)

### 1. Config

In [None]:
# settings for data prcocessing
include_zero_gene = "all"
max_seq_len = 1536 # 1536

# settings for optimizer
lr = 1e-4
batch_size = 12 # 12
eval_batch_size = 12 # 12
epochs = 10
schedule_interval = 1
early_stop = 3
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# logging steps for training
log_interval = 200

# Here we use the same dataset as GEARS, you can also use your own dataset.
# !!! If you want to use your own dataset, please process your dataset first !!!
data_name = "norman"  # dixit, norman, adamson
split = "simulation"

# visualization
if data_name == "norman":
    perts_to_plot = ["SAMD1+ZBTB1"]
elif data_name == "adamson":
    perts_to_plot = ["KCTD16+ctrl"]

# save
save_dir = Path(f"./save/dev_perturb_{data_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"saving to {save_dir}")
logging.info(f"Running on {time.strftime('%Y-%m-%d %H:%M:%S')}")

### 2. Load Dataset and DataLoader

In [None]:
pert_data = PertData("/data/home/tuser/liyazi/STELLA_Beta-main/tutorials/03_Perturbation_Prediction/data")
pert_data.load(data_name=data_name)
pert_data.prepare_split(split=split, seed=1)
pert_data.get_dataloader(batch_size=batch_size, test_batch_size=eval_batch_size)

### 3. Convert gene symbol to id according to STELLA vocabulary

In [None]:
with open("../../src/stella/gene2id.pkl", "rb") as f:
    vocab = pickle.load(f)
    
ntokens = len(vocab)
print(ntokens)

# View num_genes in STELLA vocab
num_genes_in_vocab = pert_data.adata.var["gene_name"].isin(vocab.keys()).sum()
logging.info(f"Number of genes in stella vocab: {num_genes_in_vocab} / {len(pert_data.adata.var['gene_name'])}")

genes = pert_data.adata.var["gene_name"].tolist()
n_genes = len(genes)

# If gene symbol is not in vocab, use [PAD] token id to represent it.
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["[PAD]"] for gene in genes], dtype=int
)

### 4. Load Pretrained Model and Freeze Some Layers

In [None]:
model = STELLAForPerturbation.from_pretrained(
    "../../pretrained_models/B100_L2048", input_gene_expr_type="continuous"
)

def freeze_first_k_layers(k=4):
    for name, param in model.named_parameters():
        if any(f"stella.encoder.layer.{i}" in name for i in range(k)):
            param.requires_grad = False

# freeze the first k layers
freeze_first_k_layers(k=0)  # no freeze

# check the trainable status of the parameters
for name, params in model.named_parameters():
    print(name, "\t", params.requires_grad)

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, schedule_interval, gamma=0.9)

### 5. Define the functions needed for training, validation and prediction

In [None]:
def process_loader(batch_data, split_input_gene: bool, max_seq_len: int = None) -> torch.Tensor:
    """
    Process batch_data to get the input and target tensors.
    """
    batch_data = batch_data.to(device)
    batch_size = len(batch_data.pert)  # batch size
    
    ori_gene_values = batch_data.x[:, 0].view(batch_size, n_genes)
    pert_flags = batch_data.x[:, 1].view(batch_size, n_genes)
    
    if include_zero_gene in ["all", "batch-wise"]:
        if include_zero_gene == "all":
            input_gene_ids = torch.arange(n_genes, device=device, dtype=torch.long)
        else:
            input_gene_ids = ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
        if split_input_gene:
            if len(input_gene_ids) > max_seq_len:
                # method 1: perturb gene should be in max_seq_len
                exclude = pert_flags.nonzero()[:, -1].unique()
                available = input_gene_ids[~torch.isin(input_gene_ids, exclude)]
                input_gene_ids = available[torch.randperm(len(available), device=device)[:max_seq_len-len(exclude)]]
                input_gene_ids = torch.cat([exclude, input_gene_ids])

                # method 2: perturb gene can not be in max_seq_len
                # input_gene_ids = torch.randperm(len(input_gene_ids), device=device)[:max_seq_len]

        mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
        mapped_input_gene_ids = mapped_input_gene_ids.repeat(len(batch_data.pert), 1).to(dtype=torch.int32)
        input_values = ori_gene_values[:, input_gene_ids].to(dtype=torch.float32)   # [batch size, seq_len]
        input_pert_flags = pert_flags[:, input_gene_ids].to(dtype=torch.int32)   # [batch size, seq_len]
        
        if batch_data.y is not None:
            target_gene_values = batch_data.y.view(batch_size, n_genes)  # [batch size, seq_len]
            target_values = target_gene_values[:, input_gene_ids].to(dtype=torch.float32)  # [batch size, seq_len]
        else:
            target_values = None

    return mapped_input_gene_ids, input_values, input_pert_flags, target_values
    
        


def train(model: nn.Module, train_loader: torch.utils.data.DataLoader) -> None:
    """
    Train one epoch.
    """
    model.train()
    total_loss = 0.0
    start_time = time.time()
    scaler = GradScaler()

    num_batches = len(train_loader)
    for batch, batch_data in enumerate(train_loader):
        with autocast():
            mapped_input_gene_ids, input_values, input_pert_flags, target_values = process_loader(
                batch_data, 
                split_input_gene=True, 
                max_seq_len=max_seq_len
                )

            output_dict = model(
                input_ids_gene_symbol = mapped_input_gene_ids,
                input_ids_gene_expression = input_values,
                input_pert_flags = input_pert_flags,
                labels = target_values,
                output_attentions = False,
                output_hidden_states = False,
            )

            loss = output_dict['loss']
            
            # amp training
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            # non-amp training
            # optimizer.zero_grad()
            # loss.backward()
            # grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            # optimizer.step()

            total_loss += loss.item()

            if batch % log_interval == 0 and batch > 0:
                lr = scheduler.get_last_lr()[0]
                ms_per_batch = (time.time() - start_time) * 1000 / log_interval
                cur_loss = total_loss / log_interval
                logging.info(
                    f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                    f"lr {lr:05.5f} | ms/batch {ms_per_batch:5.2f} | "
                    f"loss {cur_loss:5.4f} | grad_norm {grad_norm:5.4f} | "
                )
                total_loss = 0.0
                start_time = time.time()


def eval_perturb(model: STELLAForPerturbation, loader: DataLoader) -> Dict:
    model.eval()
    model.to(device)
    pert_cat = []
    pred = []
    truth = []
    pred_de = []
    truth_de = []
    results = {}

    for itr, batch_data in enumerate(loader):
        pert_cat.extend(batch_data.pert)
        mapped_input_gene_ids, input_values, input_pert_flags, target_values = process_loader(
            batch_data, 
            split_input_gene=False
            )
        
        with torch.no_grad():
            output_dict = model(
            input_ids_gene_symbol = mapped_input_gene_ids,
            input_ids_gene_expression = input_values,
            input_pert_flags = input_pert_flags,
            labels = target_values,
            output_attentions = False,
            output_hidden_states = False,
            )
            p = output_dict["logits"]
            t = batch_data.y
            pred.extend(p.cpu())
            truth.extend(t.cpu())

            for itr, de_idx in enumerate(batch_data.de_idx):
                pred_de.append(p[itr, de_idx])
                truth_de.append(t[itr, de_idx])

    results["pert_cat"] = np.array(pert_cat)
    pred = torch.stack(pred)
    truth = torch.stack(truth)
    results["pred"] = pred.detach().cpu().numpy().astype(np.float32)
    results["truth"] = truth.detach().cpu().numpy().astype(np.float32)

    pred_de = torch.stack(pred_de)
    truth_de = torch.stack(truth_de)
    results["pred_de"] = pred_de.detach().cpu().numpy().astype(np.float32)
    results["truth_de"] = truth_de.detach().cpu().numpy().astype(np.float32)

    return results


def predict(
    model: STELLAForPerturbation, pert_list: List[str], query, pool_size: Optional[int] = None
) -> Dict:
    """
    Predict the gene expression values for the given perturbations.

    Args:
        model (:class:`torch.nn.Module`): The model to use for prediction.
        pert_list (:obj:`List[str]`): The list of perturbations to predict.
        pool_size (:obj:`int`, optional): For each perturbation, use this number
            of cells in the control and predict their perturbation results. Report
            the stats of these predictions. If `None`, use all control cells.
    """
    adata = pert_data.adata
    ctrl_adata = adata[adata.obs["condition"] == "ctrl"]
    if pool_size is None:
        pool_size = len(ctrl_adata.obs)
    gene_list = pert_data.gene_names.values.tolist()
    for pert in pert_list:
        for i in pert:
            if i not in gene_list:
                raise ValueError(
                    "The gene is not in the perturbation graph. Please select from GEARS.gene_list!"
                )

    model.eval()
    device = next(model.parameters()).device
    with torch.no_grad():
        results_pred = {}
        for pert in pert_list:
            cell_graphs = create_cell_graph_dataset_for_prediction(
                pert, ctrl_adata, gene_list, device, num_samples=pool_size
            )
            loader = DataLoader(cell_graphs, batch_size=eval_batch_size, shuffle=False)
            preds = []
            for batch_data in loader:
                mapped_input_gene_ids, input_values, input_pert_flags, target_values = process_loader(
                batch_data, 
                split_input_gene=False
                )
                output_dict = model(
                input_ids_gene_symbol = mapped_input_gene_ids,
                input_ids_gene_expression = input_values,
                input_pert_flags = input_pert_flags,
                labels = target_values,
                output_attentions = False,
                output_hidden_states = False,
                )
                pred_gene_values = output_dict["logits"]
                preds.append(pred_gene_values)
            preds = torch.cat(preds, dim=0)
            results_pred["_".join(pert)] = np.mean(preds.detach().cpu().numpy(), axis=0)

    return results_pred

### 6. Start Training

In [None]:
best_val_corr = 0
best_model = None
patience = 0
train_loader = pert_data.dataloader["train_loader"]
valid_loader = pert_data.dataloader["val_loader"]

for epoch in range(1, epochs + 1):
    torch.cuda.empty_cache()
    epoch_start_time = time.time()

    train(model, train_loader)
    val_res = eval_perturb(model, valid_loader)
    val_metrics = compute_perturbation_metrics(
        val_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
    )
    logging.info(f"val_metrics at epoch {epoch}: ")
    logging.info(val_metrics)

    elapsed = time.time() - epoch_start_time
    logging.info(f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ")

    val_score = val_metrics["pearson_delta"]
    if val_score > best_val_corr:
        best_val_corr = val_score
        best_model = copy.deepcopy(model)
        logging.info(f"Best model with score (pearson_delta) {val_score:5.4f}")
        patience = 0
    else:
        patience += 1
        if patience >= early_stop:
            logging.info(f"Early stop at epoch {epoch}")
            break
    scheduler.step()

torch.save(best_model.state_dict(), save_dir / "best_model.pt")

### 7. Plot Perturbations

In [None]:
def plot_perturbation(
    model: nn.Module, query: str, save_file: str = None, pool_size: int = None
) -> matplotlib.figure.Figure:
    sns.set_theme(style="ticks", rc={"axes.facecolor": (0, 0, 0, 0)}, font_scale=1.5)

    adata = pert_data.adata
    gene2idx = pert_data.node_map
    cond2name = dict(adata.obs[["condition", "condition_name"]].values)
    gene_raw2id = dict(zip(adata.var.index.values, adata.var.gene_name.values))

    de_idx = [
        gene2idx[gene_raw2id[i]]
        for i in adata.uns["top_non_dropout_de_20"][cond2name[query]]
    ]
    genes = [
        gene_raw2id[i] for i in adata.uns["top_non_dropout_de_20"][cond2name[query]]
    ]
    truth = adata[adata.obs.condition == query].X.toarray()[:, de_idx]
    if query.split("+")[1] == "ctrl":
        pert_list = [[query.split("+")[0]]]
        pred = predict(model, pert_list, query, pool_size=pool_size)
        pred = pred[query.split("+")[0]][de_idx]
    else:
        pert_list = [query.split("+")]
        pred = predict(model, pert_list, query, pool_size=pool_size)
        pred = pred["_".join(query.split("+"))][de_idx]
    ctrl_means = adata[adata.obs["condition"] == "ctrl"].to_df().mean()[de_idx].values

    pred = pred - ctrl_means
    truth = truth - ctrl_means

    fig, ax = plt.subplots(figsize=[16.5, 4.5])
    plt.title(query)
    plt.boxplot(truth, showfliers=False, medianprops=dict(linewidth=0))

    for i in range(pred.shape[0]):
        _ = plt.scatter(i + 1, pred[i], color="red")

    plt.axhline(0, linestyle="dashed", color="green")

    ax.xaxis.set_ticklabels(genes, rotation=90)

    plt.ylabel("Change in Gene Expression over Control", labelpad=10)
    plt.tick_params(axis="x", which="major", pad=5)
    plt.tick_params(axis="y", which="major", pad=5)
    sns.despine()

    if save_file:
        fig.savefig(save_file, bbox_inches="tight", transparent=False)

    return fig

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if best_model is None:
    # load best model from saved file
    best_model = STELLAForPerturbation.from_pretrained("../../pretrained_models/B100_L2048", input_gene_expr_type = "continuous").to(device)
    save_dir = Path('./save/dev_perturb_norman-May11-19-54')
    model_path = save_dir / "best_model.pt"
    best_model.load_state_dict(torch.load(model_path, map_location=device))

for p in perts_to_plot:
    plot_perturbation(best_model, p, pool_size=300, save_file=f"{save_dir}/{p}.png")

### 8. Test results and deeper analysis

In [None]:
test_loader = pert_data.dataloader["test_loader"]
test_res = eval_perturb(best_model, test_loader)
test_metrics = compute_perturbation_metrics(
    test_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
)
print(test_metrics)
test_metrics = {k: float(v) if isinstance(v, np.float32) else v for k, v in test_metrics.items()}
# save test results
with open(f"{save_dir}/test_metrics.json", "w") as f:
    json.dump(test_metrics, f)

deeper_res = deeper_analysis(pert_data.adata, test_res)
non_dropout_res = non_dropout_analysis(pert_data.adata, test_res)

metrics = ["pearson_delta", "pearson_delta_de"]
metrics_non_dropout = [
    "pearson_delta_top20_de_non_dropout",
    "pearson_top20_de_non_dropout",
]
subgroup_analysis = {}
for name in pert_data.subgroup["test_subgroup"].keys():
    subgroup_analysis[name] = {}
    for m in metrics:
        subgroup_analysis[name][m] = []

    for m in metrics_non_dropout:
        subgroup_analysis[name][m] = []

for name, pert_list in pert_data.subgroup["test_subgroup"].items():
    for pert in pert_list:
        for m in metrics:
            subgroup_analysis[name][m].append(deeper_res[pert][m])

        for m in metrics_non_dropout:
            subgroup_analysis[name][m].append(non_dropout_res[pert][m])

for name, result in subgroup_analysis.items():
    for m in result.keys():
        mean_value = np.mean(subgroup_analysis[name][m])
        logging.info("test_" + name + "_" + m + ": " + str(mean_value))