In [1]:
from clex.benchmark.utils import load_model_for_testing

In [2]:
model, _, _, _ = load_model_for_testing(
    base_path="/home/workspace/clex/runs/2025-05-06_12-34-12_nanopore_full_debug",
    skip_data=True,
)

In [3]:
import os
from clex.utils import untransform_then_unbin

device = "cuda:0"
output_dir = (
    "/home/workspace/clex/runs/2025-05-06_12-34-12_nanopore_full_debug/eval/DREAM"
)

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

model.eval()
model.to(device)

Sequential(
  (0): TargetLengthCrop()
  (1): Borzoi(
    (conv_dna): ConvDna(
      (conv_layer): Conv1d(4, 256, kernel_size=(15,), stride=(1,), padding=same)
      (max_pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (_max_pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (res_tower): Sequential(
      (0): ConvBlock(
        (norm): BatchNorm1d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (activation): GELU(approximate='tanh')
        (conv_layer): Conv1d(256, 320, kernel_size=(5,), stride=(1,), padding=same)
      )
      (1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): ConvBlock(
        (norm): BatchNorm1d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (activation): GELU(approximate='tanh')
        (conv_layer): Conv1d(320, 384, kernel_size=(5,), stride=(1,), padding=same)
      )
      (3): MaxPool1

In [4]:
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from utils import PromoterDataset, plot_tensor_tracks
from clex.train import Config

# Import our batch decoder
from clex.utils import batch_one_hot_decode


def predictions_to_dataframe(prediction, exp_val, yfp_start=1, yfp_end=72):
    """
    Convert model predictions, sequences, and expression values into a DataFrame.

    Args:
        prediction (torch.Tensor): Model predictions of shape (batch_size, tracks, seq_len)
        seqs (torch.Tensor): Input sequences (one-hot encoded with shape [batch_size, seq_len, 4])
        exp_val (torch.Tensor): Expression values
        yfp_start (int): Start position of YFP region (inclusive)
        yfp_end (int): End position of YFP region (inclusive)

    Returns:
        pd.DataFrame: DataFrame with columns for sequences, expression values, and track sums
    """
    num_tracks = prediction.shape[1]

    # Calculate YFP region sums for each track
    yfp_region = prediction[
        :, :, yfp_start : yfp_end + 1
    ]  # +1 because end is inclusive
    yfp_sums = yfp_region.sum(dim=2).cpu().numpy()  # Sum along sequence dimension

    # Convert sequences to strings using our batch decoder

    # Convert expression values to numpy array
    exp_val_np = exp_val.cpu().numpy()

    # Create a dictionary to build the DataFrame
    data_dict = {
        "expression_value": exp_val_np.flatten(),  # Ensure it's 1D
    }

    # Add track sum columns
    for i in range(num_tracks):
        data_dict[f"track_{i}_sum"] = yfp_sums[:, i]

    # Create DataFrame
    df = pd.DataFrame(data_dict)

    return df


def create_track_expression_scatter_plots(df, output_dir, region_name="YFP"):
    """
    Create scatter plots of track coverage sums against expression values.

    Args:
        df (pd.DataFrame): DataFrame containing expression values and track sums
        output_dir (str): Directory to save the plots
        region_name (str): Name of the region being analyzed (for plot titles)
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Get list of track columns
    track_columns = [col for col in df.columns if col.startswith("track_")]

    # Create subplot figure for all tracks
    num_tracks = len(track_columns)
    fig_combined, axes = plt.subplots(
        nrows=(num_tracks + 1) // 2,
        ncols=2,
        figsize=(14, 3 * ((num_tracks + 1) // 2)),
        tight_layout=True,
    )
    axes = axes.flatten() if num_tracks > 1 else [axes]

    # Create individual scatter plots for each track
    for i, track_col in enumerate(track_columns):
        track_num = track_col.split("_")[1]

        # Calculate correlation
        correlation, p_value = pearsonr(df[track_col], df["expression_value"])

        # Individual plot
        plt.figure(figsize=(8, 6))
        plt.scatter(df[track_col], df["expression_value"], alpha=0.6)
        plt.title(
            f"Track {track_num} {region_name} Coverage vs Expression\nr = {correlation:.3f}, p = {p_value:.3e}"
        )
        plt.xlabel(f"Track {track_num} {region_name} Coverage Sum")
        plt.ylabel("Expression Value")
        plt.grid(True, alpha=0.3)

        # Add best fit line
        z = np.polyfit(df[track_col], df["expression_value"], 1)
        p = np.poly1d(z)
        x_range = np.linspace(df[track_col].min(), df[track_col].max(), 100)
        plt.plot(x_range, p(x_range), "r--", alpha=0.8)

        # Save individual plot
        plt.tight_layout()
        plt.savefig(f"{output_dir}/track_{track_num}_expression_scatter.png", dpi=300)
        plt.close()

        # Add to combined subplot figure
        axes[i].scatter(df[track_col], df["expression_value"], alpha=0.6)
        axes[i].set_title(f"Track {track_num}\nr = {correlation:.3f}")
        axes[i].set_xlabel(f"Track {track_num} Coverage Sum")
        axes[i].set_ylabel("Expression Value")
        axes[i].grid(True, alpha=0.3)

        # Add best fit line to combined plot
        axes[i].plot(x_range, p(x_range), "r--", alpha=0.8)

    # Save the combined figure
    fig_combined.suptitle(
        f"{region_name} Region Track Coverage vs Expression", fontsize=16
    )
    plt.tight_layout()
    fig_combined.subplots_adjust(top=0.94)  # Make room for the suptitle
    plt.savefig(f"{output_dir}/all_tracks_expression_scatter.png", dpi=300)
    plt.close()

    # Create a correlation summary plot
    correlations = []
    p_values = []
    track_names = []

    for track_col in track_columns:
        track_num = track_col.split("_")[1]
        corr, p_val = pearsonr(df[track_col], df["expression_value"])
        correlations.append(corr)
        p_values = p_values
        track_names.append(f"Track {track_num}")

    # Create correlation summary bar chart
    plt.figure(figsize=(10, 6))
    bars = plt.bar(track_names, correlations)
    plt.axhline(y=0, color="k", linestyle="-", alpha=0.3)
    plt.title(f"Correlation between {region_name} Track Coverage and Expression")
    plt.ylabel("Pearson Correlation")
    plt.ylim(-1, 1)

    # Add correlation values on top of bars
    for i, bar in enumerate(bars):
        height = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + 0.05 * (1 if height >= 0 else -1),
            f"{correlations[i]:.3f}",
            ha="center",
            va="bottom" if height >= 0 else "top",
        )

    plt.tight_layout()
    plt.savefig(f"{output_dir}/track_correlations_summary.png", dpi=300)
    plt.close()

In [5]:
"""
GPU inference – no sequence caching, no full-tensor caching
===========================================================

* Runs the model batch-by-batch.
* Immediately reduces the prediction tensor to
      (batch, tracks) = Σ_{pos ∈ [YFP start, YFP end)} prediction[..., pos]
  and discards the big (tracks, seq_len) array.
* Builds the final DataFrame from these small per-track sums plus the labels.
"""

import os
import torch
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader
from scipy.stats import pearsonr

# ---------------------------------------------------------------------------
# paths & constants
# ---------------------------------------------------------------------------
promoter_path = (
    "/home/workspace/clex/clex/benchmark/DREAM/"
    "GSE254493_filtered_test_data_with_MAUDE_expression.txt"
)
output_dir = "/home/workspace/clex/runs/2025-05-06_12-34-12_nanopore_full_debug/DREAM/output"
plots_dir = "/home/workspace/clex/runs/2025-05-06_12-34-12_nanopore_full_debug/DREAM/plots"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True)

DEVICE = device  # defined earlier in your notebook

sections = {"YFP start": 1, "YFP end": 90}
yfp_start, yfp_end = sections["YFP start"], sections["YFP end"]

# ---------------------------------------------------------------------------
# dataset / loader / model
# ---------------------------------------------------------------------------
test_dataset = PromoterDataset(promoter_path)
test_loader = DataLoader(
    test_dataset,
    batch_size=100,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True,
)

model.eval()
model.to(DEVICE)

# ---------------------------------------------------------------------------
# inference loop – keep only tiny per-track region sums
# ---------------------------------------------------------------------------
region_sums_accum = []  # list of (batch, tracks) CPU tensors
expr_accum = []  # list of (batch, 1)   CPU tensors
sample_ids = []  # list of str

with torch.no_grad(), torch.cuda.amp.autocast():
    for b_idx, (seqs, exp_val) in enumerate(tqdm(test_loader, desc="Inference")):
        preds = model(seqs.to(DEVICE, non_blocking=True))

        # Reduce *inside* the loop to avoid holding the big tensor
        region_sum = preds[:, :, yfp_start:yfp_end].sum(dim=2).cpu()  # (B, tracks)

        region_sums_accum.append(region_sum)
        expr_accum.append(exp_val.cpu())

        offset = b_idx * test_loader.batch_size
        sample_ids.extend([f"sample_{offset + j}" for j in range(seqs.size(0))])

# ---------------------------------------------------------------------------
# build the final DataFrame
# ---------------------------------------------------------------------------
region_sums_full = torch.cat(region_sums_accum, dim=0).numpy()  # (N, tracks)
expr_full = torch.cat(expr_accum, dim=0).squeeze(1).numpy()

num_tracks = region_sums_full.shape[1]

data = {
    "sample_id": sample_ids,
    "expression_value": expr_full,
}
for t in range(num_tracks):
    data[f"track_{t}"] = region_sums_full[:, t]

final_df = pd.DataFrame(data)

csv_path = os.path.join(output_dir, "all_samples_results.csv")
final_df.to_csv(csv_path, index=False)
print(f"Saved complete results to {csv_path}")

# ---------------------------------------------------------------------------
# downstream analysis – unchanged
# ---------------------------------------------------------------------------
create_track_expression_scatter_plots(final_df, plots_dir, region_name="YFP")
print(f"Generated scatter plots in {plots_dir}")

track_cols = [c for c in final_df.columns if c.startswith("track_")]
stats_df = final_df[track_cols + ["expression_value"]].describe()
stats_df.to_csv(os.path.join(output_dir, "track_statistics.csv"))
print("Saved track statistics.")

corr_records = [
    {
        "track": col,
        "correlation_with_expression": pearsonr(
            final_df[col], final_df["expression_value"]
        )[0],
        "p_value": pearsonr(final_df[col], final_df["expression_value"])[1],
    }
    for col in track_cols
]
pd.DataFrame(corr_records).to_csv(
    os.path.join(output_dir, "track_expression_correlations.csv"), index=False
)
print("Saved track-expression correlations.")


  with torch.no_grad(), torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast(enabled=False):
Inference: 100%|██████████| 712/712 [00:28<00:00, 25.25it/s]


Saved complete results to /home/workspace/clex/runs/2025-05-06_12-34-12_nanopore_full_debug/DREAM/output/all_samples_results.csv
Generated scatter plots in /home/workspace/clex/runs/2025-05-06_12-34-12_nanopore_full_debug/DREAM/plots
Saved track statistics.
Saved track-expression correlations.
