## In this script, we detect bacterial communities that have a particular threshold of AMR genes in common

In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity
from plotly.graph_objs import Scatter3d, Figure

In [2]:
# Read the rgi results combiined csv
combined_data_tsv = pd.read_csv("../Results/combined_data_tsv.tsv", sep="\t")

  combined_data_tsv = pd.read_csv("../Results/combined_data_tsv.tsv", sep="\t")


##### We use only the perfect and strict hits for detecting the bacterial communities

In [3]:
combined_data_perfect_strict_tsv = combined_data_tsv[combined_data_tsv["Cut_Off"].isin(["Perfect", "Strict"])]

In [14]:

# Pivot the data
combined_data_perfect_strict_pivoted = combined_data_perfect_strict_tsv.pivot_table(
    index='organism',
    columns='Best_Hit_ARO',
    values='Best_Identities',
    aggfunc='sum'  
)

# Replace NaNs with 0 and round to nearest integer 
combined_data_perfect_strict_pivoted = combined_data_perfect_strict_pivoted.fillna(0)
combined_data_perfect_strict_pivoted = (combined_data_perfect_strict_pivoted / 100).round(0).astype(int)

# Filter out organisms with fewer than 5 total AMR genes
combined_data_perfect_strict_pivoted = combined_data_perfect_strict_pivoted[
    combined_data_perfect_strict_pivoted.sum(axis=1) >= 3 
]

# Drop genes that are absent in all bacteria
combined_data_perfect_strict_pivoted = combined_data_perfect_strict_pivoted.loc[:, (combined_data_perfect_strict_pivoted != 0).any(axis=0)]


##### Now we detect and plot the communities

In [16]:
import pandas as pd
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity
from plotly.graph_objs import Scatter3d, Figure

# Prepare data
organisms = combined_data_perfect_strict_pivoted.index.tolist()
X = combined_data_perfect_strict_pivoted.values

# Cosine similarity matrix 
cos_sim = cosine_similarity(X)

SIM_THRESHOLD = 0.75
CLIQUE_SIZE_THRESHOLD = 5

# Build graph with edges only if sim ≥ threshold
G = nx.Graph()
G.add_nodes_from(organisms)

n = len(organisms)
for i in range(n):
    for j in range(i + 1, n):
        if cos_sim[i, j] >= SIM_THRESHOLD:
            G.add_edge(organisms[i], organisms[j], weight=cos_sim[i, j])

print(f"Graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")

# Find maximal cliques
candidate_cliques = list(nx.find_cliques(G))

# Validate cliques (every pair must satisfy similarity) AND size ≥ 5
def clique_valid(clique, sim_matrix, labels, threshold):
    indices = [labels.index(node) for node in clique]
    for i in range(len(indices)):
        for j in range(i + 1, len(indices)):
            if sim_matrix[indices[i], indices[j]] < threshold:
                return False
    return True

strict_cliques = []
used_nodes = set()


for clique in sorted(candidate_cliques, key=len, reverse=True):
    # Filter by size
    if len(clique) < CLIQUE_SIZE_THRESHOLD:
        continue
    # Filter by internal similarity
    if not clique_valid(clique, cos_sim, organisms, SIM_THRESHOLD):
        continue
    # Ensure no overlap with already used nodes
    if any(node in used_nodes for node in clique):
        continue
    # Accept clique
    strict_cliques.append(clique)
    used_nodes.update(clique)

print(f"✅ Retained {len(strict_cliques)} strict cliques with ≥ {CLIQUE_SIZE_THRESHOLD} members and all pairwise sim ≥ {SIM_THRESHOLD}")

# Assign community index
node_community_map = {}
for idx, clique in enumerate(strict_cliques):
    for node in clique:
        node_community_map[node] = idx

# Remove nodes not in any community
filtered_nodes = list(node_community_map.keys())
G_filtered = G.subgraph(filtered_nodes).copy()

# Color mapping
node_colors = [node_community_map[node] for node in G_filtered.nodes()]
node_text = list(G_filtered.nodes())

# 3D Layout and Visualization
pos = nx.spring_layout(G_filtered, dim=3, seed=42)

node_x = [pos[node][0] for node in G_filtered.nodes()]
node_y = [pos[node][1] for node in G_filtered.nodes()]
node_z = [pos[node][2] for node in G_filtered.nodes()]

# Edges within same community
edge_x, edge_y, edge_z = [], [], []
for u, v in G_filtered.edges():
    if node_community_map[u] == node_community_map[v]:
        x0, y0, z0 = pos[u]
        x1, y1, z1 = pos[v]
        edge_x += [x0, x1, None]
        edge_y += [y0, y1, None]
        edge_z += [z0, z1, None]

# Plotting
edge_trace = Scatter3d(
    x=edge_x, y=edge_y, z=edge_z,
    mode='lines',
    line=dict(width=2, color='gray'),
    hoverinfo='none'
)

node_trace = Scatter3d(
    x=node_x, y=node_y, z=node_z,
    mode='markers',
    marker=dict(
        size=6,
        color=node_colors,
        colorscale='Viridis',
        opacity=0.9,
        colorbar=dict(title="Community Index")
    ),
    text=node_text,
    hoverinfo='text'
)

fig = Figure(data=[edge_trace, node_trace])
fig.update_layout(
    title = f'Bacterial Communities Based on Shared AMR Genes (Each Community ≥ {CLIQUE_SIZE_THRESHOLD} Bacteria, Pairwise Cosine Similarity ≥ {SIM_THRESHOLD})',
    showlegend=False,
    margin=dict(l=0, r=0, b=0, t=50),
    scene=dict(
        xaxis=dict(showgrid=False, zeroline=False),
        yaxis=dict(showgrid=False, zeroline=False),
        zaxis=dict(showgrid=False, zeroline=False)
    )
)

# Save
fig.write_html("../Results/communities.html")


Graph: 1194 nodes, 24289 edges
✅ Retained 28 strict cliques with ≥ 5 members and all pairwise sim ≥ 0.75


#### We iterate through each community, see its members and the frequency of the genes present in it

In [None]:
print("\n=== Community AMR Gene Profiles ===")

for idx, clique in enumerate(strict_cliques):
    print(f"\n🔹 Community {idx + 1} (Size: {len(clique)})")
    print("Members:")
    for org in clique:
        print(f"   - {org}")

    # Subset AMR gene profile
    community_matrix = combined_data_perfect_strict_pivoted.loc[clique]

    # Compute frequency 
    gene_presence_count = (community_matrix > 0).sum(axis=0)
    gene_presence_freq = gene_presence_count[gene_presence_count > 0].sort_values(ascending=False)

    print("\nTop AMR genes by frequency (count/community size):")
    print(gene_presence_freq/len(clique))

    # Optional: Save each community profile to CSV
    # community_matrix.to_csv(f"../Results/community_{idx + 1}_amr_profile.csv")



=== Community AMR Gene Profiles ===

🔹 Community 1 (Size: 92)
Members:
   - Pseudomonas kielensis strain ZE23JCel16
   - Pseudomonas marincola strain YSy11
   - Aquipseudomonas alcaligenes strain NEB 585
   - Agarivorans aestuarii strain KCTC 32543
   - Pseudomonas qingdaonensis strain S-1
   - Pseudomonas cavernae strain K2W31S-8
   - Pseudomonas vranovensis strain MYb188
   - Pseudoalteromonas ostreae strain hOe-124
   - Pseudoalteromonas rhizosphaerae strain hCg-42
   - Luteibacter pinisoli strain MAH-14
   - Agarivorans albus strain JCM 21469
   - Lysobacter enzymogenes strain M497-1
   - Pseudomonas shahriarae strain SWRI52
   - Lysobacter antibioticus strain 76
   - Pseudomonas brenneri strain K5-sn1400
   - Pseudoalteromonas donghaensis strain HJ51
   - Methylocaldum szegediense isolate Msz(Nor)
   - Pseudomonas paeninsulae strain IT1137
   - Luteibacter anthropi strain SM7.4
   - Pseudomonas chlororaphis subsp. chlororaphis strain ATCC 9446
   - Pseudoduganella albidiflava str