<a href="https://colab.research.google.com/github/AratrikSarkar/ML/blob/main/compressedUsingZSTD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import subprocess

# ================= CONFIG =================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BLOCK_SIZE = 8
LATENT_DIM = 8
EPOCHS = 50
BATCH_SIZE = 32
LR = 5e-4

DNA_VOCAB = ['A','C','G','T','N']
stoi = {c:i for i,c in enumerate(DNA_VOCAB)}
itos = {i:c for c,i in stoi.items()}
PAD = stoi['N']

# ================= DATA =================
def read_fasta(path):
    seq = []
    with open(path) as f:
        for l in f:
            l = l.strip().upper()
            if l and not l.startswith(">"):
                seq.append(l)
    return "".join(seq)

def encode_dna(seq):
    return [stoi[c] for c in seq]

def decode_dna(tokens):
    return "".join(itos[int(t)] for t in tokens)

def chunk(tokens, size):
    out = []
    for i in range(0, len(tokens), size):
        b = tokens[i:i+size]
        if len(b) < size:
            b += [PAD] * (size - len(b))
        out.append(b)
    return out

# ================= MODEL =================
class CNNEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(len(DNA_VOCAB), 64)
        self.conv = nn.Sequential(
            nn.Conv1d(64,128,7,padding=3), nn.ReLU(),
            nn.Conv1d(128,256,7,padding=3), nn.ReLU(),
            nn.Conv1d(256,256,7,padding=3), nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(256, LATENT_DIM)

    def forward(self, x):
        x = self.emb(x).transpose(1,2)
        x = self.conv(x)
        x = self.pool(x).squeeze(-1)
        return self.fc(x)

class CNNDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(LATENT_DIM,256)
        self.conv = nn.Sequential(
            nn.Conv1d(256,256,7,padding=3), nn.ReLU(),
            nn.Conv1d(256,256,7,padding=3), nn.ReLU(),
            nn.Conv1d(256,128,7,padding=3), nn.ReLU(),
            nn.Conv1d(128,len(DNA_VOCAB),1)
        )

    def forward(self, z, seq_len):
        x = self.fc(z).unsqueeze(-1).repeat(1,1,seq_len)
        x = self.conv(x)
        return x.transpose(1,2)

class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = CNNEncoder()
        self.dec = CNNDecoder()

    def forward(self, x):
        z = self.enc(x)
        return self.dec(z, x.size(1))

# ================= LOSS =================
def loss_fn(logits, targets):
    return F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        targets.reshape(-1),
        ignore_index=PAD
    )

def accuracy_fn(logits, targets):
    preds = torch.argmax(logits, dim=-1)
    mask = targets != PAD
    correct = (preds == targets) & mask
    return (correct.sum().float() / mask.sum().float()).item()

# ================= PATHS =================
TRAIN_DIR = "data"
TEST_DIR  = "testing"
OUT_DIR   = "outputs"
os.makedirs(OUT_DIR, exist_ok=True)

# ================= LOAD TRAIN DATA =================
train_tokens = []
for f in os.listdir(TRAIN_DIR):
    p = os.path.join(TRAIN_DIR, f)
    if os.path.isfile(p):
        train_tokens.extend(encode_dna(read_fasta(p)))

train_data = torch.tensor(chunk(train_tokens, BLOCK_SIZE), dtype=torch.long).to(DEVICE)

# ================= TRAIN =================
model = AutoEncoder().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR)

for e in range(EPOCHS):
    model.train()
    perm = torch.randperm(train_data.size(0))
    loss_sum = acc_sum = steps = 0

    for i in range(0, train_data.size(0), BATCH_SIZE):
        batch = train_data[perm[i:i+BATCH_SIZE]]
        out = model(batch)
        loss = loss_fn(out, batch)
        acc  = accuracy_fn(out, batch)

        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_sum += loss.item()
        acc_sum  += acc
        steps += 1

    print(f"Epoch {e+1}/{EPOCHS} | Loss {loss_sum/steps:.4f} | Acc {acc_sum/steps:.4f}")

# ================= TEST + SAVE FILES =================
model.eval()

for fname in os.listdir(TEST_DIR):
    path = os.path.join(TEST_DIR, fname)
    if not os.path.isfile(path):
        continue

    base = os.path.splitext(fname)[0]
    dna = read_fasta(path)
    tokens = encode_dna(dna)

    # ---- encoded.txt ----
    with open(f"{OUT_DIR}/{base}_encoded.txt", "w") as f:
        f.write(" ".join(map(str, tokens)))

    blocks = chunk(tokens, BLOCK_SIZE)
    test_data = torch.tensor(blocks, dtype=torch.long).to(DEVICE)

    # ---- Encode latents ----
    latents = []
    with torch.no_grad():
        for i in range(test_data.size(0)):
            latents.append(model.enc(test_data[i:i+1]).cpu())
    latents = torch.cat(latents, 0)

    latents_q = (latents * 127).round().clamp(-128,127).to(torch.int8)

    # ---- latent.pt ----
    latent_pt = f"{OUT_DIR}/{base}_latent.pt"
    torch.save(latents_q, latent_pt)

    # ---- zstd compression ----
    latent_zst = latent_pt + ".zst"
    subprocess.run(["zstd", "-19", latent_pt, "-o", latent_zst], check=True)

    # ---- latent.txt ----
    with open(f"{OUT_DIR}/{base}_latent.txt", "w") as f:
        for row in (latents_q.float()/127):
            f.write(",".join(f"{v:.6f}" for v in row.tolist()) + "\n")

    # ---- decompress for decoding ----
    subprocess.run(["zstd", "-d", "-f", latent_zst, "-o", latent_pt], check=True)
    loaded = torch.load(latent_pt).float() / 127.0

    # ---- reconstruct ----
    recon = []
    with torch.no_grad():
        for i in range(0, loaded.size(0), BATCH_SIZE):
            out = model.dec(loaded[i:i+BATCH_SIZE].to(DEVICE), BLOCK_SIZE)
            preds = torch.argmax(out, -1).cpu().reshape(-1)
            recon.extend(preds.tolist())

    recon = recon[:len(tokens)]
    recon_text = decode_dna(recon)

    # ---- reconstructed.txt ----
    with open(f"{OUT_DIR}/{base}_reconstructed.txt", "w") as f:
        f.write(recon_text)

    # ---- metrics ----
    similarity = sum(a==b for a,b in zip(dna,recon_text)) / len(dna)

    print(f"\nTEST FILE: {fname}")
    print("Similarity:", similarity)
    print("Original size:", len(dna))
    print("latent.pt size:", os.path.getsize(latent_pt))
    print("latent.zst size:", os.path.getsize(latent_zst))


Epoch 1/50 | Loss 0.2041 | Acc 0.9111
Epoch 2/50 | Loss 0.0250 | Acc 0.9919
Epoch 3/50 | Loss 0.0138 | Acc 0.9958
Epoch 4/50 | Loss 0.0105 | Acc 0.9970
Epoch 5/50 | Loss 0.0075 | Acc 0.9979
Epoch 6/50 | Loss 0.0058 | Acc 0.9984
Epoch 7/50 | Loss 0.0042 | Acc 0.9989
Epoch 8/50 | Loss 0.0042 | Acc 0.9989
Epoch 9/50 | Loss 0.0042 | Acc 0.9988
Epoch 10/50 | Loss 0.0040 | Acc 0.9989
Epoch 11/50 | Loss 0.0035 | Acc 0.9991
Epoch 12/50 | Loss 0.0025 | Acc 0.9994
Epoch 13/50 | Loss 0.0022 | Acc 0.9994
Epoch 14/50 | Loss 0.0042 | Acc 0.9989
Epoch 15/50 | Loss 0.0019 | Acc 0.9996
Epoch 16/50 | Loss 0.0027 | Acc 0.9993
Epoch 17/50 | Loss 0.0021 | Acc 0.9995
Epoch 18/50 | Loss 0.0028 | Acc 0.9993
Epoch 19/50 | Loss 0.0023 | Acc 0.9994
Epoch 20/50 | Loss 0.0024 | Acc 0.9994
Epoch 21/50 | Loss 0.0023 | Acc 0.9995
Epoch 22/50 | Loss 0.0012 | Acc 0.9997
Epoch 23/50 | Loss 0.0011 | Acc 0.9998
Epoch 24/50 | Loss 0.0019 | Acc 0.9996
Epoch 25/50 | Loss 0.0021 | Acc 0.9995
Epoch 26/50 | Loss 0.0016 | Acc 0.

In [1]:
!apt-get install -y zstd


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following NEW packages will be installed:
  zstd
0 upgraded, 1 newly installed, 0 to remove and 41 not upgraded.
Need to get 603 kB of archives.
After this operation, 1,695 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 zstd amd64 1.4.8+dfsg-3build1 [603 kB]
Fetched 603 kB in 0s (3,091 kB/s)
Selecting previously unselected package zstd.
(Reading database ... 121689 files and directories currently installed.)
Preparing to unpack .../zstd_1.4.8+dfsg-3build1_amd64.deb ...
Unpacking zstd (1.4.8+dfsg-3build1) ...
Setting up zstd (1.4.8+dfsg-3build1) ...
Processing triggers for man-db (2.10.2-1) ...
