In [None]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px

In [None]:
import pandas as pd

In [None]:
import sys
sys.path.insert(0, '../../rna_ss/')
from utils import db2pairs

In [None]:
df = pd.read_pickle('../../rna_ss/data_processing/rna_cg/data/rfam.pkl')

In [None]:
df.head(20)

In [None]:
# # from http://gtrnadb.ucsc.edu/genomes/eukaryota/Hsapi19/genes/tRNA-Arg-CCT-5-1.html
# db_str = '>>>>>>>..>>..>>........<<..<<.>>>>>.......<<<<<.....>>...>>>.......<<<<<<<<<<<<.'
# db_str = db_str.replace('>', '(').replace('<', ')')

In [None]:
# pairs = db2pairs(db_str)

In [None]:
# print(pairs)

In [None]:
# x= np.zeros((len(db_str), len(db_str)))
# for i, j in pairs:
#     x[i, j] =1

In [None]:
idx = 18
tmp_pairs = df.iloc[idx].one_idx
# unpack idxes
pairs = []
for i, j in zip(tmp_pairs[0], tmp_pairs[1]):
    pairs.append((i, j))
#     pairs.append((i-1, j-1))   # only for PDB? - TODO make all dataset consistent
x= np.zeros((df.iloc[idx].len, df.iloc[idx].len))
for i, j in pairs:
    x[i, j] =1

In [None]:
plt.imshow(x)
plt.colorbar()
plt.show()

In [None]:
px.imshow(x)

In [None]:
# first sort pairs (i1, j1), (i2, j2), ..., (ik, jk),....
# such that ik < jk for all k and ik < ik+1
pairs = [(i, j) if i < j else (j, i) for i, j in pairs]
pairs = sorted(pairs)

In [None]:
class Stem(object):
    def __init__(self):
        self.one_idx = []
    
    def validate(self, one_idx):
        # validate pairs
        # no need to validate if empty, or there is only one base pair
        if len(one_idx) <= 1:
            pass
        else:
            # make sure it's sorted
            assert sorted(one_idx) == one_idx
            # make sure every 2 consecutive pairs (ik, jk) & (ik+1, jk+1) satifies ik+1 = ik + 1 and jk+1 = jk - 1
            assert all([a[0] + 1 == b[0] and a[1] - 1 == b[1] for a, b in zip(one_idx[:-1], one_idx[1:])])
            
    def add_pair(self, pair):
        assert len(pair) == 2
        assert pair[0] < pair[1]
        # add to current collection 
        one_idx = self.one_idx.copy()
        one_idx.append(pair)
        # sort
        one_idx = sorted(one_idx)
        # validate
        self.validate(one_idx)
        # update
        self.one_idx = one_idx
        
    def bounding_box(self):
        # return location and size of bounding box
        assert self.one_idx == sorted(self.one_idx)
        return self.one_idx[0][0], self.one_idx[-1][1], len(self.one_idx)
    
    def __repr__(self):
        return "Stem location ({0}, {1}) height {2} width {2}".format(*self.bounding_box())

In [None]:
class StemCollection(object):
    def __init__(self):
        self.stems = []
        self.current_stem = None
        
    def new(self):
        self.current_stem = Stem()
    
    def conclude(self):
        if len(self.current_stem.one_idx) > 0:
            self.stems.append(self.current_stem)
        self.current_stem = None
    
    def is_compatible(self, pair):
        # assuming sorted
        assert len(pair) == 2
        assert pair[0] < pair[1]
        if len(self.current_stem.one_idx) == 0:
            return True
        elif pair[0] == self.current_stem.one_idx[-1][0] + 1 and pair[1] == self.current_stem.one_idx[-1][1] - 1:
            return True
        else:
            return False
    
    def add_pair(self, pair):
        self.current_stem.add_pair(pair)
    
    def sort(self):
        raise NotImplementedError

In [None]:
# external_loop = [None, None, None, None]
# do not consider external loop (non-local, see http://eternawiki.org/wiki/index.php5/External_Loop)


# stems only
sc = StemCollection()
sc.new()
for pair in pairs:
    if sc.is_compatible(pair):
        sc.add_pair(pair)
    else:
        sc.conclude()
        sc.new()
        sc.add_pair(pair)
sc.conclude()

In [None]:
sc.stems

In [None]:
def paired(position, pairs):
    paired = False
    for pair in pairs:
        if position == pair[0] or position == pair[1]:
            paired = True
    return paired

In [None]:
l_bulges = []
r_bulges = []
internal_loops = []
hairpin_loops = []

# find in-between stem local structures:
# bulge
# internal loop

# TODO sort stem collection

for s1, s2 in zip(sc.stems[:-1], sc.stems[1:]):
    # make sure these two stems are not fully connected
    assert not(s1.one_idx[-1][0] + 1 == s2.one_idx[0][0] and s1.one_idx[-1][1] - 1 == s2.one_idx[0][1])
    if s1.one_idx[-1][0] + 1 == s2.one_idx[0][0]:  # i connected
        # check if all idxes on the other side are unpaired -> bulge
        idxes = range(s2.one_idx[0][1] + 1, s1.one_idx[-1][1])
        if all([not paired(i, pairs) for i in idxes]):
            r_bulges.append((list(idxes), s1.one_idx[-1][0], s2.one_idx[0][0]))
            print("bulge(R) {} between stems:\n{}\n{}\n".format(list(idxes), s1, s2))
    elif s1.one_idx[-1][1] - 1 == s2.one_idx[0][1]:  # j connected
        # check if all idxes on the other side are unpaired -> bulge
        idxes = range(s1.one_idx[-1][0] + 1, s2.one_idx[0][0])
        if all([not paired(i, pairs) for i in idxes]):
            l_bulges.append((list(idxes), s2.one_idx[0][1], s1.one_idx[-1][1]))
            print("bulge(R) {} between stems:\n{}\n{}\n".format(list(idxes), s1, s2))
    else:  # neither side connected
        # check if all idxes on both sides are unpaired -> internal loop
        idxes_i = range(s1.one_idx[-1][0] + 1, s2.one_idx[0][0])
        idxes_j = range(s2.one_idx[0][1] + 1, s1.one_idx[-1][1])
        if all([not paired(i, pairs) for i in list(idxes_i) + list(idxes_j)]):
            internal_loops.append([min(idxes_i), max(idxes_i), min(idxes_j), max(idxes_j)])
            print("internal loop {} {} between stems:\n{}\n{}\n".format(list(idxes_i), list(idxes_j), s1, s2))
            
# find single-stem local structure: 
# hairpin loop
for s in sc.stems:
    # check whether the positions enclosed by the stem are unpaired -> hairpin loop
    idxes = range(s.one_idx[-1][0] + 1, s.one_idx[-1][1])
    if all([not paired(i, pairs) for i in idxes]):
        hairpin_loops.append(list(idxes))
        print("hairpin loop {} within stem:\n{}\n".format(list(idxes), s))

In [None]:
l_bulges

In [None]:
r_bulges

In [None]:
internal_loops

In [None]:
hairpin_loops

In [None]:
import plotly.graph_objects as go

In [None]:
fig = px.imshow(x)

# stems
for stem in sc.stems:
    a, b, w = stem.bounding_box()
    fig.add_shape(
        type='rect',
        y0=a, y1=a+w, x0=b, x1=b+w,
        xref='x', yref='y',
        line_color='red'
    )
    

# bulges
for bulge in l_bulges:
    a_s, b1, b2 = bulge
    fig.add_shape(
        type='rect',
        y0=min(a_s), y1=max(a_s), x0=b1, x1=b2,
        xref='x', yref='y',
        line_color='cyan'
    )
for bulge in r_bulges:
    bs, a1, a2 = bulge
    fig.add_shape(
        type='rect',
        y0=a1, y1=a2, x0=min(bs), x1=max(bs),
        xref='x', yref='y',
        line_color='cyan'
    )


# internal loop
for a1, a2, b1, b2 in internal_loops:
    fig.add_shape(
        type='rect',
        y0=a1, y1=a2, x0=b1, x1=b2,
        xref='x', yref='y',
        line_color='purple'
    )

    
# hairpin loop
# these are symmetric around the off-diagonal, only need to draw the triangle, but I'm lazy so drawing rectangle instead
for idxes in hairpin_loops:
    fig.add_shape(
        type='rect',
        y0=min(idxes), y1=max(idxes), x0=min(idxes), x1=max(idxes),
        xref='x', yref='y',
        line_color='white'
    )

fig.show()

In [None]:
# fig = px.imshow(x)
# # fig.add_trace(go.Scatter(x=[20], y=[10], marker=dict(color='red', size=6)))
# fig.add_shape(
#     type='rect',
# #     x0=0, x1=9, y0=98, y1=107,
#     y0=0, y1=9, x0=107, x1=116,
#     xref='x', yref='y',
#     line_color='red'
# )
# fig.show()

In [None]:
s1.one_idx