### Imports

In [1]:
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F # Needed for the attention layer
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score
import numpy as np
import pandas as pd
import time
import torch.nn.functional as F
import optuna


  from .autonotebook import tqdm as notebook_tqdm


### Prepping Data

In [2]:
#obtain data from file
adata =  sc.read_h5ad("data/Norman_2019.h5ad")  # replace with your path

#export data from sc.read
ddata = adata.X.toarray()
labels = adata.obs['perturbation_name'].to_numpy()
parsed_labels = [p.split('+') if p != 'control' else [] for p in labels]

#multilabel encode the data 
mlb = MultiLabelBinarizer()
labels_int = mlb.fit_transform(parsed_labels)

#split data
X_train, X_test, y_train, y_test = train_test_split(
    ddata, 
    labels_int, 
    test_size=0.2, 
    random_state=67, #SIX SEVEENNNNNNNNNN
    #stratify=labels_int
)

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)




In [5]:
import modelMLP

input_size = X_train.shape[1] # Number of genes
num_classes = labels_int[0].size
learning_rate = 0.00026
num_epochs = 25

model = modelMLP.MLP(input_size=input_size, num_classes=num_classes)
criterion = nn.BCEWithLogitsLoss() # Best for multi-class classification
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay= 2e-6)

### Oputuna Hyperparameters

In [6]:
def define_model(trial):
    """A helper function to define the model architecture based on the trial."""
    n_layers = trial.suggest_int('n_layers', 2, 10)
    layers = []
    in_features = X_train.shape[1]
    
    for i in range(n_layers):
        out_features = trial.suggest_int(f'n_units_l{i}', 10, 100, log=True)
        layers.append(nn.Linear(in_features, out_features))
        layers.append(nn.ReLU())
        dropout_p = trial.suggest_float(f'dropout_l{i}', 0.25, 0.5)
        layers.append(nn.Dropout(dropout_p))
        in_features = out_features
        
    layers.append(nn.Linear(in_features, num_classes))
    return nn.Sequential(*layers)

def objective(trial):
    """This is the main function Optuna will call for each trial."""
    
    # --- 1. Suggest Hyperparameters ---
    model = define_model(trial)
    lr = trial.suggest_float('lr', 1e-6, 1e-5, log=True)
    weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-5, log=True)
    
    # --- 2. Setup Model and Optimizer ---
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    # --- 3. Training and Evaluation Loop (Adapted from your code) ---
    num_epochs = 100
    patience = 10  # For early stopping
    patience_counter = 0
    best_val_f1 = 0.0

    for epoch in range(num_epochs):
        # --- Training Phase ---
        model.train()
        for features, labels in train_loader:
            outputs = model(features)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # --- Validation Phase ---
        model.eval()
        all_test_preds, all_test_labels = [], []
        with torch.no_grad():
            for inputs, labels in test_loader:
                outputs = model(inputs)
                preds = (torch.sigmoid(outputs) > 0.5).float()
                all_test_preds.append(preds.numpy())
                all_test_labels.append(labels.numpy())
        
        val_f1 = f1_score(np.vstack(all_test_labels), np.vstack(all_test_preds), average='micro')
        
        # --- Early Stopping and Pruning Logic ---
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Pruning: A feature of Optuna to stop unpromising trials early.
        trial.report(val_f1, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()

        if patience_counter >= patience:
            break # Exit loop if no improvement

    # The objective function must return the score to be maximized
    return best_val_f1

In [7]:
# Create a study object and specify the direction as 'maximize'
study = optuna.create_study(direction='maximize')

# Start the optimization. Let's run it for 50 trials as an example.
study.optimize(objective, n_trials=50)

# --- 4. Get and Print the Best Results ---
print("\nStudy statistics: ")
print("  Number of finished trials: ", len(study.trials))

print("\nBest trial:")
trial = study.best_trial

print(f"  Value (Best Validation F1 Score): {trial.value:.4f}")

print("  Best Hyperparameters: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

[I 2025-10-09 16:16:21,528] A new study created in memory with name: no-name-ba401a26-4b81-42c8-a4c2-3d5d9f8777fe
[W 2025-10-09 16:17:23,941] Trial 0 failed with parameters: {'n_layers': 10, 'n_units_l0': 53, 'dropout_l0': 0.44310917917856973, 'n_units_l1': 11, 'dropout_l1': 0.33263951812054837, 'n_units_l2': 96, 'dropout_l2': 0.3947779118610492, 'n_units_l3': 24, 'dropout_l3': 0.3456585669516296, 'n_units_l4': 28, 'dropout_l4': 0.25080184659563143, 'n_units_l5': 85, 'dropout_l5': 0.4128456506411581, 'n_units_l6': 16, 'dropout_l6': 0.48589124570366443, 'n_units_l7': 14, 'dropout_l7': 0.31230719765682186, 'n_units_l8': 76, 'dropout_l8': 0.3769424309827486, 'n_units_l9': 93, 'dropout_l9': 0.35077941653928724, 'lr': 4.073671182439168e-06, 'weight_decay': 9.901674774322529e-06} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/Users/steveyin/Library/Python/3.12/lib/python/site-packages/optuna/study/_optimize.py", line 201, in _run_trial
    va

KeyboardInterrupt: 