In [1]:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

# Load the dataset
ATLAS_Dataset = pd.read_csv('atlas_2024_genes.csv')


In [23]:
# Function to create a network diagram for a specific species and continent
def genotype_countries_network(df, species_name, continent_name, top_n_strains=10):
    # Filter the dataset to include only the specified species and continent
    df_filtered = df[(df['Species'] == species_name) & (df['Continents'] == continent_name)]

    # Get the top N most common strains for this species and continent
    top_strains = df_filtered['Strain'].value_counts().nlargest(top_n_strains).index

    # Filter the dataframe to only include the top N strains
    df_top_strains = df_filtered[df_filtered['Strain'].isin(top_strains)]

    # Create a graph
    G = nx.Graph()

    # Add nodes and edges for genes, strains, and countries
    for _, row in df_top_strains.iterrows():
        Genotype = row['Gene']
        Subgenotype = row['Strain']
        Country = row['Country']
        
        # Add nodes for gene, strain, and country
        G.add_node(Genotype, type='Genotype')
        G.add_node(Subgenotype, type='Subgenotype')
        G.add_node(Country, type='Country')
        
        # Add edges
        G.add_edge(Genotype,Subgenotype)
        G.add_edge(Subgenotype, Country)

    # Define colors and sizes based on node type
    node_color_map = {
        'Genotype': 'lightblue',
        'Subgenotype': 'lightgreen',
        'Country': 'salmon'
    }
    
    node_size_map = {
        'Genotype': 800,
        'Subgenotype': 1200,
        'Country': 1400
    }

    # Extract node colors and sizes
    node_colors = [node_color_map[G.nodes[node]['type']] for node in G.nodes]
    node_sizes = [node_size_map[G.nodes[node]['type']] for node in G.nodes]

    # Draw the network using a spring layout
    plt.figure(figsize=(12, 10))
    pos = nx.spring_layout(G, seed=42, k=0.6)  # Use k parameter to adjust spacing
    nx.draw(G, pos, with_labels=True, node_size=node_sizes, node_color=node_colors, font_size=8, font_weight='bold', edge_color='gray')

    # Create a legend for node types
    legend_labels = [plt.Line2D([0], [0], marker='o', color='w', label=node_type, markersize=10, markerfacecolor=color) 
                     for node_type, color in node_color_map.items()]
    plt.legend(handles=legend_labels, loc='upper left')

    plt.title(f'Genotypes Network Diagram for {species_name} in {continent_name} (Top {top_n_strains} Sungenotypes)')
    #plt.show()
    plt.savefig('network.png')
