In [2]:
# Antibiotic Sample Classification Analysis
# ==========================================

# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from matplotlib.ticker import MaxNLocator

# Setting visualization aesthetics
plt.style.use('seaborn-whitegrid')
sns.set_context("notebook", font_scale=1.1)

# %% [markdown]
# ## Data Acquisition and Validation

# %%
# Define file paths with flexible configuration
input_file = "datasets/Antibiotics-SampleID.csv"
output_heatmap = "Antibiotics_Sample_Classification_heatmap.png"
output_matrix_csv = "Antibiotics_Sample_Matrix_Data.csv"

# File verification with path detection protocol
def validate_file_path(file_path):
    """
    Validates file existence and searches for alternatives if not found.
    
    Parameters:
    file_path (str): Primary target file path
    
    Returns:
    str: Valid file path or raises FileNotFoundError
    """
    if os.path.exists(file_path):
        print(f"File located: {file_path}")
        return file_path
        
    # Generate alternative path variations
    filename = os.path.basename(file_path)
    name_parts = os.path.splitext(filename)
    possible_alternatives = [
        filename,
        name_parts[0].replace(" ", "_") + name_parts[1],
        name_parts[0].replace("-", "") + name_parts[1],
        name_parts[0].replace("-", "_") + name_parts[1],
        name_parts[0].replace(" ", "") + name_parts[1]
    ]
    
    # Search for alternatives in current directory
    for alt_name in possible_alternatives:
        if os.path.exists(alt_name):
            print(f"Located alternative file: {alt_name}")
            return alt_name
    
    raise FileNotFoundError(f"Unable to locate target file: {file_path}")

# Validate and set correct file path
input_file = validate_file_path(input_file)

# %% [markdown]
# ## Data Loading and Preprocessing

# %%
# Load and examine data structure
data = pd.read_csv(input_file)
print("Initial data dimensions:", data.shape)
print("\nColumn headers:")
print(data.columns.tolist())
print("\nFirst rows preview:")
data.head()

# %%
# Data structure examination and transformation
def preprocess_antibiotics_data(df):
    """
    Preprocesses the antibiotics dataset by handling headers, indices,
    and converting data types appropriately.
    
    Parameters:
    df (pandas.DataFrame): Raw input dataframe
    
    Returns:
    pandas.DataFrame: Processed dataframe ready for analysis
    """
    # Create a copy to preserve original data
    processed_df = df.copy()
    
    # Header structure processing - check if first row contains headers
    if 'Sample ID' in processed_df.columns:
        print("Using 'Sample ID' as index column")
        id_column = 'Sample ID'
    elif 'Unnamed: 0' in processed_df.columns:
        print("Renaming 'Unnamed: 0' to 'Sample ID'")
        processed_df = processed_df.rename(columns={'Unnamed: 0': 'Sample ID'})
        id_column = 'Sample ID'
    elif processed_df.columns[0] == '':
        print("Renaming first column to 'Sample ID'")
        processed_df = processed_df.rename(columns={processed_df.columns[0]: 'Sample ID'})
        id_column = 'Sample ID'
    else:
        # Use first column as ID
        print("Using first column as index")
        id_column = processed_df.columns[0]
    
    # Set index and remove any unnamed columns that are entirely empty
    processed_df = processed_df.set_index(id_column)
    empty_cols = [col for col in processed_df.columns if 'Unnamed' in col and processed_df[col].isna().all()]
    if empty_cols:
        print(f"Removing {len(empty_cols)} empty unnamed columns")
        processed_df = processed_df.drop(columns=empty_cols)
    
    # Try to convert numeric columns, coercing errors to NaN
    for col in processed_df.columns:
        try:
            processed_df[col] = pd.to_numeric(processed_df[col], errors='coerce')
        except:
            print(f"Column '{col}' contains non-numeric data")
    
    # Fill NaN values with 0 for numeric analysis
    processed_df = processed_df.fillna(0)
    
    # Check if any columns should be converted to integer type
    for col in processed_df.columns:
        if processed_df[col].dropna().apply(lambda x: x.is_integer()).all():
            processed_df[col] = processed_df[col].astype(int)
    
    return processed_df

# Apply preprocessing
processed_data = preprocess_antibiotics_data(data)
print("\nProcessed data dimensions:", processed_data.shape)
processed_data.head()

# %%
# Data inspection and quality assessment
def assess_data_quality(df):
    """
    Assess data quality by checking for missing values,
    analyzing value distributions, and identifying potential issues.
    
    Parameters:
    df (pandas.DataFrame): Input dataframe
    
    Returns:
    dict: Dictionary of quality metrics
    """
    quality_metrics = {}
    
    # Check for missing values
    missing_count = df.isna().sum().sum()
    quality_metrics['missing_values'] = missing_count
    quality_metrics['missing_percentage'] = (missing_count / df.size) * 100
    
    # Check value ranges
    quality_metrics['value_ranges'] = {
        col: (df[col].min(), df[col].max()) for col in df.columns
    }
    
    # Check for zero values
    zero_counts = (df == 0).sum()
    quality_metrics['zero_counts'] = zero_counts
    quality_metrics['zero_percentage'] = (zero_counts / len(df)) * 100
    
    # Summary statistics
    quality_metrics['summary_stats'] = df.describe()
    
    return quality_metrics

# Analyze data quality
quality_assessment = assess_data_quality(processed_data)
print("\nData quality assessment:")
print(f"Missing values: {quality_assessment['missing_values']} ({quality_assessment['missing_percentage']:.2f}%)")
print("\nValue ranges (min, max):")
for col, (min_val, max_val) in list(quality_assessment['value_ranges'].items())[:5]:
    print(f"  {col}: ({min_val}, {max_val})")
print("... and more columns")

# %% [markdown]
# ## Data Transformation and Matrix Generation

# %%
# Export processed matrix data for reference
processed_data.to_csv(output_matrix_csv)
print(f"Matrix data exported to: {output_matrix_csv}")

# %%
# Label truncation with mapping preservation for visualization clarity
def create_label_mappings(df, row_length_threshold=20, col_length_threshold=15):
    """
    Creates mappings for long row and column labels to improve visualization.
    
    Parameters:
    df (pandas.DataFrame): Input dataframe
    row_length_threshold (int): Maximum length for row labels
    col_length_threshold (int): Maximum length for column labels
    
    Returns:
    tuple: (modified dataframe, row mapping dict, column mapping dict)
    """
    modified_df = df.copy()
    
    # Create row label mapping
    row_label_mapping = {}
    for idx, label in enumerate(modified_df.index):
        if len(str(label)) > row_length_threshold:
            shortened = f"Sample {idx+1}"
            row_label_mapping[shortened] = str(label)
            modified_df.rename(index={label: shortened}, inplace=True)
    
    # Create column label mapping
    column_label_mapping = {}
    for idx, col in enumerate(modified_df.columns):
        if len(str(col)) > col_length_threshold:
            shortened = f"Cat. {idx+1}"
            column_label_mapping[shortened] = str(col)
            modified_df.rename(columns={col: shortened}, inplace=True)
    
    return modified_df, row_label_mapping, column_label_mapping

# Apply label mapping
visualization_data, row_mapping, column_mapping = create_label_mappings(processed_data)
print(f"Created {len(row_mapping)} row mappings and {len(column_mapping)} column mappings")
visualization_data.head()

# %% [markdown]
# ## Visualization Generation

# %%
# Configure optimal figure dimensions
def calculate_figure_dimensions(df):
    """
    Calculate optimal figure dimensions based on dataframe size.
    
    Parameters:
    df (pandas.DataFrame): Input dataframe
    
    Returns:
    tuple: (width, height) in inches
    """
    # Base dimensions
    base_width = 10
    base_height = 8
    
    # Scale based on data dimensions
    width_scale = max(1, df.shape[1] / 5)
    height_scale = max(1, df.shape[0] / 10)
    
    # Calculate dimensions with constraints
    width = min(24, max(base_width, base_width * width_scale))
    height = min(16, max(base_height, base_height * height_scale))
    
    return width, height

# Calculate figure dimensions
fig_width, fig_height = calculate_figure_dimensions(visualization_data)
print(f"Optimal figure dimensions: {fig_width:.1f} x {fig_height:.1f} inches")

# %%
# Generate heatmap visualization
def generate_heatmap(df, row_mapping=None, column_mapping=None, width=12, height=10, 
                    title="Antibiotic Sample Classification Heatmap", 
                    output_file=None):
    """
    Generate a comprehensive heatmap visualization with label mappings.
    
    Parameters:
    df (pandas.DataFrame): Data for visualization
    row_mapping (dict): Mapping of shortened row labels to original
    column_mapping (dict): Mapping of shortened column labels to original
    width (float): Figure width in inches
    height (float): Figure height in inches
    title (str): Plot title
    output_file (str): Path to save the figure
    
    Returns:
    matplotlib.figure.Figure: The generated figure
    """
    # Initialize figure
    plt.figure(figsize=(width, height))
    
    # Create custom color gradient: black (low) -> green (mid) -> red (high)
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list(
        "custom", ["#000000", "#006400", "#8B0000"], N=256)
    
    # Generate heatmap with enhanced parameters
    heatmap = sns.heatmap(
        df,
        cmap=cmap,
        cbar=True,
        square=False,
        xticklabels=True,
        yticklabels=True,
        linewidths=0.5,
        linecolor='white',
        annot=True,
        fmt='.1f' if df.dtypes[0] == float else 'g',
        annot_kws={"size": 9},
        robust=True,
        cbar_kws={"shrink": 0.5, "label": "Value Magnitude"}
    )
    
    # Configure typography
    plt.xticks(fontsize=10, rotation=45, ha='right')
    plt.yticks(fontsize=10)
    
    # Title and labels
    plt.title(title, fontsize=16, pad=20)
    plt.ylabel('Sample ID', fontsize=14, labelpad=15)
    
    # Optimize layout
    plt.tight_layout(pad=2.0)
    
    # Add mapping information as legend if needed
    if row_mapping or column_mapping:
        legend_text = ""
        
        if row_mapping:
            legend_text += "SAMPLE ID MAPPING:\n"
            # Display first 5 mappings, then count the rest
            sorted_items = sorted(row_mapping.items())
            for short, full in sorted_items[:5]:
                legend_text += f"{short}: {full}\n"
            if len(sorted_items) > 5:
                legend_text += f"... and {len(sorted_items)-5} more sample mappings\n"
        
        if column_mapping:
            if legend_text:
                legend_text += "\n"
            legend_text += "CATEGORY MAPPING:\n"
            sorted_items = sorted(column_mapping.items())
            for short, full in sorted_items[:5]:
                legend_text += f"{short}: {full}\n"
            if len(sorted_items) > 5:
                legend_text += f"... and {len(sorted_items)-5} more category mappings\n"
        
        # Position legend text
        plt.figtext(0.5, 0.01, legend_text, ha='center', fontsize=9, 
                    bbox={"facecolor":"white", "alpha":0.8, "pad":5})
        
        # Adjust bottom margin
        plt.subplots_adjust(bottom=0.3)
    
    # Save figure if output file specified
    if output_file:
        plt.savefig(output_file, dpi=150, bbox_inches='tight', pad_inches=0.5)
        print(f"Heatmap saved to: {output_file}")
    
    return plt.gcf()

# Generate and display heatmap
heatmap_fig = generate_heatmap(
    visualization_data, 
    row_mapping, 
    column_mapping,
    width=fig_width,
    height=fig_height,
    title="Antibiotic Sample Classification Heatmap",
    output_file=output_heatmap
)

# %% [markdown]
# ## Supplementary Analysis

# %%
# Generate reference file for label mappings
def generate_mapping_reference(row_mapping, column_mapping, output_file="antibiotic_label_reference.txt"):
    """
    Generate a reference file for label mappings.
    
    Parameters:
    row_mapping (dict): Row label mapping dictionary
    column_mapping (dict): Column label mapping dictionary
    output_file (str): Output file path
    """
    with open(output_file, 'w') as f:
        f.write("ANTIBIOTICS SAMPLE CLASSIFICATION LABEL REFERENCE\n")
        f.write("==============================================\n\n")
        
        if row_mapping:
            f.write("SAMPLE ID MAPPINGS:\n")
            f.write("-----------------\n")
            for short, full in sorted(row_mapping.items()):
                f.write(f"{short}: {full}\n")
            f.write("\n")
        
        if column_mapping:
            f.write("CATEGORY MAPPINGS:\n")
            f.write("----------------\n")
            for short, full in sorted(column_mapping.items()):
                f.write(f"{short}: {full}\n")
    
    print(f"Label reference file generated: {output_file}")

# Generate reference file if mappings exist
if row_mapping or column_mapping:
    generate_mapping_reference(row_mapping, column_mapping)

# %%
# Additional statistical analysis
def perform_statistical_analysis(df):
    """
    Perform additional statistical analysis on the dataset.
    
    Parameters:
    df (pandas.DataFrame): Input dataframe
    
    Returns:
    dict: Dictionary of analysis results
    """
    analysis = {}
    
    # Basic statistics
    analysis['basic_stats'] = df.describe()
    
    # Correlation analysis
    if df.shape[1] > 1:  # Only if multiple columns
        analysis['correlation'] = df.corr()
    
    # Column sums and means
    analysis['column_sums'] = df.sum()
    analysis['column_means'] = df.mean()
    
    # Row sums and means
    analysis['row_sums'] = df.sum(axis=1)
    analysis['row_means'] = df.mean(axis=1)
    
    # Identify top samples by total value
    row_totals = df.sum(axis=1).sort_values(ascending=False)
    analysis['top_samples'] = row_totals.head(5)
    analysis['bottom_samples'] = row_totals.tail(5)
    
    return analysis

# Perform statistical analysis
stat_analysis = perform_statistical_analysis(processed_data)
print("\nStatistical Analysis Summary:")
print("\nTop 5 samples by total value:")
print(stat_analysis['top_samples'])
print("\nBottom 5 samples by total value:")
print(stat_analysis['bottom_samples'])

# %% [markdown]
# ## Data Distribution Visualization

# %%
# Generate distribution plots
def visualize_distributions(df):
    """
    Create visualizations for data distributions.
    
    Parameters:
    df (pandas.DataFrame): Input dataframe
    """
    # Create figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    
    # Plot 1: Overall distribution of all values
    all_values = df.values.flatten()
    sns.histplot(all_values, bins=30, kde=True, ax=axes[0, 0])
    axes[0, 0].set_title('Distribution of All Values')
    axes[0, 0].set_xlabel('Value')
    axes[0, 0].set_ylabel('Frequency')
    
    # Plot 2: Row sums distribution
    row_sums = df.sum(axis=1)
    sns.histplot(row_sums, bins=20, kde=True, ax=axes[0, 1])
    axes[0, 1].set_title('Distribution of Sample Totals')
    axes[0, 1].set_xlabel('Total Value')
    axes[0, 1].set_ylabel('Frequency')
    
    # Plot 3: Column sums distribution
    col_sums = df.sum()
    sns.barplot(x=col_sums.index, y=col_sums.values, ax=axes[1, 0])
    axes[1, 0].set_title('Category Totals')
    axes[1, 0].set_xlabel('Category')
    axes[1, 0].set_ylabel('Total Value')
    axes[1, 0].tick_params(axis='x', rotation=90)
    
    # Plot 4: Box plot of columns
    sns.boxplot(data=df, ax=axes[1, 1])
    axes[1, 1].set_title('Value Distribution by Category')
    axes[1, 1].set_xlabel('Category')
    axes[1, 1].set_ylabel('Value')
    axes[1, 1].tick_params(axis='x', rotation=90)
    
    # Adjust layout
    plt.tight_layout()
    plt.savefig('antibiotic_distributions.png', dpi=150, bbox_inches='tight')
    
    return fig

# Generate distribution visualizations
dist_fig = visualize_distributions(processed_data)

# %% [markdown]
# ## Hierarchical Clustering Analysis

# %%
# Perform hierarchical clustering
def perform_clustering(df):
    """
    Perform hierarchical clustering on samples and categories.
    
    Parameters:
    df (pandas.DataFrame): Input dataframe
    
    Returns:
    matplotlib.figure.Figure: Clustered heatmap figure
    """
    # Create figure
    plt.figure(figsize=(fig_width, fig_height))
    
    # Generate clustered heatmap
    clustered_heatmap = sns.clustermap(
        df,
        cmap=plt.cm.colors.LinearSegmentedColormap.from_list(
            "custom", ["#000000", "#006400", "#8B0000"], N=256),
        figsize=(fig_width, fig_height),
        dendrogram_ratio=0.1,
        colors_ratio=0.03,
        row_cluster=True,
        col_cluster=True,
        linewidths=0.5,
        annot=df.shape[0] < 20 and df.shape[1] < 15,  # Only annotate if not too large
        fmt='.1f' if df.dtypes[0] == float else 'g',
        robust=True,
        cbar_kws={"label": "Value Magnitude"}
    )
    
    # Adjust labels
    plt.setp(clustered_heatmap.ax_heatmap.xaxis.get_majorticklabels(), rotation=45, ha='right')
    plt.setp(clustered_heatmap.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)
    
    # Add title
    plt.suptitle("Hierarchical Clustering of Antibiotic Samples", fontsize=16, y=1.02)
    
    # Save figure
    plt.savefig('antibiotic_clustering.png', dpi=150, bbox_inches='tight')
    
    return clustered_heatmap

# Generate clustered heatmap
clustered_fig = perform_clustering(processed_data)

# %% [markdown]
# ## Summary and Conclusions

# %%
# Generate summary statistics
print("\nSummary Statistics:")
print("-" * 50)
print(f"Total samples analyzed: {processed_data.shape[0]}")
print(f"Total categories: {processed_data.shape[1]}")
print(f"Data range: {processed_data.values.min()} to {processed_data.values.max()}")
print(f"Mean value: {processed_data.values.mean():.2f}")
print(f"Median value: {np.median(processed_data.values):.2f}")
print(f"Standard deviation: {processed_data.values.std():.2f}")
print("-" * 50)
print("Analysis complete. Output files generated:")
print(f"- {output_matrix_csv}")
print(f"- {output_heatmap}")
print("- antibiotic_distributions.png")
print("- antibiotic_clustering.png")
if row_mapping or column_mapping:
    print("- antibiotic_label_reference.txt")

OSError: 'seaborn-whitegrid' is not a valid package style, path of style file, URL of style file, or library style name (library styles are listed in `style.available`)