In [1]:
import os
os.chdir("/Users/aliahmadi/Documents/Projects/RNA-Secondary-Structure-Prediction/notebook/EKH-25")

## Imports

In [2]:
from collections import defaultdict
from scipy.linalg import expm
from grammar.pcnf import PCNF
from Bio import Phylo, SeqIO
from copy import deepcopy
from io import StringIO
import networkx as nx
from math import log 
import numpy as np
import shutil
import pickle

___

## IUPAC Nucleotide table

In [3]:
iupac_nucleotides = {
    'A':["A"],
    'C':["C"],
    'G':["G"],
    'U':["U"], 
    'R':["A", "G"], 
    'Y':["C", "U"],
    'S':["G", "C"],
    'W':["A", "U"],
    'K':["G", "U"],
    'M':["A", "C"],
    'B':["C", "G", "U"],
    'D':["A", "G", "U"],
    'H':["A", "C", "U"],
    'V':["A", "C", "G"],
    'N':["A", "C", "G", "U"],
    '-': ["-"]
}

## Global Variables

In [4]:
trees = defaultdict(lambda: Phylo.read(StringIO("();"), "newick"))
sequences = defaultdict(lambda: defaultdict(str))
structures = defaultdict(str)

## Pairing Charecters

In [5]:
def get_pair_start(pair_end):
    pairing_chars = [("<", ">"), ("(", ")"), ("[", "]"), ("{", "}")]
    return [val for val in pairing_chars if val[1] == pair_end][0][0]

In [6]:
def is_pair_start(char):
    pairing_chars = [("<", ">"), ("(", ")"), ("[", "]"), ("{", "}")]
    return char in [p[0] for p in pairing_chars]

In [7]:
def is_pair_end(char):
    pairing_chars = [("<", ">"), ("(", ")"), ("[", "]"), ("{", "}")]
    return char in [p[1] for p in pairing_chars]

## Verify Sequences

In [8]:
def is_valid_str(input_string, valid_characters = ['A','C', 'G', 'U']):
    for char in input_string:
        if char not in valid_characters:
            return False
    return True

## Read Trees

In [9]:
def clade_names_fix(tree):
    for index, clade in enumerate(tree.find_clades()):
        if not clade.name:
            clade.name = str(index)

def read_tree(filename: str):
    Tree = Phylo.read(filename, 'newick')
    clade_names_fix(Tree)
    return Tree

def open_tree(filename: str, dataset_name):
    Tree = Phylo.read(filename, 'newick')
    clade_names_fix(Tree)
    
    trees[dataset_name] = Tree

## Read Sequences

In [10]:
def read_sequences (filename: str, dataset_name):
    dataset = defaultdict(str)
    
    with open(filename) as file:
        records = SeqIO.parse(file, "phylip-relaxed")
        count = 0
        
        for record in records:
            dataset[record.id] = record.seq
            count += 1
           
        dataset["_weight"] = 1000 / count   
        
    sequences[dataset_name] = dataset

## Read Structures

In [11]:
def read_structure(filename: str, dataset_name):
    with open(filename) as file:
        structures[dataset_name] = file.readline().strip()

## Simplify Structure

In [12]:
def simplify_struct(structures, filename="./primaries/structures"):
    pairing_chars = ["<", ">", "(", ")", "[", "]", "{", "}"]
    simplified = ""
    for _, structure in structures.items():
        for i in range(len(structure)):
            if structure[i] in pairing_chars:
                simplified += "d "
            else:
                simplified += "s "
        simplified += "\n"
    
    simplified = simplified.rstrip('\n')
    
    if filename != "":        
        with open(f"{filename}.train", "w+") as file:
            file.write(simplified)

## Create Tree

In [13]:
def create_tree(input_sequences, filename = "./outputs/tree.nwk", draw=True):
  names = list(input_sequences.keys())
  sequences = list(input_sequences.values())
  
  if len(names) <= 2:
    return None
  
  os.mkdir("./tmp")

  if len(names) > 2:
    phylip_file = "./tmp/sequences.phylip"
    with open(phylip_file, "w") as f:
      f.write(f"{len(sequences)} {len(sequences[0])}\n\n")
      for i, seq in enumerate(sequences):
        f.write(f"{names[i]}\t{seq}\n") 
    
    !./phyml -i tmp/sequences.phylip -m GTR
    
    output_tree = Phylo.read(phylip_file + "_phyml_tree.txt", 'newick')

    output_tree.root_at_midpoint()
  
  for index, clade in enumerate(output_tree.find_clades()):
    if not clade.name:
      clade.name = str(index)
  
  shutil.rmtree("./tmp")
  Phylo.write(output_tree, filename, "newick")
  
  Phylo.draw(output_tree)
  
  return output_tree

## Calculate Frequencies

In [14]:
def calc_frequencies (sequences, structures):
    single_nucleotides = defaultdict(int)
    paired_nucleotides = defaultdict(int)
    
    total_singles = 0
    total_paireds = 0
    
    # Stack of unpaired nucleotides (structure symbol, nocleotide)
    unpaired_nucleotides = []
    for dataset_name, dataset_values in sequences.items():
        sequences_weight = dataset_values["_weight"]
        dataset_sequences = [v for k, v in dataset_values.items() if k != "_weight"]
        dataset_structure = structures[dataset_name]
        for sequence in dataset_sequences:
            for index, nucleotide in enumerate(sequence): 
                structure_symbol = dataset_structure[index]
                # Having a charecter [ '(', '[', '{', '<' ]
                if is_pair_start(structure_symbol):
                    unpaired_nucleotides.append((structure_symbol, iupac_nucleotides[nucleotide]))
                # Having a charecter [ ')', ']', '}', '>' ]
                elif is_pair_end(structure_symbol):
                    unpaired_nucleotide = unpaired_nucleotides.pop()
                    if unpaired_nucleotide[0] == get_pair_start(structure_symbol):
                        targets = [f"{nucl1}{nucl2}" 
                            for nucl1 in unpaired_nucleotide[1] 
                            for nucl2 in iupac_nucleotides[nucleotide]
                        ]
                        for target in targets:
                            if is_valid_str(target):
                                # print(len(targets))
                                paired_nucleotides[target] += sequences_weight / len(targets)
                                paired_nucleotides[target[::-1]] += sequences_weight / len(targets)
                                
                                total_paireds += sequences_weight / len(targets)
                    else:
                        raise ValueError('Invalid pattern in structure')
                # Having a non-pairing charecter
                else:
                    targets = iupac_nucleotides[nucleotide]
                    for target in targets:
                        if is_valid_str(nucleotide):
                            # print(len(targets))
                            single_nucleotides[target] += sequences_weight / len(targets)
                            
                            total_singles += sequences_weight / len(targets)

    total = total_singles + (total_paireds * 2)
                
    for key, value in single_nucleotides.items():
        single_nucleotides[key] = value / total_singles
    for key, value in paired_nucleotides.items():
        paired_nucleotides[key] = value / (total_paireds * 2)
        
    return (
        single_nucleotides,
        paired_nucleotides,
        total_singles / total,
        (total_paireds * 2) / total
    )

## Check simularity

In [15]:
def check_simularity(seq1:str, seq2:str):
    simularity = 0
    if len(seq1) == len(seq2):
        for i in range(len(seq1)):
            if seq1[i] == seq2[i]:
                simularity += 1
        return simularity / len(seq1) >= .85

In [16]:
def check_similarity(seq1: str, seq2: str) -> bool:
    # Validate input
    if len(seq1) != len(seq2):
        return False  # Sequences must be of the same length

    # Define nucleotide groups
    pyrimidines = {'C', 'U'}  # Cytosine, Uracil
    purines = {'A', 'G'}      # Adenine, Guanine

    # Initialize similarity count
    similarity = 0

    for i in range(len(seq1)):
        # Handle IUPAC codes
        possible_nucleotides1 = iupac_nucleotides.get(seq1[i], [seq1[i]])
        possible_nucleotides2 = iupac_nucleotides.get(seq2[i], [seq2[i]])
        
        match_score = 0
        for n1 in possible_nucleotides1:
            for n2 in possible_nucleotides2:
                if n1 == n2:
                    match_score += 1 / (len(possible_nucleotides1) * len(possible_nucleotides2))
                elif (n1 in pyrimidines and n2 in pyrimidines) or (n1 in purines and n2 in purines):
                    match_score += 0.5 / (len(possible_nucleotides1) * len(possible_nucleotides2))                    
        
        similarity += match_score 

    # Calculate final similarity ratio
    similarity_ratio = similarity / len(seq1)

    return similarity_ratio >= 0.85

## Calculate Rate Values

In [17]:
def calc_rate_values(
    trees,
    sequences: defaultdict, 
    structures, 
    single_frequencies, 
    paired_frequencies, 
    singles_prob, 
    paireds_prob
):
    unpaired_nucleotides = []
    
    single_mutation_count = defaultdict(float)
    paired_mutation_count = defaultdict(float)
    k_value = 0
    
    for dataset_name, dataset_values in sequences.items():
        sequences_weight = dataset_values["_weight"]
        dataset_sequences = [(k, v) for k, v in dataset_values.items() if k != "_weight"]
        dataset_structure = structures[dataset_name]
        
        for i in range(len(dataset_sequences)):
            temp_single_mutation_count = defaultdict(int)
            temp_paired_mutation_count = defaultdict(int)
            k_temp = 0
            
            same_first_sequence_count = 0
            
            for j in range(len(dataset_sequences)):
                columns_count = 0
                
                first_name = dataset_sequences[i][0]
                second_name = dataset_sequences[j][0]
                first_sequence = dataset_sequences[i][1]
                second_sequence = dataset_sequences[j][1]
                # The pair should contain diffrent 
                # sequence with at least %85 simularity.
                if i != j and check_simularity(first_sequence, second_sequence):
                    same_first_sequence_count += 1
                    for k in range(len(first_sequence)): 
                        structure_symbol = dataset_structure[k]
                        # Having a charecter [ '(', '[', '{', '<' ]
                        if is_pair_start(structure_symbol):
                            unpaired_nucleotides.append((
                                structure_symbol, 
                                iupac_nucleotides[first_sequence[k]],
                                iupac_nucleotides[second_sequence[k]]
                            ))
                        # Having a charecter [ ')', ']', '}', '>' ]
                        elif is_pair_end(structure_symbol):
                            unpaired_nucleotide = unpaired_nucleotides.pop()
                            
                            first_side_targets = [f"{nucl1}{nucl2}" 
                                for nucl1 in unpaired_nucleotide[1] 
                                for nucl2 in iupac_nucleotides[first_sequence[k]]
                            ]
                            second_side_targets = [f"{nucl1}{nucl2}" 
                                for nucl1 in unpaired_nucleotide[2] 
                                for nucl2 in iupac_nucleotides[second_sequence[k]]
                            ]
                            
                            if unpaired_nucleotide[0] == get_pair_start(structure_symbol):
                                for first_side in first_side_targets:
                                    for second_side in second_side_targets:
                                        if (is_valid_str(first_side) 
                                        and is_valid_str(second_side)):
                                            columns_count += (2 * sequences_weight) / (len(first_side_targets) * len(second_side_targets))
                                            
                                            if first_side != second_side:
                                                temp_paired_mutation_count[
                                                    (first_side, second_side)
                                                ] += sequences_weight / (len(first_side_targets) * len(second_side_targets))
                                                temp_paired_mutation_count[
                                                    (first_side[::-1], second_side[::-1])
                                                ] += sequences_weight / (len(first_side_targets) * len(second_side_targets))
                            else:
                                raise ValueError('Invalid pattern in structure')    
                        # Having a non-pairing charecter
                        else:
                            first_side_targets = iupac_nucleotides[first_sequence[k]]
                            second_side_targets = iupac_nucleotides[second_sequence[k]]
                            
                            for first_side in first_side_targets:
                                for second_side in second_side_targets:
                                    if (is_valid_str(first_side) 
                                    and is_valid_str(second_side)):
                                        columns_count += sequences_weight / (len(first_side_targets) * len(second_side_targets))
                                        if first_side != second_side:
                                            temp_single_mutation_count[(
                                                first_side, 
                                                second_side,
                                            )] += sequences_weight / (len(first_side_targets) * len(second_side_targets))
                    
                    k_temp += (trees[dataset_name].distance(
                        first_name, 
                        second_name
                    ) * columns_count)
                    
            
            if same_first_sequence_count > 0:
                k_value += (k_temp / same_first_sequence_count)
                
                for key in temp_single_mutation_count:
                    single_mutation_count[key] += (temp_single_mutation_count[key] 
                                                / same_first_sequence_count)
                for key in temp_paired_mutation_count:
                    paired_mutation_count[key] += (temp_paired_mutation_count[key] 
                                                / same_first_sequence_count)
                
    single_chars = ["A", "C", "G", "U"]
    paired_chars = [c1 + c2 for c1 in single_chars for c2 in single_chars]
    
    single_rate_values = defaultdict(float)
    paired_rate_values = defaultdict(float)
            
    for i in single_chars:
        single_rate_values[(i,i)] = 0
        for j in single_chars:
            if i != j:
                single_rate_values[(i,j)] = (single_mutation_count[(i,j)] 
                                             / (singles_prob * single_frequencies[i] * k_value))
                single_rate_values[(i,i)] = single_rate_values[(i,i)] - single_rate_values[(i,j)]

    for i in paired_chars:
        paired_rate_values[(i,i)] = 0
        for j in paired_chars:
            if i != j:
                paired_rate_values[(i,j)] = ((paired_mutation_count[(i,j)] * 2)
                                             / (paireds_prob * paired_frequencies[i] * k_value))
                paired_rate_values[(i,i)] = paired_rate_values[(i,i)] - paired_rate_values[(i,j)]
    
    return single_rate_values, paired_rate_values  

## Save PCFG

In [18]:
def save_pcfg(pcfg, filename):
    unary_rules = pcfg.grammar.unary_rules
    binary_rules = pcfg.grammar.binary_rules
    
    with open(f"{filename}.pcfg", "w+") as pcfg_file:
        for A, B, C in binary_rules:
            pcfg_file.write(f"{A} -> {B} {C} {pcfg.q[(A, B, C)]}\n")
            
        for A, w in unary_rules:
            pcfg_file.write(f"{A} -> {w} {pcfg.q[(A, w)]}\n")

___

## Load Primary Datas

In [19]:
open_tree("./primaries/trees/RF00001.nwk", "RF00001")
open_tree("./primaries/trees/RF00005.nwk", "RF00005")
# open_tree("./primaries/trees/RF00162.nwk", "RF00162")
# open_tree("./primaries/trees/RF01704.nwk", "RF01704")
# open_tree("./primaries/trees/RF01734.nwk", "RF01734")
# open_tree("./primaries/trees/RF01739.nwk", "RF01739")
# open_tree("./primaries/trees/RF02035.nwk", "RF02035")
# open_tree("./primaries/trees/RF02957.nwk", "RF02957")
open_tree("./primaries/trees/RF03000.nwk", "RF03000")
# open_tree("./primaries/trees/RF03054.nwk", "RF03054")
# open_tree("./primaries/trees/RF03135.nwk", "RF03135")

read_sequences("./primaries/phylips/RF00001.phylip", "RF00001")
read_sequences("./primaries/phylips/RF00005.phylip", "RF00005")
# read_sequences("./primaries/phylips/RF00162.phylip", "RF00162")
# read_sequences("./primaries/phylips/RF01704.phylip", "RF01704")
# read_sequences("./primaries/phylips/RF01734.phylip", "RF01734")
# read_sequences("./primaries/phylips/RF01739.phylip", "RF01739")
# read_sequences("./primaries/phylips/RF02035.phylip", "RF02035")
# read_sequences("./primaries/phylips/RF02957.phylip", "RF02957")
read_sequences("./primaries/phylips/RF03000.phylip", "RF03000")
# read_sequences("./primaries/phylips/RF03054.phylip", "RF03054")
# read_sequences("./primaries/phylips/RF03135.phylip", "RF03135")

read_structure("./primaries/structures/RF00001.structure", "RF00001")
read_structure("./primaries/structures/RF00005.structure", "RF00005")
# read_structure("./primaries/structures/RF00162.structure", "RF00162")
# read_structure("./primaries/structures/RF01704.structure", "RF01704")
# read_structure("./primaries/structures/RF01734.structure", "RF01734")
# read_structure("./primaries/structures/RF01739.structure", "RF01739")
# read_structure("./primaries/structures/RF02035.structure", "RF02035")
# read_structure("./primaries/structures/RF02957.structure", "RF02957")
read_structure("./primaries/structures/RF03000.structure", "RF03000")
# read_structure("./primaries/structures/RF03054.structure", "RF03054")
# read_structure("./primaries/structures/RF03135.structure", "RF03135")

## Calculate Evolutionary Parameters

In [20]:
(single_frequencies, 
 paired_frequencies, 
 singles_prob, 
 paireds_prob) = calc_frequencies(sequences, structures)

(single_rate_values,
 paired_rate_values) = calc_rate_values(
    trees,
    sequences, 
    structures, 
    single_frequencies, 
    paired_frequencies, 
    singles_prob, 
    paireds_prob
)

# Save to a file
with open("./primaries/frequencies.pkl", "wb") as file:
    pickle.dump((single_frequencies, 
                 paired_frequencies, 
                 singles_prob, 
                 paireds_prob), file)

with open("./primaries/mutation_rate.pkl", "wb") as file:
    pickle.dump((single_rate_values,
                 paired_rate_values), file)

In [21]:
# Load from the file
with open("./primaries/frequencies.pkl", "rb") as file:
    (single_frequencies, 
     paired_frequencies, 
     singles_prob, 
     paireds_prob) = pickle.load(file)
    
with open("./primaries/mutation_rate.pkl", "rb") as file:
    (single_rate_values,
     paired_rate_values) = pickle.load(file)

## Train Grammar Of Structure

In [22]:
simplify_struct(structures, filename="./primaries/structures")

In [23]:
# Train for first time
pcfg = PCNF("./primaries/structure.cfg", "./primaries/structure.pcfg")
pcfg.estimate("./primaries/structures.train", iter_num=5)
save_pcfg(pcfg, "./primaries/structure")

Itration number: 1
Itration number: 2
Itration number: 3
Itration number: 4
Itration number: 5
Estimation complete!


In [24]:
# Read from trained file
pcfg = PCNF("./primaries/structure.cfg", "./primaries/structure.pcfg")

___