In [1]:
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from Bio.SeqUtils import nt_search
from gpt_utils import *
from Bio import pairwise2
from Bio.pairwise2 import format_alignment
import os
rna_fold_path = os.path.expanduser('~/.conda/envs/myenv/bin/RNAfold')
os.environ["PATH"] += os.pathsep + os.path.dirname(rna_fold_path)
import pandas as pd
import json
import subprocess
import RNA
from collections import Counter
import itertools



### Helper Functions


In [13]:
def get_ct_data(sequence, structure, output_file):
    ct_data = {}
    with open(output_file, 'w') as f:
        f.write(f"{len(sequence)} ENERGY\n")
        stack = []
        
        for i, (nt, struct) in enumerate(zip(sequence, structure), start=1):
            prev_nt = i - 1 if i != 1 else 0
            next_nt = i + 1 if i != len(sequence) else 0
            
            if struct == '(':
                stack.append(i)
                ct_data[i] = [i, nt, prev_nt, next_nt, 0, i]  # Temporarily set as unpaired
            elif struct == ')':
                partner = stack.pop()
                f.write(f"{i} {nt} {prev_nt} {next_nt} {partner} {i}\n")
                ct_data[i] = [i, nt, prev_nt, next_nt, partner, i]
                ct_data[partner][4] = i  # Update partner's pairing information
            else:
                f.write(f"{i} {nt} {prev_nt} {next_nt} 0 {i}\n")
                ct_data[i] = [i, nt, prev_nt, next_nt, 0, i]

    return ct_data

def dotbracket_to_ct(sequence, structure):
    # Initialize the list for CT data
    ct_data = []
    
    # Get the pair table from the dot-bracket structure
    # The pair table is 1-indexed and the first element is the length of the RNA
    pair_table = RNA.ptable(structure)
    # Loop through each nucleotide in the sequence
    for i, nt in enumerate(sequence, start=1):
        # Get the pairing partner from the pair table
        partner = pair_table[i]
        
        # Write the CT data: index, nt, prev, next, partner, index
        prev_nt = i - 1 if i > 1 else 0
        next_nt = i + 1 if i < len(sequence) else 0
        
        # Append the data as a tuple or a list
        ct_data.append((i, nt, prev_nt, next_nt, partner, i))
    
    return ct_data

                
def decode_structure(encoded_seq):
    # Reverse the translation_dict to map encoded chars back to their original pairs
    reverse_translation_struct = {v: k[0] for k, v in translation_dict.items()}

    decoded_structure = ''

    i = 0
    while i < len(encoded_seq):
        # Check for special tokens and treat them separately
        if encoded_seq[i:i+5] in {'ZZZZZ', 'BBBBB', 'DDDDD', 'FFFFF'}:
            decoded_structure += encoded_seq[i:i+5]
            i += 5
        elif encoded_seq[i] in reverse_translation_struct:
            decoded_structure += reverse_translation_struct[encoded_seq[i]]
            i += 1
        else:
            decoded_structure += '.'
            i += 1

    return decoded_structure

def is_valid_dot_bracket(structure, sequence):
    stack = []
    valid_pairs = [('U', 'G'), ('G', 'U'), ('C', 'G'), ('G', 'C'), ('A', 'U'), ('U', 'A')]
    invalid_pairs_dict = {}
    
    for i, char in enumerate(structure):
        if char == '(':
            stack.append((i, sequence[i]))
        elif char == ')':
            if not stack:
                return False, invalid_pairs_dict
            else:
                position, nucleotide = stack.pop()
                if (nucleotide, sequence[i]) not in valid_pairs:
                    invalid_pair = f"{nucleotide}{sequence[i]}"
                    invalid_pairs_dict[invalid_pair] = invalid_pairs_dict.get(invalid_pair, 0) + 1
                    
    # Check if there are any unmatched parentheses left
    if stack:
        return False, invalid_pairs_dict
    
    # Check if any invalid pairs were found
    if invalid_pairs_dict:
        return False, invalid_pairs_dict
        
    return True, invalid_pairs_dict


def calculate_novelty(gen_seq, real_sequences):
    local_scores = [(pairwise2.align.localxx(gen_seq, real_seq, score_only=True), real_seq) for real_seq in real_sequences]
    max_score, max_real_seq = max(local_scores, key=lambda x: x[0])
    normalized_max_score = max_score / min(len(gen_seq), len(max_real_seq))
    avg_score = sum(score for score, _ in local_scores) / len(real_sequences)
    return (round(avg_score, 2), round(normalized_max_score, 2))

def calculate_local_alignment(sequence1, sequence2):
    alignments = pairwise2.align.localxx(sequence1, sequence2)
    return alignments[0][2]

def calculate_diversity(gen_seq, gen_sequences):
    local_scores = [(pairwise2.align.localxx(gen_seq, gen1_seq, score_only=True), gen1_seq) for gen1_seq in gen_sequences if gen_seq != gen1_seq]
    max_score, max_real_seq = max(local_scores, key=lambda x: x[0])
    normalized_max_score = max_score / min(len(gen_seq), len(max_real_seq))
    avg_score = sum(score for score, _ in local_scores) / len(gen_sequences)
    return (round(avg_score, 2), round(normalized_max_score, 2))

# Helper function to count connections (in mature or star)
def count_connections(start, end, ct_data):
    connections = 0
    for i in range(start, end + 1):
        if ct_data[i][4] != 0 and not (start <= ct_data[i][4] <= end): 
            connections += 1
    return connections

# Helper function to count connections (in mature or star)
def count_connections_both(start_mature, start_star, end_mature, end_star, ct_data):
    connections = 0
    for i in range(start_mature, end_mature + 1):
        if ct_data[i][4] != 0 and (start_mature <= ct_data[i][0] <= end_mature) and (start_star <= ct_data[i][4] <= end_star): ## TODO: change the range of checking ct_data[i][4] ##
            connections += 1
    return connections

def calculate_max_bulge(mature_range, star_range, ct_data):
    max_bulge_mature = 0
    max_bulge_star = 0
    current_bulge = 0
    in_bulge = False
    bulge_start = -1
    bulge_end = -1
    mature_bulges = []
    star_bulges = []

    for i in range(1, len(ct_data)):
        if ct_data[i][4] == 0:  # If current position is unpaired
            if not in_bulge:
                bulge_start = i-1  # Start of the bulge
                # print(i,bulge_start)
                in_bulge = True
            current_bulge += 1
            # print(ct_data[i+1] , ct_data[i+1][4])
            if i == len(ct_data)-1 or ct_data[i+1][4] != 0:  # If it's the last position or next is paired
                bulge_end = i+1  # End of the bulge
                # print(i,bulge_end)
                # Check if the bulge is completely contained within the mature range
                if bulge_start >= mature_range[0] and bulge_end <= mature_range[1]:
                    # print(f"Mature Bulge found from index {bulge_start} to {bulge_end}. Size: {current_bulge}")
                    mature_bulges.append((bulge_start, bulge_end))
                    max_bulge_mature = max(max_bulge_mature, current_bulge)
                # Check if the bulge is completely contained within the star range
                elif bulge_start >= star_range[0] and bulge_end <= star_range[1]:
                    # print(f"Star Bulge found from index {bulge_start} to {bulge_end}. Size: {current_bulge}")
                    star_bulges.append((bulge_start, bulge_end))
                    max_bulge_star = max(max_bulge_star, current_bulge)
                current_bulge = 0
                in_bulge = False
        else:
            current_bulge = 0
            in_bulge = False

    return {
        "mature_max_bulge": max_bulge_mature,
        "star_max_bulge": max_bulge_star,
        "mature_bulges": mature_bulges,
        "star_bulges": star_bulges
    }


def debug_print(debug, *args):
    if debug:
        print(*args)

def check_seed_family(seed):
    # Load seed families from CSV

    filename = os.path.abspath("/sise/vaksler-group/IsanaRNA/Transformers/Rom/Data_source/seed_family_from_mirgendb.csv")

    # Read CSV using pandas
    df = pd.read_csv(filename, encoding='ISO-8859-1')

    # Create a dictionary from the 'Seed' and 'Family' columns
    seed_family_dict = df.set_index('Seed')['Family'].to_dict()
    
    return seed_family_dict.get(seed, "Unknown")

# Helper function to calculate UG & UGUG
def find_ug_sequences(decoded_seq, mature_start, mature_end, threshold=3):
    # Calculate the search ranges considering the threshold
    ug_search_start = max(mature_end - 14 - threshold, 0)
    ug_search_end = min(mature_start + threshold, len(decoded_seq))
    
    ugug_search_start = max(mature_end + 1 - threshold, 0)
    ugug_search_end = min(mature_end + 3 + threshold, len(decoded_seq))
    
    # Search for 'UG' and 'UGUG' sequences in the calculated ranges
    ug_index = decoded_seq.find('UG', ug_search_start, ug_search_end) + 1  # +1 to make it 1-indexed
    ug = ug_index if ug_index != 0 else "FALSE"
    
    ugug_index = decoded_seq.find('UGUG', ugug_search_start, ugug_search_end) + 1  # +1 to make it 1-indexed
    ugug = ugug_index if ugug_index != 0 else "FALSE"
    
    return ug, ugug

# Helper function to mer features (for sizes 1 or 2)
def calculate_mer_ratios(sequence, mer_size):
    total_length = len(sequence)
    # If mer_size is greater than the sequence length, return an error or handle appropriately
    if mer_size > total_length:
        return "(0.00 .. 0.00)"
    
    # Generate all possible combinations of nucleotides of length mer_size
    possible_mers = [''.join(p) for p in itertools.product('AUCG', repeat=mer_size)]
    mer_counts = Counter([sequence[i:i+mer_size] for i in range(total_length - mer_size + 1)])
    
    # Calculate ratios
    mer_ratios = {mer: mer_counts[mer] / total_length for mer in possible_mers}
    max_ratio = max(mer_ratios.values(), default=0)
    min_ratio = min(mer_ratios.values(), default=0)
    
    return f"({min_ratio:.2f} .. {max_ratio:.2f})"

def calculate_energy(sequence):
    # Calculate the secondary structure and the free energy of the structure
    structure, energy = RNA.fold(sequence)
    
    return energy

def is_valid_dot_bracket(structure, sequence):
    stack = []
    valid_pairs = [('U', 'G'), ('G', 'U'), ('C', 'G'), ('G', 'C'), ('A', 'U'), ('U', 'A')]
    invalid_pairs_dict = {}
    
    for i, char in enumerate(structure):
        if char == '(':
            stack.append((i, sequence[i]))
        elif char == ')':
            if not stack:
                return False, invalid_pairs_dict
            else:
                position, nucleotide = stack.pop()
                if (nucleotide, sequence[i]) not in valid_pairs:
                    invalid_pair = f"{nucleotide}{sequence[i]}"
                    invalid_pairs_dict[invalid_pair] = invalid_pairs_dict.get(invalid_pair, 0) + 1
                    
    # Check if there are any unmatched parentheses left
    if stack:
        return False, invalid_pairs_dict
    
    # Check if any invalid pairs were found
    if invalid_pairs_dict:
        return False, invalid_pairs_dict
        
    return True, invalid_pairs_dict


In [11]:
def run_rnafold(rna_sequence):
    # Start the RNAfold process
    process = subprocess.Popen(['RNAfold'],
                               stdin=subprocess.PIPE,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE,
                               text=True)

    # Send the RNA sequence and get the output
    output, error = process.communicate(rna_sequence)

    # Check for errors
    if error:
        print("Error:", error)
    else:
        return output

def retrieve_parts_only_nts(encoded_seq):

    # Identify positions of special tokens
    mature_start = encoded_seq.find('ZZZZZ')
    end_mature_indices = encoded_seq[mature_start + 5:].find('BBBBB')
    mature_end = mature_start + 5 + end_mature_indices

    star_start = encoded_seq.find('DDDDD')
    end_star_indices = encoded_seq[star_start + 5:].find('FFFFF')
    star_end = star_start + 5 + end_star_indices
    
    # Decode mature and star parts
    decoded_mature = encoded_seq[mature_start + 5:mature_end].replace('T', 'U')
    decoded_star = encoded_seq[star_start + 5:star_end].replace('T', 'U')
    
    decoded_mature = decoded_mature.replace('ZZZZZ', '').replace('BBBBB', '').replace('DDDDD', '').replace('FFFFF', '')
    decoded_star = decoded_star.replace('ZZZZZ', '').replace('BBBBB', '').replace('DDDDD', '').replace('FFFFF', '')
    
    intermediate_seq = encoded_seq.replace('ZZZZZ', '').replace('BBBBB', '').replace('DDDDD', '').replace('FFFFF', '')

    # Check for -1 in the find results
    if mature_start == -1 or end_mature_indices == -1 or star_start == -1 or end_star_indices == -1:
        return None

    # Decode full sequence and folding
    full_seq_folding = run_rnafold(intermediate_seq)
    if full_seq_folding:
        # print(full_seq_folding) 
        full_seq_folding = full_seq_folding.split('\n')[-2][:-9] # extract nesscery structre from rnafold output
        # print(len(full_seq_folding), len(intermediate_seq)) 
    else:
        return -1
    full_seq = intermediate_seq.replace('T', 'U')


    return {
        'encoded_seq': encoded_seq,
        'decoded_seq': full_seq,
        'mature': decoded_mature,
        'star': decoded_star,
        'full_seq_folding': full_seq_folding
    }


def adjust_index(index):
    return index + 1

def extract_features_only_nts(decoded_seq, full_seq_folding, mature, star):
    # Save CT file
    if is_valid_dot_bracket(full_seq_folding,decoded_seq):
        ct_data = get_ct_data(decoded_seq, full_seq_folding, "output.ct")
    else:
        return None

    # Locate mature and star sequences within the full sequence
    mature_start = decoded_seq.find(mature)
    mature_end = mature_start + len(mature) - 1  # Adjust to 0-based indexing
    star_start = decoded_seq.find(star)
    star_end = star_start + len(star) - 1  # Adjust to 0-based indexing
    if mature_start < 0 or star_start < 0:
        print('mature ot start invalid')
        return None

    if mature_start < star_start:
        loop = decoded_seq[mature_end+1:star_start]
        flank1 = decoded_seq[:mature_start]
        flank2 = decoded_seq[star_end+1:]
        direction = '5p'
    else:
        loop = decoded_seq[star_end+1:mature_start]
        flank1 = decoded_seq[:star_start]
        flank2 = decoded_seq[mature_end+1:]
        direction = '3p'

    # Compute the seed section
    seed_start = mature_start + 2  # 2nd nucleotide, 0-indexed
    seed_end = seed_start + 7  # 7th nucleotide, 0-indexed
    seed = decoded_seq[seed_start-1:seed_end-1]  
    seed_family = check_seed_family(seed) # loading dict from shared folder 

    # Calculate the indices for Loop, flank1, and flank2
    loop_start_index = decoded_seq.find(loop) + 1
    loop_end_index = loop_start_index + len(loop) - 1

    flank1_start_index = decoded_seq.find(flank1) + 1
    flank1_end_index = flank1_start_index + len(flank1) - 1

    flank2_start_index = decoded_seq.find(flank2) + 1
    flank2_end_index = flank2_start_index + len(flank2) - 1
    
    mature_range, star_range = (adjust_index(mature_start), adjust_index(mature_end)), (adjust_index(star_start), adjust_index(star_end))
    
    # Feature calculations
    # print(mature_start, mature_end , len(ct_data))
    mature_connections = count_connections(adjust_index(mature_start), adjust_index(mature_end), ct_data)
    
    mature_star_connections = count_connections_both(adjust_index(mature_start), adjust_index(star_start), adjust_index(mature_end), adjust_index(star_end), ct_data)

    mature_bp_ratio = mature_connections / len(mature) if len(mature) != 0 else 0
    calculate_max_bulge_result = calculate_max_bulge(mature_range, star_range, ct_data)
    mature_max_bulge, star_max_bulge = calculate_max_bulge_result["mature_max_bulge"], calculate_max_bulge_result["star_max_bulge"]
    mature_bulges = calculate_max_bulge_result["mature_bulges"]
    star_bulges = calculate_max_bulge_result["star_bulges"]
    # mature_max_asymmetry = calculate_bulge_asymmetry(mature_bulges, ct_data , 'mature')
    # print('mature',mature_bulges)
    # star_max_asymmetry = calculate_bulge_asymmetry(star_bulges, ct_data,'star')
    # print('star',star_bulges)
    star_connections = count_connections(adjust_index(star_start), adjust_index(star_end), ct_data)
    star_bp_ratio = star_connections / len(star) if len(star) != 0 else 0
    ug, ugug = find_ug_sequences(decoded_seq, mature_start, mature_end, threshold=10)
    hairpin_trimmed = decoded_seq.replace(flank1, '').replace(flank2,'')
    h_start = adjust_index(decoded_seq.find(hairpin_trimmed))
    h_end = mature_start + len(hairpin_trimmed)
    energy = calculate_energy(hairpin_trimmed)
    one_mer_mature = calculate_mer_ratios(mature, 1)
    two_mer_mature = calculate_mer_ratios(mature, 2)
    one_mer_h_trimm = calculate_mer_ratios(hairpin_trimmed, 1)
    two_mer_h_trimm = calculate_mer_ratios(hairpin_trimmed, 2)
    one_mer_full = calculate_mer_ratios(decoded_seq, 1)
    two_mer_full = calculate_mer_ratios(decoded_seq, 2)

    feature_dict = {
        'full_seq': decoded_seq,
        'Mature': mature,
        'Mature_length': len(mature),
        'Star': star,
        'End_star': adjust_index(star_end),
        'full_seq_folding': full_seq_folding,
        'Loop_seq': loop,
        'Loop_length': len(loop),
        'flank1': flank1,
        'flank2': flank2,
        'flank1_length': len(flank1),
        'flank2_length': len(flank2),
        'Mature_start': adjust_index(mature_start),
        'Mature_end': adjust_index(mature_end),
        'Star_start': adjust_index(star_start),
        'Star_end': adjust_index(star_end),
        'Star_length': len(star),
        'Mature_length': len(mature),
        'Loop_seq_start': loop_start_index,
        'Loop_seq_end': loop_end_index,
        'flank1_start': flank1_start_index,
        'flank1_end': flank1_end_index,
        'flank2_start': flank2_start_index,
        'flank2_end': flank2_end_index,
        'seed': seed,
        'seed_start': seed_start,
        'seed_end': seed_end,
        'seed_family': seed_family,
        '3p/5p': direction,
        'Mature_connections': mature_connections,
        'Mature_Star_connections': mature_star_connections,
        'Mature_BP_ratio': round(mature_bp_ratio, 2),
        'Mature_max_bulge': mature_max_bulge,
        'Star_connections': star_connections,
        'Star_BP_ratio': round(star_bp_ratio, 2),
        'Star_max_bulge': star_max_bulge,
        'UG': ug,
        'UGUG': ugug,
        'hairpin_trimmed': hairpin_trimmed,
        'hairpin_trimmed_length': len(hairpin_trimmed),
        'folding_energy': round(energy, 2),
        'one_mer_mature': one_mer_mature,
        'two_mer_mature': two_mer_mature,
        'one_mer_hairpin_trimmed': one_mer_h_trimm ,
        'two_mer_hairpin_trimmed': two_mer_h_trimm ,
        'one_mer_full': one_mer_full,
        'two_mer_full': two_mer_full,
    }

    return feature_dict

def process_rna_sequences(rna_sequences):
    results = []
    invalid_results = []
    
    summary = {
        'multiple_mature': 0,
        'multiple_star': 0,
        'missing_mature': 0,
        'missing_star': 0,
        'invalid_dot_bracket': 0,
        'invalid_pairs': {},
        'total_sequences': len(rna_sequences),
        'unique_before_filtering': len(set(rna_sequences)),  # Number of unique sequences before filtering
        'unique_after_filtering': 0,  # To be calculated after filtering
        'unique_invalid': 0,
        'failed_rnafold': 0  
    }
    
    for seq in tqdm(rna_sequences, desc="Processing sequences"):
        if 'ZZZZZ' not in seq or 'BBBBB' not in seq:
            summary['missing_mature'] += 1
        elif seq.count('ZZZZZ') > 1 or seq.count('BBBBB') > 1:
            summary['multiple_mature'] += 1

        if 'DDDDD' not in seq or 'FFFFF' not in seq:
            summary['missing_star'] += 1
        elif seq.count('DDDDD') > 1 or seq.count('FFFFF') > 1:
            summary['multiple_star'] += 1

        parts = retrieve_parts_only_nts(seq)  # Ensure retrieve_parts is defined somewhere
        if parts:
            if parts == -1:
                summary['failed_rnafold'] += 1
                continue
            is_valid, invalid_pairs = is_valid_dot_bracket(parts['full_seq_folding'], parts['decoded_seq'])
            if is_valid:
                results.append(parts)
            else:
                summary['invalid_dot_bracket'] += 1
                invalid_result = {'sequence': parts['full_seq_folding'], 'encoded_seq': seq,'invalid_pairs':invalid_pairs}
                invalid_result.update(invalid_pairs)
                invalid_results.append(invalid_result)
                for pair, count in invalid_pairs.items():
                    summary['invalid_pairs'][pair] = summary['invalid_pairs'].get(pair, 0) + count

    results_df = pd.DataFrame(results).drop_duplicates(subset='decoded_seq')
    invalid_results_df = pd.DataFrame(invalid_results)
    
    # Calculate unique values after filtering
    summary['unique_after'] = results_df['decoded_seq'].nunique() if 'decoded_seq' in results_df.columns else 0
    summary['unique_invalid'] = invalid_results_df['sequence'].nunique() if 'sequence' in invalid_results_df.columns else 0
    
    return results_df, invalid_results_df, summary

def extract_features_from_df(rna_df):
    features = []
    # Add tqdm around your loop
    for index, row in tqdm(rna_df.iterrows(), total=len(rna_df), desc="Extracting features"):
        feature = extract_features_only_nts(row['decoded_seq'], row['full_seq_folding'], row['mature'], row['star'])
        if feature:
            features.append(feature)

    features_df = pd.DataFrame(features)
    return features_df

def remove_flanks(full_seq, flank1, flank2):
    result = full_seq.replace(flank1, "").replace(flank2, "")
    return result
    
def add_pre_mirna_and_remove_duplicates(rna_df):
    rna_df['pre_mirna'] = rna_df.apply(lambda row: remove_flanks(row['full_seq'], row['flank1'], row['flank2']), axis=1)
    rna_df = rna_df.drop_duplicates(subset='pre_mirna', keep='first')
    return rna_df

def remove_duplicates(rna_df):
    rna_df['pre_mirna'] = rna_df['full_seq']
    rna_df = rna_df.drop_duplicates(subset='pre_mirna', keep='first')
    return rna_df

    
def add_diversity_novelty_to_df(rna_df, real_sequences_no_flanks):
    # Calculate local and global diversity scores
    tqdm.pandas() # This is for using tqdm with pandas' apply method

    rna_df['avg_diversity'], rna_df['max_diversity'] = zip(*rna_df['pre_mirna'].progress_apply( lambda x: calculate_diversity(x, rna_df['pre_mirna'].tolist())))
    rna_df['avg_novelty'], rna_df['max_novelty'] = zip(*rna_df['pre_mirna'].progress_apply(lambda x: calculate_novelty(x, real_sequences_no_flanks)))
    return rna_df

def add_diversity_to_original_df(df, df2):
    df['avg_diversity'], df['max_diversity'] = zip(*df['full_seq'].progress_apply( lambda x: calculate_diversity(x, df2['full_seq'].tolist())))

    return df

### Add Pre-mirna & mers & connections columns to original

In [4]:
# add mers and connections to original data:

def extract_features_only_mers(full_seq, full_seq_folding, mature, star):
    ct_data = get_ct_data(full_seq, full_seq_folding, "output_original.ct")

    mature_start = full_seq.find(mature)
    mature_end = mature_start + len(mature) - 1  # Adjust to 0-based indexing
    star_start = full_seq.find(star)
    star_end = star_start + len(star) - 1  # Adjust to 0-based indexing
    
    if mature_start < 0 or star_start < 0:
        print('mature ot start invalid')
        return None

    if mature_start < star_start:
        loop = full_seq[mature_end+1:star_start]
        flank1 = full_seq[:mature_start]
        flank2 = full_seq[star_end+1:]
        direction = '5p'
    else:
        loop = full_seq[star_end+1:mature_start]
        flank1 = full_seq[:star_start]
        flank2 = full_seq[mature_end+1:]
        direction = '3p'
        
    hairpin_trimmed = full_seq.replace(flank1, '').replace(flank2,'')
    one_mer_mature = calculate_mer_ratios(mature, 1)
    two_mer_mature = calculate_mer_ratios(mature, 2)
    one_mer_h_trimm = calculate_mer_ratios(hairpin_trimmed, 1)
    two_mer_h_trimm = calculate_mer_ratios(hairpin_trimmed, 2)
    one_mer_full = calculate_mer_ratios(full_seq, 1)
    two_mer_full = calculate_mer_ratios(full_seq, 2)
    
    mature_star_connections = count_connections_both(adjust_index(mature_start), adjust_index(star_start), adjust_index(mature_end), adjust_index(star_end), ct_data)

    feature_dict = {
        'one_mer_mature': one_mer_mature,
        'two_mer_mature': two_mer_mature,
        'one_mer_hairpin_trimmed': one_mer_h_trimm ,
        'two_mer_hairpin_trimmed': two_mer_h_trimm ,
        'one_mer_full': one_mer_full,
        'two_mer_full': two_mer_full,
        'Mature_Star_connections': mature_star_connections,
    }
    
    return feature_dict
    
def extract_mers_from_df(rna_df):
    features = []
    # Add tqdm around your loop
    for index, row in tqdm(rna_df.iterrows(), total=len(rna_df), desc="Extracting features"):
        feature = extract_features_only_mers(row['full_seq'], row['full_seq_folding'], row['Mature'], row['Star'])
        if feature:
            features.append(feature)
    new_features_df = pd.DataFrame(features)
    features_df = pd.concat([rna_df, new_features_df], axis=1)

    return features_df

In [2]:
import math

In [3]:
def calculate_statistics(sequences):
    num_sequences = len(sequences)
    avg_length = int(sum(len(seq) for seq in sequences) / num_sequences) # if isinstance(seq, str)
    min_length = min(len(seq) for seq in sequences)
    max_length = max(len(seq) for seq in sequences)

    return {
        "Number of sequences": num_sequences,
        "Average sequence length": avg_length,
        "Min sequence length": min_length,
        "Max sequence length": max_length
    }


## Extract Features

In [4]:
def fasta_to_tuples(filename):
    with open(filename, 'r') as f:
        content = f.read().strip().split("\n")

    sequences = []
    for i in range(0, len(content), 2):
        header = content[i][1:]  # remove '>'
        sequence = content[i+1]
        sequences.append((header, sequence))

    return sequences

def tuples_to_jsonl(sequences, output_filename):
    with open(output_filename, 'w') as f:
        for header, sequence in sequences:
            data = {
                "id": header,
                "sequence": sequence
            }
            f.write(json.dumps(data) + '\n')

#### test set

In [5]:
with open('/sise/vaksler-group/IsanaRNA/Transformers/GPT_env/generated_100_test_human.txt', 'r') as f: # genreated_new_10000_human_mirgenedb_gff_ms, 'genreated_52_split_human_mirgenedb_gff_ms.txt'
    generated_sequences = f.readlines()
real_sequences_path = '/sise/vaksler-group/IsanaRNA/Transformers/GPT_env/seq_clusters/human_test_full_seqs.txt'
with open(real_sequences_path, "r") as file:
    real_sequences = file.readlines()
generated_sequences = [s.strip() for s in generated_sequences] ; real_sequences = [s.strip() for s in real_sequences]
generated_sequences = [s for s in generated_sequences if 'N' not in s]


#### All data:

In [7]:
with open('/sise/vaksler-group/IsanaRNA/Transformers/GPT_env/generated_100_full_data_hairpin_human.txt', 'r') as f: # genreated_new_10000_human_no_flanks_mirgenedb_gff_ms / genreated_new_10000_human_mirgenedb_gff_ms, 'genreated_52_split_human_mirgenedb_gff_ms.txt'
    generated_sequences = f.readlines()
with open('/sise/vaksler-group/IsanaRNA/Transformers/GPT_env/human_fine_tuned_original_data.txt', 'r') as f: # 
    real_sequences = f.readlines()
    # # # without flanks:
human_data_fasta = fasta_to_tuples("/sise/vaksler-group/IsanaRNA/Transformers/GPT_env/Data_source/miRGeneDB/precursors_human_no_flank.fas.txt")
real_sequences_no_flanks = [s[1] for s in human_data_fasta]
real_sequences_no_flanks = [s.strip() for s in real_sequences_no_flanks]
# Remove the newline characters at the end of each line
generated_sequences = [s.strip() for s in generated_sequences] ; real_sequences = [s.strip() for s in  real_sequences]
generated_sequences = [s for s in generated_sequences if 'N' not in s]

('ZZZZZTATACCTCAGTTTTATCAGGTGBBBBBTTCTTAAAATCADDDDDCCTGGAAACACTGAGGTTGTGTFFFFF',
 'GAAGAAGAAGACCCAAUGCCCGGGGAGAAGUACGGUGAGCCUGUCAUUAUUCAGAGAGGCUAGAUCCUCUGUGUUGAGAAGGAUCAUGAUGGGCUCCUCGGUGUUCUCCAGGUAGCGGCACCACACCAUGAAGG',
 'UGAGGUAGUAGUUUGUGCUGUUGGUCGGGUUGUGACAUUGCCCGCUGUGGAGAUAACUGCGCAAGCUACUGCCUUGC')

In [14]:
rna_sequences = generated_sequences
rna_df, invalid_results_df, summary = process_rna_sequences(rna_sequences)

# # Create directory for results if it doesn't exist
output_dir = "./generated_human_all_data_hairpin_seq" #generated_human_test_seq / generated_no_flanks_results_only_nts / generated_spilt_52_results_only_nts
os.makedirs(output_dir, exist_ok=True)

# Save results to CSV files
rna_df.to_csv(os.path.join(output_dir, "human_gen_all_sequences.csv"), index=False) # rna_sequences #intermidiate results only with parts 
invalid_results_df.to_csv(os.path.join(output_dir, "invalid_sequences_human_all_gen.csv"), index=False) #invalid_sequences # invalid sequences with 

# Save summary to a JSON file
with open(os.path.join(output_dir, "summary_human_all_gen.json"), 'w') as f:  #summary / summary_human_gen_completion_hairpin
    json.dump(summary, f)

print("Results saved to", output_dir)
print("Summary:", summary)
print("Invalid Results:")
print(invalid_results_df)

Processing sequences: 100%|██████████| 100/100 [00:00<00:00, 175.81it/s]


Results saved to ./generated_human_all_data_hairpin_seq
Summary: {'multiple_mature': 6, 'multiple_star': 2, 'missing_mature': 12, 'missing_star': 8, 'invalid_dot_bracket': 0, 'invalid_pairs': {}, 'total_sequences': 100, 'unique_before_filtering': 100, 'unique_after_filtering': 0, 'unique_invalid': 0, 'failed_rnafold': 0, 'unique_after': 80}
Invalid Results:
Empty DataFrame
Columns: []
Index: []


In [None]:
# Extract features and save to CSV
features_df = extract_features_from_df(rna_df)
print('Finish extract features\nAdding diversity values:')
# features_df = add_diversity_to_df(features_df, real_sequences)

# no flanks - only pre-mirna
# to check the diversity and novelity between the sequences without considering the flanks:
# features_df = add_pre_mirna_and_remove_duplicates(features_df) # with flanks
# features_df = remove_duplicates(features_df)  # for no flanks

# features_df = add_diversity_novelty_to_df(features_df, real_sequences)
features_df.to_csv(os.path.join(output_dir, "gen_human_test_features.csv"), index=False)
features_df