In [None]:
import json
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa
from pathlib import Path
import logging

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [None]:
def load_metadata_mapping(metadata_file: str) -> dict:
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)

    sample_sheet_df = pd.read_table('./assets/sample_sheet/sample_sheet.tsv', sep='\t')

    file_to_sample_mapping = {}

    for entry in metadata:
        file_name = entry.get('file_name')

        sample_id = sample_sheet_df[sample_sheet_df["File Name"] == file_name]["Sample ID"].values[0]

        sample_id = sample_id[:-1]  # Remove last letter because Sample ID in sample sheet includes an A at the end. Further research.
        
        if not sample_id:
            logger.warning(f"No sample ID found for file: {file_name}")
        else:
            file_to_sample_mapping[file_name] = sample_id

    logger.info(f"Loaded metadata mapping for {len(file_to_sample_mapping)} files")
    return file_to_sample_mapping

In [None]:
def process_tsv_file(file_path):
    try:
        # Read the TSV file, skipping the first comment line
        df = pd.read_csv(file_path, skiprows=1, sep='\t')

        # Check if required columns exist
        required_cols = ['gene_id', 'tpm_unstranded']
        if not all(col in df.columns for col in required_cols):
            logger.warning(f"Required columns not found in {file_path}. Available columns: {df.columns.tolist()}")
            return None

        # Filter out rows that start with 'N_' (metadata rows)
        df_filtered = df[~df['gene_id'].str.startswith('N_', na=False)]

        # Set gene_id as index and select tpm_unstranded column
        df_processed = df_filtered.set_index('gene_id')['tpm_unstranded']

        return df_processed

    except Exception as e:
        logger.error(f"Error processing file {file_path}: {str(e)}")
        return None

In [None]:
def combine_gdc_files(tsv_directory, metadata_file, output_file):
    # Load metadata mapping
    file_to_sample_mapping = load_metadata_mapping(metadata_file)

    # Get all TSV files
    tsv_path = Path(tsv_directory)
    tsv_files = list(tsv_path.glob("*.tsv"))

    logger.info(f"Found {len(tsv_files)} TSV files")

    # Dictionary to store all sample data
    all_samples_data = {}
    gene_ids = None

    processed_count = 0
    skipped_count = 0

    for tsv_file in tsv_files:
        file_name = tsv_file.name

        # Get sample ID from metadata
        if file_name not in file_to_sample_mapping:
            logger.warning(f"File {file_name} not found in metadata mapping, skipping")
            skipped_count += 1
            continue

        sample_id = file_to_sample_mapping[file_name]

        # Process the TSV file
        sample_data = process_tsv_file(str(tsv_file))

        if sample_data is None:
            skipped_count += 1
            continue

        # Store gene IDs from first successful file
        if gene_ids is None:
            gene_ids = sample_data.index.tolist()
            logger.info(f"Using gene set from {file_name} with {len(gene_ids)} genes")

        # Ensure consistent gene ordering
        sample_data = sample_data.reindex(gene_ids, fill_value=0.0)

        # Store the data
        all_samples_data[sample_id] = sample_data.values
        processed_count += 1

        if processed_count % 50 == 0:
            logger.info(f"Processed {processed_count} files...")

    logger.info(f"Successfully processed {processed_count} files, skipped {skipped_count} files")

    if not all_samples_data:
        logger.error("No data was successfully processed!")
        return

    # Create the combined DataFrame
    logger.info("Creating combined DataFrame...")
    combined_df = pd.DataFrame(all_samples_data, index=gene_ids)

    # Set the index name to 'gene_id'
    combined_df.index.name = 'gene_id'

    # Ensure proper data types
    combined_df = combined_df.astype('float32')  # TPM values as float32 to save space

    logger.info(f"Combined dataset shape: {combined_df.shape}")
    logger.info(f"Genes (rows): {combined_df.shape[0]}")
    logger.info(f"Samples (columns): {combined_df.shape[1]}")
    logger.info(f"Index name: {combined_df.index.name}")

    # Save as parquet
    logger.info(f"Saving to {output_file}...")
    table = pa.Table.from_pandas(combined_df)
    pq.write_table(table, output_file, compression='snappy')

    logger.info("Dataset creation completed successfully!")

    # Print some basic statistics
    logger.info(f"Sample statistics:")
    logger.info(f"  - Min TPM value: {combined_df.values.min()}")
    logger.info(f"  - Max TPM value: {combined_df.values.max()}")
    logger.info(f"  - Mean TPM value: {combined_df.values.mean():.4f}")
    logger.info(f"  - Median TPM value: {float(pd.Series(combined_df.values.flatten()).median()):.4f}")


def verify_dataset(parquet_file):
    logger.info("Verifying created dataset...")

    # Load the dataset
    df = pd.read_parquet(parquet_file)

    # Basic info
    print("=" * 60)
    print("DATASET OVERVIEW")
    print("=" * 60)

    # Dataset shape and basic info
    info_df = pd.DataFrame({
        'Metric': ['Number of Genes (rows)', 'Number of Samples (columns)', 'Data Type', 'Memory Usage (MB)'],
        'Value': [df.shape[0], df.shape[1], str(df.dtypes.iloc[0]), f"{df.memory_usage(deep=True).sum() / 1024**2:.2f}"]
    })
    display(info_df)

    # Sample preview
    print("\n" + "=" * 60)
    print("DATA PREVIEW")
    print("=" * 60)
    display(df.iloc[:10, :5])  # First 10 genes, first 5 samples

    # Basic statistics
    print("\n" + "=" * 60)
    print("EXPRESSION STATISTICS")
    print("=" * 60)

    stats_df = pd.DataFrame({
        'Statistic': ['Min TPM', 'Max TPM', 'Mean TPM', 'Median TPM', 'Std TPM',
                     'Zero Values (%)', 'Non-zero Values (%)'],
        'Value': [
            f"{df.values.min():.4f}",
            f"{df.values.max():.4f}",
            f"{df.values.mean():.4f}",
            f"{np.median(df.values):.4f}",
            f"{df.values.std():.4f}",
            f"{(df.values == 0).sum() / df.size * 100:.2f}%",
            f"{(df.values > 0).sum() / df.size * 100:.2f}%"
        ]
    })
    display(stats_df)


def essential_plots_for_ml(parquet_file):
    # Load the dataset
    df = pd.read_parquet(parquet_file)

    print("=" * 60)
    print("ESSENTIAL PLOTS FOR ML PROJECT")
    print("=" * 60)

    # Calculate basic statistics
    gene_stats = pd.DataFrame({
        'Mean_Expression': df.mean(axis=1),
        'Zero_Percentage': (df == 0).sum(axis=1) / df.shape[1] * 100,
        'CV': df.std(axis=1) / df.mean(axis=1)
    })

    sample_stats = pd.DataFrame({
        'Total_Expression': df.sum(axis=0),
        'Expressed_Genes': (df > 0).sum(axis=0)
    })

    # Create 4 essential plots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Plot 1: Expression Distribution (Need for Log Transformation)
    axes[0,0].hist(np.log10(df.values.flatten() + 1), bins=100, alpha=0.7, edgecolor='black')
    axes[0,0].set_xlabel('Log10(TPM + 1)')
    axes[0,0].set_ylabel('Frequency')
    axes[0,0].set_title('1. Expression Value Distribution\n(Determines if log transform needed)')
    axes[0,0].axvline(np.log10(1), color='red', linestyle='--', label='TPM = 1')
    axes[0,0].legend()

    # Plot 2: Gene Sparsity (Feature Filtering)
    axes[0,1].hist(gene_stats['Zero_Percentage'], bins=50, alpha=0.7, edgecolor='black')
    axes[0,1].set_xlabel('Percentage of Zero Values per Gene')
    axes[0,1].set_ylabel('Number of Genes')
    axes[0,1].set_title('2. Gene Sparsity Distribution\n(Guides feature filtering thresholds)')
    axes[0,1].axvline(50, color='red', linestyle='--', label='50% threshold')
    axes[0,1].axvline(90, color='orange', linestyle='--', label='90% threshold')
    axes[0,1].legend()

    # Plot 3: Sample Quality (Outlier Detection)
    axes[1,0].scatter(sample_stats['Total_Expression']/1000, sample_stats['Expressed_Genes'],
                     alpha=0.6, s=30)
    axes[1,0].set_xlabel('Total Expression (thousands TPM)')
    axes[1,0].set_ylabel('Number of Expressed Genes')
    axes[1,0].set_title('3. Sample Quality Check\n(Outlier detection)')

    # Add outlier boundaries
    total_q1, total_q3 = sample_stats['Total_Expression'].quantile([0.25, 0.75])
    genes_q1, genes_q3 = sample_stats['Expressed_Genes'].quantile([0.25, 0.75])
    total_iqr = total_q3 - total_q1
    genes_iqr = genes_q3 - genes_q1

    axes[1,0].axvline((total_q1 - 1.5 * total_iqr)/1000, color='red', linestyle='--', alpha=0.5)
    axes[1,0].axvline((total_q3 + 1.5 * total_iqr)/1000, color='red', linestyle='--', alpha=0.5)
    axes[1,0].axhline(genes_q1 - 1.5 * genes_iqr, color='red', linestyle='--', alpha=0.5)
    axes[1,0].axhline(genes_q3 + 1.5 * genes_iqr, color='red', linestyle='--', alpha=0.5)

    # Plot 4: Gene Variability (Feature Selection)
    # Filter out genes with very low expression for CV calculation
    cv_data = gene_stats[(gene_stats['Mean_Expression'] > 1) & (gene_stats['CV'].notna())]['CV']
    cv_filtered = cv_data[cv_data < 5]  # Remove extreme outliers for better visualization

    axes[1,1].hist(cv_filtered, bins=50, alpha=0.7, edgecolor='black')
    axes[1,1].set_xlabel('Coefficient of Variation (CV)')
    axes[1,1].set_ylabel('Number of Genes')
    axes[1,1].set_title('4. Gene Variability Distribution\n(Identifies informative features)')
    axes[1,1].axvline(1, color='red', linestyle='--', label='CV = 1')
    axes[1,1].legend()

    plt.tight_layout()
    plt.show()

    # Summary statistics for decision making
    print("\nKEY STATISTICS FOR ML PREPROCESSING:")
    print("-" * 50)

    summary_stats = pd.DataFrame({
        'Metric': [
            'Total genes',
            'Genes with >50% zeros',
            'Genes with >90% zeros',
            'Low expression genes (mean < 1 TPM)',
            'Highly variable genes (CV > 1)',
            'Potential outlier samples'
        ],
        'Count': [
            len(gene_stats),
            (gene_stats['Zero_Percentage'] > 50).sum(),
            (gene_stats['Zero_Percentage'] > 90).sum(),
            (gene_stats['Mean_Expression'] < 1).sum(),
            (gene_stats['CV'] > 1).sum(),
            len(sample_stats[(sample_stats['Total_Expression'] < total_q1 - 1.5 * total_iqr) |
                           (sample_stats['Total_Expression'] > total_q3 + 1.5 * total_iqr)])
        ],
        'Percentage': [
            '100%',
            f"{(gene_stats['Zero_Percentage'] > 50).sum() / len(gene_stats) * 100:.1f}%",
            f"{(gene_stats['Zero_Percentage'] > 90).sum() / len(gene_stats) * 100:.1f}%",
            f"{(gene_stats['Mean_Expression'] < 1).sum() / len(gene_stats) * 100:.1f}%",
            f"{(gene_stats['CV'] > 1).sum() / len(gene_stats) * 100:.1f}%",
            f"{len(sample_stats[(sample_stats['Total_Expression'] < total_q1 - 1.5 * total_iqr) | (sample_stats['Total_Expression'] > total_q3 + 1.5 * total_iqr)]) / len(sample_stats) * 100:.1f}%"
        ]
    })

    display(summary_stats)

    return gene_stats, sample_stats

In [None]:
TSV_DIRECTORY = "./data/tsv"
METADATA_FILE = "./REQUIRED/metadata.json"
OUTPUT_FILE = "./data/dataset.pq"

In [None]:
# Combine the files
combine_gdc_files(TSV_DIRECTORY, METADATA_FILE, OUTPUT_FILE)

In [None]:
verify_dataset(OUTPUT_FILE)