In [1]:
import os
import re
import math
import itertools
import fileinput

import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

from Bio.Seq import Seq
from collections import defaultdict

In [2]:
def get_node_id(sample_id, contig_id):
    '''
    Generate a node_id from sample_id and contig_id.
    '''
    node_id = f"{sample_id}:{contig_id}"

    return node_id

def node_coverage(gfa_arguments, sequence_len):
    """
    Return coverage parsed from dp or estimated from KC tag.
    The second return value is True for dp and False for KC.
    """

    # Find dp or KC tag
    for arguments in gfa_arguments:
        # Check for the 'dp:f:' pattern or the 'KC:i:' pattern in each gfa_argument
        dp_match = re.match(r'^dp:f:(.*)$', arguments)
        kc_match = re.match(r'^KC:i:(.*)$', arguments)

        if dp_match:
            # If 'dp:f:' pattern is found, return the extracted coverage and True for dp
            return float(dp_match.group(1)), True

        if kc_match:
            # If 'KC:i:' pattern is found, return the calculated coverage from KC and False for KC
            return float(kc_match.group(1)) / sequence_len, False

    # If neither 'dp' nor 'KC' tags are found, raise an assertion error
    raise AssertionError("Depth not found")

def gc_content(sequence):
    """Calculate the GC content of a given DNA sequence."""

    # Count the number of 'G' and 'C' bases in the sequence
    gc_bases = sum(base in 'GC' for base in sequence)

    # Count the total number of valid bases (A, C, G, T) in the sequence
    total_bases = sum(base in 'ACGT' for base in sequence)

    # Calculate GC content only if there are valid bases in the sequence
    if total_bases > 0:
        gc_content = round(gc_bases / total_bases, 4)
    else:
        # Default GC content when there are no valid bases
        gc_content = 0.5

    return gc_content

def kmer_distribution(sequence, kmer_len=5, scale=False):
    """Calculate k-mer distribution from a sequence"""

    assert kmer_len % 2 == 1, "K-mer length should be odd."

    k_mers = ["".join(x) for x in itertools.product("ACGT", repeat=kmer_len)]

    forward_kmers = []
    forward_kmer_set = set()
    reverse_kmer_set = set()

    for k_mer in k_mers:
        if not ((k_mer in forward_kmer_set) or (k_mer in reverse_kmer_set)):
            forward_kmers.append(k_mer)
            forward_kmer_set.add(k_mer)
            reverse_kmer_set.add(str(Seq(k_mer).reverse_complement()))

    # Using defaultdict for pseudocounts
    kmer_count_dict = defaultdict(lambda: 0.01)

    # Counting k-mers in the sequence
    for i in range(len(sequence) - kmer_len + 1):
        kmer = sequence[i:i + kmer_len]

        if kmer in kmer_count_dict:
            kmer_count_dict[kmer] += 1

    # Calculating k-mer distribution
    k_mer_distribution = [
        kmer_count_dict[k_mer] + kmer_count_dict[str(Seq(k_mer).reverse_complement())]
        for k_mer in forward_kmers
    ]

    # Scaling k-mer distribution if specified
    if scale:
        total_count = sum(k_mer_distribution)
        k_mer_distribution = [count / total_count for count in k_mer_distribution]

    return k_mer_distribution

def weighted_median(values, weights):
    # Calculate the middle value
    middle = np.sum(weights) / 2

    # Calculate cumulative sum of weights
    cumsum = np.cumsum(weights)

    # Iterate through the cumulative sums
    for i, x in enumerate(cumsum):
        # Find the index where cumulative sum is greater than or equal to the middle
        if x >= middle:
            # Return the corresponding value as the weighted median
            return values[i]

    # Assertion to handle unexpected cases if the loop completes without returning
    assert False

def add_normalized_coverage(graph, node_ids):
    """
    Add attribute coverage_norm: original coverage divided by median weighted by length.
    """
    # Sort nodes based on their coverage attribute
    sorted_nodes = sorted(node_ids, key=lambda x: graph.nodes[x]["coverage"])

    # Extract lengths and coverages for sorted nodes
    coverages = np.array([graph.nodes[node]["coverage"] for node in sorted_nodes])
    lengths = np.array([graph.nodes[node]["length"] for node in sorted_nodes])

    # Calculate the median
    median = weighted_median(coverages, lengths)

    # Calculate and add coverage_norm attribute for each node
    for node_id in node_ids:
        # Calculate coverage_norm: original coverage divided by the median weighted by length
        graph.nodes[node_id]["coverage_norm"] = graph.nodes[node_id]["coverage"] / median

def KL(a, b):
    # Convert input arrays 'a' and 'b' to NumPy arrays of type float
    a = np.asarray(a, dtype=float)
    b = np.asarray(b, dtype=float)

    # Calculate the Kullback-Leibler divergence between probability distributions 'a' and 'b'
    kl_divergence = np.sum(np.where(a != 0, a * np.log(a / b), 0))

    return kl_divergence

def label_to_pair(label):
    """Convert a label into a pair of values based on predefined mappings."""

    # Define mappings of labels to pairs
    label_mappings = {
        "chromosome": [0, 1],
        "plasmid": [1, 0],
        "ambiguous": [1, 1],
        "unlabeled": [0, 0],
        None: [0, 0]
    }

    if label in label_mappings:
        return label_mappings[label]
    else:
        raise AssertionError(f"Unrecognized label: {label}")

def pair_to_label(pair):
    """Converts a pair of values into a label based on predefined mappings"""

    if pair == [0, 1]:
        return "chromosome"
    elif pair == [1, 0]:
        return "plasmid"
    elif pair == [1, 1]:
        return "ambiguous"
    elif pair == [0, 0]:
        return "unlabeled"
    else:
        raise AssertionError(f"Unrecognized pair: {pair}")

In [3]:
def read_graph(graph_file, csv_file, sample_id, graph, minimum_contig_len):
    node_ids = []
    sequences = ""
    coverage_types = {True:0, False:0}  # Which coverage types for individual nodes

    with fileinput.input(graph_file, openhook=fileinput.hook_compressed, mode='r') as file:
        for line in file:
            if isinstance(line, bytes):
                line = line.decode("utf-8") # Convert byte to string

            parts = line.strip().split("\t")

            if parts[0] == "S": # Node line
                node_id = get_node_id(sample_id, parts[1])
                node_ids.append(node_id)

                sequence = parts[2].upper()

                if not re.match(r'^[A-Z]*$', sequence):
                    raise AssertionError(f"Bad sequence in {node_id}")

                # N is ignored by GC, helps to avoid fake kmers
                sequences += "N" + sequence

                graph.add_node(node_id)
                sequence_len = len(sequence)

                # Calculate coverage
                (coverage, is_dp) = node_coverage(parts[3:], sequence_len)

                graph.nodes[node_id]["contig"] = parts[1]
                graph.nodes[node_id]["sample"] = sample_id
                graph.nodes[node_id]["length"] = sequence_len
                graph.nodes[node_id]["coverage"] = coverage
                graph.nodes[node_id]["gc"] = gc_content(sequence)
                graph.nodes[node_id]["kmer_counts_norm"] = kmer_distribution(sequence, scale=True)

                coverage_types[is_dp] += 1

            if parts[0] == "L":  # Edge line
                graph.add_edge(get_node_id(sample_id, parts[1]), get_node_id(sample_id, parts[3]))

    # Check that only one coverage type seen
    assert coverage_types[True] == 0 or coverage_types[False] == 0

    # Get GC of whole sequence
    gc_of_whole_seq = gc_content(sequences)

    # Get max length
    max_contig_length = max([graph.nodes[node_id]["length"] for node_id in node_ids])

    # Get graph degrees and normalized gc content and normalized contig lengths (divided by max length)
    for node_id in node_ids:
        graph.nodes[node_id]["degree"] = graph.degree[node_id]
        graph.nodes[node_id]["gc_norm"] =  graph.nodes[node_id]["gc"] - gc_of_whole_seq
        graph.nodes[node_id]["length_norm"] = graph.nodes[node_id]["length"] / 2000000
        graph.nodes[node_id]["loglength"] = math.log(graph.nodes[node_id]["length"] + 1)

    add_normalized_coverage(graph, node_ids)

    # Get euclidian of pentamer distribution for each node
    all_kmer_counts_norm = np.array(kmer_distribution(sequences, scale=True))

    for node_id in node_ids:
        diff = np.array(graph.nodes[node_id]["kmer_counts_norm"]) - all_kmer_counts_norm
        graph.nodes[node_id]["kmer_dist"] = np.linalg.norm(diff)
        graph.nodes[node_id]["kmer_dot"] = np.dot(np.array(graph.nodes[node_id]["kmer_counts_norm"]), all_kmer_counts_norm)
        graph.nodes[node_id]["kmer_kl"] = KL(np.array(graph.nodes[node_id]["kmer_counts_norm"]), all_kmer_counts_norm)

    # Read and add node labels
    if csv_file is not None:
        labels_df = pd.read_csv(csv_file)
        labels_df["id"] = labels_df["contig"].map(lambda x : get_node_id(sample_id, x))
        labels_df.set_index("id", inplace=True)
    else:
        labels_df = pd.DataFrame()

    for node_id in node_ids:
        label = None
        if node_id in labels_df.index:
            label = labels_df.loc[node_id, "label"]  # Textual label

        pair = label_to_pair(label)  # Pair of binary values

        graph.nodes[node_id]["text_label"] = pair_to_label(pair)
        graph.nodes[node_id]["plasmid_label"] = pair[0]
        graph.nodes[node_id]["chrom_label"] = pair[1]

    # Get the number of nodes and edges
    num_nodes = graph.number_of_nodes()
    num_edges = graph.number_of_edges()

    print("Number of nodes:", num_nodes)
    print("Number of edges:", num_edges)

    # Remove short contigs from the graph and connect new neighbors
    if minimum_contig_len > 0:
        for node_id in node_ids:
          if graph.nodes[node_id]["length"] < minimum_contig_len:
              print(node_id)
              neighbors = list(graph.neighbors(node_id))
              all_new_edges = list(itertools.combinations(neighbors, 2))

              for edge in all_new_edges:
                  graph.add_edge(edge[0], edge[1])

              graph.remove_node(node_id)

    # Get the number of nodes and edges
    num_nodes = graph.number_of_nodes()
    num_edges = graph.number_of_edges()

    print("Number of nodes:", num_nodes)
    print("Number of edges:", num_edges)

In [4]:
graph = nx.Graph()

graph_file = '/workspaces/panplasmid/models/plasgraph2/example/SAMN15148288_SKESA.gfa.gz'
csv_file = '/workspaces/panplasmid/models/plasgraph2/example/SAMN15148288_output.csv'
sample_id = 'sample'

read_graph(graph_file=graph_file, csv_file=csv_file, sample_id=sample_id, graph=graph, minimum_contig_len=0)

Number of nodes: 223
Number of edges: 265
Number of nodes: 223
Number of edges: 265
