In [None]:
import os
import torch
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForSequenceClassification
from datasets import Dataset, DatasetDict
from sklearn.metrics import roc_auc_score
import numpy as np
import pandas as pd
from tqdm import tqdm

os.environ["CUDA_DEVICE_ORDE"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
torch.cuda.is_available()

In [None]:
device = torch.device("cuda")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-human-ref")

In [None]:
def preprocessing(examples):
    return tokenizer(examples["seq"])

def compute_metrics(eval_pred):
    predictions, labels = eval_pred    
    auroc = roc_auc_score(labels, predictions[:, 1])
    return {'AUROC': auroc}

def prediction(seq):
    global model, i
    x_feat = tokenizer(seq, return_tensors='pt')
    for k in x_feat:
        x_feat[k] = x_feat[k].cuda()
    with torch.no_grad():
        out = model(x_feat['input_ids'])
    prob = torch.softmax(out['logits'], dim=-1)
    return round(float(prob[0][1]), 5)

Before running the cell, make sure that parent folder contains the folder 'test' with the gHTS and CHS data files in fasta format.

In [None]:
ghts_seqs = dict()
chs_seqs = dict()
with open('../test/gHTS_participants.fasta') as test:
    lines = test.readlines()
    for i in range(0, len(lines)):
        s = lines[i].strip()
        if s[0] == '>':
            key = s[1:]
        else:
            ghts_seqs[key] = s

with open('../test/CHS_participants.fasta') as test:
    lines = test.readlines()
    for i in range(0, len(lines)):
        s = lines[i].strip()
        if s[0] == '>':
            key = s[1:]
        else:
            chs_seqs[key] = s

ghts_ids = list(ghts_seqs.keys())
ghts_submit = dict()
ghts_submit['tags'] = ghts_ids

chs_ids = list(chs_seqs.keys())
chs_submit = dict()
chs_submit['tags'] = chs_ids


Before running the cell, make sure that parent folder contains the folder 'datasets' with the HTS data files in parquet.gzip format.

In [None]:
for tf in ['CREB3L3', 'MYPOP', 'SP140L', 'ZNF831', 'ZNF286B',
           'ZNF780B', 'MSANTD1', 'FIZ1', 'MKX', 'ZNF721',
           'ZBTB47', 'GCM1', 'TPRX1', 'ZFTA', 'ZNF500']:
    data = pd.read_parquet(f'../datasets/{tf}.parquet.gzip')
    train = data[data.group=='train']
    val = data[data.group=='val']
    dataset = DatasetDict({
        "train": Dataset.from_pandas(train),
        "val": Dataset.from_pandas(val)
        })
    tokenized_dataset = dataset.map(preprocessing, batched=True, remove_columns=["seq", 'group', '__index_level_0__'])

    training_args = TrainingArguments(
        output_dir="test_run",
        learning_rate=1e-05,
        lr_scheduler_type="constant",
        warmup_ratio=0,
        optim='adamw_torch',
        weight_decay=0,
        per_device_train_batch_size=128,
        per_device_eval_batch_size=64,
        num_train_epochs=2,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        label_names=["labels"],
        load_best_model_at_end=True
    )
    model = AutoModelForSequenceClassification.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-human-ref", num_labels=2)
    model = model.to(device)

    trainer = Trainer(
        model.to(device),
        training_args,
        train_dataset= tokenized_dataset["train"],
        eval_dataset= tokenized_dataset["val"],
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )
    for param in model.parameters():
        param.data = param.data.contiguous()
    train_results = trainer.train()
    torch.save(model, f'./a2g_{tf}_model')

    model = model.eval().cuda()
    predictions = []
    for name, seq in tqdm(ghts_seqs.items()):
        predictions.append(prediction(seq))
    ghts_submit[tf] = predictions

    predictions = []
    for name, seq in tqdm(chs_seqs.items()):
        predictions.append(prediction(seq))
    chs_submit[tf] = predictions

In [None]:
ghts_df = pd.DataFrame.from_dict(ghts_submit).set_index('tags')
chs_df = pd.DataFrame.from_dict(chs_submit).set_index('tags')
ghts_df.to_csv('./submission_NT_ghts.tsv', sep="\t")
chs_df.to_csv('./submission_NT_chs.tsv', sep="\t")