In [1]:
from datasets import load_dataset
import pandas as pd

import random
import numpy as np
import torch

import tqdm as tq

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def set_seed(seed=7):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(7)

In [3]:
repo_id = 'kuleshov-group/cross-species-single-nucleotide-annotation'
tis = load_dataset(repo_id, data_files={'train': 'TIS/train.tsv', 'valid': 'TIS/valid.tsv', 'test_rice':'TIS/test_rice.tsv', 'test_sorghum':'TIS/test_sorghum.tsv', 'test_maize':'TIS/test_maize.tsv'})

tis_train = tis['train']

train_df = tis_train.to_pandas()

seqs = train_df["sequences"]

In [4]:
df = pd.DataFrame(data=None, columns=["reference_seq", "variant_seq", "label"])

df["reference_seq"] = seqs
df["variant_seq"] = seqs

In [5]:
ft_df = df.loc[np.repeat(df.index, 2)].reset_index(drop=True)
ft_df["label"] = [0, 1] * (len(ft_df) // 2)

ft_df

Unnamed: 0,reference_seq,variant_seq,label
0,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,0
1,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,1
2,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,0
3,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,1
4,TCCATAGCTCTAATTGCAACAGGCGTGGGTGCGGCTGCCGGCTTTG...,TCCATAGCTCTAATTGCAACAGGCGTGGGTGCGGCTGCCGGCTTTG...,0
...,...,...,...
397177,TGGTCGAGTGTGCGATGATGAAGGGACTCGACCATGTAATCGAGTG...,TGGTCGAGTGTGCGATGATGAAGGGACTCGACCATGTAATCGAGTG...,1
397178,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,0
397179,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,1
397180,TTCGTTTCTTACCTTAGTACGTCCTCTAGCAACTGAATGGACTTCA...,TTCGTTTCTTACCTTAGTACGTCCTCTAGCAACTGAATGGACTTCA...,0


In [6]:
def make_pair(df, max_mutations=30):
    nucleic_list = ['A', 'T', 'G', 'C']

    for i in range(1, len(df)+1, 2):
        seq = df.loc[i, "variant_seq"]
        seq_len = len(seq)

        # 변이 글자 수 랜덤 선택 (1~max_mutations)
        num_mut = random.randint(1, max_mutations)
        idx_list = random.sample(range(seq_len), num_mut)

        seq_list = list(seq)
        for idx in idx_list:
            original = seq_list[idx]
            candidates = [n for n in nucleic_list if n != original]
            seq_list[idx] = random.choice(candidates)

        new_seq = "".join(seq_list)
        df.loc[i, "variant_seq"] = new_seq

    return df


In [7]:
ft_df = make_pair(ft_df)

In [8]:
false_count = 0

for i in range(len(ft_df)):
    ref = ft_df.loc[i, "reference_seq"]
    var = ft_df.loc[i, "variant_seq"]

    for j, ch in enumerate(ref):
        if ch != var[j]:
            false_count += 1

false_count

3073389

In [9]:
ft_df.to_csv("./Data/fine_tuning.csv", index=None)

In [10]:
pd.read_csv("./Data/fine_tuning.csv")

Unnamed: 0,reference_seq,variant_seq,label
0,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,0
1,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTAGAACATA...,ATTGTCCTAACTCAGAGTCCTCAGCATCATCACGGATTTGAACATA...,1
2,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,0
3,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,CTCTTTCAGATCCTTATAGCTTCTATAAATATGATTGAGATTAAGC...,1
4,TCCATAGCTCTAATTGCAACAGGCGTGGGTGCGGCTGCCGGCTTTG...,TCCATAGCTCTAATTGCAACAGGCGTGGGTGCGGCTGCCGGCTTTG...,0
...,...,...,...
397177,TGGTCGAGTGTGCGATGATGAAGGGACTCGACCATGTAATCGAGTG...,TGGTCGAGTGTGCGATGATGAAGGGACTCGACCATGTAATCGAGTG...,1
397178,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,0
397179,GGTATAATGCATGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,GGTATAATGCGTGATTGACTAAGCAGACAAGTTCTGATCAAGCCAC...,1
397180,TTCGTTTCTTACCTTAGTACGTCCTCTAGCAACTGAATGGACTTCA...,TTCGTTTCTTACCTTAGTACGTCCTCTAGCAACTGAATGGACTTCA...,0
