In [19]:
import torch
import torch.nn as nn
import os
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

# DNA Vocabulary
DNA_VOCAB = ['A', 'C', 'G', 'T', 'N']
stoi = {ch: i for i, ch in enumerate(DNA_VOCAB)}
itos = {i: ch for ch, i in stoi.items()}
PAD_TOKEN = stoi['N']



In [67]:
#different types of configurations
CONFIG = {
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    "data": {
        "block_size": 128
    },

    "training": {
        "batch_size": 32,
        "epochs": 30,
        "learning_rate": 3e-4
    },

    "model": {
        "vocab_size": len(DNA_VOCAB),
        "latent_dim": 64,

        # "encoder": {
        #     "type": "lstm",
        #     "emb_dim": 64,
        #     "hidden_dim": 128,
        #     "num_heads": 4,
        #     "kernel_sizes": [3, 5, 7]
        # },

        # "decoder": {
        #     "type": "lstm",
        #     "hidden_dim": 128,
        #     "num_heads": 4
        # }
        # "encoder": {
        #     "type": "cnn_v2",
        #     "emb_dim": 64
        # },
        # "decoder": {
        #     "type": "cnn"
        # }

      "encoder": {
          "type": "transformer",
          "emb_dim": 128,
          "num_heads": 8,
          "num_layers": 4
      },
      "decoder": {
          "type": "transformer",
          "emb_dim": 128,
          "num_heads": 8,
          "num_layers": 4
      }
    }
}


In [46]:
#base encoder class for other encoders
class BaseEncoder(nn.Module):
    def forward(self, x):
        """
        Input:  x -> (B, L)
        Output: z -> (B, latent_dim)
        """
        raise NotImplementedError


In [47]:
class LSTMEncoder(BaseEncoder):
    def __init__(self, vocab_size, emb_dim, hidden_dim, latent_dim, num_heads):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, emb_dim)

        self.lstm = nn.LSTM(
            emb_dim,
            hidden_dim,
            batch_first=True,
            bidirectional=True
        )

        self.attention = nn.MultiheadAttention(
            2 * hidden_dim,
            num_heads,
            batch_first=True
        )

        self.fc_latent = nn.Linear(2 * hidden_dim, latent_dim)

    def forward(self, x):
        x = self.embedding(x)                # (B, L, E)
        out, _ = self.lstm(x)                # (B, L, 2H)
        attn, _ = self.attention(out, out, out)
        pooled = attn.mean(dim=1)
        return self.fc_latent(pooled)


In [48]:
class CNNEncoder(BaseEncoder):
    def __init__(self, vocab_size, emb_dim, latent_dim, kernel_sizes=[3,5,7]):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, emb_dim)

        self.convs = nn.ModuleList([
            nn.Conv1d(
                in_channels=emb_dim,
                out_channels=128,
                kernel_size=k,
                padding=k//2
            )
            for k in kernel_sizes
        ])

        self.fc_latent = nn.Linear(128 * len(kernel_sizes), latent_dim)

    def forward(self, x):
        x = self.embedding(x)       # (B, L, E)
        x = x.transpose(1, 2)       # (B, E, L)

        feats = []
        for conv in self.convs:
            f = F.relu(conv(x))
            f = f.mean(dim=2)       # global avg pooling
            feats.append(f)

        feats = torch.cat(feats, dim=1)
        return self.fc_latent(feats)


In [49]:
class CNNEncoderV2(BaseEncoder):
    def __init__(self, vocab_size, emb_dim, latent_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)

        self.conv = nn.Sequential(
            nn.Conv1d(emb_dim, 128, 7, padding=3),
            nn.ReLU(),
            nn.Conv1d(128, 256, 7, padding=3),
            nn.ReLU(),
            nn.Conv1d(256, 256, 7, padding=3),
            nn.ReLU()
        )

        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(256, latent_dim)

    def forward(self, x):
        x = self.embedding(x).transpose(1, 2)
        x = self.conv(x)
        x = self.pool(x).squeeze(-1)
        return self.fc(x)


In [50]:
class TransformerEncoder(BaseEncoder):
    def __init__(self, vocab_size, emb_dim, latent_dim, n_heads, n_layers):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=emb_dim,
            nhead=n_heads,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        self.fc = nn.Linear(emb_dim, latent_dim)

    def forward(self, x):
        x = self.emb(x)
        x = self.transformer(x)
        return self.fc(x.mean(dim=1))


In [51]:
#base decoder class for other decoders
class BaseDecoder(nn.Module):
    def forward(self, z, seq_len):
        """
        Input:  z -> (B, latent_dim)
        Output: logits -> (B, L, vocab)
        """
        raise NotImplementedError


In [52]:
class LSTMDecoder(BaseDecoder):
    def __init__(self, vocab_size, latent_dim, hidden_dim, num_heads, max_len):
        super().__init__()

        self.pos_embedding = nn.Embedding(max_len, latent_dim)

        self.lstm = nn.LSTM(
            latent_dim,
            hidden_dim,
            batch_first=True
        )

        self.attention = nn.MultiheadAttention(
            hidden_dim,
            num_heads,
            batch_first=True
        )

        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, z, seq_len):
        B = z.size(0)

        pos = torch.arange(seq_len, device=z.device)
        pos_emb = self.pos_embedding(pos).unsqueeze(0).repeat(B, 1, 1)

        z = z.unsqueeze(1).repeat(1, seq_len, 1) + pos_emb
        out, _ = self.lstm(z)
        out, _ = self.attention(out, out, out)
        return self.fc_out(out)


In [53]:
class CNNDecoder(BaseDecoder):
    def __init__(self, vocab_size, latent_dim, seq_len):
        super().__init__()
        self.seq_len = seq_len

        self.fc = nn.Linear(latent_dim, 256)

        self.conv = nn.Sequential(
            nn.Conv1d(256, 256, 7, padding=3),
            nn.ReLU(),
            nn.Conv1d(256, 128, 7, padding=3),
            nn.ReLU(),
            nn.Conv1d(128, vocab_size, 1)
        )

    def forward(self, z, seq_len):
        x = self.fc(z).unsqueeze(-1).repeat(1, 1, seq_len)
        x = self.conv(x)
        return x.transpose(1, 2)


In [54]:
class TransformerDecoder(BaseDecoder):
    def __init__(self, vocab_size, latent_dim, emb_dim, n_heads, n_layers, max_len):
        super().__init__()
        self.pos = nn.Embedding(max_len, emb_dim)
        self.fc = nn.Linear(latent_dim, emb_dim)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=emb_dim,
            nhead=n_heads,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, n_layers)
        self.out = nn.Linear(emb_dim, vocab_size)

    def forward(self, z, seq_len):
        B = z.size(0)
        tgt = self.fc(z).unsqueeze(1).repeat(1, seq_len, 1)
        pos = self.pos(torch.arange(seq_len, device=z.device)).unsqueeze(0)
        tgt = tgt + pos
        out = self.decoder(tgt, tgt)
        return self.out(out)


In [55]:
#auto encoder class
class DNAAutoEncoder(nn.Module):
    def __init__(self, encoder, decoder, seq_len):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.seq_len = seq_len

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z, self.seq_len)


In [56]:
#choosing encoder
def build_encoder(cfg):
    enc = cfg["model"]["encoder"]
    vocab = cfg["model"]["vocab_size"]
    latent = cfg["model"]["latent_dim"]

    if enc["type"] == "lstm":
        return LSTMEncoder(
            vocab_size=vocab,
            emb_dim=enc["emb_dim"],
            hidden_dim=enc["hidden_dim"],
            latent_dim=latent,
            num_heads=enc["num_heads"]
        )

    elif enc["type"] == "cnn":
        return CNNEncoder(
            vocab_size=vocab,
            emb_dim=enc["emb_dim"],
            latent_dim=latent,
            kernel_sizes=enc["kernel_sizes"]
        )

    elif enc["type"] == "cnn_v2":
        return CNNEncoderV2(
            vocab_size=vocab,
            emb_dim=enc["emb_dim"],
            latent_dim=latent
        )

    elif enc["type"] == "transformer":
        return TransformerEncoder(
            vocab_size=vocab,
            emb_dim=enc["emb_dim"],
            latent_dim=latent,
            n_heads=enc["num_heads"],
            n_layers=enc["num_layers"]
        )

    else:
        raise ValueError(f"Unknown encoder type: {enc['type']}")



In [57]:
#choosing decoder
def build_decoder(cfg):
    dec = cfg["model"]["decoder"]
    vocab = cfg["model"]["vocab_size"]
    latent = cfg["model"]["latent_dim"]
    seq_len = cfg["data"]["block_size"]

    if dec["type"] == "lstm":
        return LSTMDecoder(
            vocab_size=vocab,
            latent_dim=latent,
            hidden_dim=dec["hidden_dim"],
            num_heads=dec["num_heads"],
            max_len=seq_len
        )

    elif dec["type"] == "cnn":
        return CNNDecoder(
            vocab_size=vocab,
            latent_dim=latent,
            seq_len=seq_len
        )

    elif dec["type"] == "transformer":
        return TransformerDecoder(
            vocab_size=vocab,
            latent_dim=latent,
            emb_dim=dec["emb_dim"],
            n_heads=dec["num_heads"],
            n_layers=dec["num_layers"],
            max_len=seq_len
        )

    else:
        raise ValueError(f"Unknown decoder type: {dec['type']}")



In [58]:
#auto encoder building
def build_model(cfg):
    encoder = build_encoder(cfg)
    decoder = build_decoder(cfg)

    return DNAAutoEncoder(
        encoder=encoder,
        decoder=decoder,
        seq_len=cfg["data"]["block_size"]
    )


In [68]:
#reading fasta file
def read_fasta(path):
    sequence = []
    with open(path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('>'):
                continue
            sequence.append(line.upper())
    dna = "".join(sequence)
    validate_dna(dna)
    return dna

#validation
def validate_dna(seq, allowed=None):
    if allowed is None:
        allowed = {'A', 'C', 'G', 'T', 'N'}
    invalid = set(seq) - allowed
    if invalid:
        raise ValueError(f"Invalid DNA symbols found: {invalid}")

#characters => integer tokens
def encode_dna(seq):
    return [stoi[ch] for ch in seq]

#reverse
def decode_dna(tokens):
    return ''.join(itos[t] for t in tokens)

#Split long DNA into fixed-length blocks
def chunk_sequence(tokens, block_size, pad_token=PAD_TOKEN):
    chunks = []
    for i in range(0, len(tokens), block_size):
        block = tokens[i:i + block_size]
        if len(block) < block_size:
            block = block + [pad_token] * (block_size - len(block))
        chunks.append(block)
    return chunks

#Convert Python list â†’ PyTorch tensor
def prepare_dataset(blocks):
    return torch.tensor(blocks, dtype=torch.long)

#Measure how well the model reconstructs DNA
def reconstruction_loss(logits, targets):
    B, L, V = logits.shape

    logits = logits.reshape(B * L, V)
    targets = targets.reshape(B * L)

    return F.cross_entropy(
        logits,
        targets,
        ignore_index=PAD_TOKEN
    )

#Compute token-level reconstruction accuracy
def reconstruction_accuracy(logits, targets):
    """
    Computes token-level accuracy ignoring PAD tokens
    """
    preds = torch.argmax(logits, dim=-1)   # (B, L)

    mask = (targets != PAD_TOKEN)          # ignore padding
    correct = (preds == targets) & mask

    accuracy = correct.sum().float() / mask.sum().float()
    return accuracy.item()

#Compute base-by-base similarity between sequences
def sequence_similarity(original, reconstructed):
    matches = sum(o == r for o, r in zip(original, reconstructed))
    return matches / len(original)

#train
def train_autoencoder(model, dataset, epochs, batch_size, lr):
    model.to(device)
    dataset = dataset.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    num_samples = dataset.size(0)

    for epoch in range(epochs):
        model.train()
        perm = torch.randperm(num_samples)

        total_loss = 0.0
        total_acc = 0.0
        steps = 0

        for i in range(0, num_samples, batch_size):
            idx = perm[i:i + batch_size]
            batch = dataset[idx]

            logits = model(batch)
            loss = reconstruction_loss(logits, batch)
            acc = reconstruction_accuracy(logits, batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_acc += acc
            steps += 1
            # print(num_samples,batch_size)

        print(
            f"Epoch [{epoch+1}/{epochs}] | "
            f"Loss: {total_loss/steps:.4f} | "
            f"Train Acc: {total_acc/steps:.4f}"
        )





In [60]:
#saving latent in a text file
def save_latents_readable(model, dataset, path="latents.txt"):
    model.eval()
    latents = []

    with torch.no_grad(), open(path, "w") as f:
        for i in range(dataset.size(0)):
            x = dataset[i:i+1].to(device)
            z = model.encoder(x).squeeze(0).cpu().tolist()
            latents.append(z)

            f.write(",".join(f"{v:.6f}" for v in z) + "\n")

    print(f"Latents saved to {path}")
    return latents

#saving latent in a .pt file
def save_latents_binary(model, dataset, path="latents.pt"):
    model.eval()
    latents = []

    with torch.no_grad():
        for i in range(dataset.size(0)):
            x = dataset[i:i+1].to(device)
            z = model.encoder(x)
            latents.append(z.cpu())

    latents = torch.cat(latents, dim=0)  # (N, latent_dim)
    torch.save(latents, path)
    return latents

# def compare_sizes(raw_dna, latent_txt_path):
#     original_bytes = len(raw_dna)  # 1 char = 1 byte
#     compressed_bytes = os.path.getsize(latent_txt_path)

#     ratio = original_bytes / compressed_bytes

#     print(f"Original size   : {original_bytes} bytes")
#     print(f"Compressed size : {compressed_bytes} bytes")
#     print(f"Compression ratio: {ratio:.4f}")

    # return original_bytes, compressed_bytes, ratio

#comparison of decoded sequence and original sequence
def reconstruct_and_compare_safe(
    model,
    dataset,
    raw_dna,
    batch_size=16
):
    model.eval()
    reconstructed_tokens = []

    with torch.no_grad():
        for i in range(0, dataset.size(0), batch_size):
            batch = dataset[i:i+batch_size].to(device)

            logits = model(batch)                 # (B, L, V)
            preds = torch.argmax(logits, dim=-1)  # (B, L)

            for block in preds.cpu().tolist():
                reconstructed_tokens.extend(block)

    # Remove padding tokens beyond original length
    reconstructed_tokens = reconstructed_tokens[:len(raw_dna)]

    reconstructed_dna = decode_dna(reconstructed_tokens)
    similarity = sequence_similarity(raw_dna, reconstructed_dna)

    print("\n===== RECONSTRUCTION RESULT =====")
    print("Original (first 200):")
    print(raw_dna[:200])

    print("\nReconstructed (first 200):")
    print(reconstructed_dna[:200])

    print(f"\nSequence similarity: {similarity:.4f}")

    return reconstructed_dna, similarity

#compression ratio for text file
def compression_ratio_text(raw_dna, latent_txt_file):
    """
    Compression ratio using readable text latents
    """
    original_size = len(raw_dna)                 # bytes
    compressed_size = os.path.getsize(latent_txt_file)

    ratio = original_size / compressed_size

    print("\n===== COMPRESSION (TEXT LATENTS) =====")
    print(f"Original size   : {original_size} bytes")
    print(f"Compressed size : {compressed_size} bytes")
    print(f"Compression ratio: {ratio:.4f}")

    return original_size, compressed_size, ratio

#compression ratio for .pt file
def compression_ratio_binary(raw_dna, latent_pt_file):
    """
    Compression ratio using binary tensor latents
    """
    original_size = len(raw_dna)                  # bytes
    compressed_size = os.path.getsize(latent_pt_file)

    ratio = original_size / compressed_size

    print("\n===== COMPRESSION (BINARY LATENTS) =====")
    print(f"Original size   : {original_size} bytes")
    print(f"Compressed size : {compressed_size} bytes")
    print(f"Compression ratio: {ratio:.4f}")

    return original_size, compressed_size, ratio



In [69]:
if __name__ == "__main__":

    #data
    dna_seq = read_fasta("AeCa .txt")
    dna_seq = dna_seq[:50000]
    tokens = encode_dna(dna_seq)
    blocks = chunk_sequence(tokens, CONFIG["data"]["block_size"])
    dataset = prepare_dataset(blocks)

    #model
    model = build_model(CONFIG).to(CONFIG["device"])

    #train
    train_autoencoder(
        model=model,
        dataset=dataset,
        epochs=CONFIG["training"]["epochs"],
        batch_size=CONFIG["training"]["batch_size"],
        lr=CONFIG["training"]["learning_rate"]
    )


Epoch [1/30] | Loss: 1.4795 | Train Acc: 0.2618
Epoch [2/30] | Loss: 1.3940 | Train Acc: 0.2633
Epoch [3/30] | Loss: 1.3905 | Train Acc: 0.2609
Epoch [4/30] | Loss: 1.3882 | Train Acc: 0.2628
Epoch [5/30] | Loss: 1.3862 | Train Acc: 0.2681
Epoch [6/30] | Loss: 1.3813 | Train Acc: 0.2784
Epoch [7/30] | Loss: 1.3795 | Train Acc: 0.2795
Epoch [8/30] | Loss: 1.3774 | Train Acc: 0.2876
Epoch [9/30] | Loss: 1.3775 | Train Acc: 0.2879
Epoch [10/30] | Loss: 1.3779 | Train Acc: 0.2950
Epoch [11/30] | Loss: 1.3744 | Train Acc: 0.2997
Epoch [12/30] | Loss: 1.3679 | Train Acc: 0.3167
Epoch [13/30] | Loss: 1.3656 | Train Acc: 0.3202
Epoch [14/30] | Loss: 1.3651 | Train Acc: 0.3182
Epoch [15/30] | Loss: 1.3614 | Train Acc: 0.3213
Epoch [16/30] | Loss: 1.3597 | Train Acc: 0.3222
Epoch [17/30] | Loss: 1.3579 | Train Acc: 0.3270
Epoch [18/30] | Loss: 1.3583 | Train Acc: 0.3235
Epoch [19/30] | Loss: 1.3558 | Train Acc: 0.3275
Epoch [20/30] | Loss: 1.3565 | Train Acc: 0.3273
Epoch [21/30] | Loss: 1.3557 

In [70]:
# ===== LATENT ANALYSIS =====


latents = save_latents_readable(
    model=model,
    dataset=dataset,
    path="latents.txt"
)

# compare_sizes(dna_seq, "latents.txt")

reconstructed_dna, similarity = reconstruct_and_compare_safe(
    model=model,
    dataset=dataset,
    raw_dna=dna_seq,
    batch_size=16
)
# ===== SAVE LATENTS =====
save_latents_binary(model, dataset, "latents.pt")

# ===== COMPRESSION =====
compression_ratio_text(dna_seq, "latents.txt")
compression_ratio_binary(dna_seq, "latents.pt")



Latents saved to latents.txt

===== RECONSTRUCTION RESULT =====
Original (first 200):
GCCGCCCCCATGGTCCATACGGTGTGCGAATACGGCGTGGCCCTCCTTACCCCATCCAGGCCTCTCTACGCCCCACTTGTCTATAGTGCCTTTCACGACCCTGGCCACAAGGTCGATCGCGTACTCCCAGGAGACAGGCTCTAGCCTTCCTCCGAACCTTACGAGGGGCTTTGTCAGCCTCCTCTCGCCTGCTATGTTCC

Reconstructed (first 200):
CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC

Sequence similarity: 0.3352

===== COMPRESSION (TEXT LATENTS) =====
Original size   : 50000 bytes
Compressed size : 236640 bytes
Compression ratio: 0.2113

===== COMPRESSION (BINARY LATENTS) =====
Original size   : 50000 bytes
Compressed size : 101673 bytes
Compression ratio: 0.4918


(50000, 101673, 0.49177264367137785)