In [92]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
import os

In [93]:
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support, classification_report

In [94]:
import scanpy as sc
adata = sc.read_h5ad("/gpfs/gibbs/pi/zhao/tl688/scGPT/examples/predict_dosage_dataset.h5ad")
adata = sc.AnnData(adata.X, obs = adata.obs, var = adata.var)
adata.obs['celltype'] = ['no' for i in adata.obs_names]
adata.obs['batch'] = ['no' for i in adata.obs_names]

# adata = sc.read_h5ad("/gpfs/gibbs/pi/zhao/tl688/scGPT/gene_labels/task3a_bi_nom.h5ad")
# adata = sc.AnnData(adata.X, obs = adata.obs, var = adata.var)
# adata.obs['celltype'] = ['no' for i in adata.obs_names]
# adata.obs['batch'] = ['no' for i in adata.obs_names]

# adata = sc.read_h5ad("/gpfs/gibbs/pi/zhao/tl688/scGPT/gene_labels/task3b_bi_lys4.h5ad")
# print(adata.var_names)
# adata = sc.AnnData(adata.X, obs = adata.obs, var = adata.var)
# adata.obs['celltype'] = ['no' for i in adata.obs_names]
# adata.obs['batch'] = ['no' for i in adata.obs_names]

  utils.warn_names_duplicates("var")


In [95]:
adata.var_names

Index(['7SK', 'A1BG-AS1', 'A1BG', 'A1CF', 'A2M-AS1', 'A2M', 'A2ML1', 'A2MP1',
       'A4GALT', 'AAAS',
       ...
       'ZW10', 'ZWILCH', 'ZWINT', 'ZXDB', 'ZXDC', 'ZYG11A', 'ZYG11B', 'ZYX',
       'ZZEF1', 'ZZZ3'],
      dtype='object', length=26674)

In [96]:
adata_train = adata[:, adata.var['dose_cond'] != -1] 

In [97]:
from sklearn.model_selection import train_test_split

In [98]:
(
    train_data,
    valid_data,
    train_gene_labels,
    valid_gene_labels,
    train_gene_name,
    valid_gene_name,
    train_adata_varnames,
    valid_adata_varnames
    
) = train_test_split(
    adata_train.X.T, adata_train.var['dose_cond'], adata_train.var_names, adata_train.var_names, test_size=0.33, shuffle=True, random_state=42
)


In [99]:
train_gene_labels

RARB       1
ZNF33B     0
ZKSCAN8    0
ZNF684     0
ZNF22      0
          ..
MBD2       1
PRDM16     1
ZNF439     0
ZNF664     0
PAX8       1
Name: dose_cond, Length: 289, dtype: int64

In [100]:
train_gene_name

Index(['RARB', 'ZNF33B', 'ZKSCAN8', 'ZNF684', 'ZNF22', 'MYCN', 'ZNF625',
       'ZNF480', 'MECOM', 'ZNF154',
       ...
       'ZNF20', 'RUNX2', 'ZUFSP', 'DLX1', 'ZFP41', 'MBD2', 'PRDM16', 'ZNF439',
       'ZNF664', 'PAX8'],
      dtype='object', length=289)

In [101]:
df_geneemb = pd.read_csv("../cellfm_pathwayenrichment.csv", index_col=0)

In [102]:
ovarlap_train = sorted(set(train_gene_name).intersection(set(df_geneemb.index)))
df_geneemb_train = df_geneemb.loc[ovarlap_train]
y_train = train_gene_labels.loc[ovarlap_train].values

In [103]:
ovarlap_valid = sorted(set(valid_gene_name).intersection(set(df_geneemb.index)))
df_geneemb_valid = df_geneemb.loc[ovarlap_valid]
y_valid = valid_gene_labels.loc[ovarlap_valid].values

In [104]:
class DummyDataset(Dataset):
    """
    A simple dummy dataset for classification.
    """
    def __init__(self, data,label):
        self.data = torch.FloatTensor(data)
        # Create random labels (0 or 1)
        self.labels = torch.FloatTensor(label).long()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

class DummyDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, emb_dims=128):
        super().__init__()
        self.batch_size = batch_size
        self.emb_dims = emb_dims

    def setup(self, stage=None):

        self.train_dataset, self.val_dataset =DummyDataset(df_geneemb_train.values,y_train), DummyDataset(df_geneemb_valid.values,y_valid)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)


# --- 2. Define the Lightning Module (ClassifierModel) ---

class ClassifierModel(pl.LightningModule):
    def __init__(self, emb_dims=128, learning_rate=1e-3):
        super().__init__()
        # Saves emb_dims and learning_rate to self.hparams for easy access
        self.save_hyperparameters()

        # Translate the MindSpore architecture to PyTorch:
        # nn.Dense(emb_dims, emb_dims//2, has_bias=False) -> nn.Linear(emb_dims, emb_dims//2, bias=False)
        # nn.SiLU() is directly available in PyTorch

        self.mlp = nn.Sequential(
            # Layer 1: emb_dims -> emb_dims//2
            nn.Linear(emb_dims, emb_dims // 2, bias=False),
            nn.Dropout(p=0.15),
            nn.SiLU(),

            # Layer 2: emb_dims//2 -> emb_dims//4
            nn.Linear(emb_dims // 2, emb_dims // 4, bias=False),
            nn.Dropout(p=0.15),
            nn.SiLU(),

            # Layer 3: emb_dims//4 -> 2 (2 classes)
            nn.Linear(emb_dims // 4, 2, bias=False),
        )

        # Loss function for classification
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        """Passes the input tensor through the MLP."""
        return self.mlp(x)

    def _common_step(self, batch, batch_idx):
        """Reusable function for training and validation steps."""
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        n_correct = (preds == y).sum().item()
        n_total = y.size(0)
        
        return loss, n_correct, n_total

    def training_step(self, batch, batch_idx):
        loss, n_correct, n_total = self._common_step(batch, batch_idx)
        acc = n_correct / n_total
        
        # Logging to the progress bar and logger
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', acc, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, n_correct, n_total = self._common_step(batch, batch_idx)
        
        # Logging validation metrics
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_n_correct', float(n_correct), on_step=False, on_epoch=True, reduce_fx='sum')
        self.log('val_n_total', float(n_total), on_step=False, on_epoch=True, reduce_fx='sum')
        
    def on_validation_epoch_end(self):
        # Calculate final validation accuracy after all steps
        n_correct = self.trainer.logged_metrics['val_n_correct']
        n_total = self.trainer.logged_metrics['val_n_total']
        val_acc = n_correct / n_total
        self.log('val_acc', val_acc, prog_bar=True)
        del self.trainer.logged_metrics['val_n_correct']
        del self.trainer.logged_metrics['val_n_total']


    def configure_optimizers(self):
        """Defines the optimizer."""
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

In [105]:
import pandas as pd

In [106]:
# --- Configuration ---
EMB_DIMS = 1536  # Must match the input feature size
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
MAX_EPOCHS = 30

print(f"Starting training with emb_dims={EMB_DIMS}, LR={LEARNING_RATE}, Epochs={MAX_EPOCHS}")

# 1. Setup DataModule
dm = DummyDataModule(batch_size=BATCH_SIZE, emb_dims=EMB_DIMS)

# 2. Setup Model
model = ClassifierModel(emb_dims=EMB_DIMS, learning_rate=LEARNING_RATE)

# 3. Setup Logger and Trainer
# Use CSVLogger to save logs in 'lightning_logs/'
logger = CSVLogger("lightning_logs", name="mlp_classifier")

# Check for GPU availability
accelerator = "gpu" if torch.cuda.is_available() else "cpu"

trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    logger=logger,
    accelerator=accelerator,
    log_every_n_steps=10,
)

# 4. Train the Model
trainer.fit(model, dm)

# 5. Optional: Test the model after training
print("\n--- Training complete. Running validation on best model ---")
trainer.validate(ckpt_path="best", datamodule=dm)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | mlp     | Sequential       | 1.5 M 
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
5.901     Total estimated model params size (MB)


Starting training with emb_dims=1536, LR=0.0001, Epochs=30


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Restoring states from the checkpoint path at lightning_logs/mlp_classifier/version_8/checkpoints/epoch=29-step=150.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at lightning_logs/mlp_classifier/version_8/checkpoints/epoch=29-step=150.ckpt



--- Training complete. Running validation on best model ---


Validation: 0it [00:00, ?it/s]

[{'val_loss': 0.5864219665527344,
  'val_n_correct': 93.0,
  'val_n_total': 137.0,
  'val_acc': 0.6788321137428284}]

In [107]:
with torch.no_grad():
    model.cuda()
    y_pred = model(torch.FloatTensor(df_geneemb_valid.values).cuda()).cpu()

In [108]:
y_pred.argmax(axis=1)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [109]:
label = y_pred.argmax(axis=1).numpy()

In [110]:
print(classification_report(label, y_valid, digits=4))

              precision    recall  f1-score   support

           0     0.9485    0.7023    0.8070       131
           1     0.0250    0.1667    0.0435         6

    accuracy                         0.6788       137
   macro avg     0.4867    0.4345    0.4252       137
weighted avg     0.9080    0.6788    0.7736       137

