In [None]:
import torch
import torch.nn as nn

embed_dim = 128
num_heads = 8
seq_len = 10
batch_size = 32
 (64)
# Instantiate multi-head attention
mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

# Example input tensor: (batch, seq_len, embed_dim)
x = torch.randn(batch_size, seq_len, embed_dim)

# Self-attention: Q=K=V=x
output, attn_weights = mha(x, x, x)

print(output.shape)        # [32, 10, 128]
print(attn_weights.shape)  # [32, 10, 10]

In [None]:
x.shape

In [None]:
# Access the combined in-projection weights
combined_weight = mha.in_proj_weight  # Shape: (3*embed_dim, embed_dim)

# Split them into Q, K, V weight matrices
q_weight = combined_weight[:embed_dim, :]
k_weight = combined_weight[embed_dim:2*embed_dim, :]
v_weight = combined_weight[2*embed_dim:, :]

print("Query Weight Matrix shape:", q_weight.shape)
print("Key Weight Matrix shape:", k_weight.shape)
print("Value Weight Matrix shape:", v_weight.shape)

In [None]:
class TransformerCellEncoder(nn.Module):
    def __init__(self, num_genes=2000, embed_dim=128, num_heads=8, hidden_dim=256, final_dim=64):
        super().__init__()

        # Step 0: Gene embedding
        self.input_embedding = nn.Linear(num_genes, embed_dim)

        # Step 1: Multi-head self-attention
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

        # Step 2 & 4: Layer normalization
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # Step 3: Feed-forward
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim),
        )

        # Step 5: Final linear reduction
        self.final_linear = nn.Linear(embed_dim, final_dim)

    def forward(self, x):
        # x shape: (batch_size, num_cells, num_genes)

        # Embedding genes into lower-dimension
        x_emb = self.input_embedding(x)  # (batch_size, num_cells, embed_dim)

        # Multi-head self-attention (across cells)
        attn_output, attn_weights = self.mha(query=x_emb, key=x_emb, value=x_emb)

        # Add & Norm (Transformer style)
        x = self.norm1(x_emb + attn_output)

        # Feed-forward
        ff_output = self.ff(x)

        # Add & Norm again
        x = self.norm2(x + ff_output)

        # Reduce to final embedding
        final_embedding = self.final_linear(x)  # (batch_size, num_cells, final_dim)

        return final_embedding  # Ready to pass to your projection head

### Loss Debug

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# === Simulate input ===
torch.manual_seed(0)

batch_size = 4
cells_per_lineage = 3
embedding_dim = 2

features = torch.randn(batch_size, cells_per_lineage, embedding_dim)
print("Simulated features shape:", features.shape)

# # === Run loss ===
# loss_fn = ContrastiveLoss(temperature=0.5)
# loss = loss_fn(features)
# print("Contrastive loss:", loss.item())

In [None]:
features

In [None]:
[-1.1258/(((-1.1258)**2+(-1.1524)**2)**(1/2)), -1.1524/(((-1.1258)**2+(-1.1524)**2)**(1/2))]

In [None]:
temperature = 1

In [None]:
batch_size, cells_per_lineage, embedding_dim = features.shape

In [None]:
features = F.normalize(features, dim=2)
print("features:", features)
features_flat = features.view(batch_size * cells_per_lineage, embedding_dim)
print("features_flat:", features_flat)

In [None]:
similarity_matrix = torch.matmul(features_flat, features_flat.T) / temperature
similarity_matrix


In [None]:
labels = torch.arange(batch_size).repeat_interleave(cells_per_lineage).to(features.device)
labels

In [None]:
mask = torch.eye(batch_size * cells_per_lineage, dtype=torch.bool).to(features.device)
mask

In [None]:
logits = similarity_matrix.masked_fill(mask, float('-inf'))
logits

In [None]:
log_prob = F.log_softmax(logits, dim=1)
log_prob

In [None]:
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1)
positive_mask

In [None]:
positive_mask = positive_mask & (~mask)
positive_mask



In [None]:
positive_count = positive_mask.sum(1)
positive_count

In [None]:

loss = loss.mean()
loss

In [None]:
loss = -(positive_mask * log_prob).sum(1) / positive_count.clamp(min=1)


In [None]:
positive_mask * log_prob

In [None]:
(positive_mask * log_prob).sum(1)

In [None]:
loss

In [None]:
loss = loss.mean()

In [None]:
loss

In [None]:
import math
math.exp(-math.inf)

In [2]:

import torch
import torch.nn as nn
import torch.nn.functional as F
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        """
        Contrastive Loss for supervised contrastive learning on cell embeddings.

        Args:
            temperature (float): Scaling factor for similarity scores.
        """
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features):
        """
        Compute the supervised contrastive loss for the provided features.

        Args:
            features (torch.Tensor): Embeddings of shape
                                     (batch_size, cells_per_lineage, embedding_dim).

        Returns:
            torch.Tensor: Scalar contrastive loss.
        """
        batch_size, cells_per_lineage, embedding_dim = features.shape

        # NaN check before anything
        if torch.isnan(features).any():
            print("[DEBUG] NaN detected in input features!")
            print("features:", features)
            exit()

        # Normalize features
        features = F.normalize(features, dim=2)

        # Reshape to (batch_size * cells_per_lineage, embedding_dim)
        features_flat = features.view(batch_size * cells_per_lineage, embedding_dim)

        # Compute similarity matrix
        similarity_matrix = torch.matmul(features_flat, features_flat.T) / self.temperature

        # Create labels indicating positive pairs (cells from same lineage)
        labels = torch.arange(batch_size).repeat_interleave(cells_per_lineage).to(features.device)

        # Mask to exclude self-comparisons
        mask = torch.eye(batch_size * cells_per_lineage, dtype=torch.bool).to(features.device)

        # Compute log-softmax of similarities
        logits = similarity_matrix.masked_fill(mask, float('-inf'))
        log_prob = F.log_softmax(logits, dim=1)

        # Create mask for positive pairs (same lineage but not the same cell)
        positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1)
        positive_mask = positive_mask & (~mask)  # remove self-pairs

        # Count positive pairs for normalization
        positive_count = positive_mask.sum(1)

        # Compute loss (only for cells with at least one positive pair)
        masked_log_prob = log_prob.masked_fill(~positive_mask, 0.0)
        loss = -masked_log_prob.sum(1) / positive_count.clamp(min=1)
        
        loss = loss.mean()

        # Post-computation check
        if torch.isnan(loss) or torch.isinf(loss):
            print("[DEBUG] NaN or Inf detected in loss!")
            print("loss:", loss)
            print("features_flat (sample):", features_flat[0])
            print("similarity_matrix (sample):", similarity_matrix[0][:10])
            print("log_prob (sample):", log_prob[0][:10])
            exit()

        return loss

In [3]:
# === Simulate input ===
torch.manual_seed(0)

batch_size = 4
cells_per_lineage = 3
embedding_dim = 2

features = torch.randn(batch_size, cells_per_lineage, embedding_dim)
print("Simulated features shape:", features.shape)



Simulated features shape: torch.Size([4, 3, 2])


In [4]:
# === Run loss ===
loss_fn = ContrastiveLoss(temperature=0.5)
loss = loss_fn(features)
print("Contrastive loss:", loss.item())

Contrastive loss: 3.346036911010742
