<a href="https://colab.research.google.com/github/Valasik0/dna-sequence-llm/blob/first-prototype-test/dna_llm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import gzip
import wandb
import yaml
from dataclasses import dataclass, asdict
from typing import Dict, Any
import time
import os

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [51]:
@dataclass
class ExperimentConfig:
    vocab_size: int = 5
    max_len: int = 256
    d_model: int = 128
    n_heads: int = 4
    n_layers: int = 4
    batch_size: int = 256
    epochs: int = 10
    lr: float = 1e-3
    mask_prob: float = 0.15
    mode: str = "fasta"
    n_samples: int = 1000
    max_len_fasta: int = 15000
    experiment_name: str = "dna_transformer_baseline"
    run_name: str = None

def setup_experiment(config: ExperimentConfig):
    if config.run_name is None:
        config.run_name = f"{config.experiment_name}_{int(time.time())}"
    output_dir = f"outputs/{config.run_name}"
    os.makedirs(output_dir, exist_ok=True)
    with open(f"{output_dir}/config.txt", "w") as f:
        f.write(str(asdict(config)))
    return output_dir

In [32]:
class DNATokenizer:
    def __init__(self, method="single", k=3):
        self.method = method
        self.k = k

        if method == "single":
            self.vocab = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'MASK': 4}
        elif method == "kmer":
            self.vocab = self._build_kmer_vocab(k)
        else:
          pass
          #dalsi tokenizery (pozdejsi experimenty)

    def _build_kmer_vocab(self, k):
        bases = ['A', 'C', 'G', 'T']
        kmers = [''.join(p) for p in itertools.product(bases, repeat=k)]
        vocab = {kmer: i for i, kmer in enumerate(kmers)}
        vocab['MASK'] = len(vocab)
        return vocab

    def encode(self, sequence):
        if self.method == "single":
            return [self.vocab.get(x, 0) for x in sequence.upper()]
        elif self.method == "kmer":
            tokens = []
            for i in range(len(sequence) - self.k + 1):
                kmer = sequence[i:i+self.k]
                tokens.append(self.vocab.get(kmer, 0))
            return tokens

In [35]:
def span_mask(x, mask_prob=0.15, span_len_range=(3, 8), mask_idx=4):
    """Span masking pro simulaci strukturálních variant"""
    masked = x.clone()
    labels = x.clone()
    labels.fill_(-100)  # ignore index

    for i in range(x.size(0)):
        pos = 0
        while pos < x.size(1):
            if random.random() < mask_prob:
                span_len = random.randint(*span_len_range)
                end_pos = min(pos + span_len, x.size(1))
                masked[i, pos:end_pos] = mask_idx
                labels[i, pos:end_pos] = x[i, pos:end_pos]
                pos = end_pos
            else:
                pos += 1

    return masked, labels

In [39]:
def mask_input(x, mask_prob=0.15):
    masked = x.clone()
    mask = torch.rand_like(x.float()) < mask_prob
    masked[mask] = 4
    labels = x.clone()
    labels[~mask] = -100
    return masked, labels

In [37]:
def curriculum_mask(x, epoch, total_epochs, base_prob=0.15, max_prob=0.30, mask_idx=4):
    """Postupné zvyšování obtížnosti maskování"""
    progress = epoch / total_epochs
    current_prob = base_prob + (max_prob - base_prob) * progress
    return mask_input(x, mask_prob=current_prob, mask_idx=mask_idx)

In [36]:
def evaluate_model(model, sequences, tokenizer, device, mask_strategy="random"):
    """Základní vyhodnocení modelu"""
    model.eval()
    total_loss = 0
    total_acc = 0
    num_batches = 0

    with torch.no_grad():
        for i in range(0, len(sequences), 64):  # batch_size = 64
            batch_seqs = sequences[i:i+64]
            batch_tokens = [tokenizer.encode(seq) for seq in batch_seqs]
            batch_tokens = [t for t in batch_tokens if len(t) == model.max_len]

            if len(batch_tokens) == 0:
                continue

            x = torch.tensor(batch_tokens, dtype=torch.long).to(device)

            if mask_strategy == "random":
                masked_x, labels = mask_input(x)
            else:
                masked_x, labels = span_mask(x)

            logits = model(masked_x)
            loss = F.cross_entropy(
                logits.view(-1, model.vocab_size),
                labels.view(-1),
                ignore_index=-100
            )

            # Accuracy pouze na maskovaných pozicích
            mask_positions = (labels != -100)
            if mask_positions.sum() > 0:
                pred = logits.argmax(dim=-1)
                acc = (pred == labels)[mask_positions].float().mean()
                total_acc += acc.item()

            total_loss += loss.item()
            num_batches += 1

    return {
        'eval_loss': total_loss / num_batches if num_batches > 0 else float('inf'),
        'eval_accuracy': total_acc / num_batches if num_batches > 0 else 0.0
    }

In [50]:
def train_with_tracking(config: ExperimentConfig):
    output_dir = setup_experiment(config)

    tokenizer = DNATokenizer(method="single")

    sequences = get_sequences(
        mode=config.mode,
        n=config.n_samples,
        L=config.max_len,
        fasta_path=fasta_path if config.mode == 'fasta' else None,
        max_len_fasta=config.max_len_fasta
    )

    split_idx = int(0.8 * len(sequences))
    train_sequences = sequences[:split_idx]
    val_sequences = sequences[split_idx:]

    model = SimpleDNATransformer(
        vocab_size=config.vocab_size,
        max_len=config.max_len,
        d_model=config.d_model,
        n_heads=config.n_heads,
        n_layers=config.n_layers
    ).to(DEVICE)

    optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=0.01)
    train_batches = prepare_batches(train_sequences, config.max_len)
    print(f"Tréninková data: {len(train_batches)} batchů")

    train_losses, eval_losses, eval_accuracies = [], [], []

    for epoch in range(config.epochs):
        model.train()
        epoch_loss = 0

        indices = random.sample(range(len(train_batches)), min(config.batch_size, len(train_batches)))
        x = torch.tensor([train_batches[i] for i in indices], dtype=torch.long).to(DEVICE)
        masked_x, labels = mask_input(x, mask_prob=config.mask_prob)
        logits = model(masked_x)
        loss = F.cross_entropy(
            logits.view(-1, config.vocab_size),
            labels.view(-1),
            ignore_index=-100
        )
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        epoch_loss = loss.item()
        train_losses.append(epoch_loss)

        # Vyhodnocení každých 5 epoch
        if epoch % 5 == 0:
            eval_metrics = evaluate_model(model, val_sequences, tokenizer, DEVICE)
            eval_losses.append(eval_metrics["eval_loss"])
            eval_accuracies.append(eval_metrics["eval_accuracy"])
            print(f"Epoch {epoch}: train_loss={epoch_loss:.4f}, eval_loss={eval_metrics['eval_loss']:.4f}, eval_acc={eval_metrics['eval_accuracy']:.4f}")
        else:
            print(f"Epoch {epoch}: train_loss={epoch_loss:.4f}")

    print("Ukládám model a statistiky…")
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': config,
        'tokenizer_vocab': tokenizer.vocab
    }, f"{output_dir}/model_final.pt")

    # Výsledky tréninku (možno i vizualizovat v Colabu)
    with open(f"{output_dir}/train_losses.txt", "w") as f:
        for l in train_losses:
            f.write(f"{l}\n")
    with open(f"{output_dir}/eval_losses.txt", "w") as f:
        for l in eval_losses:
            f.write(f"{l}\n")
    with open(f"{output_dir}/eval_accuracies.txt", "w") as f:
        for l in eval_accuracies:
            f.write(f"{l}\n")

    print(f"Experiment: {config.run_name} dokončen. Výsledky v {output_dir}.")
    return model, output_dir

In [57]:
class SimpleDNATransformer(nn.Module):
    def __init__(self, vocab_size=4, max_len=256, d_model=128, n_heads=4, n_layers=4):
        super().__init__()
        self.max_len = max_len
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        seq_len = x.shape[1]
        x = self.embed(x) + self.pos_embed[:, :seq_len]
        x = self.encoder(x)
        logits = self.head(x)
        return logits

In [58]:
configs = [
    ExperimentConfig(experiment_name="baseline_single_token"),
    ExperimentConfig(d_model=256, experiment_name="larger_model"),
    ExperimentConfig(mask_prob=0.25, experiment_name="higher_masking"),
    ExperimentConfig(n_samples=2000, experiment_name="more_data")
]

for config in configs:
    print(f"\n=== Spouštím experiment: {config.experiment_name} ===")
    model, output_dir = train_with_tracking(config)
    print(f"Výsledky uloženy v: {output_dir}")


=== Spouštím experiment: baseline_single_token ===
Tréninková data: 46 batchů
Epoch 0: train_loss=1.5323, eval_loss=1.7982, eval_acc=0.2369
Epoch 1: train_loss=1.8369
Epoch 2: train_loss=1.4034
Epoch 3: train_loss=1.6598
Epoch 4: train_loss=1.5628
Epoch 5: train_loss=1.4241, eval_loss=1.3988, eval_acc=0.2727
Epoch 6: train_loss=1.3944
Epoch 7: train_loss=1.4174
Epoch 8: train_loss=1.4012
Epoch 9: train_loss=1.3695
Ukládám model a statistiky…
Experiment: baseline_single_token_1753269044 dokončen. Výsledky v outputs/baseline_single_token_1753269044.
Výsledky uloženy v: outputs/baseline_single_token_1753269044

=== Spouštím experiment: larger_model ===
Tréninková data: 46 batchů
Epoch 0: train_loss=1.6131, eval_loss=1.9518, eval_acc=0.2774
Epoch 1: train_loss=1.9011
Epoch 2: train_loss=1.9229
Epoch 3: train_loss=1.6985
Epoch 4: train_loss=1.3899
Epoch 5: train_loss=1.5987, eval_loss=1.7249, eval_acc=0.2857
Epoch 6: train_loss=1.6530
Epoch 7: train_loss=1.5712
Epoch 8: train_loss=1.4550
E

In [47]:
fasta_path = "/content/drive/MyDrive/SP/GCF_000001405.26_GRCh38_genomic.fna.gz"
L = 256           #velikost okna (context window)
VOCAB_SIZE = 5    #A, C,G,T,MASK
BATCH_SIZE = 256
EPOCHS = 10
MASK_IDX = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [42]:
def gen_random_seq(n=10000, L=256, alphabet='ACGT'):
    for _ in range(n):
        yield ''.join(random.choice(alphabet) for _ in range(L))

In [41]:
def get_sequences(mode='random', n=10000, L=256, fasta_path=None, max_len_fasta=None):
    if mode == 'random':
        return list(gen_random_seq(n, L))
    elif mode == 'fasta':
        assert fasta_path is not None, 'File path not found'
        full_seq = read_fasta(fasta_path, max_length=max_len_fasta if max_len_fasta else L*n)
        return [full_seq[i:i+L] for i in range(0, len(full_seq) - L + 1, L)]
    else:
        raise ValueError("mode must be 'random' or 'fasta'")

In [46]:
def seq_to_tokens(seq):
    mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    seq = seq.upper()
    return [mapping.get(x, 0) for x in seq if x in mapping]


In [43]:
def prepare_batches(sequences, L=256):
    token_batches = []
    for seq in sequences:
        tokens = seq_to_tokens(seq)
        if len(tokens) == L:
            token_batches.append(tokens)
    return token_batches

In [44]:
def read_fasta(filepath, max_length=10000):
    seq = []
    total_len = 0
    valid_bases = set('ACTG')

    with gzip.open(filepath, 'rt') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                continue
            line = line.upper()
            filtered = ''.join([c for c in line if c in valid_bases])  #jen ACTG
            to_take = max_length - total_len
            if to_take <= 0:
                break
            seq.append(filtered[:to_take])
            total_len += len(filtered[:to_take])
            if total_len >= max_length:
                break

    return ''.join(seq)