In [1]:
import numpy as np
import networkx as nx
from tqdm import tqdm
from collections import deque
import json
import os

class QuadGraph():
    def __init__(self, sequence, max_bulge_len, max_loop_len):
        self._seq = sequence.upper()
        self._maxBulgeLen = int(max_bulge_len)
        self._maxLoopLen = int(max_loop_len)
        self._stemLen = 3 # QuadGraph only supports stem length of 3
        self._kmerSize = self._stemLen + self._maxBulgeLen
        self._loopWindow = (self._maxLoopLen + self._stemLen + self._maxBulgeLen) * \
                           (self._stemLen + self._maxBulgeLen + 1)
        self.Graphs = {'G': [], 'C': []}
        self.graphPositions = {'G': [], 'C': []}
        for base in ['G', 'C']:
            for stems in self._generate_stems(base):
                start_pos, graph = self._make_graph(stems, base)
                if graph is not None:
                    self.Graphs[base].append(graph)
                    self.graphPositions[base].append(start_pos)

    def _generate_stems(self, base):
        stems = []
        for p in tqdm(range(len(self._seq) - self._kmerSize + 1),
                      desc="Progress for base '%s'" % base):
            s = []
            kmer = np.array(list(self._seq[p: p + self._kmerSize]))
            if kmer[0] == base:
                if kmer[1] == base:
                    s.extend([[p, p + 1, p + i] for i in
                               np.where(kmer == base)[0][2:]])
                s.extend([[p, p + i, p + i + 1] for i in
                          range(2, self._kmerSize - 1) if
                          kmer[i] == base and kmer[i + 1] == base])
            if len(s) > 0:
                if len(stems) > 1 and s[0][0] - stems[-1][2] > self._maxLoopLen:
                    if len(stems) > 4:
                        yield np.array(stems)
                    stems = list(s)
                else:
                    stems.extend(s)
        for j in range(1, len(kmer) - self._stemLen):  # loop for last kmer
            xp = p + j
            k = kmer[j:]
            if k[0] == base:
                if k[1] == base:
                    stems.extend([[xp, xp + 1, xp + i] for i in
                                   np.where(k == base)[0][2:]])
                stems.extend([[xp, xp + i, xp + i + 1] for i in
                              range(2, len(k) - 1) if
                              k[i] == base and k[i + 1] == base])
        if len(stems) > 4:
            yield np.array(stems)
    
    @staticmethod
    def _stem_encoder(stem_array, norm):
        diff = np.diff(stem_array)
        if diff[0] != 1:
            return '%d%s**' % (stem_array[0] - norm, '-' * (diff[0] - 1))
        else:
            return '%d*%s*' % (stem_array[0] - norm, '-' * (diff[1] - 1))

    @staticmethod
    def _stem_decoder(code, norm):
        bulge = code.count('-')
        first_val = int(code.replace('*', '').replace('-', '')) + norm
        if code[-2] == '*':
            return [first_val, first_val + bulge + 1, first_val + bulge + 2]
        else:
            return [first_val, first_val + 1, first_val + bulge + 2]

    @staticmethod
    def _remove_nonquad_nodes(graph):
        quads = deque()
        for n1 in graph.nodes():
            for n2 in graph.successors(n1):
                for n3 in graph.successors(n2):
                    for n4 in graph.successors(n3):
                        quads.extend([n1, n2, n3, n4])
        del_nodes = list(set(graph.nodes()).difference(set(quads)))
        for node in del_nodes:
            graph.remove_node(node)
        return graph        

    def _make_graph(self, stems, base):
        G = nx.DiGraph()
        first_base = stems[0][0]
        for i in range(stems.shape[0] - 1):
            s = stems[i + 1: i + self._loopWindow][:, 0]
            b = (s > stems[i][2] + 1) & \
                (s <= stems[i][2] + self._maxLoopLen + 1)
            for j in np.where(b == True)[0]:
                n1 = self._stem_encoder(stems[i], first_base)
                n2 = self._stem_encoder(stems[i + j + 1], first_base)
                edge_len = stems[i + j + 1][0] - stems[i][2]
                G.add_edge(n1, n2, length=int(edge_len))
        G = self._remove_nonquad_nodes(G)
        if len(G.edges()) >= 3 and nx.dag_longest_path_length(G) >= 3:
            return (first_base, G)
        else:
            return (None, None)