<a href="https://colab.research.google.com/github/Valasik0/dna-sequence-llm/blob/first-prototype-test/dna_llm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import gzip

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
class SimpleDNATransformer(nn.Module):
    def __init__(self, vocab_size=4,
                 max_len=256, #context window
                 d_model=128, #embedding dim
                 n_heads=4,
                 n_layers=4): #skryte vrstvy
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        seq_len = x.shape[1]
        x = self.embed(x) + self.pos_embed[:, :seq_len]
        x = self.encoder(x)
        logits = self.head(x)
        return logits

In [4]:
model = SimpleDNATransformer()
x = torch.randint(0, 4, (32, 256))  # batch 32 sekvencí, delka 256 nt
logits = model(x)
logits.shape


torch.Size([32, 256, 4])

In [None]:
x[0]

tensor([0, 0, 2, 3, 3, 0, 3, 2, 2, 1, 2, 3, 1, 0, 0, 2, 0, 2, 0, 1, 3, 1, 0, 2,
        3, 2, 1, 2, 0, 1, 1, 3, 3, 2, 0, 3, 3, 3, 1, 3, 2, 2, 1, 3, 0, 0, 3, 2,
        1, 3, 3, 0, 2, 2, 2, 1, 1, 2, 0, 2, 0, 2, 1, 3, 1, 2, 2, 2, 2, 1, 0, 0,
        1, 3, 0, 1, 1, 1, 2, 2, 2, 2, 1, 0, 3, 3, 3, 1, 3, 2, 0, 3, 0, 2, 3, 3,
        2, 0, 1, 0, 1, 1, 3, 0, 0, 0, 1, 3, 2, 2, 1, 3, 2, 3, 0, 0, 1, 2, 3, 1,
        0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 2, 1, 1, 3, 1, 2, 2, 2, 0, 3, 1, 2, 0, 0,
        3, 0, 0, 2, 0, 2, 0, 2, 2, 1, 0, 3, 0, 2, 3, 0, 2, 0, 2, 1, 3, 1, 3, 1,
        0, 3, 3, 1, 0, 2, 3, 3, 2, 2, 3, 3, 3, 1, 2, 1, 2, 0, 0, 1, 1, 0, 0, 3,
        0, 0, 3, 0, 1, 2, 0, 1, 2, 2, 1, 1, 3, 3, 1, 2, 0, 2, 3, 3, 1, 2, 0, 0,
        3, 0, 1, 0, 2, 3, 0, 0, 0, 0, 2, 1, 1, 2, 2, 2, 0, 0, 1, 1, 3, 1, 0, 2,
        0, 0, 3, 3, 3, 2, 2, 0, 1, 1, 3, 3, 0, 0, 2, 2])

In [5]:
total_params = sum(p.numel() for p in model.parameters())
total_params

2405892

In [4]:
fasta_path = "/content/drive/MyDrive/SP/GCF_000001405.26_GRCh38_genomic.fna.gz"
L = 256           # velikost okna (context window)
VOCAB_SIZE = 5    # A,C,G,T,MASK
BATCH_SIZE = 256
EPOCHS = 10
MASK_IDX = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
def gen_random_seq(n=10000, L=256, alphabet='ACGT'):
    for _ in range(n):
        yield ''.join(random.choice(alphabet) for _ in range(L))

In [6]:
def read_fasta(filepath, max_length=10000):
    seq = []
    total_len = 0
    with gzip.open(filepath, 'rt') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                continue
            to_take = max_length - total_len
            if to_take <= 0:
                break
            seq.append(line[:to_take].upper())
            total_len += len(line[:to_take])
    return ''.join(seq)

In [7]:
def get_sequences(mode='random', n=10000, L=256, fasta_path=None, max_len_fasta=None):
    if mode == 'random':
        return list(gen_random_seq(n, L))
    elif mode == 'fasta':
        assert fasta_path is not None, 'File path not found'
        full_seq = read_fasta(fasta_path, max_length=max_len_fasta if max_len_fasta else L*n)
        return [full_seq[i:i+L] for i in range(0, len(full_seq) - L + 1, L)]
    else:
        raise ValueError("mode must be 'random' or 'fasta'")

In [8]:
def seq_to_tokens(seq):
    mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    seq = seq.upper()
    return [mapping.get(x, 0) for x in seq if x in mapping]

In [9]:
def prepare_batches(sequences, L=256):
    token_batches = []
    for seq in sequences:
        tokens = seq_to_tokens(seq)
        if len(tokens) == L:
            token_batches.append(tokens)
    return token_batches

In [10]:
def mask_input(x, mask_prob=0.15):
    masked = x.clone()
    mask = torch.rand_like(x.float()) < mask_prob
    masked[mask] = 4
    labels = x.clone()
    labels[~mask] = -100
    return masked, labels

In [14]:
fasta_path = "/content/drive/MyDrive/SP/GCF_000001405.26_GRCh38_genomic.fna.gz"
mode = "random"  # "random" / "radnom"
n_samples = 100   # kolik sekvencí
max_len_fasta = 10000   # kolik bází načíst z FASTA (ignoruje zbytek souboru)

sequences = get_sequences(
    mode=mode,
    n=n_samples,
    L=L,
    fasta_path=fasta_path if mode == 'fasta' else None,
    max_len_fasta=max_len_fasta if mode == 'fasta' else None
)

token_batches = prepare_batches(sequences, L)
print(f"Batchů připraveno: {len(token_batches)}")

Batchů připraveno: 100


In [12]:
sequences[0]

'TGCGGTGTGATTCAGGAGGGTGACGTTGGAAGTGGACAATGCATACGGGGCCTTCATGCCCTTTCCCAACTTGGCCTTATTGTAGAGGAATATTAACGCCGCCTCTGGTAGATAAGAGAGAAGGCGCCTTGGCCCGTAACACCTGTCGGTATCCCGTCAACGAATACGTGTCTTGTGCAGGTCATCCTTGCCTGAGTTATTTGGCCGCGATGTTACTAATGTGGTCCCCTGTCGTGCCGAGTGAACAACTCTGGTA'

In [13]:
token_batches[0]

[3,
 2,
 1,
 2,
 2,
 3,
 2,
 3,
 2,
 0,
 3,
 3,
 1,
 0,
 2,
 2,
 0,
 2,
 2,
 2,
 3,
 2,
 0,
 1,
 2,
 3,
 3,
 2,
 2,
 0,
 0,
 2,
 3,
 2,
 2,
 0,
 1,
 0,
 0,
 3,
 2,
 1,
 0,
 3,
 0,
 1,
 2,
 2,
 2,
 2,
 1,
 1,
 3,
 3,
 1,
 0,
 3,
 2,
 1,
 1,
 1,
 3,
 3,
 3,
 1,
 1,
 1,
 0,
 0,
 1,
 3,
 3,
 2,
 2,
 1,
 1,
 3,
 3,
 0,
 3,
 3,
 2,
 3,
 0,
 2,
 0,
 2,
 2,
 0,
 0,
 3,
 0,
 3,
 3,
 0,
 0,
 1,
 2,
 1,
 1,
 2,
 1,
 1,
 3,
 1,
 3,
 2,
 2,
 3,
 0,
 2,
 0,
 3,
 0,
 0,
 2,
 0,
 2,
 0,
 2,
 0,
 0,
 2,
 2,
 1,
 2,
 1,
 1,
 3,
 3,
 2,
 2,
 1,
 1,
 1,
 2,
 3,
 0,
 0,
 1,
 0,
 1,
 1,
 3,
 2,
 3,
 1,
 2,
 2,
 3,
 0,
 3,
 1,
 1,
 1,
 2,
 3,
 1,
 0,
 0,
 1,
 2,
 0,
 0,
 3,
 0,
 1,
 2,
 3,
 2,
 3,
 1,
 3,
 3,
 2,
 3,
 2,
 1,
 0,
 2,
 2,
 3,
 1,
 0,
 3,
 1,
 1,
 3,
 3,
 2,
 1,
 1,
 3,
 2,
 0,
 2,
 3,
 3,
 0,
 3,
 3,
 3,
 2,
 2,
 1,
 1,
 2,
 1,
 2,
 0,
 3,
 2,
 3,
 3,
 0,
 1,
 3,
 0,
 0,
 3,
 2,
 3,
 2,
 2,
 3,
 1,
 1,
 1,
 1,
 3,
 2,
 3,
 1,
 2,
 3,
 2,
 1,
 1,
 2,
 0,
 2,
 3,
 2,
 0,
 0,
 1,
 0,
 0,
 1,
 3,


In [15]:
model = SimpleDNATransformer(vocab_size=VOCAB_SIZE, max_len=L).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(EPOCHS):
    indices = random.sample(range(len(token_batches)), min(BATCH_SIZE, len(token_batches)))
    x = torch.tensor([token_batches[i] for i in indices], dtype=torch.long).to(DEVICE)
    masked_x, labels = mask_input(x)
    logits = model(masked_x)
    loss = F.cross_entropy(
        logits.view(-1, VOCAB_SIZE),
        labels.view(-1),
        ignore_index=-100
    )
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch} loss: {loss.item():.4f}")

Epoch 0 loss: 1.6671
Epoch 1 loss: 1.7780
Epoch 2 loss: 1.4508
Epoch 3 loss: 1.5010
Epoch 4 loss: 1.4568
Epoch 5 loss: 1.4142
Epoch 6 loss: 1.4049
Epoch 7 loss: 1.4219
Epoch 8 loss: 1.4250
Epoch 9 loss: 1.4133
