In [None]:
# !pip install plotly 
# !pip install scikit-bio
# !pip install scipy pandas numpy matplotlib seaborn statsmodels
# !pip install rpy2 scikit-learn

Collecting rpy2
  Downloading rpy2-3.6.4-py3-none-any.whl.metadata (5.4 kB)
Collecting rpy2-rinterface>=3.6.3 (from rpy2)
  Downloading rpy2_rinterface-3.6.3-cp313-cp313-macosx_10_13_universal2.whl.metadata (1.9 kB)
Collecting rpy2-robjects>=3.6.3 (from rpy2)
  Downloading rpy2_robjects-3.6.3-py3-none-any.whl.metadata (3.3 kB)
Collecting tzlocal (from rpy2-robjects>=3.6.3->rpy2)
  Downloading tzlocal-5.3.1-py3-none-any.whl.metadata (7.6 kB)
Downloading rpy2-3.6.4-py3-none-any.whl (9.9 kB)
Downloading rpy2_rinterface-3.6.3-cp313-cp313-macosx_10_13_universal2.whl (173 kB)
Downloading rpy2_robjects-3.6.3-py3-none-any.whl (125 kB)
Downloading tzlocal-5.3.1-py3-none-any.whl (18 kB)
Installing collected packages: tzlocal, rpy2-rinterface, rpy2-robjects, rpy2
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4/4[0m [rpy2][32m3/4[0m [rpy2]
[1A[2KSuccessfully installed rpy2-3.6.4 rpy2-rinterface-3.6.3 rpy2-robjects-3.6.3 tzlocal-5.3.1


In [13]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from scipy import stats
from scipy.spatial.distance import pdist, squareform
import statsmodels.api as sm
from skbio.diversity import alpha_diversity, beta_diversity
from skbio.stats.ordination import pcoa
from sklearn.manifold import MDS
import rpy2.robjects as robjects
from rpy2.robjects import pandas2ri
from rpy2.robjects.packages import importr
from rpy2.robjects.conversion import localconverter

# Set plotting style
plt.rcParams.update({
    'figure.figsize': [10, 6],
    'figure.dpi': 100,
    'axes.titlesize': 14,
    'axes.labelsize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'font.size': 10
})
sns.set_theme(style="whitegrid")
sns.set_palette('husl')

In [None]:
metadata = pd.read_csv('input/all-curated-MB_meta.csv', low_memory=False)
meta_mp_women = metadata[(meta['gender'] == 'female') 
                     & (meta['menopausal_status'].notna())
                     & (meta['age'].notna())
                     & (meta['PMID'].notna())
                     ]
meta_mp_women.head(2)


NameError: name 'meta' is not defined

In [66]:
## studies involved
print("\nStudies involved: ")
print(meta_mp_women[['study_name', 'PMID']].drop_duplicates())
print("\nStatuses available:")
print(meta_mp_women[['gender','menopausal_status']].drop_duplicates())
print("\nSamples per status:")
print(meta_mp_women.groupby('menopausal_status').size())
print("\nSamples per study and menopausal status:")
print(meta_mp_women.groupby(['study_name', 'menopausal_status']).size().reset_index(name='samples_count'))


Studies involved: 
       study_name      PMID
2582   ChuDM_2017  28112736
16247  ShaoY_2019  31534227
19995   XieH_2016  27818083

Statuses available:
       gender menopausal_status
2582   female               pre
19995  female              post
20031  female     going_through

Samples per status:
menopausal_status
going_through     28
post             174
pre              118
dtype: int64

Samples per study and menopausal status:
   study_name menopausal_status  samples_count
0  ChuDM_2017               pre             17
1  ShaoY_2019               pre             63
2   XieH_2016     going_through             28
3   XieH_2016              post            174
4   XieH_2016               pre             38


In [None]:
import rpy2.robjects as robjects
from rpy2.robjects.conversion import localconverter
from rpy2.robjects import pandas2ri

def download_and_save_datasets():
    """Download datasets and save them as CSV files"""
    
    r_code = """
    # Load required libraries
    library(curatedMetagenomicData)
    library(dplyr)
    
    # Function to download and save one dataset
    download_study <- function(study_name) {
        tryCatch({
            # Create pattern for the study
            pattern <- paste0(study_name, ".relative_abundance")
            print(paste("Downloading:", pattern))
            
            # Get the data
            data <- curatedMetagenomicData(pattern, dryrun=FALSE)
            
            if (length(data) > 0) {
                # Extract abundance data
                abundance <- as.data.frame(assay(data[[1]]))
                
                # Extract metadata
                metadata <- as.data.frame(colData(data[[1]]))
                
                # Add study name to metadata
                metadata$study_name <- study_name
                
                # Save files
                abundance_file <- paste0(study_name, "_abundance.csv")
                metadata_file <- paste0(study_name, "_metadata.csv")
                
                write.csv(abundance, abundance_file)
                write.csv(metadata, metadata_file)
                
                print(paste("Saved:", abundance_file, "and", metadata_file))
                
                return(TRUE)
            } else {
                print(paste("No data found for:", study_name))
                return(FALSE)
            }
        }, error = function(e) {
            print(paste("Error processing", study_name, ":", e$message))
            return(FALSE)
        })
    }
    
    # List of studies to download
    studies <- c("ChuDM_2017", "ShaoY_2019", "XieH_2016")
    
    # Download each study
    results <- lapply(studies, download_study)
    
    # Print summary
    successful <- sum(unlist(results))
    print(paste("Successfully downloaded", successful, "out of", length(studies), "studies"))
    """
    
    try:
        print("Starting download process...")
        robjects.r(r_code)
        print("\nDownload complete! Check current directory for CSV files.")
        
        # List saved files
        import os
        csv_files = [f for f in os.listdir('.') if f.endswith('.csv')]
        print("\nSaved files:")
        for file in csv_files:
            print(f"- {file}")
            
    except Exception as e:
        print(f"\nError: {str(e)}")
        print("\nTroubleshooting steps:")
        print("1. Make sure R and required packages are installed")
        print("2. Check write permissions in current directory")
        print("3. Verify internet connection")

# Alternative version using BiocFileCache for persistent storage
def download_with_cache():
    r_code = """
    library(curatedMetagenomicData)
    library(BiocFileCache)
    
    # Create cache
    bfc <- BiocFileCache()
    
    # Function to download and cache dataset
    cache_study <- function(study_name) {
        pattern <- paste0(study_name, ".relative_abundance")
        
        # Get data with caching
        data <- curatedMetagenomicData(pattern, dryrun=FALSE)
        
        if (length(data) > 0) {
            # Save to cache and CSV
            abundance <- as.data.frame(assay(data[[1]]))
            metadata <- as.data.frame(colData(data[[1]]))
            
            write.csv(abundance, paste0(study_name, "_abundance.csv"))
            write.csv(metadata, paste0(study_name, "_metadata.csv"))
            
            return(TRUE)
        }
        return(FALSE)
    }
    
    # Download studies
    studies <- c("ChuDM_2017", "ShaoY_2019", "XieH_2016")
    results <- sapply(studies, cache_study)
    """
    
    return robjects.r(r_code)

# Execute the download
if __name__ == "__main__":
    print("Starting download process...")
    download_and_save_datasets()
    
    if data:
        print("\nSuccessfully downloaded and verified data!")
        # Optional: perform initial analysis
        for study, study_data in data.items():
            print(f"\n{study} summary:")
            print(f"Number of species: {study_data['abundance'].shape[0]}")
            print(f"Number of samples: {study_data['abundance'].shape[1]}")

Starting download process...
Starting download process...


R callback write-console: 
Attaching package: ‘dplyr’

  
R callback write-console: The following objects are masked from ‘package:Biostrings’:

    collapse, intersect, setdiff, setequal, union

  
R callback write-console: The following object is masked from ‘package:XVector’:

    slice

  
R callback write-console: The following object is masked from ‘package:Biobase’:

    combine

  
R callback write-console: The following objects are masked from ‘package:GenomicRanges’:

    intersect, setdiff, union

  
R callback write-console: The following object is masked from ‘package:GenomeInfoDb’:

    intersect

  
R callback write-console: The following objects are masked from ‘package:IRanges’:

    collapse, desc, intersect, setdiff, slice, union

  
R callback write-console: The following objects are masked from ‘package:S4Vectors’:

    first, intersect, rename, setdiff, setequal, union

  
R callback write-console: The following objects are masked from ‘package:BiocGenerics’:

   

[1] "Downloading: ChuDM_2017.relative_abundance"


R callback write-console: snapshotDate(): 2025-04-12
  
R callback write-console: 
$`2021-04-02.ChuDM_2017.relative_abundance`
dropping rows without rowTree matches:
  k__Bacteria|p__Actinobacteria|c__Coriobacteriia|o__Coriobacteriales|f__Atopobiaceae|g__Olsenella|s__Olsenella_profusa
  k__Bacteria|p__Actinobacteria|c__Coriobacteriia|o__Coriobacteriales|f__Coriobacteriaceae|g__Collinsella|s__Collinsella_stercoris
  k__Bacteria|p__Actinobacteria|c__Coriobacteriia|o__Coriobacteriales|f__Coriobacteriaceae|g__Enorma|s__[Collinsella]_massiliensis
  k__Bacteria|p__Firmicutes|c__Bacilli|o__Bacillales|f__Bacillales_unclassified|g__Gemella|s__Gemella_bergeri
  k__Bacteria|p__Firmicutes|c__Bacilli|o__Lactobacillales|f__Carnobacteriaceae|g__Granulicatella|s__Granulicatella_elegans
  k__Bacteria|p__Firmicutes|c__Clostridia|o__Clostridiales|f__Ruminococcaceae|g__Ruminococcus|s__Ruminococcus_champanellensis
  k__Bacteria|p__Firmicutes|c__Erysipelotrichia|o__Erysipelotrichales|f__Erysipelotrichaceae|

[1] "Saved: ChuDM_2017_abundance.csv and ChuDM_2017_metadata.csv"
[1] "Downloading: ShaoY_2019.relative_abundance"


R callback write-console: snapshotDate(): 2025-04-12
  
R callback write-console: 
$`2021-03-31.ShaoY_2019.relative_abundance`
dropping rows without rowTree matches:
  k__Bacteria|p__Actinobacteria|c__Coriobacteriia|o__Coriobacteriales|f__Atopobiaceae|g__Olsenella|s__Olsenella_profusa
  k__Bacteria|p__Actinobacteria|c__Coriobacteriia|o__Coriobacteriales|f__Coriobacteriaceae|g__Collinsella|s__Collinsella_stercoris
  k__Bacteria|p__Actinobacteria|c__Coriobacteriia|o__Coriobacteriales|f__Coriobacteriaceae|g__Enorma|s__[Collinsella]_massiliensis
  k__Bacteria|p__Firmicutes|c__Bacilli|o__Bacillales|f__Bacillales_unclassified|g__Gemella|s__Gemella_bergeri
  k__Bacteria|p__Firmicutes|c__Bacilli|o__Lactobacillales|f__Carnobacteriaceae|g__Granulicatella|s__Granulicatella_elegans
  k__Bacteria|p__Firmicutes|c__Clostridia|o__Clostridiales|f__Ruminococcaceae|g__Ruminococcus|s__Ruminococcus_champanellensis
  k__Bacteria|p__Firmicutes|c__Erysipelotrichia|o__Erysipelotrichales|f__Erysipelotrichaceae|

[1] "Saved: ShaoY_2019_abundance.csv and ShaoY_2019_metadata.csv"
[1] "Downloading: XieH_2016.relative_abundance"


R callback write-console: snapshotDate(): 2025-04-12
  
R callback write-console: 
$`2021-03-31.XieH_2016.relative_abundance`
dropping rows without rowTree matches:
  k__Bacteria|p__Actinobacteria|c__Coriobacteriia|o__Coriobacteriales|f__Atopobiaceae|g__Olsenella|s__Olsenella_profusa
  k__Bacteria|p__Actinobacteria|c__Coriobacteriia|o__Coriobacteriales|f__Coriobacteriaceae|g__Collinsella|s__Collinsella_stercoris
  k__Bacteria|p__Actinobacteria|c__Coriobacteriia|o__Coriobacteriales|f__Coriobacteriaceae|g__Enorma|s__[Collinsella]_massiliensis
  k__Bacteria|p__Firmicutes|c__Bacilli|o__Lactobacillales|f__Carnobacteriaceae|g__Granulicatella|s__Granulicatella_elegans
  k__Bacteria|p__Firmicutes|c__Clostridia|o__Clostridiales|f__Ruminococcaceae|g__Ruminococcus|s__Ruminococcus_champanellensis
  k__Bacteria|p__Firmicutes|c__Erysipelotrichia|o__Erysipelotrichales|f__Erysipelotrichaceae|g__Bulleidia|s__Bulleidia_extructa
  k__Bacteria|p__Proteobacteria|c__Betaproteobacteria|o__Burkholderiales|f__

[1] "Saved: XieH_2016_abundance.csv and XieH_2016_metadata.csv"
[1] "Successfully downloaded 3 out of 3 studies"

Download complete! Check current directory for CSV files.

Saved files:
- menopausal_dataset.csv
- ShaoY_2019_metadata.csv
- ShaoY_2019_abundance.csv
- all_data.csv
- XieH_2016_abundance.csv
- ChuDM_2017_abundance.csv
- bacteria_matrix.csv
- filtered_metadata.csv
- metadata.csv
- XieH_2016_metadata.csv
- ChuDM_2017_metadata.csv

Verifying downloaded data...
Error reading ChuDM_2017 data: name 'os' is not defined
Error reading ShaoY_2019 data: name 'os' is not defined
Error reading XieH_2016 data: name 'os' is not defined


In [None]:
class MicrobiomeAnalysis:
    def __init__(self, base_path="input"):
        self.base_path = Path(base_path)
        self.studies = ["ChuDM", "ShaoY", "XieH"]
        self.data = {}
        self.combined_data = None
        self.combined_metadata = None

    def load_and_combine_data(self):
        """Load and combine data from all studies"""
        all_abundance = []
        all_metadata = []
        
        # Load data from each study
        for study in self.studies:
            study_path = self.base_path / study
            try:
                # Load abundance data
                abundance_file = list(study_path.glob("*abundance.csv"))[0]
                abundance = pd.read_csv(abundance_file, index_col=0)
                
                # Load metadata
                metadata_file = list(study_path.glob("*metadata.csv"))[0]
                metadata = pd.read_csv(metadata_file, index_col=0)
                
                all_abundance.append(abundance)
                all_metadata.append(metadata)
                
            except Exception as e:
                print(f"Error loading {study}: {str(e)}")
        
        # Combine abundance data
        self.combined_abundance = pd.concat(all_abundance, axis=1)
        
        # Combine metadata
        self.combined_metadata = pd.concat(all_metadata, axis=0)
        
        # Filter for females with menopausal status and age
        female_mask = (
            (self.combined_metadata['gender'].str.lower() == 'female') &
            self.combined_metadata['menopausal_status'].notna() &
            self.combined_metadata['age'].notna()
        )
        
        self.female_metadata = self.combined_metadata[female_mask]
        self.female_abundance = self.combined_abundance[
            self.combined_abundance.columns.intersection(self.female_metadata.index)
        ]
        
        print("\nCombined Dataset Summary:")
        print(f"Total samples: {self.female_abundance.shape[1]}")
        print(f"Total species: {self.female_abundance.shape[0]}")
        print(f"Menopausal status groups: {self.female_metadata['menopausal_status'].unique()}")

    def calculate_diversity(self):
        """Calculate alpha diversity metrics"""
        def shannon_diversity(x):
            x = x[x > 0]
            return -np.sum(x * np.log(x))
        
        def simpson_diversity(x):
            return 1 - np.sum(x ** 2)
        
        # Calculate diversity metrics
        self.diversity_metrics = pd.DataFrame(index=self.female_abundance.columns)
        
        # Shannon diversity
        self.diversity_metrics['shannon'] = self.female_abundance.apply(
            lambda x: shannon_diversity(x/100)
        )
        
        # Simpson diversity
        self.diversity_metrics['simpson'] = self.female_abundance.apply(
            lambda x: simpson_diversity(x/100)
        )
        
        # Species richness
        self.diversity_metrics['richness'] = (self.female_abundance > 0).sum()
        
        # Merge with metadata (only menopausal_status and age)
        self.diversity_metrics = self.diversity_metrics.merge(
            self.female_metadata[['age', 'menopausal_status']],
            left_index=True,
            right_index=True
        )

    def calculate_beta_diversity(self, method='bray-curtis'):
        """Calculate beta diversity between samples"""
        self.beta_diversity = {}
        
        # Normalize abundance data
        abundance_norm = self.female_abundance.div(self.female_abundance.sum(axis=0), axis=1)
        
        if method == 'bray-curtis':
            def bray_curtis(x, y):
                return np.sum(np.abs(x - y)) / np.sum(x + y)
            
            n_samples = abundance_norm.shape[1]
            dist_matrix = np.zeros((n_samples, n_samples))
            
            for i in range(n_samples):
                for j in range(i+1, n_samples):
                    dist = bray_curtis(abundance_norm.iloc[:,i], abundance_norm.iloc[:,j])
                    dist_matrix[i,j] = dist
                    dist_matrix[j,i] = dist
            
            self.beta_diversity['distance_matrix'] = pd.DataFrame(
                dist_matrix,
                index=abundance_norm.columns,
                columns=abundance_norm.columns
            )
        
        # Perform ordination (PCoA)
        pcoa_results = pcoa(self.beta_diversity['distance_matrix'])
        
        # Store PCoA coordinates
        self.beta_diversity['pcoa'] = pd.DataFrame(
            pcoa_results.samples.values,
            index=abundance_norm.columns,
            columns=[f'PC{i+1}' for i in range(pcoa_results.samples.shape[1])]
        )
        
        # Add metadata
        self.beta_diversity['pcoa'] = self.beta_diversity['pcoa'].merge(
            self.female_metadata[['menopausal_status', 'age']],
            left_index=True,
            right_index=True
        )
        
        # Calculate variance explained
        self.beta_diversity['variance_explained'] = pcoa_results.proportion_explained

        # Calculate species contributions to beta diversity
        species_contributions = []
        for species in abundance_norm.index:
            # Calculate correlation with ordination axes
            corr_pc1 = stats.spearmanr(abundance_norm.loc[species], 
                                    self.beta_diversity['pcoa']['PC1'])[0]
            corr_pc2 = stats.spearmanr(abundance_norm.loc[species], 
                                    self.beta_diversity['pcoa']['PC2'])[0]
            
            # Calculate mean abundance in different menopausal status groups
            group_means = {}
            group_data = []
            for group in self.female_metadata['menopausal_status'].unique():
                group_samples = self.female_metadata[
                    self.female_metadata['menopausal_status'] == group
                ].index
                group_abundance = abundance_norm.loc[species, group_samples]
                group_means[group] = group_abundance.mean()
                group_data.append(group_abundance)
            
            # Calculate effect size (difference between groups)
            effect_size = np.max(list(group_means.values())) - np.min(list(group_means.values()))
            
            # Perform statistical test with error handling
            try:
                stat, p_value = stats.kruskal(*group_data)
            except ValueError:  # Handle case where all values are identical
                p_value = 1.0  # Set p-value to 1 for identical distributions
                
            species_contributions.append({
                'species': species,
                'corr_pc1': corr_pc1,
                'corr_pc2': corr_pc2,
                'effect_size': effect_size,
                'mean_abundance': abundance_norm.loc[species].mean(),
                'p_value': p_value,
                **group_means  # Add group means to the results
            })
        
        self.beta_diversity['species_contributions'] = pd.DataFrame(species_contributions)
        
        return self.beta_diversity
        

    def plot_beta_diversity(self, output_dir="figures"):
        """Plot beta diversity analysis results including Manhattan plot"""
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True)
        
        # Normalize abundance data
        abundance_norm = self.female_abundance.div(self.female_abundance.sum(axis=0), axis=1)
        
        # Manhattan plot
        plt.figure(figsize=(15, 6))
        
        # -log10 transform p-values
        log_p = -np.log10(self.beta_diversity['species_contributions']['p_value'])
        
        # Plot points
        plt.scatter(
            range(len(log_p)),
            log_p,
            c=self.beta_diversity['species_contributions']['effect_size'],
            cmap='viridis',
            alpha=0.6
        )
        
        # Add significance threshold line
        plt.axhline(y=-np.log10(0.05), color='r', linestyle='--', alpha=0.5)
        
        # Customize plot
        plt.colorbar(label='Effect Size')
        plt.xlabel('Species')
        plt.ylabel('-log10(p-value)')
        plt.title('Manhattan Plot of Species Contributions to Beta Diversity')
        
        # Add species labels for significant results
        significant = self.beta_diversity['species_contributions']['p_value'] < 0.05
        significant_contributions = self.beta_diversity['species_contributions'][significant]
        
        for i, (idx, row) in enumerate(significant_contributions.iterrows()):
            plt.annotate(
                str(idx),  # or row['species'] if you have species names
                (i, -np.log10(row['p_value'])),
                xytext=(5, 5),
                textcoords='offset points',
                rotation=45,
                fontsize=8
            )
        
        plt.tight_layout()
        plt.savefig(output_path / 'manhattan_plot.png')
        plt.close()
        
        # Save top contributing species
        top_species = self.beta_diversity['species_contributions'].sort_values(
            'p_value'
        ).head(20)
        
        # Make sure we're using the correct species identifiers
        species_to_plot = abundance_norm.index.intersection(top_species.index)
        
        if len(species_to_plot) > 0:
            plt.figure(figsize=(12, 8))
            sns.heatmap(
                abundance_norm.loc[species_to_plot].T,
                cmap='viridis',
                xticklabels=True,
                yticklabels=False
            )
            plt.title('Top Contributing Species Abundance Patterns')
            plt.tight_layout()
            plt.savefig(output_path / 'top_species_heatmap.png')
            plt.close()
        else:
            print("Warning: No matching species found for heatmap")
        
        # Print summary of significant species
        print("\nTop contributing species:")
        for idx in species_to_plot:
            species_data = self.beta_diversity['species_contributions'].loc[idx]
            print(f"Species {idx}:")
            print(f"  p-value: {species_data['p_value']:.4f}")
            print(f"  effect size: {species_data['effect_size']:.4f}")

    def analyze_diversity(self):
        """Perform statistical analysis of diversity metrics"""
        metrics = ['shannon', 'simpson', 'richness']
        
        # Initialize results dictionary
        results = {
            'menopausal_status_tests': {},
            'age_correlation': {}
        }
        
        # 1. Tests for menopausal status differences
        for metric in metrics:
            # Kruskal-Wallis test
            groups = [group[metric].values for name, group 
                    in self.diversity_metrics.groupby('menopausal_status')]
            
            try:
                stat, pval = stats.kruskal(*groups)
            except ValueError:  # Handle case where all values are identical
                stat, pval = 0, 1.0
                
            results['menopausal_status_tests'][metric] = {
                'statistic': stat,
                'p_value': pval
            }
            
            # Add descriptive statistics
            desc_stats = self.diversity_metrics.groupby('menopausal_status')[metric].describe()
            results['menopausal_status_tests'][metric]['descriptive'] = desc_stats
        
        # 2. Correlation with age
        for metric in metrics:
            correlation, pval = stats.spearmanr(
                self.diversity_metrics['age'],
                self.diversity_metrics[metric]
            )
            
            results['age_correlation'][metric] = {
                'correlation': correlation,
                'p_value': pval
            }
        
        self.diversity_analysis = results
        
        # Print summary
        print("\nDiversity Analysis Results:")
        
        print("\nMenopausal Status Differences:")
        for metric in metrics:
            print(f"\n{metric.capitalize()}:")
            print(f"Kruskal-Wallis test p-value: {results['menopausal_status_tests'][metric]['p_value']:.4f}")
            print("\nDescriptive statistics by group:")
            print(results['menopausal_status_tests'][metric]['descriptive'])
        
        print("\nAge Correlations:")
        for metric in metrics:
            corr = results['age_correlation'][metric]['correlation']
            pval = results['age_correlation'][metric]['p_value']
            print(f"\n{metric.capitalize()}:")
            print(f"Spearman correlation: {corr:.4f}")
            print(f"p-value: {pval:.4f}")

    def test_beta_diversity(self):
        """Statistical tests for beta diversity"""
        from scipy.stats import mannwhitneyu, spearmanr
        import numpy as np
        
        # Get distance matrix and group information
        distances = self.beta_diversity['distance_matrix'].values
        groups = self.female_metadata['menopausal_status']
        unique_groups = groups.unique()
        
        # Initialize results dictionary
        results = {
            'pairwise_tests': {},
            'group_distances': {},
            'age_effects': {}
        }
        
        # 1. Pairwise comparisons between menopausal status groups
        for g1 in unique_groups:
            for g2 in unique_groups:
                if g1 < g2:  # avoid duplicate comparisons
                    # Get distances between and within groups
                    mask1 = groups == g1
                    mask2 = groups == g2
                    
                    between_distances = distances[mask1][:, mask2].flatten()
                    within1_distances = distances[mask1][:, mask1].flatten()
                    within2_distances = distances[mask2][:, mask2].flatten()
                    
                    # Perform statistical test
                    try:
                        stat, pval = mannwhitneyu(
                            between_distances,
                            np.concatenate([within1_distances, within2_distances])
                        )
                    except ValueError:  # Handle case where distributions are identical
                        stat, pval = 0, 1.0
                    
                    results['pairwise_tests'][f'{g1} vs {g2}'] = {
                        'statistic': stat,
                        'p_value': pval,
                        'mean_between_distance': np.mean(between_distances),
                        'mean_within_distance1': np.mean(within1_distances),
                        'mean_within_distance2': np.mean(within2_distances)
                    }
        
        # 2. Calculate within-group distances statistics
        for group in unique_groups:
            mask = groups == group
            within_distances = distances[mask][:, mask]
            
            results['group_distances'][group] = {
                'mean': np.mean(within_distances),
                'std': np.std(within_distances),
                'median': np.median(within_distances)
            }
        
        # 3. Age-related analysis
        # Flatten the distance matrix and create corresponding age differences
        flat_distances = []
        age_diffs = []
        ages = self.female_metadata['age'].values
        
        for i in range(len(ages)):
            for j in range(i+1, len(ages)):
                flat_distances.append(distances[i,j])
                age_diffs.append(abs(ages[i] - ages[j]))
        
        # Calculate correlation between distances and age differences
        corr, pval = spearmanr(flat_distances, age_diffs)
        results['age_effects']['correlation'] = {
            'correlation': corr,
            'p_value': pval
        }
        
        # Store results
        self.beta_diversity['statistical_tests'] = results
        
        # Print summary
        print("\nBeta Diversity Statistical Tests:")
        
        print("\nPairwise comparisons between menopausal status groups:")
        for comparison, result in results['pairwise_tests'].items():
            print(f"\n{comparison}:")
            print(f"p-value: {result['p_value']:.4f}")
            print(f"Mean between-group distance: {result['mean_between_distance']:.4f}")
        
        print("\nWithin-group distances:")
        for group, stats in results['group_distances'].items():
            print(f"\n{group}:")
            print(f"Mean ± SD: {stats['mean']:.4f} ± {stats['std']:.4f}")
        
        print("\nAge effects:")
        print(f"Correlation with age differences: {results['age_effects']['correlation']['correlation']:.4f}")
        print(f"p-value: {results['age_effects']['correlation']['p_value']:.4f}")
        
        # Additional visualization for age effects
        plt.figure(figsize=(8, 6))
        plt.scatter(age_diffs, flat_distances, alpha=0.3)
        plt.xlabel('Age Difference (years)')
        plt.ylabel('Community Distance')
        plt.title('Relationship between Age Differences\nand Community Dissimilarity')
        
        # Add trend line
        z = np.polyfit(age_diffs, flat_distances, 1)
        p = np.poly1d(z)
        plt.plot(sorted(age_diffs), p(sorted(age_diffs)), "r--", alpha=0.8)
        
        plt.tight_layout()
        plt.savefig('figures/age_distance_relationship.png')
        plt.close()
        
        return results


    def plot_results(self, output_dir="figures"):
        """Create visualizations for diversity analysis"""
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True)
        
        # Set style using seaborn directly
        sns.set_theme(style="whitegrid")
        colors = sns.color_palette("husl", n_colors=len(self.female_metadata['menopausal_status'].unique()))
           
        # 1. Alpha diversity plots
        metrics = ['shannon', 'simpson', 'richness']
        
        # Boxplots by menopausal status
        plt.figure(figsize=(15, 5))
        for i, metric in enumerate(metrics):
            plt.subplot(1, 3, i+1)
            sns.boxplot(
                data=self.diversity_metrics,
                x='menopausal_status',
                y=metric,
                palette=colors
            )
            sns.swarmplot(
                data=self.diversity_metrics,
                x='menopausal_status',
                y=metric,
                color='0.25',
                alpha=0.5,
                size=4
            )
            plt.title(f'{metric.capitalize()} Diversity')
            plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(output_path / 'alpha_diversity_boxplots.png')
        plt.close()
        
        # 2. Age relationship plots
        plt.figure(figsize=(15, 5))
        for i, metric in enumerate(metrics):
            plt.subplot(1, 3, i+1)
            sns.scatterplot(
                data=self.diversity_metrics,
                x='age',
                y=metric,
                hue='menopausal_status',
                palette=colors,
                alpha=0.7
            )
            # Add regression line
            sns.regplot(
                data=self.diversity_metrics,
                x='age',
                y=metric,
                scatter=False,
                color='red',
                line_kws={'linestyle': '--'}
            )
            plt.title(f'{metric.capitalize()} vs Age')
        plt.tight_layout()
        plt.savefig(output_path / 'diversity_age_relationship.png')
        plt.close()
        
        # 3. Correlation heatmap
        correlation_data = self.diversity_metrics[
            metrics + ['age']
        ].corr()
        
        plt.figure(figsize=(8, 6))
        sns.heatmap(
            correlation_data,
            annot=True,
            cmap='coolwarm',
            center=0,
            vmin=-1,
            vmax=1
        )
        plt.title('Correlation between Diversity Metrics')
        plt.tight_layout()
        plt.savefig(output_path / 'correlation_heatmap.png')
        plt.close()
        
        # 4. Distribution plots
        plt.figure(figsize=(15, 5))
        for i, metric in enumerate(metrics):
            plt.subplot(1, 3, i+1)
            for group, color in zip(
                self.diversity_metrics['menopausal_status'].unique(),
                colors
            ):
                group_data = self.diversity_metrics[
                    self.diversity_metrics['menopausal_status'] == group
                ][metric]
                sns.kdeplot(
                    data=group_data,
                    label=group,
                    color=color,
                    fill=True,
                    alpha=0.3
                )
            plt.title(f'{metric.capitalize()} Distribution')
            plt.legend()
        plt.tight_layout()
        plt.savefig(output_path / 'diversity_distributions.png')
        plt.close()
        
        # 5. Summary statistics table
        summary_stats = pd.DataFrame()
        for metric in metrics:
            group_stats = self.diversity_metrics.groupby('menopausal_status')[metric].describe()
            summary_stats = pd.concat([summary_stats, group_stats], keys=[metric], axis=0)
        
        # Save summary statistics
        summary_stats.to_csv(output_path / 'diversity_summary_stats.csv')
        
        # 6. Violin plots with individual points
        plt.figure(figsize=(15, 5))
        for i, metric in enumerate(metrics):
            plt.subplot(1, 3, i+1)
            sns.violinplot(
                data=self.diversity_metrics,
                x='menopausal_status',
                y=metric,
                palette=colors,
                inner='box'
            )
            sns.stripplot(
                data=self.diversity_metrics,
                x='menopausal_status',
                y=metric,
                color='black',
                alpha=0.3,
                jitter=0.2,
                size=4
            )
            plt.title(f'{metric.capitalize()} Distribution')
            plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(output_path / 'diversity_violins.png')
        plt.close()
        
        print(f"\nPlots saved in: {output_path}")
        print("Generated plots:")
        print("1. Alpha diversity boxplots")
        print("2. Diversity vs age relationships")
        print("3. Correlation heatmap")
        print("4. Diversity distributions")
        print("5. Violin plots with individual points")
        print("\nSummary statistics saved as 'diversity_summary_stats.csv'")


    def run_complete_analysis(self):
        """Run complete analysis pipeline"""
        print("Loading and combining data...")
        self.load_and_combine_data()
        
        print("\nCalculating alpha diversity metrics...")
        self.calculate_diversity()
        
        print("\nCalculating beta diversity...")
        self.calculate_beta_diversity()
        
        print("\nPerforming statistical analysis...")
        self.analyze_diversity()
        self.test_beta_diversity()
        
        print("\nGenerating plots...")
        self.plot_results()
        self.plot_beta_diversity()
        
        # Save results
        results_dir = Path("results")
        results_dir.mkdir(exist_ok=True)
        
        # Save all results
        self.diversity_metrics.to_csv(results_dir / "diversity_metrics.csv")
        self.beta_diversity['species_contributions'].to_csv(
            results_dir / "species_contributions.csv"
        )
        
        print("\nAnalysis complete! Results saved in 'results' directory.")

In [30]:
from pathlib import Path 

# Initialize and run analysis
analyzer = MicrobiomeAnalysis(base_path="input")
analyzer.run_complete_analysis()

# Access results
diversity_metrics = analyzer.diversity_metrics
statistical_results = analyzer.statistical_results

# Print summary statistics
print("\nDiversity Metrics Summary:")
print(diversity_metrics.groupby('menopausal_status').agg({
    'shannon': ['mean', 'std'],
    'simpson': ['mean', 'std'],
    'richness': ['mean', 'std']
}))

Loading and combining data...

Combined Dataset Summary:
Total samples: 320
Total species: 964
Menopausal status groups: ['pre' 'post' 'going_through']

Calculating alpha diversity metrics...

Calculating beta diversity...


  warn(
  warn(
  corr_pc1 = stats.spearmanr(abundance_norm.loc[species],
  corr_pc2 = stats.spearmanr(abundance_norm.loc[species],



Performing statistical analysis...

Diversity Analysis Results:

Menopausal Status Differences:

Shannon:
Kruskal-Wallis test p-value: 0.9427

Descriptive statistics by group:
                   count      mean       std       min       25%       50%  \
menopausal_status                                                            
going_through       28.0  3.009519  0.402499  1.955274  2.859249  3.156611   
post               174.0  3.003045  0.430973  1.092170  2.829103  3.063401   
pre                118.0  3.005985  0.464181  0.902234  2.804913  3.098196   

                        75%       max  
menopausal_status                      
going_through      3.267935  3.508241  
post               3.272185  3.719965  
pre                3.298588  3.733806  

Simpson:
Kruskal-Wallis test p-value: 0.7469

Descriptive statistics by group:
                   count      mean       std       min       25%       50%  \
menopausal_status                                                         


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(
  summary_stats = pd.concat([summary_stats, group_stats], keys=[metric], axis=0)
  summary_stats = pd.concat([summary_stats, group_stats], keys=[metric], axis=0)
  summary_stats = pd.concat([summary_stats, group_stats], keys=[metric], axis=0)

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(

Passing `palette` without


Plots saved in: figures
Generated plots:
1. Alpha diversity boxplots
2. Diversity vs age relationships
3. Correlation heatmap
4. Diversity distributions
5. Violin plots with individual points

Summary statistics saved as 'diversity_summary_stats.csv'


KeyError: "None of [Index([ 11, 464, 348, 228,   6, 305,  10, 224, 364,  12, 373, 226, 407, 405,\n       320,  22,  26, 181, 180,  28],\n      dtype='int64')] are in the [index]"

<Figure size 1200x800 with 0 Axes>