In [3]:
#IMPORTING LIBRARIES AND DATASETS

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

# IMPORTING THE FILTERED GENOTYPIC DATA
ATLAS_Dataset = pd.read_csv('atlas_2024_genes.csv')

# FILTER OUT THE DATA OF EACH CONTINENT FOR CONTINENT-BASED ANALYSES
Africa = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'Africa']
Europe = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'Europe']
North_America = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'North America']
South_America = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'South America']
Asia = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'Asia']
Oceania = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'Oceania']

In [12]:
# MAKING SANKEY DIAGRAMS OF SELECTED GENOTYPE SUB-CLASSES, FOR SELECTED SOECIES AND CONTINENT

#Function to create a Sankey diagram for selected genotype sub-classes, species, and continent
def AMR_genes_sankey_diagram(df, selected_species=None, selected_gene_classes=None):

    # Filter the DataFrame for selected species if specified
    if selected_species is not None:
        df = df[df['Species'].isin(selected_species)]

    # Filter the DataFrame for selected gene classes if specified
    if selected_gene_classes is not None:
        df = df[df['Gene Class'].isin(selected_gene_classes)]

    # Group the data by Species, Gene Class, Gene, and Strain, and count occurrences
    df_grouped = df.groupby(['Species', 'Gene Class', 'Gene', 'Strain']).size().reset_index(name='Count')

    # Extract unique values for species, gene classes, genes, and strains
    species = df_grouped['Species'].unique()
    gene_classes = df_grouped['Gene Class'].unique()
    genes = df_grouped['Gene'].unique()
    strains = df_grouped['Strain'].unique()

    # Create index mappings for each category to ensure unique node indices
    species_indices = {species: idx for idx, species in enumerate(species)}
    gene_class_indices = {gene_class: idx + len(species) for idx, gene_class in enumerate(gene_classes)}
    gene_indices = {gene: idx + len(species) + len(gene_classes) for idx, gene in enumerate(genes)}
    strain_indices = {strain: idx + len(species) + len(gene_classes) + len(genes) for idx, strain in enumerate(strains)}

    # Combine all nodes into a single list
    all_nodes = list(species) + list(gene_classes) + list(genes) + list(strains)

    # Create a dictionary to map nodes to their indices
    node_indices = {node: idx for idx, node in enumerate(all_nodes)}

    # Prepare the source indices for the Sankey diagram links
    sources = (
        [species_indices[species] for species in df_grouped['Species']] +
        [gene_class_indices[gene_class] for gene_class in df_grouped['Gene Class']] +
        [gene_indices[gene] for gene in df_grouped['Gene']]
    )

    # Prepare the target indices for the Sankey diagram links
    targets = (
        [gene_class_indices[gene_class] for gene_class in df_grouped['Gene Class']] +
        [gene_indices[gene] for gene in df_grouped['Gene']] + 
        [strain_indices[strain] for strain in df_grouped['Strain']]
    )

    # Replicate the count values for each type of link to match the number of connections
    values = df_grouped['Count'].tolist() * 3 

    # Create the Sankey diagram using Plotly
    fig = go.Figure(go.Sankey(
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color='black', width=0.5),
            label=all_nodes,
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values
        )
    ))

    # Update the layout of the figure
    fig.update_layout(
        title_text='A Sankey Diagram Linking Bacterial Species to Genotype Sub-classes (Beta Lactamases), Genotypes, and Subgenotypes',
        font_size=10,
        width=1000, 
        height=1000
    )

    # Display the Sankey diagram
    fig.show()