In [None]:
import torch
import numpy as np
import pandas as pd

np.random.seed(314159)

import torch
import pytorch_lightning as pl

from torch_geometric.data import Data

import wandb


In [None]:
import numpy as np
import torch

edge_list = np.load("edge_list_latest1.npy", allow_pickle=True)

# Splitting
edge_list = np.array([row[0].split(",") for row in edge_list[1:]])  # Skip the header

# Source nodes, target nodes, and combined scores
protein1 = edge_list[:, 0]  # First column (gene1)
protein2 = edge_list[:, 1]  # Second column (gene2)
combined_score = edge_list[:, 2].astype(np.float32)  # Third column (combined_score_x or _y)

# Mapping
protein_map = {protein: idx for idx, protein in enumerate(np.unique(np.concatenate((protein1, protein2))))}
protein1 = np.vectorize(protein_map.get)(protein1)
protein2 = np.vectorize(protein_map.get)(protein2)

# Stacking
edges = np.stack([protein1, protein2], axis=0)
edge_list = torch.tensor(edges, dtype=torch.long)
edge_weight = torch.tensor(combined_score, dtype=torch.float32) 

print("Edge list shape:", edge_list.shape) 
print("Edge weights shape:", edge_weight.shape)
print("Protein map:", protein_map)


In [None]:
data_path = '200_node_network_embeddings_latest.csv' 
node_dataset = pd.read_csv(data_path, index_col=0)
node_dataset

In [None]:
xcell_path = "/home/user/Diamond_101225/diamond_rho/diamond_tpm/gene_xcell_spearman_rho_with_ENSG.csv"
print("Loading xCell PCs from:", xcell_path)
# separator detection
xcell = pd.read_csv(xcell_path, sep=None, engine='python')

if 'gene_id' in xcell.columns:
    gene_col = 'gene_id'
    xcell['ensembl'] = xcell['gene_id'].astype(str).str.split('.').str[0]
else:
    gene_col = xcell.columns[0]
    xcell['ensembl'] = xcell[gene_col].astype(str).str.split('.').str[0]
pc_cols = [c for c in xcell.columns if ('xcell' in str(c).lower() and 'pc' in str(c).lower()) or str(c).lower().startswith('xcell_pc')]

if len(pc_cols) < 1:
    possible = [c for c in xcell.columns if c not in (gene_col, 'Gene', 'ensembl')]
    pc_cols = possible[:20]

print("Detected xCell PC columns (count={}): {}".format(len(pc_cols), pc_cols[:30]))

if len(pc_cols) == 0:
    print("WARNING: no xCell PC columns detected. Please inspect file and set pc_cols manually.")
else:
    xcell_idx = xcell.set_index('ensembl')[pc_cols]

    # Alignment
    xcell_aligned = xcell_idx.reindex(node_dataset.index)

    missing_pct = xcell_aligned.isna().mean(axis=0).round(3) * 100
    print("Percent missing per xCell PC column:\n", missing_pct)

    # Missing values
    fill_strategy = "mean"  # options: "mean" or "zero"
    if fill_strategy == "zero":
        xcell_aligned_filled = xcell_aligned.fillna(0.0)
    else:
        xcell_aligned_filled = xcell_aligned.fillna(xcell_aligned.mean())

    # renaming
    xcell_aligned_filled.columns = [f"xcell_{str(c)}" if not str(c).startswith("xcell_") else str(c) for c in xcell_aligned_filled.columns]

    # concatenation
    node_dataset = pd.concat([node_dataset, xcell_aligned_filled], axis=1)

    # merging
    out_merge = "node_dataset_with_xcell_rho.csv"
    node_dataset.to_csv(out_merge)
    print("Saved merged node_dataset to:", out_merge)

In [None]:
node_dataset.sort_index(inplace=True)
node_dataset.reset_index(drop=False, inplace=True)
assert((node_dataset.index.to_numpy()==np.arange(len(node_dataset))).all())

In [None]:
label_name = 'my_label'

pos_label_col = 'gda_score' 
node_dataset[pos_label_col].fillna(0, inplace=True) # Replace NaN values with 0  # Replace NaN values with 0
pos_labels = pd.array([1 if row[pos_label_col] else None for id_, row in node_dataset.iterrows()], dtype='Int32')
node_dataset[label_name] = pos_labels

def sample_negatives(PU_labels):
    '''randomly samples from the unlabeled samples'''

    # sample same # as positives
    num_pos = (PU_labels==1).sum()
    neg_inds = PU_labels[PU_labels.isna()].sample(num_pos).index

    return neg_inds 

neg_label_inds = sample_negatives(node_dataset[label_name])
node_dataset.loc[neg_label_inds, label_name] = 0

node_dataset[label_name].value_counts()

In [None]:
node_dataset.set_index(node_dataset.columns[0], inplace=True)
node_dataset

In [None]:
label_col = node_dataset.columns[-1]
node_dataset[label_col] = node_dataset[label_col].astype('Int32')

if 'gda_score' in node_dataset.columns:
    node_dataset = node_dataset.drop(columns=['gda_score'])

node_feat_cols = node_dataset.columns[:-1].tolist()

node_data = node_dataset[node_feat_cols + [label_col]]

X = torch.Tensor(node_data[node_feat_cols].select_dtypes(include=[np.number]).to_numpy(dtype=np.float32))

y = node_data[label_col].fillna(-1).astype('int')
y = torch.Tensor(y).type(torch.int64)

node_data_labeled = node_data[node_data[label_col].notna()]
node_data_labeled

In [None]:
from sklearn.model_selection import train_test_split
import numpy as np
import torch

X_myIDs = node_data_labeled.index.to_numpy()
labels = node_data_labeled[label_col].to_numpy()

test_size = 0.2
val_size = 0.1 * (1 / (1 - test_size))

myIDs_train_val, myIDs_test = train_test_split(X_myIDs, test_size=test_size, shuffle=True, stratify=labels)

labels_train_val = node_data_labeled.loc[myIDs_train_val, label_col].to_numpy()
myIDs_train, myIDs_val = train_test_split(myIDs_train_val, test_size=val_size, shuffle=True, stratify=labels_train_val)

id_to_idx = {id_: idx for idx, id_ in enumerate(node_data.index)}

train_idx = np.array([id_to_idx[i] for i in myIDs_train])
val_idx = np.array([id_to_idx[i] for i in myIDs_val])
test_idx = np.array([id_to_idx[i] for i in myIDs_test])

n_nodes = len(node_data)

train_mask = np.zeros(n_nodes, dtype=bool)
train_mask[train_idx] = True
train_mask = torch.tensor(train_mask, dtype=torch.bool)

val_mask = np.zeros(n_nodes, dtype=bool)
val_mask[val_idx] = True
val_mask = torch.tensor(val_mask, dtype=torch.bool)

test_mask = np.zeros(n_nodes, dtype=bool)
test_mask[test_idx] = True
test_mask = torch.tensor(test_mask, dtype=torch.bool)

print(f"Number of training nodes: {train_mask.sum().item()}")
print(f"Number of validation nodes: {val_mask.sum().item()}")
print(f"Number of test nodes: {test_mask.sum().item()}")


In [None]:
data = Data(x=X, y=y, edge_index=edge_list, edge_attr=edge_weight)

num_classes = 2
num_features = X.shape[1]

data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask
assert y.shape[0] == X.shape[0], "Mismatch: y and X must have the same number of nodes"


print(data)

In [None]:
import wandb
import os

os.environ["WANDB_API_KEY"] = ""#WANDB API key
os.environ["WANDB_MODE"] = "online"

config = {
    "dataset": "CIFAR10",
    "machine": "online cluster",
    "model": "CNN",
    "learning_rate": 0.01,
    "batch_size": 128,
}

wandb.init(project="offline-demo")

for i in range(100):
    wandb.log({"accuracy": i})

In [None]:
train_accs, train_precs, train_recalls, train_f1s = [], [], [], []
train_aurocs, train_auprcs = [], []

test_accs, test_precs, test_recalls, test_f1s = [], [], [], []
test_aurocs, test_auprcs = [], []


In [None]:
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, SGConv, MessagePassing
import torch.nn.functional as F
import pytorch_lightning as pl
import torch

class GNNModel(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes, hidden_dense, GNN_conv_layer=GCNConv, dropout_rate=0.1, **kwargs):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        self.GNN_conv_layer = GNN_conv_layer

        self.convs.append(GNN_conv_layer(in_channels=num_features, out_channels=hidden_channels[0], **kwargs))

        for c1, c2 in zip(hidden_channels[:-1], hidden_channels[1:]):
            self.convs.append(GNN_conv_layer(in_channels=c1, out_channels=c2, **kwargs))

        self.dense1 = torch.nn.Linear(hidden_channels[-1], hidden_dense)
        self.dense_out = torch.nn.Linear(hidden_dense, num_classes)

        self.dropout_rate = dropout_rate

    def forward(self, x, edge_index, edge_weight=None):
        for conv in self.convs:
            if isinstance(conv, SGConv):
                x = conv(x, edge_index, edge_weight=edge_weight)
            else:
                x = conv(x, edge_index, edge_weight=edge_weight)
            x = x.relu()
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        x = self.dense1(x)
        x = x.relu()
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.dense_out(x)

        return x


class WeightedSAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='mean')
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.lin_update = torch.nn.Linear(in_channels + out_channels, out_channels)
        self.last_messages = None

    def forward(self, x, edge_index, edge_weight):
        self.log("edge_shape", edge_index.shape[1], on_step=True, on_epoch=True, prog_bar=True, logger=True)
        print(f"Edge shape: {edge_index.shape}") 

        x_transformed = self.lin(x)
        self.last_messages = None
        out = self.propagate(edge_index=edge_index, x=x_transformed, edge_weight=edge_weight)
        out = self.lin_update(torch.cat([out, x], dim=1))
        return out

    def message(self, x_j, edge_weight):
        messages = edge_weight.view(-1, 1) * x_j
        self.last_messages = messages.detach().cpu()
        return messages

class LitGNN(pl.LightningModule):
    def __init__(self, model_name, num_features, hidden_channels, num_classes, hidden_dense, GNN_conv_layer, dropout_rate):
        super().__init__()
        
        self.save_hyperparameters() 
        self.model_name = model_name
        self.model = GNNModel(num_features, hidden_channels, num_classes, hidden_dense, GNN_conv_layer, dropout_rate)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.validation_step_outputs = []

    def forward(self, x, edge_index, edge_weight=None):
        return self.model(x, edge_index, edge_weight)

    def training_step(self, batch, batch_idx):
        out = self(batch.x, batch.edge_index, batch.edge_attr)
        valid_indices = batch.y >= 0
        
        self.log("edge_shape", batch.edge_index.shape[1])
        print(f"Training Edge shape: {batch.edge_index.shape}")
        
        if valid_indices.any():
            loss = self.criterion(out[valid_indices], batch.y[valid_indices])
        else:
            loss = torch.tensor(0.0, requires_grad=True, device=self.device)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        out = self(batch.x, batch.edge_index, edge_weight=batch.edge_attr)
        valid_indices = batch.y >= 0

        if valid_indices.any():
            loss = self.criterion(out[valid_indices], batch.y[valid_indices])
            acc = (out[valid_indices].argmax(dim=1) == batch.y[valid_indices]).float().mean()
        else:
            loss = torch.tensor(0.0, device=self.device)
            acc = torch.tensor(0.0, device=self.device)

        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.log("val_acc", acc, on_epoch=True, prog_bar=True)
        self.validation_step_outputs.append({"val_loss": loss, "val_acc": acc})

    def on_validation_epoch_end(self):
        val_acc_mean = torch.stack([x["val_acc"] for x in self.validation_step_outputs]).mean()
        self.log("val_acc_epoch", val_acc_mean, prog_bar=True)
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

    def test_step(self, batch, batch_idx):
        out = self(batch.x, batch.edge_index, batch.edge_attr)
        valid_indices = batch.y >= 0

        if valid_indices.any():
            loss = self.criterion(out[valid_indices], batch.y[valid_indices])
            acc = (out[valid_indices].argmax(dim=1) == batch.y[valid_indices]).float().mean()
        else:
            loss = torch.tensor(0.0, device=self.device)
            acc = torch.tensor(0.0, device=self.device)

        self.log("test_loss", loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", acc, on_epoch=True, prog_bar=True)
        return {"test_loss": loss, "test_acc": acc}

    def test_epoch_end(self, outputs):
        test_acc_mean = torch.stack([x["test_acc"] for x in outputs]).mean()
        self.log("test_acc_epoch", test_acc_mean, prog_bar=True)


In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary, EarlyStopping
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
import torch
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv
import datetime

AVAIL_GPUS = 0 
MAX_EPOCHS = 200 
NUM_TRIALS = 25 
GNN_conv_layer=WeightedSAGEConv 

# Loop through trials
for trial in range(1, NUM_TRIALS + 1):
    model_name = f'gat_{datetime.datetime.today().strftime("%Y-%m-%d")}_trial_{trial}'
    logger = WandbLogger(name=model_name, project="", log_model="all") #Project name

    model = LitGNN(
        model_name="SGConvNet", 
        num_features=num_features, 
        hidden_channels=[128], 
        num_classes=num_classes, 
        hidden_dense=64, 
        GNN_conv_layer=SGConv, 
        dropout_rate=0.4
    )

    train_data_loader = DataLoader([data], batch_size=1, num_workers=3)
    val_data_loader = DataLoader([data], batch_size=1, num_workers=3)

    checkpoint_callback = ModelCheckpoint(
        save_weights_only=True,
        mode="max", 
        monitor="val_acc",
        dirpath=f"checkpoints/trial_{trial}",
        filename="{epoch:02d}-{val_acc:.2f}"
    )

    trainer = pl.Trainer(
        callbacks=[
            checkpoint_callback, 
            EarlyStopping(monitor="val_acc", patience=20, verbose=True, mode="max"),
            ModelSummary(max_depth=3)
        ],
        devices=1, 
        accelerator="cpu",
        max_epochs=MAX_EPOCHS,
        logger=logger,
    )

    trainer.fit(model, train_dataloaders=train_data_loader, val_dataloaders=val_data_loader)

    best_model_path = checkpoint_callback.best_model_path
    if best_model_path:
        model = LitGNN.load_from_checkpoint(
            best_model_path,
            model_name="SGConvNet",
            num_features=num_features,
            hidden_channels=[128],
            num_classes=num_classes,
            hidden_dense=64,
            GNN_conv_layer=SGConv,
            dropout_rate=0.4
        )
        print(f"Trial {trial}: Loaded model from checkpoint: {best_model_path}")
    else:
        print(f"Trial {trial}: No checkpoint found. Training from scratch.")


In [None]:
import numpy as np
import torch
from sklearn.metrics import (
    classification_report,
    roc_auc_score,
    average_precision_score
)


def to_numpy(x):
    if isinstance(x, np.ndarray):
        return x
    return x.detach().cpu().numpy()


train_accuracies, train_precisions, train_recalls, train_f1s = [], [], [], []
test_accuracies, test_precisions, test_recalls, test_f1s = [], [], [], []

train_aucs, test_aucs = [], []
train_auprs, test_auprs = [], []


for trial in range(1, NUM_TRIALS + 1):
    print(f"\nEvaluating Trial {trial}")

    device = next(model.parameters()).device
    data = data.to(device)

    model.eval()
    with torch.no_grad():
        logits = model(
            data.x,
            data.edge_index,
            data.edge_attr
        )

    probs = torch.softmax(logits, dim=1)


    train_mask = data.train_mask
    y_train = data.y[train_mask]
    preds_train = logits[train_mask].argmax(dim=1)
    probs_train = probs[train_mask][:, 1]

    train_report = classification_report(
        to_numpy(y_train),
        to_numpy(preds_train),
        labels=[0, 1],
        target_names=["negative", "positive"],
        output_dict=True,
        zero_division=0
    )

    train_accuracies.append(train_report["accuracy"])
    train_precisions.append(train_report["positive"]["precision"])
    train_recalls.append(train_report["positive"]["recall"])
    train_f1s.append(train_report["positive"]["f1-score"])

    if len(torch.unique(y_train)) > 1:
        train_aucs.append(
            roc_auc_score(to_numpy(y_train), to_numpy(probs_train))
        )
        train_auprs.append(
            average_precision_score(to_numpy(y_train), to_numpy(probs_train))
        )
    else:
        train_aucs.append(np.nan)
        train_auprs.append(np.nan)

    test_mask = data.test_mask
    y_test = data.y[test_mask]
    preds_test = logits[test_mask].argmax(dim=1)
    probs_test = probs[test_mask][:, 1]

    test_report = classification_report(
        to_numpy(y_test),
        to_numpy(preds_test),
        labels=[0, 1],
        target_names=["negative", "positive"],
        output_dict=True,
        zero_division=0
    )

    test_accuracies.append(test_report["accuracy"])
    test_precisions.append(test_report["positive"]["precision"])
    test_recalls.append(test_report["positive"]["recall"])
    test_f1s.append(test_report["positive"]["f1-score"])

    if len(torch.unique(y_test)) > 1:
        test_aucs.append(
            roc_auc_score(to_numpy(y_test), to_numpy(probs_test))
        )
        test_auprs.append(
            average_precision_score(to_numpy(y_test), to_numpy(probs_test))
        )
    else:
        test_aucs.append(np.nan)
        test_auprs.append(np.nan)


print("\nAggregated Metrics (Mean ± Std)")

print(f"Train Accuracy : {np.mean(train_accuracies):.4f} ± {np.std(train_accuracies):.4f}")
print(f"Train Precision: {np.mean(train_precisions):.4f} ± {np.std(train_precisions):.4f}")
print(f"Train Recall   : {np.mean(train_recalls):.4f} ± {np.std(train_recalls):.4f}")
print(f"Train F1       : {np.mean(train_f1s):.4f} ± {np.std(train_f1s):.4f}")
print(f"Train AUROC    : {np.nanmean(train_aucs):.4f} ± {np.nanstd(train_aucs):.4f}")
print(f"Train AUPRC    : {np.nanmean(train_auprs):.4f} ± {np.nanstd(train_auprs):.4f}")

print(f"Test Accuracy  : {np.mean(test_accuracies):.4f} ± {np.std(test_accuracies):.4f}")
print(f"Test Precision : {np.mean(test_precisions):.4f} ± {np.std(test_precisions):.4f}")
print(f"Test Recall    : {np.mean(test_recalls):.4f} ± {np.std(test_recalls):.4f}")
print(f"Test F1        : {np.mean(test_f1s):.4f} ± {np.std(test_f1s):.4f}")
print(f"Test AUROC     : {np.nanmean(test_aucs):.4f} ± {np.nanstd(test_aucs):.4f}")
print(f"Test AUPRC     : {np.nanmean(test_auprs):.4f} ± {np.nanstd(test_auprs):.4f}")
