In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
from Bio import SeqIO
import numpy as np
import os
import h5py
from nucleotide_transformer.pretrained import get_pretrained_model

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name="250M_multi_species_v2",
    embeddings_layers_to_save=(20,),
    max_positions=512,
)
forward_fn = hk.transform(forward_fn)

Downloading model's hyperparameters json file...
Downloaded model's hyperparameters.
Downloading model's weights...
Downloaded model's weights...


In [None]:
def batched_fasta(path, batch_size=16):
    ids, seqs = [], []
    for rec in SeqIO.parse(path, "fasta"):
        ids.append(rec.id)
        seqs.append(str(rec.seq.upper()))
        if len(seqs) >= batch_size:
            yield ids, seqs
            ids, seqs = [], []
    if seqs:
        yield ids, seqs

In [None]:
def chunk_sequence(seq, chunk_size=512, stride=None):
    s = str(seq)
    if stride is None:
        stride = chunk_size
    for i in range(0, len(s), stride):
        chunk = s[i:i + chunk_size]
        if chunk:
            yield chunk

In [None]:
def batched_fasta_chunks(path, batch_size=16, chunk_size=512, stride=None):
    """
    Parcoure le FASTA en flux et renvoie des lots de fragments de séquence (longueur max chunk_size).
    Chaque identifiant de fragment est original_id + '_part{n}'.
    Parameters
    ----------
    path : str
        Chemin vers le fichier FASTA.
    batch_size : int, optional
        Taille du lot, par défaut 16.
    chunk_size : int, optional
        Taille maximale des fragments, par défaut 512.
    stride : int, optional
        Décalage entre les fragments, par défaut None (égale à chunk_size).
    """
    ids, seqs = [], []
    for rec in SeqIO.parse(path, "fasta"):
        for idx, chunk in enumerate(chunk_sequence(rec.seq.upper(), chunk_size, stride)):
            ids.append(f"{rec.id}_part{idx}")
            seqs.append(chunk)
            if len(seqs) >= batch_size:
                yield ids, seqs
                ids, seqs = [], []
    if seqs:
        yield ids, seqs

In [None]:
filePath = "data/raw/Gossypium_hirsutum_v2.1_genomic.fna"
out_h5 = "data/embeddings/Gossypium_hirsutum_v2.1_embeddings.h5"
random_key = jax.random.PRNGKey(0)
batch_size = 2
chunk_size = 512
stride = None  # None pour des fragments non-chevauchants, 256 pour 50% de chevauchement

In [None]:
with h5py.File(out_h5, "a") as f:
    # Crée les jeux de données sur le premier lot lorsque la forme des embeddings est connue
    for ids, sequences in batched_fasta_chunks(filePath, batch_size=batch_size, chunk_size=chunk_size, stride=stride):
        # tokeniser
        token_pairs = tokenizer.batch_tokenize(sequences)  # returns (str, token_ids)
        token_ids = [p[1] for p in token_pairs]
        tokens = jnp.asarray(token_ids, dtype=jnp.int32)

        # exécute le modèle (met à jour la clé RNG pour chaque lot)
        random_key, subkey = jax.random.split(random_key)
        outs = forward_fn.apply(parameters, subkey, tokens)

        # transfére vers numpy sur l'hôte
        emb = jax.device_get(outs["embeddings_20"]).astype(np.float32)  # forme (batch, seq_len, dim)
        b, seq_len, dim = emb.shape

        # crée les jeux de données si nécessaire
        if "Gossypium_hirsutum_v2.1_embeddings" not in f:
            maxshape = (None, seq_len, dim)
            f.create_dataset("embeddings", data=emb, maxshape=maxshape, chunks=(1, seq_len, dim), dtype="f4")
            str_dt = h5py.string_dtype(encoding="utf-8")
            f.create_dataset("ids", data=np.array(ids, dtype=object), maxshape=(None,), dtype=str_dt)
        else:
            # ajoute les embeddings
            ds = f["embeddings"]
            old = ds.shape[0]
            ds.resize(old + b, axis=0)
            ds[old:old + b, ...] = emb
            # ajoute les ids
            ids_ds = f["ids"]
            ids_ds.resize(old + b, axis=0)
            ids_ds[old:old + b] = np.array(ids, dtype=object)

In [None]:
# Lecture des séquences CRISPR
bt_sequences = []
for record in SeqIO.parse("data/raw/bt_genes.fasta", "fasta"):
    bt_sequences.append(record.seq)

In [None]:

# Get embeddings at layer 20
print(outs["embeddings_20"].shape)