# AIRR-ML-25: Adaptive Immune Profiling Challenge - EDA & Data Analysis

## Challenge Overview
**Objective:** Build ML models for two tasks:
1. **Predict immune state** (disease vs. healthy) from adaptive immune repertoires
2. **Identify immune receptor sequences** most strongly associated with the target immune state

## Notebook Structure
1. **Setup & Data Loading**
2. **Metadata Analysis** - Labels, demographics, class balance
3. **Sequence-Level Analysis** - junction_aa, v_call, j_call distributions
4. **Diversity Metrics** - Uniqueness, entropy, richness
5. **Technical Bias Check** - Batch effects across sequencing runs
6. **HLA Gene Analysis** - Allele prevalence and label associations
7. **Missing Values Assessment**
8. **Train vs Test Comparison**
9. **Dimensionality Reduction Visualization**

---

## 1. Setup & Configuration

In [None]:
# Core imports
import os
import glob
import warnings
from collections import Counter, defaultdict
from typing import Iterator, Tuple, Union, List, Optional, Dict

# Data manipulation
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.gridspec import GridSpec

# Statistical analysis
from scipy import stats
from scipy.stats import entropy

# Progress bars
from tqdm.notebook import tqdm

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Set display options
pd.set_option('display.max_columns', 50)
pd.set_option('display.max_rows', 100)
pd.set_option('display.width', 200)

# Plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

print("‚úÖ Libraries loaded successfully!")

In [None]:
# =============================================================================
# DATA PATHS CONFIGURATION
# =============================================================================
# Update these paths based on your environment (Kaggle vs Local)

# For Kaggle:
TRAIN_DIR = "/kaggle/input/adaptive-immune-profiling-challenge-2025/train_datasets/train_datasets"
TEST_DIR = "/kaggle/input/adaptive-immune-profiling-challenge-2025/test_datasets/test_datasets"

# For local testing (uncomment and modify if running locally):
# TRAIN_DIR = "/path/to/your/local/train_datasets"
# TEST_DIR = "/path/to/your/local/test_datasets"

# Check if running on Kaggle
IS_KAGGLE = os.path.exists("/kaggle/input")
print(f"Running on Kaggle: {IS_KAGGLE}")

# Verify paths exist
if IS_KAGGLE:
    if os.path.exists(TRAIN_DIR):
        train_datasets = sorted([d for d in os.listdir(TRAIN_DIR) if d.startswith("train_dataset_")])
        test_datasets = sorted([d for d in os.listdir(TEST_DIR) if d.startswith("test_dataset_")])
        print(f"\nüìÅ Found {len(train_datasets)} training datasets")
        print(f"üìÅ Found {len(test_datasets)} test datasets")
        print(f"\nüîπ Training datasets: {train_datasets[:5]}{'...' if len(train_datasets) > 5 else ''}")
        print(f"üîπ Test datasets: {test_datasets[:5]}{'...' if len(test_datasets) > 5 else ''}")
    else:
        print("‚ö†Ô∏è Data directory not found. Please check the path.")
else:
    print("‚ö†Ô∏è Not running on Kaggle. Update paths in the cell above for local execution.")

## 2. Utility Functions for Data Loading

In [None]:
# =============================================================================
# DATA LOADING UTILITIES
# =============================================================================

def load_metadata(dataset_dir: str, metadata_filename: str = 'metadata.csv') -> Optional[pd.DataFrame]:
    """Load metadata.csv from a dataset directory."""
    metadata_path = os.path.join(dataset_dir, metadata_filename)
    if os.path.exists(metadata_path):
        return pd.read_csv(metadata_path)
    return None


def load_all_metadata(base_dir: str, dataset_prefix: str = "train_dataset_") -> pd.DataFrame:
    """Load and combine metadata from all datasets in a directory."""
    all_metadata = []
    datasets = sorted([d for d in os.listdir(base_dir) if d.startswith(dataset_prefix)])
    
    for dataset_name in tqdm(datasets, desc="Loading metadata"):
        dataset_path = os.path.join(base_dir, dataset_name)
        metadata = load_metadata(dataset_path)
        if metadata is not None:
            metadata['dataset'] = dataset_name
            all_metadata.append(metadata)
    
    if all_metadata:
        return pd.concat(all_metadata, ignore_index=True)
    return pd.DataFrame()


def load_repertoire(file_path: str) -> Optional[pd.DataFrame]:
    """Load a single repertoire TSV file."""
    try:
        return pd.read_csv(file_path, sep='\t')
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return None


def load_sample_repertoires(base_dir: str, dataset_name: str, n_samples: int = 5) -> Dict[str, pd.DataFrame]:
    """Load a sample of repertoire files from a dataset for quick analysis."""
    dataset_path = os.path.join(base_dir, dataset_name)
    tsv_files = glob.glob(os.path.join(dataset_path, "*.tsv"))[:n_samples]
    
    repertoires = {}
    for file_path in tsv_files:
        filename = os.path.basename(file_path)
        repertoires[filename] = load_repertoire(file_path)
    
    return repertoires


def get_repertoire_summary(dataset_path: str, sample_size: int = None) -> pd.DataFrame:
    """Get summary statistics for repertoires in a dataset."""
    metadata_path = os.path.join(dataset_path, 'metadata.csv')
    summaries = []
    
    if os.path.exists(metadata_path):
        metadata = pd.read_csv(metadata_path)
        files_to_process = metadata['filename'].tolist()
        if sample_size:
            files_to_process = files_to_process[:sample_size]
    else:
        files_to_process = glob.glob(os.path.join(dataset_path, "*.tsv"))
        if sample_size:
            files_to_process = files_to_process[:sample_size]
    
    for filename in tqdm(files_to_process, desc=f"Analyzing {os.path.basename(dataset_path)}", leave=False):
        if os.path.exists(metadata_path):
            file_path = os.path.join(dataset_path, filename)
        else:
            file_path = filename
            filename = os.path.basename(filename)
        
        repertoire = load_repertoire(file_path)
        if repertoire is not None:
            summary = {
                'filename': filename,
                'n_sequences': len(repertoire),
                'n_unique_junction_aa': repertoire['junction_aa'].nunique() if 'junction_aa' in repertoire.columns else 0,
                'n_unique_v_call': repertoire['v_call'].nunique() if 'v_call' in repertoire.columns else 0,
                'n_unique_j_call': repertoire['j_call'].nunique() if 'j_call' in repertoire.columns else 0,
                'has_d_call': 'd_call' in repertoire.columns,
                'has_templates': 'templates' in repertoire.columns or 'duplicate_count' in repertoire.columns,
                'columns': list(repertoire.columns)
            }
            
            if 'junction_aa' in repertoire.columns:
                seq_lengths = repertoire['junction_aa'].dropna().str.len()
                summary['mean_seq_length'] = seq_lengths.mean()
                summary['min_seq_length'] = seq_lengths.min()
                summary['max_seq_length'] = seq_lengths.max()
            
            summaries.append(summary)
    
    return pd.DataFrame(summaries)


print("‚úÖ Utility functions defined!")

## 3. Metadata Analysis

### 3.1 Overview of All Datasets

In [None]:
# =============================================================================
# LOAD ALL TRAINING METADATA
# =============================================================================

# Load metadata from all training datasets
train_metadata = load_all_metadata(TRAIN_DIR, dataset_prefix="train_dataset_")

print("=" * 80)
print("TRAINING METADATA OVERVIEW")
print("=" * 80)
print(f"\nüìä Total repertoires (samples): {len(train_metadata):,}")
print(f"üìÅ Number of datasets: {train_metadata['dataset'].nunique()}")
print(f"\nüìã Available columns in metadata:")
for col in train_metadata.columns:
    non_null = train_metadata[col].notna().sum()
    unique = train_metadata[col].nunique()
    print(f"   ‚Ä¢ {col}: {non_null:,} non-null ({non_null/len(train_metadata)*100:.1f}%), {unique:,} unique values")

# Display sample of the data
print("\n" + "=" * 80)
print("SAMPLE DATA (first 10 rows)")
print("=" * 80)
display(train_metadata.head(10))

### 3.2 Class Balance Analysis (Label Distribution)

In [None]:
# =============================================================================
# CLASS BALANCE ANALYSIS
# =============================================================================

fig = plt.figure(figsize=(16, 10))
gs = GridSpec(2, 3, figure=fig, hspace=0.3, wspace=0.3)

# --- Overall class distribution ---
ax1 = fig.add_subplot(gs[0, 0])
if 'label_positive' in train_metadata.columns:
    class_counts = train_metadata['label_positive'].value_counts()
    colors = ['#2ecc71', '#e74c3c']  # Green for True, Red for False
    ax1.pie(class_counts, labels=['Positive (Disease)', 'Negative (Healthy)'], 
            autopct='%1.1f%%', colors=colors, explode=(0.02, 0.02),
            shadow=True, startangle=90)
    ax1.set_title('Overall Class Distribution', fontsize=12, fontweight='bold')
    
    print("=" * 80)
    print("OVERALL CLASS BALANCE")
    print("=" * 80)
    print(f"\nüìà Positive (Disease): {class_counts.get(True, 0):,} ({class_counts.get(True, 0)/len(train_metadata)*100:.1f}%)")
    print(f"üìâ Negative (Healthy): {class_counts.get(False, 0):,} ({class_counts.get(False, 0)/len(train_metadata)*100:.1f}%)")
    print(f"‚öñÔ∏è  Imbalance Ratio: {max(class_counts)/min(class_counts):.2f}:1")

# --- Class distribution per dataset ---
ax2 = fig.add_subplot(gs[0, 1:])
if 'label_positive' in train_metadata.columns:
    dataset_class_counts = train_metadata.groupby(['dataset', 'label_positive']).size().unstack(fill_value=0)
    dataset_class_counts.plot(kind='bar', stacked=True, ax=ax2, color=['#e74c3c', '#2ecc71'])
    ax2.set_xlabel('Dataset', fontsize=10)
    ax2.set_ylabel('Number of Repertoires', fontsize=10)
    ax2.set_title('Class Distribution by Dataset', fontsize=12, fontweight='bold')
    ax2.legend(['Negative', 'Positive'], loc='upper right')
    ax2.tick_params(axis='x', rotation=45)

# --- Repertoires per dataset ---
ax3 = fig.add_subplot(gs[1, 0])
dataset_sizes = train_metadata['dataset'].value_counts().sort_index()
ax3.bar(range(len(dataset_sizes)), dataset_sizes.values, color='steelblue', alpha=0.7)
ax3.set_xlabel('Dataset Index', fontsize=10)
ax3.set_ylabel('Number of Repertoires', fontsize=10)
ax3.set_title('Repertoires per Dataset', fontsize=12, fontweight='bold')
ax3.axhline(y=dataset_sizes.mean(), color='red', linestyle='--', label=f'Mean: {dataset_sizes.mean():.0f}')
ax3.legend()

# --- Summary statistics table ---
ax4 = fig.add_subplot(gs[1, 1:])
ax4.axis('off')

if 'label_positive' in train_metadata.columns:
    summary_data = []
    for dataset in sorted(train_metadata['dataset'].unique()):
        ds_data = train_metadata[train_metadata['dataset'] == dataset]
        n_pos = ds_data['label_positive'].sum()
        n_neg = len(ds_data) - n_pos
        ratio = n_pos / n_neg if n_neg > 0 else float('inf')
        summary_data.append({
            'Dataset': dataset.replace('train_dataset_', 'DS_'),
            'Total': len(ds_data),
            'Positive': n_pos,
            'Negative': n_neg,
            'Pos %': f"{n_pos/len(ds_data)*100:.1f}%",
            'Ratio': f"{ratio:.2f}"
        })
    
    summary_df = pd.DataFrame(summary_data)
    table = ax4.table(cellText=summary_df.values,
                      colLabels=summary_df.columns,
                      loc='center',
                      cellLoc='center',
                      colColours=['#4a90d9']*len(summary_df.columns))
    table.auto_set_font_size(False)
    table.set_fontsize(9)
    table.scale(1.2, 1.5)
    ax4.set_title('Class Balance Summary by Dataset', fontsize=12, fontweight='bold', y=1.02)

plt.tight_layout()
plt.show()

# Detailed printout
print("\n" + "=" * 80)
print("DETAILED CLASS BALANCE BY DATASET")
print("=" * 80)
if 'label_positive' in train_metadata.columns:
    display(summary_df)

### 3.3 Demographic Features Analysis (Age, Sex, Race, etc.)

In [None]:
# =============================================================================
# DEMOGRAPHIC FEATURES ANALYSIS
# =============================================================================

# Identify demographic columns
demographic_cols = ['age', 'sex', 'race', 'study_group', 'subject_id']
available_demo_cols = [col for col in demographic_cols if col in train_metadata.columns]

print("=" * 80)
print("DEMOGRAPHIC FEATURES ANALYSIS")
print("=" * 80)
print(f"\nüìã Available demographic columns: {available_demo_cols}")

if available_demo_cols:
    # Create visualizations for each demographic feature
    n_cols = min(3, len(available_demo_cols))
    n_rows = (len(available_demo_cols) + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))
    if n_rows == 1 and n_cols == 1:
        axes = np.array([[axes]])
    elif n_rows == 1:
        axes = axes.reshape(1, -1)
    elif n_cols == 1:
        axes = axes.reshape(-1, 1)
    
    for idx, col in enumerate(available_demo_cols):
        row_idx, col_idx = idx // n_cols, idx % n_cols
        ax = axes[row_idx, col_idx]
        
        if train_metadata[col].dtype in ['float64', 'int64'] and train_metadata[col].nunique() > 10:
            # Continuous variable - histogram
            if 'label_positive' in train_metadata.columns:
                for label, color in [(True, '#2ecc71'), (False, '#e74c3c')]:
                    subset = train_metadata[train_metadata['label_positive'] == label][col].dropna()
                    ax.hist(subset, bins=30, alpha=0.6, label=f'{"Positive" if label else "Negative"}', color=color)
                ax.legend()
            else:
                ax.hist(train_metadata[col].dropna(), bins=30, alpha=0.7, color='steelblue')
            ax.set_xlabel(col, fontsize=10)
            ax.set_ylabel('Count', fontsize=10)
        else:
            # Categorical variable - bar chart
            if 'label_positive' in train_metadata.columns:
                cross_tab = pd.crosstab(train_metadata[col], train_metadata['label_positive'])
                cross_tab.plot(kind='bar', ax=ax, color=['#e74c3c', '#2ecc71'], alpha=0.8)
                ax.legend(['Negative', 'Positive'])
            else:
                train_metadata[col].value_counts().head(15).plot(kind='bar', ax=ax, color='steelblue', alpha=0.8)
            ax.set_xlabel(col, fontsize=10)
            ax.set_ylabel('Count', fontsize=10)
            ax.tick_params(axis='x', rotation=45)
        
        ax.set_title(f'{col.upper()} Distribution', fontsize=11, fontweight='bold')
    
    # Hide empty subplots
    for idx in range(len(available_demo_cols), n_rows * n_cols):
        row_idx, col_idx = idx // n_cols, idx % n_cols
        axes[row_idx, col_idx].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed statistics
    print("\n" + "-" * 80)
    for col in available_demo_cols:
        print(f"\nüìä {col.upper()} Distribution:")
        if train_metadata[col].dtype in ['float64', 'int64'] and train_metadata[col].nunique() > 10:
            print(train_metadata[col].describe())
        else:
            print(train_metadata[col].value_counts().head(10))
        print(f"   Missing: {train_metadata[col].isna().sum()} ({train_metadata[col].isna().sum()/len(train_metadata)*100:.1f}%)")
else:
    print("\n‚ö†Ô∏è No demographic columns found in metadata.")

### 3.4 Feature Correlations with Target Label

In [None]:
# =============================================================================
# CORRELATION ANALYSIS WITH TARGET LABEL
# =============================================================================

print("=" * 80)
print("CORRELATION ANALYSIS WITH TARGET LABEL (label_positive)")
print("=" * 80)

if 'label_positive' in train_metadata.columns:
    # Convert label to numeric for correlation
    train_metadata['label_numeric'] = train_metadata['label_positive'].astype(int)
    
    correlation_results = []
    
    # Analyze each column
    for col in train_metadata.columns:
        if col in ['label_positive', 'label_numeric', 'repertoire_id', 'filename', 'dataset']:
            continue
        
        col_data = train_metadata[col].dropna()
        if len(col_data) < 10:
            continue
        
        try:
            if train_metadata[col].dtype in ['float64', 'int64']:
                # Continuous: Pearson correlation + t-test
                valid_mask = train_metadata[col].notna()
                correlation = train_metadata.loc[valid_mask, col].corr(train_metadata.loc[valid_mask, 'label_numeric'])
                
                # T-test between groups
                pos_vals = train_metadata[train_metadata['label_positive'] == True][col].dropna()
                neg_vals = train_metadata[train_metadata['label_positive'] == False][col].dropna()
                if len(pos_vals) > 1 and len(neg_vals) > 1:
                    t_stat, p_val = stats.ttest_ind(pos_vals, neg_vals)
                else:
                    t_stat, p_val = np.nan, np.nan
                
                correlation_results.append({
                    'Feature': col,
                    'Type': 'Continuous',
                    'Correlation': correlation,
                    'Test': 't-test',
                    'Statistic': t_stat,
                    'P-value': p_val,
                    'Significant': p_val < 0.05 if not np.isnan(p_val) else False
                })
            else:
                # Categorical: Chi-squared test
                contingency = pd.crosstab(train_metadata[col], train_metadata['label_positive'])
                if contingency.shape[0] > 1 and contingency.shape[1] > 1:
                    chi2, p_val, dof, expected = stats.chi2_contingency(contingency)
                    
                    # Cram√©r's V for effect size
                    n = contingency.sum().sum()
                    min_dim = min(contingency.shape) - 1
                    cramers_v = np.sqrt(chi2 / (n * min_dim)) if min_dim > 0 else 0
                    
                    correlation_results.append({
                        'Feature': col,
                        'Type': 'Categorical',
                        'Correlation': cramers_v,
                        'Test': 'Chi-squared',
                        'Statistic': chi2,
                        'P-value': p_val,
                        'Significant': p_val < 0.05
                    })
        except Exception as e:
            print(f"   Could not analyze {col}: {e}")
            continue
    
    if correlation_results:
        corr_df = pd.DataFrame(correlation_results).sort_values('P-value')
        
        # Visualization
        fig, axes = plt.subplots(1, 2, figsize=(14, 6))
        
        # Correlation/Effect size bar chart
        ax1 = axes[0]
        colors = ['#2ecc71' if sig else '#95a5a1' for sig in corr_df['Significant']]
        bars = ax1.barh(corr_df['Feature'], corr_df['Correlation'].abs(), color=colors, alpha=0.8)
        ax1.set_xlabel('Correlation / Effect Size (Cram√©r\'s V)', fontsize=10)
        ax1.set_title('Feature Association with Label', fontsize=12, fontweight='bold')
        ax1.axvline(x=0.1, color='red', linestyle='--', alpha=0.5, label='Small effect')
        ax1.axvline(x=0.3, color='orange', linestyle='--', alpha=0.5, label='Medium effect')
        ax1.legend(fontsize=8)
        
        # P-value bar chart (log scale)
        ax2 = axes[1]
        log_pvals = -np.log10(corr_df['P-value'].replace(0, 1e-300))
        colors = ['#e74c3c' if sig else '#95a5a1' for sig in corr_df['Significant']]
        ax2.barh(corr_df['Feature'], log_pvals, color=colors, alpha=0.8)
        ax2.axvline(x=-np.log10(0.05), color='red', linestyle='--', label='p=0.05')
        ax2.axvline(x=-np.log10(0.01), color='orange', linestyle='--', label='p=0.01')
        ax2.set_xlabel('-log10(P-value)', fontsize=10)
        ax2.set_title('Statistical Significance', fontsize=12, fontweight='bold')
        ax2.legend(fontsize=8)
        
        plt.tight_layout()
        plt.show()
        
        # Display results table
        print("\nüìä Correlation Analysis Results (sorted by significance):")
        display(corr_df.style.format({
            'Correlation': '{:.4f}',
            'Statistic': '{:.2f}',
            'P-value': '{:.2e}'
        }).applymap(lambda x: 'background-color: #d4edda' if x == True else '', subset=['Significant']))
    
    # Clean up
    train_metadata.drop('label_numeric', axis=1, inplace=True)
else:
    print("‚ö†Ô∏è label_positive column not found in metadata.")

## 4. Sequence-Level Analysis

### 4.1 Repertoire Structure Overview

In [None]:
# =============================================================================
# REPERTOIRE STRUCTURE OVERVIEW
# =============================================================================
# Analyze repertoire files from a sample of datasets to understand structure

print("=" * 80)
print("REPERTOIRE STRUCTURE OVERVIEW")
print("=" * 80)

# Get list of training datasets
train_datasets = sorted([d for d in os.listdir(TRAIN_DIR) if d.startswith("train_dataset_")])

# Analyze structure from first dataset
sample_dataset = train_datasets[0]
sample_dataset_path = os.path.join(TRAIN_DIR, sample_dataset)

# Load a sample repertoire to examine structure
sample_files = glob.glob(os.path.join(sample_dataset_path, "*.tsv"))[:1]
if sample_files:
    sample_rep = pd.read_csv(sample_files[0], sep='\t')
    
    print(f"\nüìÅ Sample dataset: {sample_dataset}")
    print(f"üìÑ Sample file: {os.path.basename(sample_files[0])}")
    print(f"\nüìã Columns in repertoire files:")
    for col in sample_rep.columns:
        dtype = sample_rep[col].dtype
        n_unique = sample_rep[col].nunique()
        n_null = sample_rep[col].isna().sum()
        print(f"   ‚Ä¢ {col}: {dtype} | {n_unique:,} unique | {n_null} null")
    
    print(f"\nüìä Sample repertoire shape: {sample_rep.shape}")
    print(f"   ‚Ä¢ Rows (sequences): {len(sample_rep):,}")
    print(f"   ‚Ä¢ Columns: {len(sample_rep.columns)}")
    
    print("\n" + "=" * 80)
    print("SAMPLE REPERTOIRE DATA (first 10 rows)")
    print("=" * 80)
    display(sample_rep.head(10))

### 4.2 Repertoire Size Distribution

In [None]:
# =============================================================================
# REPERTOIRE SIZE DISTRIBUTION ANALYSIS
# =============================================================================

def get_repertoire_sizes(base_dir: str, datasets: list, sample_per_dataset: int = None) -> pd.DataFrame:
    """Get the number of sequences per repertoire across datasets."""
    sizes = []
    
    for dataset_name in tqdm(datasets, desc="Analyzing repertoire sizes"):
        dataset_path = os.path.join(base_dir, dataset_name)
        metadata_path = os.path.join(dataset_path, 'metadata.csv')
        
        if os.path.exists(metadata_path):
            metadata = pd.read_csv(metadata_path)
            files_to_check = metadata['filename'].tolist()
            labels = dict(zip(metadata['filename'], metadata['label_positive']))
        else:
            files_to_check = [os.path.basename(f) for f in glob.glob(os.path.join(dataset_path, "*.tsv"))]
            labels = {}
        
        if sample_per_dataset:
            files_to_check = files_to_check[:sample_per_dataset]
        
        for filename in files_to_check:
            file_path = os.path.join(dataset_path, filename)
            if os.path.exists(file_path):
                try:
                    # Just count lines without loading full file for efficiency
                    with open(file_path, 'r') as f:
                        n_lines = sum(1 for _ in f) - 1  # Subtract header
                    
                    sizes.append({
                        'dataset': dataset_name,
                        'filename': filename,
                        'n_sequences': n_lines,
                        'label_positive': labels.get(filename, None)
                    })
                except Exception as e:
                    continue
    
    return pd.DataFrame(sizes)

# Get repertoire sizes (sample for speed)
print("=" * 80)
print("REPERTOIRE SIZE DISTRIBUTION")
print("=" * 80)

repertoire_sizes = get_repertoire_sizes(TRAIN_DIR, train_datasets, sample_per_dataset=50)

if not repertoire_sizes.empty:
    print(f"\nüìä Analyzed {len(repertoire_sizes):,} repertoires across {repertoire_sizes['dataset'].nunique()} datasets")
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Overall distribution
    ax1 = axes[0, 0]
    ax1.hist(repertoire_sizes['n_sequences'], bins=50, color='steelblue', alpha=0.7, edgecolor='black')
    ax1.axvline(repertoire_sizes['n_sequences'].median(), color='red', linestyle='--', 
                label=f'Median: {repertoire_sizes["n_sequences"].median():,.0f}')
    ax1.axvline(repertoire_sizes['n_sequences'].mean(), color='orange', linestyle='--', 
                label=f'Mean: {repertoire_sizes["n_sequences"].mean():,.0f}')
    ax1.set_xlabel('Number of Sequences', fontsize=10)
    ax1.set_ylabel('Count', fontsize=10)
    ax1.set_title('Repertoire Size Distribution (All Datasets)', fontsize=12, fontweight='bold')
    ax1.legend()
    
    # Log scale distribution
    ax2 = axes[0, 1]
    ax2.hist(np.log10(repertoire_sizes['n_sequences'] + 1), bins=50, color='steelblue', alpha=0.7, edgecolor='black')
    ax2.set_xlabel('log10(Number of Sequences)', fontsize=10)
    ax2.set_ylabel('Count', fontsize=10)
    ax2.set_title('Repertoire Size Distribution (Log Scale)', fontsize=12, fontweight='bold')
    
    # Distribution by label
    ax3 = axes[1, 0]
    if repertoire_sizes['label_positive'].notna().any():
        for label, color in [(True, '#2ecc71'), (False, '#e74c3c')]:
            subset = repertoire_sizes[repertoire_sizes['label_positive'] == label]['n_sequences']
            if len(subset) > 0:
                ax3.hist(subset, bins=30, alpha=0.6, label=f'{"Positive" if label else "Negative"}', color=color)
        ax3.legend()
    ax3.set_xlabel('Number of Sequences', fontsize=10)
    ax3.set_ylabel('Count', fontsize=10)
    ax3.set_title('Repertoire Size by Label', fontsize=12, fontweight='bold')
    
    # Box plot by dataset
    ax4 = axes[1, 1]
    dataset_order = repertoire_sizes.groupby('dataset')['n_sequences'].median().sort_values().index
    sns.boxplot(data=repertoire_sizes, x='dataset', y='n_sequences', ax=ax4, order=dataset_order)
    ax4.set_xlabel('Dataset', fontsize=10)
    ax4.set_ylabel('Number of Sequences', fontsize=10)
    ax4.set_title('Repertoire Size by Dataset', fontsize=12, fontweight='bold')
    ax4.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # Summary statistics
    print("\nüìà Repertoire Size Statistics:")
    print(repertoire_sizes['n_sequences'].describe())
    
    if repertoire_sizes['label_positive'].notna().any():
        print("\nüìä Size Statistics by Label:")
        display(repertoire_sizes.groupby('label_positive')['n_sequences'].describe())

### 4.3 Junction AA Sequence Analysis

In [None]:
# =============================================================================
# JUNCTION_AA SEQUENCE ANALYSIS
# =============================================================================

def analyze_sequences_from_dataset(dataset_path: str, n_files: int = 10) -> Dict:
    """Analyze junction_aa sequences from a sample of repertoire files."""
    metadata_path = os.path.join(dataset_path, 'metadata.csv')
    
    if os.path.exists(metadata_path):
        metadata = pd.read_csv(metadata_path)
        files = metadata['filename'].tolist()[:n_files]
        labels = dict(zip(metadata['filename'], metadata['label_positive']))
    else:
        files = [os.path.basename(f) for f in glob.glob(os.path.join(dataset_path, "*.tsv"))[:n_files]]
        labels = {}
    
    all_seq_lengths = []
    all_sequences = []
    amino_acid_counts = Counter()
    
    for filename in files:
        file_path = os.path.join(dataset_path, filename)
        if os.path.exists(file_path):
            rep = pd.read_csv(file_path, sep='\t')
            if 'junction_aa' in rep.columns:
                seqs = rep['junction_aa'].dropna()
                all_sequences.extend(seqs.tolist())
                
                lengths = seqs.str.len()
                label = labels.get(filename, None)
                for length in lengths:
                    all_seq_lengths.append({'length': length, 'label': label})
                
                # Count amino acids
                for seq in seqs:
                    amino_acid_counts.update(seq)
    
    return {
        'seq_lengths': pd.DataFrame(all_seq_lengths),
        'sequences': all_sequences,
        'amino_acid_counts': amino_acid_counts
    }

print("=" * 80)
print("JUNCTION_AA SEQUENCE ANALYSIS")
print("=" * 80)

# Analyze sequences from first few datasets
sequence_analysis = {}
for dataset_name in train_datasets[:3]:
    dataset_path = os.path.join(TRAIN_DIR, dataset_name)
    sequence_analysis[dataset_name] = analyze_sequences_from_dataset(dataset_path, n_files=20)
    print(f"‚úÖ Analyzed {dataset_name}")

# Combine results
all_lengths = pd.concat([sa['seq_lengths'] for sa in sequence_analysis.values()], ignore_index=True)
all_aa_counts = Counter()
for sa in sequence_analysis.values():
    all_aa_counts.update(sa['amino_acid_counts'])

# Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Sequence length distribution
ax1 = axes[0, 0]
ax1.hist(all_lengths['length'].dropna(), bins=50, color='steelblue', alpha=0.7, edgecolor='black')
ax1.axvline(all_lengths['length'].median(), color='red', linestyle='--', 
            label=f'Median: {all_lengths["length"].median():.0f}')
ax1.set_xlabel('Junction AA Length', fontsize=10)
ax1.set_ylabel('Count', fontsize=10)
ax1.set_title('Junction AA Sequence Length Distribution', fontsize=12, fontweight='bold')
ax1.legend()

# Length by label
ax2 = axes[0, 1]
if all_lengths['label'].notna().any():
    for label, color in [(True, '#2ecc71'), (False, '#e74c3c')]:
        subset = all_lengths[all_lengths['label'] == label]['length']
        if len(subset) > 0:
            ax2.hist(subset, bins=30, alpha=0.6, label=f'{"Positive" if label else "Negative"}', color=color)
    ax2.legend()
ax2.set_xlabel('Junction AA Length', fontsize=10)
ax2.set_ylabel('Count', fontsize=10)
ax2.set_title('Sequence Length by Label', fontsize=12, fontweight='bold')

# Amino acid frequency
ax3 = axes[1, 0]
aa_df = pd.DataFrame.from_dict(all_aa_counts, orient='index', columns=['count']).sort_values('count', ascending=True)
# Filter to standard amino acids
standard_aa = list('ACDEFGHIKLMNPQRSTVWY')
aa_df_filtered = aa_df[aa_df.index.isin(standard_aa)].sort_values('count', ascending=True)
ax3.barh(aa_df_filtered.index, aa_df_filtered['count'], color='steelblue', alpha=0.8)
ax3.set_xlabel('Count', fontsize=10)
ax3.set_ylabel('Amino Acid', fontsize=10)
ax3.set_title('Amino Acid Frequency in Junction AA', fontsize=12, fontweight='bold')

# Length statistics by dataset
ax4 = axes[1, 1]
length_stats = []
for dataset_name, sa in sequence_analysis.items():
    lengths = sa['seq_lengths']['length']
    length_stats.append({
        'Dataset': dataset_name.replace('train_dataset_', 'DS_'),
        'Mean': lengths.mean(),
        'Median': lengths.median(),
        'Std': lengths.std()
    })
length_df = pd.DataFrame(length_stats)
x = np.arange(len(length_df))
width = 0.35
ax4.bar(x - width/2, length_df['Mean'], width, label='Mean', color='steelblue', alpha=0.8)
ax4.bar(x + width/2, length_df['Median'], width, label='Median', color='coral', alpha=0.8)
ax4.errorbar(x - width/2, length_df['Mean'], yerr=length_df['Std'], fmt='none', color='black', capsize=3)
ax4.set_xticks(x)
ax4.set_xticklabels(length_df['Dataset'], rotation=45)
ax4.set_ylabel('Length', fontsize=10)
ax4.set_title('Sequence Length Statistics by Dataset', fontsize=12, fontweight='bold')
ax4.legend()

plt.tight_layout()
plt.show()

# Print statistics
print("\nüìä Junction AA Length Statistics:")
print(all_lengths['length'].describe())

print("\nüìä Top 10 Most Common Amino Acids:")
print(aa_df_filtered.tail(10))

### 4.4 V/J/D Gene Call Distribution

In [None]:
# =============================================================================
# V/J/D GENE CALL DISTRIBUTION
# =============================================================================

def analyze_gene_calls(dataset_path: str, n_files: int = 20) -> Dict:
    """Analyze v_call, j_call, and d_call distributions."""
    metadata_path = os.path.join(dataset_path, 'metadata.csv')
    
    if os.path.exists(metadata_path):
        metadata = pd.read_csv(metadata_path)
        files = metadata['filename'].tolist()[:n_files]
        labels = dict(zip(metadata['filename'], metadata['label_positive']))
    else:
        files = [os.path.basename(f) for f in glob.glob(os.path.join(dataset_path, "*.tsv"))[:n_files]]
        labels = {}
    
    v_calls = Counter()
    j_calls = Counter()
    d_calls = Counter()
    
    v_calls_by_label = {True: Counter(), False: Counter()}
    j_calls_by_label = {True: Counter(), False: Counter()}
    
    has_d_call = False
    
    for filename in files:
        file_path = os.path.join(dataset_path, filename)
        label = labels.get(filename, None)
        
        if os.path.exists(file_path):
            rep = pd.read_csv(file_path, sep='\t')
            
            if 'v_call' in rep.columns:
                v_vals = rep['v_call'].dropna()
                v_calls.update(v_vals)
                if label is not None:
                    v_calls_by_label[label].update(v_vals)
            
            if 'j_call' in rep.columns:
                j_vals = rep['j_call'].dropna()
                j_calls.update(j_vals)
                if label is not None:
                    j_calls_by_label[label].update(j_vals)
            
            if 'd_call' in rep.columns:
                has_d_call = True
                d_calls.update(rep['d_call'].dropna())
    
    return {
        'v_calls': v_calls,
        'j_calls': j_calls,
        'd_calls': d_calls,
        'has_d_call': has_d_call,
        'v_calls_by_label': v_calls_by_label,
        'j_calls_by_label': j_calls_by_label
    }

print("=" * 80)
print("V/J/D GENE CALL DISTRIBUTION")
print("=" * 80)

# Analyze gene calls from datasets
gene_analysis = {}
for dataset_name in train_datasets[:3]:
    dataset_path = os.path.join(TRAIN_DIR, dataset_name)
    gene_analysis[dataset_name] = analyze_gene_calls(dataset_path, n_files=30)
    print(f"‚úÖ Analyzed {dataset_name}: d_call present = {gene_analysis[dataset_name]['has_d_call']}")

# Combine results
combined_v = Counter()
combined_j = Counter()
combined_d = Counter()
for ga in gene_analysis.values():
    combined_v.update(ga['v_calls'])
    combined_j.update(ga['j_calls'])
    combined_d.update(ga['d_calls'])

# Visualization
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Top V genes
ax1 = axes[0, 0]
top_v = pd.DataFrame.from_dict(combined_v, orient='index', columns=['count']).nlargest(20, 'count')
ax1.barh(top_v.index, top_v['count'], color='steelblue', alpha=0.8)
ax1.set_xlabel('Count', fontsize=10)
ax1.set_ylabel('V Gene', fontsize=10)
ax1.set_title('Top 20 V Gene Calls', fontsize=12, fontweight='bold')
ax1.invert_yaxis()

# Top J genes
ax2 = axes[0, 1]
top_j = pd.DataFrame.from_dict(combined_j, orient='index', columns=['count']).nlargest(20, 'count')
ax2.barh(top_j.index, top_j['count'], color='coral', alpha=0.8)
ax2.set_xlabel('Count', fontsize=10)
ax2.set_ylabel('J Gene', fontsize=10)
ax2.set_title('Top 20 J Gene Calls', fontsize=12, fontweight='bold')
ax2.invert_yaxis()

# V gene diversity by dataset
ax3 = axes[1, 0]
v_diversity = []
for dataset_name, ga in gene_analysis.items():
    n_unique_v = len(ga['v_calls'])
    total_v = sum(ga['v_calls'].values())
    v_diversity.append({
        'Dataset': dataset_name.replace('train_dataset_', 'DS_'),
        'Unique V genes': n_unique_v,
        'Total V calls': total_v
    })
v_div_df = pd.DataFrame(v_diversity)
x = np.arange(len(v_div_df))
ax3.bar(x, v_div_df['Unique V genes'], color='steelblue', alpha=0.8)
ax3.set_xticks(x)
ax3.set_xticklabels(v_div_df['Dataset'], rotation=45)
ax3.set_ylabel('Number of Unique V Genes', fontsize=10)
ax3.set_title('V Gene Diversity by Dataset', fontsize=12, fontweight='bold')

# Top D genes (if available)
ax4 = axes[1, 1]
if combined_d:
    top_d = pd.DataFrame.from_dict(combined_d, orient='index', columns=['count']).nlargest(15, 'count')
    ax4.barh(top_d.index, top_d['count'], color='forestgreen', alpha=0.8)
    ax4.set_xlabel('Count', fontsize=10)
    ax4.set_ylabel('D Gene', fontsize=10)
    ax4.set_title('Top 15 D Gene Calls', fontsize=12, fontweight='bold')
    ax4.invert_yaxis()
else:
    ax4.text(0.5, 0.5, 'D gene calls not available\nin analyzed datasets', 
             ha='center', va='center', fontsize=12, transform=ax4.transAxes)
    ax4.axis('off')

plt.tight_layout()
plt.show()

# Print statistics
print("\nüìä Gene Call Statistics:")
print(f"   ‚Ä¢ Unique V genes: {len(combined_v):,}")
print(f"   ‚Ä¢ Unique J genes: {len(combined_j):,}")
print(f"   ‚Ä¢ Unique D genes: {len(combined_d):,}")

print("\nüìä Top 10 V Genes:")
for gene, count in combined_v.most_common(10):
    print(f"   {gene}: {count:,}")

print("\nüìä Top 10 J Genes:")
for gene, count in combined_j.most_common(10):
    print(f"   {gene}: {count:,}")

## 5. Diversity Metrics & Sequence Sharing

### 5.1 Repertoire Diversity Metrics

In [None]:
# =============================================================================
# REPERTOIRE DIVERSITY METRICS
# =============================================================================

def calculate_diversity_metrics(dataset_path: str, n_files: int = 30) -> pd.DataFrame:
    """Calculate diversity metrics for repertoires in a dataset."""
    metadata_path = os.path.join(dataset_path, 'metadata.csv')
    
    if os.path.exists(metadata_path):
        metadata = pd.read_csv(metadata_path)
        files = metadata['filename'].tolist()[:n_files]
        labels = dict(zip(metadata['filename'], metadata['label_positive']))
    else:
        files = [os.path.basename(f) for f in glob.glob(os.path.join(dataset_path, "*.tsv"))[:n_files]]
        labels = {}
    
    metrics = []
    
    for filename in tqdm(files, desc=f"Calculating diversity", leave=False):
        file_path = os.path.join(dataset_path, filename)
        
        if os.path.exists(file_path):
            rep = pd.read_csv(file_path, sep='\t')
            
            metric = {
                'filename': filename,
                'label_positive': labels.get(filename, None),
                'n_sequences': len(rep),
            }
            
            # Junction AA diversity
            if 'junction_aa' in rep.columns:
                junction_counts = rep['junction_aa'].value_counts()
                n_unique = len(junction_counts)
                metric['n_unique_junction'] = n_unique
                metric['junction_richness'] = n_unique / len(rep) if len(rep) > 0 else 0
                
                # Shannon entropy
                probs = junction_counts / junction_counts.sum()
                metric['junction_entropy'] = entropy(probs)
                
                # Simpson's diversity index
                metric['simpson_diversity'] = 1 - sum((junction_counts / junction_counts.sum()) ** 2)
                
                # Clonality (inverse of normalized entropy)
                max_entropy = np.log(n_unique) if n_unique > 1 else 1
                metric['clonality'] = 1 - (metric['junction_entropy'] / max_entropy) if max_entropy > 0 else 0
            
            # V gene diversity
            if 'v_call' in rep.columns:
                v_counts = rep['v_call'].value_counts()
                metric['n_unique_v'] = len(v_counts)
                metric['v_entropy'] = entropy(v_counts / v_counts.sum())
            
            # J gene diversity  
            if 'j_call' in rep.columns:
                j_counts = rep['j_call'].value_counts()
                metric['n_unique_j'] = len(j_counts)
                metric['j_entropy'] = entropy(j_counts / j_counts.sum())
            
            # Templates/duplicate counts if available
            template_col = 'templates' if 'templates' in rep.columns else ('duplicate_count' if 'duplicate_count' in rep.columns else None)
            if template_col:
                metric['has_templates'] = True
                metric['total_templates'] = rep[template_col].sum()
                metric['mean_templates'] = rep[template_col].mean()
                metric['max_templates'] = rep[template_col].max()
            else:
                metric['has_templates'] = False
            
            metrics.append(metric)
    
    return pd.DataFrame(metrics)

print("=" * 80)
print("REPERTOIRE DIVERSITY METRICS")
print("=" * 80)

# Calculate diversity for multiple datasets
diversity_results = []
for dataset_name in train_datasets[:5]:
    dataset_path = os.path.join(TRAIN_DIR, dataset_name)
    div_df = calculate_diversity_metrics(dataset_path, n_files=30)
    div_df['dataset'] = dataset_name
    diversity_results.append(div_df)
    print(f"‚úÖ Calculated diversity for {dataset_name}")

diversity_df = pd.concat(diversity_results, ignore_index=True)

# Visualization
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

# Shannon entropy distribution
ax1 = axes[0, 0]
if 'junction_entropy' in diversity_df.columns:
    if diversity_df['label_positive'].notna().any():
        for label, color in [(True, '#2ecc71'), (False, '#e74c3c')]:
            subset = diversity_df[diversity_df['label_positive'] == label]['junction_entropy'].dropna()
            if len(subset) > 0:
                ax1.hist(subset, bins=20, alpha=0.6, label=f'{"Positive" if label else "Negative"}', color=color)
        ax1.legend()
    else:
        ax1.hist(diversity_df['junction_entropy'].dropna(), bins=20, color='steelblue', alpha=0.7)
ax1.set_xlabel('Shannon Entropy', fontsize=10)
ax1.set_ylabel('Count', fontsize=10)
ax1.set_title('Junction AA Entropy Distribution', fontsize=12, fontweight='bold')

# Richness by label
ax2 = axes[0, 1]
if 'junction_richness' in diversity_df.columns:
    if diversity_df['label_positive'].notna().any():
        sns.boxplot(data=diversity_df, x='label_positive', y='junction_richness', ax=ax2, 
                   palette={True: '#2ecc71', False: '#e74c3c'})
    else:
        ax2.hist(diversity_df['junction_richness'].dropna(), bins=20, color='steelblue', alpha=0.7)
ax2.set_xlabel('Label', fontsize=10)
ax2.set_ylabel('Richness (Unique/Total)', fontsize=10)
ax2.set_title('Junction Richness by Label', fontsize=12, fontweight='bold')

# Clonality distribution
ax3 = axes[0, 2]
if 'clonality' in diversity_df.columns:
    if diversity_df['label_positive'].notna().any():
        for label, color in [(True, '#2ecc71'), (False, '#e74c3c')]:
            subset = diversity_df[diversity_df['label_positive'] == label]['clonality'].dropna()
            if len(subset) > 0:
                ax3.hist(subset, bins=20, alpha=0.6, label=f'{"Positive" if label else "Negative"}', color=color)
        ax3.legend()
    else:
        ax3.hist(diversity_df['clonality'].dropna(), bins=20, color='steelblue', alpha=0.7)
ax3.set_xlabel('Clonality', fontsize=10)
ax3.set_ylabel('Count', fontsize=10)
ax3.set_title('Clonality Distribution', fontsize=12, fontweight='bold')

# Simpson diversity by dataset
ax4 = axes[1, 0]
if 'simpson_diversity' in diversity_df.columns:
    sns.boxplot(data=diversity_df, x='dataset', y='simpson_diversity', ax=ax4)
    ax4.tick_params(axis='x', rotation=45)
ax4.set_xlabel('Dataset', fontsize=10)
ax4.set_ylabel('Simpson Diversity', fontsize=10)
ax4.set_title('Simpson Diversity by Dataset', fontsize=12, fontweight='bold')

# V gene entropy by label
ax5 = axes[1, 1]
if 'v_entropy' in diversity_df.columns:
    if diversity_df['label_positive'].notna().any():
        sns.boxplot(data=diversity_df, x='label_positive', y='v_entropy', ax=ax5,
                   palette={True: '#2ecc71', False: '#e74c3c'})
ax5.set_xlabel('Label', fontsize=10)
ax5.set_ylabel('V Gene Entropy', fontsize=10)
ax5.set_title('V Gene Entropy by Label', fontsize=12, fontweight='bold')

# Correlation: entropy vs richness
ax6 = axes[1, 2]
if 'junction_entropy' in diversity_df.columns and 'junction_richness' in diversity_df.columns:
    if diversity_df['label_positive'].notna().any():
        for label, color in [(True, '#2ecc71'), (False, '#e74c3c')]:
            subset = diversity_df[diversity_df['label_positive'] == label]
            ax6.scatter(subset['junction_richness'], subset['junction_entropy'], alpha=0.5, 
                       label=f'{"Positive" if label else "Negative"}', c=color, s=30)
        ax6.legend()
    else:
        ax6.scatter(diversity_df['junction_richness'], diversity_df['junction_entropy'], alpha=0.5, c='steelblue')
ax6.set_xlabel('Richness', fontsize=10)
ax6.set_ylabel('Shannon Entropy', fontsize=10)
ax6.set_title('Entropy vs Richness', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

# Print summary statistics
print("\nüìä Diversity Metrics Summary:")
diversity_cols = ['junction_entropy', 'junction_richness', 'simpson_diversity', 'clonality', 'v_entropy', 'j_entropy']
available_cols = [c for c in diversity_cols if c in diversity_df.columns]
display(diversity_df[available_cols].describe())

# Compare by label
if diversity_df['label_positive'].notna().any():
    print("\nüìä Diversity Metrics by Label:")
    display(diversity_df.groupby('label_positive')[available_cols].mean())

### 5.2 Shared Sequences Analysis ("Public Clones" / "Star Soldiers")

In [None]:
# =============================================================================
# SHARED SEQUENCES ANALYSIS ("PUBLIC CLONES" / "STAR SOLDIERS")
# =============================================================================
# Identify sequences that appear across multiple individuals - potential disease markers

def find_shared_sequences(dataset_path: str, n_files: int = 50) -> Dict:
    """Find sequences shared across multiple repertoires."""
    metadata_path = os.path.join(dataset_path, 'metadata.csv')
    
    if os.path.exists(metadata_path):
        metadata = pd.read_csv(metadata_path)
        files = metadata['filename'].tolist()[:n_files]
        labels = dict(zip(metadata['filename'], metadata['label_positive']))
    else:
        files = [os.path.basename(f) for f in glob.glob(os.path.join(dataset_path, "*.tsv"))[:n_files]]
        labels = {}
    
    # Track which repertoires contain each sequence
    seq_to_repertoires = defaultdict(set)
    seq_to_positive = defaultdict(set)  # Track if sequence appears in positive samples
    seq_to_negative = defaultdict(set)  # Track if sequence appears in negative samples
    
    for filename in tqdm(files, desc="Finding shared sequences", leave=False):
        file_path = os.path.join(dataset_path, filename)
        label = labels.get(filename, None)
        
        if os.path.exists(file_path):
            rep = pd.read_csv(file_path, sep='\t')
            if 'junction_aa' in rep.columns:
                unique_seqs = rep['junction_aa'].dropna().unique()
                for seq in unique_seqs:
                    seq_to_repertoires[seq].add(filename)
                    if label is True:
                        seq_to_positive[seq].add(filename)
                    elif label is False:
                        seq_to_negative[seq].add(filename)
    
    return {
        'seq_to_repertoires': seq_to_repertoires,
        'seq_to_positive': seq_to_positive,
        'seq_to_negative': seq_to_negative,
        'n_repertoires': len(files)
    }

print("=" * 80)
print("SHARED SEQUENCES ANALYSIS (Public Clones)")
print("=" * 80)

# Analyze shared sequences from a sample dataset
sample_dataset_name = train_datasets[0]
sample_dataset_path = os.path.join(TRAIN_DIR, sample_dataset_name)
shared_results = find_shared_sequences(sample_dataset_path, n_files=50)

# Analyze sharing patterns
sharing_counts = Counter()
for seq, repertoires in shared_results['seq_to_repertoires'].items():
    sharing_counts[len(repertoires)] += 1

# Identify disease-associated shared sequences
disease_enriched = []
for seq, pos_reps in shared_results['seq_to_positive'].items():
    neg_reps = shared_results['seq_to_negative'].get(seq, set())
    n_pos = len(pos_reps)
    n_neg = len(neg_reps)
    total = n_pos + n_neg
    if total >= 3:  # Minimum threshold
        enrichment = n_pos / total
        disease_enriched.append({
            'sequence': seq,
            'n_positive': n_pos,
            'n_negative': n_neg,
            'total': total,
            'pos_ratio': enrichment
        })

disease_df = pd.DataFrame(disease_enriched)

# Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Sharing distribution
ax1 = axes[0, 0]
sharing_df = pd.DataFrame.from_dict(sharing_counts, orient='index', columns=['count']).sort_index()
ax1.bar(sharing_df.index[:20], sharing_df['count'][:20], color='steelblue', alpha=0.8)
ax1.set_xlabel('Number of Repertoires', fontsize=10)
ax1.set_ylabel('Number of Sequences', fontsize=10)
ax1.set_title('Sequence Sharing Distribution', fontsize=12, fontweight='bold')
ax1.set_yscale('log')

# Cumulative sharing
ax2 = axes[0, 1]
cumsum = sharing_df['count'].sort_index(ascending=False).cumsum()
ax2.plot(cumsum.index, cumsum.values, color='steelblue', linewidth=2)
ax2.set_xlabel('Minimum Number of Repertoires', fontsize=10)
ax2.set_ylabel('Cumulative Number of Sequences', fontsize=10)
ax2.set_title('Sequences Shared by ‚â• N Repertoires', fontsize=12, fontweight='bold')
ax2.set_yscale('log')
ax2.grid(True, alpha=0.3)

# Disease enrichment of shared sequences
ax3 = axes[1, 0]
if len(disease_df) > 0:
    ax3.hist(disease_df['pos_ratio'], bins=20, color='steelblue', alpha=0.7, edgecolor='black')
    ax3.axvline(x=0.5, color='red', linestyle='--', label='Equal distribution')
    ax3.set_xlabel('Positive Ratio (n_pos / total)', fontsize=10)
    ax3.set_ylabel('Count', fontsize=10)
    ax3.set_title('Disease Enrichment of Shared Sequences', fontsize=12, fontweight='bold')
    ax3.legend()
else:
    ax3.text(0.5, 0.5, 'Insufficient data for analysis', ha='center', va='center', transform=ax3.transAxes)
    ax3.axis('off')

# Top disease-enriched sequences
ax4 = axes[1, 1]
if len(disease_df) > 0:
    # Get top positive-enriched sequences
    top_pos = disease_df.nlargest(15, 'pos_ratio')[['sequence', 'n_positive', 'n_negative', 'pos_ratio']]
    
    ax4.axis('off')
    table = ax4.table(cellText=top_pos.round(2).values,
                      colLabels=['Sequence', 'N Positive', 'N Negative', 'Pos Ratio'],
                      loc='center',
                      cellLoc='center',
                      colColours=['#4a90d9']*4)
    table.auto_set_font_size(False)
    table.set_fontsize(8)
    table.scale(1.2, 1.4)
    ax4.set_title('Top Disease-Enriched Shared Sequences', fontsize=12, fontweight='bold', y=1.02)
else:
    ax4.text(0.5, 0.5, 'Insufficient data', ha='center', va='center', transform=ax4.transAxes)
    ax4.axis('off')

plt.tight_layout()
plt.show()

# Print summary
print(f"\nüìä Shared Sequences Summary for {sample_dataset_name}:")
print(f"   ‚Ä¢ Total unique sequences: {len(shared_results['seq_to_repertoires']):,}")
print(f"   ‚Ä¢ Sequences in 1 repertoire only: {sharing_counts.get(1, 0):,}")
print(f"   ‚Ä¢ Sequences in 2+ repertoires: {sum(v for k, v in sharing_counts.items() if k >= 2):,}")
print(f"   ‚Ä¢ Sequences in 5+ repertoires: {sum(v for k, v in sharing_counts.items() if k >= 5):,}")
print(f"   ‚Ä¢ Sequences in 10+ repertoires: {sum(v for k, v in sharing_counts.items() if k >= 10):,}")

if len(disease_df) > 0:
    print(f"\nüìä Disease-Enriched Sequences (appearing in 3+ repertoires):")
    print(f"   ‚Ä¢ Total: {len(disease_df):,}")
    print(f"   ‚Ä¢ Highly positive-enriched (>70%): {len(disease_df[disease_df['pos_ratio'] > 0.7]):,}")
    print(f"   ‚Ä¢ Highly negative-enriched (<30%): {len(disease_df[disease_df['pos_ratio'] < 0.3]):,}")

## 6. Technical Bias Check (Batch Effects)

### 6.1 Sequencing Run Analysis

In [None]:
# =============================================================================
# TECHNICAL BIAS CHECK - BATCH EFFECTS
# =============================================================================

print("=" * 80)
print("TECHNICAL BIAS CHECK - BATCH EFFECTS")
print("=" * 80)

# Check for sequencing_run_id in metadata
if 'sequencing_run_id' in train_metadata.columns:
    print("\n‚úÖ sequencing_run_id found in metadata")
    
    run_analysis = train_metadata.groupby('sequencing_run_id').agg({
        'repertoire_id': 'count',
        'label_positive': ['sum', 'mean'],
        'dataset': 'nunique'
    }).round(3)
    run_analysis.columns = ['n_samples', 'n_positive', 'positive_rate', 'n_datasets']
    
    print(f"\nüìä Number of sequencing runs: {train_metadata['sequencing_run_id'].nunique()}")
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Samples per run
    ax1 = axes[0, 0]
    run_counts = train_metadata['sequencing_run_id'].value_counts()
    ax1.bar(range(len(run_counts)), run_counts.values, color='steelblue', alpha=0.7)
    ax1.set_xlabel('Sequencing Run Index', fontsize=10)
    ax1.set_ylabel('Number of Samples', fontsize=10)
    ax1.set_title('Samples per Sequencing Run', fontsize=12, fontweight='bold')
    
    # Label distribution by run
    ax2 = axes[0, 1]
    run_label_dist = train_metadata.groupby(['sequencing_run_id', 'label_positive']).size().unstack(fill_value=0)
    run_label_dist.plot(kind='bar', stacked=True, ax=ax2, color=['#e74c3c', '#2ecc71'])
    ax2.set_xlabel('Sequencing Run', fontsize=10)
    ax2.set_ylabel('Count', fontsize=10)
    ax2.set_title('Label Distribution by Sequencing Run', fontsize=12, fontweight='bold')
    ax2.legend(['Negative', 'Positive'])
    ax2.tick_params(axis='x', rotation=45)
    
    # Positive rate by run
    ax3 = axes[1, 0]
    positive_rates = train_metadata.groupby('sequencing_run_id')['label_positive'].mean().sort_values()
    ax3.barh(range(len(positive_rates)), positive_rates.values, color='coral', alpha=0.8)
    ax3.axvline(x=train_metadata['label_positive'].mean(), color='red', linestyle='--', 
                label=f'Overall: {train_metadata["label_positive"].mean():.2f}')
    ax3.set_xlabel('Positive Rate', fontsize=10)
    ax3.set_ylabel('Sequencing Run', fontsize=10)
    ax3.set_title('Positive Rate by Sequencing Run', fontsize=12, fontweight='bold')
    ax3.legend()
    
    # Chi-squared test for batch effect
    ax4 = axes[1, 1]
    contingency = pd.crosstab(train_metadata['sequencing_run_id'], train_metadata['label_positive'])
    chi2, p_val, dof, expected = stats.chi2_contingency(contingency)
    
    ax4.axis('off')
    text = f"""BATCH EFFECT STATISTICAL TEST
    
Chi-squared Test Results:
‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ
œá¬≤ statistic: {chi2:.2f}
Degrees of freedom: {dof}
P-value: {p_val:.2e}
‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ

Interpretation:
{'‚ö†Ô∏è SIGNIFICANT batch effect detected!' if p_val < 0.05 else '‚úÖ No significant batch effect'}
{'Consider stratified sampling by run' if p_val < 0.05 else 'Labels are independent of sequencing run'}
"""
    ax4.text(0.1, 0.5, text, fontsize=11, family='monospace', transform=ax4.transAxes, 
             verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.show()
    
    print("\nüìä Sequencing Run Summary:")
    display(run_analysis.sort_values('n_samples', ascending=False).head(15))
    
else:
    print("\n‚ö†Ô∏è sequencing_run_id not found in metadata")
    print("   Batch effect analysis by sequencing run not possible.")
    
    # Check for other potential batch indicators
    potential_batch_cols = ['study_group', 'cohort', 'batch', 'plate', 'run']
    found_cols = [col for col in potential_batch_cols if col in train_metadata.columns]
    
    if found_cols:
        print(f"\n   Alternative batch indicators found: {found_cols}")
        for col in found_cols:
            print(f"\nüìä {col} distribution:")
            print(train_metadata[col].value_counts().head(10))
    else:
        print("   No alternative batch indicators found.")

## 7. HLA Gene Analysis

In [None]:
# =============================================================================
# HLA GENE ANALYSIS
# =============================================================================

print("=" * 80)
print("HLA GENE ANALYSIS")
print("=" * 80)

# Find HLA-related columns
hla_cols = [col for col in train_metadata.columns if 'hla' in col.lower() or 'mhc' in col.lower()]

if hla_cols:
    print(f"\n‚úÖ HLA-related columns found: {hla_cols}")
    
    for hla_col in hla_cols:
        print(f"\n" + "=" * 60)
        print(f"ANALYSIS OF: {hla_col}")
        print("=" * 60)
        
        # Basic statistics
        n_unique = train_metadata[hla_col].nunique()
        n_missing = train_metadata[hla_col].isna().sum()
        print(f"\nüìä Basic stats:")
        print(f"   ‚Ä¢ Unique values: {n_unique:,}")
        print(f"   ‚Ä¢ Missing: {n_missing:,} ({n_missing/len(train_metadata)*100:.1f}%)")
        
        if n_unique > 0 and n_unique < 100:
            # Distribution plot
            fig, axes = plt.subplots(1, 2, figsize=(14, 5))
            
            # Overall distribution
            ax1 = axes[0]
            value_counts = train_metadata[hla_col].value_counts().head(20)
            ax1.barh(value_counts.index.astype(str), value_counts.values, color='steelblue', alpha=0.8)
            ax1.set_xlabel('Count', fontsize=10)
            ax1.set_ylabel(hla_col, fontsize=10)
            ax1.set_title(f'Top 20 {hla_col} Values', fontsize=12, fontweight='bold')
            ax1.invert_yaxis()
            
            # Association with label
            ax2 = axes[1]
            if 'label_positive' in train_metadata.columns:
                # Get top HLA types and their positive rates
                hla_label_stats = train_metadata.groupby(hla_col).agg({
                    'label_positive': ['count', 'sum', 'mean']
                }).round(3)
                hla_label_stats.columns = ['count', 'n_positive', 'positive_rate']
                hla_label_stats = hla_label_stats[hla_label_stats['count'] >= 5].sort_values('positive_rate', ascending=False)
                
                if len(hla_label_stats) > 0:
                    top_hla = hla_label_stats.head(15)
                    colors = ['#2ecc71' if r > 0.5 else '#e74c3c' for r in top_hla['positive_rate']]
                    ax2.barh(top_hla.index.astype(str), top_hla['positive_rate'], color=colors, alpha=0.8)
                    ax2.axvline(x=train_metadata['label_positive'].mean(), color='black', linestyle='--', 
                               label=f'Overall rate: {train_metadata["label_positive"].mean():.2f}')
                    ax2.set_xlabel('Positive Rate', fontsize=10)
                    ax2.set_ylabel(hla_col, fontsize=10)
                    ax2.set_title(f'{hla_col} Association with Label (min 5 samples)', fontsize=12, fontweight='bold')
                    ax2.legend()
                    ax2.invert_yaxis()
            
            plt.tight_layout()
            plt.show()
            
            # Statistical test for association
            if 'label_positive' in train_metadata.columns:
                # Chi-squared test
                valid_data = train_metadata[[hla_col, 'label_positive']].dropna()
                if len(valid_data) > 0:
                    contingency = pd.crosstab(valid_data[hla_col], valid_data['label_positive'])
                    if contingency.shape[0] > 1 and contingency.shape[1] > 1:
                        chi2, p_val, dof, expected = stats.chi2_contingency(contingency)
                        print(f"\nüìä Chi-squared test for {hla_col} vs label:")
                        print(f"   ‚Ä¢ œá¬≤ = {chi2:.2f}")
                        print(f"   ‚Ä¢ p-value = {p_val:.2e}")
                        print(f"   ‚Ä¢ {'‚ö†Ô∏è SIGNIFICANT association!' if p_val < 0.05 else '‚úÖ No significant association'}")
        
        elif n_unique >= 100:
            print(f"\n‚ö†Ô∏è Too many unique values ({n_unique}) for detailed visualization")
            print("   Top 10 values:")
            print(train_metadata[hla_col].value_counts().head(10))
else:
    print("\n‚ö†Ô∏è No HLA-related columns found in metadata")
    print("   HLA analysis not possible. Columns available:")
    print(f"   {list(train_metadata.columns)}")

## 8. Missing Values Assessment

In [None]:
# =============================================================================
# MISSING VALUES ASSESSMENT
# =============================================================================

print("=" * 80)
print("MISSING VALUES ASSESSMENT")
print("=" * 80)

# Metadata missing values
print("\nüìã METADATA MISSING VALUES:")
missing_meta = train_metadata.isnull().sum()
missing_pct = (missing_meta / len(train_metadata) * 100).round(2)
missing_df = pd.DataFrame({
    'Column': missing_meta.index,
    'Missing Count': missing_meta.values,
    'Missing %': missing_pct.values,
    'Total': len(train_metadata)
})
missing_df = missing_df[missing_df['Missing Count'] > 0].sort_values('Missing %', ascending=False)

if len(missing_df) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Missing values bar chart
    ax1 = axes[0]
    colors = ['#e74c3c' if pct > 50 else '#f39c12' if pct > 20 else '#2ecc71' for pct in missing_df['Missing %']]
    ax1.barh(missing_df['Column'], missing_df['Missing %'], color=colors, alpha=0.8)
    ax1.set_xlabel('Missing %', fontsize=10)
    ax1.set_ylabel('Column', fontsize=10)
    ax1.set_title('Missing Values in Metadata', fontsize=12, fontweight='bold')
    ax1.axvline(x=50, color='red', linestyle='--', alpha=0.5, label='50% threshold')
    ax1.axvline(x=20, color='orange', linestyle='--', alpha=0.5, label='20% threshold')
    ax1.legend(fontsize=8)
    ax1.invert_yaxis()
    
    # Missing values by dataset
    ax2 = axes[1]
    dataset_missing = train_metadata.groupby('dataset').apply(lambda x: x.isnull().sum().sum() / (len(x) * len(x.columns)) * 100)
    dataset_missing = dataset_missing.sort_values(ascending=False)
    ax2.barh(dataset_missing.index, dataset_missing.values, color='steelblue', alpha=0.8)
    ax2.set_xlabel('Overall Missing %', fontsize=10)
    ax2.set_ylabel('Dataset', fontsize=10)
    ax2.set_title('Missing Values by Dataset', fontsize=12, fontweight='bold')
    ax2.invert_yaxis()
    
    plt.tight_layout()
    plt.show()
    
    print("\nüìä Missing Values Summary:")
    display(missing_df)
else:
    print("   ‚úÖ No missing values in metadata!")

# Check for -999.0 placeholder values (as mentioned in competition description)
print("\n" + "-" * 80)
print("üìã CHECKING FOR -999.0 PLACEHOLDER VALUES:")
for col in train_metadata.select_dtypes(include=[np.number]).columns:
    n_placeholder = (train_metadata[col] == -999.0).sum()
    if n_placeholder > 0:
        print(f"   ‚Ä¢ {col}: {n_placeholder:,} instances of -999.0")

# Repertoire-level missing values analysis
print("\n" + "-" * 80)
print("üìã REPERTOIRE-LEVEL MISSING VALUES (Sample):")

# Check a few repertoire files
sample_dataset_path = os.path.join(TRAIN_DIR, train_datasets[0])
sample_files = glob.glob(os.path.join(sample_dataset_path, "*.tsv"))[:5]

repertoire_missing = []
for file_path in sample_files:
    rep = pd.read_csv(file_path, sep='\t')
    missing_info = {
        'file': os.path.basename(file_path),
        'total_rows': len(rep)
    }
    for col in rep.columns:
        missing_info[f'{col}_missing'] = rep[col].isnull().sum()
        missing_info[f'{col}_missing_%'] = rep[col].isnull().sum() / len(rep) * 100
    repertoire_missing.append(missing_info)

rep_missing_df = pd.DataFrame(repertoire_missing)
print(f"\nSample repertoire files from {train_datasets[0]}:")

# Show missing percentages for key columns
key_cols = ['junction_aa', 'v_call', 'j_call']
for col in key_cols:
    if f'{col}_missing_%' in rep_missing_df.columns:
        avg_missing = rep_missing_df[f'{col}_missing_%'].mean()
        max_missing = rep_missing_df[f'{col}_missing_%'].max()
        print(f"   ‚Ä¢ {col}: avg {avg_missing:.1f}% missing, max {max_missing:.1f}% missing")

## 9. Train vs Test Comparison

In [None]:
# =============================================================================
# TRAIN VS TEST COMPARISON
# =============================================================================

print("=" * 80)
print("TRAIN VS TEST COMPARISON")
print("=" * 80)

# Get test dataset info
test_datasets = sorted([d for d in os.listdir(TEST_DIR) if d.startswith("test_dataset_")])

print(f"\nüìä Overview:")
print(f"   ‚Ä¢ Training datasets: {len(train_datasets)}")
print(f"   ‚Ä¢ Test datasets: {len(test_datasets)}")

# Map test datasets to their training counterparts
train_test_mapping = defaultdict(list)
for test_ds in test_datasets:
    # Extract base ID (e.g., "test_dataset_1_a" -> "1")
    base_id = test_ds.replace("test_dataset_", "").split("_")[0]
    train_name = f"train_dataset_{base_id}"
    if train_name in train_datasets:
        train_test_mapping[train_name].append(test_ds)

print(f"\nüìã Train-Test Dataset Mapping:")
for train_ds, test_list in list(train_test_mapping.items())[:10]:
    print(f"   {train_ds} ‚Üí {test_list}")
if len(train_test_mapping) > 10:
    print(f"   ... and {len(train_test_mapping) - 10} more")

# Analyze test data structure
print("\n" + "=" * 80)
print("TEST DATA STRUCTURE ANALYSIS")
print("=" * 80)

def get_dataset_stats(base_dir: str, dataset_name: str) -> Dict:
    """Get basic stats for a dataset."""
    dataset_path = os.path.join(base_dir, dataset_name)
    tsv_files = glob.glob(os.path.join(dataset_path, "*.tsv"))
    
    stats = {
        'dataset': dataset_name,
        'n_files': len(tsv_files),
        'has_metadata': os.path.exists(os.path.join(dataset_path, 'metadata.csv'))
    }
    
    # Sample a file to get column info
    if tsv_files:
        sample_file = pd.read_csv(tsv_files[0], sep='\t', nrows=10)
        stats['columns'] = list(sample_file.columns)
        stats['n_columns'] = len(sample_file.columns)
    
    return stats

# Compare train and test structure
comparison_data = []

for train_ds in train_datasets[:5]:  # Sample comparison
    train_stats = get_dataset_stats(TRAIN_DIR, train_ds)
    train_stats['type'] = 'train'
    comparison_data.append(train_stats)
    
    for test_ds in train_test_mapping.get(train_ds, [])[:2]:
        test_stats = get_dataset_stats(TEST_DIR, test_ds)
        test_stats['type'] = 'test'
        comparison_data.append(test_stats)

comparison_df = pd.DataFrame(comparison_data)

# Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Number of repertoires
ax1 = axes[0, 0]
train_counts = train_metadata.groupby('dataset').size()
test_counts = []
for test_ds in test_datasets[:len(train_datasets)]:
    test_path = os.path.join(TEST_DIR, test_ds)
    n_files = len(glob.glob(os.path.join(test_path, "*.tsv")))
    test_counts.append({'dataset': test_ds, 'count': n_files})
test_counts_df = pd.DataFrame(test_counts)

x = np.arange(min(10, len(train_counts)))
width = 0.35
ax1.bar(x - width/2, train_counts.values[:10], width, label='Train', color='steelblue', alpha=0.8)
if len(test_counts_df) > 0:
    ax1.bar(x + width/2, test_counts_df['count'].values[:10], width, label='Test', color='coral', alpha=0.8)
ax1.set_xlabel('Dataset Index', fontsize=10)
ax1.set_ylabel('Number of Repertoires', fontsize=10)
ax1.set_title('Repertoire Count: Train vs Test', fontsize=12, fontweight='bold')
ax1.legend()

# Column comparison
ax2 = axes[0, 1]
if 'columns' in comparison_df.columns:
    train_cols = set()
    test_cols = set()
    for _, row in comparison_df.iterrows():
        if row['type'] == 'train':
            train_cols.update(row['columns'])
        else:
            test_cols.update(row['columns'])
    
    common = train_cols & test_cols
    train_only = train_cols - test_cols
    test_only = test_cols - train_cols
    
    data = [len(common), len(train_only), len(test_only)]
    labels = ['Common', 'Train Only', 'Test Only']
    colors = ['#2ecc71', '#3498db', '#e74c3c']
    ax2.pie(data, labels=labels, autopct='%1.0f', colors=colors, explode=(0.02, 0.02, 0.02))
    ax2.set_title('Column Overlap: Train vs Test', fontsize=12, fontweight='bold')
    
    print("\nüìã Column Comparison:")
    print(f"   ‚Ä¢ Common columns: {common}")
    print(f"   ‚Ä¢ Train-only columns: {train_only}")
    print(f"   ‚Ä¢ Test-only columns: {test_only}")

# Summary table
ax3 = axes[1, 0]
ax3.axis('off')
summary_text = f"""
TRAIN VS TEST SUMMARY
‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ

Training Data:
  ‚Ä¢ Datasets: {len(train_datasets)}
  ‚Ä¢ Total repertoires: {len(train_metadata):,}
  ‚Ä¢ Has labels: Yes (label_positive)

Test Data:
  ‚Ä¢ Datasets: {len(test_datasets)}
  ‚Ä¢ Has labels: No (to be predicted)
  ‚Ä¢ Mapped to {len(train_test_mapping)} training datasets

Key Differences:
  ‚Ä¢ Test data has NO label_positive column
  ‚Ä¢ Predictions needed for {sum(test_counts_df['count']) if len(test_counts_df) > 0 else 'N/A'} test repertoires
"""
ax3.text(0.1, 0.5, summary_text, fontsize=11, family='monospace', transform=ax3.transAxes,
         verticalalignment='center', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))

# Repertoire sizes comparison
ax4 = axes[1, 1]
# Sample test repertoire sizes
test_sizes = []
for test_ds in test_datasets[:3]:
    test_path = os.path.join(TEST_DIR, test_ds)
    for file_path in glob.glob(os.path.join(test_path, "*.tsv"))[:10]:
        with open(file_path, 'r') as f:
            n_lines = sum(1 for _ in f) - 1
        test_sizes.append({'dataset': test_ds, 'n_sequences': n_lines, 'type': 'test'})

# Compare with train sizes
if not repertoire_sizes.empty:
    train_size_sample = repertoire_sizes.head(100).copy()
    train_size_sample['type'] = 'train'
    combined_sizes = pd.concat([train_size_sample[['n_sequences', 'type']], 
                                pd.DataFrame(test_sizes)[['n_sequences', 'type']]], ignore_index=True)
    sns.boxplot(data=combined_sizes, x='type', y='n_sequences', ax=ax4, palette=['steelblue', 'coral'])
    ax4.set_xlabel('Dataset Type', fontsize=10)
    ax4.set_ylabel('Number of Sequences', fontsize=10)
    ax4.set_title('Repertoire Size: Train vs Test', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

## 10. Dimensionality Reduction Visualization (PCA/t-SNE)

In [None]:
# =============================================================================
# DIMENSIONALITY REDUCTION VISUALIZATION
# =============================================================================
# Create k-mer features for repertoires and visualize with PCA/t-SNE

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

def compute_kmer_features(dataset_path: str, k: int = 3, n_files: int = 50) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Compute k-mer frequency features for repertoires."""
    metadata_path = os.path.join(dataset_path, 'metadata.csv')
    
    if os.path.exists(metadata_path):
        metadata = pd.read_csv(metadata_path)
        files = metadata['filename'].tolist()[:n_files]
        labels = dict(zip(metadata['filename'], metadata['label_positive']))
    else:
        files = [os.path.basename(f) for f in glob.glob(os.path.join(dataset_path, "*.tsv"))[:n_files]]
        labels = {}
    
    feature_list = []
    meta_list = []
    
    for filename in tqdm(files, desc="Computing k-mer features", leave=False):
        file_path = os.path.join(dataset_path, filename)
        
        if os.path.exists(file_path):
            rep = pd.read_csv(file_path, sep='\t')
            
            if 'junction_aa' in rep.columns:
                kmer_counts = Counter()
                for seq in rep['junction_aa'].dropna():
                    for i in range(len(seq) - k + 1):
                        kmer_counts[seq[i:i+k]] += 1
                
                # Normalize by total
                total = sum(kmer_counts.values())
                kmer_freq = {kmer: count/total for kmer, count in kmer_counts.items()}
                kmer_freq['filename'] = filename
                feature_list.append(kmer_freq)
                
                meta_list.append({
                    'filename': filename,
                    'label_positive': labels.get(filename, None),
                    'n_sequences': len(rep)
                })
    
    features_df = pd.DataFrame(feature_list).fillna(0).set_index('filename')
    meta_df = pd.DataFrame(meta_list)
    
    return features_df, meta_df

print("=" * 80)
print("DIMENSIONALITY REDUCTION VISUALIZATION")
print("=" * 80)

# Compute k-mer features for a sample dataset
sample_dataset_name = train_datasets[0]
sample_dataset_path = os.path.join(TRAIN_DIR, sample_dataset_name)

print(f"\nüìä Computing 3-mer features for {sample_dataset_name}...")
kmer_features, kmer_meta = compute_kmer_features(sample_dataset_path, k=3, n_files=100)

if len(kmer_features) > 10:
    # Select top N most variable k-mers
    kmer_var = kmer_features.var().sort_values(ascending=False)
    top_kmers = kmer_var.head(500).index
    X = kmer_features[top_kmers].values
    
    # Standardize
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # PCA
    print("   Computing PCA...")
    pca = PCA(n_components=min(50, X_scaled.shape[1]))
    X_pca = pca.fit_transform(X_scaled)
    
    # t-SNE on PCA-reduced features
    print("   Computing t-SNE...")
    tsne = TSNE(n_components=2, perplexity=min(30, len(X_pca)//4), random_state=42, n_iter=1000)
    X_tsne = tsne.fit_transform(X_pca[:, :min(30, X_pca.shape[1])])
    
    # Visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    
    # PCA variance explained
    ax1 = axes[0, 0]
    cumsum_var = np.cumsum(pca.explained_variance_ratio_)
    ax1.plot(range(1, len(cumsum_var)+1), cumsum_var, 'b-', linewidth=2)
    ax1.axhline(y=0.9, color='r', linestyle='--', label='90% variance')
    ax1.axhline(y=0.95, color='orange', linestyle='--', label='95% variance')
    ax1.set_xlabel('Number of Components', fontsize=10)
    ax1.set_ylabel('Cumulative Explained Variance', fontsize=10)
    ax1.set_title('PCA Explained Variance', fontsize=12, fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # PCA plot colored by label
    ax2 = axes[0, 1]
    labels_array = kmer_meta.set_index('filename').loc[kmer_features.index, 'label_positive'].values
    if pd.notna(labels_array).any():
        scatter = ax2.scatter(X_pca[:, 0], X_pca[:, 1], 
                             c=['#2ecc71' if l == True else '#e74c3c' if l == False else 'gray' for l in labels_array],
                             alpha=0.6, s=50)
        ax2.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)', fontsize=10)
        ax2.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)', fontsize=10)
        ax2.set_title('PCA: Colored by Label', fontsize=12, fontweight='bold')
        
        # Add legend
        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor='#2ecc71', label='Positive'),
                          Patch(facecolor='#e74c3c', label='Negative')]
        ax2.legend(handles=legend_elements, loc='upper right')
    
    # t-SNE colored by label
    ax3 = axes[1, 0]
    if pd.notna(labels_array).any():
        scatter = ax3.scatter(X_tsne[:, 0], X_tsne[:, 1],
                             c=['#2ecc71' if l == True else '#e74c3c' if l == False else 'gray' for l in labels_array],
                             alpha=0.6, s=50)
        ax3.set_xlabel('t-SNE 1', fontsize=10)
        ax3.set_ylabel('t-SNE 2', fontsize=10)
        ax3.set_title('t-SNE: Colored by Label', fontsize=12, fontweight='bold')
        ax3.legend(handles=legend_elements, loc='upper right')
    
    # t-SNE colored by repertoire size
    ax4 = axes[1, 1]
    sizes_array = kmer_meta.set_index('filename').loc[kmer_features.index, 'n_sequences'].values
    scatter = ax4.scatter(X_tsne[:, 0], X_tsne[:, 1], c=np.log10(sizes_array), 
                         cmap='viridis', alpha=0.6, s=50)
    plt.colorbar(scatter, ax=ax4, label='log10(n_sequences)')
    ax4.set_xlabel('t-SNE 1', fontsize=10)
    ax4.set_ylabel('t-SNE 2', fontsize=10)
    ax4.set_title('t-SNE: Colored by Repertoire Size', fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Print insights
    print(f"\nüìä Dimensionality Reduction Results:")
    print(f"   ‚Ä¢ Number of repertoires: {len(kmer_features)}")
    print(f"   ‚Ä¢ Number of k-mer features: {len(top_kmers)}")
    print(f"   ‚Ä¢ Components for 90% variance: {np.argmax(cumsum_var >= 0.9) + 1}")
    print(f"   ‚Ä¢ Components for 95% variance: {np.argmax(cumsum_var >= 0.95) + 1}")
    
    # Check for separation
    if pd.notna(labels_array).any():
        pos_mask = labels_array == True
        neg_mask = labels_array == False
        if pos_mask.sum() > 0 and neg_mask.sum() > 0:
            from scipy.stats import mannwhitneyu
            # Test separation on PC1
            stat, pval = mannwhitneyu(X_pca[pos_mask, 0], X_pca[neg_mask, 0])
            print(f"\nüìä Separability Test (PC1):")
            print(f"   ‚Ä¢ Mann-Whitney U test p-value: {pval:.2e}")
            print(f"   ‚Ä¢ {'‚úÖ Significant separation!' if pval < 0.05 else '‚ö†Ô∏è No significant separation'}")
else:
    print("‚ö†Ô∏è Not enough data for dimensionality reduction")

## 11. EDA Summary & Key Findings

In [None]:
# =============================================================================
# EDA SUMMARY & KEY FINDINGS
# =============================================================================

print("=" * 80)
print("üìä EDA SUMMARY & KEY FINDINGS")
print("=" * 80)

summary_text = """
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ                          DATA OVERVIEW                                       ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ Training Data:                                                               ‚îÇ
‚îÇ   ‚Ä¢ Number of datasets: {n_train_datasets}                                   ‚îÇ
‚îÇ   ‚Ä¢ Total repertoires: {n_train_samples:,}                                   ‚îÇ
‚îÇ   ‚Ä¢ Class balance: {pos_pct:.1f}% positive, {neg_pct:.1f}% negative          ‚îÇ
‚îÇ                                                                              ‚îÇ
‚îÇ Test Data:                                                                   ‚îÇ
‚îÇ   ‚Ä¢ Number of datasets: {n_test_datasets}                                    ‚îÇ
‚îÇ   ‚Ä¢ Labels: Not provided (to predict)                                        ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò

‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ                          KEY FINDINGS                                        ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ 1. METADATA FEATURES:                                                        ‚îÇ
‚îÇ    ‚Ä¢ Available columns vary by dataset                                       ‚îÇ
‚îÇ    ‚Ä¢ Demographic features may have significant missing values                ‚îÇ
‚îÇ    ‚Ä¢ Some features show association with label (check correlation results)   ‚îÇ
‚îÇ                                                                              ‚îÇ
‚îÇ 2. SEQUENCE CHARACTERISTICS:                                                 ‚îÇ
‚îÇ    ‚Ä¢ Junction AA lengths vary across datasets                                ‚îÇ
‚îÇ    ‚Ä¢ V/J gene usage patterns differ between datasets                         ‚îÇ
‚îÇ    ‚Ä¢ D gene information may not be available in all datasets                 ‚îÇ
‚îÇ                                                                              ‚îÇ
‚îÇ 3. DIVERSITY METRICS:                                                        ‚îÇ
‚îÇ    ‚Ä¢ Shannon entropy and clonality vary by label                             ‚îÇ
‚îÇ    ‚Ä¢ Disease samples may show altered diversity patterns                     ‚îÇ
‚îÇ    ‚Ä¢ Consider diversity as a feature for classification                      ‚îÇ
‚îÇ                                                                              ‚îÇ
‚îÇ 4. SHARED SEQUENCES (PUBLIC CLONES):                                         ‚îÇ
‚îÇ    ‚Ä¢ Many sequences are private (appear in single individual)                ‚îÇ
‚îÇ    ‚Ä¢ Some sequences appear across multiple individuals                       ‚îÇ
‚îÇ    ‚Ä¢ Disease-associated public clones may exist                              ‚îÇ
‚îÇ                                                                              ‚îÇ
‚îÇ 5. POTENTIAL ISSUES:                                                         ‚îÇ
‚îÇ    ‚Ä¢ Check for batch effects if sequencing_run_id varies                     ‚îÇ
‚îÇ    ‚Ä¢ Missing values need handling strategy                                   ‚îÇ
‚îÇ    ‚Ä¢ Dataset heterogeneity requires careful cross-validation                 ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò

‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ                     RECOMMENDATIONS FOR MODELING                             ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ 1. Feature Engineering:                                                      ‚îÇ
‚îÇ    ‚Ä¢ K-mer frequencies (3-mer, 4-mer)                                        ‚îÇ
‚îÇ    ‚Ä¢ V/J gene usage profiles                                                 ‚îÇ
‚îÇ    ‚Ä¢ Diversity metrics (entropy, clonality, richness)                        ‚îÇ
‚îÇ    ‚Ä¢ Sequence length distributions                                           ‚îÇ
‚îÇ    ‚Ä¢ Public clone presence/absence                                           ‚îÇ
‚îÇ                                                                              ‚îÇ
‚îÇ 2. Model Selection:                                                          ‚îÇ
‚îÇ    ‚Ä¢ Consider ensemble methods for heterogeneous data                        ‚îÇ
‚îÇ    ‚Ä¢ Test both traditional ML (XGBoost, RF) and deep learning                ‚îÇ
‚îÇ    ‚Ä¢ Use stratified cross-validation by dataset                              ‚îÇ
‚îÇ                                                                              ‚îÇ
‚îÇ 3. Important Sequence Identification:                                        ‚îÇ
‚îÇ    ‚Ä¢ Fisher's exact test for sequence-label association                      ‚îÇ
‚îÇ    ‚Ä¢ Feature importance from tree-based models                               ‚îÇ
‚îÇ    ‚Ä¢ Attention weights from neural networks                                  ‚îÇ
‚îÇ                                                                              ‚îÇ
‚îÇ 4. Handling Challenges:                                                      ‚îÇ
‚îÇ    ‚Ä¢ Impute or handle missing values consistently                            ‚îÇ
‚îÇ    ‚Ä¢ Account for batch effects in validation                                 ‚îÇ
‚îÇ    ‚Ä¢ Normalize for repertoire size differences                               ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
"""

# Calculate summary statistics
if 'label_positive' in train_metadata.columns:
    pos_count = train_metadata['label_positive'].sum()
    neg_count = len(train_metadata) - pos_count
    pos_pct = pos_count / len(train_metadata) * 100
    neg_pct = neg_count / len(train_metadata) * 100
else:
    pos_pct = neg_pct = 0

print(summary_text.format(
    n_train_datasets=len(train_datasets),
    n_train_samples=len(train_metadata),
    pos_pct=pos_pct,
    neg_pct=neg_pct,
    n_test_datasets=len(test_datasets)
))

print("\n" + "=" * 80)
print("‚úÖ EDA COMPLETE - Ready for modeling!")
print("=" * 80)

---

# Part 2: Model Template & Submission Code

The following sections contain the competition template code for building the `ImmuneStatePredictor` class and generating submissions.

## Need for a uniform interface of running models

As described in the official competition page, to win the prize money, a prerequisite is that the code has to be made open-source. In addition, the top 10 submissions/teams will be invited to become co-authors in a scientific paper that involves further stress-testing of their models in a subsequent phase with many other datasets outside Kaggle platform. **To enable such further analyses and re-use of the models by the community, we strongly encourage** the participants to adhere to a code template that we provide through this repository that enables a uniform interface of running models: [https://github.com/uio-bmi/predict-airr](https://github.com/uio-bmi/predict-airr)


Ideally, all the methods can be run in a unified way, e.g.,

`python3 -m submission.main --train_dir /path/to/train_dir --test_dirs /path/to/test_dir_1 /path/to/test_dir_2 --out_dir /path/to/output_dir --n_jobs 4 --device cpu`

## Adhering to code template on Kaggle Notebooks

Those participants who make use of Kaggle resources and Kaggle notebooks to develop and run their code are also strongly encouraged to copy the code template, particularly the `ImmuneStatePredictor` class and any utility functions from the provided code template repository and adhere to the code template to enable a unified way of running different methods at a later stage. In this notebook, we copied the code template below for participants to paste into their respective Kaggle notebooks and edit as needed.

In [None]:
## imports required for the basic code template below.

import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import glob
import sys
import argparse
from collections import defaultdict
from typing import Iterator, Tuple, Union, List

In [None]:
## some utility functions such as data loaders, etc.

def load_data_generator(data_dir: str, metadata_filename='metadata.csv') -> Iterator[
    Union[Tuple[str, pd.DataFrame, bool], Tuple[str, pd.DataFrame]]]:
    """
    A generator to load immune repertoire data.

    This function operates in two modes:
    1.  If metadata is found, it yields data based on the metadata file.
    2.  If metadata is NOT found, it uses glob to find and yield all '.tsv'
        files in the directory.

    Args:
        data_dir (str): The path to the directory containing the data.

    Yields:
        An iterator of tuples. The format depends on the mode:
        - With metadata: (repertoire_id, pd.DataFrame, label_positive)
        - Without metadata: (filename, pd.DataFrame)
    """
    metadata_path = os.path.join(data_dir, metadata_filename)

    if os.path.exists(metadata_path):
        metadata_df = pd.read_csv(metadata_path)
        for row in metadata_df.itertuples(index=False):
            file_path = os.path.join(data_dir, row.filename)
            try:
                repertoire_df = pd.read_csv(file_path, sep='\t')
                yield row.repertoire_id, repertoire_df, row.label_positive
            except FileNotFoundError:
                print(f"Warning: File '{row.filename}' listed in metadata not found. Skipping.")
                continue
    else:
        search_pattern = os.path.join(data_dir, '*.tsv')
        tsv_files = glob.glob(search_pattern)
        for file_path in sorted(tsv_files):
            try:
                filename = os.path.basename(file_path)
                repertoire_df = pd.read_csv(file_path, sep='\t')
                yield filename, repertoire_df
            except Exception as e:
                print(f"Warning: Could not read file '{file_path}'. Error: {e}. Skipping.")
                continue


def load_full_dataset(data_dir: str) -> pd.DataFrame:
    """
    Loads all TSV files from a directory and concatenates them into a single DataFrame.

    This function handles two scenarios:
    1. If metadata.csv exists, it loads data based on the metadata and adds
       'repertoire_id' and 'label_positive' columns.
    2. If metadata.csv does not exist, it loads all .tsv files and adds
       a 'filename' column as an identifier.

    Args:
        data_dir (str): The path to the data directory.

    Returns:
        pd.DataFrame: A single, concatenated DataFrame containing all the data.
    """
    metadata_path = os.path.join(data_dir, 'metadata.csv')
    df_list = []
    data_loader = load_data_generator(data_dir=data_dir)

    if os.path.exists(metadata_path):
        metadata_df = pd.read_csv(metadata_path)
        total_files = len(metadata_df)
        for rep_id, data_df, label in tqdm(data_loader, total=total_files, desc="Loading files"):
            data_df['ID'] = rep_id
            data_df['label_positive'] = label
            df_list.append(data_df)
    else:
        search_pattern = os.path.join(data_dir, '*.tsv')
        total_files = len(glob.glob(search_pattern))
        for filename, data_df in tqdm(data_loader, total=total_files, desc="Loading files"):
            data_df['ID'] = os.path.basename(filename).replace(".tsv", "")
            df_list.append(data_df)

    if not df_list:
        print("Warning: No data files were loaded.")
        return pd.DataFrame()

    full_dataset_df = pd.concat(df_list, ignore_index=True)
    return full_dataset_df


def load_and_encode_kmers(data_dir: str, k: int = 3) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Loading and k-mer encoding of repertoire data.

    Args:
        data_dir: Path to data directory
        k: K-mer length

    Returns:
        Tuple of (encoded_features_df, metadata_df)
        metadata_df always contains 'ID', and 'label_positive' if available
    """
    from collections import Counter

    metadata_path = os.path.join(data_dir, 'metadata.csv')
    data_loader = load_data_generator(data_dir=data_dir)

    repertoire_features = []
    metadata_records = []

    search_pattern = os.path.join(data_dir, '*.tsv')
    total_files = len(glob.glob(search_pattern))

    for item in tqdm(data_loader, total=total_files, desc=f"Encoding {k}-mers"):
        if os.path.exists(metadata_path):
            rep_id, data_df, label = item
        else:
            filename, data_df = item
            rep_id = os.path.basename(filename).replace(".tsv", "")
            label = None

        kmer_counts = Counter()
        for seq in data_df['junction_aa'].dropna():
            for i in range(len(seq) - k + 1):
                kmer_counts[seq[i:i + k]] += 1

        repertoire_features.append({
            'ID': rep_id,
            **kmer_counts
        })

        metadata_record = {'ID': rep_id}
        if label is not None:
            metadata_record['label_positive'] = label
        metadata_records.append(metadata_record)

        del data_df, kmer_counts

    features_df = pd.DataFrame(repertoire_features).fillna(0).set_index('ID')
    features_df.fillna(0)
    metadata_df = pd.DataFrame(metadata_records)

    return features_df, metadata_df


def save_tsv(df: pd.DataFrame, path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    df.to_csv(path, sep='\t', index=False)


def get_repertoire_ids(data_dir: str) -> list:
    """
    Retrieves repertoire IDs from the metadata file or filenames in the directory.

    Args:
        data_dir (str): The path to the data directory.

    Returns:
        list: A list of repertoire IDs.
    """
    metadata_path = os.path.join(data_dir, 'metadata.csv')

    if os.path.exists(metadata_path):
        metadata_df = pd.read_csv(metadata_path)
        repertoire_ids = metadata_df['repertoire_id'].tolist()
    else:
        search_pattern = os.path.join(data_dir, '*.tsv')
        tsv_files = glob.glob(search_pattern)
        repertoire_ids = [os.path.basename(f).replace('.tsv', '') for f in sorted(tsv_files)]

    return repertoire_ids


def generate_random_top_sequences_df(n_seq: int = 50000) -> pd.DataFrame:
    """
    Generates a random DataFrame simulating top important sequences.

    Args:
        n_seq (int): Number of sequences to generate.

    Returns:
        pd.DataFrame: A DataFrame with columns 'ID', 'dataset', 'junction_aa', 'v_call', 'j_call'.
    """
    seqs = set()
    while len(seqs) < n_seq:
        seq = ''.join(np.random.choice(list('ACDEFGHIKLMNPQRSTVWY'), size=15))
        seqs.add(seq)
    data = {
        'junction_aa': list(seqs),
        'v_call': ['TRBV20-1'] * n_seq,
        'j_call': ['TRBJ2-7'] * n_seq,
        'importance_score': np.random.rand(n_seq)
    }
    return pd.DataFrame(data)


def validate_dirs_and_files(train_dir: str, test_dirs: List[str], out_dir: str) -> None:
    assert os.path.isdir(train_dir), f"Train directory `{train_dir}` does not exist."
    train_tsvs = glob.glob(os.path.join(train_dir, "*.tsv"))
    assert train_tsvs, f"No .tsv files found in train directory `{train_dir}`."
    metadata_path = os.path.join(train_dir, "metadata.csv")
    assert os.path.isfile(metadata_path), f"`metadata.csv` not found in train directory `{train_dir}`."

    for test_dir in test_dirs:
        assert os.path.isdir(test_dir), f"Test directory `{test_dir}` does not exist."
        test_tsvs = glob.glob(os.path.join(test_dir, "*.tsv"))
        assert test_tsvs, f"No .tsv files found in test directory `{test_dir}`."

    try:
        os.makedirs(out_dir, exist_ok=True)
        test_file = os.path.join(out_dir, "test_write_permission.tmp")
        with open(test_file, "w") as f:
            f.write("test")
        os.remove(test_file)
    except Exception as e:
        print(f"Failed to create or write to output directory `{out_dir}`: {e}")
        sys.exit(1)


def concatenate_output_files(out_dir: str) -> None:
    """
    Concatenates all test predictions and important sequences TSV files from the output directory.

    This function finds all files matching the patterns:
    - *_test_predictions.tsv
    - *_important_sequences.tsv

    and concatenates them to match the expected output format of submissions.csv.

    Args:
        out_dir (str): Path to the output directory containing the TSV files.

    Returns:
        pd.DataFrame: Concatenated DataFrame with predictions followed by important sequences.
                     Columns: ['ID', 'dataset', 'label_positive_probability', 'junction_aa', 'v_call', 'j_call']
    """
    predictions_pattern = os.path.join(out_dir, '*_test_predictions.tsv')
    sequences_pattern = os.path.join(out_dir, '*_important_sequences.tsv')

    predictions_files = sorted(glob.glob(predictions_pattern))
    sequences_files = sorted(glob.glob(sequences_pattern))

    df_list = []

    for pred_file in predictions_files:
        try:
            df = pd.read_csv(pred_file, sep='\t')
            df_list.append(df)
        except Exception as e:
            print(f"Warning: Could not read predictions file '{pred_file}'. Error: {e}. Skipping.")
            continue

    for seq_file in sequences_files:
        try:
            df = pd.read_csv(seq_file, sep='\t')
            df_list.append(df)
        except Exception as e:
            print(f"Warning: Could not read sequences file '{seq_file}'. Error: {e}. Skipping.")
            continue

    if not df_list:
        print("Warning: No output files were found to concatenate.")
        concatenated_df = pd.DataFrame(
            columns=['ID', 'dataset', 'label_positive_probability', 'junction_aa', 'v_call', 'j_call'])
    else:
        concatenated_df = pd.concat(df_list, ignore_index=True)
    submissions_file = os.path.join(out_dir, 'submissions.csv')
    concatenated_df.to_csv(submissions_file, index=False)
    print(f"Concatenated output written to `{submissions_file}`.")


def get_dataset_pairs(train_dir: str, test_dir: str) -> List[Tuple[str, List[str]]]:
    """Returns list of (train_path, [test_paths]) tuples for dataset pairs."""
    test_groups = defaultdict(list)
    for test_name in sorted(os.listdir(test_dir)):
        if test_name.startswith("test_dataset_"):
            base_id = test_name.replace("test_dataset_", "").split("_")[0]
            test_groups[base_id].append(os.path.join(test_dir, test_name))

    pairs = []
    for train_name in sorted(os.listdir(train_dir)):
        if train_name.startswith("train_dataset_"):
            train_id = train_name.replace("train_dataset_", "")
            train_path = os.path.join(train_dir, train_name)
            pairs.append((train_path, test_groups.get(train_id, [])))

    return pairs

In [None]:
## Main ImmuneStatePredictor class, where participants will fill in their implementations within the placeholders 
## and replace any example code lines with actual code that makes sense


class ImmuneStatePredictor:
    """
    A template for predicting immune states from TCR repertoire data.

    Participants should implement the logic for training, prediction, and
    sequence identification within this class.
    """

    def __init__(self, n_jobs: int = 1, device: str = 'cpu', **kwargs):
        """
        Initializes the predictor.

        Args:
            n_jobs (int): Number of CPU cores to use for parallel processing.
            device (str): The device to use for computation (e.g., 'cpu', 'cuda').
            **kwargs: Additional hyperparameters for the model.
        """
        total_cores = os.cpu_count()
        if n_jobs == -1:
            self.n_jobs = total_cores
        else:
            self.n_jobs = min(n_jobs, total_cores)
        self.device = device
        if device == 'cuda' and not torch.cuda.is_available():
            print("Warning: 'cuda' was requested but is not available. Falling back to 'cpu'.")
            self.device = 'cpu'
        else:
            self.device = device
        # --- your code starts here ---
        # Example: Store hyperparameters, the actual model, identified important sequences, etc.

        # NOTE: we encourage you to use self.n_jobs and self.device if appropriate in
        # your implementation instead of hardcoding these values because your code may later be run in an
        # environment with different hardware resources.

        self.model = None
        self.important_sequences_ = None
        # --- your code ends here ---

    def fit(self, train_dir_path: str):
        """
        Trains the model on the provided training data.

        Args:
            train_dir_path (str): Path to the directory with training TSV files.

        Returns:
            self: The fitted predictor instance.
        """

        # --- your code starts here ---
        # Load the data, prepare suited representations as needed, train your model,
        # and find the top k important sequences that best explain the labels.
        # Example: Load the data. One possibility could be to use the provided utility function as shown below.

        # full_train_dataset_df = load_full_dataset(train_dir_path)

        #   Model Training
        #    Example: self.model = SomeClassifier().fit(X_train, y_train)
        self.model = "some trained model"  # Replace with your actual learnt model

        #   Identify important sequences (can be done here or in the dedicated method)
        #    Example:
        self.important_sequences_ = self.identify_associated_sequences(top_k=50000, dataset_name=os.path.basename(train_dir_path))

        # --- your code ends here ---
        print("Training complete.")
        return self

    def predict_proba(self, test_dir_path: str) -> pd.DataFrame:
        """
        Predicts probabilities for examples in the provided path.

        Args:
            test_dir_path (str): Path to the directory with test TSV files.

        Returns:
            pd.DataFrame: A DataFrame with 'ID', 'dataset', 'label_positive_probability', 'junction_aa', 'v_call', 'j_call' columns.
        """
        print(f"Making predictions for data in {test_dir_path}...")
        if self.model is None:
            raise RuntimeError("The model has not been fitted yet. Please call `fit` first.")

        # --- your code starts here ---

        # Example: Load the data. One possibility could be to use the provided utility function as shown below.

        # full_test_dataset_df = load_full_dataset(test_dir_path)
        repertoire_ids = get_repertoire_ids(test_dir_path)  # Replace with actual repertoire IDs from the test data

        # Prediction
        #    Example:
        # draw random probabilities for demonstration purposes

        probabilities = np.random.rand(len(repertoire_ids)) # Replace with true predicted probabilities from your model

        # --- your code ends here ---

        predictions_df = pd.DataFrame({
            'ID': repertoire_ids,
            'dataset': [os.path.basename(test_dir_path)] * len(repertoire_ids),
            'label_positive_probability': probabilities
        })

        # to enable compatibility with the expected output format that includes junction_aa, v_call, j_call columns
        predictions_df['junction_aa'] = -999.0
        predictions_df['v_call'] = -999.0
        predictions_df['j_call'] = -999.0

        predictions_df = predictions_df[['ID', 'dataset', 'label_positive_probability', 'junction_aa', 'v_call', 'j_call']]

        print(f"Prediction complete on {len(repertoire_ids)} examples in {test_dir_path}.")
        return predictions_df

    def identify_associated_sequences(self, dataset_name: str, top_k: int = 50000) -> pd.DataFrame:
        """
        Identifies the top "k" important sequences (rows) from the training data that best explain the labels.

        Args:
            top_k (int): The number of top sequences to return (based on some scoring mechanism).

        Returns:
            pd.DataFrame: A DataFrame with 'ID', 'dataset', 'label_positive_probability', 'junction_aa', 'v_call', 'j_call' columns.
        """

        # --- your code starts here ---
        
        # Return the top k sequences, sorted based on some form of importance score.
        # Example:
        # all_sequences_scored = self._score_all_sequences()
        
        all_sequences_scored = generate_random_top_sequences_df(n_seq=top_k)  # Replace with your way of identifying top k sequences

        # note that all_sequences_scored should contain a 'importance_score' column that will be used further below
        
        # --- your code ends here ---

        top_sequences_df = all_sequences_scored.nlargest(top_k, 'importance_score')
        top_sequences_df = top_sequences_df[['junction_aa', 'v_call', 'j_call']]
        top_sequences_df['dataset'] = dataset_name
        top_sequences_df['ID'] = range(1, len(top_sequences_df)+1)
        top_sequences_df['ID'] = top_sequences_df['dataset'] + '_seq_top_' + top_sequences_df['ID'].astype(str)
        top_sequences_df['label_positive_probability'] = -999.0 # to enable compatibility with the expected output format
        top_sequences_df = top_sequences_df[['ID', 'dataset', 'label_positive_probability', 'junction_aa', 'v_call', 'j_call']]

        return top_sequences_df

In [None]:
## The `main` workflow that uses your implementation of the ImmuneStatePredictor class to train, identify important sequences and predict test labels


def _train_predictor(predictor: ImmuneStatePredictor, train_dir: str):
    """Trains the predictor on the training data."""
    print(f"Fitting model on examples in ` {train_dir} `...")
    predictor.fit(train_dir)


def _generate_predictions(predictor: ImmuneStatePredictor, test_dirs: List[str]) -> pd.DataFrame:
    """Generates predictions for all test directories and concatenates them."""
    all_preds = []
    for test_dir in test_dirs:
        print(f"Predicting on examples in ` {test_dir} `...")
        preds = predictor.predict_proba(test_dir)
        if preds is not None and not preds.empty:
            all_preds.append(preds)
        else:
            print(f"Warning: No predictions returned for {test_dir}")
    if all_preds:
        return pd.concat(all_preds, ignore_index=True)
    return pd.DataFrame()


def _save_predictions(predictions: pd.DataFrame, out_dir: str, train_dir: str) -> None:
    """Saves predictions to a TSV file."""
    if predictions.empty:
        raise ValueError("No predictions to save - predictions DataFrame is empty")

    preds_path = os.path.join(out_dir, f"{os.path.basename(train_dir)}_test_predictions.tsv")
    save_tsv(predictions, preds_path)
    print(f"Predictions written to `{preds_path}`.")


def _save_important_sequences(predictor: ImmuneStatePredictor, out_dir: str, train_dir: str) -> None:
    """Saves important sequences to a TSV file."""
    seqs = predictor.important_sequences_
    if seqs is None or seqs.empty:
        raise ValueError("No important sequences available to save")

    seqs_path = os.path.join(out_dir, f"{os.path.basename(train_dir)}_important_sequences.tsv")
    save_tsv(seqs, seqs_path)
    print(f"Important sequences written to `{seqs_path}`.")


def main(train_dir: str, test_dirs: List[str], out_dir: str, n_jobs: int, device: str) -> None:
    validate_dirs_and_files(train_dir, test_dirs, out_dir)
    predictor = ImmuneStatePredictor(n_jobs=n_jobs,
                                     device=device)  # instantiate with any other parameters as defined by you in the class
    _train_predictor(predictor, train_dir)
    predictions = _generate_predictions(predictor, test_dirs)
    _save_predictions(predictions, out_dir, train_dir)
    _save_important_sequences(predictor, out_dir, train_dir)


def run():
    parser = argparse.ArgumentParser(description="Immune State Predictor CLI")
    parser.add_argument("--train_dir", required=True, help="Path to training data directory")
    parser.add_argument("--test_dirs", required=True, nargs="+", help="Path(s) to test data director(ies)")
    parser.add_argument("--out_dir", required=True, help="Path to output directory")
    parser.add_argument("--n_jobs", type=int, default=1,
                        help="Number of CPU cores to use. Use -1 for all available cores.")
    parser.add_argument("--device", type=str, default='cpu', choices=['cpu', 'cuda'],
                        help="Device to use for computation ('cpu' or 'cuda').")
    args = parser.parse_args()
    main(args.train_dir, args.test_dirs, args.out_dir, args.n_jobs, args.device)


In [None]:
train_datasets_dir = "/kaggle/input/adaptive-immune-profiling-challenge-2025/train_datasets/train_datasets"
test_datasets_dir = "/kaggle/input/adaptive-immune-profiling-challenge-2025/test_datasets/test_datasets"
results_dir = "/kaggle/working/results"

train_test_dataset_pairs = get_dataset_pairs(train_datasets_dir, test_datasets_dir)

for train_dir, test_dirs in train_test_dataset_pairs:
    main(train_dir=train_dir, test_dirs=test_dirs, out_dir=results_dir, n_jobs=4, device="cpu")

In [None]:
concatenate_output_files(out_dir=results_dir)