In [1]:
# Cohort Statistics Analysis
import os
import pandas as pd
import numpy as np
from preprocessing import DataLoader
import seaborn as sns
import matplotlib.pyplot as plt

# Initialize DataLoader
base_path = os.path.dirname(os.getcwd())
loader = DataLoader(base_path)
loader.load_all_data()

In [9]:
class CohortStatistics:
    def __init__(self, loader):
        self.loader = loader
        self.stats_df = pd.DataFrame()
        
    def calculate_basic_stats(self, pdata, name):
        """Calculate basic statistics for a cohort"""
        stats = {
            'cohort_name': name,
            'patient_count': len(pdata)
        }
        
        # Age statistics if available
        if 'AGE' in pdata.columns:
            # Convert to numeric, invalid values become NaN
            age_data = pd.to_numeric(pdata['AGE'], errors='coerce').dropna()
            if len(age_data) == 0:
                stats.update({
                    'min_age': np.nan,
                    'q25_age': np.nan,
                    'median_age': np.nan,
                    'q75_age': np.nan,
                    'max_age': np.nan
                })
            else:
                stats.update({
                    'min_age': float(age_data.min()),
                    'q25_age': float(age_data.quantile(0.25)),
                    'median_age': float(age_data.median()),
                    'q75_age': float(age_data.quantile(0.75)),
                    'max_age': float(age_data.max())
                })
        
        # Tissue type if available
        if 'TISSUE' in pdata.columns:
            stats['tissue'] = pdata['TISSUE'].iloc[0] if not pdata['TISSUE'].isna().all() else np.nan
            
        return stats
    

    def calculate_stage_stats(self, pdata, column_name, stages):
        """Calculate statistics for different stages"""
        stats = {}
        
        if column_name in pdata.columns:
            stage_data = pdata[column_name]
            if stage_data.isna().all():
                for stage in stages:
                    stats[f'count_{column_name}_{stage}'] = np.nan
                    stats[f'prop_{column_name}_{stage}'] = np.nan
            else:
                total_stage = 0
                for stage in stages:
                    count = (stage_data == stage).sum()
                    stats[f'count_{column_name}_{stage}'] = int(count)
                    total_stage += count

                for stage in stages:
                    stats[f'prop_{column_name}_{stage}'] = (
                        float(stats[f'count_{column_name}_{stage}'] / total_stage)
                        if total_stage > 0 else 0
                    )
                    
        return stats
    
    def calculate_gleason_stats(self, pdata):
        """Calculate Gleason score statistics"""
        stats = {}
        
        if 'GLEASON_SCORE' in pdata.columns:
            # Convert to numeric
            gleason_data = pd.to_numeric(pdata['GLEASON_SCORE'], errors='coerce')
            
            if gleason_data.isna().all():
                for score in range(2, 11):
                    stats[f'gleason_{score}'] = np.nan
                    stats[f'gleason_{score}_prop'] = np.nan
            else:
                total_gleason = 0
                for score in range(2, 11):
                    count = (gleason_data == score).sum()
                    stats[f'gleason_{score}'] = int(count)
                    total_gleason += count

                for score in range(2, 11):
                    stats[f'gleason_{score}_prop'] = (
                        float(stats[f'gleason_{score}'] / total_gleason)
                        if total_gleason > 0 else 0
                    )
            
            # GLEASON_SCORE_1 and GLEASON_SCORE_2
            for gs_column in ['GLEASON_SCORE_1', 'GLEASON_SCORE_2']:
                if gs_column in pdata.columns:
                    gs_data = pd.to_numeric(pdata[gs_column], errors='coerce')
                    
                    if gs_data.isna().all():
                        for score in range(1, 6):
                            stats[f'{gs_column}_{score}'] = np.nan
                            stats[f'{gs_column}_{score}_prop'] = np.nan
                    else:
                        total_gs = 0
                        for score in range(1, 6):
                            count = (gs_data == score).sum()
                            stats[f'{gs_column}_{score}'] = int(count)
                            total_gs += count

                        for score in range(1, 6):
                            stats[f'{gs_column}_{score}_prop'] = (
                                float(stats[f'{gs_column}_{score}'] / total_gs)
                                if total_gs > 0 else 0
                            )
                        
        return stats
    
    def calculate_psa_stats(self, pdata):
        """Calculate PSA-related statistics"""
        stats = {}
        
        if 'PRE_OPERATIVE_PSA' in pdata.columns:
            # Convert to numeric, invalid values become NaN
            psa_data = pd.to_numeric(pdata['PRE_OPERATIVE_PSA'], errors='coerce').dropna()
            if len(psa_data) == 0:
                stats.update({
                    'psa_mean': np.nan,
                    'psa_median': np.nan,
                    'psa_min': np.nan,
                    'psa_max': np.nan,
                    'psa_q25': np.nan,
                    'psa_q75': np.nan,
                    'psa_over_4_count': np.nan,
                    'psa_over_4_prop': np.nan
                })
            else:
                stats.update({
                    'psa_mean': float(psa_data.mean()),
                    'psa_median': float(psa_data.median()),
                    'psa_min': float(psa_data.min()),
                    'psa_max': float(psa_data.max()),
                    'psa_q25': float(psa_data.quantile(0.25)),
                    'psa_q75': float(psa_data.quantile(0.75)),
                    'psa_over_4_count': int((psa_data > 4).sum()),
                    'psa_over_4_prop': float((psa_data > 4).sum() / len(psa_data))
                })
                
        return stats
    
    def calculate_bcr_stats(self, pdata):
        """Calculate BCR-related statistics"""
        stats = {}
        
        if all(col in pdata.columns for col in ['BCR_STATUS', 'MONTH_TO_BCR']):
            # Convert both columns to numeric
            bcr_status = pd.to_numeric(pdata['BCR_STATUS'], errors='coerce')
            month_to_bcr = pd.to_numeric(pdata['MONTH_TO_BCR'], errors='coerce')
            
            # Get valid BCR data
            bcr_data = month_to_bcr[bcr_status == 1].dropna()
            
            if len(bcr_data) == 0:
                stats.update({
                    'bcr_mean': np.nan,
                    'bcr_median': np.nan,
                    'bcr_min': np.nan,
                    'bcr_max': np.nan,
                    'bcr_q25': np.nan,
                    'bcr_q75': np.nan,
                    'bcr_proportion': np.nan
                })
            else:
                stats.update({
                    'bcr_mean': float(bcr_data.mean()),
                    'bcr_median': float(bcr_data.median()),
                    'bcr_min': float(bcr_data.min()),
                    'bcr_max': float(bcr_data.max()),
                    'bcr_q25': float(bcr_data.quantile(0.25)),
                    'bcr_q75': float(bcr_data.quantile(0.75))
                })
                
                # Calculate BCR proportion using non-NA BCR_STATUS values
                valid_bcr = bcr_status.dropna()
                if len(valid_bcr) > 0:
                    bcr_count = (valid_bcr == 1).sum()
                    stats['bcr_proportion'] = float(bcr_count / len(valid_bcr))
                else:
                    stats['bcr_proportion'] = np.nan
                
        return stats

    
    def calculate_gene_stats(self, exprs_data):
        """Calculate gene-related statistics"""
        return {
            'gene_count': exprs_data.shape[1] if isinstance(exprs_data, pd.DataFrame) else np.nan
        }
    
    def calculate_cohort_stats(self, pdata, exprs_data, name):
        """Calculate all statistics for a single cohort"""
        # Initialize with basic stats
        stats = self.calculate_basic_stats(pdata, name)
        
        # T-stages to check
        t_stages = ['T1', 'T1A', 'T1B', 'T1C', 'T2', 'T2A', 'T2B', 'T2C', 'T3', 'T3A', 'T3B', 'T4']
        
        # Add all other statistics
        stats.update(self.calculate_stage_stats(pdata, 'PATH_T_STAGE', t_stages))
        stats.update(self.calculate_stage_stats(pdata, 'CLIN_T_STAGE', t_stages))
        stats.update(self.calculate_gleason_stats(pdata))
        stats.update(self.calculate_psa_stats(pdata))
        stats.update(self.calculate_bcr_stats(pdata))
        stats.update(self.calculate_gene_stats(exprs_data))
        
        return stats
    
    def calculate_all_stats(self):
        """Calculate statistics for all cohorts"""
        all_stats = []
        
        # For original cohort data
        for name, pdata in self.loader.pdata_original.items():
            exprs = self.loader.exprs_data.get(name.replace('cohort_', 'exprs_'), pd.DataFrame())
            stats = self.calculate_cohort_stats(pdata, exprs, f"Original_{name}")
            all_stats.append(stats)
            
        # For imputed cohort data
        for name, pdata in self.loader.pdata_imputed.items():
            exprs = self.loader.exprs_data.get(name.replace('cohort_', 'exprs_'), pd.DataFrame())
            stats = self.calculate_cohort_stats(pdata, exprs, f"Imputed_{name}")
            all_stats.append(stats)
            
        # For merged data
        if self.loader.merged_pdata_original:
            for name, pdata in self.loader.merged_pdata_original.items():
                stats = self.calculate_cohort_stats(
                    pdata, 
                    self.loader.intersection_data.get('exprs_intersect.csv', pd.DataFrame()),
                    f"Merged_Original_{name}"
                )
                all_stats.append(stats)
                
        if self.loader.merged_pdata_imputed:
            for name, pdata in self.loader.merged_pdata_imputed.items():
                stats = self.calculate_cohort_stats(
                    pdata, 
                    self.loader.common_genes_data.get('common_genes_knn_imputed.csv', pd.DataFrame()),
                    f"Merged_Imputed_{name}"
                )
                all_stats.append(stats)
        
        self.stats_df = pd.DataFrame(all_stats)
        return self.stats_df
    
    def plot_cohort_comparisons(self):
        """Create plots comparing key statistics across cohorts"""
        if self.stats_df.empty:
            self.calculate_all_stats()
            
        plt.style.use('seaborn')
        fig, axes = plt.subplots(2, 2, figsize=(20, 20))
        
        # Helper function to plot with NA handling
        def safe_barplot(data, x, y, ax, title):
            plot_data = data.copy()
            plot_data[y] = pd.to_numeric(plot_data[y], errors='coerce')
            if plot_data[y].notna().any():
                sns.barplot(data=plot_data, x=x, y=y, ax=ax)
                ax.set_title(title)
                ax.tick_params(axis='x', rotation=45)
            else:
                ax.text(0.5, 0.5, f'No data available for {y}', 
                       horizontalalignment='center', verticalalignment='center')
        
        # Plot all comparisons
        safe_barplot(self.stats_df, 'cohort_name', 'patient_count', axes[0,0], 'Patient Count by Cohort')
        safe_barplot(self.stats_df, 'cohort_name', 'bcr_proportion', axes[0,1], 'BCR Proportion by Cohort')
        safe_barplot(self.stats_df, 'cohort_name', 'gene_count', axes[1,0], 'Gene Count by Cohort')
        safe_barplot(self.stats_df, 'cohort_name', 'median_age', axes[1,1], 'Median Age by Cohort')
        
        plt.tight_layout()
        return fig

In [12]:
# Example usage:
stats_calculator = CohortStatistics(loader)
stats_df = stats_calculator.calculate_all_stats()
print(stats_df)
# comparison_plots = stats_calculator.plot_cohort_comparisons()
# plt.show()

                                  cohort_name  patient_count  min_age  \
0              Original_Belfast_2018_Jain.csv            248  48.0000   
1           Original_CPC_GENE_2017_Fraser.csv             73  43.8500   
2            Original_DKFZ_2018_Gerhauser.csv             82  32.0000   
3            Original_CancerMap_2017_Luca.csv            133      NaN   
4              Original_MSKCC_2010_Taylor.csv            131  37.2958   
5              Original_Atlanta_2014_Long.csv            100  43.0000   
6         Original_CamCap_2016_Ross_Adams.csv            112  41.0000   
7      Original_Stockholm_2016_Ross_Adams.csv             92      NaN   
8                  Original_CPGEA_2020_Li.csv            120  50.0000   
9               Imputed_Belfast_2018_Jain.csv            248  48.0000   
10           Imputed_CPC_GENE_2017_Fraser.csv             73  43.8500   
11            Imputed_DKFZ_2018_Gerhauser.csv             82  32.0000   
12            Imputed_CancerMap_2017_Luca.csv      