In [1]:
# import h5py
# import pandas as pd
# import numpy as np

# # Step 1: Load the data matrix and metadata from the .gctx file
# gctx_file = "../data/raw/level3_beta_ctl_n188708x12328.gctx"
# h5_file = "../data/raw/compound_control_dataset_all.h5"

# chunk_size = 10000  # Number of rows to process at a time

# with h5py.File(gctx_file, "r") as f:
#     # Access data matrix and metadata
#     data = f["/0/DATA/0/matrix"]
#     row_ids = f["/0/META/COL/id"][:].astype(str)  # All sample IDs (rows)
#     col_ids = f["/0/META/ROW/id"][:].astype(int)  # All gene IDs (columns)

#     # Load metadata files
#     gene_metadata = pd.read_csv("../data/raw/geneinfo_beta.txt", sep="\t")
#     pert_info = pd.read_csv("../data/raw/instinfo_beta.txt", sep="\t")

#     # Filter perturbation metadata
#     filtered_pert_info = pert_info[
#         (pert_info["qc_pass"] == 1.0) & (pert_info["pert_type"] == "ctl_vehicle")
#     ].copy()
#     filtered_pert_info.set_index("sample_id", inplace=True)

#     # Filter valid sample IDs
#     valid_sample_ids = pd.Index(row_ids).intersection(filtered_pert_info.index)
#     filtered_pert_info = filtered_pert_info.loc[valid_sample_ids]
#     filtered_pert_info.sort_index(inplace=True)

#     # Create a mapping of sample IDs to row indices
#     row_idx_map = {rid: idx for idx, rid in enumerate(row_ids)}

#     # Compute row indices for valid samples
#     filtered_row_indices = [row_idx_map[sid] for sid in valid_sample_ids]

# # Step 2: Write the filtered data matrix and metadata in chunks
# with h5py.File(h5_file, "w") as h5f:
#     # Pre-allocate space for the dataset in the HDF5 file
#     data_shape = (len(filtered_row_indices), len(col_ids))
#     data_dset = h5f.create_dataset(
#         "data", shape=data_shape, dtype="float32", compression="gzip"
#     )

#     # Write row (sample) metadata
#     for column in filtered_pert_info.columns:
#         h5f.create_dataset(
#             f"row_metadata/{column}",
#             data=filtered_pert_info[column].values.astype("S"),
#         )

#     # Write column (gene) metadata
#     for column in gene_metadata.columns:
#         h5f.create_dataset(
#             f"col_metadata/{column}",
#             data=gene_metadata[column].values.astype("S"),
#         )

#     # Write row and column IDs
#     h5f.create_dataset("row_ids", data=valid_sample_ids.values.astype("S"))
#     gene_mapping = gene_metadata.set_index("gene_id")["gene_symbol"].to_dict()
#     filtered_col_symbols = [gene_mapping.get(gid, f"Gene_{gid}") for gid in col_ids]
#     h5f.create_dataset(
#         "col_ids", data=[x.encode("utf-8") for x in filtered_col_symbols]
#     )

#     # Process and write data in chunks
#     for i in range(0, len(filtered_row_indices), chunk_size):
#         chunk_indices = filtered_row_indices[i : i + chunk_size]  # Current chunk
#         chunk_data = data[chunk_indices, :]  # Load chunk from the data matrix
#         data_dset[i : i + len(chunk_indices), :] = chunk_data  # Write to dataset

# print(f"Filtered data and metadata successfully written to {h5_file}.")

In [None]:
import h5py
import pandas as pd

# Step 1: Load the data matrix and metadata from the .gctx file
gctx_file = "../data/raw/level3_beta_trt_cp_n1805898x12328.gctx"

with h5py.File(gctx_file, "r") as f:
    # Access the entire data matrix and metadata
    data = f["/0/DATA/0/matrix"]
    row_ids = f["/0/META/COL/id"][:].astype(str)  # Sample IDs (rows)
    col_ids = f["/0/META/ROW/id"][:].astype(int)  # Gene IDs (columns)

    # Load metadata files
    gene_metadata = pd.read_csv("../data/raw/geneinfo_beta.txt", sep="\t")
    pert_info = pd.read_csv("../data/raw/compound_perturbation_metadata.txt", sep="\t")

    # Filter perturbation metadata
    filtered_pert_info = pert_info[
        (pert_info["qc_pass"] == 1.0) & (pert_info["pert_type"] == "trt_cp")
    ].copy()
    filtered_pert_info.set_index("sample_id", inplace=True)

    # STEP A: Reduce filtered_pert_info to match row_ids
    subset_row_ids = pd.Index(row_ids)
    subset_row_ids = subset_row_ids.intersection(filtered_pert_info.index)
    subset_pert_info = filtered_pert_info.loc[subset_row_ids].copy()
    subset_pert_info.sort_index(inplace=True)

    # Keep a separate column for sample_id (while retaining it as index)
    subset_pert_info["sample_id"] = subset_pert_info.index

    # STEP B: Recompute the row indices in the data matrix for valid sample_ids
    row_idx_map = {rid: idx for idx, rid in enumerate(row_ids)}
    filtered_row_indices = [row_idx_map[sid] for sid in subset_pert_info.index]

    # Filter the data matrix to match the final subset of row indices
    filtered_data_matrix = data[filtered_row_indices, :]

# Step 2: Map gene IDs to gene symbols
gene_mapping = gene_metadata.set_index("gene_id")["gene_symbol"].to_dict()
filtered_col_symbols = [gene_mapping.get(gid, f"Gene_{gid}") for gid in col_ids]

# Step 3: Write the filtered data and metadata to an HDF5 file
h5_file = "../data/raw/compound_perturbation_dataset.h5"

rows, cols = filtered_data_matrix.shape
# Decide on a chunk shape
chunk_rows = 128  # e.g. 128 rows per chunk
chunk_cols = cols  # store entire row block in each chunk, or you could do smaller
chunk_shape = (chunk_rows, chunk_cols)

with h5py.File(h5_file, "w") as h5f:
    # Write the filtered data matrix
    h5f.create_dataset("data", data=filtered_data_matrix, compression="gzip", chunks=chunk_shape)

    # Write row (sample) metadata
    for column in subset_pert_info.columns:
        h5f.create_dataset(
            f"row_metadata/{column}",
            data=subset_pert_info[column].values.astype("S"),
        )

    # Write column (gene) metadata
    for column in gene_metadata.columns:
        h5f.create_dataset(
            f"col_metadata/{column}",
            data=gene_metadata[column].values.astype("S"),
        )

    # Write final row and column IDs
    h5f.create_dataset("row_ids", data=subset_pert_info.index.values.astype("S"))
    h5f.create_dataset(
        "col_ids", data=[x.encode("utf-8") for x in filtered_col_symbols]
    )

print(f"Filtered data and metadata successfully written to {h5_file}.")

Filtered data and metadata successfully written to ../data/raw/compound_control_dataset.h5.


In [1]:
import h5py
import numpy as np
import pandas as pd

# Example paths
gctx_file = "../data/raw/level3_beta_trt_cp_n1805898x12328.gctx"
h5_file = "../data/raw/compound_perturbation_dataset.h5"

# Step 1: Read and filter row metadata, then determine which row indices to keep
with h5py.File(gctx_file, "r") as f_in:
    # Original row IDs
    row_ids = f_in["/0/META/COL/id"][:].astype(str)

# Load external metadata (e.g., geneinfo, pert_info) as needed
gene_metadata = pd.read_csv("../data/raw/geneinfo_beta.txt", sep="\t")
pert_info = pd.read_csv("../data/raw/compound_perturbation_metadata.txt", sep="\t")

filtered_pert_info = pert_info[
    (pert_info["qc_pass"] == 1.0) & (pert_info["pert_type"] == "trt_cp")
].copy()
filtered_pert_info.set_index("sample_id", inplace=True)

subset_row_ids = pd.Index(row_ids).intersection(filtered_pert_info.index)
subset_pert_info = filtered_pert_info.loc[subset_row_ids].copy()
subset_pert_info.sort_index(inplace=True)
subset_pert_info["sample_id"] = subset_pert_info.index

row_idx_map = {rid: idx for idx, rid in enumerate(row_ids)}
filtered_row_indices = [row_idx_map[sid] for sid in subset_pert_info.index]

# Step 2: Determine column subset or read all columns
# (If you do partial columns, e.g. for landmark genes, define 'col_indices' here.
# Otherwise, read all columns.)
col_ids_all = None
col_symbols_all = None
with h5py.File(gctx_file, "r") as f_in:
    col_ids_all = f_in["/0/META/ROW/id"][:].astype(int)

gene_mapping = gene_metadata.set_index("gene_id")["gene_symbol"].to_dict()
col_symbols_all = [gene_mapping.get(gid, f"Gene_{gid}") for gid in col_ids_all]
col_indices = np.arange(len(col_ids_all))  # or define a subset for partial columns

# Step 3: Create a chunked output dataset, define chunk shape
# We read in chunks from the big file and write them chunk by chunk.
chunk_rows_read = 50000  # number of rows to read from the big file at once

# We'll define the final shape for the output.
# Number of rows is len(filtered_row_indices), columns is len(col_indices).
n_rows_out = len(filtered_row_indices)
n_cols_out = len(col_indices)

# We'll set a chunk shape for the output as well (for final chunked HDF5).
chunk_shape_out = (128, n_cols_out)

# Step 4: Reading and writing chunk by chunk
with h5py.File(gctx_file, "r") as f_in, h5py.File(h5_file, "w") as f_out:
    data_in = f_in["/0/DATA/0/matrix"]
    # Create the output dataset in chunked form
    dset_out = f_out.create_dataset(
        "data",
        shape=(n_rows_out, n_cols_out),
        dtype="float32",
        compression="gzip",
        chunks=chunk_shape_out,
    )

    # We'll keep track of how many rows we've already written in the output
    out_row_start = 0
    total_rows = n_rows_out

    for start_idx in range(0, total_rows, chunk_rows_read):
        end_idx = min(start_idx + chunk_rows_read, total_rows)
        chunk_size = end_idx - start_idx

        # Subset of global row indices for this chunk
        row_subset = filtered_row_indices[start_idx:end_idx]

        # Read those rows from the big file, and only the desired columns
        chunk_data = data_in[row_subset, :][
            :, col_indices
        ]  # shape: (chunk_size, n_cols_out)

        # Write to the new dataset
        dset_out[out_row_start : out_row_start + chunk_size, :] = chunk_data
        out_row_start += chunk_size

    # Step 5: Write row metadata
    for column in subset_pert_info.columns:
        f_out.create_dataset(
            f"row_metadata/{column}",
            data=subset_pert_info[column].values.astype("S"),
        )

    # Step 6: Write column metadata
    # For partial columns, you might subset your gene_metadata accordingly if needed.
    for col_name in gene_metadata.columns:
        f_out.create_dataset(
            f"col_metadata/{col_name}",
            data=gene_metadata[col_name].values.astype("S"),
        )

    # Step 7: Write final row and column IDs
    f_out.create_dataset("row_ids", data=subset_pert_info.index.values.astype("S"))
    selected_col_symbols = [col_symbols_all[i] for i in col_indices]
    f_out.create_dataset(
        "col_ids", data=[x.encode("utf-8") for x in selected_col_symbols]
    )

print(f"Filtered data and metadata successfully written to {h5_file} in chunks!")

Filtered data and metadata successfully written to ../data/raw/compound_perturbation_dataset.h5 in chunks!


In [1]:
import pandas as pd


sample_meta = pd.read_csv("../data/raw/instinfo_beta.txt", sep="\t")

  sample_meta = pd.read_csv("../data/raw/instinfo_beta.txt", sep="\t")


In [2]:
# First select only the samples that have passed QC and are trt_cp (compound treatment)
sample_meta_filtered = sample_meta[
    (sample_meta["qc_pass"] == 1.0)
    & (
        (sample_meta["pert_type"] == "ctl_vehicle")
    )
]

# Now check the value counts for the pert_dose_unit
sample_meta_filtered["pert_dose_unit"].value_counts()

pert_dose_unit
Unknown    112
Name: count, dtype: int64

In [3]:
len(sample_meta_filtered)

76742

In [4]:
# Set the pert_dose_unit to "uM" for all samples and the pert_dose to 0.0
sample_meta_filtered["pert_dose_unit"] = "uM"
sample_meta_filtered["pert_dose"] = 0.0

# Check the value counts for the pert_dose_unit
sample_meta_filtered["pert_dose_unit"].value_counts()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  sample_meta_filtered["pert_dose_unit"] = "uM"
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  sample_meta_filtered["pert_dose"] = 0.0


pert_dose_unit
uM    76742
Name: count, dtype: int64

In [5]:

# Save this filtered metadata to a new file
sample_meta_filtered.to_csv(
    "../data/raw/compound_control_metadata.txt", sep="\t", index=False
)

In [None]:
# Set all rows pert_dose to 0.0 where pert_dose_unit is "Unknown" or Nan/None types
sample_meta_filtered.loc[
    sample_meta_filtered["pert_dose_unit"].isin(["Unknown", None]), "pert_dose"
] = 0.0

# Fill missing values with "uM" and replace "Unknown" with "uM"
sample_meta_filtered["pert_dose_unit"].fillna("uM", inplace=True)
sample_meta_filtered["pert_dose_unit"].replace("Unknown", "uM", inplace=True)



In [3]:
# h5_file = "../data/raw/compound_control_dataset.h5"
# with h5py.File(h5_file, "r") as h5f:
#     # Load the data matrix
#     data_matrix = h5f["data"][:]

#     # Decode row metadata
#     row_metadata = {
#         key: h5f[f"row_metadata/{key}"][:].astype(str) for key in h5f["row_metadata"]
#     }

#     # Decode column metadata
#     col_metadata = {
#         key: h5f[f"col_metadata/{key}"][:].astype(str) for key in h5f["col_metadata"]
#     }

#     # Load and decode row and column IDs
#     row_ids = h5f["row_ids"][:].astype(str)
#     col_ids = h5f["col_ids"][:].astype(str)

# # Convert metadata to DataFrames for ease of use (optional)
# import pandas as pd

# row_metadata_df = pd.DataFrame(row_metadata)
# col_metadata_df = pd.DataFrame(col_metadata)


In [1]:
# Add src to path
import logging
import sys
import os

import torch

# Add the src directory to the Python path
sys.path.append(os.path.abspath(os.path.join("..", "src")))
from utils import create_smiles_dict

from deepchem.feat.smiles_tokenizer import SmilesTokenizer
from torch.utils.data import DataLoader, random_split
from data_sets import PerturbationDataset
import pandas as pd
from training import train_multimodal_model
import numpy as np
import matplotlib.pyplot as plt
from evaluation import evaluate_multimodal_model

device = "cuda" if torch.cuda.is_available() else "cpu"

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!


Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead


Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'lightning'
Skipped loading some Jax models, missing a dependency. No module named 'haiku'


In [2]:
# Load the datasets
controls = "../data/raw/compound_control_dataset.h5"
perturbations = "../data/raw/compound_pertubation_dataset.h5"

# Load the SMILES data and create a dictionary mapping
smiles_df = pd.read_csv("../data/raw/compoundinfo_beta.txt", sep="\t")
smiles_dict = create_smiles_dict(smiles_df)

# Load the SMILES tokenizer
vocab_file = ("../data/raw/vocab.txt" )
smiles_tokenizer = SmilesTokenizer(vocab_file=vocab_file)

# Analyze SMILES string lengths to determine an appropriate max_length
smiles_lengths = [len(smile) for smile in smiles_dict.values()]
max_length = min(max(smiles_lengths), 128)

# Check some SMILES strings to tokens
smiles = list(smiles_dict.values())
tokens = smiles_tokenizer.encode(
    smiles[13],
    add_special_tokens=True,
    max_length=max_length,
    truncation=True,
    padding="max_length",
    return_tensors="pt",
)
tokens

tensor([[12, 16, 19, 15, 20, 15, 15, 15, 15, 17, 19, 16, 18, 15, 20, 31, 15, 20,
         15, 15, 17, 25, 25, 20, 31, 15, 20, 15, 15, 25, 15, 21, 15, 15, 17, 28,
         18, 15, 15, 15, 20, 21, 18, 16, 17, 22, 19, 18, 23, 16, 20, 17, 16, 21,
         16, 16, 26, 16, 16, 17, 16, 21, 18, 16, 16, 20, 16, 26, 18, 16, 17, 19,
         18, 22, 19, 13,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]])

In [12]:
dataset = PerturbationDataset(
    controls_file=controls,
    perturbations_file=perturbations,
    smiles_dict=smiles_dict,
    plate_column="det_plate",
    normalize=True,
    n_rows=10000,
    pairing="random",
    landmark_only=True,
    tokenizer=smiles_tokenizer,
    max_smiles_length=max_length,
)

# Split the dataset into train, validation, and test sets
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

total_len = len(dataset)
train_size = int(train_ratio * total_len)
val_size = int(val_ratio * total_len)
test_size = total_len - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

# Log the sizes of the datasets
logging.debug(
    f"Train Dataset: {len(train_dataset)}, Val Dataset: {len(val_dataset)}, Test Dataset: {len(test_dataset)}"
)

In [13]:
batch_size = 256
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
)

test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=0
)

In [14]:
import torch
from torch import nn
class SimpleSMILESEncoder(nn.Module):
    """
    A toy SMILES encoder that one-hot encodes each character
    and processes it via a small MLP or CNN. Here we use a trivial MLP
    for demonstration.
    """

    def __init__(self, vocab, embed_dim=64, hidden_dim=128):
        """
        Args:
            vocab (str or list): A string or list of allowed SMILES characters.
            embed_dim (int): Size of the per-character embedding.
            hidden_dim (int): Size of the hidden layer in the MLP that encodes the SMILES.
        """
        super(SimpleSMILESEncoder, self).__init__()
        self.vocab = vocab
        self.vocab_size = len(self.vocab)
        self.embed_dim = embed_dim

        # A simple embedding layer for each character:
        self.embedding = nn.Embedding(self.vocab_size, embed_dim)

        # Then a small MLP to compress all characters into a single vector
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )

        # A final pooling transformation to get a single vector
        # (In practice, you might do average/attention pooling across characters.)
        # For simplicity, we apply MLP to each character embedding and average.
        # More advanced: CNN or RNN or Transformer.

    def forward(self, smiles_batch):
        tokens, _ = self.smiles_to_indices(smiles_batch)

        # Make sure tokens are on the same device as the embedding layer
        tokens = tokens.to(self.embedding.weight.device)

        embedded = self.embedding(tokens)
        B, L, E = embedded.shape
        embedded = embedded.view(B * L, E)
        embedded = self.mlp(embedded)
        embedded = embedded.view(B, L, -1)
        embedded = embedded.mean(dim=1)
        return embedded

    def smiles_to_indices(self, smiles_batch):
        """
        Convert list of SMILES strings to a batch of indices for each character.
        We'll simply pad them to max_len in this toy example.
        """
        # Find max length
        max_len = max(len(s) for s in smiles_batch)
        token_ids = []
        lengths = []

        char_to_idx = {c: i for i, c in enumerate(self.vocab)}

        for s in smiles_batch:
            lengths.append(len(s))
            row = []
            for ch in s:
                if ch in char_to_idx:
                    row.append(char_to_idx[ch])
                else:
                    row.append(char_to_idx["?"])
            # pad
            while len(row) < max_len:
                row.append(char_to_idx[" "])
            token_ids.append(row)

        # Convert to torch tensor
        tokens = torch.tensor(token_ids, dtype=torch.long)
        return tokens, lengths


class Perturbinator(nn.Module):
    """
    A model that encodes (unperturbed gene expression) + (drug SMILES)
    and predicts (perturbed gene expression).
    """

    def __init__(
        self, gene_dim, gene_hidden_dim=512, drug_hidden_dim=128, smiles_vocab=None
    ):
        """
        Args:
            gene_dim (int): Dimensionality of the gene expression input.
            gene_hidden_dim (int): hidden dimension for the gene MLP.
            drug_hidden_dim (int): hidden dimension for the SMILES encoder output.
            smiles_vocab (str or list): vocabulary for the SMILES.
        """
        super(Perturbinator, self).__init__()
        if smiles_vocab is None:
            # A simple default vocabulary of typical SMILES chars + space + ?
            # In practice, define a more robust set of tokens
            smiles_vocab = "ACGT()[]=+#@0123456789abcdefghijklmnopqrstuvwxyz" + "? "

        self.gene_dim = gene_dim
        self.gene_hidden_dim = gene_hidden_dim
        self.drug_hidden_dim = drug_hidden_dim

        # 2.1) Gene MLP
        self.gene_encoder = nn.Sequential(
            nn.Linear(gene_dim, gene_hidden_dim),
            nn.ReLU(),
            nn.Linear(gene_hidden_dim, gene_hidden_dim),
            nn.ReLU(),
        )

        # 2.2) SMILES encoder
        self.smiles_encoder = SimpleSMILESEncoder(
            vocab=smiles_vocab, embed_dim=64, hidden_dim=drug_hidden_dim
        )

        # 2.3) Fusion -> final MLP to predict the same dimension as gene_dim
        self.fusion = nn.Sequential(
            nn.Linear(gene_hidden_dim + drug_hidden_dim, gene_hidden_dim),
            nn.ReLU(),
            nn.Linear(
                gene_hidden_dim, gene_dim
            ),  # we predict perturbed expression (size = gene_dim)
        )

    def forward(self, gene_expr, smiles_batch):
        """
        gene_expr: Tensor of shape (B, gene_dim)
        smiles_batch: List of length B containing SMILES strings
        """
        # Encode gene
        gene_emb = self.gene_encoder(gene_expr)  # (B, gene_hidden_dim)
        # Encode SMILES
        drug_emb = self.smiles_encoder(smiles_batch)  # (B, drug_hidden_dim)

        fused = torch.cat(
            [gene_emb, drug_emb], dim=1
        )  # (B, gene_hidden_dim + drug_hidden_dim)
        out = self.fusion(fused)
        return out

In [15]:
# Initialize the Perturbinator
model = Perturbinator(
    gene_dim=978,
    gene_hidden_dim=512,
    drug_hidden_dim=128,
)

# Criterion and Optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", patience=3
)

train_losses, val_losses = train_multimodal_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    epochs=20,
    device=device,
    gradient_clipping=1.0,
    early_stopping_patience=5,
    model_name="Perturbinator",
    use_mixed_precision=True,
)

# Log the final train and validation losses
logging.debug("Final train and validation losses:")
logging.debug("Train Losses: ", train_losses)
logging.debug("Validation Losses: ", val_losses)

Training Progress:   0%|          | 0/20 [00:00<?, ?it/s]2024-12-25 12:30:19 [INFO] - Epoch 1/20 - Perturbinator, Train Loss: 0.0481, Val Loss: 0.0136
Training Progress:   5%|▌         | 1/20 [00:10<03:17, 10.39s/it]2024-12-25 12:30:29 [INFO] - Epoch 2/20 - Perturbinator, Train Loss: 0.0128, Val Loss: 0.0122
Training Progress:  10%|█         | 2/20 [00:20<03:05, 10.31s/it]2024-12-25 12:30:40 [INFO] - Epoch 3/20 - Perturbinator, Train Loss: 0.0122, Val Loss: 0.0121
Training Progress:  15%|█▌        | 3/20 [00:31<02:56, 10.36s/it]2024-12-25 12:30:50 [INFO] - Epoch 4/20 - Perturbinator, Train Loss: 0.0121, Val Loss: 0.0120
Training Progress:  20%|██        | 4/20 [00:41<02:44, 10.29s/it]2024-12-25 12:31:00 [INFO] - Epoch 5/20 - Perturbinator, Train Loss: 0.0120, Val Loss: 0.0118
Training Progress:  25%|██▌       | 5/20 [00:51<02:34, 10.29s/it]2024-12-25 12:31:09 [INFO] - Epoch 6/20 - Perturbinator, Train Loss: 0.0116, Val Loss: 0.0108
Training Progress:  30%|███       | 6/20 [01:00<02:17,

In [16]:

# Evaluate on test set
test_metrics = evaluate_multimodal_model(model, test_loader,criterion=criterion, device=device)
logging.debug("Test Metrics: ", test_metrics)
test_metrics

{'MSE': 0.005319735416295854,
 'MAE': 0.0537196584045887,
 'R2': 0.7902430295944214,
 'PCC': 0.889129982598392}

In [6]:
# Get some predictions from the test set
model.eval()
with torch.no_grad():
    for batch_data in test_loader:
        features = batch_data["features"].to(device)
        labels = batch_data["labels"].to(device)
        smiles_list = batch_data["smiles_tokens"].to(device)

        # Forward pass
        preds = model(features, smiles_list)  # shape (B, #genes)

        # Move to CPU for easy concatenation
        preds = preds.cpu().numpy()
        labels = labels.cpu().numpy()
        break

# Select a random sample
sample_idx = np.random.randint(0, len(preds))
sample_features = features[sample_idx]
sample_pred = preds[sample_idx]
sample_label = labels[sample_idx]

# Plot the predicted vs. actual gene expression
plt.figure(figsize=(8, 6))
plt.scatter(sample_label, sample_pred, alpha=0.7)
plt.xlabel("Actual Gene Expression")
plt.ylabel("Predicted Gene Expression")
plt.title("Predicted vs. Actual Gene Expression")
plt.grid(True)
plt.show()

print("Base Expression: ", sample_features[:5])
print("Predicted Pert. Expression: ", sample_pred[:5])
print("Actual Pert. Expression: ", sample_label[:5])

NameError: name 'model' is not defined