In [1]:
## Import required libraries and modules
import sys
import os
import logging
import pandas as pd
import importlib

# Add the src directory to the Python path
sys.path.append(os.path.abspath(os.path.join("..", "src")))


from utils import load_config
from preprocess import split_data

# Load Config
config = load_config("../config.yaml")

# Configure logging
logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s - %(levelname)s - %(message)s",
)

In [8]:
from utils import load_sampled_data

landmark_df = load_sampled_data(
    config["data_paths"]["preprocessed_best_inferred_file"], 
)

In [9]:
# Drop the viability, cell_mfc_name and pert_dose columns
landmark_df = landmark_df.drop(
    ["viability", "cell_mfc_name", "pert_dose"], axis=1
)

landmark_df.columns

Index(['AARS', 'ABCF1', 'ABL1', 'ACAA1', 'ACAT2', 'ACLY', 'ADAM10', 'ADH5',
       'PARP1', 'ADRB2',
       ...
       'CCR2', 'TMEM242', 'SMIM27', 'ARMCX4', 'NBPF10', 'TIMM23', 'ZNF783',
       'MICA', 'TMEM257', 'C10orf12'],
      dtype='object', length=10174)

In [10]:
import pandas as pd

# Suppose you already have your DataFrame, for example:
# df = pd.read_csv('your_data.csv')  # or however you load your data
# For this example, we mimic a DataFrame with gene columns:


# Extract the gene symbols from the DataFrame columns
gene_list = list(landmark_df.columns)

# Sort the gene list alphabetically
sorted_genes = sorted(gene_list)

# Create a mapping dictionary: index -> gene
gene2ind = {idx: gene for idx, gene in enumerate(sorted_genes)}

# Write the mapping to a file in the required format ("index<TAB>gene")
with open("best_inferred2ind.txt", "w") as f:
    for idx, gene in gene2ind.items():
        f.write(f"{idx}\t{gene}\n")

print("Mapping written to gene2ind.txt")

Mapping written to gene2ind.txt


In [None]:
## Load, Split, and Preprocess Datase
# Configurable parameters
sample_size = 1000  # Number of rows to sample from each dataset
chunk_size = 1000  # Chunk size for loading large datasets

# Load datasets
logging.info("Loading datasets with sampling...")


def load_sampled_data(file_path, sample_size, use_chunks=False, chunk_size=None):
    """
    Load and sample a dataset, with optional chunked loading for large files.

    Args:
        file_path (str): Path to the dataset file.
        sample_size (int): Number of rows to sample.
        use_chunks (bool): Whether to load the dataset in chunks.
        chunk_size (int, optional): Size of chunks if `use_chunks` is True.

    Returns:
        pd.DataFrame: Sampled DataFrame.
    """
    if use_chunks:
        logging.info(f"Loading {file_path} in chunks...")
        chunks = []
        total_loaded = 0
        for chunk in pd.read_csv(file_path, chunksize=chunk_size):
            if total_loaded >= sample_size:
                break

            # Determine how many rows to sample from this chunk
            sample_rows = min(sample_size - total_loaded, len(chunk))
            chunks.append(chunk.sample(sample_rows))
            total_loaded += sample_rows

        sampled_df = pd.concat(chunks, axis=0)
        del chunks  # Free memory
    else:
        logging.info(f"Sampling {sample_size} rows from {file_path}...")
        sampled_df = pd.read_csv(file_path, nrows=sample_size)

    return sampled_df


# Load data with sampling

tf_df = load_sampled_data(config["data_paths"]["preprocessed_tf_file"], sample_size)
landmark_df = load_sampled_data(
    config["data_paths"]["preprocessed_landmark_file"], sample_size
)
best_inferred_df = load_sampled_data(
    config["data_paths"]["preprocessed_best_inferred_file"], sample_size
)

gene_df = load_sampled_data(
    config["data_paths"]["preprocessed_gene_file"],
    sample_size,
    use_chunks=True,
    chunk_size=chunk_size,
)

# Split Data
logging.info("Splitting datasets into train/val/test...")

X_tf_train, y_tf_train, X_tf_val, y_tf_val, X_tf_test, y_tf_test = split_data(
    tf_df, target_name="viability", config=config, stratify_by="cell_mfc_name"
)
(
    X_landmark_train,
    y_landmark_train,
    X_landmark_val,
    y_landmark_val,
    X_landmark_test,
    y_landmark_test,
) = split_data(
    landmark_df, target_name="viability", config=config, stratify_by="cell_mfc_name"
)
(
    X_best_inferred_train,
    y_best_inferred_train,
    X_best_inferred_val,
    y_best_inferred_val,
    X_best_inferred_test,
    y_best_inferred_test,
) = split_data(
    best_inferred_df,
    target_name="viability",
    config=config,
    stratify_by="cell_mfc_name",
)
X_gene_train, y_gene_train, X_gene_val, y_gene_val, X_gene_test, y_gene_test = (
    split_data(
        gene_df, target_name="viability", config=config, stratify_by="cell_mfc_name"
    )
)

In [None]:
import networkx as nx


def load_ontology(file_path):
    dG = nx.DiGraph()
    with open(file_path, "r") as f:
        for line in f:
            parent, child, rel_type = line.strip().split("\t")
            dG.add_edge(parent, child, relationship=rel_type)
    return dG


# Load the ontology
ontology_file = "../data/raw/drugcell_ont.txt"
dG = load_ontology(ontology_file)

# Check the structure
print(f"Number of nodes: {len(dG.nodes())}")
print(f"Number of edges: {len(dG.edges())}")

In [None]:
genes_in_ontology = {
    node for node, data in dG.nodes(data=True) if dG.out_degree(node) == 0
}
missing_genes = set(gene_df.columns) - genes_in_ontology
if missing_genes:
    print(f"Warning: {len(missing_genes)} genes are missing from the ontology.")
    
print(len(gene_df.columns) - len(missing_genes))

In [None]:
import networkx as nx


def filter_ontology(dG, valid_genes):
    """
    Filter the ontology to keep only terms that contain genes from the dataset.

    Args:
        dG (nx.DiGraph): Original ontology graph.
        valid_genes (set): Set of genes present in both the dataset and the ontology.

    Returns:
        nx.DiGraph: Updated ontology graph with only relevant genes.
    """
    filtered_dG = nx.DiGraph()

    for parent, child, data in dG.edges(data=True):
        if child in valid_genes or any(
            grandchild in valid_genes for grandchild in nx.descendants(dG, child)
        ):
            filtered_dG.add_edge(parent, child, relationship=data["relationship"])

    return filtered_dG


# Get the intersection of genes
genes_in_ontology = {
    node for node, data in dG.nodes(data=True) if dG.out_degree(node) == 0
}
valid_genes = genes_in_ontology.intersection(set(gene_df.columns))

# Apply filtering
filtered_dG = filter_ontology(dG, valid_genes)

# Check the new structure
print(
    f"Updated Ontology: {len(filtered_dG.nodes())} nodes, {len(filtered_dG.edges())} edges"
)

In [8]:
def build_term_size_map(dG, valid_genes):
    """
    Build a mapping of ontology terms to the number of genes associated with each term.

    Args:
        dG (nx.DiGraph): Filtered ontology graph.
        valid_genes (set): Set of genes available in the dataset.

    Returns:
        dict: Mapping of term IDs to the number of associated genes.
    """
    term_size_map = {}

    for term in dG.nodes():
        # Count only genes that are in the dataset
        term_genes = [gene for gene in nx.descendants(dG, term) if gene in valid_genes]
        term_size_map[term] = len(term_genes)

    return term_size_map


# Generate term_size_map using filtered ontology
term_size_map = build_term_size_map(filtered_dG, valid_genes)


def build_term_direct_gene_map(dG, valid_genes):
    """
    Build a mapping of ontology terms to the genes directly annotated with them.

    Args:
        dG (nx.DiGraph): Filtered ontology graph.
        valid_genes (set): Set of genes available in the dataset.

    Returns:
        dict: Mapping of term IDs to directly associated genes.
    """
    term_direct_gene_map = {}

    for term in dG.nodes():
        # Find direct gene annotations (child nodes that are genes)
        direct_genes = [gene for gene in dG.successors(term) if gene in valid_genes]
        term_direct_gene_map[term] = direct_genes

    return term_direct_gene_map


# Generate term_direct_gene_map using filtered ontology
term_direct_gene_map = build_term_direct_gene_map(filtered_dG, valid_genes)

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx


class GeneExpressionDrugCell(nn.Module):
    def __init__(
        self,
        term_size_map,
        term_direct_gene_map,
        dG,
        ngene,
        num_hiddens_genotype,
        num_hiddens_final,
    ):
        """
        Custom DrugCell-like model using gene expression data instead of mutations.

        Args:
            term_size_map (dict): Mapping of terms to gene counts.
            term_direct_gene_map (dict): Mapping of terms to directly associated genes.
            dG (nx.DiGraph): Filtered ontology graph.
            ngene (int): Number of input genes.
            num_hiddens_genotype (int): Number of hidden neurons per biological term.
            num_hiddens_final (int): Number of neurons in the final prediction layer.
        """
        super(GeneExpressionDrugCell, self).__init__()

        self.term_direct_gene_map = term_direct_gene_map
        self.cal_term_dim(term_size_map, num_hiddens_genotype)
        self.gene_dim = ngene

        # Construct hierarchical neural network from ontology
        self.construct_gene_network(dG)

        # Final prediction layers
        self.add_module(
            "final_linear_layer", nn.Linear(num_hiddens_genotype, num_hiddens_final)
        )
        self.add_module("final_batchnorm_layer", nn.BatchNorm1d(num_hiddens_final))
        self.add_module("final_output_layer", nn.Linear(num_hiddens_final, 1))

    def cal_term_dim(self, term_size_map, num_hiddens_genotype):
        """Calculate hidden layer sizes for each biological term."""
        self.term_dim_map = {}
        for term, term_size in term_size_map.items():
            num_output = int(num_hiddens_genotype)
            self.term_dim_map[term] = num_output

    def construct_gene_network(self, dG):
        """Create layers following the hierarchical ontology structure."""
        self.term_layer_list = []
        self.term_neighbor_map = {
            term: list(dG.successors(term)) for term in dG.nodes()
        }

        while True:
            leaves = [n for n in dG.nodes() if dG.out_degree(n) == 0]
            if not leaves:
                break

            self.term_layer_list.append(leaves)
            for term in leaves:
                input_size = sum(
                    self.term_dim_map[child] for child in self.term_neighbor_map[term]
                )
                if term in self.term_direct_gene_map:
                    input_size += len(self.term_direct_gene_map[term])

                self.add_module(
                    term + "_linear_layer",
                    nn.Linear(input_size, self.term_dim_map[term]),
                )
                self.add_module(
                    term + "_batchnorm_layer", nn.BatchNorm1d(self.term_dim_map[term])
                )

            dG.remove_nodes_from(leaves)

    def forward(self, x):
        """Define forward pass for gene expression input."""
        term_outputs = {}

        # Process gene input
        for term in self.term_direct_gene_map:
            gene_subset = self.term_direct_gene_map[term]
            term_outputs[term] = self._modules[term + "_linear_layer"](
                x[:, gene_subset]
            )

        # Propagate through the hierarchical model
        for layer in self.term_layer_list:
            for term in layer:
                inputs = [term_outputs[child] for child in self.term_neighbor_map[term]]
                term_input = torch.cat(inputs, 1)
                term_output = torch.tanh(
                    self._modules[term + "_linear_layer"](term_input)
                )
                term_outputs[term] = self._modules[term + "_batchnorm_layer"](
                    term_output
                )

        # Final layer to predict viability
        final_output = self._modules["final_batchnorm_layer"](
            torch.tanh(
                self._modules["final_linear_layer"](
                    term_outputs[self.term_layer_list[-1][0]]
                )
            )
        )
        output = self._modules["final_output_layer"](final_output)
        return output

In [None]:
import torch.optim as optim

# Define loss and optimizer
model = GeneExpressionDrugCell(
    term_size_map=term_size_map,
    term_direct_gene_map=term_direct_gene_map,
    dG=filtered_dG,  # Filtered ontology graph
    ngene=len(valid_genes),  # Number of genes in the dataset
    num_hiddens_genotype=128,
    num_hiddens_final=64,
)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(50):
    model.train()
    optimizer.zero_grad()
    outputs = model(torch.tensor(X_gene_train.values, dtype=torch.float32))
    loss = criterion(
        outputs, torch.tensor(y_gene_train.values, dtype=torch.float32).unsqueeze(1)
    )
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")