In [1]:
# Imports
import umap
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple
import plotly.express as px
import warnings
from plotly import graph_objects as go

# Clustering Analysis using Hierarchical Clustering
from sklearn.preprocessing import StandardScaler
from scipy.cluster.hierarchy import linkage, dendrogram
from brancharchitect.io import read_newick
from brancharchitect.split_analysis import (
    compute_taxon_co_occurrence_in_filtered_nonexistent_splits,
)

warnings.filterwarnings("ignore")  # To suppress any warnings for clean output


# Statistical Significance Testing
# Assuming you have a function for permutation testing (defined earlier)
def plot_umap_2d_interactive(
    df: pd.DataFrame,
    title: str,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    random_state: int = 42,
    width: int = 800,
    height: int = 600,
    use_precomputed_distances: bool = True,
) -> Tuple[pd.DataFrame, go.Figure]:
    """
    Computes and plots interactive 2D UMAP with plotly.

    Args:
        df: Input DataFrame with features
        title: Plot title
        n_neighbors: UMAP neighbor parameter
        min_dist: UMAP minimum distance parameter
        random_state: Random seed
        width: Plot width
        height: Plot height
        use_precomputed_distances: Whether to use precomputed distances (1 - co_occurrence) for UMAP

    Returns:
        Tuple of UMAP DataFrame and plotly figure
    """
    # Apply UMAP
    umap_df = apply_umap(
        df,
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        n_components=2,
        random_state=random_state,
        use_precomputed_distances=use_precomputed_distances,
    )

    # Create interactive plot
    fig = px.scatter(
        umap_df,
        x="UMAP_1",
        y="UMAP_2",
        title=title,
        labels={"UMAP_1": "UMAP_1", "UMAP_2": "UMAP_2"},
        hover_data=[umap_df.index],
    )

    fig.update_layout(width=width, height=height, title_x=0.5, showlegend=True)

    return umap_df, fig


def permutation_test_co_clustering(
    co_occurrence_matrix: pd.DataFrame, num_permutations: int = 1000
) -> pd.DataFrame:
    """
    Performs a permutation test to assess the significance of co-occurrence frequencies.

    Args:
    - co_occurrence_matrix (pd.DataFrame): Co-occurrence frequency matrix.
    - num_permutations (int): Number of permutations.

    Returns:
    - pd.DataFrame: P-value matrix for co-occurrence frequencies.
    """
    taxa = co_occurrence_matrix.index.tolist()
    observed = co_occurrence_matrix.values
    permuted_counts = np.zeros_like(observed)

    for _ in range(num_permutations):
        permuted_indices = np.random.permutation(len(taxa))
        permuted = observed[permuted_indices, :][:, permuted_indices]
        permuted_counts += permuted >= observed

    p_values = permuted_counts / num_permutations
    p_values_df = pd.DataFrame(p_values, index=taxa, columns=taxa)
    return p_values_df

# Function Definitions
def plot_umap_2d_static(
    df: pd.DataFrame,
    title: str,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    random_state: int = 42,
    figsize: tuple = (12, 8),
) -> Tuple[pd.DataFrame, plt.Figure]:
    """
    Computes and plots 2D UMAP with matplotlib.

    Args:
        df: Input DataFrame with features
        title: Plot title
        n_neighbors: UMAP neighbor parameter
        min_dist: UMAP minimum distance parameter
        random_state: Random seed
        figsize: Figure size tuple

    Returns:
        Tuple of UMAP DataFrame and matplotlib figure
    """
    # Compute 2D UMAP
    reducer = umap.UMAP(
        n_components=2,
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        random_state=random_state,
    )

    umap_data = reducer.fit_transform(df)
    umap_df = pd.DataFrame(umap_data, columns=["UMAP1", "UMAP2"], index=df.index)

    # Create static plot
    fig, ax = plt.subplots(figsize=figsize)
    scatter = ax.scatter(
        umap_df["UMAP1"], umap_df["UMAP2"], c=range(len(df)), cmap="viridis", alpha=0.6
    )

    plt.colorbar(scatter)
    ax.set_title(title)
    ax.set_xlabel("UMAP1")
    ax.set_ylabel("UMAP2")

    return umap_df, fig


def preprocess_cooccurrence_df(
    co_occurrence_freq: Dict[str, Dict[str, float]]
) -> pd.DataFrame:
    """
    Preprocesses the co-occurrence frequency dictionary into a symmetric DataFrame.

    Args:
    - co_occurrence_freq (Dict[str, Dict[str, float]]): Co-occurrence frequencies between taxa.

    Returns:
    - pd.DataFrame: Symmetric co-occurrence frequency DataFrame.
    """
    df = pd.DataFrame(co_occurrence_freq)

    # Ensure the DataFrame is symmetric
    df = (df + df.T) / 2

    # Replace NaN with 0 (if any)
    df = df.fillna(0)

    # Set diagonal to 1.0 to represent perfect self-co-occurrence
    np.fill_diagonal(df.values, 1.0)

    return df


def plot_interactive_heatmap(
    co_occurrence_freq: Dict[str, Dict[str, float]], title: str
):
    """
    Plots an interactive heatmap of co-occurrence frequencies between taxa using Plotly.

    Args:
    - co_occurrence_freq (Dict[str, Dict[str, float]]): Co-occurrence frequencies between taxa.
    - title (str): Title of the heatmap.
    """
    # Preprocess the co-occurrence DataFrame
    df = preprocess_cooccurrence_df(co_occurrence_freq)

    # Create an interactive heatmap using Plotly
    fig = px.imshow(
        df,
        labels=dict(x="Taxa", y="Taxa", color="Co-occurrence Frequency"),
        x=df.columns,
        y=df.index,
        color_continuous_scale="YlGnBu",
        title=title,
    )

    # Update layout for better appearance
    fig.update_layout(
        width=800,
        height=800,
        xaxis_nticks=len(df.columns),
        yaxis_nticks=len(df.index),
    )

    # Show the figure
    fig.show()


def extract_g_group(taxon_name: str) -> str:
    """
    Extracts the G name combination from a taxon name.

    Args:
    - taxon_name (str): The full taxon name (e.g., "GII.P7.GII.6.KR074148").

    Returns:
    - str: The G name combination (e.g., "GII.P7.GII.6").
    """
    parts = taxon_name.split(".")
    if len(parts) > 1:
        # Remove the last part (assumed to be the unique identifier)
        g_group = ".".join(parts[:-1])
    else:
        # If splitting doesn't work, return the full name
        g_group = taxon_name
    return g_group


def get_g_groups(taxa: List[str]) -> List[str]:
    """
    Extracts G name combinations for a list of taxa.

    Args:
    - taxa (List[str]): List of taxon names.

    Returns:
    - List[str]: List of G name combinations.
    """
    return [extract_g_group(taxon) for taxon in taxa]


def convert_cooccurrence_to_dataframe(
    co_occurrence_freq: Dict[str, Dict[str, float]]
) -> pd.DataFrame:
    """
    Converts the nested co-occurrence frequency dictionary to a Pandas DataFrame.

    Args:
    - co_occurrence_freq (Dict[str, Dict[str, float]]): Co-occurrence frequencies between taxa.

    Returns:
    - pd.DataFrame: Co-occurrence frequency matrix.
    """
    df = preprocess_cooccurrence_df(co_occurrence_freq)
    return df


def apply_umap(
    df: pd.DataFrame,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    n_components: int = 2,
    random_state: int = 42,
    use_precomputed_distances: bool = True,
    distance_transform: str = 'negative_log',  # Options: 'negative_log', 'inverse', 'power', 'subtract'
    power: float = 1.0,  # Used if distance_transform is 'power'
) -> pd.DataFrame:
    """
    Applies UMAP to the co-occurrence frequency DataFrame.

    Args:
    - df (pd.DataFrame): Co-occurrence frequency matrix.
    - n_neighbors (int): The size of local neighborhood used for manifold approximation.
    - min_dist (float): The effective minimum distance between embedded points.
    - n_components (int): The dimension of the space to embed into.
    - random_state (int): Random seed for reproducibility.
    - use_precomputed_distances (bool): Whether to use precomputed distances for UMAP.
    - distance_transform (str): Method to transform co-occurrence into distances.
    - power (float): Power for the 'power' transformation.

    Returns:
    - pd.DataFrame: UMAP embedding with taxa names as index.
    """
    if use_precomputed_distances:
        co_occurrence_values = df.values.copy()

        if distance_transform == 'negative_log':
            # Replace zeros to avoid log(0)
            epsilon = 1e-10
            co_occurrence_values[co_occurrence_values <= epsilon] = epsilon
            # Distance = -log(co_occurrence)
            distance_matrix = -np.log(co_occurrence_values)
        elif distance_transform == 'inverse':
            # Replace zeros to avoid division by zero
            epsilon = 1e-10
            co_occurrence_values[co_occurrence_values <= epsilon] = epsilon
            # Distance = 1 / co_occurrence
            distance_matrix = 1.0 / co_occurrence_values
        elif distance_transform == 'power':
            # Distance = co_occurrence ** (-power)
            # Replace zeros to avoid division by zero
            epsilon = 1e-10
            co_occurrence_values[co_occurrence_values <= epsilon] = epsilon
            distance_matrix = co_occurrence_values ** (-power)
        elif distance_transform == 'subtract':
            # Distance = max_co_occurrence - co_occurrence
            max_co_occurrence = co_occurrence_values.max()
            distance_matrix = max_co_occurrence - co_occurrence_values
        else:
            raise ValueError(f"Unknown distance_transform: {distance_transform}")

        # Ensure the distance matrix is symmetric
        distance_matrix = (distance_matrix + distance_matrix.T) / 2

        # Set diagonal to zero to represent zero distance to self
        np.fill_diagonal(distance_matrix, 0.0)

        # Initialize UMAP with precomputed metric
        reducer = umap.UMAP(
            n_neighbors=n_neighbors,
            min_dist=min_dist,
            n_components=n_components,
            metric="precomputed",
            random_state=random_state,
        )

        # Fit and transform the data
        embedding = reducer.fit_transform(distance_matrix)
    else:
        # Optionally, scale the data
        scaler = StandardScaler()
        data = scaler.fit_transform(df)

        # Initialize UMAP
        reducer = umap.UMAP(
            n_neighbors=n_neighbors,
            min_dist=min_dist,
            n_components=n_components,
            metric="euclidean",
            random_state=random_state,
        )

        # Fit and transform the data
        embedding = reducer.fit_transform(data)

    # Create a DataFrame for the embedding
    umap_df = pd.DataFrame(
        embedding, index=df.index, columns=[f"UMAP_{i+1}" for i in range(n_components)]
    )

    return umap_df


def plot_embedding(
    embedding_df: pd.DataFrame, g_groups: List[str], method_name: str, title: str
):
    """
    Plots the embedding, coloring taxa based on their G name combinations.

    Args:
    - embedding_df (pd.DataFrame): Embedding DataFrame.
    - g_groups (List[str]): List of G name combinations corresponding to taxa.
    - method_name (str): The name of the dimensionality reduction method.
    - title (str): Title of the plot.
    """
    # Add G group information to the embedding DataFrame
    plot_df = embedding_df.copy()
    plot_df["G_Group"] = g_groups

    plt.figure(figsize=(14, 12))
    sns.scatterplot(
        x=plot_df.columns[0],
        y=plot_df.columns[1],
        hue="G_Group",
        palette="tab20",
        data=plot_df,
        s=100,
        alpha=0.7,
    )
    plt.title(f"{title} ({method_name})", fontsize=18)
    plt.xlabel(f"{method_name} Dimension 1")
    plt.ylabel(f"{method_name} Dimension 2")
    plt.legend(title="G Group", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()


def plot_interactive_embedding_3d(
    embedding_df: pd.DataFrame, g_groups: List[str], method_name: str, title: str
):
    """
    Plots an interactive 3D embedding, coloring taxa based on their G name combinations.

    Args:
    - embedding_df (pd.DataFrame): 3D embedding DataFrame.
    - g_groups (List[str]): List of G name combinations corresponding to taxa.
    - method_name (str): The name of the dimensionality reduction method.
    - title (str): Title of the plot.
    """
    # Add G group and Taxon information to the embedding DataFrame
    plot_df = embedding_df.copy()
    plot_df["G_Group"] = g_groups
    plot_df["Taxon"] = plot_df.index

    # Create an interactive 3D scatter plot
    fig = px.scatter_3d(
        plot_df,
        x=plot_df.columns[0],
        y=plot_df.columns[1],
        z=plot_df.columns[2],
        color="G_Group",
        hover_data=["Taxon"],
        title=f"{title} ({method_name})",
        labels={
            plot_df.columns[0]: f"{method_name} Dimension 1",
            plot_df.columns[1]: f"{method_name} Dimension 2",
            plot_df.columns[2]: f"{method_name} Dimension 3",
            "G_Group": "G Group",
        },
        opacity=0.8,
        width=1000,
        height=800,
    )

    # Update layout for better aesthetics
    fig.update_layout(
        legend_title_text="G Group",
        title_font_size=20,
        scene=dict(
            xaxis_title=f"{method_name} Dimension 1",
            yaxis_title=f"{method_name} Dimension 2",
            zaxis_title=f"{method_name} Dimension 3",
            camera=dict(eye=dict(x=1.25, y=1.25, z=1.25)),
        ),
    )

    fig.show()

In [3]:

# ===========================
# Reading Trees and Computing Co-occurrence
# ===========================

# Read trees from a Newick file
# Note: Adjust the file path as needed
trees = read_newick(
    "./../../sliding-window/output_norovirus_window_size_200_5/best_trees.newick"
)

# trees = read_newick(
#     "./../../sliding-window/output_norovirus_window_size_200_5/best_trees.newick"
#)

# Compute co-occurrence frequencies using all pairs
co_occurrence_freq_all_pairs = (
    compute_taxon_co_occurrence_in_filtered_nonexistent_splits(
        list_of_trees=trees, to_filter=True
    )
)

# Convert co-occurrence frequencies to DataFrame
co_occurrence_df_all_pairs = convert_cooccurrence_to_dataframe(
    co_occurrence_freq_all_pairs
)

# Plot the heatmap for all pairs
plot_interactive_heatmap(
    co_occurrence_freq_all_pairs, title="Taxon Co-occurrence Heatmap (All Pairs)"
)

# ===========================
# Dimensionality Reduction Analysis
# ===========================

# Extract G Groups
taxa = co_occurrence_df_all_pairs.index.tolist()
g_groups = get_g_groups(taxa)

# Perform permutation test
# p_values_df = permutation_test_co_clustering(co_occurrence_df_all_pairs, num_permutations=1000)
# Plot heatmap of p-values
# plt.figure(figsize=(12, 10))
# sns.heatmap(p_values_df, cmap='coolwarm_r', square=True)
# plt.title("P-values Heatmap of Co-occurrence Frequencies")
# plt.show()


In [134]:
# Apply UMAP for 3D embedding
umap_embedding_3d = apply_umap(
    co_occurrence_df_all_pairs,
    n_neighbors=5,
    min_dist=0.1,
    n_components=3,
    random_state=42,
    use_precomputed_distances=True,
    distance_transform='subtract',
)

# Plot UMAP 3D embedding
plot_interactive_embedding_3d(
    umap_embedding_3d,
    g_groups,
    method_name="UMAP",
    title="Taxon Co-occurrence Embedding",
)


# Apply UMAP for 2D embedding
umap_embedding_2d = apply_umap(
    co_occurrence_df_all_pairs,
    n_neighbors=8,
    min_dist=0.1,
    n_components=2,
    random_state=42,
    use_precomputed_distances=True,
    distance_transform='subtract',    
)

# Define the function to plot 2D UMAP with Grouped Taxa
def plot_interactive_embedding_2d(embedding, groups, title, method_name):
    embedding.columns = ['UMAP_1', 'UMAP_2']
    fig = px.scatter(
        embedding, x='UMAP_1', y='UMAP_2', color=groups,
        title=f"{title} ({method_name})",
        labels={'UMAP_1': 'UMAP 1', 'UMAP_2': 'UMAP 2'}
    )
    # for i, txt in enumerate(embedding.index):
    #     fig.add_annotation(
    #         x=embedding.iloc[i, 0], y=embedding.iloc[i, 1],
    #         text=txt, showarrow=False, yshift=10
    #     )
    fig.show()

# Plot 2D UMAP with Grouped Taxa
plot_interactive_embedding_2d(
    umap_embedding_2d,
    g_groups,
    title="Taxon Co-occurrence UMAP Projection (2D) Grouped by G Name Combinations",
    method_name="UMAP",
)



In [2]:
# ===========================
# Enhanced Analysis Using NMI and Weighted Co-clustering
# ===========================

# Imports
import matplotlib.pyplot as plt
from brancharchitect.io import read_newick
from brancharchitect.co_clustering_frequencies import compute_weighted_co_clustering, compute_co_clustering_nmi
import warnings

# Suppress warnings for clean output
warnings.filterwarnings('ignore')

# ===========================
# Reading Trees and Computing Co-occurrence
# ===========================

# Read trees from a Newick file
# Adjust the file path as needed
trees = read_newick(
    "./../data/five_taxa_all_permutations.newick"
)

# Compute co-occurrence frequencies using weighted co-clustering
co_occurrence_freq_weighted = compute_weighted_co_clustering(trees)

# Compute co-occurrence frequencies using NMI
co_occurrence_freq_nmi = compute_co_clustering_nmi(trees)

# Convert co-occurrence frequencies to DataFrames
co_occurrence_df_weighted = convert_cooccurrence_to_dataframe(co_occurrence_freq_weighted)

# ===========================
# Analysis Using Weighted Co-clustering
# ===========================

# Plot the heatmap for weighted co-occurrence
plot_interactive_heatmap(
    co_occurrence_freq_weighted, title="Taxon Co-occurrence Heatmap (Weighted Co-clustering)"
)

# Extract G Groups
taxa = co_occurrence_df_weighted.index.tolist()
g_groups = get_g_groups(taxa)

# Compute linkage matrix for weighted co-occurrence
linkage_matrix_weighted = linkage(co_occurrence_df_weighted, method='ward')

# Plot dendrogram for weighted co-occurrence
plt.figure(figsize=(12, 8))
dendrogram(
    linkage_matrix_weighted,
    labels=co_occurrence_df_weighted.index,
    leaf_rotation=90,
)
plt.title("Hierarchical Clustering Dendrogram (Weighted Co-clustering)")
plt.xlabel("Taxa")
plt.ylabel("Distance")
plt.tight_layout()
plt.show()

# Perform permutation test for weighted co-occurrence
p_values_df_weighted = permutation_test_co_clustering(co_occurrence_df_weighted, num_permutations=1000)

# Plot heatmap of p-values for weighted co-occurrence
plt.figure(figsize=(12, 10))
sns.heatmap(p_values_df_weighted, cmap='coolwarm_r', square=True)
plt.title("P-values Heatmap of Co-occurrence Frequencies (Weighted Co-clustering)")
plt.show()

# ===========================
# UMAP Analysis
# ===========================

# Apply UMAP for 3D embedding
umap_embedding_3d = apply_umap(
    co_occurrence_df_weighted,
    n_neighbors=12,
    min_dist=0.1,
    n_components=3,
    random_state=42,
)

# Plot 3D UMAP with Grouped Taxa
plot_interactive_embedding_3d(
    umap_embedding_3d,
    g_groups,
    title="Taxon Co-occurrence UMAP Projection (3D) Grouped by G Name Combinations",
    method_name="UMAP",    
)


FileNotFoundError: [Errno 2] No such file or directory: './../data/five_taxa_all_permutations.newick'

In [1]:
# Imports
import networkx as nx
import matplotlib.pyplot as plt
from itertools import combinations
from typing import List, Dict, Tuple, Set

# Import the Node class from brancharchitect.tree
# Ensure this import matches your environment
from brancharchitect.tree import Node

# Function Definitions

def extract_splits(tree: Node) -> Set[Tuple[int]]:
    """
    Extract splits from a tree.

    Args:
    - tree (Node): The tree from which to extract splits.

    Returns:
    - Set[Tuple[int]]: A set of splits, where each split is represented as a tuple of indices.
    """
    # Assume tree.to_splits() returns a dictionary with split indices as keys
    return set(tree.to_splits().keys())

def get_taxa_indices(tree: Node) -> Tuple[Dict[int, str], Dict[str, int]]:
    """
    Create mappings between taxa indices and names.

    Args:
    - tree (Node): The tree from which to extract taxa.

    Returns:
    - Tuple[Dict[int, str], Dict[str, int]]: Two dictionaries for index-to-taxon and taxon-to-index mappings.
    """
    index_to_taxon = {}
    taxon_to_index = {}
    for leaf in tree.get_leaves():
        index = leaf.split_indices[0]  # Assuming split_indices is a tuple with one element
        name = leaf.name
        index_to_taxon[index] = name
        taxon_to_index[name] = index
    return index_to_taxon, taxon_to_index

def filter_minimal_splits(splits: List[Set[int]]) -> List[Set[int]]:
    """
    Filters the list of splits to only include minimal unique splits,
    i.e., splits that are not subsets of any other split in the list.

    Args:
    - splits (List[Set[int]]): List of splits represented as sets of taxon indices.

    Returns:
    - List[Set[int]]: Filtered list of minimal unique splits.
    """
    filtered_splits = []
    for split in splits:
        if not any(
            split < other_split for other_split in splits if split != other_split
        ):
            filtered_splits.append(split)
    return filtered_splits

def compute_taxon_co_occurrence_in_filtered_nonexistent_splits_all_pairs(
    list_of_trees: List[Node],
    to_filter: bool = True,
) -> Dict[str, Dict[str, float]]:
    """
    Computes co-occurrence frequencies between taxa based on filtered non-existent splits across all pairs of trees.

    Args:
    - list_of_trees (List[Node]): A list of trees to analyze.
    - to_filter (bool): Whether to filter splits to minimal unique splits.

    Returns:
    - Dict[str, Dict[str, float]]: Co-occurrence frequencies between taxa.
    """
    # Ensure there are at least two trees to compare
    if len(list_of_trees) < 2:
        return {}

    # Get the list of taxa and create indices
    taxa = list_of_trees[0]._order
    taxa_indices = {taxon: idx for idx, taxon in enumerate(taxa)}
    num_taxa = len(taxa_indices)

    # Initialize co-occurrence counts matrix
    co_occurrence_counts = np.zeros((num_taxa, num_taxa), dtype=int)
    total_filtered_splits = 0

    # Create a mapping from indices to taxon names
    index_to_taxon = {idx: taxon for taxon, idx in taxa_indices.items()}

    # Iterate over all pairs of trees
    tree_pairs = list(combinations(list_of_trees, 2))

    for tree_one, tree_two in tree_pairs:
        splits_one = set(tree_one.to_splits().keys())
        splits_two = set(tree_two.to_splits().keys())

        # Identify unique splits in both trees
        unique_splits_one = splits_one - splits_two
        unique_splits_two = splits_two - splits_one

        # Combine unique splits
        all_unique_splits = unique_splits_one.union(unique_splits_two)

        # Map split indices to taxon names
        splits_taxa = [
            set(index_to_taxon.get(idx) for idx in split if idx in index_to_taxon)
            for split in all_unique_splits
        ]

        if to_filter:
            # Filter minimal unique splits
            filtered_splits = filter_minimal_splits(splits_taxa)
        else:
            filtered_splits = splits_taxa

        total_filtered_splits += len(filtered_splits)

        for split in filtered_splits:
            taxa_in_split = list(split)
            # For each pair of taxa in the split, increment the co-occurrence count
            for taxon1, taxon2 in combinations(taxa_in_split, 2):
                idx1 = taxa_indices[taxon1]
                idx2 = taxa_indices[taxon2]
                co_occurrence_counts[idx1][idx2] += 1
                co_occurrence_counts[idx2][idx1] += 1  # Symmetric

    # Compute co-occurrence frequencies
    co_occurrence_freq = {}
    for i, taxon1 in enumerate(taxa):
        co_occurrence_freq[taxon1] = {}
        for j, taxon2 in enumerate(taxa):
            if i == j:
                co_occurrence_freq[taxon1][taxon2] = 1.0
            else:
                freq = (
                    co_occurrence_counts[i][j] / total_filtered_splits
                    if total_filtered_splits > 0
                    else 0.0
                )
                co_occurrence_freq[taxon1][taxon2] = freq

    return co_occurrence_freq

def create_network_from_co_occurrence(co_occurrence_freq: Dict[str, Dict[str, float]], threshold: float = 0.1) -> nx.Graph:
    """
    Creates a network from co-occurrence frequencies.

    Args:
    - co_occurrence_freq (Dict[str, Dict[str, float]]): Co-occurrence frequencies between taxa.
    - threshold (float): Minimum frequency to include an edge.

    Returns:
    - nx.Graph: A NetworkX graph with taxa as nodes and co-occurrence frequencies as edge weights.
    """
    G = nx.Graph()
    taxa = list(co_occurrence_freq.keys())
    G.add_nodes_from(taxa)

    for taxon1 in taxa:
        for taxon2, freq in co_occurrence_freq[taxon1].items():
            if taxon1 != taxon2 and freq >= threshold:
                G.add_edge(taxon1, taxon2, weight=freq)

    return G

def visualize_co_occurrence_network(G: nx.Graph, title: str = "Taxon Co-Occurrence Network"):
    """
    Visualizes the co-occurrence network using NetworkX and Matplotlib.

    Args:
    - G (nx.Graph): The co-occurrence network graph.
    - title (str): Title of the plot.
    """
    plt.figure(figsize=(12, 10))

    # Position nodes using spring layout
    pos = nx.spring_layout(G, k=0.5, seed=42)

    # Get edge weights for visualization
    weights = nx.get_edge_attributes(G, 'weight')
    max_weight = max(weights.values()) if weights else 1
    edge_widths = [weights[edge] / max_weight * 5 for edge in G.edges()]  # Scale weights for visibility

    # Draw nodes and edges
    nx.draw_networkx_nodes(G, pos, node_size=500, node_color='skyblue')
    nx.draw_networkx_edges(G, pos, width=edge_widths, alpha=0.7)
    nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif')

    # Optionally, draw edge labels
    # edge_labels = {edge: f"{weights[edge]:.2f}" for edge in G.edges()}
    # nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)

    plt.title(title, fontsize=16)
    plt.axis('off')
    plt.show()

# ===============================
# Main Code Execution Starts Here
# ===============================

# Replace the following with your actual list of trees
# Each tree should be an instance of brancharchitect.tree.Node

# Example:
# list_of_trees = [tree1, tree2, tree3, ...]

# For demonstration purposes, let's assume you have the following list of trees
# Note: Replace this with your actual trees
list_of_trees = read_newick(
    "./../../sliding-window/output_norovirus_window_size_200_5/best_trees.newick"
)


# Ensure that list_of_trees is not empty
if not list_of_trees:
    print("Please provide your list of trees as 'list_of_trees'.")
else:
    # Compute co-occurrence frequencies
    co_occurrence_freq = compute_taxon_co_occurrence_in_filtered_nonexistent_splits_all_pairs(list_of_trees)

    # Create the network with a chosen threshold
    threshold = 0.6  # Adjust as needed
    G = create_network_from_co_occurrence(co_occurrence_freq, threshold=threshold)

    # Visualize the network
    visualize_co_occurrence_network(G, title="Taxon Co-Occurrence Network")

NameError: name 'read_newick' is not defined