<a href="https://colab.research.google.com/github/AratrikSarkar/ML/blob/main/new.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from IPython.utils.path import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

# ================= CONFIG =================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BLOCK_SIZE = 8
LATENT_DIM = 8
EPOCHS = 50
BATCH_SIZE = 32
LR = 5e-4

DNA_VOCAB = ['A','C','G','T','N']
stoi = {c:i for i,c in enumerate(DNA_VOCAB)}
itos = {i:c for c,i in stoi.items()}
PAD = stoi['N']

# ================= DATA =================
def read_fasta(path):
    seq = []
    with open(path) as f:
        for l in f:
            l=l.strip().upper()
            if l and not l.startswith(">"):
                seq.append(l)
    dna = "".join(seq)
    return dna

def encode_dna(seq):
    return [stoi[c] for c in seq]

def decode_dna(tokens):
    return "".join(itos[t] for t in tokens)

def chunk(tokens, size):
    out=[]
    for i in range(0,len(tokens),size):
        b=tokens[i:i+size]
        if len(b)<size: b=b+[PAD]*(size-len(b))
        out.append(b)
    return out

# ================= MODEL =================
class CNNEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(len(DNA_VOCAB),64)
        self.conv = nn.Sequential(
            nn.Conv1d(64,128,7,padding=3), nn.ReLU(),
            nn.Conv1d(128,256,7,padding=3), nn.ReLU(),
            nn.Conv1d(256,256,7,padding=3), nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(256, LATENT_DIM)

    def forward(self,x):
        x=self.emb(x).transpose(1,2)
        x=self.conv(x)
        x=self.pool(x).squeeze(-1)
        return self.fc(x)

class CNNDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(LATENT_DIM,256)
        self.conv = nn.Sequential(
          nn.Conv1d(256,256,7,padding=3), nn.ReLU(),
          nn.Conv1d(256,256,7,padding=3), nn.ReLU(),
          nn.Conv1d(256,128,7,padding=3), nn.ReLU(),
          nn.Conv1d(128,len(DNA_VOCAB),1)
        )


    def forward(self,z,seq_len):
        x=self.fc(z).unsqueeze(-1).repeat(1,1,seq_len)
        x=self.conv(x)
        return x.transpose(1,2)

class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc=CNNEncoder()
        self.dec=CNNDecoder()
    def forward(self,x):
        z=self.enc(x)
        return self.dec(z,x.size(1))

# ================= LOSS =================
def loss_fn(logits,targets):
    return F.cross_entropy(
        logits.reshape(-1,logits.size(-1)),
        targets.reshape(-1),
        ignore_index=PAD
    )

#======================ACCURACY FUNCTION========
def accuracy_fn(logits, targets):
    preds = torch.argmax(logits, dim=-1)
    mask = (targets != PAD)
    correct = (preds == targets) & mask
    return (correct.sum().float() / mask.sum().float()).item()

# ================= TRAIN FROM data/ =================
TRAIN_DIR = "data"
TEST_DIR = "testing"
OUT_DIR = "outputs"
os.makedirs(OUT_DIR, exist_ok=True)

# ---- Load training files ----
train_tokens = []
for fname in os.listdir(TRAIN_DIR):
    path = os.path.join(TRAIN_DIR, fname)
    if not os.path.isfile(path):
        continue
    dna = read_fasta(path)
    tokens = encode_dna(dna)
    train_tokens.extend(tokens)

if len(train_tokens) == 0:
    raise RuntimeError("No training data found in data/")

blocks = chunk(train_tokens, BLOCK_SIZE)
train_data = torch.tensor(blocks, dtype=torch.long).to(DEVICE)

model = AutoEncoder().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR)

# ===== TRAIN =====
for e in range(EPOCHS):
    model.train()
    perm = torch.randperm(train_data.size(0))
    total_loss = 0
    total_acc = 0
    steps = 0

    for i in range(0, train_data.size(0), BATCH_SIZE):
        batch = train_data[perm[i:i+BATCH_SIZE]]
        out = model(batch)
        loss = loss_fn(out, batch)
        acc = accuracy_fn(out, batch)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.item()
        total_acc += acc
        steps += 1

    print(f"Epoch {e+1}/{EPOCHS} | Loss {total_loss/steps:.4f} | Acc {total_acc/steps:.4f}")

# ================= TEST FROM testing/ =================
model.eval()
for fname in os.listdir(TEST_DIR):
    path = os.path.join(TEST_DIR, fname)
    if not os.path.isfile(path):
        continue

    dna = read_fasta(path)
    tokens = encode_dna(dna)
    blocks = chunk(tokens, BLOCK_SIZE)
    test_data = torch.tensor(blocks, dtype=torch.long).to(DEVICE)

    base = os.path.splitext(fname)[0]

    # ---- Save encoded ----
    with open(f"{OUT_DIR}/{base}_encoded.txt","w") as f:
        f.write(" ".join(map(str,tokens)))

    # ---- Save latents ----
    latents=[]
    with torch.no_grad():
        for i in range(test_data.size(0)):
            z=model.enc(test_data[i:i+1])
            latents.append(z.cpu())
    latents=torch.cat(latents,0)
    latents_q=(latents*127).round().clamp(-128,127).to(torch.int8)

    torch.save(latents_q, f"{OUT_DIR}/{base}_latent.pt")
    loaded_latents_q = torch.load(f"{OUT_DIR}/{base}_latent.pt")
    loaded_latents = loaded_latents_q.float()/127.0

    with open(f"{OUT_DIR}/{base}_latent.txt","w") as f:
        for row in loaded_latents:
            f.write(",".join(f"{v:.6f}" for v in row.tolist())+"\n")

    # ---- Reconstruct ----
    recon=[]
    with torch.no_grad():
        for i in range(0,loaded_latents.size(0),BATCH_SIZE):
            batch = loaded_latents[i:i+BATCH_SIZE].to(DEVICE)
            out=model.dec(batch, BLOCK_SIZE)
            pred=torch.argmax(out,-1).cpu().tolist()
            for b in pred: recon.extend(b)

    recon=recon[:len(tokens)]
    recon_text=decode_dna(recon)

    with open(f"{OUT_DIR}/{base}_reconstructed.txt","w") as f:
        f.write(recon_text)

    # ---- Metrics ----
    same=sum(a==b for a,b in zip(dna,recon_text))
    sim=same/len(dna)

    orig_size=len(dna)
    latent_txt_size=os.path.getsize(f"{OUT_DIR}/{base}_latent.txt")
    latent_pt_size=os.path.getsize(f"{OUT_DIR}/{base}_latent.pt")

    print(f"\nTEST FILE: {fname}")
    print("Similarity:", sim)
    print("Original:", orig_size)
    print("Latent.txt:", latent_txt_size, "ratio:", orig_size/latent_txt_size)
    print("Latent.pt :", latent_pt_size, "ratio:", orig_size/latent_pt_size)

Epoch 1/50 | Loss 0.2022 | Acc 0.9121
Epoch 2/50 | Loss 0.0198 | Acc 0.9939
Epoch 3/50 | Loss 0.0129 | Acc 0.9962
Epoch 4/50 | Loss 0.0089 | Acc 0.9975
Epoch 5/50 | Loss 0.0065 | Acc 0.9982
Epoch 6/50 | Loss 0.0060 | Acc 0.9984
Epoch 7/50 | Loss 0.0052 | Acc 0.9986
Epoch 8/50 | Loss 0.0037 | Acc 0.9990
Epoch 9/50 | Loss 0.0034 | Acc 0.9991
Epoch 10/50 | Loss 0.0032 | Acc 0.9992
Epoch 11/50 | Loss 0.0034 | Acc 0.9991
Epoch 12/50 | Loss 0.0031 | Acc 0.9992
Epoch 13/50 | Loss 0.0001 | Acc 1.0000
Epoch 14/50 | Loss 0.0017 | Acc 0.9996
Epoch 15/50 | Loss 0.0037 | Acc 0.9990
Epoch 16/50 | Loss 0.0030 | Acc 0.9992
Epoch 17/50 | Loss 0.0025 | Acc 0.9994
Epoch 18/50 | Loss 0.0032 | Acc 0.9992
Epoch 19/50 | Loss 0.0027 | Acc 0.9993
Epoch 20/50 | Loss 0.0022 | Acc 0.9994
Epoch 21/50 | Loss 0.0032 | Acc 0.9992
Epoch 22/50 | Loss 0.0015 | Acc 0.9996
Epoch 23/50 | Loss 0.0026 | Acc 0.9994
Epoch 24/50 | Loss 0.0013 | Acc 0.9997
Epoch 25/50 | Loss 0.0017 | Acc 0.9996
Epoch 26/50 | Loss 0.0012 | Acc 0.