In [13]:
#IMPORTING LIBRARIES AND DATASETS

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt



# IMPORTING THE FILTERED GENOTYPIC DATA
ATLAS_Dataset = pd.read_csv('atlas_2024_genes.csv')

# FILTER OUT THE DATA OF EACH CONTINENT FOR CONTINENT-BASED ANALYSES
Africa = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'Africa']
Europe = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'Europe']
North_America = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'North America']
South_America = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'South America']
Asia = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'Asia']
Oceania = ATLAS_Dataset[ATLAS_Dataset['Continents'] == 'Oceania']

In [23]:
# VISUALIZING THE PREVALENCE OF DIFFERENT GENOTYPES IN ONE OR MORE SELECTED SPECIES

# Function to visualize the prevalence of AMR genes in selected species
def amr_gene_prevalence(df, species_list):


    gene_class_mapping = {
        'CTXM9': 'ESBL', 'SHV': 'Beta-lactamase', 'CTXM1': 'ESBL', 'TEM': 'Beta-lactamase', 
        'KPC': 'Carbapenemase', 'AMPC': 'AmpC beta-lactamase', 'ACTMIR': 'AmpC beta-lactamase', 
        'VIM': 'Carbapenemase', 'OXA': 'Carbapenemase', 'CTXM2': 'ESBL', 'VEB': 'ESBL', 
        'CMY11': 'AmpC beta-lactamase', 'DHA': 'AmpC beta-lactamase', 'GES': 'Carbapenemase', 
        'ACC': 'AmpC beta-lactamase', 'CTXM825': 'ESBL', 'NDM': 'Carbapenemase', 'IMP': 'Carbapenemase', 
        'FOX': 'AmpC beta-lactamase', 'SPM': 'Carbapenemase', 'CMY1MOX': 'AmpC beta-lactamase','PER':'ESBL'
    }
    # Filter the DataFrame to only include data for the selected species
    df_filtered = df[df['Species'].isin(species_list)]

    # Create a pivot table to count the occurrences of each gene within each species
    pivot_table = df_filtered.pivot_table(index='Gene', columns='Species', aggfunc='size', fill_value=0)

    # Add a column to the pivot table to map each gene to its class
    pivot_table['Gene Class'] = pivot_table.index.map(gene_class_mapping).fillna('Unknown')

    # Define colors for each gene class for visualization purposes
    class_colors = {'ESBL': 'skyblue', 'Carbapenemase': 'lightgreen','AmpC beta-lactamase': 'lightcoral','Beta-lactamase': 'grey',}

    # Map each gene to its corresponding color based on the class
    gene_classes = pivot_table['Gene Class']
    row_colors = gene_classes.map(class_colors)

    # Remove the 'Gene Class' column from the pivot table to prepare for heatmap plotting
    pivot_table = pivot_table.drop(columns=['Gene Class'])

    # Plot the heatmap of gene prevalence
    plt.figure(figsize=(11,8))
    sns.heatmap(pivot_table, cmap="Blues", annot=True, fmt="d", cbar_kws={'label': 'Prevalence'})

    # Get the current axis and apply the row colors to each gene label
    ax = plt.gca()
    for label, color in zip(ax.get_yticklabels(), row_colors):
        label.set_color(color)

    # Create a legend to display the color coding for gene classes
    handles = [plt.Line2D([0], [0], color=color, lw=4) for color in class_colors.values()]
    labels = class_colors.keys()
    plt.legend(handles, labels, title='Genotype Subclasses', loc='upper right', bbox_to_anchor=(1.5, 1))

    # Add titles and labels to the plot
    plt.title(f'Prevalence of AMR Genes by Species in Africa')
    plt.xlabel('Species')
    plt.ylabel('Gene')
    #plt.show()
    plt.savefig('amr_gene_prevalence.png')

Test run:

amr_gene_prevalence(ATLAS_Dataset,['Escherichia coli', 'Klebsiella pneumoniae','Enterobacter cloacae','Pseudomonas aeruginosa'])