In [2]:
import RNA
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from src.dataset import load_benchmark_dataset, Species, Modification

In [3]:
test_dataset = load_benchmark_dataset(Species.human, Modification.psi, True)
train_dataset = load_benchmark_dataset(Species.human, Modification.psi)

In [7]:
def rna_to_graph(sequence: str):
    dot_bracket, mfe = RNA.fold(sequence)
    
    graph = nx.Graph()
    
    for i, (base, structure) in enumerate(zip(sequence, dot_bracket)):
        graph.add_node(i, base=base, structure=structure)
    
    
    stack = []
    for i, char in enumerate(dot_bracket):
        if char == '(':
            stack.append(i)
        elif char == ')':
            if stack:
                start = stack.pop()
                graph.add_edge(start, i, type='base_pair')
        if i > 0:
            graph.add_edge(i-1, i, type='backbone')
    
    bases = ['A', 'U', 'G', 'C']
    node_features = np.zeros((len(sequence), len(bases)))
    for i, base in enumerate(sequence):
        node_features[i, bases.index(base)] = 1
    
    # Adjacency matrix
    adj_matrix = nx.to_numpy_array(graph)

    return graph, node_features, adj_matrix

def render_graph(graph):
    pos = nx.spring_layout(graph)
    plt.figure(figsize=(12, 8))
    nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=10, font_weight='bold')
    
    # Add base labels
    node_labels = nx.get_node_attributes(graph, 'base')
    nx.draw_networkx_labels(graph, pos, node_labels, font_size=8)
    
    # Color base pair edges differently
    base_pair_edges = [(u, v) for (u, v, d) in graph.edges(data=True) if d['type'] == 'base_pair']
    nx.draw_networkx_edges(graph, pos, edgelist=base_pair_edges, edge_color='r', width=2)
    
    plt.title("RNA Secondary Structure Graph")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [5]:
train_sequences = train_dataset.samples['sequence'].values

In [8]:
for seq in train_sequences[:1]:
    graph, features, matrix = rna_to_graph(seq)
    print(features)
    print(matrix)

[[0. 0. 0. 1.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 0. 1.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 1. 0. 0.]]
[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0

In [None]:
# def rna_to_graph(sequence):
#     # Step 1: Generate dot-parenthesis notation
#     (dot_bracket, mfe) = RNA.fold(sequence)
#     print(f"Sequence: {sequence}")
#     print(f"Dot-bracket: {dot_bracket}")
#     print(f"Minimum free energy: {mfe}")
# 
#     # Step 2: Create graph representation
#     G = nx.Graph()
#     
#     # Add nodes (nucleotides)
#     for i, (base, structure) in enumerate(zip(sequence, dot_bracket)):
#         G.add_node(i, base=base, structure=structure)
#     
#     # Add edges (connections)
#     stack = []
#     for i, char in enumerate(dot_bracket):
#         if char == '(':
#             stack.append(i)
#         elif char == ')':
#             if stack:
#                 start = stack.pop()
#                 G.add_edge(start, i, type='base_pair')
#         
#         # Add backbone connections
#         if i > 0:
#             G.add_edge(i-1, i, type='backbone')
# 
#     # Step 3: Visualize the graph
#     pos = nx.spring_layout(G)
#     plt.figure(figsize=(12, 8))
#     nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=10, font_weight='bold')
#     
#     # Add base labels
#     node_labels = nx.get_node_attributes(G, 'base')
#     nx.draw_networkx_labels(G, pos, node_labels, font_size=8)
#     
#     # Color base pair edges differently
#     base_pair_edges = [(u, v) for (u, v, d) in G.edges(data=True) if d['type'] == 'base_pair']
#     nx.draw_networkx_edges(G, pos, edgelist=base_pair_edges, edge_color='r', width=2)
#     
#     plt.title("RNA Secondary Structure Graph")
#     plt.axis('off')
#     plt.tight_layout()
#     plt.show()
# 
#     # Step 4: Prepare data for GCN
#     # Node features: One-hot encoding of bases
#     bases = ['A', 'U', 'G', 'C']
#     node_features = np.zeros((len(sequence), len(bases)))
#     for i, base in enumerate(sequence):
#         node_features[i, bases.index(base)] = 1
#     
#     # Adjacency matrix
#     adj_matrix = nx.to_numpy_array(G)
# 
#     return G, node_features, adj_matrix
# 
# # Example usage
# sequence = "GGGCUAUUAGCUCAGUUGGUUAGAGCGCACCCCUGAUAAGGGUGAGGUCGCUGAUUCGAAUUCAGCAUAGCCCA"
# G, node_features, adj_matrix = rna_to_graph(sequence)
# 
# print("Node features shape:", node_features.shape)
# print("Adjacency matrix shape:", adj_matrix.shape)