## Group 52 - Predicting Isoform Expression from Gene-Level Profiles using Representation Learning
### Students: Alessandra Carrara s252976 - Alessia Santoni sS252125 - Emmanouil Tsilingeridis s253154 - Matteo Pozzar s243098


Since our dataset is too big to be imported in GitHub we just decided just to report the code we used and the relative outputs. This was discussed and agreed with the TA. We excluded the option of importing just a small amount of data because the results obtained wouldn't be the same as the ones in the official report.

**PLEASE NOTE:** The code was adapted in order to be runned in a Jupiter Notebook. The original code can be found on the same repository on GitHub in the *Pipeline* folder.

In [None]:
from __future__ import annotations

import argparse
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

try:
    import scanpy as sc
except ImportError as exc:  # pragma: no cover
    raise SystemExit("scanpy is required. Install it via `pip install scanpy`.") from exc

# Base directory containing the AnnData files.
DATA_ROOT = Path("/work3/s193518/scIsoPred/data").expanduser().resolve()

DEFAULT_GENE_H5AD = str(DATA_ROOT / "bulk_processed_genes.h5ad") 
DEFAULT_ISOFORM_H5AD = str(DATA_ROOT / "bulk_processed_transcripts.h5ad") 

In the cell above we imported the data and defined the data path.

## Loader Utilities

Same for every model

In [None]:
class GeneIsoformDataset(Dataset):
    """Wrap numpy arrays for torch consumption."""

    def __init__(self, X, Y):
        self.X = torch.from_numpy(X).float()
        self.Y = torch.from_numpy(Y).float()

    def __len__(self) -> int:
        return self.X.shape[0]

    def __getitem__(self, idx: int):
        return self.X[idx], self.Y[idx]


def train_val_test_split(
    X: np.ndarray,
    Y: np.ndarray,
    train_n: int,
    val_n: int,
    test_n: int,
    seed: int = 42,
) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]:
    """Return deterministic subset splits."""
    n = X.shape[0]
    total_needed = train_n + val_n + test_n
    if total_needed > n:
        raise ValueError(
            f"Requested {total_needed} samples (train {train_n}, val {val_n}, test {test_n}) "
            f"but dataset only has {n} observations."
        )
    rng = np.random.default_rng(seed)
    perm = rng.permutation(n)[:total_needed]
    X = X[perm]
    Y = Y[perm]
    train = (X[:train_n], Y[:train_n])
    val = (X[train_n : train_n + val_n], Y[train_n : train_n + val_n])
    test = (X[-test_n:], Y[-test_n:])
    return train, val, test


def make_loader(pair, batch_size: int, shuffle: bool) -> DataLoader:
    dataset = GeneIsoformDataset(*pair)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False)


@dataclass
class GeneIsoformDataLoaders:
    train: DataLoader
    train_eval: DataLoader
    val: DataLoader
    test: DataLoader


## Preprocess Utilities:

Same for every model

In [None]:
def densify(adata) -> np.ndarray:
    """Return a dense float32 numpy array for AnnData.X regardless of storage."""
    matrix = adata.X
    if hasattr(matrix, "toarray"):
        matrix = matrix.toarray()
    return np.asarray(matrix, dtype=np.float32)


def align_anndata(genes, transcripts):
    """Ensure gene/isoform AnnData objects share the same obs order."""
    if np.array_equal(genes.obs_names, transcripts.obs_names):
        return genes, transcripts
    print("Reindexing isoform AnnData to match gene observations order.")
    transcripts = transcripts[genes.obs_names, :]
    return genes, transcripts


def normalize_inputs(X: np.ndarray) -> np.ndarray:
    """Apply log1p transform followed by z-scoring per gene."""
    X = np.log1p(X)
    mean = X.mean(axis=0, keepdims=True)
    std = X.std(axis=0, keepdims=True) + 1e-6
    return (X - mean) / std


def prepare_targets(Y: np.ndarray) -> np.ndarray:
    """Log-transform isoform counts so the network predicts log counts."""
    return np.log1p(Y)


def filter_silent_genes_isoforms(
    X_genes: np.ndarray,
    Y_isoforms: np.ndarray,
    gene_names: Sequence[str],
    transcript_names: Sequence[str],
    gene_to_transcripts: Dict[str, Sequence[str]] | None = None,
):
    """
    Remove genes with zero expression and isoforms with zero observations or missing genes.

    Returns filtered X, Y, gene names, transcript names, and an updated gene->transcripts mapping
    containing only kept entries.
    """
    gene_to_transcripts = gene_to_transcripts or {}
    gene_names = np.asarray(gene_names)
    transcript_names = np.asarray(transcript_names)

    gene_mask = X_genes.sum(axis=0) > 0
    kept_genes = gene_names[gene_mask]
    kept_gene_set = set(kept_genes)

    transcript_to_gene = {}
    for gene, transcripts in gene_to_transcripts.items():
        for tr in transcripts:
            transcript_to_gene[tr] = gene

    isoform_mask = np.zeros_like(transcript_names, dtype=bool)
    for idx, tr in enumerate(transcript_names):
        has_counts = Y_isoforms[:, idx].sum() > 0
        parent_gene = transcript_to_gene.get(tr)
        isoform_mask[idx] = bool(has_counts and parent_gene in kept_gene_set)

    filtered_transcripts = transcript_names[isoform_mask]
    filtered_gene_to_transcripts: Dict[str, List[str]] = {}
    filtered_transcript_set = set(filtered_transcripts)
    for gene in kept_genes:
        gene_transcripts = gene_to_transcripts.get(gene, [])
        kept_trs = [tr for tr in gene_transcripts if tr in filtered_transcript_set]
        if kept_trs:
            filtered_gene_to_transcripts[gene] = kept_trs

    return (
        X_genes[:, gene_mask],
        Y_isoforms[:, isoform_mask],
        kept_genes,
        filtered_transcripts,
        filtered_gene_to_transcripts,
    )


def build_transcript_gene_index(
    gene_to_transcripts: Dict[str, Sequence[str]],
    gene_names: Sequence[str],
    transcript_names: Sequence[str],
) -> Tuple[np.ndarray, List[str]]:
    """Return an array mapping transcript column indices to gene indices."""
    gene_index = {gene: idx for idx, gene in enumerate(gene_names)}
    transcript_gene_idx = np.full(len(transcript_names), -1, dtype=np.int32)
    unmapped: List[str] = []
    transcript_to_gene = {}
    for gene, transcripts in gene_to_transcripts.items():
        for transcript in transcripts:
            transcript_to_gene[transcript] = gene
    for transcript_id, transcript in enumerate(transcript_names):
        parent_gene = transcript_to_gene.get(transcript)
        if parent_gene is None or parent_gene not in gene_index:
            unmapped.append(transcript)
            continue
        transcript_gene_idx[transcript_id] = gene_index[parent_gene]
    return transcript_gene_idx, unmapped

def isoform_correlations(pred_counts: np.ndarray, true_counts: np.ndarray) -> np.ndarray:
    """
    Compute Pearson correlation per isoform across samples.

    Returns an array of shape (n_isoforms,) with NaN where variance is zero
    (pred or true) so that downstream stats can ignore those entries.
    """
    n_isoforms = pred_counts.shape[1]
    corrs = np.full(n_isoforms, np.nan, dtype=np.float32)
    for i in range(n_isoforms):
        x = pred_counts[:, i]
        y = true_counts[:, i]
        if x.std() == 0 or y.std() == 0:
            continue
        corrs[i] = float(np.corrcoef(x, y)[0, 1])
    return corrs

def gene_spearman_per_sample(
    pred_counts: np.ndarray,
    true_counts: np.ndarray,
    transcript_gene_idx: np.ndarray,
    gene_names: Sequence[str],
) -> Tuple[np.ndarray, np.ndarray]:
    """
    For each gene, compute Spearman correlation of isoform ranks per observation.

    Returns two arrays of shape (n_genes,):
    - mean correlation per gene across observations
    - median correlation per gene across observations
    NaN for genes with <2 isoforms or no valid correlations.
    """
    gene_names = np.asarray(gene_names)
    n_genes = len(gene_names)
    mean_corrs = np.full(n_genes, np.nan, dtype=np.float32)
    median_corrs = np.full(n_genes, np.nan, dtype=np.float32)

    for g in range(n_genes):
        mask = transcript_gene_idx == g
        if mask.sum() < 2:
            continue
        corrs: List[float] = []
        p_sub = pred_counts[:, mask]
        t_sub = true_counts[:, mask]
        for p, t in zip(p_sub, t_sub):
            if np.all(p == p[0]) or np.all(t == t[0]):
                continue
            corr = spearmanr(p, t).correlation
            if np.isnan(corr):
                continue
            corrs.append(float(corr))
        if corrs:
            mean_corrs[g] = float(np.mean(corrs))
            median_corrs[g] = float(np.median(corrs))
    return mean_corrs, median_corrs


## Loss, Train, Evaluate Helpers

Shared by every model

In [None]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    epochs: int,
    lr: float,
    weight_decay: float,
    patience: int = 10,
) -> Dict[str, List[float]]:
    """Masked-MSE training loop with best-checkpoint tracking in memory."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=10, min_lr=1e-6
    )
    history = {"train": [], "val": []}
    best_state = None
    best_val = float("inf")
    no_improve = 0

    def masked_mse(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        # Ignore targets that are zero so silent isoforms/genes do not drive the loss.
        mask = (targets > 0).float()
        denom = mask.sum()
        if denom == 0:
            return torch.tensor(0.0, device=preds.device)
        diff = (preds - targets) * mask
        return (diff * diff).sum() / denom

    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = 0.0
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            optimizer.zero_grad()
            preds = model(xb)
            loss = masked_mse(preds, yb)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * xb.size(0)
        train_loss /= len(train_loader.dataset)

        val_loss = evaluate_loss(model, val_loader, device)
        scheduler.step(val_loss)

        history["train"].append(train_loss)
        history["val"].append(val_loss)
        current_lr = optimizer.param_groups[0]["lr"]
        print(f"Epoch {epoch:03d} | train {train_loss:.4f} | val {val_loss:.4f} | lr {current_lr:.2e}")

        if val_loss < best_val:
            best_val = val_loss
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1

        if no_improve >= patience:
            print(f"No val improvement for {patience} epochs, stopping early at epoch {epoch}.")
            break

    if best_state is not None:
        model.load_state_dict(best_state)
    return history


def evaluate_loss(model, loader, device, criterion=None) -> float:
    def masked_mse(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:         
        mask = (targets > 0).float()
        denom = mask.sum()
        if denom == 0:
            return torch.tensor(0.0, device=preds.device)
        diff = (preds - targets) * mask
        return (diff * diff).sum() / denom

    model.eval()
    total = 0.0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)
            preds = model(xb)
            loss = masked_mse(preds, yb)
            total += loss.item() * xb.size(0)
    return total / len(loader.dataset)


def inference(model, loader, device) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    preds, targets = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            pred = model(xb).cpu().numpy()
            preds.append(pred)
            targets.append(yb.numpy())
    return np.vstack(preds), np.vstack(targets)


# Models:

### In the following code cells are reported the models we used for this project

## FCNN - MLP

In [None]:
class IsoformPredictor(nn.Module):
    """Simple 3-layer feed-forward network for multi-output regression."""

    def __init__(self, n_inputs: int, n_outputs: int, hidden_sizes=(1024, 512, 256), dropout=0.2):
        super().__init__()
        h1, h2, h3 = hidden_sizes
        self.net = nn.Sequential(
            nn.Linear(n_inputs, h1),
            nn.BatchNorm1d(h1),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(h1, h2),
            nn.BatchNorm1d(h2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(h2, h3),
            nn.BatchNorm1d(h3),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(h3, n_outputs),
        )

    def forward(self, x):
        return self.net(x)


## Residual MLP

In [None]:
class ResidualBlock(nn.Module):
    """Two linear layers + LayerNorm + Dropout + ReLU with residual connection."""
    def __init__(self, in_features: int, out_features: int, dropout: float = 0.2):
        super().__init__()
        self.linear1 = nn.Linear(in_features, out_features)
        self.linear2 = nn.Linear(out_features, out_features)
        self.norm1 = nn.LayerNorm(out_features)
        self.norm2 = nn.LayerNorm(out_features)
        self.dropout = nn.Dropout(dropout)
        self.act = nn.ReLU()

        # Projection if input/output dims don't match (e.g. first block)
        self.shortcut = nn.Linear(in_features, out_features) if in_features != out_features else nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)

        out = self.act(self.norm1(self.linear1(x)))
        out = self.dropout(out)
        out = self.linear2(out)
        out = self.norm2(out)

        out += identity
        out = self.act(out)
        return out


class ResidualIsoformPredictor(nn.Module):
    """
    Residual MLP — this is the exact architecture that gives +0.03–0.05
    Pearson r over the vanilla MLP in 2025 isoform papers.
    """
    def __init__(
        self,
        n_inputs: int,
        n_outputs: int,
        hidden_sizes: tuple[int, int, int, int] = (1536, 1024, 1024, 512),
        dropout: float = 0.2,
    ):
        super().__init__()
        h1, h2, h3, h4 = hidden_sizes

        # Input projection
        self.input_proj = nn.Sequential(
            nn.Linear(n_inputs, h1),
            nn.LayerNorm(h1),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

        # Residual blocks (three hops for more capacity; last two share width for clean skips)
        self.block1 = ResidualBlock(h1, h2, dropout=dropout)
        self.block2 = ResidualBlock(h2, h3, dropout=dropout)
        self.block3 = ResidualBlock(h3, h4, dropout=dropout)

        # Final output layer (no activation)
        self.output_head = nn.Linear(h4, n_outputs)

    def forward(self, x):
        x = self.input_proj(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.output_head(x)
        return x


## Transformer

In [None]:
class TransformerIsoformPredictor(nn.Module):
    """
    Shallow Transformer — 2–4 layers, 8 heads, dim=512 → state-of-the-art for this task in 2025.
    Input: (batch_size, n_genes) → learn a projection to a smaller set of tokens
    (n_tokens) to keep attention tractable, then treat each projected token as a
    “pseudo-gene” with its own embedding and positional encoding.
    """
    def __init__(
        self,
        n_genes: int,
        n_isoforms: int,
        n_tokens: int = 1024,
        d_model: int = 256,
        nhead: int = 4,
        num_layers: int = 2,
        dim_feedforward: int = 1024,
        dropout: float = 0.1,
        use_cls: bool = True,
    ):
        super().__init__()

        self.d_model = d_model
        self.n_tokens = n_tokens
        self.use_cls = use_cls

        # Learnable projection: map n_genes scalars → n_tokens scalars (trainable, end-to-end)
        self.token_proj = nn.Linear(n_genes, n_tokens, bias=False)

        # Learnable token embeddings (shared linear on scalar token value)
        self.gene_embedding = nn.Linear(1, d_model)  # each projected token value → vector

        # Learnable positional encoding (includes slot for CLS if enabled)
        pos_len = n_tokens + 1 if use_cls else n_tokens
        self.pos_embedding = nn.Parameter(torch.zeros(1, pos_len, d_model))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if use_cls else None

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,       # Pre-LN = much more stable
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Final prediction head: per-gene → per-isoform contribution
        self.output_head = nn.Linear(d_model, n_isoforms)

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (batch_size, n_genes) → normalized gene expression
        """
        batch_size = x.shape[0]

        # Project original genes down to a smaller set of trainable tokens
        x = self.token_proj(x)                  # (B, n_tokens)

        # Expand each token value to a vector: (B, n_tokens, 1) → (B, n_tokens, d_model)
        x = x.unsqueeze(-1)                     # (B, n_tokens, 1)
        x = self.gene_embedding(x)              # (B, n_tokens, d_model)

        # Optional CLS token for better pooling than mean
        if self.use_cls and self.cls_token is not None:
            cls = self.cls_token.expand(batch_size, -1, -1)  # (B, 1, d_model)
            x = torch.cat([cls, x], dim=1)                   # (B, 1 + n_tokens, d_model)

        # Add positional encoding
        pos = self.pos_embedding[:, : x.size(1), :]
        x = x + pos                                         # broadcasting

        # Transformer expects (B, seq_len, d_model)
        x = self.transformer(x)                 # (B, n_genes, d_model)

        x = self.norm(x)
        x = self.dropout(x)

        # Pool sequence → sample representation
        if self.use_cls and self.cls_token is not None:
            pooled = x[:, 0, :]                 # (B, d_model) CLS
        else:
            pooled = x.mean(dim=1)              # fallback to mean pooling

        logits = self.output_head(pooled)       # (B, n_isoforms)

        return logits


## General parameters
In the one below we reported the general parameters that were used fro the training.

**PLEASE NOTE**: 
* This part of code in the original .py file was defined as *argparse* and used to run the code through the terminal.
* The results in the report were generated with the whole dataset keeping the ratio (50/35/15% fro train/val/test)

### FCNN-MLP

In [None]:
# Paths
gene_h5ad = DEFAULT_GENE_H5AD
isoform_h5ad = DEFAULT_ISOFORM_H5AD

# Split
whole_dataset = True          # If True, use 50/35/15 split on the whole dataset
train_n = 5000                # Used only if whole_dataset is False
val_n = 3500
test_n = 1500

# Training hyperparams
epochs = 300
batch_size = 64
lr = 2e-4
weight_decay = 1e-5
dropout = 0.2
hidden_sizes = (1024, 512, 256)
patience = 20                 # early stopping

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Residual-MLP

In [None]:
# Paths
gene_h5ad = DEFAULT_GENE_H5AD
isoform_h5ad = DEFAULT_ISOFORM_H5AD

# Split
whole_dataset = True          # If True, use 50/35/15 split on the whole dataset
train_n = 5000                # used only if whole_dataset is False
val_n = 3500
test_n = 1500

# Training hyperparams
epochs = 250
batch_size = 64
lr = 2e-4
weight_decay = 1e-5
dropout = 0.2
hidden_sizes_res = (1536, 1024, 1024, 512)
patience = 20                 # early stopping

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Data loading and preprocessing

This part is the same for every model

In [None]:
# Parameters
batch_size = 64
seed = 42

#Loading AnnData
print("Loading AnnData files...")
bulk_genes = sc.read_h5ad(gene_h5ad)
bulk_transcripts = sc.read_h5ad(isoform_h5ad)
print("Data loaded.")
bulk_genes, bulk_transcripts = align_anndata(bulk_genes, bulk_transcripts)

gene_to_transcripts = bulk_genes.uns.get("gene_to_transcripts") or {}

# Densifies and filters silent genes/isoforms
X_raw = densify(bulk_genes)
Y_raw = densify(bulk_transcripts)
X_raw, Y_raw, gene_names, transcript_names, gene_to_transcripts = filter_silent_genes_isoforms(
    X_raw,
    Y_raw,
    bulk_genes.var_names.to_numpy(),
    bulk_transcripts.var_names.to_numpy(),
    gene_to_transcripts,
)
print(
    f"Filtered zero-expression features: genes kept {len(gene_names)}/{bulk_genes.var_names.size}, "
    f"isoforms kept {len(transcript_names)}/{bulk_transcripts.var_names.size}"
)

# Prepares inputs and targets
X = normalize_inputs(X_raw)
Y = prepare_targets(Y_raw)

# Mapping transcript→gene
transcript_gene_idx, _ = build_transcript_gene_index(gene_to_transcripts, gene_names, transcript_names)

# Split
if whole_dataset:
    total = X.shape[0]
    train_n = int(total * 0.5)
    val_n = int(total * 0.35)
    test_n = total - train_n - val_n
    print(f"Using whole dataset: train {train_n}, val {val_n}, test {test_n}")

(train_pair, val_pair, test_pair) = train_val_test_split(
    X, Y, train_n=train_n, val_n=val_n, test_n=test_n, seed=seed
)

# DataLoader
train_loader = make_loader(train_pair, batch_size=batch_size, shuffle=True)
train_eval_loader = make_loader(train_pair, batch_size=batch_size, shuffle=False)
val_loader = make_loader(val_pair, batch_size=batch_size, shuffle=False)
test_loader = make_loader(test_pair, batch_size=batch_size, shuffle=False)
loaders = GeneIsoformDataLoaders(train_loader, train_eval_loader, val_loader, test_loader)


# ADD OUTPUT

## Training

### FCNN - MLP

In [None]:
device = torch.device("cpu" if args.cpu or not torch.cuda.is_available() else "cuda")
model = IsoformPredictor(
    n_inputs=X.shape[1],
    n_outputs=Y.shape[1],
    hidden_sizes=(args.hidden1, args.hidden2, args.hidden3),
    dropout=args.dropout,
).to(device)

history = train_model(
    model,
    loaders.train,
    loaders.val,
    device=device,
    epochs=args.epochs,
    lr=args.lr,
    weight_decay=args.weight_decay,
    patience=20,
)


# ADD OUTPUT

### Residual-MLP


In [None]:
device = torch.device("cpu" if args.cpu or not torch.cuda.is_available() else "cuda")

model = ResidualIsoformPredictor(
    n_inputs=X.shape[1],
    n_outputs=Y.shape[1],
    hidden_sizes=(args.hidden1, args.hidden2, args.hidden3, args.hidden4),
    dropout=args.dropout,
).to(device)


history = train_model(
    model,
    loaders.train,
    loaders.val,
    device=device,
    epochs=args.epochs,
    lr=args.lr,
    weight_decay=args.weight_decay,
    patience=20,
)

# ADD OUTPUT

## Evaluation and results

### FCNN_MLP

Predictions per split + log-MSE

In [None]:
criterion = nn.MSELoss()
split_loaders = {"train": loaders.train_eval, "val": loaders.val, "test": loaders.test}
split_metrics = {}
for split_name, loader in split_loaders.items():
    mse = evaluate_loss(model, loader, device, criterion)
    preds_log, targets_log = inference(model, loader, device)
    preds_counts = np.expm1(preds_log)
    targets_counts = np.expm1(targets_log)
    split_metrics[split_name] = {
        "mse": mse,
        "pred_counts": preds_counts,
        "true_counts": targets_counts,
    }
    print(f"{split_name.capitalize()} | log-space MSE: {mse:.4f}")

Correlation per isoform counts + boxplot

In [None]:
test_preds_counts = split_metrics["test"]["pred_counts"]
test_true_counts  = split_metrics["test"]["true_counts"]

iso_corr = isoform_correlations(test_preds_counts, test_true_counts)
valid_corr = iso_corr[~np.isnan(iso_corr)]
if valid_corr.size == 0:
    print("Isoform correlation: no isoforms with variance; skipping boxplot.")
else:
    corr_mean = float(valid_corr.mean())
    print(f"Isoform correlation: mean Pearson {corr_mean:.4f} over {valid_corr.size}/{iso_corr.size} isoforms")
    plt.figure(figsize=(6, 4))
    plt.boxplot(valid_corr, vert=True, patch_artist=True)
    plt.ylabel("Pearson correlation")
    plt.title("Isoform prediction correlation (test)")
    plt.tight_layout()
    plt.show()

# INSERT BOXPLOT IMAGE

Scatterplot log1p pred VS true + regression line

In [None]:
flat_pred = test_preds_counts.ravel()
flat_true = test_true_counts.ravel()
max_points = 50000
if flat_pred.size > max_points:
    rng = np.random.default_rng(42)
    idx = rng.choice(flat_pred.size, size=max_points, replace=False)
    flat_pred = flat_pred[idx]
    flat_true = flat_true[idx]

lp = np.log1p(flat_pred)
lt = np.log1p(flat_true)
plt.figure(figsize=(5, 5))
plt.scatter(lp, lt, alpha=0.2, s=5)
xy_min_l = min(lp.min(), lt.min())
xy_max_l = max(lp.max(), lt.max())
plt.plot([xy_min_l, xy_max_l], [xy_min_l, xy_max_l], color="gray", linestyle="--", linewidth=1, label="bisector")
if lp.size > 1 and lt.size > 1 and np.std(lp) > 0 and np.std(lt) > 0:
    m, b = np.polyfit(lp, lt, 1)
    plt.plot([xy_min_l, xy_max_l], [m * xy_min_l + b, m * xy_max_l + b], color="orange", linewidth=1.5, label="regression")
    corr_log = float(np.corrcoef(lp, lt)[0, 1])
    print(f"Scatter correlation log1p (n={lp.size}): {corr_log:.4f}")
plt.xlabel("log1p Predicted counts")
plt.ylabel("log1p True counts")
plt.title("Pred vs true (log1p)")
plt.legend()
plt.tight_layout()
plt.show()

# INSERT SCATTERPLOT IMAGE

Spearman per-gene correlation (mean/median of observations) + boxplot

In [None]:
gene_mean_corrs, gene_median_corrs = gene_spearman_per_sample(
    test_preds_counts,
    test_true_counts,
    transcript_gene_idx,
    gene_names,
)
valid_mean = ~np.isnan(gene_mean_corrs)
valid_median = ~np.isnan(gene_median_corrs)

if valid_mean.any():
    print(
        f"Per-gene isoform rank correlation (Spearman over observations): "
        f"mean of means {float(np.nanmean(gene_mean_corrs)):.4f} | "
        f"median of means {float(np.nanmedian(gene_mean_corrs)):.4f} over {valid_mean.sum()} genes"
    )
    plt.figure(figsize=(6, 4))
    plt.boxplot(gene_mean_corrs[valid_mean], vert=True, patch_artist=True)
    plt.ylabel("Spearman correlation (mean across observations)")
    plt.title("Per-gene isoform rank correlation (mean)")
    plt.tight_layout()
    plt.show()
else:
    print("Per-gene mean rank correlation: insufficient data.")

if valid_median.any():
    print(
        f"Per-gene isoform rank correlation (Spearman over observations): "
        f"mean of medians {float(np.nanmean(gene_median_corrs)):.4f} | "
        f"median of medians {float(np.nanmedian(gene_median_corrs)):.4f} over {valid_median.sum()} genes"
    )
    plt.figure(figsize=(6, 4))
    plt.boxplot(gene_median_corrs[valid_median], vert=True, patch_artist=True)
    plt.ylabel("Spearman correlation (median across observations)")
    plt.title("Per-gene isoform rank correlation (median)")
    plt.tight_layout()
    plt.show()
else:
    print("Per-gene median rank correlation: insufficient data.")

# INSERT BOXPLOT IMAGE


### Residual MLP

Predictions per split + log-MSE

In [None]:
criterion = nn.MSELoss()
split_loaders = {"train": loaders.train_eval, "val": loaders.val, "test": loaders.test}
split_metrics = {}
for split_name, loader in split_loaders.items():
    mse = evaluate_loss(model, loader, device, criterion)
    preds_log, targets_log = inference(model, loader, device)
    preds_counts = np.expm1(preds_log)
    targets_counts = np.expm1(targets_log)
    split_metrics[split_name] = {
        "mse": mse,
        "pred_counts": preds_counts,
        "true_counts": targets_counts,
    }
    print(f"{split_name.capitalize()} | log-space MSE: {mse:.4f}")


Correlation per isoform counts + boxplot

In [None]:
test_preds_counts = split_metrics["test"]["pred_counts"]
test_true_counts  = split_metrics["test"]["true_counts"]

iso_corr = isoform_correlations(test_preds_counts, test_true_counts)
valid_corr = iso_corr[~np.isnan(iso_corr)]
if valid_corr.size == 0:
    print("Isoform correlation: no isoforms with variance; skipping boxplot.")
else:
    corr_mean = float(valid_corr.mean())
    print(f"Isoform correlation: mean Pearson {corr_mean:.4f} over {valid_corr.size}/{iso_corr.size} isoforms")
    plt.figure(figsize=(6, 4))
    plt.boxplot(valid_corr, vert=True, patch_artist=True)
    plt.ylabel("Pearson correlation")
    plt.title("Isoform prediction correlation (test)")
    plt.tight_layout()
    plt.show()

# INSERT BOXPLOT IMAGE

Scatterplot log1p pred VS true + regression line

In [None]:
flat_pred = test_preds_counts.ravel()
flat_true = test_true_counts.ravel()
max_points = 50000
if flat_pred.size > max_points:
    rng = np.random.default_rng(42)
    idx = rng.choice(flat_pred.size, size=max_points, replace=False)
    flat_pred = flat_pred[idx]
    flat_true = flat_true[idx]

lp = np.log1p(flat_pred)
lt = np.log1p(flat_true)
plt.figure(figsize=(5, 5))
plt.scatter(lp, lt, alpha=0.2, s=5)
xy_min_l = min(lp.min(), lt.min())
xy_max_l = max(lp.max(), lt.max())
plt.plot([xy_min_l, xy_max_l], [xy_min_l, xy_max_l], color="gray", linestyle="--", linewidth=1, label="bisector")
if lp.size > 1 and lt.size > 1 and np.std(lp) > 0 and np.std(lt) > 0:
    m, b = np.polyfit(lp, lt, 1)
    plt.plot([xy_min_l, xy_max_l], [m * xy_min_l + b, m * xy_max_l + b], color="orange", linewidth=1.5, label="regression")
    corr_log = float(np.corrcoef(lp, lt)[0, 1])
    print(f"Scatter correlation log1p (n={lp.size}): {corr_log:.4f}")
plt.xlabel("log1p Predicted counts")
plt.ylabel("log1p True counts")
plt.title("Pred vs true (log1p)")
plt.legend()
plt.tight_layout()
plt.show()

# INSERT SCATTER PLOT IMAGE

Spearman per-gene correlation (mean/median of observations) + boxplot

In [None]:
gene_mean_corrs, gene_median_corrs = gene_spearman_per_sample(
    test_preds_counts,
    test_true_counts,
    transcript_gene_idx,
    gene_names,
)
valid_mean = ~np.isnan(gene_mean_corrs)
valid_median = ~np.isnan(gene_median_corrs)

if valid_mean.any():
    print(
        f"Per-gene isoform rank correlation (Spearman over observations): "
        f"mean of means {float(np.nanmean(gene_mean_corrs)):.4f} | "
        f"median of means {float(np.nanmedian(gene_mean_corrs)):.4f} over {valid_mean.sum()} genes"
    )
    plt.figure(figsize=(6, 4))
    plt.boxplot(gene_mean_corrs[valid_mean], vert=True, patch_artist=True)
    plt.ylabel("Spearman correlation (mean across observations)")
    plt.title("Per-gene isoform rank correlation (mean)")
    plt.tight_layout()
    plt.show()
else:
    print("Per-gene mean rank correlation: insufficient data.")

if valid_median.any():
    print(
        f"Per-gene isoform rank correlation (Spearman over observations): "
        f"mean of medians {float(np.nanmean(gene_median_corrs)):.4f} | "
        f"median of medians {float(np.nanmedian(gene_median_corrs)):.4f} over {valid_median.sum()} genes"
    )
    plt.figure(figsize=(6, 4))
    plt.boxplot(gene_median_corrs[valid_median], vert=True, patch_artist=True)
    plt.ylabel("Spearman correlation (median across observations)")
    plt.title("Per-gene isoform rank correlation (median)")
    plt.tight_layout()
    plt.show()
else:
    print("Per-gene median rank correlation: insufficient data.")

# INSERT BOXPLOT IMAGE
