In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

from datasets import load_dataset
import pandas as pd

import random
import numpy as np
import torch

import tqdm as tq

In [36]:
def set_seed(seed=7):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(7)

In [37]:
repo_id = 'kuleshov-group/cross-species-single-nucleotide-annotation'
tis = load_dataset(repo_id, data_files={'train': 'TIS/train.tsv', 'valid': 'TIS/valid.tsv', 'test_rice':'TIS/test_rice.tsv', 'test_sorghum':'TIS/test_sorghum.tsv', 'test_maize':'TIS/test_maize.tsv'})

tis_train = tis['train']

train_df = tis_train.to_pandas()

seqs = train_df["sequences"]

In [38]:
df = pd.DataFrame(data=None, columns=["reference_seq", "variant_seq", "label"])

df["reference_seq"] = seqs
df["variant_seq"] = seqs

In [39]:
ft_df = df.loc[np.repeat(df.index, 2)].reset_index(drop=True)
ft_df["label"] = [0, 1] * (len(ft_df) // 2)

ft_df

Unnamed: 0,reference_seq,variant_seq,label
0,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,0
1,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,1
2,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,0
3,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,1
4,TCCATAGCTCTAATTGCAACAGGCGTGGGTGCGGCTGCCGGCTTTG...,TCCATAGCTCTAATTGCAACAGGCGTGGGTGCGGCTGCCGGCTTTG...,0
...,...,...,...
397177,TGGTCGAGTGTGCGATGATGAAGGGACTCGACCATGTAATCGAGTG...,TGGTCGAGTGTGCGATGATGAAGGGACTCGACCATGTAATCGAGTG...,1
397178,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,0
397179,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,1
397180,TTCGTTTCTTACCTTAGTACGTCCTCTAGCAACTGAATGGACTTCA...,TTCGTTTCTTACCTTAGTACGTCCTCTAGCAACTGAATGGACTTCA...,0


In [40]:
def make_pair(df):
    nucleic_list = ['A', 'T', 'G', 'C']
    
    for i in range(1, len(df)+1, 2):
        seq = df.loc[i, "variant_seq"]

        length = len(seq)
        idx = random.randrange(0, length)

        original = seq[idx]
        candidates = [nucleic for nucleic in nucleic_list if nucleic != original]

        new = random.choice(candidates)

        seq_list = list(seq)
        seq_list[idx] = new

        new_seq = "".join(seq_list)

        df.loc[i, "variant_seq"] = new_seq

    return df

In [41]:
ft_df = make_pair(ft_df)

In [52]:
false_count = 0

for i in range(len(ft_df)):
    ref = ft_df.loc[i, "reference_seq"]
    var = ft_df.loc[i, "variant_seq"]

    for j, ch in enumerate(ref):
        if ch != var[j]:
            false_count += 1

false_count

198591

In [54]:
ft_df.to_csv("./Data/fine_tuning.csv", index=None)

In [55]:
pd.read_csv("./Data/fine_tuning.csv")

Unnamed: 0,reference_seq,variant_seq,label
0,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,0
1,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,1
2,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,0
3,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,1
4,TCCATAGCTCTAATTGCAACAGGCGTGGGTGCGGCTGCCGGCTTTG...,TCCATAGCTCTAATTGCAACAGGCGTGGGTGCGGCTGCCGGCTTTG...,0
...,...,...,...
397177,TGGTCGAGTGTGCGATGATGAAGGGACTCGACCATGTAATCGAGTG...,TGGTCGAGTGTGCGAGGATGAAGGGACTCGACCATGTAATCGAGTG...,1
397178,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,0
397179,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,1
397180,TTCGTTTCTTACCTTAGTACGTCCTCTAGCAACTGAATGGACTTCA...,TTCGTTTCTTACCTTAGTACGTCCTCTAGCAACTGAATGGACTTCA...,0
