## workflow

a) Input: scVI embeddings of cell states
b) Output: Generated protein sequences

Process Flow:

scVI embeddings → UNetModel → Flow Matching → Protein Decoder → Protein Sequence

1. Input Preparation:
   - Start with scVI (single-cell Variational Inference) embeddings of cell states.
   - These embeddings are high-dimensional vectors (e.g., 128 dimensions) representing cellular gene expression profiles.
   - The scVI embeddings are generated using a separate scVI model trained on single-cell RNA sequencing data.

2. Model Architecture Overview:
   The system consists of several key components:
   a) UNetModel: The core generative model
   b) FlowMatchingTrainer: Manages the flow matching process
   c) ProtT5Encoder: Encodes protein sequences (for training)
   d) ProtT5Decoder: Decodes latent representations to protein sequences

3. UNetModel Detailed Architecture:
   3.1. Initialization:
   - The UNetModel is initialized with parameters like input/output channels, number of ResBlocks, attention resolutions, etc.
   - Key components are created: time embedder, input blocks, middle block, output blocks.

   3.2. Time Embedding:
   - Function: timestep_embedding
   - Converts a scalar timestep to a high-dimensional vector using sinusoidal functions.
   - This embedding is further processed through a small MLP (self.time_embed).

   3.3. Input Blocks:
   - A series of TimestepEmbedSequential modules, each containing:
     a) ResBlock: Combines feature maps with time embeddings
     b) AttentionBlock or SpatialTransformer: For self-attention mechanisms
     c) Downsample: Reduces spatial dimensions (if applicable)

   3.4. Middle Block:
   - Contains ResBlocks and Attention mechanisms for global reasoning.

   3.5. Output Blocks:
   - Mirror the input blocks, but with Upsample layers instead of Downsample.
   - Use skip connections from input blocks.

   3.6. Final Output Layer:
   - Normalization followed by a convolution to produce the output channels.

4. Detailed Component Breakdown:
   4.1. ResBlock:
   - Residual block that processes features and incorporates time embeddings.
   - Contains normalization layers, convolutions, and optional up/downsampling.
   - Uses checkpoint function for memory-efficient backpropagation.

   4.2. AttentionBlock:
   - Self-attention mechanism allowing interaction between different parts of the sequence.
   - Uses QKVAttention for efficient attention computation.

   4.3. SpatialTransformer:
   - More sophisticated attention mechanism with multiple transformer layers.
   - Each layer contains self-attention and feed-forward networks.

   4.4. CrossAttention:
   - Attention mechanism that can attend to a separate context (used in SpatialTransformer).
   - Splits input into query, key, and value before computing attention.

   4.5. FeedForward:
   - Simple feedforward network used in transformer blocks.
   - Contains two linear layers with GELU activation and dropout.

   4.6. Upsample and Downsample:
   - Handle changes in spatial dimensions of feature maps.
   - Use either interpolation or transposed convolutions.

   4.7. GroupNorm32:
   - Custom group normalization for improved training stability.

   4.8. TimestepEmbedSequential:
   - Sequential module that handles passing of timestep embeddings to appropriate submodules.

5. FlowMatchingTrainer:
   - Manages the training process of the UNetModel.
   - Implements the forward process (adding noise) and reverse process (denoising).
   - Uses a noise schedule to control the amount of noise added at each timestep.
   - Computes loss based on the model's ability to predict the noise added.

6. ProtT5Encoder (for training):
   - Utilizes a pre-trained ProtT5 model to encode protein sequences into a latent space.
   - Processes amino acid sequences into a high-dimensional representation.

7. ProtT5Decoder (for inference):
   - Converts latent representations back into amino acid sequences.
   - Uses beam search or other decoding strategies to generate the final protein sequence.

8. Training Process:
   8.1. Data Preparation:
   - Batch of scVI embeddings and corresponding protein sequences are loaded.
   - Protein sequences are encoded using ProtT5Encoder.

   8.2. Forward Pass:
   - Random timesteps are generated for each sample in the batch.
   - Noise is added to the encoded protein sequences based on the timesteps.
   - The UNetModel processes the noisy encodings, conditioned on scVI embeddings and timesteps.

   8.3. Loss Computation:
   - The model's output is compared to the true noise added.
   - Loss is calculated (usually mean squared error).

   8.4. Backpropagation:
   - Gradients are computed and model parameters are updated.

9. Inference Process:
   9.1. Start with an scVI embedding of a cell state.
   9.2. Generate random noise as the starting point.
   9.3. Gradually denoise using the UNetModel:
      - For each timestep (from most noisy to least):
        - Pass the current noisy sample through the UNetModel.
        - Use the model's prediction to update the sample.
   9.4. The final denoised representation is passed through the ProtT5Decoder.
   9.5. The decoder outputs the generated protein sequence.

10. Utility Functions:
    - conv_nd: Creates 1D convolutions for our sequence data.
    - zero_module: Initializes a module's parameters to zero.
    - normalization: Applies GroupNorm32 normalization.
    - checkpoint: Implements gradient checkpointing for memory efficiency.
    - exists and default: Helper functions for handling optional parameters.

Cool adaptations
    - Adaptation of 2D UNet architecture to 1D protein sequences.
        - Unet in original model was used for the audio representations (spectograms)
        - While we do use protein embeddings (like those from ProtT5), the UNet in our case still operates on a 1D sequence of these embeddings.
        - Each position in this sequence corresponds to an amino acid, but is represented by a high-dimensional vector.
        - The UNet processes this sequence of vectors, maintaining the 1D structure of the protein
            - each element of the sequence is itself a rich high-dimensional representation (1280)
    - Integration of flow matching with protein language models.
    - Use of scVI embeddings as conditional input for targeted protein generation.


Loss:
θ^ = argmin_θ E_t,z_t ||u_θ(z_t, t, c) - v_t||^2
Where:

u_θ is your flow matching model (UNet)
z_t is the scVI embedding at time t
t is the timestep
c is your context (which we'll discuss next)
v_t is the target velocity (z_1 - (1-σ_min)z_0 in the optimal transport formulation)


UNet
This 1D UNet processes the scVI latent representations, which encode cellular states, through a series of downsampling and upsampling operations.
The input blocks progressively reduce the spatial dimensions while increasing the channel depth, capturing hierarchical features.
The middle block, with its deep channel representation, allows for global reasoning across the entire sequence.
The output blocks then gradually upsample the representation back to the original dimensions, utilizing skip connections to preserve fine-grained information.
Time embeddings are crucial, allowing the model to understand its position in the generation process.
These embeddings are added to the input at each step, guiding the transformation from noise to protein sequence.
Attention mechanisms, implemented either as AttentionBlocks or SpatialTransformers, enable the model to capture long-range dependencies critical for protein structure.
The ResBlocks incorporate both the current state and the time embedding, allowing for time-dependent processing at each level.
The context dimension, which could include additional information like pseudotime or motif data, is integrated through the SpatialTransformer blocks, providing extra conditioning for the generation process.
The model's output represents the velocity field in the flow matching framework, predicting how the latent representation should change at each step to transform noise into a meaningful protein sequence representation.



## library installs

In [1]:
!pip install torch anndata scvi-tools einops numpy scipy transformers scanpy

Collecting anndata
  Downloading anndata-0.10.8-py3-none-any.whl.metadata (6.6 kB)
Collecting scvi-tools
  Downloading scvi_tools-1.1.5-py3-none-any.whl.metadata (17 kB)
Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting scanpy
  Downloading scanpy-1.10.2-py3-none-any.whl.metadata (9.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cac

## utils

In [2]:
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import numpy as np


class AbstractDistribution:
    def sample(self):
        raise NotImplementedError()

    def mode(self):
        raise NotImplementedError()


class DiracDistribution(AbstractDistribution):
    def __init__(self, value):
        self.value = value

    def sample(self):
        return self.value

    def mode(self):
        return self.value


class DiagonalGaussianDistribution(object):
    def __init__(self, parameters, deterministic=False):
        self.parameters = parameters
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(
                device=self.parameters.device
            )

    def sample(self):
        x = self.mean + self.std * torch.randn(self.mean.shape).to(
            device=self.parameters.device
        )
        return x

    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.0])
        else:
            if other is None:
                return 0.5 * torch.sum(
                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
                    dim=[1, 2, 3],
                )
            else:
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var
                    - 1.0
                    - self.logvar
                    + other.logvar,
                    dim=[1, 2, 3],
                )

    def nll(self, sample, dims=[1, 2, 3]):
        if self.deterministic:
            return torch.Tensor([0.0])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims,
        )

    def mode(self):
        return self.mean


def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
    Compute the KL divergence between two gaussians.
    Shapes are automatically broadcasted, so batches can be compared to
    scalars, among other use cases.
    """
    tensor = None
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, torch.Tensor):
            tensor = obj
            break
    assert tensor is not None, "at least one argument must be a Tensor"

    # Force variances to be Tensors. Broadcasting helps convert scalars to
    # Tensors, but it does not work for torch.exp().
    logvar1, logvar2 = [
        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
        for x in (logvar1, logvar2)
    ]

    return 0.5 * (
        -1.0
        + logvar2
        - logvar1
        + torch.exp(logvar1 - logvar2)
        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
    )


In [3]:
from abc import abstractmethod
from functools import partial
import math
from typing import Iterable

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import repeat
def Normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

def count_flops_attn(model, _x, y):
    b, c, *spatial = y[0].shape
    num_spatial = int(np.prod(spatial))
    matmul_ops = 2 * b * (num_spatial**2) * c
    model.total_ops += torch.DoubleTensor([matmul_ops])

def timestep_embedding(timesteps, dim, max_period=10000):
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")

def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)

def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")

def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module

def normalization(channels):
    """
    Make a standard normalization layer.
    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    return GroupNorm32(32, channels)

class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)

def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)

class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """
    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """
    def forward(self, x, emb, context=None):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            elif isinstance(layer, SpatialTransformer):
                x = layer(x, context)
            else:
                x = layer(x)
        return x


In [4]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        attn = sim.softmax(dim=-1)

        out = torch.einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)

def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)

class Upsample(nn.Module):
    def __init__(self, channels, use_conv, dims=1, out_channels=None, padding=1):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        if use_conv:
            self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.dims == 3:
            x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
        else:
            x = F.interpolate(x, scale_factor=2, mode="nearest")
        if self.use_conv:
            x = self.conv(x)
        return x

class Downsample(nn.Module):
    def __init__(self, channels, use_conv, dims=1, out_channels=None, padding=1):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        stride = 2 if dims != 3 else (1, 2, 2)
        if use_conv:
            self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.op(x)

class ResBlock(TimestepBlock):
    def __init__(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        dims=1,
        use_checkpoint=False,
        up=False,
        down=False,
    ):
        super().__init__()
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm

        self.in_layers = nn.Sequential(
            normalization(channels),
            nn.SiLU(),
            conv_nd(dims, channels, self.out_channels, 3, padding=1),
        )

        self.updown = up or down

        if up:
            self.h_upd = Upsample(channels, False, dims)
            self.x_upd = Upsample(channels, False, dims)
        elif down:
            self.h_upd = Downsample(channels, False, dims)
            self.x_upd = Downsample(channels, False, dims)
        else:
            self.h_upd = self.x_upd = nn.Identity()

        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
            ),
        )
        self.out_layers = nn.Sequential(
            normalization(self.out_channels),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
        )

        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        elif use_conv:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
        else:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

    def forward(self, x, emb):
        return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)

    def _forward(self, x, emb):
        if self.updown:
            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
            h = in_rest(x)
            h = self.h_upd(h)
            x = self.x_upd(x)
            h = in_conv(h)
        else:
            h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = torch.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h

class QKVAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.chunk(3, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = torch.einsum(
            "bct,bcs->bts",
            (q * scale).view(bs * self.n_heads, ch, length),
            (k * scale).view(bs * self.n_heads, ch, length),
        )
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
        return a.reshape(bs, -1, length)


class AttentionBlock(nn.Module):
    def __init__(
        self,
        channels,
        num_heads=1,
        num_head_channels=-1,
        use_checkpoint=False,
        use_new_attention_order=False,
    ):
        super().__init__()
        self.channels = channels
        if num_head_channels == -1:
            self.num_heads = num_heads
        else:
            assert channels % num_head_channels == 0
            self.num_heads = channels // num_head_channels
        self.use_checkpoint = use_checkpoint
        self.norm = normalization(channels)
        self.qkv = conv_nd(1, channels, channels * 3, 1)
        self.attention = QKVAttention(self.num_heads)
        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

    def forward(self, x):
        return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)

    def _forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)
        qkv = self.qkv(self.norm(x))
        h = self.attention(qkv)
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)

class SpatialTransformer(nn.Module):
    def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None):
        super().__init__()
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)

        self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)

        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
             for d in range(depth)]
        )

        self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
        b, c, s = x.shape
        x_in = x
        x = self.norm(x)
        x = self.proj_in(x)
        x = rearrange(x, 'b c s -> b s c')
        for block in self.transformer_blocks:
            x = block(x, context=context)
        x = rearrange(x, 'b s c -> b c s')
        x = self.proj_out(x)
        return x + x_in

class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
        super().__init__()
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint

    def forward(self, x, context=None):
        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

    def _forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x

## pretrained embedding modules (prott5, scvi)

In [5]:
from transformers import T5EncoderModel, T5Tokenizer
import torch.nn as nn
import re

class ProtT5EncodingModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.protT5_model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
        self.protT5_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_bfd")

    def forward(self, sequence):
        processed_seq = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))
        ids = self.protT5_tokenizer(processed_seq, add_special_tokens=True, return_tensors="pt", padding='longest')
        input_ids = ids['input_ids'].to(self.protT5_model.device)
        attention_mask = ids['attention_mask'].to(self.protT5_model.device)

        with torch.no_grad():
            embedding_repr = self.protT5_model(input_ids=input_ids, attention_mask=attention_mask)

        seq_emb = embedding_repr.last_hidden_state
        return seq_emb

In [6]:

from transformers import T5ForConditionalGeneration

class ProtT5DecodingModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.protT5_model = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_bfd")
        self.protT5_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_bfd")

    def forward(self, latent_repr, max_length=200):
        outputs = self.protT5_model.generate(
            inputs_embeds=latent_repr,
            max_length=max_length,
            num_beams=4,
            early_stopping=True
        )

        decoded_sequences = self.protT5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
        return decoded_sequences

In [7]:
import scvi
import scanpy as sc
import numpy as np

# encode pseudotime and latent representations

class SCVIEncodingModule:
    def __init__(self):
        self.latent_representations = {}
        self.pseudotime_representations = {}

    def encode(self, adata_dict):
        for cell_type, adata in adata_dict.items():
            print(f"Training and embedding for cell type: {cell_type}...")

            adata_copy = adata.copy()

            nan_count = np.isnan(adata_copy.X).sum()
            if nan_count > 0:
                print(f"There are {nan_count} NaN values in adata_copy.X for cell type: {cell_type}")
            else:
                print(f"No NaN values found in adata_copy.X for cell type: {cell_type}")

            latent = adata_copy.obsm['X_scvi']
            pseudotime = adata_copy.obs['dpt_pseudotime']

            # Store latent representation in the dictionary
            self.latent_representations[cell_type] = latent
            self.pseudotime_representations[cell_type] = pseudotime

        print("Encoding completed.")

        return self.latent_representations, self.pseudotime_representations

## unet model (flow matching)

In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CustomUNet1D(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        model_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        use_checkpoint=False,
        num_heads=8,
        use_scale_shift_norm=False,
        use_spatial_transformer=False,
        transformer_depth=1,
        context_dim=1,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.use_checkpoint = use_checkpoint
        self.num_heads = num_heads
        self.use_spatial_transformer = use_spatial_transformer
        self.context_dim = context_dim

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            nn.Linear(model_channels, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )

        self.input_blocks = nn.ModuleList([
            nn.Conv1d(in_channels, model_channels, 3, padding=1)
        ])

        ch = model_channels
        input_block_chans = [model_channels]
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResBlock1D(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    if use_spatial_transformer:
                        layers.append(
                            SpatialTransformer1D(
                                ch, num_heads, context_dim, depth=transformer_depth
                            )
                        )
                    else:
                        layers.append(AttentionBlock1D(ch, num_heads=num_heads))
                self.input_blocks.append(nn.Sequential(*layers))
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                self.input_blocks.append(
                    nn.Conv1d(ch, ch, 3, stride=2, padding=1)
                )
                input_block_chans.append(ch)
                ds *= 2

        self.middle_block = nn.Sequential(
            ResBlock1D(
                ch,
                time_embed_dim,
                dropout,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock1D(ch, num_heads=num_heads) if not use_spatial_transformer else
            SpatialTransformer1D(ch, num_heads, context_dim, depth=transformer_depth),
            ResBlock1D(
                ch,
                time_embed_dim,
                dropout,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                layers = [
                    ResBlock1D(
                        ch + input_block_chans.pop(),
                        time_embed_dim,
                        dropout,
                        out_channels=model_channels * mult,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    if use_spatial_transformer:
                        layers.append(
                            SpatialTransformer1D(
                                ch, num_heads, context_dim, depth=transformer_depth
                            )
                        )
                    else:
                        layers.append(AttentionBlock1D(ch, num_heads=num_heads))
                if level and i == num_res_blocks:
                    layers.append(nn.ConvTranspose1d(ch, ch, 4, stride=2, padding=1))
                    ds //= 2
                self.output_blocks.append(nn.Sequential(*layers))

        self.out = nn.Sequential(
            nn.GroupNorm(32, ch),
            nn.SiLU(),
            nn.Conv1d(ch, out_channels, 3, padding=1),
        )

    def forward(self, x, timesteps, context=None):
      x = x.transpose(1, 2)  # transpose shape: [2, 50, 50663]
      t_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

      h = x
      hs = []
      for module in self.input_blocks:
          print("Module type input blocks:", type(module))
          if isinstance(module, nn.Conv1d):
              h = module(h)
          elif isinstance(module, SpatialTransformer1D):
            h = module(h, context)
          elif isinstance(module, ResBlock1D):
            h = module(h, t_emb)
          elif isinstance(module, nn.Sequential):
              for submodule in module:
                  print("submodule type input blocks:", type(submodule))
                  if isinstance(submodule, ResBlock1D):
                      h = submodule(h, t_emb)
                  elif isinstance(submodule, SpatialTransformer1D):
                      h = submodule(h, context)
                  else:
                      h = submodule(h)
          else:
              h = module(h)
          hs.append(h)

      if isinstance(self.middle_block, nn.Sequential):
          for submodule in self.middle_block:
              if isinstance(submodule, ResBlock1D):
                  h = submodule(h, t_emb)
              elif isinstance(submodule, SpatialTransformer1D):
                  h = submodule(h, context)
              else:
                  h = submodule(h)
      else:
          h = self.middle_block(h)

      for module in self.output_blocks:
          h = torch.cat([h, hs.pop()], dim=1)
          if isinstance(module, nn.Sequential):
              for submodule in module:
                  if isinstance(submodule, ResBlock1D):
                      h = submodule(h, t_emb)
                  elif isinstance(submodule, SpatialTransformer1D):
                      h = submodule(h, context)
                  else:
                      h = submodule(h)
          elif isinstance(module, SpatialTransformer1D):
            h = module(h, context)
          elif isinstance(module, ResBlock1D):
            h = module(h, t_emb)
          else:
              h = module(h)

      output = self.out(h)
      output = output.transpose(1, 2)  # shape: [2, 50663, 1024]

      return output

class ResBlock1D(nn.Module):
    def __init__(self, channels, time_embed_dim, dropout, out_channels=None, use_scale_shift_norm=False):
        super().__init__()
        self.channels = channels
        self.time_embed_dim = time_embed_dim
        self.out_channels = out_channels or channels
        self.use_scale_shift_norm = use_scale_shift_norm

        self.in_layers = nn.Sequential(
            nn.GroupNorm(32, channels),
            nn.SiLU(),
            nn.Conv1d(channels, self.out_channels, 3, padding=1),
        )
        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_embed_dim, 2 * self.out_channels if use_scale_shift_norm else self.out_channels),
        )
        self.out_layers = nn.Sequential(
            nn.GroupNorm(32, self.out_channels),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv1d(self.out_channels, self.out_channels, 3, padding=1),
        )
        if channels != self.out_channels:
            self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
        else:
            self.skip_connection = nn.Identity()

    def forward(self, x, emb):
        h = self.in_layers(x)
        emb_out = self.emb_layers(emb).unsqueeze(2)
        if self.use_scale_shift_norm:
            scale, shift = torch.chunk(emb_out, 2, dim=1)
            h = self.out_layers[0](h) * (1 + scale) + shift
            h = self.out_layers[1:](h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h

class AttentionBlock1D(nn.Module):
    def __init__(self, channels, num_heads=1):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads

        self.norm = nn.GroupNorm(32, channels)
        self.qkv = nn.Conv1d(channels, channels * 3, 1)
        self.attention = QKVAttention(num_heads)
        self.proj_out = nn.Conv1d(channels, channels, 1)

    def forward(self, x):
        b, c, s = x.shape
        qkv = self.qkv(self.norm(x))
        qkv = qkv.reshape(b * self.num_heads, -1, s)
        h = self.attention(qkv)
        h = h.reshape(b, -1, s)
        return x + self.proj_out(h)

class QKVAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        bs, width, length = qkv.shape
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.chunk(3, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = torch.einsum(
            "bct,bcs->bts",
            (q * scale).view(bs * self.n_heads, ch, length),
            (k * scale).view(bs * self.n_heads, ch, length),
        )
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
        return a.reshape(bs, -1, length)

class SpatialTransformer1D(nn.Module):
    def __init__(self, channels, num_heads, context_dim, depth=1):
        super().__init__()
        self.norm = nn.GroupNorm(32, channels)
        inner_dim = channels
        self.proj_in = nn.Conv1d(channels, inner_dim, 1)
        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, num_heads, context_dim) for _ in range(depth)]
        )
        self.proj_out = nn.Conv1d(inner_dim, channels, 1)

    def forward(self, x, context=None):
        b, c, s = x.shape
        x_in = x
        x = self.norm(x)
        x = self.proj_in(x)
        x = x.permute(0, 2, 1).contiguous()
        for block in self.transformer_blocks:
            x = block(x, context)
        x = x.permute(0, 2, 1).contiguous()
        x = self.proj_out(x)
        return x + x_in

class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, context_dim):
        super().__init__()
        self.attn1 = CrossAttention(dim, dim, num_heads)
        self.ff = FeedForward(dim)
        self.attn2 = CrossAttention(dim, context_dim, num_heads)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x

class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim, num_heads, dim_head=64):
        super().__init__()
        inner_dim = dim_head * num_heads
        self.scale = dim_head ** -0.5
        self.num_heads = num_heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context=None):
        h = self.num_heads

        q = self.to_q(x)
        context = x if context is None else context
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: t.reshape(t.shape[0], -1, h, t.shape[-1] // h).permute(0, 2, 1, 3), (q, k, v))
        sim = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        attn = sim.softmax(dim=-1)
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = out.permute(0, 2, 1, 3).reshape(out.shape[0], -1, out.shape[-1] * h)
        return self.to_out(out)

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim),
        )

    def forward(self, x):
        return self.net(x)

def timestep_embedding(timesteps, dim, max_period=10000):
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

class TimestepEmbedSequential(nn.Sequential):
    def forward(self, x, emb, context=None):
        for layer in self:
            if isinstance(layer, ResBlock1D):
                x = layer(x, emb)
            elif isinstance(layer, SpatialTransformer1D):
                x = layer(x, context)
            elif isinstance(layer, nn.Conv1d) or isinstance(layer, nn.GroupNorm) or isinstance(layer, nn.ReLU):
                x = layer(x)
            else:
                x = layer(x, emb, context)
        return x

## ode solver

In [49]:
from scipy.integrate import solve_ivp
import numpy as np

class ODESolverModule:
    def __init__(self, model):
        self.model = model

    def ode_func(self, t, y, *args):
        y_tensor = torch.tensor(y, dtype=torch.float32).unsqueeze(0)
        t_tensor = torch.tensor([t], dtype=torch.float32)
        with torch.no_grad():
            dy_dt = self.model(y_tensor, t_tensor, *args).squeeze().numpy()
        return dy_dt

    def solve(self, y0, t_span, *args, method='RK45', **kwargs):
        solution = solve_ivp(
            fun=lambda t, y: self.ode_func(t, y, *args),
            t_span=t_span,
            y0=y0,
            method=method,
            **kwargs
        )
        return solution

## flow trainer

In [50]:
class FlowMatchingTrainer(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        init_type="gaussian",
        noise_scale=1.0,
        reflow_t_schedule="uniform",
        use_ode_sampler="euler",
        sigma_var=0.0,
        ode_tol=1e-5,
        sample_N=25,
    ):
        super().__init__()
        self.model = model
        self.init_type = init_type
        self.noise_scale = noise_scale
        self.reflow_t_schedule = reflow_t_schedule
        self.use_ode_sampler = use_ode_sampler
        self.sigma_var = sigma_var
        self.ode_tol = ode_tol
        self.sample_N = sample_N
        self.T = 1
        self.eps = 1e-3
        self.sigma_t = lambda t: (1.0 - t) * sigma_var

    def forward(self, x_0, c):
        t = torch.rand(x_0.shape[0], device=x_0.device) * (self.T - self.eps) + self.eps
        t_expand = t.view(-1, 1, 1).repeat(1, x_0.shape[1], x_0.shape[2])
        c = c.to(x_0.device)

        noise = torch.randn_like(x_0)
        target = x_0 - noise
        perturbed_data = t_expand * x_0 + (1 - t_expand) * noise

        model_out = self.model(perturbed_data, t * 999, c)

        loss = F.mse_loss(model_out, target, reduction="none").mean([1, 2]).mean()
        return loss

    @torch.no_grad()
    def euler_sample(self, cond, shape, guidance_scale):
        device = self.model.device
        batch = torch.randn(shape, device=device)
        x = torch.randn_like(batch)
        dt = 1.0 / self.sample_N
        eps = 1e-3
        for i in range(self.sample_N):
            num_t = i / self.sample_N * (self.T - eps) + eps
            t = torch.ones(batch.shape[0], device=device) * num_t

            model_out = self.model(torch.cat([x] * 2), torch.cat([t * 999] * 2), cond)
            noise_pred_uncond, noise_pred_text = model_out.chunk(2)
            pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            sigma_t = self.sigma_t(num_t)
            pred_sigma = pred + (sigma_t**2) / (2 * (self.noise_scale**2) * ((1.0 - num_t) ** 2)) * (
                0.5 * num_t * (1.0 - num_t) * pred - 0.5 * (2.0 - num_t) * x.detach().clone()
            )

            x = x.detach().clone() + pred_sigma * dt + sigma_t * np.sqrt(dt) * torch.randn_like(pred_sigma).to(device)

        return x, self.sample_N

## train module

### scvi encoding

In [11]:
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
import anndata
import scvi
import os

file_count = 0

adata_list = []
folder_path = '/content/drive/MyDrive/tf-flow-design/combined_adata_folder/'
for filename in os.listdir(folder_path):
    if filename.endswith('.h5ad'):
        file_path = os.path.join(folder_path, filename)
        print(filename)
        adata = anndata.read_h5ad(file_path)
        adata_list.append(adata)
        file_count += 1
        if file_count >= 5:
            break

print(f'Read and stored {file_count} .h5ad files.')

combined_adata_macrophage.h5ad
combined_adata_monocyte.h5ad
combined_adata_endothelial cell of hepatic sinusoid.h5ad
combined_adata_liver dendritic cell.h5ad
combined_adata_nk cell.h5ad
Read and stored 5 .h5ad files.


In [12]:
# # Concatenate all AnnData objects
# combined_adata = anndata.concat(adata_list, join='outer', label='batch')
# print(f"Combined AnnData shape: {combined_adata.shape}")

In [13]:
import anndata

adata_dict = {}

for adata in adata_list:
    # Get the unique cell_ontology_class values (excluding 'mesenchymal stem cell')
    cell_types = adata.obs['cell_ontology_class'].unique()
    for cell_type in cell_types:
        if cell_type != 'mesenchymal stem cell':
            if cell_type not in adata_dict:
                adata_dict[cell_type] = adata
            else:
                adata_dict[cell_type] = anndata.concat([adata_dict[cell_type], adata])


In [14]:
adata_dict.keys()

dict_keys(['macrophage', 'monocyte', 'endothelial cell of hepatic sinusoid', 'liver dendritic cell', 'nk cell'])

In [15]:
scvi_encoder = SCVIEncodingModule()
scvi_latents, scvi_pseudotimes = scvi_encoder.encode(adata_dict)

Training and embedding for cell type: macrophage...
No NaN values found in adata_copy.X for cell type: macrophage
Training and embedding for cell type: monocyte...
No NaN values found in adata_copy.X for cell type: monocyte
Training and embedding for cell type: endothelial cell of hepatic sinusoid...
No NaN values found in adata_copy.X for cell type: endothelial cell of hepatic sinusoid
Training and embedding for cell type: liver dendritic cell...
No NaN values found in adata_copy.X for cell type: liver dendritic cell
Training and embedding for cell type: nk cell...
No NaN values found in adata_copy.X for cell type: nk cell
Encoding completed.


In [16]:
scvi_latents['macrophage'].shape

(50663, 50)

In [17]:
scvi_pseudotimes['macrophage'].shape

(50663,)

### generate random protein sequences for each scvi latent (3 pos, 10 neg)

In [18]:
import random

# valid AAs
valid_amino_acids = "ACDEFGHIKLMNPQRSTVWY"

def generate_random_protein_sequence(min_length=100, max_length=200):
    length = random.randint(min_length, max_length)
    return ''.join(random.choices(valid_amino_acids, k=length))

protein_sequences = {}

for cell_type in scvi_latents.keys():
    protein_sequences[cell_type] = [generate_random_protein_sequence() for _ in range(3)]

cell_type_example = 'macrophage'
print(f"prot seqs for {cell_type_example}:")
for seq in protein_sequences[cell_type_example]:
    print(seq)



prot seqs for macrophage:
RGAFKLTNWGNESMPYEAQCVLFLYEHYYTKTRIRSFNRRLLEFIWVWYARMENQLINDYVPMVIRRTIDPPFLRSEMYFFWVNDQACHQNKGYFRSGMWMEGAKMKIC
TDCSSDLDAKMEYTKDGEHPENLTDMKHAIKDRGPKVRCYWSCCFWMFESGRLAYQQTMFGNGCFINTAWMHPFGGYCYDQWFAHLAQARWWRMLPGYTAGIKFMMVQFDPWYMIMTLDNAIVQHYCHAHPSPVWVSIWISNHQNHTYWFNVNLYCGFAADNEREKSDLPMKKFFANIH
GNHWAKHKHSIHNDVDDYKTDPWECHSFQPMGSHLLMRSWREKTEAMIWTTYISHTNCEPWCQWKSSFSFGWDRRDAFFGWTNYRICITCLAEMRVQQGSVCAKEENDTNSCMKYCVKTRGHC


### datasets

In [51]:
from torch.nn.utils.rnn import pad_sequence

In [52]:
encoder = ProtT5EncodingModule()

latent_list = []
pseudotime_list = []
sequence_list = []

for cell_type, latents in scvi_latents.items():
    pseudotimes = scvi_pseudotimes[cell_type]
    for latent, pseudotime, sequence in zip(latents, pseudotimes, protein_sequences[cell_type]):
        print(f"Cell type: {cell_type}, Latent shape: {latents.shape}, Sequence length: {len(sequence)}, Pseudotime shape: {pseudotimes.shape}")
        latent_list.append(latents)
        pseudotime_list.append(pseudotimes)
        sequence_list.append(sequence)

# seqs encoding protT5
sequence_tensor_list = []
for sequence in sequence_list:
    sequence_tensor = encoder(sequence)
    sequence_tensor_list.append(sequence_tensor)


Cell type: macrophage, Latent shape: (50663, 50), Sequence length: 109, Pseudotime shape: (50663,)
Cell type: macrophage, Latent shape: (50663, 50), Sequence length: 179, Pseudotime shape: (50663,)
Cell type: macrophage, Latent shape: (50663, 50), Sequence length: 123, Pseudotime shape: (50663,)
Cell type: monocyte, Latent shape: (27973, 50), Sequence length: 143, Pseudotime shape: (27973,)
Cell type: monocyte, Latent shape: (27973, 50), Sequence length: 173, Pseudotime shape: (27973,)
Cell type: monocyte, Latent shape: (27973, 50), Sequence length: 153, Pseudotime shape: (27973,)
Cell type: endothelial cell of hepatic sinusoid, Latent shape: (15880, 50), Sequence length: 187, Pseudotime shape: (15880,)
Cell type: endothelial cell of hepatic sinusoid, Latent shape: (15880, 50), Sequence length: 131, Pseudotime shape: (15880,)
Cell type: endothelial cell of hepatic sinusoid, Latent shape: (15880, 50), Sequence length: 159, Pseudotime shape: (15880,)
Cell type: liver dendritic cell, Late

In [53]:
sequence_tensor_list[0].shape

torch.Size([1, 110, 1024])

In [54]:

# pad
max_len = max(sequence.shape[1] for sequence in sequence_tensor_list)
padded_sequence_tensor_list = [
    torch.cat([sequence, torch.zeros(1, max_len - sequence.shape[1], sequence.shape[2], dtype=sequence.dtype)], dim=1)
    if sequence.shape[1] < max_len else sequence for sequence in sequence_tensor_list
]
sequence_tensor = torch.cat(padded_sequence_tensor_list, dim=0)


In [55]:
sequence_tensor.shape

torch.Size([15, 188, 1024])

In [56]:
latent_list[0].shape

(50663, 50)

In [57]:
latent_tensor_list = [torch.tensor(latent, dtype=torch.float32) for latent in latent_list]

max_len = max(latent.shape[0] for latent in latent_tensor_list)

padded_latent_list = [
    torch.cat([latent, torch.zeros((max_len - latent.shape[0], latent.shape[1]), dtype=latent.dtype)], dim=0)
    if latent.shape[0] < max_len else latent for latent in latent_tensor_list
]

latent_tensor = torch.stack(padded_latent_list, dim=0)


In [58]:
latent_tensor_list[0].shape

torch.Size([50663, 50])

In [59]:
latent_tensor.shape

torch.Size([15, 50663, 50])

In [60]:
### pseudotime as a context:
adata_list[0].obs['dpt_pseudotime'].shape
unique_counts = adata_list[0].obs['cell_ontology_class'].value_counts()
print(unique_counts)

cell_ontology_class
macrophage               35204
mesenchymal stem cell    15459
Name: count, dtype: int64


In [61]:
pseudotime_tensor_list = [torch.tensor(pseudotime, dtype=torch.float32) for pseudotime in pseudotime_list]

# ensure pseudotime tensors have shape (seqlen, 1) before padding
pseudotime_tensor_list = [
    pseudotime if len(pseudotime.shape) == 2 else pseudotime.unsqueeze(1) for pseudotime in pseudotime_tensor_list
]

  pseudotime_tensor_list = [torch.tensor(pseudotime, dtype=torch.float32) for pseudotime in pseudotime_list]


In [62]:
print(pseudotime_tensor_list[0].shape)

torch.Size([50663, 1])


In [63]:
max_pseudotime_len = max(pseudotime.shape[0] for pseudotime in pseudotime_tensor_list)

padded_pseudotime_tensor_list = [
    torch.cat([pseudotime, torch.zeros((max_pseudotime_len - pseudotime.shape[0], pseudotime.shape[1]), dtype=pseudotime.dtype)], dim=0)
    if pseudotime.shape[0] < max_pseudotime_len else pseudotime for pseudotime in pseudotime_tensor_list
]

pseudotime_tensor = torch.stack(padded_pseudotime_tensor_list, dim=0)

In [64]:
pseudotime_tensor.shape

torch.Size([15, 50663, 1])

In [65]:
# seq dim = 1024, latent dim = 50, pseudotime dim = 1

In [66]:
from torch.utils.data import DataLoader, TensorDataset, random_split

full_dataset = TensorDataset(latent_tensor, pseudotime_tensor, sequence_tensor)

# split size calculation
total_size = len(full_dataset)
train_size = int(0.7 * total_size)
val_size = total_size - train_size

# dataset split
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# dataloader creation
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [67]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7975e3ab7100>

In [68]:
for batch in train_dataloader:
    scvi, dpt, seq = batch
    print(scvi.shape)
    print(seq.shape)
    print(dpt.shape)

torch.Size([1, 50663, 50])
torch.Size([1, 188, 1024])
torch.Size([1, 50663, 1])
torch.Size([1, 50663, 50])
torch.Size([1, 188, 1024])
torch.Size([1, 50663, 1])
torch.Size([1, 50663, 50])
torch.Size([1, 188, 1024])
torch.Size([1, 50663, 1])
torch.Size([1, 50663, 50])
torch.Size([1, 188, 1024])
torch.Size([1, 50663, 1])
torch.Size([1, 50663, 50])
torch.Size([1, 188, 1024])
torch.Size([1, 50663, 1])
torch.Size([1, 50663, 50])
torch.Size([1, 188, 1024])
torch.Size([1, 50663, 1])
torch.Size([1, 50663, 50])
torch.Size([1, 188, 1024])
torch.Size([1, 50663, 1])
torch.Size([1, 50663, 50])
torch.Size([1, 188, 1024])
torch.Size([1, 50663, 1])
torch.Size([1, 50663, 50])
torch.Size([1, 188, 1024])
torch.Size([1, 50663, 1])
torch.Size([1, 50663, 50])
torch.Size([1, 188, 1024])
torch.Size([1, 50663, 1])


### models

In [69]:
class ProteinFlowMatching(nn.Module):
    def __init__(self, flow_matching, decoder):
        super().__init__()
        self.flow_matching = flow_matching
        self.decoder = decoder

    def forward(self, x, context):
        # During training, we only need to compute the loss from flow matching
        return self.flow_matching(x, context)

    def generate(self, x, context, num_steps=200):
        # Generate latent representation using flow matching
        latent = self.flow_matching.euler_sample(context, x.shape, guidance_scale=3.0)[0]

        # Decode the latent representation to protein sequence
        protein_sequence = self.decoder(latent, max_length=num_steps)

        return protein_sequence

In [72]:
# Initialize models
unet = CustomUNet1D(
    in_channels=50,  # Dimension of scVI latents
    out_channels=1024,  # Dimension of protein embeddings
    model_channels=64,
    num_res_blocks=2,
    attention_resolutions=(1,),
    dropout=0.1,
    channel_mult=(1, 2, 4, 8),
    use_spatial_transformer=True,
    transformer_depth=1,
    context_dim=1,  # Dimension of pseudotime
)
flow_matching = FlowMatchingTrainer(unet, sample_N=25)
decoder = ProtT5DecodingModule()  # vocab_size is the number of amino acids + special tokens
model = ProteinFlowMatching(flow_matching, decoder)
optimizer = Adam(model.parameters(), lr=1e-4)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

model


ProteinFlowMatching(
  (flow_matching): FlowMatchingTrainer(
    (model): CustomUNet1D(
      (time_embed): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): SiLU()
        (2): Linear(in_features=256, out_features=256, bias=True)
      )
      (input_blocks): ModuleList(
        (0): Conv1d(50, 64, kernel_size=(3,), stride=(1,), padding=(1,))
        (1-2): 2 x Sequential(
          (0): ResBlock1D(
            (in_layers): Sequential(
              (0): GroupNorm(32, 64, eps=1e-05, affine=True)
              (1): SiLU()
              (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
            )
            (emb_layers): Sequential(
              (0): SiLU()
              (1): Linear(in_features=256, out_features=64, bias=True)
            )
            (out_layers): Sequential(
              (0): GroupNorm(32, 64, eps=1e-05, affine=True)
              (1): SiLU()
              (2): Dropout(p=0.1, inplace=False)
              (3

### train

In [71]:

num_epochs = 1

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for scvi_latent, pseudotime, protein_seq in train_dataloader:
        scvi_latent, pseudotime, protein_seq = scvi_latent.to(device), pseudotime.to(device), protein_seq.to(device)

        optimizer.zero_grad()

        loss = flow_matching(scvi_latent, pseudotime)  # Using pseudotime as context
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for scvi_latent, pseudotime, protein_seq in val_dataloader:
            scvi_latent, pseudotime, protein_seq = scvi_latent.to(device), pseudotime.to(device), protein_seq.to(device)
            loss = flow_matching(scvi_latent, pseudotime)
            val_loss += loss.item()

    avg_train_loss = train_loss / len(train_dataloader)
    avg_val_loss = val_loss / len(val_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

# Generate new protein sequences
model.eval()
with torch.no_grad():
    scvi_latent = torch.randn(2000, 50).to(device)  # Random scVI latent
    pseudotime = torch.rand(2000, 1).to(device)  # Random pseudotime
    generated_sequence = model.generate(scvi_latent, pseudotime)
    print("Generated sequence:", generated_sequence)

Module type input blocks: <class 'torch.nn.modules.conv.Conv1d'>
Module type input blocks: <class 'torch.nn.modules.container.Sequential'>
submodule type input blocks: <class '__main__.ResBlock1D'>
submodule type input blocks: <class '__main__.SpatialTransformer1D'>


OutOfMemoryError: CUDA out of memory. Tried to allocate 76.50 GiB. GPU 