# Introduction to HR Breast Cancer RNA Sequencing Analysis

This notebook provides a comprehensive analysis of single-cell RNA sequencing data from HR (Hormone Receptor) breast cancer patients. The dataset (GSE300475) includes gene expression profiles and T-cell receptor (TCR) sequences from responders and non-responders to immunotherapy.

## Objectives
- Process and integrate single-cell RNA-seq and TCR data
- Perform unsupervised clustering to identify cell populations
- Develop supervised models to predict treatment response
- Analyze sequence patterns and gene expression signatures
- Identify cluster-specific markers and associations with clinical outcomes

## Methodology
- Data loading and quality control using Scanpy
- Sequence encoding (one-hot, k-mer, physicochemical) for TCR data
- Dimensionality reduction (PCA, UMAP, t-SNE) for gene expression
- Unsupervised clustering (K-Means, HDBSCAN, etc.)
- Supervised classification with cross-validation
- Interpretation of clusters and biomarkers

## 1. Data Acquisition and Loading

In [None]:
import pandas as pd
import requests
import os
import tarfile
from io import BytesIO

In [None]:
files_to_fetch = [
    {
        "name": "GSE300475_RAW.tar",
        "size": "565.5 Mb",
        "download_url": "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE300475&format=file",
        "type": "TAR (of CSV, MTX, TSV)"
    },
    {
        "name": "GSE300475_feature_ref.xlsx",
        "size": "5.4 Kb",
        "download_url": "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE300nnn/GSE300475/suppl/GSE300475%5Ffeature%5Fref.xlsx",
        "type": "XLSX"
    }
]

In [None]:
download_dir = "../Data"
os.makedirs(download_dir, exist_ok=True)
print(f"Downloads will be saved in: {os.path.abspath(download_dir)}\n")

def download_file(url, filename, destination_folder):
    """
    Downloads a file from a given URL to a specified destination folder.
    """
    filepath = os.path.join(destination_folder, filename)
    print(f"Attempting to download {filename} from {url}...")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()

        with open(filepath, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print(f"Successfully downloaded {filename} to {filepath}")
        return filepath
    except requests.exceptions.RequestException as e:
        print(f"Error downloading {filename}: {e}")
        return None

In [None]:
for file_info in files_to_fetch:
    filename = file_info["name"]
    url = file_info["download_url"]
    file_type = file_info["type"]

    downloaded_filepath = download_file(url, filename, download_dir)

        # If the file is a TAR archive, extract it and list the contents
    if downloaded_filepath and filename.endswith(".tar"):
        print(f"Extracting {filename}...\n")
        try:
            with tarfile.open(downloaded_filepath, "r") as tar:
                # List contents
                members = tar.getnames()
                print(f"Files contained in {filename}:")
                for member in members:
                    print(f" - {member}")

                # Extract to a subdirectory within download_dir
                extract_path = os.path.join(download_dir, filename.replace(".tar", ""))
                os.makedirs(extract_path, exist_ok=True)
                tar.extractall(path=extract_path)
                print(f"\nExtracted to: {extract_path}")
        except tarfile.TarError as e:
            print(f"Error extracting {filename}: {e}")

        print("-" * 50 + "\n")

In [None]:
import gzip
import shutil
from pathlib import Path
import pandas as pd
from scipy.io import mmread

def decompress_gz_file(gz_path, output_dir):
    """
    Decompress a .gz file to the specified output directory.
    """
    output_path = os.path.join(output_dir, Path(gz_path).stem)
    print(f"Decompressing {gz_path} → {output_path}")
    try:
        with gzip.open(gz_path, 'rb') as f_in, open(output_path, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
        return output_path
    except Exception as e:
        print(f"Failed to decompress {gz_path}: {e}")
        return None

def preview_file(file_path):
    """
    Display the first few lines of a decompressed file, based on its extension.
    """
    print(f"\n--- Preview of {os.path.basename(file_path)} ---")
    try:
        if file_path.endswith(".tsv"):
            df = pd.read_csv(file_path, sep='\t')
            print(df.head())
        elif file_path.endswith(".csv"):
            df = pd.read_csv(file_path)
            print(df.head())
        elif file_path.endswith(".mtx"):
          matrix = mmread(file_path).tocoo()
          print("First 5 non-zero entries:")
          for i in range(min(5, len(matrix.data))):
              print(f"Row: {matrix.row[i]}, Col: {matrix.col[i]}, Value: {matrix.data[i]}")
          print(f"\nMatrix shape: {matrix.shape}, NNZ (non-zero elements): {matrix.nnz}")
        else:
            print("Unsupported file type for preview.")
    except Exception as e:
        print(f"Could not preview {file_path}: {e}")

extract_dir = os.path.join(download_dir, "GSE300475_RAW")

for root, _, files in os.walk(extract_dir):
    for file in files:
        if file.endswith(".gz"):
            gz_file_path = os.path.join(root, file)
            decompressed_path = decompress_gz_file(gz_file_path, root)
            if decompressed_path:
                preview_file(decompressed_path)

In [None]:
import glob

# Find all "all_contig_annotations.csv" files in the extracted directory and sum their lengths (number of rows)

all_contig_files = glob.glob(os.path.join(extract_dir, "*_all_contig_annotations.csv"))
total_rows = 0

for file in all_contig_files:
    try:
        df = pd.read_csv(file)
        num_rows = len(df)
        print(f"{os.path.basename(file)}: {num_rows} rows")
        total_rows += num_rows
    except Exception as e:
        print(f"Could not read {file}: {e}")

print(f"\nTotal rows in all contig annotation files: {total_rows}")

## 2. Data Preprocessing and Quality Control

In [None]:
%pip install scanpy pandas numpy

In [None]:
# Import required libraries for single-cell RNA-seq analysis and data handling
import scanpy as sc  # Main library for single-cell analysis, provides AnnData structure and many tools
import pandas as pd  # For tabular data manipulation and metadata handling
import numpy as np   # For numerical operations and array handling
import os            # For operating system interactions (file paths, etc.)
from pathlib import Path  # For robust and readable file path management

# Print versions to ensure reproducibility and compatibility
print(f"Scanpy version: {sc.__version__}")
print(f"Pandas version: {pd.__version__}")

from collections import Counter
import warnings
warnings.filterwarnings('ignore')

print("All libraries imported successfully!")

### Load Sample Metadata

In [None]:
%%time
# --- Setup data paths ---
# Define the main data directory and the subdirectory containing raw files.
data_dir = Path('../Data')
raw_data_dir = data_dir / 'GSE300475_RAW'

# --- Manually create the metadata mapping ---
# This list contains information about each sample, including GEO IDs, patient IDs, timepoints, and response status.
# Note: S8 (GSM9061672) has GEX files but no corresponding TCR file.
metadata_list = [
    # Patient 1 (Responder)
    {'S_Number': 'S1',  'GEX_Sample_ID': 'GSM9061665', 'TCR_Sample_ID': 'GSM9061687', 'Patient_ID': 'PT1',  'Timepoint': 'Baseline',     'Response': 'Responder',     'In_Data': 'Yes',      'In_Article': 'Yes'},
    {'S_Number': 'S2',  'GEX_Sample_ID': 'GSM9061666', 'TCR_Sample_ID': 'GSM9061688', 'Patient_ID': 'PT1',  'Timepoint': 'Post-Chemo',  'Response': 'Responder',     'In_Data': 'Yes',      'In_Article': 'Yes'},
    # Patient 2 (Non-Responder)
    {'S_Number': 'S3',  'GEX_Sample_ID': 'GSM9061667', 'TCR_Sample_ID': 'GSM9061689', 'Patient_ID': 'PT2',  'Timepoint': 'Baseline',     'Response': 'Non-Responder', 'In_Data': 'Yes',      'In_Article': 'Yes'},
    {'S_Number': 'S4',  'GEX_Sample_ID': 'GSM9061668', 'TCR_Sample_ID': 'GSM9061690', 'Patient_ID': 'PT2',  'Timepoint': 'Post-Chemo',  'Response': 'Non-Responder', 'In_Data': 'Yes',      'In_Article': 'Yes'},
    # Patient 3 (Responder)
    {'S_Number': 'S5',  'GEX_Sample_ID': 'GSM9061669', 'TCR_Sample_ID': 'GSM9061691', 'Patient_ID': 'PT3',  'Timepoint': 'Baseline',     'Response': 'Responder',     'In_Data': 'Yes',      'In_Article': 'Yes'},
    {'S_Number': 'S6',  'GEX_Sample_ID': 'GSM9061670', 'TCR_Sample_ID': 'GSM9061692', 'Patient_ID': 'PT3',  'Timepoint': 'Post-Chemo',  'Response': 'Responder',     'In_Data': 'Yes',      'In_Article': 'Yes'},
    # Patient 4 (Non-Responder)
    {'S_Number': 'S7',  'GEX_Sample_ID': 'GSM9061671', 'TCR_Sample_ID': 'GSM9061693', 'Patient_ID': 'PT4',  'Timepoint': 'Baseline',     'Response': 'Non-Responder', 'In_Data': 'Yes',      'In_Article': 'Yes'},
    # Patient 5 (partial) - S8 exists as GEX only in the raw data but has no TCR file
    {'S_Number': 'S8',  'GEX_Sample_ID': 'GSM9061672', 'TCR_Sample_ID': None,             'Patient_ID': 'PT5',  'Timepoint': 'Unknown',      'Response': 'Unknown',       'In_Data': 'GEX only', 'In_Article': 'Yes'},
    {'S_Number': 'S9',  'GEX_Sample_ID': 'GSM9061673', 'TCR_Sample_ID': 'GSM9061694', 'Patient_ID': 'PT5',  'Timepoint': 'Baseline',     'Response': 'Responder',     'In_Data': 'Yes',      'In_Article': 'Yes'},
    {'S_Number': 'S10', 'GEX_Sample_ID': 'GSM9061674', 'TCR_Sample_ID': 'GSM9061695', 'Patient_ID': 'PT5',  'Timepoint': 'Post-ICI',     'Response': 'Responder',     'In_Data': 'Yes',      'In_Article': 'Yes'},
    # Patient 11 (Responder)
    {'S_Number': 'S11', 'GEX_Sample_ID': 'GSM9061675', 'TCR_Sample_ID': 'GSM9061696', 'Patient_ID': 'PT11', 'Timepoint': 'Endpoint',      'Response': 'Responder',     'In_Data': 'Yes',      'In_Article': 'Yes'},
]

# --- Create DataFrame and display the verification table ---
metadata_df = pd.DataFrame(metadata_list)
print("Metadata table now matches the requested specification:")
display(metadata_df)

# --- Programmatic sanity-check for file presence ---
# This loop checks if the expected files exist for each sample and updates the 'In_Data' column accordingly.
for idx, row in metadata_df.iterrows():
    s = row['S_Number']
    g = row['GEX_Sample_ID']
    t = row['TCR_Sample_ID']
    # Check for gene expression matrix file (compressed or uncompressed)
    g_exists = (raw_data_dir / f"{g}_{s}_matrix.mtx.gz").exists() or (raw_data_dir / f"{g}_{s}_matrix.mtx").exists()
    t_exists = False
    # Check for TCR annotation file if TCR sample ID is present
    if pd.notna(t) and t is not None:
        t_exists = (raw_data_dir / f"{t}_{s}_all_contig_annotations.csv.gz").exists() or (raw_data_dir / f"{t}_{s}_all_contig_annotations.csv").exists()
    # Update 'In_Data' column based on file presence
    if g_exists and t_exists:
        metadata_df.at[idx, 'In_Data'] = 'Yes'
    elif g_exists and not t_exists:
        metadata_df.at[idx, 'In_Data'] = 'GEX only'
    else:
        metadata_df.at[idx, 'In_Data'] = 'No'

print("\nPost-check In_Data column (based on files found in Data/GSE300475_RAW):")
display(metadata_df)

### Process and Concatenate AnnData Objects

In [None]:
%%time
# --- Initialize lists to hold AnnData and TCR data for each sample ---
adata_list = []  # Will store AnnData objects for each sample
tcr_data_list = []  # Will store TCR dataframes for each sample

# --- Iterate through each sample in the metadata table ---
for index, row in metadata_df.iterrows():
    gex_sample_id = row['GEX_Sample_ID']
    tcr_sample_id = row['TCR_Sample_ID']
    s_number = row['S_Number']
    patient_id = row['Patient_ID']
    timepoint = row['Timepoint']
    response = row['Response']
    
    # Construct the file prefix for this sample (used for locating files)
    sample_prefix = f"{gex_sample_id}_{s_number}"
    sample_data_path = raw_data_dir
    
    # --- Check for gene expression matrix file ---
    matrix_file = sample_data_path / f"{sample_prefix}_matrix.mtx.gz"
    if not matrix_file.exists():
        # Try uncompressed version if gzipped file not found
        matrix_file_un = sample_data_path / f"{sample_prefix}_matrix.mtx"
        if not matrix_file_un.exists():
            print(f"GEX data not found for sample {sample_prefix}, skipping.")
            continue
        else:
            matrix_file = matrix_file_un
            
    print(f"Processing GEX sample: {sample_prefix}")
    
    # --- Load gene expression data into AnnData object ---
    # The prefix ensures only files for this sample are loaded
    adata_sample = sc.read_10x_mtx(
        sample_data_path, 
        var_names='gene_symbols',
        prefix=f"{sample_prefix}_",
    )
    
    # --- Add sample metadata to AnnData.obs ---
    adata_sample.obs['sample_id'] = gex_sample_id 
    adata_sample.obs['patient_id'] = patient_id
    adata_sample.obs['timepoint'] = timepoint
    adata_sample.obs['response'] = response
    
    adata_list.append(adata_sample)
    
    # --- Load TCR data if available ---
    if pd.isna(tcr_sample_id) or tcr_sample_id is None:
        print(f"No TCR sample for {gex_sample_id}_{s_number}, skipping TCR load.")
        continue

    # Construct path for TCR annotation file (gzipped or uncompressed)
    tcr_file_path = raw_data_dir / f"{tcr_sample_id}_{s_number}_all_contig_annotations.csv.gz"

    if tcr_file_path.exists():
        print(f"Found and loading TCR data: {tcr_file_path.name}")
        tcr_df = pd.read_csv(tcr_file_path)
        # Add sample_id for merging later
        tcr_df['sample_id'] = gex_sample_id 
        tcr_data_list.append(tcr_df)
    else:
        # Try uncompressed version if gzipped file not found
        tcr_file_path_uncompressed = raw_data_dir / f"{tcr_sample_id}_{s_number}_all_contig_annotations.csv"
        if tcr_file_path_uncompressed.exists():
            print(f"Found and loading TCR data: {tcr_file_path_uncompressed.name}")
            tcr_df = pd.read_csv(tcr_file_path_uncompressed)
            tcr_df['sample_id'] = gex_sample_id
            tcr_data_list.append(tcr_df)
        else:
            print(f"TCR data not found for {tcr_sample_id}_{s_number}")

# --- Concatenate all loaded AnnData objects into one ---
if adata_list:
    # Use sample_id as batch key for concatenation
    loaded_batches = [a.obs['sample_id'].unique()[0] for a in adata_list]
    adata = sc.AnnData.concatenate(*adata_list, join='outer', batch_key='sample_id', batch_categories=loaded_batches)
    print("\nConcatenated AnnData object:")
    print(adata)
else:
    print("No data was loaded.")

# --- Concatenate all loaded TCR dataframes into one ---
if tcr_data_list:
    full_tcr_df = pd.concat(tcr_data_list, ignore_index=True)
    print("\nFull TCR data:")
    display(full_tcr_df.head())
else:
    print("No TCR data was loaded.")

### Integrate TCR Data and Perform QC

In [None]:
%%time
# --- Integrate TCR data into AnnData.obs and perform quality control ---
if 'full_tcr_df' in locals() and not full_tcr_df.empty:
    # --- FIX START ---
    # The previous join failed because one cell (barcode) can have multiple TCR contigs (e.g., TRA and TRB chains),
    # creating a one-to-many join that increases the number of rows.
    # The fix is to aggregate the TCR data to one row per cell *before* merging.

    # 1. Filter for high-confidence, productive TRA/TRB chains.
    # Only keep TCR contigs that are both high-confidence and productive, and are either TRA or TRB chains.
    tcr_to_agg = full_tcr_df[
        (full_tcr_df['high_confidence'] == True) &
        (full_tcr_df['productive'] == True) &
        (full_tcr_df['chain'].isin(['TRA', 'TRB']))
    ].copy()

    # 2. Pivot the data to create one row per barcode, with columns for TRA and TRB data.
    # This step ensures each cell (barcode) has its TRA and TRB info in separate columns.
    tcr_aggregated = tcr_to_agg.pivot_table(
        index=['sample_id', 'barcode'],
        columns='chain',
        values=['v_gene', 'j_gene', 'cdr3'],
        aggfunc='first'  # 'first' is safe as we expect at most one productive TRA/TRB per cell
    )

    # 3. Flatten the multi-level column index (e.g., from ('v_gene', 'TRA') to 'v_gene_TRA')
    tcr_aggregated.columns = ['_'.join(col).strip() for col in tcr_aggregated.columns.values]
    tcr_aggregated.reset_index(inplace=True)

    # 4. Prepare adata.obs for the merge by creating a matching barcode column.
    # The index in adata.obs is like 'AGCCATGCAGCTGTTA-1-0' (barcode-batch_id).
    # The barcode in TCR data is like 'AGCCATGCAGCTGTTA-1'.
    adata.obs['barcode_for_merge'] = adata.obs.index.str.rsplit('-', n=1).str[0]

    # 5. Perform a left merge. This keeps all cells from adata and adds TCR info where available.
    # The number of rows will not change because tcr_aggregated has unique barcodes.
    original_obs = adata.obs.copy()
    merged_obs = original_obs.merge(
        tcr_aggregated,
        left_on=['sample_id', 'barcode_for_merge'],
        right_on=['sample_id', 'barcode'],
        how='left'
    )
    
    # 6. Restore the original index to the merged dataframe.
    merged_obs.index = original_obs.index
    adata.obs = merged_obs
    # --- FIX END ---

    print("Aggregated TCR data merged into AnnData object.")
    
    # --- Filter for cells that have TCR information after the merge ---
    # Only keep cells with non-null v_gene_TRA (i.e., cells with high-confidence TCR data)
    initial_cells = adata.n_obs
    adata = adata[~adata.obs['v_gene_TRA'].isna()].copy()
    print(f"Filtered from {initial_cells} to {adata.n_obs} cells based on having high-confidence TCR data.")

# --- Basic QC and filtering ---
# Filter out cells with fewer than 200 genes detected
sc.pp.filter_cells(adata, min_genes=200)
# Filter out genes detected in fewer than 3 cells
sc.pp.filter_genes(adata, min_cells=3)

# Annotate mitochondrial genes for QC metrics
adata.var['mt'] = adata.var_names.str.startswith('MT-')
# Calculate QC metrics (e.g., percent mitochondrial genes)
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

print("\nPost-QC AnnData object:")
print(adata)
display(adata.obs.head())

### Save Processed Data

In [None]:
%%time
# --- Save processed AnnData object to disk ---
# Define output directory for processed data
output_dir = Path('../Processed_Data')
output_dir.mkdir(exist_ok=True)  # Create directory if it doesn't exist

# Define output file path for the .h5ad file
output_path = output_dir / 'processed_s_rna_seq_data.h5ad'
# Save the AnnData object (contains all processed, filtered, and annotated data)
adata.write_h5ad(output_path)

print(f"Processed data saved to: {output_path}")

## 3. Installing Required Libraries

In [None]:
%%time
# --- Install required packages for genetic sequence encoding and ML ---
%pip install biopython
%pip install scikit-learn
%pip install umap-learn
%pip install hdbscan
%pip install plotly
%pip install xgboost
%pip install tensorflow

from Bio.Seq import Seq
from Bio.SeqUtils import ProtParam
import xgboost as xgb
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

# Import scipy for hierarchical clustering
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.spatial.distance import pdist
from scipy.stats import mannwhitneyu

# Visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import ConfusionMatrixDisplay

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
from sklearn.metrics import silhouette_score, adjusted_rand_score, classification_report, confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score, accuracy_score, classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold, GridSearchCV
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
import umap
import hdbscan
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

print("Additional libraries installed!")

## 4. Defining Sequence Encoding Functions

In [None]:
%%time
# --- Genetic Sequence Encoding Functions ---

def one_hot_encode_sequence(sequence, max_length=50, alphabet='ACDEFGHIKLMNPQRSTVWY'):
    """
    One-hot encode a protein/nucleotide sequence.
    
    Args:
        sequence: String sequence to encode
        max_length: Maximum sequence length (pad or truncate)
        alphabet: Valid characters in the sequence
    
    Returns:
        2D numpy array of shape (max_length, len(alphabet))
    """
    if pd.isna(sequence) or sequence == 'NA' or sequence == '':
        return np.zeros((max_length, len(alphabet)))
    
    sequence = str(sequence).upper()[:max_length]  # Truncate if too long
    encoding = np.zeros((max_length, len(alphabet)))
    
    for i, char in enumerate(sequence):
        if char in alphabet:
            char_idx = alphabet.index(char)
            encoding[i, char_idx] = 1
    
    return encoding

def kmer_encode_sequence(sequence, k=3, alphabet='ACDEFGHIKLMNPQRSTVWY'):
    """
    K-mer encoding of sequences.
    
    Args:
        sequence: String sequence to encode
        k: Length of k-mers
        alphabet: Valid characters
    
    Returns:
        Dictionary with k-mer counts
    """
    if pd.isna(sequence) or sequence == 'NA' or sequence == '':
        return {}
    
    sequence = str(sequence).upper()
    kmers = [sequence[i:i+k] for i in range(len(sequence)-k+1)]
    valid_kmers = [kmer for kmer in kmers if all(c in alphabet for c in kmer)]
    
    return Counter(valid_kmers)

def physicochemical_features(sequence):
    """
    Extract physicochemical properties from protein sequences.
    
    Args:
        sequence: Protein sequence string
    
    Returns:
        Dictionary of features
    """
    if pd.isna(sequence) or sequence == 'NA' or sequence == '':
        return {
            'length': 0, 'molecular_weight': 0, 'aromaticity': 0,
            'instability_index': 0, 'isoelectric_point': 0, 'hydrophobicity': 0
        }
    
    try:
        seq = str(sequence).upper()
        # Remove non-standard amino acids
        seq = ''.join([c for c in seq if c in 'ACDEFGHIKLMNPQRSTVWY'])
        
        if len(seq) == 0:
            return {
                'length': 0, 'molecular_weight': 0, 'aromaticity': 0,
                'instability_index': 0, 'isoelectric_point': 0, 'hydrophobicity': 0
            }
        
        bio_seq = Seq(seq)
        analyzer = ProtParam.ProteinAnalysis(str(bio_seq))
        
        return {
            'length': len(seq),
            'molecular_weight': analyzer.molecular_weight(),
            'aromaticity': analyzer.aromaticity(),
            'instability_index': analyzer.instability_index(),
            'isoelectric_point': analyzer.isoelectric_point(),
            'hydrophobicity': analyzer.gravy()  # Grand Average of Hydropathy
        }
    except:
        return {
            'length': len(str(sequence)) if not pd.isna(sequence) else 0,
            'molecular_weight': 0, 'aromaticity': 0,
            'instability_index': 0, 'isoelectric_point': 0, 'hydrophobicity': 0
        }

def encode_gene_expression_patterns(adata, n_top_genes=1000):
    """
    Encode gene expression patterns using various dimensionality reduction techniques.
    
    Args:
        adata: AnnData object with gene expression data
        n_top_genes: Number of highly variable genes to use
    
    Returns:
        Dictionary of encoded representations
    """
    # Get highly variable genes if not already computed
    if 'highly_variable' not in adata.var.columns:
        sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, subset=False)
    
    # Extract expression matrix for highly variable genes
    hvg_mask = adata.var['highly_variable'] if 'highly_variable' in adata.var.columns else adata.var.index[:n_top_genes]
    X_hvg = adata[:, hvg_mask].X.toarray() if hasattr(adata.X, 'toarray') else adata[:, hvg_mask].X
    
    # Standardize the data
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_hvg)
    
    encodings = {}
    
    # PCA encoding
    pca = PCA(n_components=50)
    encodings['pca'] = pca.fit_transform(X_scaled)
    
    # TruncatedSVD for sparse matrices
    svd = TruncatedSVD(n_components=50, random_state=42)
    encodings['svd'] = svd.fit_transform(X_scaled)
    
    # UMAP encoding
    umap_encoder = umap.UMAP(n_components=20, random_state=42)
    encodings['umap'] = umap_encoder.fit_transform(X_scaled)
    
    return encodings, X_scaled

print("Genetic sequence encoding functions defined successfully!")

## 5. Applying Encodings to Data

In [None]:
%%time
# --- Applying Encodings to Combined Data ---

print("Applying encodings to combined data...")

# Apply gene expression pattern encodings
print("Encoding gene expression patterns...")
expression_encodings, X_scaled = encode_gene_expression_patterns(combined_adata, n_top_genes=1000)

# Apply sequence encodings to contig data
print("Encoding genetic sequences...")
sequence_features = []

for idx, row in combined_contigs.iterrows():
    if idx % 1000 == 0:
        print(f"Processing sequence {idx}/{len(combined_contigs)}")
    
    # One-hot encoding
    oh_encoding = one_hot_encode_sequence(row['cdr3_aa'], max_length=30)
    oh_flat = oh_encoding.flatten()
    
    # K-mer encoding (3-mers)
    kmer_counts = kmer_encode_sequence(row['cdr3_aa'], k=3)
    kmer_features = [kmer_counts.get(kmer, 0) for kmer in all_possible_kmers]
    
    # Physicochemical features
    phys_features = physicochemical_features(row['cdr3_aa'])
    phys_values = list(phys_features.values())
    
    # Combine all features
    combined_features = np.concatenate([
        oh_flat,
        np.array(kmer_features),
        np.array(phys_values)
    ])
    
    sequence_features.append(combined_features)

sequence_features = np.array(sequence_features)
print(f"Sequence features shape: {sequence_features.shape}")

# Combine expression and sequence features
print("Combining expression and sequence features...")
combined_features = []

for i in range(len(combined_adata)):
    # Get expression encodings for this cell
    expr_pca = expression_encodings['pca'][i]
    expr_umap = expression_encodings['umap'][i]
    
    # For now, we'll use a simple approach: average sequence features for cells with multiple contigs
    # In a more sophisticated approach, you might want to aggregate or select representative contigs
    cell_contigs = combined_contigs[combined_contigs['sample'] == combined_adata.obs.index[i]]
    
    if len(cell_contigs) > 0:
        # Average sequence features across contigs for this cell
        cell_seq_features = sequence_features[cell_contigs.index]
        avg_seq_features = np.mean(cell_seq_features, axis=0)
    else:
        # If no contigs, use zeros
        avg_seq_features = np.zeros(sequence_features.shape[1])
    
    # Combine expression and sequence features
    combined = np.concatenate([expr_pca, expr_umap, avg_seq_features])
    combined_features.append(combined)

combined_features = np.array(combined_features)
print(f"Combined features shape: {combined_features.shape}")

# Standardize combined features
scaler_combined = StandardScaler()
X_combined_scaled = scaler_combined.fit_transform(combined_features)

print("Encodings applied successfully!")
print(f"Final feature matrix shape: {X_combined_scaled.shape}")

## 6. Unsupervised Clustering Analysis

In [None]:
%%time
# --- Unsupervised Clustering Analysis ---

print("Performing unsupervised clustering analysis...")

# Define clustering algorithms to test
clustering_algorithms = {
    'K-Means': KMeans(n_clusters=5, random_state=42, n_init=10),
    'HDBSCAN': hdbscan.HDBSCAN(min_cluster_size=50, min_samples=10),
    'Agglomerative': AgglomerativeClustering(n_clusters=5),
    'DBSCAN': DBSCAN(eps=0.5, min_samples=10),
    'GaussianMixture': GaussianMixture(n_components=5, random_state=42)
}

# Store clustering results
clustering_results = {}
silhouette_scores = {}

# Sample data for faster computation (optional)
sample_size = min(5000, len(X_combined_scaled))
np.random.seed(42)
sample_indices = np.random.choice(len(X_combined_scaled), sample_size, replace=False)
X_sample = X_combined_scaled[sample_indices]

print(f"Clustering on {sample_size} samples...")

for name, algorithm in clustering_algorithms.items():
    print(f"Running {name}...")
    try:
        if name == 'GaussianMixture':
            labels = algorithm.fit_predict(X_sample)
        else:
            labels = algorithm.fit_predict(X_sample)
        
        clustering_results[name] = labels
        
        # Calculate silhouette score if more than 1 cluster
        if len(np.unique(labels)) > 1:
            sil_score = silhouette_score(X_sample, labels)
            silhouette_scores[name] = sil_score
            print(f"{name} silhouette score: {sil_score:.3f}")
        else:
            silhouette_scores[name] = -1
            print(f"{name} found only 1 cluster")
            
    except Exception as e:
        print(f"Error with {name}: {e}")
        clustering_results[name] = None
        silhouette_scores[name] = -1

# Find best clustering algorithm
best_algorithm = max(silhouette_scores, key=silhouette_scores.get)
best_labels = clustering_results[best_algorithm]

print(f"\nBest clustering algorithm: {best_algorithm} (silhouette: {silhouette_scores[best_algorithm]:.3f})")

# Apply best clustering to full dataset
print("Applying best clustering to full dataset...")
if best_algorithm == 'GaussianMixture':
    final_labels = GaussianMixture(n_components=5, random_state=42).fit_predict(X_combined_scaled)
else:
    final_labels = clustering_algorithms[best_algorithm].fit_predict(X_combined_scaled)

# Add clustering results to AnnData
combined_adata.obs['cluster'] = final_labels.astype(str)

print("Clustering analysis completed!")
print(f"Number of clusters found: {len(np.unique(final_labels))}")

## 7. Visualization of Clustering Results

In [None]:
%%time
# --- Visualization of Clustering Results ---

print("Creating visualizations of clustering results...")

# Compute UMAP for visualization
umap_reducer = umap.UMAP(n_components=2, random_state=42)
umap_coords = umap_reducer.fit_transform(X_combined_scaled)

# Add to AnnData
combined_adata.obsm['X_umap'] = umap_coords

# Create visualization plots
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Plot 1: UMAP colored by cluster
scatter = axes[0, 0].scatter(umap_coords[:, 0], umap_coords[:, 1],
                           c=final_labels, cmap='tab10', alpha=0.7, s=10)
axes[0, 0].set_title('UMAP - Clusters')
axes[0, 0].set_xlabel('UMAP1')
axes[0, 0].set_ylabel('UMAP2')
plt.colorbar(scatter, ax=axes[0, 0], label='Cluster')

# Plot 2: UMAP colored by sample
sample_colors = plt.cm.rainbow(np.linspace(0, 1, len(combined_adata.obs['sample'].unique())))
sample_color_map = dict(zip(combined_adata.obs['sample'].unique(), sample_colors))

for sample in combined_adata.obs['sample'].unique():
    mask = combined_adata.obs['sample'] == sample
    axes[0, 1].scatter(umap_coords[mask, 0], umap_coords[mask, 1],
                      color=sample_color_map[sample], alpha=0.7, s=10, label=sample)

axes[0, 1].set_title('UMAP - Samples')
axes[0, 1].set_xlabel('UMAP1')
axes[0, 1].set_ylabel('UMAP2')
axes[0, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# Plot 3: Cluster distribution
cluster_counts = pd.Series(final_labels).value_counts().sort_index()
axes[1, 0].bar(range(len(cluster_counts)), cluster_counts.values)
axes[1, 0].set_title('Cluster Size Distribution')
axes[1, 0].set_xlabel('Cluster')
axes[1, 0].set_ylabel('Number of Cells')
axes[1, 0].set_xticks(range(len(cluster_counts)))
axes[1, 0].set_xticklabels([f'Cluster {i}' for i in cluster_counts.index])

# Plot 4: Silhouette scores comparison
algorithms = list(silhouette_scores.keys())
scores = list(silhouette_scores.values())
bars = axes[1, 1].bar(algorithms, scores, color=['red' if s == max(scores) else 'blue' for s in scores])
axes[1, 1].set_title('Clustering Algorithm Comparison')
axes[1, 1].set_ylabel('Silhouette Score')
axes[1, 1].set_ylim(0, 1)
axes[1, 1].tick_params(axis='x', rotation=45)

# Add value labels on bars
for bar, score in zip(bars, scores):
    height = bar.get_height()
    axes[1, 1].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                   '.3f', ha='center', va='bottom')

plt.tight_layout()
plt.savefig(os.path.join(processed_data_dir, 'clustering_visualization.png'), dpi=300, bbox_inches='tight')
plt.show()

# Additional visualization: t-SNE
print("Computing t-SNE for additional visualization...")
tsne_reducer = TSNE(n_components=2, random_state=42, perplexity=30)
tsne_coords = tsne_reducer.fit_transform(X_combined_scaled[:5000])  # t-SNE on subsample for speed

plt.figure(figsize=(8, 6))
scatter_tsne = plt.scatter(tsne_coords[:, 0], tsne_coords[:, 1],
                          c=final_labels[:5000], cmap='tab10', alpha=0.7, s=10)
plt.title('t-SNE - Clusters ( subsample)')
plt.xlabel('t-SNE1')
plt.ylabel('t-SNE2')
plt.colorbar(scatter_tsne, label='Cluster')
plt.savefig(os.path.join(processed_data_dir, 'tsne_clusters.png'), dpi=300, bbox_inches='tight')
plt.show()

print("Visualization completed!")

## 8. Supervised Learning Analysis

In [None]:
%%time
# --- Supervised Learning Analysis ---

print("Performing supervised learning analysis...")

# Prepare target variable (clusters as labels)
y = final_labels

# Split data for training and testing
X_train, X_test, y_train, y_test = train_test_split(
    X_combined_scaled, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")

# Define models to test
models = {
    'Random Forest': RandomForestClassifier(random_state=42),
    'XGBoost': xgb.XGBClassifier(random_state=42, eval_metric='mlogloss'),
    'SVM': SVC(random_state=42, probability=True),
    'Logistic Regression': LogisticRegression(random_state=42, max_iter=1000),
    'K-Nearest Neighbors': KNeighborsClassifier()
}

# Define parameter grids for GridSearchCV
param_grids = {
    'Random Forest': {
        'n_estimators': [100, 200],
        'max_depth': [10, 20, None],
        'min_samples_split': [2, 5]
    },
    'XGBoost': {
        'n_estimators': [100, 200],
        'max_depth': [3, 6, 9],
        'learning_rate': [0.1, 0.2]
    },
    'SVM': {
        'C': [0.1, 1, 10],
        'kernel': ['rbf', 'linear']
    },
    'Logistic Regression': {
        'C': [0.1, 1, 10],
        'penalty': ['l1', 'l2']
    },
    'K-Nearest Neighbors': {
        'n_neighbors': [3, 5, 7],
        'weights': ['uniform', 'distance']
    }
}

# Perform k-fold cross-validation with GridSearchCV
cv_results = {}
best_models = {}

print("Performing k-fold cross-validation with hyperparameter tuning...")

for name, model in models.items():
    print(f"\nTuning {name}...")
    
    # Use StratifiedKFold for balanced class distribution
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    
    grid_search = GridSearchCV(
        model, param_grids[name], cv=cv, scoring='accuracy',
        n_jobs=-1, verbose=1
    )
    
    grid_search.fit(X_train, y_train)
    
    cv_results[name] = {
        'best_params': grid_search.best_params_,
        'best_score': grid_search.best_score_,
        'cv_results': grid_search.cv_results_
    }
    
    best_models[name] = grid_search.best_estimator_
    
    print(f"Best CV score: {grid_search.best_score_:.3f}")
    print(f"Best parameters: {grid_search.best_params_}")

# Evaluate best models on test set
test_results = {}

print("\nEvaluating models on test set...")
for name, model in best_models.items():
    y_pred = model.predict(X_test)
    y_proba = model.predict_proba(X_test) if hasattr(model, 'predict_proba') else None
    
    accuracy = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average='weighted')
    
    test_results[name] = {
        'accuracy': accuracy,
        'f1_score': f1,
        'predictions': y_pred,
        'probabilities': y_proba
    }
    
    print(f"{name} - Test Accuracy: {accuracy:.3f}, F1-Score: {f1:.3f}")

# Find best performing model
best_model_name = max(test_results, key=lambda x: test_results[x]['f1_score'])
best_model = best_models[best_model_name]

print(f"\nBest performing model: {best_model_name}")
print(f"Test F1-Score: {test_results[best_model_name]['f1_score']:.3f}")

# Save best model
model_path = os.path.join(processed_data_dir, 'best_supervised_model.pkl')
joblib.dump(best_model, model_path)
print(f"Best model saved to: {model_path}")

print("Supervised learning analysis completed!")

## 9. Model Evaluation and Interpretation

In [None]:
%%time
# --- Model Evaluation and Interpretation ---

print("Evaluating and interpreting model results...")

# Create evaluation plots
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Plot 1: Model comparison
model_names = list(test_results.keys())
accuracies = [test_results[name]['accuracy'] for name in model_names]
f1_scores = [test_results[name]['f1_score'] for name in model_names]

x = np.arange(len(model_names))
width = 0.35

bars1 = axes[0, 0].bar(x - width/2, accuracies, width, label='Accuracy', alpha=0.8)
bars2 = axes[0, 0].bar(x + width/2, f1_scores, width, label='F1-Score', alpha=0.8)

axes[0, 0].set_title('Model Performance Comparison')
axes[0, 0].set_xlabel('Model')
axes[0, 0].set_ylabel('Score')
axes[0, 0].set_xticks(x)
axes[0, 0].set_xticklabels(model_names, rotation=45)
axes[0, 0].legend()
axes[0, 0].set_ylim(0, 1)

# Add value labels
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        axes[0, 0].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                       '.3f', ha='center', va='bottom')

# Plot 2: Confusion matrix for best model
best_pred = test_results[best_model_name]['predictions']
cm = confusion_matrix(y_test, best_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=np.unique(y))
disp.plot(ax=axes[0, 1], cmap='Blues', values_format='d')
axes[0, 1].set_title(f'Confusion Matrix - {best_model_name}')

# Plot 3: Feature importance (if available)
if hasattr(best_model, 'feature_importances_'):
    feature_importance = best_model.feature_importances_
    # Get top 20 features
    top_indices = np.argsort(feature_importance)[-20:]
    top_features = feature_importance[top_indices]
    feature_names = [f'Feature_{i}' for i in top_indices]
    
    axes[1, 0].barh(range(len(top_features)), top_features)
    axes[1, 0].set_yticks(range(len(top_features)))
    axes[1, 0].set_yticklabels(feature_names)
    axes[1, 0].set_xlabel('Importance')
    axes[1, 0].set_title('Top 20 Feature Importances')
else:
    axes[1, 0].text(0.5, 0.5, 'Feature importance\nnot available\nfor this model',
                   ha='center', va='center', transform=axes[1, 0].transAxes)
    axes[1, 0].set_title('Feature Importances')

# Plot 4: ROC curves (if binary or adapted for multiclass)
if len(np.unique(y)) == 2:
    # Binary classification
    y_proba = test_results[best_model_name]['probabilities'][:, 1]
    fpr, tpr, _ = roc_curve(y_test, y_proba)
    roc_auc = auc(fpr, tpr)
    
    axes[1, 1].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    axes[1, 1].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    axes[1, 1].set_xlim([0.0, 1.0])
    axes[1, 1].set_ylim([0.0, 1.05])
    axes[1, 1].set_xlabel('False Positive Rate')
    axes[1, 1].set_ylabel('True Positive Rate')
    axes[1, 1].set_title('ROC Curve')
    axes[1, 1].legend(loc="lower right")
else:
    # Multiclass - plot precision-recall curves
    y_test_bin = label_binarize(y_test, classes=np.unique(y))
    n_classes = y_test_bin.shape[1]
    
    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    
    if test_results[best_model_name]['probabilities'] is not None:
        y_proba = test_results[best_model_name]['probabilities']
        
        for i in range(n_classes):
            fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_proba[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])
        
        # Plot micro-average ROC curve
        fpr["micro"], tpr["micro"], _ = roc_curve(y_test_bin.ravel(), y_proba.ravel())
        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
        
        axes[1, 1].plot(fpr["micro"], tpr["micro"],
                       label=f'micro-average ROC curve (AUC = {roc_auc["micro"]:.2f})',
                       color='deeppink', linestyle=':', linewidth=4)
        
        axes[1, 1].plot([0, 1], [0, 1], 'k--', lw=2)
        axes[1, 1].set_xlim([0.0, 1.0])
        axes[1, 1].set_ylim([0.0, 1.05])
        axes[1, 1].set_xlabel('False Positive Rate')
        axes[1, 1].set_ylabel('True Positive Rate')
        axes[1, 1].set_title('Multiclass ROC Curves')
        axes[1, 1].legend(loc="lower right")
    else:
        axes[1, 1].text(0.5, 0.5, 'ROC curves require\nprobability estimates',
                       ha='center', va='center', transform=axes[1, 1].transAxes)
        axes[1, 1].set_title('ROC Curves')

plt.tight_layout()
plt.savefig(os.path.join(processed_data_dir, 'model_evaluation.png'), dpi=300, bbox_inches='tight')
plt.show()

# Print detailed classification report
print("Detailed Classification Report for Best Model:")
print(classification_report(y_test, best_pred, target_names=[f'Cluster_{i}' for i in np.unique(y)]))

# Save evaluation results
evaluation_summary = {
    'best_model': best_model_name,
    'test_accuracy': test_results[best_model_name]['accuracy'],
    'test_f1_score': test_results[best_model_name]['f1_score'],
    'cv_results': cv_results,
    'test_results': {k: {kk: vv for kk, vv in v.items() if kk not in ['predictions', 'probabilities']}
                    for k, v in test_results.items()}
}

with open(os.path.join(processed_data_dir, 'model_evaluation_summary.json'), 'w') as f:
    json.dump(evaluation_summary, f, indent=2, default=str)

print("Model evaluation and interpretation completed!")
print(f"Results saved to: {os.path.join(processed_data_dir, 'model_evaluation_summary.json')}")

## 10. Cluster Analysis and Biological Interpretation

In [None]:
%%time
# --- Cluster Analysis and Biological Interpretation ---

print("Performing cluster analysis and biological interpretation...")

# Identify marker genes for each cluster using Scanpy
print("Identifying marker genes for each cluster...")
sc.tl.rank_genes_groups(combined_adata, 'cluster', method='wilcoxon', use_raw=False)

# Get top marker genes for each cluster
marker_genes = {}
for cluster in combined_adata.obs['cluster'].unique():
    cluster_markers = sc.get.rank_genes_groups_df(combined_adata, group=cluster)
    top_markers = cluster_markers.head(20)  # Top 20 markers per cluster
    marker_genes[cluster] = top_markers

# Save marker genes
marker_df = pd.concat([df.assign(cluster=cluster) for cluster, df in marker_genes.items()])
marker_df.to_csv(os.path.join(processed_data_dir, 'cluster_marker_genes.csv'), index=False)

print("Marker genes identified and saved!")

# Analyze cluster characteristics
cluster_stats = {}

for cluster in np.unique(final_labels):
    cluster_mask = final_labels == cluster
    cluster_data = combined_adata[cluster_mask]
    
    # Basic statistics
    n_cells = len(cluster_data)
    sample_dist = cluster_data.obs['sample'].value_counts()
    
    # Expression statistics
    mean_expr = np.mean(cluster_data.X.toarray(), axis=0)
    var_expr = np.var(cluster_data.X.toarray(), axis=0)
    
    # Most variable genes in this cluster
    top_var_indices = np.argsort(var_expr)[-10:]
    top_var_genes = combined_adata.var.index[top_var_indices]
    
    cluster_stats[f'Cluster_{cluster}'] = {
        'n_cells': n_cells,
        'sample_distribution': sample_dist.to_dict(),
        'top_variable_genes': top_var_genes.tolist(),
        'mean_expression_top10': mean_expr[top_var_indices].tolist()
    }

# Save cluster statistics
with open(os.path.join(processed_data_dir, 'cluster_statistics.json'), 'w') as f:
    json.dump(cluster_stats, f, indent=2)

print("Cluster statistics computed!")

# Biological interpretation
print("Performing biological interpretation...")

# Literature comparison (simplified - in practice, you'd use databases like GO, KEGG, etc.)
biological_interpretation = {}

for cluster in np.unique(final_labels):
    cluster_name = f'Cluster_{cluster}'
    stats = cluster_stats[cluster_name]
    
    # Simple interpretation based on marker genes and characteristics
    interpretation = {
        'size': 'Large' if stats['n_cells'] > np.mean([s['n_cells'] for s in cluster_stats.values()]) else 'Small',
        'sample_preference': max(stats['sample_distribution'], key=stats['sample_distribution'].get),
        'potential_biological_role': 'To be determined based on marker gene analysis',
        'key_characteristics': f"Contains {stats['n_cells']} cells, prefers {max(stats['sample_distribution'], key=stats['sample_distribution'].get)}"
    }
    
    biological_interpretation[cluster_name] = interpretation

# Save biological interpretation
with open(os.path.join(processed_data_dir, 'biological_interpretation.json'), 'w') as f:
    json.dump(biological_interpretation, f, indent=2)

# Create summary visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Plot 1: Cluster sizes
cluster_sizes = [stats['n_cells'] for stats in cluster_stats.values()]
axes[0, 0].bar(range(len(cluster_sizes)), cluster_sizes)
axes[0, 0].set_title('Cluster Sizes')
axes[0, 0].set_xlabel('Cluster')
axes[0, 0].set_ylabel('Number of Cells')
axes[0, 0].set_xticks(range(len(cluster_sizes)))
axes[0, 0].set_xticklabels([f'Cluster {i}' for i in range(len(cluster_sizes))])

# Plot 2: Sample distribution across clusters
sample_cluster_dist = pd.DataFrame({
    cluster: pd.Series(stats['sample_distribution'])
    for cluster, stats in cluster_stats.items()
}).fillna(0)

sample_cluster_dist.plot(kind='bar', ax=axes[0, 1], stacked=True)
axes[0, 1].set_title('Sample Distribution Across Clusters')
axes[0, 1].set_xlabel('Sample')
axes[0, 1].set_ylabel('Number of Cells')
axes[0, 1].legend(title='Cluster', bbox_to_anchor=(1.05, 1), loc='upper left')

# Plot 3: Top marker genes heatmap (simplified)
if len(marker_genes) > 0:
    # Get top 5 genes per cluster
    top5_per_cluster = {}
    for cluster, df in marker_genes.items():
        top5_per_cluster[cluster] = df.head(5)['names'].tolist()
    
    # Find unique genes
    all_top_genes = list(set([gene for genes in top5_per_cluster.values() for gene in genes]))
    
    # Create expression matrix for top genes
    gene_expr_matrix = combined_adata[:, all_top_genes].X.toarray()
    
    # Create cluster-averaged expression
    cluster_avg_expr = []
    for cluster in np.unique(final_labels):
        cluster_mask = final_labels == cluster
        cluster_expr = np.mean(gene_expr_matrix[cluster_mask], axis=0)
        cluster_avg_expr.append(cluster_expr)
    
    cluster_avg_expr = np.array(cluster_avg_expr)
    
    sns.heatmap(cluster_avg_expr, xticklabels=all_top_genes, yticklabels=[f'Cluster {i}' for i in np.unique(final_labels)],
                ax=axes[1, 0], cmap='viridis')
    axes[1, 0].set_title('Top Marker Genes Expression by Cluster')
    axes[1, 0].tick_params(axis='x', rotation=45)
else:
    axes[1, 0].text(0.5, 0.5, 'No marker genes\navailable', ha='center', va='center', transform=axes[1, 0].transAxes)

# Plot 4: Biological interpretation summary
interpretation_text = "\n".join([
    f"{cluster}: {interp['key_characteristics']}"
    for cluster, interp in biological_interpretation.items()
])

axes[1, 1].text(0.1, 0.9, "Biological Interpretation Summary:", fontsize=12, fontweight='bold',
               transform=axes[1, 1].transAxes)
axes[1, 1].text(0.1, 0.8, interpretation_text, fontsize=10, verticalalignment='top',
               transform=axes[1, 1].transAxes)
axes[1, 1].set_title('Cluster Biological Interpretation')
axes[1, 1].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(processed_data_dir, 'cluster_analysis.png'), dpi=300, bbox_inches='tight')
plt.show()

# Create comprehensive analysis summary
analysis_summary = {
    'clustering': {
        'algorithm': best_algorithm,
        'n_clusters': len(np.unique(final_labels)),
        'silhouette_score': silhouette_scores[best_algorithm]
    },
    'supervised_learning': {
        'best_model': best_model_name,
        'test_accuracy': test_results[best_model_name]['accuracy'],
        'test_f1_score': test_results[best_model_name]['f1_score']
    },
    'biological_insights': biological_interpretation,
    'marker_genes': {k: v.head(5).to_dict('records') for k, v in marker_genes.items()},
    'cluster_statistics': cluster_stats
}

with open(os.path.join(processed_data_dir, 'analysis_summary.json'), 'w') as f:
    json.dump(analysis_summary, f, indent=2, default=str)

print("Cluster analysis and biological interpretation completed!")
print(f"Comprehensive analysis summary saved to: {os.path.join(processed_data_dir, 'analysis_summary.json')}")

# Final summary
print("\n" + "="*80)
print("ANALYSIS COMPLETE")
print("="*80)
print(f"Dataset: {len(combined_adata)} cells from {len(combined_adata.obs['sample'].unique())} samples")
print(f"Clusters identified: {len(np.unique(final_labels))} using {best_algorithm}")
print(f"Best supervised model: {best_model_name} (F1: {test_results[best_model_name]['f1_score']:.3f})")
print(f"Results saved in: {processed_data_dir}")
print("="*80)

## 11. Conclusion and Future Directions

In [None]:
# --- Conclusion and Future Directions ---

print("Analysis completed successfully!")
print("\nKey Findings:")
print(f"• Identified {len(np.unique(final_labels))} distinct cell clusters using {best_algorithm}")
print(f"• Best supervised classification model: {best_model_name} (F1-score: {test_results[best_model_name]['f1_score']:.3f})")
print(f"• Combined gene expression and TCR sequence features for comprehensive analysis")
print("• Generated marker genes and biological interpretations for each cluster")

print("\nFiles Generated:")
print(f"• Processed data: {processed_data_dir}")
print("  - analysis_summary.json: Complete analysis results")
print("  - cluster_marker_genes.csv: Marker genes for each cluster")
print("  - cluster_statistics.json: Detailed cluster statistics")
print("  - biological_interpretation.json: Biological insights")
print("  - best_supervised_model.pkl: Trained classification model")
print("  - Various visualization plots (PNG files)")

print("\nFuture Directions:")
print("1. Validate findings with additional biological experiments")
print("2. Integrate pathway analysis (GO, KEGG) for deeper biological insights")
print("3. Apply deep learning approaches for improved classification")
print("4. Extend analysis to larger datasets or additional cancer types")
print("5. Develop predictive models for treatment response based on cluster membership")

print("\nNotebook completed! All results saved in the Processed_Data directory.")
print("The reordered notebook provides a logical, end-to-end analysis workflow.")