In [1]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import numpy as np
from tqdm import tqdm
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments, PreTrainedTokenizerFast
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer

In [2]:
def calculate_statistics(sequences):
    num_sequences = len(sequences)
    avg_length = int(sum(len(seq) for seq in sequences) / num_sequences)
    min_length = min(len(seq) for seq in sequences)
    max_length = max(len(seq) for seq in sequences)

    return {
        "Number of sequences": num_sequences,
        "Average sequence length": avg_length,
        "Min sequence length": min_length,
        "Max sequence length": max_length
    }

### 1. Load and Encode RNA Data

In [3]:
# True if no flanks 
# if all data - use False
hairpin_bool = False
if hairpin_bool:
    relevant_seq = 'pre_mirna'
else:
    relevant_seq = 'full_seq'

In [4]:
# *all data:*

gff_data = pd.read_csv('/sise/vaksler-group/IsanaRNA/Transformers/GPT_env/Data_output/gff_output/sebastian_db_features_new.csv')
mirogen_data = pd.read_csv('/sise/vaksler-group/IsanaRNA/Transformers/Rom/Data_output/miRGeneDB_output/miRGeneDB_features.csv')
# # remove the flanks from the sequences:
# gff_data = gff_data.dropna(subset=['full_seq', 'flank1', 'flank2']); mirogen_data = mirogen_data.dropna(subset=['full_seq', 'flank1', 'flank2']) # found some null flanks
# for index, row in gff_data.iterrows():
#     gff_data.at[index, 'full_seq'] = gff_data.at[index, 'full_seq'].replace(row['flank1'], '').replace(row['flank2'], '')
# for index, row in mirogen_data.iterrows():
#     mirogen_data.at[index, 'full_seq'] = mirogen_data.at[index, 'full_seq'].replace(row['flank1'], '').replace(row['flank2'], '')

    
# *only train (based on clusters)*:

# gff_data = pd.read_csv('/sise/vaksler-group/IsanaRNA/Transformers/GPT_env/seq_clusters/gff_train_data.csv')
# mirogen_data = pd.read_csv('/sise/vaksler-group/IsanaRNA/Transformers/GPT_env/seq_clusters/mirgene_train_data.csv')


original_gff_sequences = gff_data[relevant_seq].tolist() ; original_mirogen_sequences  = mirogen_data[relevant_seq].tolist()
gff_data = gff_data.dropna(subset=['Star','Mature']); mirogen_data = mirogen_data.dropna(subset=['Star','Mature']) # found some null star and mature


def encode_sequence(row):
    full_seq = row[relevant_seq]
    mature_start = full_seq.find(row['Mature'])
    if mature_start == -1:
        print("no mature")
    mature_end = mature_start + len(row['Mature'])
    star_start = full_seq.find(row['Star'])
    if star_start == -1:
        print("no star")
    star_end = star_start + len(row['Star'])
    # star_end = len(row['Star'])

    encoded_seq = full_seq
    if star_start < star_end:
        if mature_start < star_start:
            encoded_seq = (encoded_seq[:mature_start] + 'ZZZZZ' +
                           encoded_seq[mature_start:mature_end] + 'BBBBB' +
                           encoded_seq[mature_end:star_start] + 'DDDDD' +
                           encoded_seq[star_start:star_end] + 'FFFFF' +
                           encoded_seq[star_end:])
        else:
            encoded_seq = (encoded_seq[:star_start] + 'DDDDD' +
                           encoded_seq[star_start:star_end] + 'FFFFF' +
                           encoded_seq[star_end:mature_start] + 'ZZZZZ' +
                           encoded_seq[mature_start:mature_end] + 'BBBBB' +
                           encoded_seq[mature_end:])
    return encoded_seq


def decode_sequence(encoded_seq):
    # decode the encoded patterns with original tokens
    decoded_seq = encoded_seq.replace('ZZZZZ', '').replace('BBBBB', '').replace('DDDDD', '').replace('FFFFF', '')
    return decoded_seq


# Apply the encoding function on both datasets
gff_data['encoded_seq'] = gff_data.apply(encode_sequence, axis=1) ; mirogen_data['encoded_seq'] = mirogen_data.apply(encode_sequence, axis=1)

# gff_data['encoded_seq'] = gff_data.apply(encode_sequence, axis=1) ; mirogen_data['encoded_seq'] = mirogen_data.apply(encode_sequence, axis=1)
gff_data['decoded_seq'] = gff_data['encoded_seq'].apply(decode_sequence) ; mirogen_data['decoded_seq'] = mirogen_data['encoded_seq'].apply(decode_sequence)

In [5]:
original_full_gff_sequences = gff_data['full_seq'].tolist() ; original_full_mirogen_sequences  = mirogen_data['full_seq'].tolist()


In [6]:
print((gff_data[relevant_seq] == gff_data['decoded_seq']).value_counts())
# Iterate through the sampled false indices and check each point
# get sampled_false_indices (where gff_data['full_seq'] != gff_data['decoded_seq'])
false_indices = [idx for idx in gff_data[gff_data[relevant_seq] != gff_data['decoded_seq']].index]
sampled_false_indices = np.random.choice(false_indices, 3, replace=False)
for idx in sampled_false_indices:
    print("Index:", idx)

    # Extract necessary information from the DataFrame
    row = gff_data.loc[idx]
    full_seq = row[relevant_seq]
    # full_seq_folding = row['full_seq_folding'] ################################################
    encoded_seq = row['encoded_seq']
    decoded_seq = row['decoded_seq']

    # 1. Check Special Encoding Tokens
    # Check if special tokens are correctly replaced during decoding
    if 'ZZZZZ' in decoded_seq or 'BBBBB' in decoded_seq or 'DDDDD' in decoded_seq or 'FFFFF' in decoded_seq:
        print("Special Encoding Tokens not correctly replaced.")

    # # 2. Check Translation Dictionary and Reverse Translation Dictionary
    # # Encode and decode the full sequence without adding special tokens and compare with the original
    # simple_encoded_seq = ''.join([translation_dict.get(full_seq_folding[i] + full_seq[i], full_seq[i]) for i in range(len(full_seq))])
    # simple_decoded_seq = ''.join([reverse_translation_dict.get(char, char) for char in simple_encoded_seq])
    # if simple_decoded_seq != full_seq:
    #     print("Discrepancy in Translation and Reverse Translation Dictionary.")

    # 3. Check Start and End Indices
    # Print the start and end indices used during encoding
    mature_start = full_seq.find(row['Mature'])
    mature_end = mature_start + len(row['Mature'])
    star_start = row['Start_star']
    star_end = row['End_star']
    print("Mature Start:", mature_start, "Mature End:", mature_end)
    print("Star Start:", star_start, "Star End:", star_end)
    # Print the Decoded, Encoded, and Full Sequence for comparison
    print("Encoded Sequence:", encoded_seq)
    print("Decoded Sequence:", decoded_seq)
    print("full/hairpin Sequence:", full_seq)
    print('len(decoded_seq):', len(decoded_seq))
    print('len(full/hairpin seq):', len(full_seq))
    print("-" * 30)


True     174970
False      1664
Name: count, dtype: int64
Index: 26419
Mature Start: 31 Mature End: 53
Star Start: 27 Star End: 48
Encoded Sequence: GTCCCAGAAGAGAACTTGCCAGCTGCCDDDDDACAAACCCGTAGATCCGAACTTFFFFFZZZZZACCCGTAGATCCGAACTTGTGGBBBBBTGACTGGCCGCACAAGCTCGTGTCTATAGGTATGTGTCTGTGTGGCCATCACAGCACCCCTCTC
Decoded Sequence: GTCCCAGAAGAGAACTTGCCAGCTGCCACAAACCCGTAGATCCGAACTTACCCGTAGATCCGAACTTGTGGTGACTGGCCGCACAAGCTCGTGTCTATAGGTATGTGTCTGTGTGGCCATCACAGCACCCCTCTC
full/hairpin Sequence: GTCCCAGAAGAGAACTTGCCAGCTGCCACAAACCCGTAGATCCGAACTTGTGGTGACTGGCCGCACAAGCTCGTGTCTATAGGTATGTGTCTGTGTGGCCATCACAGCACCCCTCTC
len(decoded_seq): 135
len(full/hairpin seq): 117
------------------------------
Index: 98153
Mature Start: 31 Mature End: 53
Star Start: 35 Star End: 56
Encoded Sequence: CGGGGGCCCGGACTCCTGGGTCCTGGCACCCZZZZZACCCGTAGAACCGACCTTGCGGBBBBBDDDDDGTAGAACCGACCTTGCGGGGCCFFFFFTTCGCCGCACACAAGCTCGTGTCTGTGGGTCCGTGTCGGGGGCTCACCATCGCGGCTGGGGCC
Decoded Sequence: CGGGGGCCCGGACTCCTGGGTCCTGGCACCCACCCGTAGAACCGACCTTGCG

In [7]:
gff_sequences = ['<SOS>' + sequence + '<EOS>' for sequence in gff_data['encoded_seq']]
mirgendb_sequences = ['<SOS>' + sequence + '<EOS>' for sequence in mirogen_data['encoded_seq']]

In [8]:
calculate_statistics(gff_sequences)

{'Number of sequences': 176634,
 'Average sequence length': 151,
 'Min sequence length': 114,
 'Max sequence length': 410}

In [9]:
calculate_statistics(mirgendb_sequences) 

{'Number of sequences': 16361,
 'Average sequence length': 151,
 'Min sequence length': 133,
 'Max sequence length': 372}

### 2. Train & Evaluate BPE Tokenizer

**no need to train it here, the train is in colab**

In [10]:
def token_statistics(tokenizer, sequences):
    token_lengths = []

    for seq in sequences:
        tokens = tokenizer.encode(seq).tokens
        token_lengths.extend([len(token) for token in tokens])

    avg_length = np.mean(token_lengths)
    min_length = np.min(token_lengths)
    max_length = np.max(token_lengths)
    median_length = np.median(token_lengths)
    std_dev = np.std(token_lengths)

    return {
        "average": avg_length,
        "min": min_length,
        "max": max_length,
        "median": median_length,
        "std_dev": std_dev
    }

In [11]:
def train_bpe_tokenizer(sequences, vocab_size=10000, min_frequency=3):
    special_tokens = ["<SOS>", "<EOS>", "ZZZZZ", "BBBBB", "DDDDD", "FFFFF"]
    
    # Define BPE model with an unknown token
    tokenizer = Tokenizer(BPE(unk_token="<UNK>"))
    
    # Define the BPE trainer with special tokens, vocab size, and min frequency
    trainer =  BpeTrainer(special_tokens=special_tokens, max_token_length=6) #vocab_size=vocab_size, min_frequency=min_frequency, 
    # Train the tokenizer from the iterator of sequences
    tokenizer.train_from_iterator(sequences, trainer=trainer)
    
    # After training, add the special tokens to the tokenizer to ensure they won't be split
    tokenizer.add_tokens(special_tokens)

    return tokenizer

In [12]:
# load tokenizer that was trained in colab
bpe_tokenizer = PreTrainedTokenizerFast(tokenizer_file="GPT_mature_star_bpe_hairpin_tokenizer.json") # or no flanks: GPT_mature_star_bpe_tokenizer

In [14]:
print("All tokens:")
unique_tokens=set()
vocab = bpe_tokenizer.get_vocab()
for token, idx in vocab.items():
    unique_tokens.add(token)
    if token == "UNK":
        print("<UNK> token found in vocabulary!")
print(len(unique_tokens))
unique_tokens_list = sorted(list(unique_tokens), key=lambda x: len(x), reverse=True)
# print(unique_tokens_list)
file_path = 'unique_tokens.txt'
with open(file_path, 'w') as file:
    for token in unique_tokens_list:
        file.write(token + '\n')

All tokens:
2132


### Validate for one-char tokens

In [15]:
one_char_tokens = [token for token in vocab if len(token) == 1]
if one_char_tokens:
    print(f"One-char tokens found: {one_char_tokens}")
else:
    print("No one-char tokens found!")

One-char tokens found: ['n', 'Y', '<', 'F', 'T', 'D', 'a', 'A', 'B', 'R', 'Z', 'W', 'C', 'O', 'N', 't', '>', 'E', 'G', 'g', 'S', 'c']


### 3. Tokenize datasets (ALL & Human)

In [None]:
tokenizer = bpe_tokenizer
# Define special tokens                            
special_tokens = {
    "pad_token": "<PAD>",
    "bos_token": "<SOS>",
    "eos_token": "<EOS>",
    "unk_token": "<UNK>"
}

# Add special tokens to the tokenizer
tokenizer.add_special_tokens(special_tokens)

MAX_SEQ_LEN = max([len(s) for s in gff_sequences]) 

def tokenize_function(data):
    # Tokenize the sequences
    output = tokenizer(data["sequences"], truncation=True, padding='max_length', max_length=MAX_SEQ_LEN)
    # Use input_ids as labels
    output["labels"] = output["input_ids"].copy()
    return output

dataset = Dataset.from_dict({"sequences": gff_sequences})
tokenized_datasets = dataset.map(tokenize_function, batched=True)


In [None]:
MAX_SEQ_LEN = max([len(s) for s in mirgendb_sequences])
mirgendb_datset = Dataset.from_dict({"sequences": mirgendb_sequences})
tokenized_mirgendb_datasets = mirgendb_datset.map(tokenize_function, batched=True)

### 4. First training on GFF data

In [None]:
# Define the GPT-2 model

# Create a new GPT-2 configuration and model
config = GPT2Config(
    vocab_size=len(tokenizer),
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id
)
# model = GPT2LMHeadModel.from_pretrained("results_ms/checkpoint-84500")

model = GPT2LMHeadModel(config)

# Define training arguments and train
training_args = TrainingArguments(
    output_dir='./results_ms_all_data', # with flanks : "results_ms"
    num_train_epochs=5,
    per_device_train_batch_size=8,
    logging_dir='./logs_ms_all_data',  # with flanks : "logs_ms"
    learning_rate=5e-5
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
)

trainer.train()

# Save the pretrain model and tokenizer
model.save_pretrained("gff_ms_all_data")  # or no flanks: gff_ms_hairpin_all_data

In [48]:
# # Define the GPT-2 model
# from transformers import GPT2Config, GPT2LMHeadModel

# # Create a new GPT-2 configuration and model
# config = GPT2Config(
#     vocab_size=len(tokenizer),
#     bos_token_id=tokenizer.bos_token_id,
#     eos_token_id=tokenizer.eos_token_id
# )
# model = GPT2LMHeadModel(config)

# # Define training arguments and train
# training_args = TrainingArguments(
#     output_dir='./GPT_mature_star_mirgendb',
#     num_train_epochs=15,
#     per_device_train_batch_size=8,
#     logging_dir='./logs',
#     learning_rate=5e-5
# )
# # mirgendb
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=tokenized_mirgendb_datasets,
# )

# trainer.train()

# # Save the pretrain model and tokenizer
# model.save_pretrained("after_prepocess_mirgendb")
# bpe_tokenizer.save("GPT_mature_star_bpe_tokenizer2.json")


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss
500,0.4705
1000,0.4283
1500,0.4193
2000,0.4113
2500,0.4022
3000,0.3953
3500,0.3914
4000,0.3846
4500,0.3773
5000,0.3726


### 4. Standard Fine-tuning with Second Data

In [31]:
# Load the pre-trained model for fine-tuning
model = GPT2LMHeadModel.from_pretrained("gff_ms_hairpin_all_data")   # withflanks: gff_ms_all_data / gff_ms

# Define fine-tuning arguments with a smaller learning rate
fine_tuning_args = TrainingArguments(
    output_dir='./mirgene_ms_hairpin_all_data_results',  # with flanks : mirgene_ms_all_data_results / mirgene_ms_results
    num_train_epochs=7,
    per_device_train_batch_size=8,
    logging_dir='./fine_tuned_mirgene_ms_hairpin_all_data_logs',    # with flanks : fine_tuned_mirgene_ms_all_data_logs / fine_tuned_mirgene_ms_logs
    learning_rate=1e-5,  # smaller learning rate for fine-tuning
)

# Create a trainer instance for fine-tuning
fine_tuning_trainer = Trainer(
    model=model,
    args=fine_tuning_args,
    train_dataset=tokenized_mirgendb_datasets,
)

# Fine-tune the model
fine_tuning_trainer.train()

# Save the fine-tuned model and tokenizer
model.save_pretrained("mirgene_ms_all_data_hairpin")    # with flanks : mirgene_ms_all_data / mirgene_ms

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss
500,0.1859
1000,0.1682
1500,0.1559
2000,0.1518
2500,0.137
3000,0.138
3500,0.1328
4000,0.1311
4500,0.121
5000,0.1194


### 5. Evaluate Model

In [36]:
# Load the fine-tuned model
model = GPT2LMHeadModel.from_pretrained("mirgene_ms_all_data_hairpin")  # with flanks : "gpt_rna_fine_tuned_ms"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(2134, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Drop

In [12]:
with open('genreated_500_ms_mirgen_gff_after_preprocess.txt', 'w') as f:
    for string in tqdm(generated_sequences):
        string = string.replace('<SOS>','').replace('<EOS>','')
        f.write(string + '\n')

100%|██████████| 500/500 [00:00<00:00, 202135.13it/s]


In [33]:
generated_sequences = []
with open('genreated_500_ms_mirgen_gff_after_preprocess.txt', 'r') as f:
    for line in f:
        # Add each line to the list after stripping newline characters
        generated_sequences.append(line.strip())