In [1]:
import polars as pl
import numpy as np
df = pl.read_csv('hf://datasets/Zehui127127/latent-dna-diffusion/sequence.csv')
df

Sequence,species
str,str
"""TATACAAATTTATTAAATTGCAATACATAG…","""Apis mellifera (Honey bee)."""
"""AACACAATAATAGTATGATCAAAAACAAGT…","""Apis mellifera (Honey bee)."""
"""TATATAATAAATTCTAATTCAGTTGAAATA…","""Apis mellifera (Honey bee)."""
"""TCATAATCATATACATATATATTCTATTTT…","""Apis mellifera (Honey bee)."""
"""GTAATTGAGATAACTGATACATTTCACTTT…","""Apis mellifera (Honey bee)."""
…,…
"""TTCTTCTAACACTTATTATAAACTAAATTG…","""Zea mays (corn)."""
"""GTCATCCAGCAAGCGGCGCGCCTGCTGGGT…","""Zea mays (corn)."""
"""AGGAATGCATGTCTTTTGCTGACACAGTTG…","""Zea mays (corn)."""
"""GTAAGCTTAGTTTTTGTGATCAGCTTCCAG…","""Zea mays (corn)."""


In [None]:
data = df["Sequence"].to_numpy().astype(str)
labels = df["species"].to_numpy().astype(str)
print("data: ", data.shape, data.dtype)
print(data)
print()
print("labels: ", labels.shape, labels.dtype)
print(labels)

data:  (159123,) <U2048
['TATACAAATTTATTAAATTGCAATACATAGCTTTTGTGAACATTATTATCATTATTTAATATTCTATATTCAATATATAATTTCATAAAATTTTATATTTTATTGTAAAATAAATTATTATATATTTAGTAATGTTAAAAATAATAAAATTTAATACTATATTTCTTATAACATAAATTAAGAATATATCCATAAAAGACTATCAAAATTATTTTAAATATTATCAAAATTTAAAAATCTTAAAATATAAAAAAATTATCCATTAATATTAAAATATTATGTTATTTAAAATACAATACAAAATTAATATAAGTAAAAAATAAACTTATAATTACAAAATTTATTATTACATTATTTATGTATTACATACATTACATATATTTATATATATTATATACATATAATACATTTTATAAAAATATATTGTAATTTTTGTTGCTTAATGTTTAATTAGTGTCTTTTATTATAAAAACAAAAAATAAATATATTAAAAACAAAATTGCGAGAAATATGATTAATATTTATAATCATATTTATAATTATATAAATATTAATTTAACATTTTTTTTTTTGATTAATAATAAATTATTTCTTTCATGTAATCAAATTTACTATTACATTTATTACTAGCTTGTATTTTAAATCGATTGTATATAATCGAATAGTTTAATTTTGTAATTTAATTTAAAAATTAATAAATTAATTTAAACTTTTTCATATCTTTTACAATTAAAAATACAATTTTTTAGATGAAAGTTAATTTTTATGAAATGAAATAATATATCACTATAAAAATTAAATATTATTAAAATAATTATTTTATAATTATATAATAAATATAACTTTTTAATTTAAATAATATTAACTTATTTTATTATTTTATAAACAATATTTTTTATGAAAACATGTCACAAGTATAATGAAAATAGATTTAATATGGACTATGAAATGTACAACTAGTTGTACATCATATA

In [3]:
def one_hot_encode(data):
    sequence_length = len(data[0])
    chars = data.view("S1").reshape(-1, sequence_length, 4)[..., 0]
    masks = [chars == b"A", chars == b"C", chars == b"G", chars == b"T"]
    nums = np.select(masks, [0, 1, 2, 3], default=4)
    one_hot = np.eye(5)[nums]
    return one_hot

def one_hot_decode(one_hot):
    nums = np.select(one_hot.T.astype(bool), [0, 1, 2, 3, 4]).T
    chars = np.array([b"A", b"C", b"G", b"T", b"N"])[nums]
    return chars

def check_one_hot_encode(data, one_hot, only_first_n_entries=None):
    chars = one_hot_decode(one_hot)
    for i, (recon, row) in enumerate(zip(chars, data)):
        if only_first_n_entries is not None and i >= only_first_n_entries:
            break
        recon = "".join(recon.astype(str))
        if row != recon:
            return False
    return True


one_hot = one_hot_encode(data)
print("one_hot: ", one_hot.shape, one_hot.dtype)
print("is correct ", check_one_hot_encode(data, one_hot, only_first_n_entries=1000))
print(one_hot)

one_hot:  (159123, 2048, 5) float64
is correct  True
[[[0. 0. 0. 1. 0.]
  [1. 0. 0. 0. 0.]
  [0. 0. 0. 1. 0.]
  ...
  [0. 0. 0. 0. 1.]
  [0. 0. 0. 0. 1.]
  [0. 0. 0. 0. 1.]]

 [[1. 0. 0. 0. 0.]
  [1. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]
  ...
  [1. 0. 0. 0. 0.]
  [1. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0.]]

 [[0. 0. 0. 1. 0.]
  [1. 0. 0. 0. 0.]
  [0. 0. 0. 1. 0.]
  ...
  [1. 0. 0. 0. 0.]
  [0. 0. 0. 1. 0.]
  [0. 0. 0. 1. 0.]]

 ...

 [[1. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0.]
  [0. 0. 1. 0. 0.]
  ...
  [0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]
  [0. 0. 0. 1. 0.]]

 [[0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]
  [1. 0. 0. 0. 0.]
  ...
  [1. 0. 0. 0. 0.]
  [1. 0. 0. 0. 0.]
  [0. 0. 0. 1. 0.]]

 [[0. 0. 0. 1. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 1. 0. 0.]
  ...
  [0. 0. 0. 1. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 0. 1. 0.]]]


In [4]:
import torch
from torch import nn
from torch.nn import functional as F
import math
import treescope as ts


class Conv1d(nn.Conv1d):
    def forward(self, x):
        x = x.transpose(-1, -2)  # to channnel first convention
        x = super().forward(x)
        x = x.transpose(-1, -2)  # back to channel last convention
        return x


class ResidualBlock(nn.Module):
    def __init__(self, dim, kernel_size=5, activation=nn.SiLU()):
        super().__init__()
        self.residual = nn.Sequential(
            nn.LayerNorm(dim),
            activation,
            Conv1d(dim, dim, kernel_size, padding="same"),
            nn.LayerNorm(dim),
            activation,
            Conv1d(dim, dim, kernel_size, padding="same"),
        )

    def forward(self, x):
        return x + self.residual(x)


class DownSample(nn.Module):
    def __init__(self, in_dim, out_dim, factor=2):
        super().__init__()
        self.kernel = nn.Linear(in_dim * factor, out_dim)
        self.factor = factor

    def forward(self, x):
        *B, L, D = x.shape
        x = x.reshape(*B, L // self.factor, self.factor * D)
        return self.kernel(x)


class UpSample(nn.Module):
    def __init__(self, in_dim, out_dim, factor=2):
        super().__init__()
        self.kernel = nn.Linear(in_dim // factor, out_dim)
        self.factor = factor

    def forward(self, x):
        *B, L, D = x.shape
        x = x.reshape(*B, L * self.factor, D // self.factor)
        return self.kernel(x)


class SequenceBVAE(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=5, activation=nn.SiLU(), beta=1.0):
        super().__init__()
        self.beta = beta
        self.encoder = nn.Sequential(
            Conv1d(in_dim, 8, kernel_size, padding="same"),
            ResidualBlock(8, kernel_size, activation),
            ResidualBlock(8, kernel_size, activation),
            ResidualBlock(8, kernel_size, activation),
            DownSample(8, 16, factor=4),
            ResidualBlock(16, kernel_size, activation),
            ResidualBlock(16, kernel_size, activation),
            ResidualBlock(16, kernel_size, activation),
            DownSample(16, 32, factor=4),
            ResidualBlock(32, kernel_size, activation),
            ResidualBlock(32, kernel_size, activation),
            ResidualBlock(32, kernel_size, activation),
            Conv1d(32, 64, kernel_size, padding="same"),
        )
        self.decoder = nn.Sequential(
            ResidualBlock(32, kernel_size, activation),
            ResidualBlock(32, kernel_size, activation),
            ResidualBlock(32, kernel_size, activation),
            UpSample(32, 16, factor=4),
            ResidualBlock(16, kernel_size, activation),
            ResidualBlock(16, kernel_size, activation),
            ResidualBlock(16, kernel_size, activation),
            UpSample(16, 8, factor=4),
            ResidualBlock(8, kernel_size, activation),
            ResidualBlock(8, kernel_size, activation),
            ResidualBlock(8, kernel_size, activation),
            Conv1d(8, in_dim, kernel_size, padding="same"),
        )

    def encode(self, x):
        mu, sigma = torch.chunk(self.encoder(x), 2, dim=-1)
        return mu, sigma

    def decode(self, z):
        return self.decoder(z)

    def losses(self, x):
        mu, sigma = self.encode(x)
        z = mu + sigma * torch.randn_like(mu)
        x_recon = self.decode(z)
        kl = 0.5 * (sigma**2 + mu**2 - (1e-8 + sigma**2).log() - 1).mean()
        recon = (x_recon - x).pow(2).mean()
        return recon, self.beta * kl

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm   

dataloader = DataLoader(one_hot, batch_size=1024, shuffle=True)

In [9]:
model = SequenceBVAE(5, 128, beta=0.001)
ts.display(model)

device = torch.device("cuda:2")
model = model.to(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
for epoch in range(50):
    for el in (pbar:=tqdm(dataloader)):
        el = el.to(device=device, dtype=torch.float32)

        loss_recon, loss_kl = model.losses(el)

        optimizer.zero_grad()
        (loss_recon + loss_kl).backward()
        optimizer.step()

        pbar.set_description(f"Loss recon: {loss_recon.item():.6f}, Loss kl: {loss_kl.item():.6f}")

torch.save(model, "model.pth")

Loss recon: 0.143699, Loss kl: 0.002363: 100%|██████████| 156/156 [01:03<00:00,  2.46it/s]
Loss recon: 0.128132, Loss kl: 0.002509: 100%|██████████| 156/156 [01:03<00:00,  2.45it/s]
Loss recon: 0.116920, Loss kl: 0.002910: 100%|██████████| 156/156 [01:03<00:00,  2.44it/s]
Loss recon: 0.103977, Loss kl: 0.003158: 100%|██████████| 156/156 [01:03<00:00,  2.44it/s]
Loss recon: 0.087798, Loss kl: 0.003516: 100%|██████████| 156/156 [01:03<00:00,  2.45it/s]
Loss recon: 0.072358, Loss kl: 0.003839: 100%|██████████| 156/156 [01:03<00:00,  2.45it/s]
Loss recon: 0.060633, Loss kl: 0.004235: 100%|██████████| 156/156 [01:03<00:00,  2.45it/s]
Loss recon: 0.050286, Loss kl: 0.004386: 100%|██████████| 156/156 [01:03<00:00,  2.45it/s]
Loss recon: 0.045342, Loss kl: 0.004116: 100%|██████████| 156/156 [01:03<00:00,  2.45it/s]
Loss recon: 0.042459, Loss kl: 0.004131: 100%|██████████| 156/156 [01:03<00:00,  2.46it/s]
Loss recon: 0.026149, Loss kl: 0.004598: 100%|██████████| 156/156 [01:03<00:00,  2.45it/s]

In [6]:
model = torch.load("model.pth")
ts.display(model)

  model = torch.load("model.pth")


In [10]:
device = torch.device("cuda:2")
model = model.to(device=device)
for el in (pbar := tqdm(list(dataloader)[:10])):
    el = el.to(device=device, dtype=torch.float32)
    encoded, _ = model.encode(el)
    decoded = model.decode(encoded)
    quantized = torch.eye(5).to(device)[decoded.argmax(dim=-1)]

    ts.display(
        {"original": el[0].T, "decoded": decoded[0].T, "quantized": quantized[0].T},
        autovisualize=True,
    )

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:00<00:03,  2.61it/s]

 20%|██        | 2/10 [00:00<00:02,  3.71it/s]

 30%|███       | 3/10 [00:00<00:01,  4.30it/s]

 40%|████      | 4/10 [00:00<00:01,  4.66it/s]

 50%|█████     | 5/10 [00:01<00:01,  4.88it/s]

 60%|██████    | 6/10 [00:01<00:00,  5.02it/s]

 70%|███████   | 7/10 [00:01<00:00,  5.12it/s]

 80%|████████  | 8/10 [00:01<00:00,  5.18it/s]

 90%|█████████ | 9/10 [00:01<00:00,  5.21it/s]

100%|██████████| 10/10 [00:02<00:00,  4.80it/s]
