In [None]:
import json
import networkx as nx
from networkx.classes.function import edges


def read_json_data(path):
  with open(path) as f:
    # Load the contents of the file into a variable
    data = f.read()
    json_data = json.loads(data)
  return json_data


def create_positive_graph(json_data):
  # Create an empty graph
  G = nx.DiGraph()
  # Add nodes to the graph
  for node in json_data["node"]:
    G.add_node(int(node["id"]), seq=node["sequence"])
  # Add edges to the graph
  for edge in json_data["edge"]:
    G.add_edge(int(edge["from"]), int(edge["to"]))
  json_data.clear()
  return G


def reverse_complement(seq):
  reverse_compl_seq = ""
  for c in seq:
    if c == 'A':
      reverse_compl_seq += 'T'
    elif c == 'T':
      reverse_compl_seq += 'A'
    elif c == 'C':
      reverse_compl_seq += 'G'
    elif c == 'G':
      reverse_compl_seq += 'C'
    elif c == 'N':
      reverse_compl_seq += 'N'
  reverse_compl_seq = reverse_compl_seq[::-1]

  return reverse_compl_seq


def create_negative_graph(G_positive):
  G_negative = G_positive.reverse()
  for node in G_negative.nodes():
    G_negative.nodes[node]["seq"] = reverse_complement(G_negative.nodes[node]["seq"])
  return G_negative


def find_pam_nodes(G, PAM):
  gg_nodes = []
  gg_splitted_nodes = []
  for node in G.nodes():
    if PAM[-2:] in G.nodes[node]["seq"]:
      gg_nodes.append(node)
    if G.nodes[node]["seq"].startswith(PAM[-1]):
      # iterate over neighbors
      for neighbor in G.predecessors(node):
        if G.nodes[neighbor]['seq'].endswith(PAM[-2]):
          gg_splitted_nodes.append(node)

  joined_gg_nodes = gg_nodes + list(set(gg_splitted_nodes) - set(gg_nodes))

  gg_splitted_nodes.clear()
  gg_nodes.clear()
  return joined_gg_nodes


#function to extract the subgraph contined the PAM.
def PAM_DFS(G, PAM_node, PAM, max_seq_depth):
  #trash_len = (lenGuida + NBulge) * MAxNumberBulgeConsidered = (20 + 0) * 4
  nodes = []
  condition = False
  trash_len = max_seq_depth - G.nodes[PAM_node]["seq"].find(PAM[-2:])
  nodes.append(PAM_node)
  for node in nodes:
    for neighbor in G.predecessors(node):
      if trash_len - len(G.nodes[neighbor]["seq"]) > 0 or trash_len > 20 :
        if neighbor not in nodes:
          trash_len -= len(G.nodes[neighbor]["seq"])
          nodes.append(neighbor)
      else:
        condition = True
        break
    if condition:
      break

  return nodes


#function that retunr the edges needed for the function find_all_paths
def extract_edges(G, nodes):
  edges = []
  if len(nodes) > 1:
    for n1 in nodes:
      for n2 in nodes:
        if n1 != n2 and G.has_edge(n1, n2):
          edge = (n1, n2)
          edges.append(edge)

  return edges


# Define a function to generate all paths from a list of edges
def find_all_paths(edges, target_node):
    # Create an empty graph and add the edges to it
    graph = nx.DiGraph()
    graph.add_edges_from(edges)
    # Generate all paths from the graph
    all_paths = []
    starts_nodes = []
    #Da qua faccio il all_simple_path con tutti i nodi che non hanno un arco entrante quindi questi nodi sono quelli
    #che vanno nello start_node
    for node in graph.nodes():
      if graph.in_degree(node) == 0:
        starts_nodes.append(node)
    for start_node in starts_nodes:
      for path in nx.all_simple_paths(graph, source = start_node, target = target_node):
        all_paths.append(path)
    graph.clear()

    return all_paths


def paths_on_strand(G, pam_nodes, PAM, max_seq_depth):
  all_paths = []
  for i in range(len(pam_nodes)):
    target_node = pam_nodes[i]
    nodes = PAM_DFS(G, target_node, PAM, max_seq_depth)
    edges = extract_edges(G, nodes)

    if len(nodes) == 1:
      all_paths.append([nodes])
    else:
      all_paths.append(find_all_paths(edges, target_node))

  return all_paths


def search_PAM_positions(seq, PAM):
  positions = []
  start = 0

  while True:
    index = seq.find(PAM[-2:], start)
    if index == -1:
        break
    if index > 20:
      positions.append(index - 21)
    start = index + 1

  return positions


  def count_mismatch(guide, seq, positions, strand, path, seq_nodes, PAM):
  max_miss = 4
  results = []
  seq_bk = seq
  guide_bk = guide
  RNA_Bulges_pos = []
  DNA_Bulges_pos = []
  for p in positions:
    guide = guide_bk
    seq = seq_bk
    index_guide = 0
    count_miss = 0
    seq_match = ""
    count_match = 0
    bulge_count = 0
    head = p
    bulge_inner_count = 0
    while index_guide < len(guide) - 1:
      #print(seq, p, len(seq))
      #print(guide, index_guide, len(guide))
      if seq[p] != guide[index_guide]:
        count_miss += 1
        #mi salvo i bulges points
        seq_p, guide_p = check_Bulges(guide, seq, p, max_miss, count_miss)
        if seq_p == guide_p:  # se sono uguali è perchè ho il bulges in entrmbe le stringhe allo stesso punto quindi considero come mismatch
          seq = seq[:p] + seq[p].lower() + seq[p+1:]
        else:
          RNA_Bulges_pos.append(seq_p)
          DNA_Bulges_pos.append(guide_p)
      else:
        count_match += 1
      seq_match += seq[p]
      p += 1
      index_guide += 1
    for sp in RNA_Bulges_pos:
      if sp > -1:
        seq_match = seq_match[:sp] + '-' + seq_match[sp:]
        count_miss -= 1
        bulge_count += 1
    # faccio lo stesso lavoro per l array che tiene i DNA Bulges
    for gp in DNA_Bulges_pos:
      if gp > -1:
        guide = guide[:gp] + '-' + guide[gp:]
        count_miss -= 1
        bulge_count += 1
    if count_miss <= max_miss and bulge_count <= 1:
      result = {
      'path' : path,
      'nodes' : seq_nodes,
      'seq': seq_bk,
      'mismatches': count_miss,
      'bulges': bulge_count,
      'guide': guide + "NNN",
      'seq_match': seq_match + PAM,
      'start': head,
      'strand': strand
      }
      print(result)
      results.append(result)

  return results


#modifico e prendo quello che fa meno mismatch
def check_Bulges(guide, seq, p, max_MM, count_miss):
  new_seq = seq
  new_guide = guide
  seq_p = check_RNA_Bulge(guide, seq, p, max_MM, count_miss)
  guide_p = check_DNA_Bulge(guide, seq, p, max_MM, count_miss)
  return seq_p, guide_p


def check_RNA_Bulge(guide, seq, p, max_MM, count_miss):
  i = p
  RNA_MM = 0
  #print("sto in RNA bulge")
  while i < len(guide) - 1: #guide di certo minore del target
    #print(i, guide, len(guide), len(seq), seq)
    if guide[i + 1].upper() != seq[i].upper():
      RNA_MM += 1
    i += 1
  if RNA_MM <= max_MM and RNA_MM <= count_miss:
    #seq = seq[:p] + '-' + seq[p:]
    return p
  else:
    return -1



def check_DNA_Bulge(guide, seq, p, max_MM, count_miss):
  i = p
  DNA_MM = 0
  #print("sto in DNA bulge")
  while i < len(guide) - 1:
    if guide[i].upper() != seq[i + 1].upper():
      DNA_MM += 1
    i += 1
  if DNA_MM <= max_MM and DNA_MM <= count_miss:
    #guide = guide[:p] + '-' + guide[p:]
    return p
  else:
    return -1



def compare_with_guide(G, all_paths):
  final_results = []
  result = []
  guide = "GAGTCCGAGCAGAAGAAGAA"
  for paths in all_paths: #list of paths
    for path in paths: #list of nodes
      seq = ""
      seq_nodes = []
      for node in path: #each single node in the path
        seq += G.nodes[node]["seq"]
        seq_nodes.append(G.nodes[node]["seq"])
      positions = search_PAM_positions(seq, "NGG")
      result = count_mismatch(guide, seq, positions, '+', path, seq_nodes, "NGG")
      if result:
        final_results.append(result)
  return final_results