In [2]:
import os
os.chdir("..")

In [3]:
from diffusers import VQModel
from utils.dataset_highvars import get_loader


In [7]:
import torch
import torch.nn as nn
from diffusers import ModelMixin, ConfigMixin
from diffusers.configuration_utils import register_to_config
import torch.nn.functional as F

In [8]:
class VectorQuantizer1D(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.beta = beta # Commitment cost

        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)

    def forward(self, latents):
        # latents shape: [Batch, Channels, Sequence_Length]
        # Permute for embedding lookup: [B, L, C]
        latents = latents.permute(0, 2, 1).contiguous()
        flat_latents = latents.view(-1, self.embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_latents**2, dim=1, keepdim=True) 
                     + torch.sum(self.embedding.weight**2, dim=1)
                     - 2 * torch.matmul(flat_latents, self.embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=latents.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self.embedding.weight).view(latents.shape)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), latents)
        q_latent_loss = F.mse_loss(quantized, latents.detach())
        loss = q_latent_loss + self.beta * e_latent_loss

        # Straight-through estimator
        quantized = latents + (quantized - latents).detach()

        # Permute back to [B, C, L]
        quantized = quantized.permute(0, 2, 1).contiguous()

        return quantized, loss

In [9]:


# Re-use the VectorQuantizer from the previous 1D example
# (Assuming VectorQuantizer1D is available or pasted here)

class TransformerVQModel(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        num_features=2000,      # Input vector size (e.g., number of highly variable genes)
        embedding_dim=256,      # Internal dimension for the Transformer
        latent_dim=64,          # Dimension of the quantized latent vectors
        n_heads=4,
        n_layers=4,             # Depth of Encoder/Decoder
        seq_len=32,             # We reshape input into a sequence of this length
        num_embeddings=1024     # Codebook size
    ):
        super().__init__()
        
        self.seq_len = seq_len
        self.feature_per_token = num_features // seq_len
        
        # Ensure divisible
        assert num_features % seq_len == 0, "num_features must be divisible by seq_len"

        # 1. Input Projection (flatten chunks of genes into tokens)
        self.input_proj = nn.Linear(self.feature_per_token, embedding_dim)
        
        # 2. Learnable Positional Embeddings
        # Essential for scRNA: tells the model "This token contains Genes 0-100"
        self.pos_emb = nn.Parameter(torch.randn(1, seq_len, embedding_dim))

        # 3. Transformer Encoder (Global Attention)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # 4. Pre-Quantization Projector
        self.pre_quant = nn.Linear(embedding_dim, latent_dim)

        # 5. Vector Quantizer
        self.quantizer = VectorQuantizer1D(num_embeddings, latent_dim)

        # 6. Post-Quantization Projector
        self.post_quant = nn.Linear(latent_dim, embedding_dim)

        # 7. Transformer Decoder
        decoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=n_heads, batch_first=True)
        self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=n_layers)

        # 8. Output Head
        self.output_head = nn.Linear(embedding_dim, self.feature_per_token)

    def encode(self, x):
        # x shape: [Batch, Num_Features] (e.g., 2000 genes)
        batch_size = x.shape[0]

        # 1. Reshape into Sequence: [B, Seq_Len, Features_Per_Token]
        x = x.view(batch_size, self.seq_len, self.feature_per_token)
        
        # 2. Project to Embeddings & Add Position Info
        h = self.input_proj(x) + self.pos_emb
        
        # 3. Transformer Pass (Global Mixing)
        h = self.encoder(h)
        
        # 4. Project to Latent Dim
        h = self.pre_quant(h)
        
        # 5. Quantize
        # Note: Permute for Quantizer expects [B, Channels, Length] usually, 
        # but our custom 1D quantizer above expected [B, C, L]. 
        # Let's adjust based on the previous quantizer definition:
        h = h.permute(0, 2, 1) # [B, Latent_Dim, Seq_Len]
        quantized, loss = self.quantizer(h)
        
        return quantized, loss

    def decode(self, quantized):
        # quantized: [B, Latent_Dim, Seq_Len]
        h = quantized.permute(0, 2, 1) # Back to [B, Seq_Len, Latent_Dim]
        
        h = self.post_quant(h) + self.pos_emb # Add pos emb again for decoder spatial awareness
        h = self.decoder(h)
        
        # Project back to scalar values
        out = self.output_head(h)
        
        # Flatten back to single vector: [B, Num_Features]
        out = out.view(out.shape[0], -1)
        return out

    def forward(self, x):
        quantized, vq_loss = self.encode(x)
        decoded = self.decode(quantized)
        return decoded, vq_loss

In [13]:
ds, geneDim, maskidx = get_loader(num_samples = 64)

In [27]:
model = TransformerVQModel(
    num_features=geneDim,    # Top 2000 Highly Variable Genes (HVGs)
    seq_len=15,           # Split into 50 tokens (40 genes per token)
    embedding_dim=256,
    latent_dim=64,        # High compression
    num_embeddings=512
)