#Enhancher type prediction

##Set up

In [38]:
!pip install matplotlib
import os

try:
    import nucleotide_transformer
except:
    !pip install numpy==1.23.5
    !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1
    import nucleotide_transformer

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model



In [39]:
#@title Select a model
#@markdown ---
model_name = '500M_human_ref'#@param['500M_human_ref', '500M_1000G', '2B5_1000G', '2B5_multi_species', '50M_multi_species_v2', '100M_multi_species_v2', '250M_multi_species_v2', '500M_multi_species_v2']
#@markdown ---

In [40]:
# Get pretrained model
embedding_layer = 20  # Select layer to extract embeddings from

parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    embeddings_layers_to_save=(embedding_layer,),
    max_positions=250
)
forward_fn = hk.transform(forward_fn)

##Import dataset

In [41]:
# Install
!pip install -q biopython transformers datasets huggingface_hub accelerate

from datasets import load_dataset, Dataset
import pandas as pd

# Load dataset
dataset_name = "enhancers"
train_dataset = load_dataset(
        "InstaDeepAI/nucleotide_transformer_downstream_tasks",
        dataset_name,
        split="train",
        streaming= False,
    )

test_dataset = load_dataset(
        "InstaDeepAI/nucleotide_transformer_downstream_tasks",
        dataset_name,
        split="test",
        streaming= False,
    )

## Split features from label

In [42]:
# Training data
train_sequences = train_dataset['sequence']
train_labels = train_dataset['label']

# Test data
test_sequences = test_dataset['sequence']
test_labels = test_dataset['label']

##Retrieve embeddings (one for every sequence)

In [43]:
batch_size = 8  # Adjust to available memory

def get_seq_embeddings(sequences: list, batch_size: int):
    embeddings = []  # Stores 6mers embeddings

    # Split sequences into batches
    batched_sequences = [sequences[i:i + batch_size] for i in range(0, len(sequences), batch_size)]

    random_key = jax.random.PRNGKey(0)
    extraction_layer = "embeddings_" + str(embedding_layer)
    cls_token_position = 0 # Position of the CLS token for every sequence

    for batch in batched_sequences:
        # Tokenize the batch
        tokens_ids = [b[1] for b in tokenizer.batch_tokenize(batch)]
        tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

        # Infer
        outs = forward_fn.apply(parameters, random_key, tokens)

        # Get embeddings
        batch_embeddings = outs[extraction_layer]

        # Mask for padding tokens and CLS token
        padding_mask = (tokens != tokenizer.pad_token_id) & (jnp.arange(tokens.shape[1]) != cls_token_position)
        masked_embeddings = batch_embeddings * padding_mask[:, :, None]

        for item in masked_embeddings:
            sum_embeddings = jnp.sum(item, axis=-1)
            non_zero_mask = sum_embeddings != 0.0
            seq_token_embeddings = item[non_zero_mask]
            sequences_lengths = item.shape[0]
            mean_embedding = jnp.sum(seq_token_embeddings, axis=0) / sequences_lengths
            embeddings.append(mean_embedding)

    return jnp.vstack(embeddings)

In [44]:
X_train = get_seq_embeddings(train_sequences, batch_size)
X_test = get_seq_embeddings(test_sequences, batch_size)

y_train = train_labels
y_test = test_labels

##Train downstream model

In [45]:
!pip install lightgbm

import lightgbm as lgb
from sklearn.metrics import accuracy_score

params = {
    'objective': 'binary',
    'metric': 'binary_error',
    'num_leaves': 31,
    'learning_rate': 0.05,
    'feature_fraction': 0.9
}

clf = lgb.LGBMClassifier(**params)

clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

# Evaluate the classifier's performance
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")

[LightGBM] [Info] Number of positive: 7484, number of negative: 7484
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 1.211321 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 326400
[LightGBM] [Info] Number of data points in the train set: 14968, number of used features: 1280
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
Accuracy: 0.75
