In [2]:
!pip install -q -r requirements.txt

In [5]:
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt

def plot_world_map_samples(tsv_path, shp_file_path, output_path='world_map_samples.png', 
                           figsize=(12, 8), point_size=10, point_color='grey', point_alpha=0.7):
    """
    Plot sample locations on a world map.

    Args:
        tsv_path (str): Path to the TSV file containing sample data.
        shp_file_path (str): Path to the shapefile for the world map.
        output_path (str): Path to save the output figure. Default is 'world_map_samples.png'.
        figsize (tuple): Figure size as (width, height). Default is (12, 8).
        point_size (int): Size of the sample points. Default is 10.
        point_color (str): Color of the sample points. Default is 'red'.
        point_alpha (float): Alpha (transparency) of the sample points. Default is 0.7.

    Returns:
        None
    """
    # Read the sample data
    df = pd.read_csv(tsv_path, sep='\t')

    # Read the world shapefile
    world = gpd.read_file(shp_file_path)

    # Create the figure and axis
    fig, ax = plt.subplots(figsize=figsize)

    # Plot the world map
    world.plot(ax=ax, color='lightgrey', edgecolor='black')

    # Plot the sample locations
    ax.scatter(df['longitude'], df['latitude'], 
               c=point_color, s=point_size, alpha=point_alpha)

    # Set the title and labels
    ax.set_title('Sample Locations', fontsize=16)
    ax.set_xlabel('Longitude', fontsize=12)
    ax.set_ylabel('Latitude', fontsize=12)

    # Set the map extent
    ax.set_xlim(-180, 180)
    ax.set_ylim(-90, 90)

    # Add gridlines
    ax.grid(True, linestyle='--', alpha=0.7)

    # Adjust layout and save the figure
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

    print(f"Map has been saved to {output_path}")

In [6]:
shp_file = "data/maps/ne_110m_ocean/ne_110m_ocean.shp"
sample_coords = "samples.tsv"
plot_world_map_samples(sample_coords, shp_file, output_path="world_samples.png")

Map has been saved to world_samples.png


## Overlay pie charts on world map

In [27]:
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import numpy as np

def plot_ocean_taxonomy_sample_map(sample_tsv, taxonomy_tsv, shp_file_path, ocean_centers, 
                                   output_path='ocean_taxonomy_sample_map.png', 
                                   figsize=(20, 15), sample_size=10, sample_color='grey', sample_alpha=0.7):
    """
    Plot sample locations and phylum distribution pie charts for each ocean basin on a world map.

    Args:
        sample_tsv (str): Path to the TSV file containing sample location data.
        taxonomy_tsv (str): Path to the TSV file containing taxonomy data.
        shp_file_path (str): Path to the shapefile for the world map.
        ocean_centers (dict): Dictionary with ocean basin names as keys and (lon, lat) tuples as values.
        output_path (str): Path to save the output figure. Default is 'ocean_taxonomy_sample_map.png'.
        figsize (tuple): Figure size as (width, height). Default is (20, 15).
        sample_size (int): Size of the sample points. Default is 10.
        sample_color (str): Color of the sample points. Default is 'grey'.
        sample_alpha (float): Alpha (transparency) of the sample points. Default is 0.7.

    Returns:
        None
    """
    # Read the sample data
    sample_df = pd.read_csv(sample_tsv, sep='\t')

    # Read the taxonomy data
    taxonomy_df = pd.read_csv(taxonomy_tsv, sep='\t')

    # Get top 10 phyla
    top_phyla = taxonomy_df.groupby('phylum')['count'].sum().nlargest(10).index.tolist()

    # Function to categorize phyla
    def categorize_phylum(phylum):
        return phylum if phylum in top_phyla else 'Other'

    # Apply categorization and recalculate counts
    taxonomy_df['phylum_category'] = taxonomy_df['phylum'].apply(categorize_phylum)
    df_grouped = taxonomy_df.groupby(['ocean_basin', 'phylum_category'])['count'].sum().reset_index()

    # Read the world shapefile
    world = gpd.read_file(shp_file_path)

    # Create the figure and axis
    fig, ax = plt.subplots(figsize=figsize)

    # Plot the world map
    world.plot(ax=ax, color='lightgrey', edgecolor='black')

    # Plot the sample locations
    ax.scatter(sample_df['longitude'], sample_df['latitude'], 
               c=sample_color, s=sample_size, alpha=sample_alpha)

    # Function to create a pie chart
    def make_pie(sizes, colors, x, y, size):
        circle = Circle((x, y), size, facecolor='none', edgecolor='none')
        ax.add_artist(circle)
        ax.pie(sizes, colors=colors, center=(x, y), radius=size)

    # Plot pie charts for each ocean basin
    for basin, (lon, lat) in ocean_centers.items():
        basin_data = df_grouped[df_grouped['ocean_basin'] == basin]
        sizes = basin_data['count'].values
        labels = basin_data['phylum_category'].values
        colors = plt.cm.Spectral(np.linspace(0, 1, len(sizes)))
        make_pie(sizes, colors, lon, lat, 15)  # Adjust size as needed
        ax.text(lon, lat, basin, fontsize=10, ha='center', va='center')

    # Set the map extent
    ax.set_xlim(-180, 180)
    ax.set_ylim(-90, 90)

    # Set the title and labels
    #ax.set_title('Sample Locations and Phylum Distribution Across Ocean Basins', fontsize=20)
    #ax.set_xlabel('Longitude', fontsize=14)
    #ax.set_ylabel('Latitude', fontsize=14)

    # Add gridlines
    ax.grid(True, linestyle='--', alpha=0.7)

    # Add a legend for phyla
    handles = [plt.Rectangle((0,0),1,1, color=plt.cm.Spectral(i/len(top_phyla))) for i in range(len(top_phyla)+1)]
    labels = top_phyla + ['Other']
    plt.legend(handles, labels, title='Phyla', loc='lower left', bbox_to_anchor=(1, 0.5))

    # Adjust layout and save the figure
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

    print(f"Map has been saved to {output_path}")

In [28]:
shp_file = "data/maps/ne_110m_ocean/ne_110m_ocean.shp"

ocean_centers = {
 'Antarctic': (-50, -70),
 'Arctic': (0, 70),
 'Atlantic': (-30, 30),
 'Indian': (80, -10),
 'Pacific': (-150, 0)
}

plot_ocean_taxonomy_sample_map(
    'samples.tsv', 'outputs/phylum_data.tsv',
    shp_file, ocean_centers
)

Map has been saved to ocean_taxonomy_sample_map.png


## Plot BGCs

In [46]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def plot_bgc_class_distribution(tsv_path, output_path='top_bgc_class_distribution.png', figsize=(16, 8)):
    """
    Create a grouped bar plot of the top 10 BGC class distributions across ocean basins,
    excluding 'Other' and 'Unknown' classes, using a custom color palette.
    
    Args:
        tsv_path (str): Path to the TSV file containing BGC class data.
        output_path (str): Path to save the output figure. Default is 'top_bgc_class_distribution.png'.
        figsize (tuple): Figure size as (width, height). Default is (16, 8).
    
    Returns:
        None
    """
    # Read the BGC class data
    df = pd.read_csv(tsv_path, sep='\t')
    
    # Remove 'Other' and 'Unknown' classes
    df = df[~df['bgc_class'].isin(['other', 'unknown'])]
    top_10_bgc_classes = df.groupby('bgc_class')['count'].sum().nlargest(10).index
    df = df[df['bgc_class'].isin(top_10_bgc_classes)]
    
    # Pivot the dataframe to have BGC classes as rows and ocean basins as columns
    pivot_df = df.pivot(index='bgc_class', columns='ocean_basin', values='count').fillna(0)
    
    # Sort BGC classes by total count across all ocean basins and select top 10
    pivot_df['total'] = pivot_df.sum(axis=1)
    pivot_df = pivot_df.sort_values('total', ascending=False).head(10).drop('total', axis=1)
    
    # Create the grouped bar plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Set the width of each bar group
    width = 0.15
    
    # Set up the x-axis
    x = range(len(pivot_df))
    
    # Generate colors using the specified color palette
    bgc_colors = plt.cm.tab20(np.linspace(0, 1, len(pivot_df.columns)))
    
    # Plot bars for each ocean basin
    for i, (basin, color) in enumerate(zip(pivot_df.columns, bgc_colors)):
        ax.bar([xi + i*width for xi in x], pivot_df[basin], width, label=basin, color=color)
    
    # Customize the plot
    ax.set_ylabel('Count', fontsize=12)
    ax.set_xlabel('BGC Class', fontsize=12)
    ax.set_title('Distribution of Top 10 BGC Classes Across Ocean Basins', fontsize=16)
    ax.set_xticks([xi + width*(len(pivot_df.columns)-1)/2 for xi in x])
    ax.set_xticklabels(pivot_df.index, rotation=45, ha='right')
    ax.legend(title='Ocean Basin')
    
    # Adjust layout and save the figure
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Top BGC class distribution plot has been saved to {output_path}")


plot_bgc_class_distribution('outputs/bgc_data.tsv')

Top BGC class distribution plot has been saved to top_bgc_class_distribution.png


## Separate BGC plots for ocean basin

In [45]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os

def plot_bgc_class_distribution(tsv_path, output_dir='bgc_plots', figsize=(12, 6)):
    """
    Create separate bar plots of the top 10 BGC class distributions for each ocean basin,
    excluding 'Other' and 'Unknown' classes, using a unique color for each BGC class.
    
    Args:
        tsv_path (str): Path to the TSV file containing BGC class data.
        output_dir (str): Directory to save the output figures. Default is 'bgc_plots'.
        figsize (tuple): Figure size as (width, height). Default is (12, 6).
    
    Returns:
        None
    """
    # Read the BGC class data
    df = pd.read_csv(tsv_path, sep='\t')
    
    # Remove 'Other' and 'Unknown' classes
    df = df[~df['bgc_class'].isin(['other', 'unknown'])]
    top_10_bgc_classes = df.groupby('bgc_class')['count'].sum().nlargest(10).index
    df = df[df['bgc_class'].isin(top_10_bgc_classes)]
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Get unique ocean basins and BGC classes
    ocean_basins = df['ocean_basin'].unique()
    all_bgc_classes = df['bgc_class'].unique()
    
    # Generate a color map for all BGC classes
    color_map = dict(zip(all_bgc_classes, plt.cm.tab20(np.linspace(0, 1, len(all_bgc_classes)))))
    
    # Create a plot for each ocean basin
    for basin in ocean_basins:
        # Filter data for the current ocean basin
        basin_data = df[df['ocean_basin'] == basin]
        
        # Sort BGC classes by count and select top 10
        top_bgcs = basin_data.sort_values('count', ascending=False).head(10)
        
        # Create the bar plot
        fig, ax = plt.subplots(figsize=figsize)
        
        # Plot bars for each BGC class with its unique color
        bars = ax.bar(top_bgcs['bgc_class'], top_bgcs['count'], 
                      color=[color_map[bgc] for bgc in top_bgcs['bgc_class']])
        
        # Customize the plot
        ax.set_ylabel('Count', fontsize=12)
        ax.set_xlabel('BGC Class', fontsize=12)
        ax.set_title(f'Top 10 BGC Classes in {basin}', fontsize=16)
        ax.set_xticklabels(top_bgcs['bgc_class'], rotation=45, ha='right')
        
        # Adjust layout and save the figure
        plt.tight_layout()
        output_path = os.path.join(output_dir, f'bgc_distribution_{basin.replace(" ", "_").lower()}.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"BGC class distribution plot for {basin} has been saved to {output_path}")


plot_bgc_class_distribution('outputs/bgc_data.tsv')

  ax.set_xticklabels(top_bgcs['bgc_class'], rotation=45, ha='right')


BGC class distribution plot for Antarctic has been saved to bgc_plots/bgc_distribution_antarctic.png


  ax.set_xticklabels(top_bgcs['bgc_class'], rotation=45, ha='right')


BGC class distribution plot for Arctic has been saved to bgc_plots/bgc_distribution_arctic.png


  ax.set_xticklabels(top_bgcs['bgc_class'], rotation=45, ha='right')


BGC class distribution plot for Atlantic has been saved to bgc_plots/bgc_distribution_atlantic.png


  ax.set_xticklabels(top_bgcs['bgc_class'], rotation=45, ha='right')


BGC class distribution plot for Indian has been saved to bgc_plots/bgc_distribution_indian.png


  ax.set_xticklabels(top_bgcs['bgc_class'], rotation=45, ha='right')


BGC class distribution plot for Pacific has been saved to bgc_plots/bgc_distribution_pacific.png


## Combined pie charts and bar plots

In [51]:
import matplotlib.patheffects as path_effects  

def plot_combined_ocean_visualization(sample_tsv, taxonomy_tsv, bgc_tsv, grid_shp_file_path, ocean_shp_file_path,
                                      output_path='combined_ocean_visualization.png', 
                                      figsize=(30, 20), sample_size=10, sample_color='black', sample_alpha=0.7,
                                      total_samples=0,  total_mags=0, total_proteins=0, total_bgcs=0, chart_alpha=0.8, chart_size=40):
    """
    Plot sample locations, phylum distribution pie charts, and BGC class bar plots for each ocean basin on a world map.
    
    Args:
        sample_tsv (str): Path to the TSV file containing sample location data.
        taxonomy_tsv (str): Path to the TSV file containing taxonomy data.
        bgc_tsv (str): Path to the TSV file containing BGC class data.
        shp_file_path (str): Path to the shapefile for the world map.
        output_path (str): Path to save the output figure. Default is 'combined_ocean_visualization.png'.
        figsize (tuple): Figure size as (width, height). Default is (30, 20).
        sample_size (int): Size of the sample points. Default is 10.
        sample_color (str): Color of the sample points. Default is 'black'.
        sample_alpha (float): Alpha (transparency) of the sample points. Default is 0.7.
        total_mags (int): Total number of MAGs in the database.
        unique_phyla (int): Number of represented (unique) phyla.
        total_bgcs (int): Total number of BGCs.
        chart_alpha (float): Alpha (transparency) of pie charts and bar plots. Default is 0.8.
    
    Returns:
        None
    """
    # Define ocean centers
    ocean_centers = {
        'Antarctic': (-10, -70),
        'Arctic': (0, 60),
        'Atlantic': (-30, 0),
        'Indian': (90, -20),
        'Pacific': (-120, 0)
    }

    # Read the data
    sample_df = pd.read_csv(sample_tsv, sep='\t')
    taxonomy_df = pd.read_csv(taxonomy_tsv, sep='\t')
    bgc_df = pd.read_csv(bgc_tsv, sep='\t')

    # Process taxonomy data
    top_phyla = taxonomy_df.groupby('phylum')['count'].sum().nlargest(10).index.tolist()
    taxonomy_df['phylum_category'] = taxonomy_df['phylum'].apply(lambda x: x if x in top_phyla else 'Other')
    df_grouped = taxonomy_df.groupby(['ocean_basin', 'phylum_category'])['count'].sum().reset_index()

    # Process BGC data
    bgc_df = bgc_df[~bgc_df['bgc_class'].isin(['other', 'unknown'])]
    top_10_bgc_classes = bgc_df.groupby('bgc_class')['count'].sum().nlargest(10).index
    bgc_df = bgc_df[bgc_df['bgc_class'].isin(top_10_bgc_classes)]

    # Create color maps
    phyla_color_map = dict(zip(top_phyla + ['Other'], plt.cm.tab20(np.linspace(0, 1, len(top_phyla) + 1))))
    bgc_color_map = dict(zip(top_10_bgc_classes, plt.cm.tab10(np.linspace(0, 1, len(top_10_bgc_classes)))))

    # Read the world shapefile
    world = gpd.read_file(ocean_shp_file_path)
    grid = gpd.read_file(grid_shp_file_path)

    # Create the figure and axis
    fig, ax = plt.subplots(figsize=figsize)

    # Plot the world map
    world.plot(ax=ax, color='#92a6b3', edgecolor='black')
    grid.plot(ax=ax, color='black', edgecolor=None, alpha=0.4)

    # Plot the sample locations
    ax.scatter(sample_df['longitude'], sample_df['latitude'], 
               c=sample_color, s=sample_size, alpha=sample_alpha)

    # Function to create a pie chart
    def make_pie(sizes, labels, x, y, size):
        # Create pie chart
        colors = [phyla_color_map[label] for label in labels]
        wedges, _ = ax.pie(sizes, colors=colors, center=(x, y), radius=size, wedgeprops=dict(width=size, alpha=chart_alpha), startangle=90)
        for wedge in wedges:
            wedge.set_zorder(3)

    # Function to create a vertical bar plot
    def make_bar(data, x, y, width, height):
        # Create bar plot
        bar_width = width / len(data)
        max_count = data['count'].max()
        for i, (_, row) in enumerate(data.iterrows()):
            bar_height = (row['count'] / max_count) * height
            ax.bar(x + i * bar_width, bar_height, bar_width, bottom=y,
                   color=bgc_color_map[row['bgc_class']], alpha=chart_alpha, zorder=3)

    # Plot visualizations for each ocean basin
    for basin, (lon, lat) in ocean_centers.items():
        # Pie chart for taxonomy
        basin_tax_data = df_grouped[df_grouped['ocean_basin'] == basin]
        sizes = basin_tax_data['count'].values
        labels = basin_tax_data['phylum_category'].values
        make_pie(sizes, labels, lon - chart_size/2 - 2, lat, chart_size/2)

        # Bar plot for BGC classes
        basin_bgc_data = bgc_df[bgc_df['ocean_basin'] == basin].sort_values('count', ascending=False).head(10)
        make_bar(basin_bgc_data, lon + 2, lat - chart_size/2, chart_size, chart_size)

        # Add ocean basin label
        # ax.text(lon, lat + chart_size/2 + 2, basin, fontsize=24, ha='center', va='bottom', weight='bold', color="white")
        ax.text(lon, lat + chart_size/2 + 2, basin, fontsize=24, ha='center', va='bottom', weight='bold', color="white",
                path_effects=[path_effects.Stroke(linewidth=3, foreground='black'),
                              path_effects.Normal()])

    # Set the map extent
    ax.set_xlim(-180, 180)
    ax.set_ylim(-90, 90)

    # Create a combined legend for phyla and BGC classes
    legend_elements = []
    legend_labels = []
    
    # Add title for taxa
    legend_elements.append(plt.Line2D([0], [0], color='none'))
    legend_labels.append('Phylum')
    
    # Add phyla to legend
    legend_elements.extend([plt.Rectangle((0,0),1,1, color=phyla_color_map[phylum], alpha=chart_alpha) for phylum in top_phyla + ['Other']])
    legend_labels.extend(top_phyla + ['Other'])
    
    # Add a separator
    legend_elements.append(plt.Line2D([0], [0], color='none'))
    legend_labels.append('')
    
    # Add title for BGC classes
    legend_elements.append(plt.Line2D([0], [0], color='none'))
    legend_labels.append('BGC Classes')
    
    # Add BGC classes to legend
    legend_elements.extend([plt.Rectangle((0,0),1,1, color=bgc_color_map[bgc], alpha=chart_alpha) for bgc in top_10_bgc_classes])
    legend_labels.extend(top_10_bgc_classes)

    # Add the combined legend
    ax.legend(legend_elements, legend_labels, loc='center left', bbox_to_anchor=(1, 0.5), title='', 
              frameon=False, handletextpad=0.5, columnspacing=1)

    # Adjust the legend style
    leg = ax.get_legend()
    for t in leg.get_texts():
        if t.get_text() in ['Phylum', 'BGC Classes']:
            t.set_fontweight('bold')

    # Add gridlines
    ax.grid(True, linestyle='--', alpha=0.7)

    # Add text box with database statistics
    stats_text = f"Database Statistics:\n\n" \
                 f"Total Samples: {total_samples:,}\n" \
                 f"Total MAGs: {total_mags:,}\n" \
                 f"Total Proteins: {total_proteins}\n" \
                 f"Total BGCs: {total_bgcs:,}"
    ax.text(0.05, 0.02, stats_text, transform=ax.transAxes, fontsize=18,
            verticalalignment='bottom', horizontalalignment='left',
            bbox=dict(boxstyle='round,pad=1', facecolor='white', alpha=0.8))

    # Adjust layout and save the figure
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Combined visualization has been saved to {output_path}")

In [49]:
from dotenv import load_dotenv
import os

from graph_db.db_connection import Neo4jConnection

load_dotenv()
uri = os.getenv("NEO4J_URI")
username = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")

conn = Neo4jConnection(uri, username, password)

query = """
MATCH (s:Sample)
RETURN COUNT(s) AS totalSamples
"""

res = conn.query(query)
total_samples = res[0].data()["totalSamples"]
print("total samples:", total_samples)


query = """
MATCH (g:Genome)
RETURN COUNT(g) AS totalGenomes
"""

res = conn.query(query)
total_genomes = res[0].data()["totalGenomes"]
print("total genomes:", total_genomes)
      
query = """
MATCH (b:BGC)
RETURN COUNT(b) AS totalBGCs
"""

res = conn.query(query)
total_bgcs = res[0].data()["totalBGCs"]
print("total BGCs:", total_bgcs)

query = """
MATCH (p:Protein)
RETURN COUNT(p) AS totalProteins
"""

res = conn.query(query)
total_proteins =  res[0].data()["totalProteins"]
print("total Proteins:", total_proteins)

total samples: 1709
total genomes: 52325
total BGCs: 315634
total Proteins: 30076030


In [52]:
ocean_shp_file = "data/maps/ne_110m_ocean/ne_110m_ocean.shp"
grid_shp_file = "data/maps/ne_110m_graticules_30/ne_110m_graticules_30.shp"

plot_combined_ocean_visualization(
    'samples.tsv', 'outputs/phylum_data.tsv',
    'outputs/bgc_data.tsv',
    grid_shp_file,
    ocean_shp_file,
    sample_size=16,
    sample_alpha=0.8,
    total_samples=total_samples,
    total_mags=total_genomes,
    total_proteins=total_proteins,
    total_bgcs=total_bgcs,
    chart_alpha=0.95
)

Combined visualization has been saved to combined_ocean_visualization.png
