In [15]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

from datasets import load_dataset
import pandas as pd

import random
import numpy as np
import torch

In [16]:
def set_seed(seed=7):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(7)

In [17]:
MODEL_ID = "InstaDeepAI/nucleotide-transformer-2.5b-multi-species"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

repo_id = 'kuleshov-group/cross-species-single-nucleotide-annotation'
tis = load_dataset(repo_id, data_files={'train': 'TIS/train.tsv', 'valid': 'TIS/valid.tsv', 'test_rice':'TIS/test_rice.tsv', 'test_sorghum':'TIS/test_sorghum.tsv', 'test_maize':'TIS/test_maize.tsv'})

tis_train = tis['train']

train_df = tis_train.to_pandas()

seqs = train_df["sequences"]

In [18]:
MODEL_CAP = tokenizer.model_max_length 

max_seq_len = seqs.str.len().max()

MAX_LEN = min(MODEL_CAP, max_seq_len)

MAX_LEN

512

In [19]:
df = pd.DataFrame(data=None, columns=["reference_seq", "variant_seq", "label"])

df["reference_seq"] = seqs
df["variant_seq"] = seqs

In [20]:
ft_df = df.loc[np.repeat(df.index, 2)].reset_index(drop=True)
ft_df["label"] = [0, 1] * (len(ft_df) // 2)

ft_df

Unnamed: 0,reference_seq,variant_seq,label
0,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,0
1,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,1
2,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,0
3,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,1
4,TCCATAGCTCTAATTGCAACAGGCGTGGGTGCGGCTGCCGGCTTTG...,TCCATAGCTCTAATTGCAACAGGCGTGGGTGCGGCTGCCGGCTTTG...,0
...,...,...,...
397177,TGGTCGAGTGTGCGATGATGAAGGGACTCGACCATGTAATCGAGTG...,TGGTCGAGTGTGCGATGATGAAGGGACTCGACCATGTAATCGAGTG...,1
397178,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,0
397179,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,1
397180,TTCGTTTCTTACCTTAGTACGTCCTCTAGCAACTGAATGGACTTCA...,TTCGTTTCTTACCTTAGTACGTCCTCTAGCAACTGAATGGACTTCA...,0
