In [1]:
# Memory monitoring utility
import psutil
import os

def print_memory_usage():
    """Print current memory usage"""
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    mem_mb = mem_info.rss / 1024 / 1024
    print(f"Current memory usage: {mem_mb:.2f} MB")
    
    # System memory info
    vm = psutil.virtual_memory()
    print(f"System memory: {vm.percent}% used ({vm.used / 1024**3:.2f} GB / {vm.total / 1024**3:.2f} GB)")
    return mem_mb

# Check initial memory
print("=== Initial Memory Check ===")
initial_mem = print_memory_usage()
print("\nTip: Run this cell periodically to monitor memory usage")

=== Initial Memory Check ===
Current memory usage: 100.90 MB
System memory: 3.8% used (0.74 GB / 31.35 GB)

Tip: Run this cell periodically to monitor memory usage


In [2]:
# ============================================================
# CONFIGURATION: Skip sections to save time and memory
# ============================================================
# Set these flags to True to SKIP the corresponding sections
# This is useful for debugging or if you only want to run deep learning

SKIP_UNSUPERVISED_LEARNING = False    # Skip Leiden clustering, UMAP, etc. (Cell 44+)
SKIP_XGBOOST_LOPO_CV = False          # Skip XGBoost and traditional ML LOPO CV (Cell 51+)
SKIP_TRADITIONAL_ML = False           # Skip Logistic Regression, Random Forest, etc.
SKIP_TO_DEEP_LEARNING = True         # Master switch: Skip everything except data loading and deep learning

# If SKIP_TO_DEEP_LEARNING is True, it overrides the other flags
if SKIP_TO_DEEP_LEARNING:
    SKIP_UNSUPERVISED_LEARNING = True
    SKIP_XGBOOST_LOPO_CV = True
    SKIP_TRADITIONAL_ML = True
    print("‚ö° FAST MODE: Skipping unsupervised learning and traditional ML, going straight to deep learning!")
else:
    print("üî¨ FULL MODE: Running all analysis sections")
    
print(f"  - Skip Unsupervised Learning: {SKIP_UNSUPERVISED_LEARNING}")
print(f"  - Skip XGBoost/LOPO CV: {SKIP_XGBOOST_LOPO_CV}")
print(f"  - Skip Traditional ML: {SKIP_TRADITIONAL_ML}")

‚ö° FAST MODE: Skipping unsupervised learning and traditional ML, going straight to deep learning!
  - Skip Unsupervised Learning: True
  - Skip XGBoost/LOPO CV: True
  - Skip Traditional ML: True


In [3]:
%pip install anndata scanpy scikit-learn umap-learn --quiet
%pip install biopython --quiet
%pip install scikit-learn --quiet
%pip install umap-learn --quiet
%pip install hdbscan --quiet
%pip install plotly --quiet
%pip install xgboost --quiet
%pip install tensorflow --quiet

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m176.2/176.2 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2.1/2.1 MB[0m [31m35.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m58.6/58.6 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m284.1/284.1 kB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m9.2/9.2 MB[0m [31m108.6 MB/s[0m eta [36m0:00:00[0m


In [4]:
# Ensure non-interactive Matplotlib backend to avoid font import issues
import matplotlib
try:
    matplotlib.use('Agg')
except Exception as e:
    print("Could not set Agg backend:", e)

# Set memory optimization flags
import os
os.environ['PYTHONHASHSEED'] = '0'
os.environ['OMP_NUM_THREADS'] = '4'  # Limit parallel threads to save memory

# --- Idempotent monkeypatch CountVectorizer.fit_transform to handle empty vocabulary errors ---
try:
    from sklearn.feature_extraction.text import CountVectorizer
    import scipy.sparse as _sps

    # Only patch once; store original on the class to avoid double-wrapping
    if not hasattr(CountVectorizer, '_orig_fit_transform'):
        CountVectorizer._orig_fit_transform = CountVectorizer.fit_transform

        def _safe_cv_fit(self, raw_docs, *args, **kwargs):
            try:
                return CountVectorizer._orig_fit_transform(self, raw_docs, *args, **kwargs)
            except ValueError as e:
                # Handle sklearn's "empty vocabulary" error by returning an all-zero matrix
                if 'empty vocabulary' in str(e).lower():
                    n = len(raw_docs) if raw_docs is not None else 0
                    return _sps.csr_matrix((n, 1))
                raise

        CountVectorizer.fit_transform = _safe_cv_fit
    else:
        # Already patched; do nothing
        pass
except Exception as e:
    # If sklearn/scipy are not available at import time, skip patching and log reason
    print("CountVectorizer monkeypatch skipped:", e)

In [5]:
# --- Initial Setup & Imports ---
import sys
import subprocess

# Install critical dependencies if missing
try:
    import Bio
except ImportError:
    print("Installing biopython...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "biopython"])

import pandas as pd
import requests
import os
import tarfile
import glob
from io import BytesIO
from collections import Counter
import warnings

# BioPython Imports
try:
    from Bio.Seq import Seq
    from Bio.SeqUtils import ProtParam
except ImportError:
    # If install just happened, might need re-import logic or kernel restart, 
    # but usually works in same session after import
    pass

# Suppress warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

# --- Environment Detection ---
IS_KAGGLE = os.path.exists('/kaggle/input') or os.environ.get('KAGGLE_KERNEL_RUN_TYPE') is not None
print(f"Running on Kaggle: {IS_KAGGLE}")

if IS_KAGGLE:
    # Ensure standard directories exist
    os.makedirs('/kaggle/working/Data', exist_ok=True)
    os.makedirs('/kaggle/working/Output', exist_ok=True)


Running on Kaggle: True


## Data Loading and Preparation
We analyze a single-cell dataset recently published by Sun et al. (2025) (GEO accession GSE300475). The data originates from the DFCI 16-466 clinical trial (NCT02999477), a randomized phase II study evaluating neoadjuvant nab-paclitaxel in combination with pembrolizumab for high-risk, early-stage HR+/HER2- breast cancer. The specific cohort analyzed consists of longitudinal peripheral blood mononuclear cell (PBMC) samples from patients in the chemotherapy-first arm.

Patients were classified into binary response categories based on Residual Cancer Burden (RCB) index assessed at surgery:
*   **Responders:** Patients achieving Pathologic Complete Response (pCR, RCB-0) or minimal residual disease (RCB-I).
*   **Non-Responders:** Patients with moderate (RCB-II) or extensive (RCB-III) residual disease.

The following code handles the downloading and extraction of the raw data files.

In [6]:
files_to_fetch = [
    {
        "name": "GSE300475_RAW.tar",
        "size": "565.5 Mb",
        "download_url": "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE300475&format=file",
        "type": "TAR (of CSV, MTX, TSV)"
    },
    {
        "name": "GSE300475_feature_ref.xlsx",
        "size": "5.4 Kb",
        "download_url": "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE300nnn/GSE300475/suppl/GSE300475%5Ffeature%5Fref.xlsx",
        "type": "XLSX"
    }
]

In [7]:
# Set download directory based on environment
from pathlib import Path

def _find_project_root():
    cwd = Path.cwd().resolve()
    for candidate in [cwd, *cwd.parents]:
        if (candidate / "README.md").exists() or (candidate / "Code").exists():
            return candidate
    return cwd

if IS_KAGGLE:
    # On Kaggle, use /kaggle/working which is writable
    download_dir = Path("/kaggle/working/Data")
else:
    # Local (VS Code / Windows): use project-root/Data
    project_root = _find_project_root()
    download_dir = project_root / "Data"

download_dir = Path(download_dir)
download_dir.mkdir(parents=True, exist_ok=True)
print(f"Downloads will be saved in: {download_dir.resolve()}\n")

def download_file(url, filename, destination_folder):
    """
    Downloads a file from a given URL to a specified destination folder.
    """
    filepath = os.path.join(destination_folder, filename)
    print(f"Attempting to download {filename} from {url}...")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()

        with open(filepath, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)

        print(f"Successfully downloaded {filename} to {filepath}")
        return filepath
    except requests.exceptions.RequestException as e:
        print(f"Error downloading {filename}: {e}")
        return None

Downloads will be saved in: /kaggle/working/Data



In [8]:
for file_info in files_to_fetch:
    filename = file_info["name"]
    url = file_info["download_url"]
    file_type = file_info["type"]

    downloaded_filepath = download_file(url, filename, download_dir)

    # If the file is a TAR archive, extract it and list the contents
    if downloaded_filepath and filename.endswith(".tar"):
        print(f"Extracting {filename}...\n")
        try:
            with tarfile.open(downloaded_filepath, "r") as tar:
                # List contents
                members = tar.getnames()
                print(f"Files contained in {filename}:")
                for member in members:
                    print(f" - {member}")

                # Extract to a subdirectory within download_dir
                extract_path = os.path.join(download_dir, filename.replace(".tar", ""))
                os.makedirs(extract_path, exist_ok=True)
                tar.extractall(path=extract_path, filter='data')
                print(f"\nExtracted to: {extract_path}")
        except tarfile.TarError as e:
            print(f"Error extracting {filename}: {e}")

        print("-" * 50 + "\n")

Attempting to download GSE300475_RAW.tar from https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE300475&format=file...
Successfully downloaded GSE300475_RAW.tar to /kaggle/working/Data/GSE300475_RAW.tar
Extracting GSE300475_RAW.tar...

Files contained in GSE300475_RAW.tar:
 - GSM9061665_S1_barcodes.tsv.gz
 - GSM9061665_S1_features.tsv.gz
 - GSM9061665_S1_matrix.mtx.gz
 - GSM9061666_S2_barcodes.tsv.gz
 - GSM9061666_S2_features.tsv.gz
 - GSM9061666_S2_matrix.mtx.gz
 - GSM9061667_S3_barcodes.tsv.gz
 - GSM9061667_S3_features.tsv.gz
 - GSM9061667_S3_matrix.mtx.gz
 - GSM9061668_S4_barcodes.tsv.gz
 - GSM9061668_S4_features.tsv.gz
 - GSM9061668_S4_matrix.mtx.gz
 - GSM9061669_S5_barcodes.tsv.gz
 - GSM9061669_S5_features.tsv.gz
 - GSM9061669_S5_matrix.mtx.gz
 - GSM9061670_S6_barcodes.tsv.gz
 - GSM9061670_S6_features.tsv.gz
 - GSM9061670_S6_matrix.mtx.gz
 - GSM9061671_S7_barcodes.tsv.gz
 - GSM9061671_S7_features.tsv.gz
 - GSM9061671_S7_matrix.mtx.gz
 - GSM9061672_S8_barcodes.tsv.gz
 - GSM9061672_S

In [9]:
import gzip
import shutil
from pathlib import Path
import pandas as pd
import os

# NOTE: We SKIP explicit decompression to avoid consuming disk space/memory.
# Scanpy's read_10x_mtx and other tools can read .gz files directly.

def preview_file(file_path):
    """
    Display the first few lines of a file (supports .gz automatically)
    """
    if file_path is None: return
    
    print(f"\n--- Preview of {os.path.basename(file_path)} ---")
    try:
        # Handle gzip if extension matches
        opener = gzip.open if str(file_path).endswith('.gz') else open
        
        if str(file_path).endswith(".tsv") or str(file_path).endswith(".csv") or str(file_path).endswith(".tsv.gz") or str(file_path).endswith(".csv.gz"):
            # Use pandas with nrows 
            sep = '\t' if 'tsv' in str(file_path) else ','
            comp = 'gzip' if str(file_path).endswith('.gz') else None
            try:
                # Try reading with header inference
                df = pd.read_csv(file_path, sep=sep, nrows=5, compression=comp)
                print(df)
            except:
                print("Could not read as CSV/TSV")
        elif 'matrix.mtx' in str(file_path):
            # Read as text stream
            with opener(file_path, 'rt') as f: # 'rt' for text mode
                print("First 10 lines (header and data):")
                for _ in range(10):
                    line = f.readline()
                    if not line: break
                    print(line.strip())
        else:
            print(f"File type {file_path} preview not customized.")
    except Exception as e:
        print(f"Could not preview {file_path}: {e}")

# Define extract_dir based on download_dir from previous cell
extract_dir = os.path.join(download_dir, "GSE300475_RAW")
raw_data_dir = Path(extract_dir) # Explicitly define this for downstream cells
print(f"Raw data directory set to: {raw_data_dir}")

gz_files = []
for root, _, files in os.walk(extract_dir):
    for file in files:
        if file.endswith(".gz"):
            gz_files.append((os.path.join(root, file), root))

print(f"Found {len(gz_files)} .gz files. Ready for processing (Decompression skipped).")

# Just preview a few to ensure they are readable
for path, _ in gz_files[:3]:
    preview_file(path)

Raw data directory set to: /kaggle/working/Data/GSE300475_RAW
Found 43 .gz files. Ready for processing (Decompression skipped).

--- Preview of GSM9061673_S9_barcodes.tsv.gz ---
   AAACCTGAGAGAGCTC-1
0  AAACCTGAGCGTAGTG-1
1  AAACCTGAGCGTGAGT-1
2  AAACCTGAGGCATGTG-1
3  AAACCTGAGTCACGCC-1
4  AAACCTGAGTCCCACG-1

--- Preview of GSM9061673_S9_matrix.mtx.gz ---
First 10 lines (header and data):
%%MatrixMarket matrix coordinate integer general
%metadata_json: {"software_version": "cellranger-6.0.0", "format_version": 2}
36604 11480 19000971
25 1 1
60 1 1
62 1 1
63 1 1
146 1 1
171 1 3
174 1 1

--- Preview of GSM9061666_S2_barcodes.tsv.gz ---
   AAACCTGAGCCATCGC-1
0  AAACCTGAGGACGAAA-1
1  AAACCTGCAAGAAAGG-1
2  AAACCTGCATGACGGA-1
3  AAACCTGGTAAATACG-1
4  AAACCTGGTGGCAAAC-1


In [10]:
%pip install scanpy --quiet

Note: you may need to restart the kernel to use updated packages.


In [11]:
import glob

# Find all "all_contig_annotations.csv" files in the extracted directory and sum their lengths (number of rows)

all_contig_files = glob.glob(os.path.join(extract_dir, "*_all_contig_annotations.csv"))
total_rows = 0

for file in all_contig_files:
    try:
        df = pd.read_csv(file)
        num_rows = len(df)
        print(f"{os.path.basename(file)}: {num_rows} rows")
        total_rows += num_rows
    except Exception as e:
        print(f"Could not read {file}: {e}")

print(f"\nTotal rows in all contig annotation files: {total_rows}")


Total rows in all contig annotation files: 0


## 1. Load Sample Metadata

First, we load the metadata from the `GSE300475_feature_ref.xlsx` file. This file contains the crucial mapping between GEO sample IDs, patient IDs, timepoints, and treatment response.

In [12]:
%pip install scanpy pandas numpy --quiet
# Import required libraries for single-cell RNA-seq analysis and data handling
import scanpy as sc  # Main library for single-cell analysis, provides AnnData structure and many tools
import pandas as pd  # For tabular data manipulation and metadata handling
import numpy as np   # For numerical operations and array handling
import os            # For operating system interactions (file paths, etc.)
from pathlib import Path  # For robust and readable file path management

# Print versions to ensure reproducibility and compatibility
print(f"Scanpy version: {sc.__version__}")
print(f"Pandas version: {pd.__version__}")

from collections import Counter
import warnings
warnings.filterwarnings('ignore')

print("All libraries imported successfully!")

Note: you may need to restart the kernel to use updated packages.
Scanpy version: 1.12
Pandas version: 2.2.2
All libraries imported successfully!


In [13]:
import gzip
import shutil
from pathlib import Path
import pandas as pd
import os
from joblib import Parallel, delayed

def preview_file(file_path):
    """
    Display the first few lines of a decompressed file without loading the whole file into memory.
    """
    if file_path is None: return
    
    print(f"\n--- Preview of {os.path.basename(file_path)} ---")
    try:
        if file_path.endswith(".tsv") or file_path.endswith(".csv"):
            # Use pandas with nrows to avoid loading full file
            sep = '\t' if file_path.endswith(".tsv") else ','
            df = pd.read_csv(file_path, sep=sep, nrows=5) 
            print(df)
        elif file_path.endswith(".mtx"):
            # Read as text stream to avoid loading massive matrix into memory
            with open(file_path, 'r') as f:
                print("First 10 lines (header and data):")
                for _ in range(10):
                    line = f.readline()
                    if not line: break
                    print(line.strip())
        elif str(file_path).endswith(".gz"):
             print(f"File is compressed ({file_path}). Scanpy will handle decompression automatically.")
        else:
            print("Unsupported file type for preview.")
    except Exception as e:
        print(f"Could not preview {file_path}: {e}")

# Ensure download_dir exists (fallback protection)
if 'download_dir' not in globals():
     # Fallback logic if variable not in scope
     if 'IS_KAGGLE' in globals() and IS_KAGGLE:
         download_dir = "/kaggle/working/Data"
     else:
         def _find_project_root():
             cwd = Path.cwd().resolve()
             for candidate in [cwd, *cwd.parents]:
                 if (candidate / "README.md").exists() or (candidate / "Code").exists():
                     return candidate
             return cwd
         download_dir = str(_find_project_root() / "Data")

# Normalize download_dir to string for os.path usage
if isinstance(download_dir, Path):
    download_dir = str(download_dir)

extract_dir = os.path.join(download_dir, "GSE300475_RAW")

# --- PATH CORRECTION LOGIC ---
# If extract_dir is empty or missing, but files are in download_dir, use download_dir
if not os.path.exists(extract_dir) or not any(f.endswith('.gz') for f in os.listdir(extract_dir) if os.path.isfile(os.path.join(extract_dir, f))):
    if os.path.exists(download_dir) and any(f.endswith('.gz') for f in os.listdir(download_dir) if os.path.isfile(os.path.join(download_dir, f))):
         print(f"Detecting files in {download_dir} directly. Adjusting path.")
         extract_dir = download_dir

# --- MEMORY OPTIMIZATION ---
# We SKIP explicit decompression here because Scanpy's read_10x_mtx can read .gz files directly.
# Decompressing large sparse matrices to dense text files on disk is unnecessary and wastes storage/IO.
print("Skipping explicit decompression to save disk space and IO.")
print("Scanpy handles .gz files directly during loading.")

# Just preview one GZ file to show it exists
gz_files = []
for root, _, files in os.walk(extract_dir):
    for file in files:
        if file.endswith(".gz"):
            gz_files.append(os.path.join(root, file))

if gz_files:
    print(f"Found {len(gz_files)} compressed files ready for loading.")
    print(f"Example: {gz_files[0]}")

Skipping explicit decompression to save disk space and IO.
Scanpy handles .gz files directly during loading.
Found 43 compressed files ready for loading.
Example: /kaggle/working/Data/GSE300475_RAW/GSM9061673_S9_barcodes.tsv.gz


In [14]:
%%time
# --- Setup data paths ---
import os
from pathlib import Path
import tarfile
import requests

# Use existing IS_KAGGLE flag or detect
if 'IS_KAGGLE' not in globals():
    IS_KAGGLE = os.path.exists('/kaggle/input') or os.environ.get('KAGGLE_KERNEL_RUN_TYPE') is not None

def _find_project_root():
    cwd = Path.cwd().resolve()
    for candidate in [cwd, *cwd.parents]:
        if (candidate / "README.md").exists() or (candidate / "Code").exists():
            return candidate
    return cwd

def _has_matrix_files(path: Path) -> bool:
    return path.exists() and any(path.rglob("*matrix.mtx*"))

# Determine candidate base directories
candidate_dirs = []
if IS_KAGGLE:
    candidate_dirs = [Path('/kaggle/working/Data'), Path('/Data'), Path('/kaggle/input')]
else:
    project_root = _find_project_root()
    if 'download_dir' in globals() and download_dir:
        candidate_dirs.append(Path(download_dir))
    candidate_dirs += [project_root / 'Data', project_root / 'data', project_root]

raw_data_dir = None
for base in candidate_dirs:
    if base is None:
        continue
    base = Path(base)
    if base.name == 'GSE300475_RAW' and _has_matrix_files(base):
        raw_data_dir = base
        break
    if _has_matrix_files(base / 'GSE300475_RAW'):
        raw_data_dir = base / 'GSE300475_RAW'
        break
    if _has_matrix_files(base):
        raw_data_dir = base
        break

# Fallback: search under project root (local only)
if raw_data_dir is None and not IS_KAGGLE:
    project_root = _find_project_root()
    for match in project_root.rglob('GSE300475_RAW'):
        if _has_matrix_files(match):
            raw_data_dir = match
            break

# Auto-download if still missing
if raw_data_dir is None:
    print("Raw data not found locally. Attempting download...")
    if IS_KAGGLE:
        download_dir = Path('/kaggle/working/Data')
    else:
        project_root = _find_project_root()
        if 'download_dir' in globals() and download_dir:
            download_dir = Path(download_dir)
        else:
            download_dir = project_root / 'Data'
    download_dir.mkdir(parents=True, exist_ok=True)
    tar_path = download_dir / 'GSE300475_RAW.tar'
    if not tar_path.exists():
        url = 'https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE300475&format=file'
        try:
            with requests.get(url, stream=True) as r:
                r.raise_for_status()
                with open(tar_path, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=1024 * 1024):
                        if chunk:
                            f.write(chunk)
            print(f"Downloaded {tar_path}")
        except Exception as e:
            raise RuntimeError(f"Download failed: {e}")
    extract_path = download_dir / 'GSE300475_RAW'
    if not extract_path.exists():
        print(f"Extracting {tar_path} to {extract_path}...")
        try:
            with tarfile.open(tar_path, 'r') as tar:
                try:
                    tar.extractall(path=extract_path, filter='data')
                except TypeError:
                    tar.extractall(path=extract_path)
        except Exception as e:
            raise RuntimeError(f"Extraction failed: {e}")
    raw_data_dir = extract_path

print(f"Data directory set to: {raw_data_dir}")

# --- Manually create the metadata mapping ---
# This list contains information about each sample, including GEO IDs, patient IDs, timepoints, and treatment response.
# Note: S8 (GSM9061672) has GEX files but no corresponding TCR file.
metadata_list = [
    # Patient 1 (Responder)
    {'S_Number': 'S1',  'GEX_Sample_ID': 'GSM9061665', 'TCR_Sample_ID': 'GSM9061687', 'Patient_ID': 'PT1',  'Timepoint': 'Baseline',     'Response': 'Responder',     'In_Data': 'Yes'},
    {'S_Number': 'S2',  'GEX_Sample_ID': 'GSM9061666', 'TCR_Sample_ID': 'GSM9061688', 'Patient_ID': 'PT1',  'Timepoint': 'Post-Tx',      'Response': 'Responder',     'In_Data': 'Yes'},
    {'S_Number': 'S3',  'GEX_Sample_ID': 'GSM9061667', 'TCR_Sample_ID': 'GSM9061689', 'Patient_ID': 'PT1',  'Timepoint': 'Recurrence',   'Response': 'Responder',     'In_Data': 'Yes'},
    
    # Patient 2 (Responder)
    {'S_Number': 'S4',  'GEX_Sample_ID': 'GSM9061668', 'TCR_Sample_ID': 'GSM9061690', 'Patient_ID': 'PT2',  'Timepoint': 'Baseline',     'Response': 'Responder',     'In_Data': 'Yes'},
    {'S_Number': 'S5',  'GEX_Sample_ID': 'GSM9061669', 'TCR_Sample_ID': 'GSM9061691', 'Patient_ID': 'PT2',  'Timepoint': 'Post-Tx',      'Response': 'Responder',     'In_Data': 'Yes'},
    
    # Patient 3 (Non-Responder)
    {'S_Number': 'S6',  'GEX_Sample_ID': 'GSM9061670', 'TCR_Sample_ID': 'GSM9061692', 'Patient_ID': 'PT3',  'Timepoint': 'Baseline',     'Response': 'Non-Responder', 'In_Data': 'Yes'},
    {'S_Number': 'S7',  'GEX_Sample_ID': 'GSM9061671', 'TCR_Sample_ID': 'GSM9061693', 'Patient_ID': 'PT3',  'Timepoint': 'Post-Tx',      'Response': 'Non-Responder', 'In_Data': 'Yes'},
    {'S_Number': 'S8',  'GEX_Sample_ID': 'GSM9061672', 'TCR_Sample_ID': None,         'Patient_ID': 'PT3',  'Timepoint': 'Recurrence',   'Response': 'Non-Responder', 'In_Data': 'GEX only'},
    
    # Patient 4 (Non-Responder)
    {'S_Number': 'S9',  'GEX_Sample_ID': 'GSM9061673', 'TCR_Sample_ID': 'GSM9061694', 'Patient_ID': 'PT4',  'Timepoint': 'Baseline',     'Response': 'Non-Responder', 'In_Data': 'Yes'},
    {'S_Number': 'S10', 'GEX_Sample_ID': 'GSM9061674', 'TCR_Sample_ID': 'GSM9061695', 'Patient_ID': 'PT4',  'Timepoint': 'Post-Tx',      'Response': 'Non-Responder', 'In_Data': 'Yes'},
    {'S_Number': 'S11', 'GEX_Sample_ID': 'GSM9061675', 'TCR_Sample_ID': 'GSM9061696', 'Patient_ID': 'PT4',  'Timepoint': 'Recurrence',   'Response': 'Non-Responder', 'In_Data': 'Yes'},
]

# Create pandas DataFrame for easy access
metadata_df = pd.DataFrame(metadata_list)
print("Metadata table now matches the requested specification:")
display(metadata_df)

# --- Programmatic sanity-check for file presence ---
# This loop checks if the expected files exist for each sample and updates the 'In_Data' column accordingly.
for idx, row in metadata_df.iterrows():
    s = row['S_Number']
    g = row['GEX_Sample_ID']
    t = row['TCR_Sample_ID']
    # Check for gene expression matrix file (compressed or uncompressed)
    # Check .mtx, .mtx.gz, and also potential file name variations or if they are in subfolders
    # We look in raw_data_dir found above.
    g_exists = (raw_data_dir / f"{g}_{s}_matrix.mtx.gz").exists() or (raw_data_dir / f"{g}_{s}_matrix.mtx").exists()
    
    # Also check if just the GSM id is present in some filename if strict match fails (fallback)
    if not g_exists:
         # Try simpler wildcard search
         g_exists = len(list(raw_data_dir.glob(f"*{g}*matrix*"))) > 0

    t_exists = False
    # Check for TCR annotation file if TCR sample ID is present
    if pd.notna(t) and t is not None:
        t_exists = (raw_data_dir / f"{t}_{s}_all_contig_annotations.csv.gz").exists() or (raw_data_dir / f"{t}_{s}_all_contig_annotations.csv").exists()
        if not t_exists:
             t_exists = len(list(raw_data_dir.glob(f"*{t}*all_contig_annotations*"))) > 0
             
    # Update 'In_Data' column based on file presence
    if g_exists and t_exists:
        metadata_df.at[idx, 'In_Data'] = 'Yes'
    elif g_exists and not t_exists:
        metadata_df.at[idx, 'In_Data'] = 'GEX only'
    else:
        metadata_df.at[idx, 'In_Data'] = 'No'

print("\nPost-check In_Data column (based on files found in raw_data_dir):")
display(metadata_df)

Data directory set to: /kaggle/working/Data/GSE300475_RAW
Metadata table now matches the requested specification:


Unnamed: 0,S_Number,GEX_Sample_ID,TCR_Sample_ID,Patient_ID,Timepoint,Response,In_Data
0,S1,GSM9061665,GSM9061687,PT1,Baseline,Responder,Yes
1,S2,GSM9061666,GSM9061688,PT1,Post-Tx,Responder,Yes
2,S3,GSM9061667,GSM9061689,PT1,Recurrence,Responder,Yes
3,S4,GSM9061668,GSM9061690,PT2,Baseline,Responder,Yes
4,S5,GSM9061669,GSM9061691,PT2,Post-Tx,Responder,Yes
5,S6,GSM9061670,GSM9061692,PT3,Baseline,Non-Responder,Yes
6,S7,GSM9061671,GSM9061693,PT3,Post-Tx,Non-Responder,Yes
7,S8,GSM9061672,,PT3,Recurrence,Non-Responder,GEX only
8,S9,GSM9061673,GSM9061694,PT4,Baseline,Non-Responder,Yes
9,S10,GSM9061674,GSM9061695,PT4,Post-Tx,Non-Responder,Yes



Post-check In_Data column (based on files found in raw_data_dir):


Unnamed: 0,S_Number,GEX_Sample_ID,TCR_Sample_ID,Patient_ID,Timepoint,Response,In_Data
0,S1,GSM9061665,GSM9061687,PT1,Baseline,Responder,Yes
1,S2,GSM9061666,GSM9061688,PT1,Post-Tx,Responder,Yes
2,S3,GSM9061667,GSM9061689,PT1,Recurrence,Responder,Yes
3,S4,GSM9061668,GSM9061690,PT2,Baseline,Responder,Yes
4,S5,GSM9061669,GSM9061691,PT2,Post-Tx,Responder,Yes
5,S6,GSM9061670,GSM9061692,PT3,Baseline,Non-Responder,Yes
6,S7,GSM9061671,GSM9061693,PT3,Post-Tx,Non-Responder,Yes
7,S8,GSM9061672,,PT3,Recurrence,Non-Responder,GEX only
8,S9,GSM9061673,GSM9061694,PT4,Baseline,Non-Responder,Yes
9,S10,GSM9061674,GSM9061695,PT4,Post-Tx,Non-Responder,Yes


CPU times: user 22.8 ms, sys: 3.93 ms, total: 26.8 ms
Wall time: 26.2 ms


In [15]:
%%time
# --- DISK-BASED MAP-REDUCE STRATEGY TO SOLVE OOM ---
# Strategy: 
# 1. Map: Process each sample -> QC -> Save to temp .h5ad on disk
# 2. Reduce: Concatenate on disk (preferred) or in small batches
# This keeps RAM usage low during processing and avoids the iterative reallocation spike.

import gc
import shutil
import scanpy as sc
import anndata as ad
import pandas as pd
import scipy.sparse as sp
import numpy as np
import os
from pathlib import Path

gc.enable()

# Validate prerequisites
if 'metadata_df' not in globals():
    raise NameError("metadata_df is not defined. Please run the metadata creation cell first.")
if 'raw_data_dir' not in globals():
    raise NameError("raw_data_dir is not defined. Please run the data path setup cell first.")

# Setup temp directory for chunks
temp_chunk_dir = Path("temp_adata_chunks")
if temp_chunk_dir.exists():
    shutil.rmtree(temp_chunk_dir)
temp_chunk_dir.mkdir(exist_ok=True)

chunk_files = []
chunk_keys = []
tcr_data_list = []  # TCR data is small enough to keep in memory

print("Starting Map Phase (Processing & Saving Chunks)...")

# --- MAP PHASE: Process & Save ---
for index, row in metadata_df.iterrows():
    gex_sample_id = row['GEX_Sample_ID']
    s_number = row['S_Number']
    
    # Construct sample-level prefix
    sample_prefix = f"{gex_sample_id}_{s_number}"

    # Use robust file finding logic from previous cells
    matrix_file = None
    for ext in ['matrix.mtx.gz', 'matrix.mtx']:
        candidate = raw_data_dir / f"{sample_prefix}_{ext}"
        if candidate.exists():
            matrix_file = candidate
            break
            
    if not matrix_file:
         # Fallback search
        for ext in ['matrix.mtx.gz', 'matrix.mtx']:
            possible_files = list(raw_data_dir.glob(f"*{gex_sample_id}*{ext}"))
            if possible_files:
                matrix_file = possible_files[0]
                break
    
    if not matrix_file:
        print(f"Skipping {sample_prefix}: Matrix file not found.")
        continue

    sample_data_path = matrix_file.parent
    matrix_prefix = matrix_file.name.replace('matrix.mtx', '').replace('.gz', '')

    print(f"Processing {index+1}/{len(metadata_df)}: {sample_prefix}")
    
    try:
        # Load GEX
        adata_sample = sc.read_10x_mtx(
            sample_data_path, 
            var_names='gene_symbols',
            prefix=matrix_prefix,
            cache=True
        )
        
        # Ensure sparse float32 IMMEDIATELY
        if not hasattr(adata_sample.X, 'toarray'):
            adata_sample.X = sp.csr_matrix(adata_sample.X, dtype=np.float32)
        else:
            adata_sample.X = sp.csr_matrix(adata_sample.X, dtype=np.float32)
            
        # Add metadata
        adata_sample.obs['sample_id'] = gex_sample_id 
        adata_sample.obs['patient_id'] = row['Patient_ID']
        adata_sample.obs['timepoint'] = row['Timepoint']
        adata_sample.obs['response'] = row['Response']
        
        # QC Filtering (Crucial reduction)
        sc.pp.filter_cells(adata_sample, min_genes=200)
        sc.pp.filter_genes(adata_sample, min_cells=3)
        
        # Ensure unique var names before saving
        adata_sample.var_names_make_unique()
        
        # Save chunk
        chunk_path = temp_chunk_dir / f"chunk_{index}_{gex_sample_id}.h5ad"
        adata_sample.write_h5ad(chunk_path, compression='gzip')
        chunk_files.append(chunk_path)
        chunk_keys.append(sample_prefix)
        
        print(f"  Saved chunk: {adata_sample.n_obs} cells. Memory cleared.")
        
        # Cleanup
        del adata_sample
        gc.collect()

    except Exception as e:
        print(f"  Error processing {sample_prefix}: {e}")
        continue
        
    # TCR Loading (Keep separate list)
    tcr_sample_id = row['TCR_Sample_ID']
    if pd.notna(tcr_sample_id):
        tcr_file = raw_data_dir / f"{tcr_sample_id}_{s_number}_all_contig_annotations.csv.gz"
        if not tcr_file.exists():
             tcr_file = raw_data_dir / f"{tcr_sample_id}_{s_number}_all_contig_annotations.csv"
             
        if tcr_file.exists():
            try:
                # Load essential columns only
                cols = ['barcode', 'is_cell', 'contig_id', 'high_confidence', 'length', 
                        'chain', 'v_gene', 'd_gene', 'j_gene', 'c_gene', 'full_length', 
                        'productive', 'cdr3']
                        
                # Check which columns actually exist in the file first to strictly avoid errors?
                # Faster to just try/except or load all if cols obscure
                # Let's try loading header first? No, pandas handling is fine.
                tcr_df = pd.read_csv(tcr_file, usecols=lambda c: c in cols)
                tcr_df['sample_id'] = gex_sample_id
                tcr_data_list.append(tcr_df)
            except:
                pass


# --- REDUCE PHASE: Disk-safe Concatenation ---
print(f"\nStarting Reduce Phase (Merging {len(chunk_files)} chunks)...")

if not chunk_files:
    raise ValueError("No chunks were saved! Check data paths.")

merged_path = temp_chunk_dir / "merged.h5ad"
adata = None

# Prefer on-disk concatenation if available (anndata>=0.9)
try:
    if hasattr(ad, "experimental") and hasattr(ad.experimental, "concat_on_disk"):
        print("Using on-disk concatenation (anndata.experimental.concat_on_disk)...")
        # Use 'batch' as the label instead of 'sample_id' to preserve the original sample_id column
        ad.experimental.concat_on_disk(
            [str(p) for p in chunk_files],
            str(merged_path),
            join='inner',  # intersection avoids union blow-up
            merge='same',
            label='batch',  # Changed from 'sample_id' to preserve original sample_id
            keys=chunk_keys,
            index_unique='-'
        )
        adata = sc.read_h5ad(merged_path)
    else:
        raise AttributeError("concat_on_disk not available in this anndata version")
except Exception as e:
    print(f"Falling back to batch in-memory concat: {e}")
    batch_size = 4
    adata = None
    for i in range(0, len(chunk_files), batch_size):
        batch_files = chunk_files[i:i+batch_size]
        batch_adatas = [sc.read_h5ad(f) for f in batch_files]
        batch = ad.concat(batch_adatas, join='inner', merge='same', index_unique='-')
        del batch_adatas
        gc.collect()
        if adata is None:
            adata = batch
        else:
            adata = ad.concat([adata, batch], join='inner', merge='same', index_unique='-')
        del batch
        gc.collect()

# Final sparse enforcement
if not sp.issparse(adata.X):
    adata.X = sp.csr_matrix(adata.X, dtype=np.float32)

print(f"Final Merged Data: {adata.n_obs} cells x {adata.n_vars} genes")

# Merge TCR data
if tcr_data_list:
    full_tcr_df = pd.concat(tcr_data_list, ignore_index=True)
    print(f"Full TCR Data: {len(full_tcr_df)} rows")
    del tcr_data_list
else:
    print("No TCR data loaded.")

# Cleanup chunks
shutil.rmtree(temp_chunk_dir)
print("Temp chunks cleaned up.")

# Dummy adata_list for compatibility
adata_list = []

Starting Map Phase (Processing & Saving Chunks)...
Processing 1/11: GSM9061665_S1
  Saved chunk: 8804 cells. Memory cleared.
Processing 2/11: GSM9061666_S2
  Saved chunk: 9037 cells. Memory cleared.
Processing 3/11: GSM9061667_S3
  Saved chunk: 7343 cells. Memory cleared.
Processing 4/11: GSM9061668_S4
  Saved chunk: 8608 cells. Memory cleared.
Processing 5/11: GSM9061669_S5
  Saved chunk: 2887 cells. Memory cleared.
Processing 6/11: GSM9061670_S6
  Saved chunk: 10353 cells. Memory cleared.
Processing 7/11: GSM9061671_S7
  Saved chunk: 9186 cells. Memory cleared.
Processing 8/11: GSM9061672_S8
  Saved chunk: 12665 cells. Memory cleared.
Processing 9/11: GSM9061673_S9
  Saved chunk: 11216 cells. Memory cleared.
Processing 10/11: GSM9061674_S10
  Saved chunk: 9582 cells. Memory cleared.
Processing 11/11: GSM9061675_S11
  Saved chunk: 9286 cells. Memory cleared.

Starting Reduce Phase (Merging 11 chunks)...
Using on-disk concatenation (anndata.experimental.concat_on_disk)...
Final Merged 

## 3. Integrate TCR Data and Perform QC

Next, we'll merge the TCR information into the `.obs` of our main `AnnData` object. We will keep only the cells that have corresponding TCR data and filter based on the `high_confidence` flag.

In [16]:
# 3. Raw Processing Branch (Only runs if needed)
# Auto-define should_process_raw if missing to avoid NameError
if 'should_process_raw' not in globals():
    _loaded_h5ad = bool(globals().get('loaded_h5ad', False))
    _adata_missing = ('adata' not in globals()) or (adata is None)
    _metadata_ready = 'metadata_df' in globals()
    should_process_raw = _metadata_ready and (not _loaded_h5ad) and _adata_missing
    print(f"should_process_raw not set; defaulting to {should_process_raw} (loaded_h5ad={_loaded_h5ad}, adata_missing={_adata_missing})")

if should_process_raw:
    print("Starting raw data processing from metadata...")

    # Ensure raw_data_dir is defined
    if 'raw_data_dir' not in globals():
        base_dir = Path('/kaggle/working/Data') if (globals().get('IS_KAGGLE', False)) else Path('../Data')
        raw_data_dir = base_dir / 'GSE300475_RAW'
        print(f"raw_data_dir undefined. Defaulting to: {raw_data_dir}")
    
    # --- Initialize lists ---
    adata_list = []  
    tcr_data_list = []  

    # --- Iterate through each sample ---
    for index, row in metadata_df.iterrows():
        gex_sample_id = row['GEX_Sample_ID']
        tcr_sample_id = row['TCR_Sample_ID']
        s_number = row['S_Number']
        patient_id = row['Patient_ID']
        timepoint = row['Timepoint']
        response = row['Response']
        
        print(f"Processing sample {index+1}/{len(metadata_df)}: {gex_sample_id} ({s_number})...")
        
        # --- Robust File Finding (Fixing 'GEX data not found') ---
        # Pattern: *GSM123*matrix.mtx* matches both .mtx and .mtx.gz
        try:
            found_gex_files = list(raw_data_dir.rglob(f"*{gex_sample_id}*matrix.mtx*"))
        except Exception as e:
            print(f"Error searching {raw_data_dir}: {e}")
            found_gex_files = []
        
        if not found_gex_files:
            print(f"  Warning: GEX matrix file for {gex_sample_id} not found in {raw_data_dir}. Skipping.")
            try:
                 print("  Debug: Listing first 5 files in raw_data_dir to help diagnose:")
                 for i, p in enumerate(raw_data_dir.rglob('*')):
                     if i >= 5: break
                     print(f"    {p.name}")
            except: pass
            continue

should_process_raw not set; defaulting to False (loaded_h5ad=False, adata_missing=False)


## 4. Save Processed Data

Finally, we save the fully processed, annotated, and filtered `AnnData` object to a `.h5ad` file. This file can be easily loaded in future notebooks for analysis.

In [17]:
%%time
# --- Integrate TCR data into AnnData.obs and perform quality control ---

# Validate that adata exists
if 'adata' not in globals():
    raise NameError("adata is not defined. Please run the data loading cell first.")
if adata is None:
    raise ValueError("adata is None. Data loading may have failed.")

# Check if TCR data exists and is not empty
if 'full_tcr_df' in globals() and isinstance(full_tcr_df, pd.DataFrame) and not full_tcr_df.empty:
    print(f"Integrating TCR data into AnnData (TCR contigs: {len(full_tcr_df)}, cells: {adata.n_obs})...")
    
    try:
        # --- TCR Data Aggregation ---
        # The previous join failed because one cell (barcode) can have multiple TCR contigs (e.g., TRA and TRB chains),
        # creating a one-to-many join that increases the number of rows.
        # The fix is to aggregate the TCR data to one row per cell *before* merging.

        # 1. Filter for high-confidence, productive TRA/TRB chains.
        # Only keep TCR contigs that are both high-confidence and productive, and are either TRA or TRB chains.
        if 'high_confidence' not in full_tcr_df.columns or 'productive' not in full_tcr_df.columns or 'chain' not in full_tcr_df.columns:
            print("WARNING: TCR dataframe missing required columns (high_confidence, productive, chain). Skipping TCR integration.")
            tcr_to_agg = pd.DataFrame()
        else:
            tcr_to_agg = full_tcr_df[
                (full_tcr_df['high_confidence'] == True) &
                (full_tcr_df['productive'] == True) &
                (full_tcr_df['chain'].isin(['TRA', 'TRB']))
            ].copy()

        if not tcr_to_agg.empty:
            # 2. Pivot the data to create one row per barcode, with columns for TRA and TRB data.
            # This step ensures each cell (barcode) has its TRA and TRB info in separate columns.
            tcr_aggregated = tcr_to_agg.pivot_table(
                index=['sample_id', 'barcode'],
                columns='chain',
                values=['v_gene', 'j_gene', 'cdr3'],
                aggfunc='first'  # 'first' is safe as we expect at most one productive TRA/TRB per cell
            )

            # 3. Flatten the multi-level column index (e.g., from ('v_gene', 'TRA') to 'v_gene_TRA')
            tcr_aggregated.columns = ['_'.join(col).strip() for col in tcr_aggregated.columns.values]
            tcr_aggregated.reset_index(inplace=True)

            # --- DEBUG: Print sample formats to diagnose any mismatches ---
            print(f"  DEBUG: adata.obs sample_id examples: {adata.obs['sample_id'].unique()[:3].tolist()}")
            print(f"  DEBUG: TCR sample_id examples: {tcr_aggregated['sample_id'].unique()[:3].tolist()}")
            
            # 4. Prepare adata.obs for the merge by creating a matching barcode column.
            # The index in adata.obs is like 'AGCCATGCAGCTGTTA-1-0' (barcode-concat_suffix).
            # The barcode in TCR data is like 'AGCCATGCAGCTGTTA-1'.
            adata.obs['barcode_for_merge'] = adata.obs.index.str.rsplit('-', n=1).str[0]
            
            # Handle case where sample_id might have been modified by concat (fallback fix)
            # Extract just the GSM ID if sample_id contains underscores (e.g., "GSM9061665_S1" -> "GSM9061665")
            if adata.obs['sample_id'].astype(str).str.contains('_').any():
                print("  INFO: sample_id contains underscores, extracting GSM ID portion for merge...")
                adata.obs['sample_id_for_merge'] = adata.obs['sample_id'].astype(str).str.split('_').str[0]
            else:
                adata.obs['sample_id_for_merge'] = adata.obs['sample_id']
            
            print(f"  DEBUG: adata barcode examples: {adata.obs['barcode_for_merge'].head(3).tolist()}")
            print(f"  DEBUG: TCR barcode examples: {tcr_aggregated['barcode'].head(3).tolist()}")

            # 5. Perform a left merge. This keeps all cells from adata and adds TCR info where available.
            # The number of rows will not change because tcr_aggregated has unique barcodes per sample.
            original_obs = adata.obs.copy()
            merged_obs = original_obs.merge(
                tcr_aggregated,
                left_on=['sample_id_for_merge', 'barcode_for_merge'],
                right_on=['sample_id', 'barcode'],
                how='left',
                suffixes=('', '_tcr')
            )
            
            # 6. Restore the original index to the merged dataframe.
            merged_obs.index = original_obs.index
            adata.obs = merged_obs
            
            # Clean up redundant columns from merge
            cols_to_drop = [c for c in ['sample_id_tcr', 'sample_id_for_merge'] if c in adata.obs.columns]
            if cols_to_drop:
                adata.obs.drop(columns=cols_to_drop, inplace=True)

            # Check how many cells got TCR info
            tcr_col = 'v_gene_TRA' if 'v_gene_TRA' in adata.obs.columns else None
            if tcr_col:
                cells_with_tcr = (~adata.obs[tcr_col].isna()).sum()
                print(f"Successfully merged TCR data. Cells with TCR info: {cells_with_tcr} / {adata.n_obs}")
                
                if cells_with_tcr == 0:
                    print("WARNING: No cells matched TCR data! Check barcode/sample_id formats.")
                    print("  Skipping TCR filtering to preserve data.")
                else:
                    # --- Filter for cells that have TCR information after the merge ---
                    # Only keep cells with non-null v_gene_TRA (i.e., cells with high-confidence TCR data)
                    initial_cells = adata.n_obs
                    adata = adata[~adata.obs[tcr_col].isna()].copy()
                    print(f"Filtered from {initial_cells} to {adata.n_obs} cells based on having high-confidence TCR data.")
            else:
                print("WARNING: TCR merge did not produce expected columns. Skipping TCR filtering.")
        else:
            print("WARNING: No high-confidence productive TRA/TRB chains found in TCR data. Skipping TCR filtering.")
            
    except Exception as e:
        import traceback
        print(f"ERROR during TCR integration: {e}")
        traceback.print_exc()
        print("Proceeding without TCR integration...")
else:
    print("No TCR data available or full_tcr_df is empty. Proceeding without TCR integration...")

# --- Basic QC and filtering ---
try:
    print(f"\nPerforming QC filtering (starting with {adata.n_obs} cells, {adata.n_vars} genes)...")
    
    # Filter out cells with fewer than 200 genes detected
    sc.pp.filter_cells(adata, min_genes=200)
    print(f"  After min_genes filter: {adata.n_obs} cells")
    
    # Filter out genes detected in fewer than 3 cells
    sc.pp.filter_genes(adata, min_cells=3)
    print(f"  After min_cells filter: {adata.n_vars} genes")

    # Annotate mitochondrial genes for QC metrics
    adata.var['mt'] = adata.var_names.str.startswith('MT-')
    # Calculate QC metrics (e.g., percent mitochondrial genes)
    sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

    print("\nPost-QC AnnData object:")
    print(adata)
    print("\nSample metadata preview:")
    display(adata.obs.head())
    
except Exception as e:
    print(f"ERROR during QC filtering: {e}")
    raise

Integrating TCR data into AnnData (TCR contigs: 162133, cells: 98967)...
  DEBUG: adata.obs sample_id examples: ['GSM9061665', 'GSM9061666', 'GSM9061667']
  DEBUG: TCR sample_id examples: ['GSM9061665', 'GSM9061666', 'GSM9061667']
  DEBUG: adata barcode examples: ['AAACCTGAGAAGGGTA-1', 'AAACCTGAGACTGTAA-1', 'AAACCTGAGCAGCGTA-1']
  DEBUG: TCR barcode examples: ['AAACCTGAGACTGTAA-1', 'AAACCTGAGCGTGAAC-1', 'AAACCTGAGCTACCTA-1']
Successfully merged TCR data. Cells with TCR info: 38413 / 98967
Filtered from 98967 to 38413 cells based on having high-confidence TCR data.

Performing QC filtering (starting with 38413 cells, 14819 genes)...
  After min_genes filter: 38413 cells
  After min_cells filter: 14816 genes

Post-QC AnnData object:
AnnData object with n_obs √ó n_vars = 38413 √ó 14816
    obs: 'sample_id', 'patient_id', 'timepoint', 'response', 'n_genes', 'batch', 'barcode_for_merge', 'barcode', 'cdr3_TRA', 'cdr3_TRB', 'j_gene_TRA', 'j_gene_TRB', 'v_gene_TRA', 'v_gene_TRB', 'n_genes_by_c

Unnamed: 0,sample_id,patient_id,timepoint,response,n_genes,batch,barcode_for_merge,barcode,cdr3_TRA,cdr3_TRB,j_gene_TRA,j_gene_TRB,v_gene_TRA,v_gene_TRB,n_genes_by_counts,total_counts,total_counts_mt,pct_counts_mt
AAACCTGAGACTGTAA-1-GSM9061665_S1,GSM9061665,PT1,Baseline,Responder,1379,GSM9061665_S1,AAACCTGAGACTGTAA-1,AAACCTGAGACTGTAA-1,CAVEARNYKLTF,CASGTGLNTEAFF,TRAJ53,TRBJ1-1,TRAV36/DV7,TRBV3-1,1379,4637.0,157.0,3.38581
AAACCTGAGCGTGAAC-1-GSM9061665_S1,GSM9061665,PT1,Baseline,Responder,1275,GSM9061665_S1,AAACCTGAGCGTGAAC-1,AAACCTGAGCGTGAAC-1,CAASAVGNEKLTF,CAWSALLGTVNGYTF,TRAJ48,TRBJ1-2,TRAV29/DV5,TRBV30,1275,4843.0,247.0,5.100144
AAACCTGAGCTACCTA-1-GSM9061665_S1,GSM9061665,PT1,Baseline,Responder,886,GSM9061665_S1,AAACCTGAGCTACCTA-1,AAACCTGAGCTACCTA-1,CALSEAWGNARLMF,CASRSREETYEQYF,TRAJ31,TRBJ2-7,TRAV19,TRBV2,886,3076.0,280.0,9.102731
AAACCTGAGCTGTTCA-1-GSM9061665_S1,GSM9061665,PT1,Baseline,Responder,1628,GSM9061665_S1,AAACCTGAGCTGTTCA-1,AAACCTGAGCTGTTCA-1,CALLGLKGEGSARQLTF,CASSLPPWRANTEAFF,TRAJ22,TRBJ1-1,TRAV9-2,TRBV11-2,1628,4914.0,288.0,5.860806
AAACCTGAGGCATTGG-1-GSM9061665_S1,GSM9061665,PT1,Baseline,Responder,1313,GSM9061665_S1,AAACCTGAGGCATTGG-1,AAACCTGAGGCATTGG-1,CAVTGFSDGQKLLF,CASSLTGEVWDEQFF,TRAJ16,TRBJ2-1,TRAV8-6,TRBV5-1,1313,4947.0,198.0,4.002426


CPU times: user 3.41 s, sys: 928 ms, total: 4.33 s
Wall time: 4.35 s


In [18]:
# MEMORY TIP: Save intermediate results to avoid reprocessing
# Uncomment the lines below to save the concatenated data before TCR integration

# output_dir = Path("/kaggle/working/Output") if IS_KAGGLE else Path("../Output")
# output_dir.mkdir(exist_ok=True, parents=True)
# checkpoint_file = output_dir / "adata_concatenated_checkpoint.h5ad"
# 
# print(f"Saving checkpoint to {checkpoint_file}...")
# adata.write_h5ad(checkpoint_file, compression='gzip')
# print(f"Checkpoint saved! File size: {checkpoint_file.stat().st_size / 1024**2:.2f} MB")
# 
# # To load this checkpoint later, use:
# # adata = sc.read_h5ad(checkpoint_file)

print("Proceeding with TCR integration...")

Proceeding with TCR integration...


In [19]:
# --- Show basic statistics about the dataset ---

# Validate that adata exists
if 'adata' not in globals():
    raise NameError("adata is not defined. Please run the data loading and QC cells first.")
if adata is None:
    raise ValueError("adata is None. Data loading may have failed.")

print("=== Dataset Statistics ===")
print(f"Total cells: {adata.n_obs}")
print(f"Total genes: {adata.n_vars}")

if 'sample_id' in adata.obs.columns:
    print(f"\nSamples: {adata.obs['sample_id'].nunique()}")
    print(adata.obs['sample_id'].value_counts())

if 'patient_id' in adata.obs.columns:
    print(f"\nPatients: {adata.obs['patient_id'].nunique()}")
    print(adata.obs['patient_id'].value_counts())

if 'response' in adata.obs.columns:
    print(f"\nResponse distribution:")
    print(adata.obs['response'].value_counts())

if 'timepoint' in adata.obs.columns:
    print(f"\nTimepoint distribution:")
    print(adata.obs['timepoint'].value_counts())

=== Dataset Statistics ===
Total cells: 38413
Total genes: 14816

Samples: 10
sample_id
GSM9061670    5310
GSM9061674    5070
GSM9061673    5045
GSM9061671    4838
GSM9061665    4008
GSM9061666    3855
GSM9061675    3774
GSM9061667    3127
GSM9061668    2471
GSM9061669     915
Name: count, dtype: int64

Patients: 4
patient_id
PT4    13889
PT1    10990
PT3    10148
PT2     3386
Name: count, dtype: int64

Response distribution:
response
Non-Responder    24037
Responder        14376
Name: count, dtype: int64

Timepoint distribution:
timepoint
Baseline      16834
Post-Tx       14678
Recurrence     6901
Name: count, dtype: int64


## 5. Install Additional Libraries for Advanced ML and Visualization

Install and import libraries such as XGBoost, TensorFlow/Keras, scipy, and additional visualization tools for comprehensive ML analysis.

In [20]:
%%time
# --- Install required packages for genetic sequence encoding and ML ---
%pip install biopython --quiet
%pip install scikit-learn --quiet
%pip install umap-learn --quiet
%pip install hdbscan --quiet
%pip install plotly --quiet
%pip install xgboost --quiet
%pip install tensorflow --quiet

from Bio.Seq import Seq
from Bio.SeqUtils import ProtParam
import xgboost as xgb
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

# Import scipy for hierarchical clustering
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.spatial.distance import pdist
from scipy.stats import mannwhitneyu

# Visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import ConfusionMatrixDisplay

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
from sklearn.metrics import silhouette_score, adjusted_rand_score, classification_report, confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score, accuracy_score, classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold, GridSearchCV
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
import umap
import hdbscan
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

print("Additional libraries installed!")

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


2026-02-06 05:40:43.117204: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770356443.295466      24 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770356443.348594      24 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770356443.802292      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770356443.802324      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770356443.802328      24 computation_placer.cc:177] computation placer alr

Additional libraries installed!
CPU times: user 27.2 s, sys: 2.73 s, total: 30 s
Wall time: 1min


In [21]:
# --- GPU acceleration helper (minimal, safe) ---
# Detect GPUs for TensorFlow, enable memory growth and mixed precision if available.
# Detect XGBoost GPU support and cuML availability.
# Provide a function _apply_gpu_patches() that will patch `models_eval` and `param_grids` in-place when they exist.

TF_GPU_AVAILABLE = False
MIXED_PRECISION_AVAILABLE = False
XGBOOST_GPU_AVAILABLE = False
CUML_AVAILABLE = False

try:
    import tensorflow as tf
    gpus = tf.config.list_physical_devices('GPU')
    TF_GPU_AVAILABLE = len(gpus) > 0
    if TF_GPU_AVAILABLE:
        print("TensorFlow GPUs detected:", gpus)
        try:
            for g in gpus:
                tf.config.experimental.set_memory_growth(g, True)
            print("Set memory growth for TensorFlow GPUs.")
        except Exception as e:
            print("Could not set memory growth:", e)
        # Try enabling mixed precision for faster FP16 compute on modern GPUs
        try:
            from tensorflow.keras import mixed_precision
            mixed_precision.set_global_policy('mixed_float16')
            MIXED_PRECISION_AVAILABLE = True
            print("Enabled mixed precision (mixed_float16).")
        except Exception as e:
            print("Mixed precision policy not enabled:", e)
    else:
        print("No TensorFlow GPU detected.")
except Exception as e:
    print("TensorFlow import failed or no GPUs:", e)

# XGBoost GPU detection - supports both old (gpu_hist) and new (device='cuda') APIs
XGBOOST_GPU_METHOD = None  # Will be 'device' for XGBoost 2.0+, 'tree_method' for older versions
try:
    import xgboost as xgb
    xgb_version = tuple(int(x) for x in xgb.__version__.split('.')[:2])
    print(f"XGBoost version: {xgb.__version__}")
    
    # XGBoost 2.0+ uses device='cuda', older uses tree_method='gpu_hist'
    if xgb_version >= (2, 0):
        try:
            # Test new API
            _ = xgb.XGBClassifier(device='cuda', n_estimators=1)
            XGBOOST_GPU_AVAILABLE = True
            XGBOOST_GPU_METHOD = 'device'
            print("XGBoost GPU support detected (device='cuda' API).")
        except Exception as e:
            print(f"XGBoost 2.0+ GPU not available: {e}")
    else:
        try:
            # Test old API
            _ = xgb.XGBClassifier(tree_method='gpu_hist', predictor='gpu_predictor', n_estimators=1)
            XGBOOST_GPU_AVAILABLE = True
            XGBOOST_GPU_METHOD = 'tree_method'
            print("XGBoost GPU support detected (tree_method='gpu_hist' API).")
        except Exception as e:
            print(f"XGBoost GPU not available: {e}")
except Exception as e:
    print("XGBoost not importable:", e)

# cuML detection
try:
    import cuml
    CUML_AVAILABLE = True
    print("cuML is available.")
except Exception:
    CUML_AVAILABLE = False

# Utility: robust getter for adata.obsm with mask and padding
def _get_obsm_or_zeros(adata, key, mask=None, n_cols=0):
    """
    Return adata.obsm[key][mask] if present, otherwise zeros(shape=(n_rows, n_cols)).
    Ensures output is a dense numpy array with n_cols columns (pads with zeros if needed).
    """
    import numpy as _np
    # Determine number of rows requested
    if mask is not None:
        try:
            n_rows = int(mask.sum()) if hasattr(mask, 'sum') else int(sum(1 for v in mask if v))
        except Exception:
            n_rows = int(sum(1 for v in mask if v))
    else:
        n_rows = getattr(adata, 'n_obs', adata.shape[0]) if 'adata' in globals() else 0

    if key in getattr(adata, 'obsm', {}):
        arr = adata.obsm[key]
        try:
            if hasattr(arr, 'toarray'):
                arr = arr.toarray()
            arr = _np.asarray(arr)
        except Exception:
            return _np.zeros((n_rows, n_cols))
        # Apply mask if provided
        if mask is not None:
            try:
                arr = arr[mask]
            except Exception:
                arr = _np.array(arr)[mask]
        # Pad or trim columns to n_cols if requested
        if n_cols:
            if arr.shape[1] < n_cols:
                pad = _np.zeros((arr.shape[0], n_cols - arr.shape[1]))
                arr = _np.hstack([arr, pad])
            elif arr.shape[1] > n_cols:
                arr = arr[:, :n_cols]
        return arr
    else:
        return _np.zeros((n_rows, n_cols))

# Define sensible default param_grids early so LOPO can see them (will be overridden later if redefined)
param_grids = {
    'Logistic Regression': {
        'C': [0.01, 0.1, 1, 10, 100],
        'penalty': ['l2'],
        'solver': ['liblinear']
    },
    'Decision Tree': {
        'max_depth': [5, 10, 20, None],
        'min_samples_split': [2, 5, 10],
        'min_samples_leaf': [1, 2, 4]
    },
    'Random Forest': {
        'n_estimators': [50, 100, 200],
        'max_depth': [10, 20, None],
        'min_samples_split': [2, 5, 10]
    },
    'XGBoost': {
        'n_estimators': [50, 100, 200],
        'max_depth': [3, 6, 9],
        'learning_rate': [0.01, 0.1, 0.3],
        'subsample': [0.8, 1.0],
        'colsample_bytree': [0.6, 0.8, 1.0]
    }
}
print("Default param_grids defined early (can be overridden later).")

# Patching helper (improved with signature filtering and XGBoost 2.0+ support)
def _apply_gpu_patches():
    import inspect
    try:
        # Check if models_eval exists before trying to access it
        if 'models_eval' not in globals():
            return  # Nothing to patch yet
            
        models_eval_ref = globals()['models_eval']
        
        # Patch XGBoost model to use GPU params when available and supported
        if 'XGBoost' in models_eval_ref and XGBOOST_GPU_AVAILABLE:
            try:
                import xgboost as xgb_mod
                m = models_eval_ref['XGBoost']
                params = m.get_params() if hasattr(m, 'get_params') else {}
                # Determine class to instantiate (prefer wrapper if provided)
                XGBClass = globals().get('XGBClassifierSK', getattr(xgb_mod, 'XGBClassifier', None))
                if XGBClass is None:
                    raise ImportError('xgboost.XGBClassifier not found')
                # Build filtered params list based on constructor signature
                sig = inspect.signature(XGBClass.__init__)
                accepts_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
                allowed = set(sig.parameters.keys())
                filtered_params = {}
                for k, v in params.items():
                    if accepts_kwargs or k in allowed:
                        filtered_params[k] = v
                
                # Add GPU params based on XGBoost version (2.0+ uses device, older uses tree_method)
                xgb_gpu_method = globals().get('XGBOOST_GPU_METHOD', 'tree_method')
                if xgb_gpu_method == 'device':
                    # XGBoost 2.0+ API
                    if accepts_kwargs or 'device' in allowed:
                        filtered_params['device'] = 'cuda'
                    # Remove old-style params if present
                    filtered_params.pop('tree_method', None)
                    filtered_params.pop('predictor', None)
                else:
                    # Old XGBoost API
                    if accepts_kwargs or 'tree_method' in allowed:
                        filtered_params['tree_method'] = 'gpu_hist'
                    if accepts_kwargs or 'predictor' in allowed:
                        filtered_params['predictor'] = 'gpu_predictor'
                
                # Remove unsupported keys
                filtered_params.pop('gpu_id', None)
                try:
                    models_eval_ref['XGBoost'] = XGBClass(**filtered_params)
                    print(f"Patched models_eval['XGBoost'] to use GPU (method={xgb_gpu_method}).")
                except TypeError as e:
                    # Fallback: try removing GPU-specific params and re-instantiate
                    for k in ['tree_method', 'predictor', 'device']:
                        filtered_params.pop(k, None)
                    fallback_params = {k: v for k, v in filtered_params.items() if accepts_kwargs or k in allowed}
                    models_eval_ref['XGBoost'] = XGBClass(**fallback_params)
                    print("Patched models_eval['XGBoost'] without GPU params due to TypeError:", e)
            except Exception as e:
                print("Failed to patch models_eval['XGBoost']:", e)
            # Patch Random Forest to use n_jobs=-1 when possible
            if 'Random Forest' in models_eval_ref:
                try:
                    from sklearn.ensemble import RandomForestClassifier
                    m = models_eval_ref['Random Forest']
                    params = m.get_params() if hasattr(m, 'get_params') else {}
                    params.setdefault('n_jobs', -1)
                    RFC = RandomForestClassifier
                    sig_rfc = inspect.signature(RFC.__init__)
                    accepts_kwargs_rfc = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig_rfc.parameters.values())
                    allowed_rfc = set(sig_rfc.parameters.keys())
                    filtered_rfc_params = {k: v for k, v in params.items() if accepts_kwargs_rfc or k in allowed_rfc}
                    models_eval_ref['Random Forest'] = RandomForestClassifier(**filtered_rfc_params)
                    print("Patched models_eval['Random Forest'] to use n_jobs=-1.")
                except Exception as e:
                    print("Failed to patch models_eval['Random Forest']:", e)
    except Exception as e:
        print("Error patching models_eval:", e)

    # Patch param_grids for XGBoost if available
    try:
        if 'param_grids' in globals() and XGBOOST_GPU_AVAILABLE:
            pg = param_grids.get('XGBoost', {})
            xgb_gpu_method = globals().get('XGBOOST_GPU_METHOD', 'tree_method')
            if xgb_gpu_method == 'device':
                # XGBoost 2.0+ uses device parameter
                if any(k.startswith('clf__') for k in pg.keys()):
                    pg.setdefault('clf__device', ['cuda'])
                else:
                    pg.setdefault('device', ['cuda'])
            else:
                # Old XGBoost uses tree_method
                if any(k.startswith('clf__') for k in pg.keys()):
                    pg.setdefault('clf__tree_method', ['gpu_hist'])
                    pg.setdefault('clf__predictor', ['gpu_predictor'])
                else:
                    pg.setdefault('tree_method', ['gpu_hist'])
                    pg.setdefault('predictor', ['gpu_predictor'])
            param_grids['XGBoost'] = pg
            print(f"Patched param_grids['XGBoost'] with GPU options (method={xgb_gpu_method}).")
    except Exception as e:
        print("Error patching param_grids:", e)

# Apply patches now if models/param grids already defined
_apply_gpu_patches()

print(f"TF_GPU_AVAILABLE={TF_GPU_AVAILABLE}, MIXED_PRECISION={MIXED_PRECISION_AVAILABLE}, XGBOOST_GPU_AVAILABLE={XGBOOST_GPU_AVAILABLE}, CUML_AVAILABLE={CUML_AVAILABLE}")
print("If models or param_grids are defined later, call _apply_gpu_patches() to apply GPU settings.")

TensorFlow GPUs detected: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Set memory growth for TensorFlow GPUs.
Enabled mixed precision (mixed_float16).
XGBoost version: 3.1.0
XGBoost GPU support detected (device='cuda' API).
Default param_grids defined early (can be overridden later).
TF_GPU_AVAILABLE=True, MIXED_PRECISION=True, XGBOOST_GPU_AVAILABLE=True, CUML_AVAILABLE=False
If models or param_grids are defined later, call _apply_gpu_patches() to apply GPU settings.


In [22]:
# --- Define a sklearn-compatible XGBoost wrapper (supports both old and new XGBoost APIs) ---
try:
    import xgboost as xgb
    xgb_version = tuple(int(x) for x in xgb.__version__.split('.')[:2])
    
    class XGBClassifierSK(xgb.XGBClassifier):
        """XGBoost wrapper that handles both old (tree_method) and new (device) APIs."""
        def __init__(self, n_estimators=100, learning_rate=0.1, max_depth=6, random_state=None,
                     use_label_encoder=False, eval_metric='logloss',
                     tree_method=None, predictor=None, device=None, **kwargs):
            # Handle XGBoost 2.0+ API vs older versions
            if xgb_version >= (2, 0):
                # New API: use 'device' parameter
                if device is not None:
                    kwargs['device'] = device
                # tree_method and predictor are deprecated in 2.0+
            else:
                # Old API: use tree_method/predictor
                if tree_method is not None:
                    kwargs.setdefault('tree_method', tree_method)
                if predictor is not None:
                    kwargs.setdefault('predictor', predictor)
            
            # Remove deprecated parameters that might cause warnings
            kwargs.pop('use_label_encoder', None)
            
            super().__init__(n_estimators=n_estimators, learning_rate=learning_rate, max_depth=max_depth,
                             random_state=random_state, eval_metric=eval_metric, **kwargs)
    
    globals()['XGBClassifierSK'] = XGBClassifierSK
    print(f'Defined XGBoost sklearn-compatible wrapper: XGBClassifierSK (XGBoost version {xgb.__version__})')
except Exception as e:
    print('Failed to define XGBClassifierSK:', e)


Defined XGBoost sklearn-compatible wrapper: XGBClassifierSK (XGBoost version 3.1.0)


In [23]:
from pathlib import Path
import os

# Determine data directory consistently (prefer existing download_dir when present)
if 'download_dir' in globals() and download_dir:
    data_dir = Path(download_dir)
elif IS_KAGGLE:
    data_dir = Path('/kaggle/working/Data')
else:
    data_dir = Path('../Data')

raw_data_dir = data_dir / 'GSE300475_RAW'
raw_data_dir = raw_data_dir.resolve()

# Ensure directory exists (no-op if not writing yet)
os.makedirs(raw_data_dir, exist_ok=True)
print(f"Using raw_data_dir = {raw_data_dir}")

Using raw_data_dir = /kaggle/working/Data/GSE300475_RAW


In [24]:
# --- Auto-apply GPU patches when LOPO is instantiated ---
try:
    import sklearn.model_selection as _skms
    if not getattr(_skms, '_LO_patched_applied', False):
        _LO_orig = _skms.LeaveOneGroupOut
        class _LO_patched(_LO_orig):
            def __init__(self, *args, **kwargs):
                # Ensure GPU patches are applied just before LOPO is constructed
                try:
                    _apply_gpu_patches()
                except Exception as _e:
                    print('Warning: _apply_gpu_patches failed during LOPO patching:', _e)
                super().__init__(*args, **kwargs)
        _skms.LeaveOneGroupOut = _LO_patched
        _skms._LO_patched_applied = True
        print('Patched sklearn.model_selection.LeaveOneGroupOut to auto-apply GPU patches on init')
    else:
        print('LOPO patch already applied')
except Exception as e:
    print('Failed to apply LOPO patch:', e)

Patched sklearn.model_selection.LeaveOneGroupOut to auto-apply GPU patches on init


In [25]:
# --- Data Loading (Robust) ---
import scanpy as sc
import os
import glob
import pandas as pd
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

def _first_existing(paths):
    for p in paths:
        if p and os.path.exists(p):
            return p
    return None

def _glob_pick(folder, patterns, key=None):
    matches = []
    for pat in patterns:
        matches.extend(glob.glob(os.path.join(folder, pat)))
    matches = sorted(set(matches))
    if key:
        key_matches = [m for m in matches if key in os.path.basename(m)]
        if len(key_matches) == 1:
            return key_matches[0]
        if len(key_matches) > 1:
            return key_matches[0]
    if len(matches) == 1:
        return matches[0]
    return None

print("Starting Data Loading...")

# Determine data directory (using extract_dir from Cell 7 if available)
if 'extract_dir' not in globals():
    # Fallback path logic matching Cell 7/8
    base_dir = '/kaggle/working/Data' if IS_KAGGLE else '../Data'
    extract_dir = os.path.join(base_dir, "GSE300475_RAW")

if not os.path.exists(extract_dir):
    print(f"Warning: Directory {extract_dir} does not exist. Please ensure Cell 7 ran successfully.")
else:
    print(f"Searching for data in: {extract_dir}")
    # Find all matrix files
    matrix_files = glob.glob(os.path.join(extract_dir, "*matrix.mtx*"))
    # Also look recursively if structure is nested
    if not matrix_files:
        matrix_files = glob.glob(os.path.join(extract_dir, "**", "*matrix.mtx*"), recursive=True)

    adata_list = []
    
    if not matrix_files:
        print("No matrix.mtx files found or previously loaded.")
        # Check if we can proceed? If this is a re-run, adata might exist.
    else:
        for mat_file in matrix_files:
            try:
                print(f"Processing {os.path.basename(mat_file)}...")
                # Handle formatted loading
                # If file is standard 10x-like (matrix.mtx, genes.tsv, barcodes.tsv) in same folder
                folder = os.path.dirname(mat_file)
                prefix = os.path.basename(mat_file).replace('matrix.mtx', '').replace('.gz', '')
                key = prefix.strip('_')
                
                # Check for accompanying files with same prefix
                genes_path = _first_existing([
                    os.path.join(folder, prefix + 'genes.tsv'),
                    os.path.join(folder, prefix + 'features.tsv'),
                    os.path.join(folder, prefix + 'genes.tsv.gz'),
                    os.path.join(folder, prefix + 'features.tsv.gz'),
                ])
                
                barcodes_path = _first_existing([
                    os.path.join(folder, prefix + 'barcodes.tsv'),
                    os.path.join(folder, prefix + 'barcodes.tsv.gz'),
                ])

                # Fallback to un-prefixed standard 10x naming
                if not genes_path:
                    genes_path = _first_existing([
                        os.path.join(folder, 'genes.tsv'),
                        os.path.join(folder, 'features.tsv'),
                        os.path.join(folder, 'genes.tsv.gz'),
                        os.path.join(folder, 'features.tsv.gz'),
                    ])

                if not barcodes_path:
                    barcodes_path = _first_existing([
                        os.path.join(folder, 'barcodes.tsv'),
                        os.path.join(folder, 'barcodes.tsv.gz'),
                    ])

                # Fallback to any matching files in the folder (use key if present)
                if not genes_path:
                    genes_path = _glob_pick(folder, ['*genes.tsv*', '*features.tsv*'], key=key)
                if not barcodes_path:
                    barcodes_path = _glob_pick(folder, ['*barcodes.tsv*'], key=key)

                if genes_path and barcodes_path and os.path.exists(genes_path) and os.path.exists(barcodes_path):
                    # Load using read_mtx for flexibility with filenames
                    adata_sample = sc.read_mtx(mat_file).T
                    
                    # Annotation
                    genes = pd.read_csv(genes_path, sep='\t', header=None)
                    barcodes = pd.read_csv(barcodes_path, sep='\t', header=None)
                    
                    # Assign var/obs names and sanitize whitespace
                    if genes.shape[1] > 1:
                        var_names = genes.iloc[:,1].astype(str).str.strip().values
                        adata_sample.var['gene_ids'] = genes.iloc[:,0].astype(str).values
                    else:
                        var_names = genes.iloc[:,0].astype(str).str.strip().values
                    adata_sample.var_names = pd.Index(var_names)
                    adata_sample.obs_names = pd.Index(barcodes.iloc[:,0].astype(str).str.strip().values)
                    adata_sample.obs['sample_id'] = prefix.strip('_') if prefix else os.path.basename(folder)
                    
                    # Ensure uniqueness within sample to avoid concat Index errors
                    try:
                        adata_sample.var_names_make_unique()
                        adata_sample.obs_names_make_unique()
                    except Exception:
                        pass
                    
                    adata_list.append(adata_sample)
                    print(f"Loaded {adata_sample.shape[0]} cells from {prefix or folder}")
                else:
                    print(f"Skipping {mat_file}: Missing genes/barcodes files (searched prefix '{prefix}' and fallbacks)")
            except Exception as e:
                print(f"Error loading {mat_file}: {e}")

        # Pre-sanitize all adata samples before concatenation
        for a in adata_list:
            try:
                a.var_names = pd.Index([str(v).strip() for v in a.var_names])
                a.var_names_make_unique()
                a.obs_names = pd.Index([str(v).strip() for v in a.obs_names])
                a.obs_names_make_unique()
            except Exception:
                pass

        if adata_list:
            # Concatenate all samples
            try:
                adata = sc.concat(adata_list, join='outer')
            except Exception as e:
                print('sc.concat failed:', e)
                # Try fallback using AnnData.concatenate with batch info
                try:
                    loaded_batches = [a.obs['sample_id'].unique()[0] for a in adata_list]
                except Exception:
                    loaded_batches = None
                try:
                    if loaded_batches:
                        adata = sc.AnnData.concatenate(*adata_list, join='outer', batch_key='sample_id', batch_categories=loaded_batches)
                    else:
                        adata = sc.AnnData.concatenate(*adata_list, join='outer', batch_key='sample_id')
                except Exception as e2:
                    raise RuntimeError(f"Failed to concatenate AnnData objects: {e}; fallback failed: {e2}")
            adata.obs_names_make_unique()
            # Basic fallback for mitochondrial genes logic (used later)
            adata.var['mt'] = adata.var_names.str.startswith('MT-')
            sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
            print(f"Combined AnnData object created: {adata.shape}")
        else:
            print("Warning: No valid data loaded into adata.")

# Ensure adata exists to prevent downstream crashes
if 'adata' not in globals():
    print("CRITICAL CHECK: adata variable not defined. Downstream cells will fail.")


Starting Data Loading...
Searching for data in: /kaggle/working/Data/GSE300475_RAW
Processing GSM9061673_S9_matrix.mtx.gz...
Loaded 11480 cells from GSM9061673_S9_
Processing GSM9061669_S5_matrix.mtx.gz...
Loaded 2912 cells from GSM9061669_S5_
Processing GSM9061674_S10_matrix.mtx.gz...
Loaded 9704 cells from GSM9061674_S10_
Processing GSM9061671_S7_matrix.mtx.gz...
Loaded 9330 cells from GSM9061671_S7_
Processing GSM9061670_S6_matrix.mtx.gz...
Loaded 10398 cells from GSM9061670_S6_
Processing GSM9061668_S4_matrix.mtx.gz...
Loaded 8723 cells from GSM9061668_S4_
Processing GSM9061675_S11_matrix.mtx.gz...
Loaded 9330 cells from GSM9061675_S11_
Processing GSM9061667_S3_matrix.mtx.gz...
Loaded 7358 cells from GSM9061667_S3_
Processing GSM9061666_S2_matrix.mtx.gz...
Loaded 9069 cells from GSM9061666_S2_
Processing GSM9061665_S1_matrix.mtx.gz...
Loaded 8931 cells from GSM9061665_S1_
Processing GSM9061672_S8_matrix.mtx.gz...
Loaded 12832 cells from GSM9061672_S8_
Combined AnnData object create

## 6. Genetic Sequence Encoding Functions

Define functions for one-hot encoding, k-mer encoding, and physicochemical features extraction for TCR sequences and gene expression patterns.

In [26]:
%%time
# --- Genetic Sequence Encoding Functions ---

def one_hot_encode_sequence(sequence, max_length=50, alphabet='ACDEFGHIKLMNPQRSTVWY'):
    """
    One-hot encode a protein/nucleotide sequence.
    Args:
        sequence: String sequence to encode
        max_length: Maximum sequence length (pad or truncate)
        alphabet: Valid characters in the sequence
    Returns:
        2D numpy array of shape (max_length, len(alphabet))
    """
    if pd.isna(sequence) or sequence == 'NA' or sequence == '':
        return np.zeros((max_length, len(alphabet)))
    
    sequence = str(sequence).upper()[:max_length]  # Truncate if too long
    encoding = np.zeros((max_length, len(alphabet)))
    
    for i, char in enumerate(sequence):
        if char in alphabet:
            char_idx = alphabet.index(char)
            encoding[i, char_idx] = 1
    
    return encoding

def kmer_encode_sequence(sequence, k=3, alphabet='ACDEFGHIKLMNPQRSTVWY'):
    """
    K-mer encoding of sequences.
    """
    if pd.isna(sequence) or sequence == 'NA' or sequence == '':
        return {}
    
    sequence = str(sequence).upper()
    kmers = [sequence[i:i+k] for i in range(len(sequence)-k+1)]
    valid_kmers = [kmer for kmer in kmers if all(c in alphabet for c in kmer)]
    
    return Counter(valid_kmers)

def physicochemical_features(sequence):
    """
    Extract physicochemical properties from protein sequences.
    """
    if pd.isna(sequence) or sequence == 'NA' or sequence == '':
        return {
            'length': 0, 'molecular_weight': 0, 'aromaticity': 0,
            'instability_index': 0, 'isoelectric_point': 0, 'hydrophobicity': 0
        }
    
    try:
        seq = str(sequence).upper()
        # Remove non-standard amino acids
        seq = ''.join([c for c in seq if c in 'ACDEFGHIKLMNPQRSTVWY'])
        
        if len(seq) == 0:
            return {
                'length': 0, 'molecular_weight': 0, 'aromaticity': 0,
                'instability_index': 0, 'isoelectric_point': 0, 'hydrophobicity': 0
            }
        
        bio_seq = Seq(seq)
        analyzer = ProtParam.ProteinAnalysis(str(bio_seq))
        
        return {
            'length': len(seq),
            'molecular_weight': analyzer.molecular_weight(),
            'aromaticity': analyzer.aromaticity(),
            'instability_index': analyzer.instability_index(),
            'isoelectric_point': analyzer.isoelectric_point(),
            'hydrophobicity': analyzer.gravy()
        }
    except:
        return {
            'length': len(str(sequence)) if not pd.isna(sequence) else 0,
            'molecular_weight': 0, 'aromaticity': 0,
            'instability_index': 0, 'isoelectric_point': 0, 'hydrophobicity': 0
        }

def encode_gene_expression_patterns(adata, n_top_genes=1000, train_mask=None):
    """
    Encode gene expression patterns using various dimensionality reduction techniques.
    
    IMPORTANT: To avoid data leakage, pass train_mask to fit transformers only on training data.
    If train_mask is None, fits on all data (use only for exploration, not CV).
    
    Returns:
        tuple: (encodings dict, X_scaled array)
    """
    import numpy as np
    from sklearn.preprocessing import StandardScaler
    from sklearn.decomposition import PCA, TruncatedSVD
    import umap
    
    # Get highly variable genes if not already computed
    if 'highly_variable' not in adata.var.columns:
        sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, subset=False)
    
    # Extract expression matrix for highly variable genes - ensure boolean mask
    hvg_mask = np.array(adata.var['highly_variable'].values, dtype=bool)
    
    # Subset adata by HVG mask
    X_full = adata.X
    if hasattr(X_full, 'toarray'):
        X_full = X_full.toarray()
    else:
        X_full = np.asarray(X_full)
    
    # Select HVG columns
    X_hvg = X_full[:, hvg_mask]
    
    # Clean Infs/NaNs (robustness fix)
    X_hvg = np.nan_to_num(X_hvg, nan=0.0, posinf=0.0, neginf=0.0)

    # Standardize the data - FIT ONLY ON TRAINING DATA if mask provided
    scaler = StandardScaler()
    if train_mask is not None:
        scaler.fit(X_hvg[train_mask])
        X_scaled = scaler.transform(X_hvg)
    else:
        X_scaled = scaler.fit_transform(X_hvg)
    
    encodings = {}
    
    # PCA encoding - FIT ONLY ON TRAINING DATA if mask provided
    n_pca = min(50, X_scaled.shape[1], X_scaled.shape[0])
    pca = PCA(n_components=n_pca)
    if train_mask is not None:
        pca.fit(X_scaled[train_mask])
        encodings['pca'] = pca.transform(X_scaled)
    else:
        encodings['pca'] = pca.fit_transform(X_scaled)
    
    # TruncatedSVD for sparse matrices
    n_svd = min(50, X_scaled.shape[1] - 1, X_scaled.shape[0] - 1)
    if n_svd > 0:
        svd = TruncatedSVD(n_components=n_svd, random_state=42)
        if train_mask is not None:
            svd.fit(X_scaled[train_mask])
            encodings['svd'] = svd.transform(X_scaled)
        else:
            encodings['svd'] = svd.fit_transform(X_scaled)
    else:
        encodings['svd'] = np.zeros((X_scaled.shape[0], 1))
    
    # UMAP encoding - UMAP doesn't support clean fit/transform easily for this pipeline, usually unsupervised
    try:
        umap_encoder = umap.UMAP(n_components=20, random_state=42)
        encodings['umap'] = umap_encoder.fit_transform(X_scaled)
    except Exception as e:
        print(f"UMAP failed: {e}")
        encodings['umap'] = np.zeros((X_scaled.shape[0], 20))
    
    return encodings, X_scaled

print("Genetic sequence encoding functions defined successfully!")

Genetic sequence encoding functions defined successfully!
CPU times: user 71 ¬µs, sys: 4 ¬µs, total: 75 ¬µs
Wall time: 78.4 ¬µs


## 7. Apply Sequence Encoding to TCR CDR3 Sequences

Encode TRA and TRB CDR3 sequences using one-hot, k-mer, and physicochemical methods, and add to AnnData.obsm and obs.

In [27]:
%%time
# --- MEMORY-OPTIMIZED TCR Sequence Encoding ---

# Validate that adata exists
if 'adata' not in globals():
    raise NameError("adata is not defined. Please run the data loading and QC cells first.")
if adata is None:
    raise ValueError("adata is None. Data loading may have failed.")

print("Starting memory-optimized TCR sequence encoding...")
import gc

# Extract TCR CDR3 sequences with robust column handling
# Check for column existence and normalize naming
if 'cdr3_TRA' not in adata.obs.columns:
    if 'CDR3_TRA' in adata.obs.columns:
        adata.obs['cdr3_TRA'] = adata.obs['CDR3_TRA']
    else:
        print("Warning: cdr3_TRA column not found, creating empty column")
        adata.obs['cdr3_TRA'] = ''

if 'cdr3_TRB' not in adata.obs.columns:
    if 'CDR3_TRB' in adata.obs.columns:
        adata.obs['cdr3_TRB'] = adata.obs['CDR3_TRB']
    else:
        print("Warning: cdr3_TRB column not found, creating empty column")
        adata.obs['cdr3_TRB'] = ''

tra_seqs = adata.obs['cdr3_TRA'].fillna('').astype(str).values
trb_seqs = adata.obs['cdr3_TRB'].fillna('').astype(str).values

print(f"TRA sequences: {len(tra_seqs)}, TRB sequences: {len(trb_seqs)}")

# MEMORY FIX: Use smaller k-mer sizes and reduced dimensionality
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import TruncatedSVD

# Reduce k-mer size from 3 to 2 to reduce feature space
def _kmer_list(seq, k=2):  # Changed from k=3 to k=2
    if len(seq) < k:
        return []
    return [seq[i:i+k] for i in range(len(seq) - k + 1)]

# Convert sequences to k-mer strings
tra_kmer_docs = [' '.join(_kmer_list(s, k=2)) for s in tra_seqs]  # k=2 instead of 3
trb_kmer_docs = [' '.join(_kmer_list(s, k=2)) for s in trb_seqs]

# MEMORY FIX: Limit max features to reduce dimensionality
vec_tra = CountVectorizer(max_features=500)  # Limit to 500 features instead of unlimited
vec_trb = CountVectorizer(max_features=500)

tra_kmer_sparse = vec_tra.fit_transform(tra_kmer_docs)
trb_kmer_sparse = vec_trb.fit_transform(trb_kmer_docs)

# Clean up k-mer docs (no longer needed)
del tra_kmer_docs, trb_kmer_docs
gc.collect()

print(f"TRA k-mer sparse shape: {tra_kmer_sparse.shape}")
print(f"TRB k-mer sparse shape: {trb_kmer_sparse.shape}")

# MEMORY FIX: Reduce dimensions even further using SVD
def _reduce_sparse(sparse_mat, n_components=50):  # Reduced from 200 to 50
    n_comp = min(n_components, max(1, sparse_mat.shape[1]-1))
    try:
        svd = TruncatedSVD(n_components=n_comp, random_state=42)
        return svd.fit_transform(sparse_mat).astype(np.float32)  # Use float32
    except Exception:
        return sparse_mat.toarray().astype(np.float32) if hasattr(sparse_mat, 'toarray') else np.asarray(sparse_mat, dtype=np.float32)

tra_kmer_matrix = _reduce_sparse(tra_kmer_sparse, n_components=50)  # Reduced from 200
trb_kmer_matrix = _reduce_sparse(trb_kmer_sparse, n_components=50)
print(f"TRA k-mer reduced shape: {tra_kmer_matrix.shape}")
print(f"TRB k-mer reduced shape: {trb_kmer_matrix.shape}")

# Clean up sparse matrices
del tra_kmer_sparse, trb_kmer_sparse
gc.collect()

# MEMORY FIX: Reduced one-hot encoding with smaller max length
max_cdr3_length = 15  # Reduced from 20 to 15
alphabet = 'ACDEFGHIKLMNPQRSTVWY'
char_to_idx = {c:i for i,c in enumerate(alphabet)}

def _one_hot_encode_batch(sequences, max_len=max_cdr3_length):
    """Batch one-hot encoding using NumPy for memory efficiency."""
    n_seqs = len(sequences)
    encoding = np.zeros((n_seqs, max_len, len(alphabet)), dtype=np.float16)  # Use float16 instead of float32
    
    for i, seq in enumerate(sequences):
        seq_str = str(seq).upper()[:max_len]  # Truncate
        for j, char in enumerate(seq_str):
            if char in char_to_idx:
                encoding[i, j, char_to_idx[char]] = 1.0
    
    return encoding

tra_one_hot = _one_hot_encode_batch(tra_seqs, max_cdr3_length)
trb_one_hot = _one_hot_encode_batch(trb_seqs, max_cdr3_length)
print(f"TRA one-hot shape: {tra_one_hot.shape} (dtype: {tra_one_hot.dtype})")
print(f"TRB one-hot shape: {trb_one_hot.shape} (dtype: {trb_one_hot.dtype})")

# Clean up sequence arrays
del tra_seqs, trb_seqs
gc.collect()

# MEMORY FIX: Store matrices in obsm (compressed format in AnnData)
adata.obsm['X_tcr_tra_kmer'] = tra_kmer_matrix
adata.obsm['X_tcr_trb_kmer'] = trb_kmer_matrix

# MEMORY FIX: Flatten and store one-hot as float32 for compatibility
adata.obsm['X_tcr_tra_onehot'] = tra_one_hot.reshape(tra_one_hot.shape[0], -1).astype(np.float32)
adata.obsm['X_tcr_trb_onehot'] = trb_one_hot.reshape(trb_one_hot.shape[0], -1).astype(np.float32)

del tra_one_hot, trb_one_hot, tra_kmer_matrix, trb_kmer_matrix
gc.collect()

print("TCR sequence encoding complete and stored in adata.obsm")
print(f"Memory usage reduced by using sparse matrices and dimension reduction")

Starting memory-optimized TCR sequence encoding...
TRA sequences: 100067, TRB sequences: 100067
TRA k-mer sparse shape: (100067, 1)
TRB k-mer sparse shape: (100067, 1)
TRA k-mer reduced shape: (100067, 1)
TRB k-mer reduced shape: (100067, 1)
TRA one-hot shape: (100067, 15, 20) (dtype: float16)
TRB one-hot shape: (100067, 15, 20) (dtype: float16)
TCR sequence encoding complete and stored in adata.obsm
Memory usage reduced by using sparse matrices and dimension reduction
CPU times: user 2.34 s, sys: 58.9 ms, total: 2.4 s
Wall time: 2.39 s


## 8. Encode Gene Expression Patterns

Apply PCA, SVD, and UMAP to gene expression data for dimensionality reduction and add encodings to AnnData.

## Feature Engineering and Encoding
A core contribution of this work is the engineering of a comprehensive feature set that translates biological sequences into machine-readable vectors. We developed three distinct encoding schemes for the TCR CDR3 amino acid sequences:

1.  **One-Hot Encoding:** This method creates a sparse binary matrix representing the presence or absence of specific amino acids at each position in the sequence. It preserves exact positional information, which is crucial for structural motifs, but results in high-dimensional, sparse vectors.
2.  **K-mer Frequency Encoding:** We decomposed sequences into overlapping substrings of length $k$ (k-mers, with $k=3$). We then calculated the frequency of each unique 3-mer in the sequence. This approach captures short, local structural motifs (e.g., "CAS", "ASS") that may be shared across different TCRs with similar antigen specificity, regardless of their exact position.
3.  **Physicochemical Property Encoding:** To capture the biophysical nature of the TCR-antigen interaction, we mapped each amino acid to a vector of physicochemical properties, including hydrophobicity, molecular weight, isoelectric point, and polarity. We then aggregated these values (e.g., mean, sum) across the CDR3 sequence. This results in a dense, low-dimensional representation that reflects the "binding potential" of the receptor.

These TCR features were concatenated with the top 50 Principal Components (PCs) derived from the gene expression data to form the "Comprehensive" feature set.

In [28]:
%%time
# --- Apply Sequence Encoding to TCR CDR3 Sequences (vectorized k-mer + reduced one-hot) ---

print("Encoding TCR CDR3 sequences (vectorized k-mer + reduced one-hot)...")

# Extract and clean CDR3 sequences
if 'cdr3_TRA' in adata.obs.columns:
    cdr3_TRA = adata.obs['cdr3_TRA'].astype(str).fillna('').str.upper()
else:
    cdr3_TRA = pd.Series([''] * adata.n_obs, index=adata.obs.index)
if 'cdr3_TRB' in adata.obs.columns:
    cdr3_TRB = adata.obs['cdr3_TRB'].astype(str).fillna('').str.upper()
else:
    cdr3_TRB = pd.Series([''] * adata.n_obs, index=adata.obs.index)

valid_aa = 'ACDEFGHIKLMNPQRSTVWY'
def _clean_seq(s):
    return ''.join([c for c in str(s) if c in valid_aa])

tra_seqs = [_clean_seq(s) for s in cdr3_TRA]
trb_seqs = [_clean_seq(s) for s in cdr3_TRB]

# --- Vectorized k-mer encoding using CountVectorizer (sparse) ---
from sklearn.feature_extraction.text import CountVectorizer
k = 3
vec_tra = CountVectorizer(analyzer='char', ngram_range=(k,k))
vec_trb = CountVectorizer(analyzer='char', ngram_range=(k,k))
tra_kmer_sparse = vec_tra.fit_transform(tra_seqs)
trb_kmer_sparse = vec_trb.fit_transform(trb_seqs)
print(f"TRA k-mer sparse shape: {tra_kmer_sparse.shape}")
print(f"TRB k-mer sparse shape: {trb_kmer_sparse.shape}")

# Reduce k-mer sparse matrices with TruncatedSVD to a dense reduced representation (keeps memory low)
from sklearn.decomposition import TruncatedSVD
def _reduce_sparse(sparse_mat, n_components=200):
    n_comp = min(n_components, max(1, sparse_mat.shape[1]-1))
    try:
        svd = TruncatedSVD(n_components=n_comp, random_state=42)
        return svd.fit_transform(sparse_mat)
    except Exception:
        # Fallback to dense (small datasets)
        return sparse_mat.toarray() if hasattr(sparse_mat, 'toarray') else np.asarray(sparse_mat)

tra_kmer_matrix = _reduce_sparse(tra_kmer_sparse, n_components=200)
trb_kmer_matrix = _reduce_sparse(trb_kmer_sparse, n_components=200)
print(f"TRA k-mer reduced shape: {tra_kmer_matrix.shape}")
print(f"TRB k-mer reduced shape: {trb_kmer_matrix.shape}")

# --- Reduced one-hot encoding: limit max length to avoid huge dense matrices ---
max_cdr3_length = 20  # smaller to reduce dimensionality and memory
alphabet = 'ACDEFGHIKLMNPQRSTVWY'
char_to_idx = {c:i for i,c in enumerate(alphabet)}
def _onehot_flat_list(seqs, max_length, alphabet, char_to_idx):
    out = np.zeros((len(seqs), max_length * len(alphabet)), dtype=np.uint8)
    for i, s in enumerate(seqs):
        for j, ch in enumerate(s[:max_length]):
            if ch in char_to_idx:
                out[i, j * len(alphabet) + char_to_idx[ch]] = 1
    return out

tra_onehot_flat = _onehot_flat_list(tra_seqs, max_cdr3_length, alphabet, char_to_idx)
trb_onehot_flat = _onehot_flat_list(trb_seqs, max_cdr3_length, alphabet, char_to_idx)
print(f"TRA one-hot flat shape: {tra_onehot_flat.shape}")
print(f"TRB one-hot flat shape: {trb_onehot_flat.shape}")

# --- Physicochemical properties (unchanged) ---
tra_physico = pd.DataFrame([physicochemical_features(seq) for seq in tra_seqs])
trb_physico = pd.DataFrame([physicochemical_features(seq) for seq in trb_seqs])
print(f"TRA physicochemical features shape: {tra_physico.shape}")
print(f"TRB physicochemical features shape: {trb_physico.shape}")

# Add to AnnData object (reduced, memory-friendly)
adata.obsm['X_tcr_tra_onehot'] = tra_onehot_flat
adata.obsm['X_tcr_trb_onehot'] = trb_onehot_flat
adata.obsm['X_tcr_tra_kmer'] = tra_kmer_matrix
adata.obsm['X_tcr_trb_kmer'] = trb_kmer_matrix

# Add physicochemical features to obs
for col in tra_physico.columns:
    adata.obs[f'tra_{col}'] = tra_physico[col].values
for col in trb_physico.columns:
    adata.obs[f'trb_{col}'] = trb_physico[col].values

print("TCR sequence encoding completed and added to AnnData object!")

# Clean up large temporary objects
import gc
try:
    # delete sparse intermediates and local copies ‚Äî AnnData already stores the reduced matrices
    del tra_kmer_sparse, trb_kmer_sparse
except Exception:
    pass
try:
    # delete other large temporaries that have been copied into `adata.obsm` or `adata.obs`
    for _n in ['tra_kmer_matrix', 'trb_kmer_matrix', 'tra_onehot_flat', 'trb_onehot_flat', 'tra_physico', 'trb_physico', 'tra_seqs', 'trb_seqs', 'vec_tra', 'vec_trb', 'char_to_idx']:
        if _n in globals():
            try:
                del globals()[_n]
            except Exception:
                pass
except Exception:
    pass
gc.collect()

Encoding TCR CDR3 sequences (vectorized k-mer + reduced one-hot)...
TRA k-mer sparse shape: (100067, 1)
TRB k-mer sparse shape: (100067, 1)
TRA k-mer reduced shape: (100067, 1)
TRB k-mer reduced shape: (100067, 1)
TRA one-hot flat shape: (100067, 400)
TRB one-hot flat shape: (100067, 400)
TRA physicochemical features shape: (100067, 6)
TRB physicochemical features shape: (100067, 6)
TCR sequence encoding completed and added to AnnData object!
CPU times: user 1.39 s, sys: 31.9 ms, total: 1.42 s
Wall time: 1.42 s


0

## 9. Create Combined Multi-Modal Encodings

Combine gene expression and TCR encodings into multi-modal representations using PCA and UMAP.

In [29]:
%%time
# MEMORY-OPTIMIZED encode_gene_expression_patterns function using Scanpy's Native Optimized PCA
def encode_gene_expression_patterns(adata, n_top_genes=1500, train_mask=None):
    """
    Encode gene expression patterns using Scanpy's optimized sparse PCA (Arpack)
    and TruncatedSVD. Avoids manual chunking complexity which can be error prone.
    
    Returns:
        tuple: (encodings dict, X_scaled placeholder)
    """
    import numpy as np
    import gc
    from sklearn.decomposition import TruncatedSVD
    import umap
    
    gc.collect()
    
    # 1. HVG Selection (Memory Optimized: Subsample cells if huge)
    # Calculating mean/var on 100k cells x 30k genes can overlap memory.
    # We calculate on a subset of 20k cells to estimate HVGs.
    print("Selecting Highly Variable Genes...")
    
    if adata.n_obs > 20000 and 'highly_variable' not in adata.var.columns:
        # Subsample for HVG calculation only
        idx = np.random.choice(adata.n_obs, 20000, replace=False)
        temp_adata = adata[idx].copy()
        sc.pp.highly_variable_genes(temp_adata, n_top_genes=n_top_genes, subset=False, flavor='seurat')
        # Transfer results back
        adata.var['highly_variable'] = False
        adata.var.loc[temp_adata.var_names, 'highly_variable'] = temp_adata.var['highly_variable']
        del temp_adata
        gc.collect()
    elif 'highly_variable' not in adata.var.columns:
        sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, subset=False, flavor='seurat')

    # Get HVG Subset (Sparse View or Copy)
    # Scanpy handles views efficiently for PCA
    adata_hvg = adata[:, adata.var['highly_variable']]
    print(f"HVG subset shape: {adata_hvg.shape}")

    # 2. PCA (Scanpy Arpack = Sparse SVD on centered data implicitly)
    # This is much more stable than manual IncrementalPCA on disjoint chunks
    print("Computing PCA (Arpack - Sparse)...")
    sc.pp.pca(adata_hvg, n_comps=30, svd_solver='arpack', zero_center=True, use_highly_variable=False)
    X_pca = adata_hvg.obsm['X_pca']
    
    # 3. TruncatedSVD (LSA - No centering)
    # Good for sparse data comparison
    print("Computing TruncatedSVD...")
    # Use the sparse matrix from the view
    svd = TruncatedSVD(n_components=30, random_state=42, algorithm='randomized')
    X_svd = svd.fit_transform(adata_hvg.X)
    
    # 4. UMAP on PCA (Standard practice)
    print("Computing UMAP on PCA embeddings...")
    umap_reducer = umap.UMAP(
        n_components=10, 
        n_neighbors=15, 
        random_state=42, 
        n_jobs=1, 
        low_memory=True
    )
    X_umap = umap_reducer.fit_transform(X_pca)

    encodings = {
        'pca': X_pca.astype(np.float32),
        'svd': X_svd.astype(np.float32),
        'umap': X_umap.astype(np.float32)
    }
    
    # Clean up
    del adata_hvg
    gc.collect()
    
    # Return nothing for X_scaled (deprecated)
    return encodings, None

# --- Main Preprocessing Block ---

print("Preprocessing gene expression data...")
import scipy.sparse as sp

# 1. Ensure float32 sparse (Crucial for memory)
if not hasattr(adata.X, 'toarray'): # is sparse
    if adata.X.dtype != np.float32:
        print("Converting sparse matrix to float32...")
        adata.X = adata.X.astype(np.float32)
else: # is dense (shouldn't be, but just in case)
    print("Warning: Data is dense. converting to sparse float32...")
    adata.X = sp.csr_matrix(adata.X, dtype=np.float32)

gc.collect()

# 2. Normalize & Log (In-place)
print("Normalizing...")
sc.pp.normalize_total(adata, target_sum=1e4)
print("Log transforming...")
sc.pp.log1p(adata)
gc.collect()

# 3. Clean Infs/NaNs (In-place, memory safe)
if hasattr(adata.X, 'data'):
    mask = np.isinf(adata.X.data)
    if mask.any():
        adata.X.data[mask] = 0
        print(f"Fixed {mask.sum()} infinite values")
    mask = np.isnan(adata.X.data)
    if mask.any():
        adata.X.data[mask] = 0
        print(f"Fixed {mask.sum()} NaN values")

print("Encoding patterns...")

# Apply encoding
try:
    # Use reduced gene set (1500)
    result = encode_gene_expression_patterns(adata, n_top_genes=1500)
    gene_encodings = result[0]

    # Add to AnnData
    for key, val in gene_encodings.items():
        adata.obsm[f'X_gene_{key}'] = val
        print(f"  Added X_gene_{key}")

    del gene_encodings
    gc.collect()
    print("Gene expression encoding completed!")
    
except Exception as e:
    print(f"Error during encoding: {e}")
    import traceback
    traceback.print_exc()
    raise e

Preprocessing gene expression data...
Normalizing...
Log transforming...
Encoding patterns...
Selecting Highly Variable Genes...
HVG subset shape: (100067, 1500)
Computing PCA (Arpack - Sparse)...
Computing TruncatedSVD...
Computing UMAP on PCA embeddings...
  Added X_gene_pca
  Added X_gene_svd
  Added X_gene_umap
Gene expression encoding completed!
CPU times: user 3min 54s, sys: 2.08 s, total: 3min 56s
Wall time: 3min 22s


In [30]:
%%time
# --- Create Combined Multi-Modal Encodings ---
print("Creating combined multi-modal encodings...")
import gc
from scipy import sparse
from sklearn.decomposition import TruncatedSVD

# Combine different encoding modalities
# 1. Gene expression PCA + TCR physicochemical features
gene_pca = None
# Retrieve pre-computed PCA from gene_encodings dict
pca_data = gene_encodings.get('pca', None) if 'gene_encodings' in globals() and isinstance(gene_encodings, dict) else None

if pca_data is not None and isinstance(pca_data, (np.ndarray, list)):
    gene_pca = np.asarray(pca_data)
elif 'X_gene_pca' in adata.obsm:
    gene_pca = adata.obsm['X_gene_pca']

# Fallback or pad
if gene_pca is not None:
    if gene_pca.ndim == 1:
        gene_pca = gene_pca.reshape(-1, 1)
    if gene_pca.shape[1] >= 20:
        gene_pca = gene_pca[:, :20]
    else:
        # Pad to 20 components
        pad_cols = 20 - gene_pca.shape[1]
        gene_pca = np.pad(gene_pca, ((0, 0), (0, pad_cols)), mode='constant')
else:
    print("Warning: PCA data not available; using zeros.")
    gene_pca = np.zeros((adata.n_obs, 20))

tcr_features = np.column_stack([
    adata.obs[['tra_length', 'tra_molecular_weight', 'tra_hydrophobicity']].fillna(0).values,
    adata.obs[['trb_length', 'trb_molecular_weight', 'trb_hydrophobicity']].fillna(0).values
])

combined_gene_tcr = np.column_stack([gene_pca, tcr_features])
adata.obsm['X_combined_gene_tcr'] = combined_gene_tcr.astype(np.float32)

# 2. Gene expression UMAP + TCR k-mer features (reduced)
gene_umap = gene_encodings.get('umap', None) if 'gene_encodings' in globals() and isinstance(gene_encodings, dict) else None
if gene_umap is None and 'X_gene_umap' in adata.obsm:
    gene_umap = adata.obsm['X_gene_umap']
if gene_umap is None:
    gene_umap = np.zeros((adata.n_obs, 2))

# Stack TRA and TRB k-mer matrices EFFICIENTLY
tra_kmer = adata.obsm.get('X_tcr_tra_kmer', None)
trb_kmer = adata.obsm.get('X_tcr_trb_kmer', None)

if tra_kmer is not None and trb_kmer is not None:
    if sparse.issparse(tra_kmer) or sparse.issparse(trb_kmer):
        # Ensure both are sparse before stacking to avoid densification
        if not sparse.issparse(tra_kmer): tra_kmer = sparse.csr_matrix(tra_kmer)
        if not sparse.issparse(trb_kmer): trb_kmer = sparse.csr_matrix(trb_kmer)
        tcr_kmer_combined = sparse.hstack([tra_kmer, trb_kmer])
    else:
        tcr_kmer_combined = np.column_stack([tra_kmer, trb_kmer])
else:
    tcr_kmer_combined = np.zeros((adata.n_obs, 1)) # Dummy

# Robust dimensional reduction for k-mer features using TruncatedSVD (Safe for OOM)
print(f"Reducing k-mer features: {tcr_kmer_combined.shape}")
n_comp_kmer = min(10, tcr_kmer_combined.shape[1] - 1)
if n_comp_kmer > 0:
    tcr_svd = TruncatedSVD(n_components=n_comp_kmer, random_state=42, algorithm='randomized')
    tcr_kmer_reduced = tcr_svd.fit_transform(tcr_kmer_combined)
else:
    tcr_kmer_reduced = np.zeros((adata.n_obs, 10))

combined_gene_tcr_kmer = np.column_stack([gene_umap, tcr_kmer_reduced])
adata.obsm['X_combined_gene_tcr_kmer'] = combined_gene_tcr_kmer.astype(np.float32)

print(f"Combined gene-TCR encoding shape: {combined_gene_tcr.shape}")
print(f"Combined gene-TCR k-mer encoding shape: {combined_gene_tcr_kmer.shape}")

# Clear memory
del tcr_kmer_combined
gc.collect()

# --- Dimensionality Reduction on Combined Data ---
print("Computing dimensionality reduction on combined data (UMAP only)...")

# UMAP on combined data
umap_combined = umap.UMAP(n_components=2, random_state=42, n_jobs=1) # Single Job for RAM safety
adata.obsm['X_umap_combined'] = umap_combined.fit_transform(combined_gene_tcr)

Creating combined multi-modal encodings...
Reducing k-mer features: (100067, 2)
Combined gene-TCR encoding shape: (100067, 26)
Combined gene-TCR k-mer encoding shape: (100067, 11)
Computing dimensionality reduction on combined data (UMAP only)...
CPU times: user 2min 51s, sys: 390 ms, total: 2min 52s
Wall time: 1min 54s


## 10. Unsupervised Machine Learning Analysis with Hierarchical Clustering

Before training predictive classifiers, we utilized unsupervised learning to define the intrinsic structure of the immune landscape. We compared several clustering algorithms:
*   **K-Means Clustering:** Partitions data into $k$ distinct clusters by minimizing within-cluster variance.
*   **DBSCAN (Density-Based Spatial Clustering of Applications with Noise):** Groups points that are closely packed together, marking points in low-density regions as outliers.
*   **Agglomerative Hierarchical Clustering:** Builds a hierarchy of clusters using a bottom-up approach.

We evaluated these methods using Silhouette Analysis to measure cluster cohesion and separation. The optimal number of clusters ($k$) for K-Means was determined using the Elbow Method.

In [31]:
# HDBSCAN/sklearn compatibility patch ‚Äî run before clustering
import sys, subprocess, inspect

# Ensure hdbscan is available (not strictly necessary if already installed earlier)
try:
    import hdbscan
except Exception:
    print("hdbscan not installed ‚Äî installing now...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "hdbscan"]) 
    import importlib
    importlib.invalidate_caches()
    import hdbscan

# Patch the check_array reference used inside hdbscan to accept the older keyword
try:
    import sklearn.utils.validation as sk_validation
    from hdbscan import hdbscan_ as _hdbscan_mod
    sig = inspect.signature(sk_validation.check_array)
    if 'ensure_all_finite' in sig.parameters and 'force_all_finite' not in sig.parameters:
        orig = getattr(_hdbscan_mod, 'check_array', None) or sk_validation.check_array
        def _patched_check_array(*args, **kwargs):
            if 'force_all_finite' in kwargs and 'ensure_all_finite' not in kwargs:
                kwargs['ensure_all_finite'] = kwargs.pop('force_all_finite')
            return orig(*args, **kwargs)
        _hdbscan_mod.check_array = _patched_check_array
        print("Patched hdbscan.check_array to accept 'force_all_finite' for this runtime.")
    else:
        print("No patch required for sklearn.check_array signature.")
except Exception as e:
    print("Compatibility patch could not be applied:", type(e).__name__, e)


No patch required for sklearn.check_array signature.


## Unsupervised Machine Learning Analysis (Updated)

This section has been updated to utilize the `clustering.py` implementation for Leiden clustering, replacing the previous K-Means/DBSCAN/Agglomerative comparison.

**Changes:**
- Imported `clustering.py` module.
- Used `clustering.preprocess_data(adata)` for data preprocessing.
- Used `clustering.perform_clustering(adata)` for Leiden clustering at multiple resolutions.
- Calculated silhouette scores for Leiden clusters to maintain compatibility with the "best clustering" selection logic.
- Renamed Leiden cluster columns to `leiden_cluster_{resolution}` to ensure compatibility with downstream feature selection filters.
- Retained TCR sequence-specific clustering and Gene Expression Module Discovery.

**Note:**
- Ensure `clustering.py` is in the python path (Code/ directory).
- The "best clustering" is now selected from the Leiden results based on silhouette score.

In [32]:
%%time
%pip install scipy
%pip install leidenalg

# Check if we should skip unsupervised learning
if globals().get('SKIP_UNSUPERVISED_LEARNING', False) or globals().get('SKIP_TO_DEEP_LEARNING', False):
    print("‚è≠Ô∏è SKIPPING: Unsupervised Learning (Leiden, UMAP, etc.)")
    print("   Set SKIP_UNSUPERVISED_LEARNING=False in the configuration cell to run this section.")
    
    # Still need to do minimal preprocessing for deep learning
    import scanpy as sc
    import numpy as np
    import pandas as pd
    import gc
    from scipy import sparse
    from sklearn.decomposition import TruncatedSVD
    from sklearn.preprocessing import LabelEncoder
    
    print("Running MINIMAL preprocessing for deep learning...")
    
    # Normalize and log-transform
    if 'log1p' not in adata.uns:
        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.log1p(adata)
    
    # Find HVGs
    if 'highly_variable' not in adata.var.columns:
        sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
    
    # Compute PCA if needed
    if 'X_pca' not in adata.obsm:
        hvg_mask = adata.var['highly_variable'].values
        n_hvgs = hvg_mask.sum()
        if sparse.issparse(adata.X):
            X_hvg = adata.X[:, hvg_mask]
            n_components = min(50, n_hvgs - 1, X_hvg.shape[0] - 1)
            svd = TruncatedSVD(n_components=n_components, random_state=42, algorithm='arpack')
            adata.obsm['X_pca'] = svd.fit_transform(X_hvg).astype(np.float32)
            del X_hvg, svd
            gc.collect()
    
    # ============================================================
    # Fix column names (response, patient_id) - ROBUST VERSION
    # ============================================================
    if 'response' not in adata.obs.columns and 'Response' in adata.obs.columns:
        adata.obs['response'] = adata.obs['Response']
    
    # First check if patient_id already exists
    patient_id_found = False
    for col_name in ['patient_id', 'Patient_ID', 'PatientID']:
        if col_name in adata.obs.columns:
            if col_name != 'patient_id':
                adata.obs['patient_id'] = adata.obs[col_name]
            patient_id_found = True
            print(f"  Found patient_id in column '{col_name}'")
            break
    
    # If patient_id not found, derive from sample_id using metadata mapping
    if not patient_id_found and 'sample_id' in adata.obs.columns:
        print("  patient_id not found directly. Deriving from sample_id...")
        
        # Recreate metadata_df mapping (same as in data loading cell)
        # This is the authoritative mapping from GEO GSE300475
        metadata_records = [
            {'sample_id': 'GSM9061665', 'Patient_ID': 'PT1', 'Timepoint': 'Pre', 'Response': 'Responder'},
            {'sample_id': 'GSM9061666', 'Patient_ID': 'PT1', 'Timepoint': 'D21', 'Response': 'Responder'},
            {'sample_id': 'GSM9061667', 'Patient_ID': 'PT1', 'Timepoint': 'D42', 'Response': 'Responder'},
            {'sample_id': 'GSM9061668', 'Patient_ID': 'PT2', 'Timepoint': 'Pre', 'Response': 'Responder'},
            {'sample_id': 'GSM9061669', 'Patient_ID': 'PT2', 'Timepoint': 'D21', 'Response': 'Responder'},
            {'sample_id': 'GSM9061670', 'Patient_ID': 'PT2', 'Timepoint': 'D42', 'Response': 'Responder'},
            {'sample_id': 'GSM9061671', 'Patient_ID': 'PT3', 'Timepoint': 'Pre', 'Response': 'Non-Responder'},
            {'sample_id': 'GSM9061672', 'Patient_ID': 'PT3', 'Timepoint': 'D21', 'Response': 'Non-Responder'},
            {'sample_id': 'GSM9061673', 'Patient_ID': 'PT4', 'Timepoint': 'Pre', 'Response': 'Non-Responder'},
            {'sample_id': 'GSM9061674', 'Patient_ID': 'PT4', 'Timepoint': 'D21', 'Response': 'Non-Responder'},
            {'sample_id': 'GSM9061675', 'Patient_ID': 'PT4', 'Timepoint': 'D42', 'Response': 'Non-Responder'},
        ]
        metadata_df_local = pd.DataFrame(metadata_records)
        
        # Create sample_id to patient_id mapping
        sample_to_patient = dict(zip(metadata_df_local['sample_id'], metadata_df_local['Patient_ID']))
        
        # Map sample_id to patient_id
        adata.obs['patient_id'] = adata.obs['sample_id'].map(sample_to_patient)
        
        # Check for unmapped values
        n_mapped = adata.obs['patient_id'].notna().sum()
        n_total = len(adata.obs)
        print(f"  Mapped {n_mapped}/{n_total} cells to patient_id")
        
        if n_mapped == 0:
            # Try parsing sample_id - maybe format is different (e.g., batch column)
            print("  Warning: No direct matches. Checking for batch column...")
            if 'batch' in adata.obs.columns:
                adata.obs['patient_id'] = adata.obs['batch'].map(sample_to_patient)
                n_mapped = adata.obs['patient_id'].notna().sum()
                print(f"  Mapped {n_mapped}/{n_total} cells using batch column")
        
        # Also derive response if missing
        if 'response' not in adata.obs.columns or adata.obs['response'].isna().all():
            sample_to_response = dict(zip(metadata_df_local['sample_id'], metadata_df_local['Response']))
            adata.obs['response'] = adata.obs['sample_id'].map(sample_to_response)
            if adata.obs['response'].isna().all() and 'batch' in adata.obs.columns:
                adata.obs['response'] = adata.obs['batch'].map(sample_to_response)
            print(f"  Derived response column: {adata.obs['response'].value_counts().to_dict()}")
        
        patient_id_found = adata.obs['patient_id'].notna().any()
    
    if not patient_id_found:
        print("  WARNING: Could not derive patient_id. Downstream patient-level analysis may fail.")
        print(f"  Available columns: {list(adata.obs.columns)}")
    else:
        print(f"  patient_id distribution: {adata.obs['patient_id'].value_counts().to_dict()}")
    
    # Create supervised_mask for downstream
    if 'response' in adata.obs.columns:
        supervised_mask = adata.obs['response'].isin(['Responder', 'Non-Responder']).values
    else:
        supervised_mask = np.ones(adata.n_obs, dtype=bool)
    
    print(f"Minimal preprocessing complete. supervised_mask: {supervised_mask.sum()} samples")
    
else:
    # --- Full Unsupervised Machine Learning Analysis ---
    print("Applying unsupervised machine learning algorithms...")

    import scanpy as sc
    import numpy as np
    import pandas as pd
    import gc  # For garbage collection
    from sklearn.cluster import KMeans
    from sklearn.preprocessing import StandardScaler, LabelEncoder
    from scipy.cluster.hierarchy import dendrogram, linkage
    from scipy import sparse
    import matplotlib.pyplot as plt
    import seaborn as sns
    import os
    from pathlib import Path
    from IPython.display import display
    
    # Ensure output directory exists
    if IS_KAGGLE:
        Path('/kaggle/working/Processed_Data').mkdir(parents=True, exist_ok=True)
    else:
        Path('Processed_Data').mkdir(parents=True, exist_ok=True)

    # Quick memory cleanup
    for _v in ['adata_list','adata_sample','metadata_list','metadata_df',
               'tra_kmer_sparse','trb_kmer_sparse','tra_kmer_matrix','trb_kmer_matrix',
               'vec_tra','vec_trb','tra_seqs','trb_seqs','tra_kmeans','trb_kmeans',
               'tra_kmer_scaled','trb_kmer_scaled','tra_scaler','trb_scaler','gene_kmeans',
               'gene_pca','gene_expression_modules','tra_clusters','trb_clusters']:
        if _v in globals():
            try:
                del globals()[_v]
            except Exception:
                pass
    gc.collect()

    np.random.seed(42)

    # ============================================================
    # Fix column names (response, patient_id) BEFORE processing
    # ============================================================
    if 'response' not in adata.obs.columns and 'Response' in adata.obs.columns:
        adata.obs['response'] = adata.obs['Response']
        print("  Renamed 'Response' column to 'response'")
    
    for col_name in ['Patient_ID', 'PatientID']:
        if col_name in adata.obs.columns and 'patient_id' not in adata.obs.columns:
            adata.obs['patient_id'] = adata.obs[col_name]
            print(f"  Renamed '{col_name}' column to 'patient_id'")
            break

    print("Preprocessing data (memory-efficient mode)...")

    # 1. Normalize and log-transform (these keep sparse)
    if 'log1p' not in adata.uns:
        print("  Normalizing...")
        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.log1p(adata)
        gc.collect()

    # 2. Find highly variable genes (does NOT densify)
    if 'highly_variable' not in adata.var.columns:
        print("  Finding highly variable genes...")
        sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
        gc.collect()

    # 3. MEMORY-EFFICIENT PCA using TruncatedSVD on sparse HVG subset
    if 'X_pca' not in adata.obsm:
        print("Computing PCA (sparse-friendly via TruncatedSVD on HVGs)...")
        
        from sklearn.decomposition import TruncatedSVD
        
        hvg_mask = adata.var['highly_variable'].values
        n_hvgs = hvg_mask.sum()
        print(f"  Using {n_hvgs} highly variable genes")
        
        if sparse.issparse(adata.X):
            X_hvg = adata.X[:, hvg_mask]
            print(f"  HVG matrix shape: {X_hvg.shape}, sparse: {sparse.issparse(X_hvg)}")
            
            n_components = min(50, n_hvgs - 1, X_hvg.shape[0] - 1)
            print(f"  Running TruncatedSVD with {n_components} components...")
            
            svd = TruncatedSVD(n_components=n_components, random_state=42, algorithm='arpack')
            X_pca = svd.fit_transform(X_hvg)
            
            adata.obsm['X_pca'] = X_pca.astype(np.float32)
            adata.uns['pca'] = {
                'variance_ratio': svd.explained_variance_ratio_,
                'variance': svd.explained_variance_,
            }
            loadings = np.zeros((adata.n_vars, n_components), dtype=np.float32)
            loadings[hvg_mask, :] = svd.components_.T.astype(np.float32)
            adata.varm['PCs'] = loadings
            
            del X_hvg, svd, X_pca, loadings
            gc.collect()
            print(f"  PCA complete. Variance explained: {adata.uns['pca']['variance_ratio'].sum():.2%}")
            
        else:
            print("  Data is dense, scaling HVGs only...")
            X_hvg = adata.X[:, hvg_mask].copy()
            scaler = StandardScaler(with_mean=True, with_std=True)
            X_hvg_scaled = scaler.fit_transform(X_hvg)
            del X_hvg
            gc.collect()
            
            from sklearn.decomposition import PCA
            n_components = min(50, n_hvgs - 1, X_hvg_scaled.shape[0] - 1)
            pca = PCA(n_components=n_components, random_state=42)
            X_pca = pca.fit_transform(X_hvg_scaled)
            
            adata.obsm['X_pca'] = X_pca.astype(np.float32)
            adata.uns['pca'] = {
                'variance_ratio': pca.explained_variance_ratio_,
                'variance': pca.explained_variance_,
            }
            loadings = np.zeros((adata.n_vars, n_components), dtype=np.float32)
            loadings[hvg_mask, :] = pca.components_.T.astype(np.float32)
            adata.varm['PCs'] = loadings
            
            del X_hvg_scaled, pca, X_pca, loadings, scaler
            gc.collect()
        
        # Memory cleanup after PCA
        try:
            if getattr(adata, 'raw', None) is not None:
                del adata.raw
        except Exception:
            pass
        try:
            if hasattr(adata, 'layers') and len(adata.layers) > 0:
                adata.layers.clear()
        except Exception:
            pass
        gc.collect()
        print("  Memory cleanup after PCA complete.")

    # Neighbors
    print("Computing neighbors...")
    sc.pp.neighbors(adata, n_neighbors=15, n_pcs=50, random_state=42)
    gc.collect()

    # 2. Perform Clustering (Leiden) - Use fewer resolutions for speed
    print("Performing Leiden clustering...")
    resolutions = [0.01, 0.05, 0.1, 0.2, 0.5, 1.0]  # Reduced from 26 to 6
    best_res = 0.1
    target_clusters = 7
    best_diff = float('inf')

    for res in resolutions:
        key = f'leiden_{res}'
        try:
            sc.tl.leiden(adata, resolution=res, key_added=key, random_state=42)
            n_clust = len(adata.obs[key].unique())
            print(f"Resolution {res}: {n_clust} clusters")
            if abs(n_clust - target_clusters) < best_diff:
                best_diff = abs(n_clust - target_clusters)
                best_res = res
        except Exception as e:
            print(f"Leiden failed for resolution {res}: {e}")
        gc.collect()

    print(f"Selected resolution: {best_res}")
    if f'leiden_{best_res}' in adata.obs:
        adata.obs['leiden'] = adata.obs[f'leiden_{best_res}']

    # 3. TCR Sequence Clustering
    print("Performing TCR sequence-specific clustering...")
    if 'X_tcr_tra_kmer' in adata.obsm:
        tra_scaler = StandardScaler()
        tra_kmer_scaled = tra_scaler.fit_transform(adata.obsm['X_tcr_tra_kmer'])
        tra_kmeans = KMeans(n_clusters=6, random_state=42, n_init=10)
        adata.obs['tra_kmer_clusters'] = pd.Categorical(tra_kmeans.fit_predict(tra_kmer_scaled))
        del tra_kmer_scaled, tra_kmeans, tra_scaler
        gc.collect()

    if 'X_tcr_trb_kmer' in adata.obsm:
        trb_scaler = StandardScaler()
        trb_kmer_scaled = trb_scaler.fit_transform(adata.obsm['X_tcr_trb_kmer'])
        trb_kmeans = KMeans(n_clusters=6, random_state=42, n_init=10)
        adata.obs['trb_kmer_clusters'] = pd.Categorical(trb_kmeans.fit_predict(trb_kmer_scaled))
        del trb_kmer_scaled, trb_kmeans, trb_scaler
        gc.collect()

    # 4. Gene Expression Module Discovery
    print("Discovering gene expression modules...")
    gene_pca = adata.obsm.get('X_gene_pca', adata.obsm.get('X_pca'))
    if gene_pca is not None:
        gene_kmeans = KMeans(n_clusters=8, random_state=42, n_init=10)
        adata.obs['gene_expression_modules'] = pd.Categorical(gene_kmeans.fit_predict(gene_pca))
        del gene_pca, gene_kmeans
        gc.collect()

    # 5. Visualization
    print("Creating visualizations...")
    sc.tl.umap(adata, random_state=42)
    
    color_keys = []
    if 'leiden' in adata.obs:
        color_keys.append('leiden')
    if 'response' in adata.obs.columns:
        color_keys.append('response')
    
    if color_keys:
        sc.pl.umap(adata, color=color_keys, show=False)
        plt.show()

    # --- Create supervised_mask for downstream cells ---
    if 'response' in adata.obs.columns:
        supervised_mask = adata.obs['response'].isin(['Responder', 'Non-Responder']).values
        print(f"\nCreated supervised_mask: {supervised_mask.sum()} samples with valid response labels")
    else:
        supervised_mask = np.ones(adata.n_obs, dtype=bool)
        print("\nWarning: No response column found. supervised_mask includes all cells.")

    print("Unsupervised machine learning analysis completed!")

Note: you may need to restart the kernel to use updated packages.
Collecting leidenalg
  Downloading leidenalg-0.11.0-cp38-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (10 kB)
Downloading leidenalg-0.11.0-cp38-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (2.7 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2.7/2.7 MB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: leidenalg
Successfully installed leidenalg-0.11.0
Note: you may need to restart the kernel to use updated packages.
‚è≠Ô∏è SKIPPING: Unsupervised Learning (Leiden, UMAP, etc.)
   Set SKIP_UNSUPERVISED_LEARNING=False in the configuration cell to run this section.
Running MINIMAL preprocessing for deep learning...
  patient_id not found directly. Deriving from sample_id...
  Mapped 0/100067 cells to patient_id
  Derived response column: {}
  Available columns: ['sample

In [33]:
# --- Memory cleanup (after Leiden clustering, before dendrogram) ---
# This frees large temporary matrices (one-hot encodings, neighbor/connectivity matrices)
# while keeping UMAP for dendrogram/visualization.
print('\nRunning memory cleanup after Leiden clustering (before dendrogram)...')
try:
    import psutil, os
    proc = psutil.Process(os.getpid())
    print(f"Memory before cleanup: {proc.memory_info().rss // (1024**2)} MB")
except Exception:
    print('psutil not available; skipping memory before measurement')

def _fallback_cleanup(drop_onehot=False, drop_raw=False, drop_obsm_umap_tsne=False, verbose=True):
    """Basic cleanup fallback when cleanup_after_clustering is unavailable."""
    if 'adata' not in globals():
        return
    if hasattr(adata, 'obsp'):
        for _k in list(adata.obsp.keys()):
            try:
                del adata.obsp[_k]
            except Exception:
                pass
    if drop_onehot and hasattr(adata, 'obsm'):
        for _key in ['X_tcr_tra_onehot', 'X_tcr_trb_onehot']:
            if _key in adata.obsm:
                try:
                    del adata.obsm[_key]
                except Exception:
                    pass
    if drop_obsm_umap_tsne and hasattr(adata, 'obsm'):
        for _key in list(adata.obsm.keys()):
            _lk = _key.lower()
            if 'umap' in _lk or 'tsne' in _lk:
                try:
                    del adata.obsm[_key]
                except Exception:
                    pass
    if drop_raw and getattr(adata, 'raw', None) is not None:
        adata.raw = None
    if verbose:
        print('Fallback cleanup completed.')

# Conservative cleanup: drop TCR one-hot arrays and obsp connectivities/distances
# Keep one-hot encodings by default to avoid KeyError in downstream feature engineering
if 'cleanup_after_clustering' in globals():
    try:
        cleanup_after_clustering(drop_onehot=False, drop_raw=False, drop_obsm_umap_tsne=False, verbose=True)
    except Exception as e:
        print('cleanup_after_clustering failed, using fallback cleanup:', e)
        _fallback_cleanup(drop_onehot=False, drop_raw=False, drop_obsm_umap_tsne=False, verbose=True)
else:
    _fallback_cleanup(drop_onehot=False, drop_raw=False, drop_obsm_umap_tsne=False, verbose=True)

try:
    proc = psutil.Process(os.getpid())
    print(f"Memory after cleanup: {proc.memory_info().rss // (1024**2)} MB")
except Exception:
    pass

import gc
gc.collect()



Running memory cleanup after Leiden clustering (before dendrogram)...
Memory before cleanup: 5001 MB
Fallback cleanup completed.
Memory after cleanup: 5001 MB


0

In [34]:
# --- Label/patient derivation and supervised availability helpers ---
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
if '_ensure_int_labels' not in globals():
    def _ensure_int_labels(y):
        y_arr = np.asarray(y)
        if np.issubdtype(y_arr.dtype, np.integer):
            return y_arr
        if y_arr.size == 0:
            return y_arr.astype(np.int64)
        if np.all(np.isfinite(y_arr)) and np.all(np.equal(y_arr, np.floor(y_arr))):
            return y_arr.astype(np.int64)
        raise ValueError("Labels must be integer or integer-like floats.")
if '_normalize_response_value' not in globals():
    def _normalize_response_value(val):
        if pd.isna(val):
            return 'Unknown'
        s = str(val).strip().lower()
        if s == '':
            return 'Unknown'
        if 'non' in s and 'responder' in s:
            return 'Non-Responder'
        if 'responder' in s:
            return 'Responder'
        return 'Unknown'
if '_ensure_response_and_patient' not in globals():
    def _ensure_response_and_patient(adata):
        # Normalize existing columns if present
        if 'response' not in adata.obs.columns and 'Response' in adata.obs.columns:
            adata.obs['response'] = adata.obs['Response']
        if 'patient_id' not in adata.obs.columns:
            for _col in ['Patient_ID', 'PatientID']:
                if _col in adata.obs.columns:
                    adata.obs['patient_id'] = adata.obs[_col]
                    break
        # Determine metadata mapping
        md = None
        if 'metadata_df' in globals() and isinstance(metadata_df, pd.DataFrame) and not metadata_df.empty:
            md = metadata_df.copy()
        else:
            # Fallback to hard-coded metadata list (same as in Cell 15)
            _metadata_list = [
                {'S_Number': 'S1',  'GEX_Sample_ID': 'GSM9061665', 'TCR_Sample_ID': 'GSM9061687', 'Patient_ID': 'PT1',  'Timepoint': 'Baseline',   'Response': 'Responder',     'In_Data': 'Yes'},
                {'S_Number': 'S2',  'GEX_Sample_ID': 'GSM9061666', 'TCR_Sample_ID': 'GSM9061688', 'Patient_ID': 'PT1',  'Timepoint': 'Post-Tx',    'Response': 'Responder',     'In_Data': 'Yes'},
                {'S_Number': 'S3',  'GEX_Sample_ID': 'GSM9061667', 'TCR_Sample_ID': 'GSM9061689', 'Patient_ID': 'PT1',  'Timepoint': 'Recurrence', 'Response': 'Responder',     'In_Data': 'Yes'},
                {'S_Number': 'S4',  'GEX_Sample_ID': 'GSM9061668', 'TCR_Sample_ID': 'GSM9061690', 'Patient_ID': 'PT2',  'Timepoint': 'Baseline',   'Response': 'Responder',     'In_Data': 'Yes'},
                {'S_Number': 'S5',  'GEX_Sample_ID': 'GSM9061669', 'TCR_Sample_ID': 'GSM9061691', 'Patient_ID': 'PT2',  'Timepoint': 'Post-Tx',    'Response': 'Responder',     'In_Data': 'Yes'},
                {'S_Number': 'S6',  'GEX_Sample_ID': 'GSM9061670', 'TCR_Sample_ID': 'GSM9061692', 'Patient_ID': 'PT3',  'Timepoint': 'Baseline',   'Response': 'Non-Responder', 'In_Data': 'Yes'},
                {'S_Number': 'S7',  'GEX_Sample_ID': 'GSM9061671', 'TCR_Sample_ID': 'GSM9061693', 'Patient_ID': 'PT3',  'Timepoint': 'Post-Tx',    'Response': 'Non-Responder', 'In_Data': 'Yes'},
                {'S_Number': 'S8',  'GEX_Sample_ID': 'GSM9061672', 'TCR_Sample_ID': None,         'Patient_ID': 'PT3',  'Timepoint': 'Recurrence', 'Response': 'Non-Responder', 'In_Data': 'GEX only'},
                {'S_Number': 'S9',  'GEX_Sample_ID': 'GSM9061673', 'TCR_Sample_ID': 'GSM9061694', 'Patient_ID': 'PT4',  'Timepoint': 'Baseline',   'Response': 'Non-Responder', 'In_Data': 'Yes'},
                {'S_Number': 'S10', 'GEX_Sample_ID': 'GSM9061674', 'TCR_Sample_ID': 'GSM9061695', 'Patient_ID': 'PT4',  'Timepoint': 'Post-Tx',    'Response': 'Non-Responder', 'In_Data': 'Yes'},
                {'S_Number': 'S11', 'GEX_Sample_ID': 'GSM9061675', 'TCR_Sample_ID': 'GSM9061696', 'Patient_ID': 'PT4',  'Timepoint': 'Recurrence', 'Response': 'Non-Responder', 'In_Data': 'Yes'},
            ]
            md = pd.DataFrame(_metadata_list)
        # Identify mapping columns in metadata
        sample_col = None
        for _c in ['sample_id', 'GEX_Sample_ID', 'GSM_ID', 'GEO_ID', 'Sample_ID']:
            if _c in md.columns:
                sample_col = _c
                break
        patient_col = None
        for _c in ['patient_id', 'Patient_ID', 'PatientID']:
            if _c in md.columns:
                patient_col = _c
                break
        response_col = None
        for _c in ['response', 'Response']:
            if _c in md.columns:
                response_col = _c
                break
        # Determine sample ID series from adata
        sample_series = None
        for _c in ['sample_id', 'batch']:
            if _c in adata.obs.columns:
                sample_series = adata.obs[_c].astype(str)
                break
        if md is not None and sample_series is not None and sample_col is not None:
            sample_key = sample_series.str.split('_').str[0]
            md_sample = md[sample_col].astype(str)
            if patient_col is not None:
                patient_map = dict(zip(md_sample, md[patient_col]))
                if 'patient_id' not in adata.obs.columns:
                    adata.obs['patient_id'] = sample_key.map(patient_map)
                else:
                    adata.obs['patient_id'] = adata.obs['patient_id'].where(adata.obs['patient_id'].notna(), sample_key.map(patient_map))
            if response_col is not None:
                resp_map = dict(zip(md_sample, md[response_col]))
                if 'response' not in adata.obs.columns:
                    adata.obs['response'] = sample_key.map(resp_map)
                else:
                    adata.obs['response'] = adata.obs['response'].where(adata.obs['response'].notna(), sample_key.map(resp_map))
        # Normalize response labels
        if 'response' in adata.obs.columns:
            adata.obs['response'] = adata.obs['response'].apply(_normalize_response_value)
        # Coverage reporting
        if 'response' in adata.obs.columns:
            resp_counts = adata.obs['response'].value_counts(dropna=False).to_dict()
            print(f"Response distribution: {resp_counts}")
        if 'patient_id' in adata.obs.columns:
            mapped = adata.obs['patient_id'].notna().sum()
            print(f"Patient_id coverage: {mapped}/{len(adata.obs)}")
if '_get_supervised_mask_and_labels' not in globals():
    def _get_supervised_mask_and_labels(adata):
        if 'response' not in adata.obs.columns:
            print("WARNING: response column missing. No supervised labels available.")
            supervised_mask = np.zeros(adata.n_obs, dtype=bool)
            return supervised_mask, pd.Series([], dtype=object), None, {}, False
        y_all = adata.obs['response'].astype(str)
        supervised_mask = y_all.isin(['Responder', 'Non-Responder']).values
        y_supervised = y_all[supervised_mask]
        if len(y_supervised) == 0:
            print("WARNING: No labeled samples found for supervised learning.")
            return supervised_mask, y_supervised, None, {}, False
        class_counts = y_supervised.value_counts().to_dict()
        if len(class_counts) < 2 or min(class_counts.values()) < 2:
            print(f"WARNING: Insufficient class balance for supervised learning: {class_counts}")
            return supervised_mask, y_supervised, None, class_counts, False
        le = LabelEncoder()
        return supervised_mask, y_supervised, le, class_counts, True
# Ensure response/patient labels are present and normalized
if 'adata' not in globals() or adata is None:
    raise NameError("adata is not defined. Please run the data loading cells first.")
_ensure_response_and_patient(adata)
supervised_mask, y_supervised, label_encoder, class_counts, SUPERVISED_AVAILABLE = _get_supervised_mask_and_labels(adata)
globals()['SUPERVISED_AVAILABLE'] = SUPERVISED_AVAILABLE
if SUPERVISED_AVAILABLE:
    y_encoded = label_encoder.fit_transform(y_supervised)
    y_encoded = _ensure_int_labels(y_encoded)
    print(f"Working with {int(supervised_mask.sum())} samples for supervised learning")
    print(f"Class distribution: {class_counts}")
else:
    print("WARNING: Supervised labels not available or insufficient. Skipping supervised-only steps.")
    supervised_mask = np.ones(adata.n_obs, dtype=bool)
    y_supervised = pd.Series(['Unknown'] * adata.n_obs, index=adata.obs.index)
    y_encoded = np.array([], dtype=np.int64)

Response distribution: {'Non-Responder': 63074, 'Responder': 36993}
Patient_id coverage: 100067/100067
Working with 100067 samples for supervised learning
Class distribution: {'Non-Responder': 63074, 'Responder': 36993}


In [35]:
%%time
# --- Comprehensive Feature Engineering ---

print("Creating comprehensive feature set using ALL available encodings...")

# --- 1. Strategic Feature Engineering with Dimensionality Reduction ---
print("Applying strategic dimensionality reduction to high-dimensional features...")

# --- Label/patient derivation and supervised availability helpers ---
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder

if '_ensure_int_labels' not in globals():
    def _ensure_int_labels(y):
        y_arr = np.asarray(y)
        if np.issubdtype(y_arr.dtype, np.integer):
            return y_arr
        if y_arr.size == 0:
            return y_arr.astype(np.int64)
        if np.all(np.isfinite(y_arr)) and np.all(np.equal(y_arr, np.floor(y_arr))):
            return y_arr.astype(np.int64)
        raise ValueError("Labels must be integer or integer-like floats.")

if '_normalize_response_value' not in globals():
    def _normalize_response_value(val):
        if pd.isna(val):
            return 'Unknown'
        s = str(val).strip().lower()
        if s == '':
            return 'Unknown'
        if 'non' in s and 'responder' in s:
            return 'Non-Responder'
        if 'responder' in s:
            return 'Responder'
        return 'Unknown'

if '_ensure_response_and_patient' not in globals():
    def _ensure_response_and_patient(adata):
        # Normalize existing columns if present
        if 'response' not in adata.obs.columns and 'Response' in adata.obs.columns:
            adata.obs['response'] = adata.obs['Response']
        if 'patient_id' not in adata.obs.columns:
            for _col in ['Patient_ID', 'PatientID']:
                if _col in adata.obs.columns:
                    adata.obs['patient_id'] = adata.obs[_col]
                    break

        # Determine metadata mapping
        md = None
        if 'metadata_df' in globals() and isinstance(metadata_df, pd.DataFrame) and not metadata_df.empty:
            md = metadata_df.copy()
        else:
            # Fallback to hard-coded metadata list (same as in Cell 15)
            _metadata_list = [
                {'S_Number': 'S1',  'GEX_Sample_ID': 'GSM9061665', 'TCR_Sample_ID': 'GSM9061687', 'Patient_ID': 'PT1',  'Timepoint': 'Baseline',   'Response': 'Responder',     'In_Data': 'Yes'},
                {'S_Number': 'S2',  'GEX_Sample_ID': 'GSM9061666', 'TCR_Sample_ID': 'GSM9061688', 'Patient_ID': 'PT1',  'Timepoint': 'Post-Tx',    'Response': 'Responder',     'In_Data': 'Yes'},
                {'S_Number': 'S3',  'GEX_Sample_ID': 'GSM9061667', 'TCR_Sample_ID': 'GSM9061689', 'Patient_ID': 'PT1',  'Timepoint': 'Recurrence', 'Response': 'Responder',     'In_Data': 'Yes'},
                {'S_Number': 'S4',  'GEX_Sample_ID': 'GSM9061668', 'TCR_Sample_ID': 'GSM9061690', 'Patient_ID': 'PT2',  'Timepoint': 'Baseline',   'Response': 'Responder',     'In_Data': 'Yes'},
                {'S_Number': 'S5',  'GEX_Sample_ID': 'GSM9061669', 'TCR_Sample_ID': 'GSM9061691', 'Patient_ID': 'PT2',  'Timepoint': 'Post-Tx',    'Response': 'Responder',     'In_Data': 'Yes'},
                {'S_Number': 'S6',  'GEX_Sample_ID': 'GSM9061670', 'TCR_Sample_ID': 'GSM9061692', 'Patient_ID': 'PT3',  'Timepoint': 'Baseline',   'Response': 'Non-Responder', 'In_Data': 'Yes'},
                {'S_Number': 'S7',  'GEX_Sample_ID': 'GSM9061671', 'TCR_Sample_ID': 'GSM9061693', 'Patient_ID': 'PT3',  'Timepoint': 'Post-Tx',    'Response': 'Non-Responder', 'In_Data': 'Yes'},
                {'S_Number': 'S8',  'GEX_Sample_ID': 'GSM9061672', 'TCR_Sample_ID': None,         'Patient_ID': 'PT3',  'Timepoint': 'Recurrence', 'Response': 'Non-Responder', 'In_Data': 'GEX only'},
                {'S_Number': 'S9',  'GEX_Sample_ID': 'GSM9061673', 'TCR_Sample_ID': 'GSM9061694', 'Patient_ID': 'PT4',  'Timepoint': 'Baseline',   'Response': 'Non-Responder', 'In_Data': 'Yes'},
                {'S_Number': 'S10', 'GEX_Sample_ID': 'GSM9061674', 'TCR_Sample_ID': 'GSM9061695', 'Patient_ID': 'PT4',  'Timepoint': 'Post-Tx',    'Response': 'Non-Responder', 'In_Data': 'Yes'},
                {'S_Number': 'S11', 'GEX_Sample_ID': 'GSM9061675', 'TCR_Sample_ID': 'GSM9061696', 'Patient_ID': 'PT4',  'Timepoint': 'Recurrence', 'Response': 'Non-Responder', 'In_Data': 'Yes'},
            ]
            md = pd.DataFrame(_metadata_list)

        # Identify mapping columns in metadata
        sample_col = None
        for _c in ['sample_id', 'GEX_Sample_ID', 'GSM_ID', 'GEO_ID', 'Sample_ID']:
            if _c in md.columns:
                sample_col = _c
                break
        patient_col = None
        for _c in ['patient_id', 'Patient_ID', 'PatientID']:
            if _c in md.columns:
                patient_col = _c
                break
        response_col = None
        for _c in ['response', 'Response']:
            if _c in md.columns:
                response_col = _c
                break

        # Determine sample ID series from adata
        sample_series = None
        for _c in ['sample_id', 'batch']:
            if _c in adata.obs.columns:
                sample_series = adata.obs[_c].astype(str)
                break

        if md is not None and sample_series is not None and sample_col is not None:
            sample_key = sample_series.str.split('_').str[0]
            md_sample = md[sample_col].astype(str)

            if patient_col is not None:
                patient_map = dict(zip(md_sample, md[patient_col]))
                if 'patient_id' not in adata.obs.columns:
                    adata.obs['patient_id'] = sample_key.map(patient_map)
                else:
                    adata.obs['patient_id'] = adata.obs['patient_id'].where(adata.obs['patient_id'].notna(), sample_key.map(patient_map))

            if response_col is not None:
                resp_map = dict(zip(md_sample, md[response_col]))
                if 'response' not in adata.obs.columns:
                    adata.obs['response'] = sample_key.map(resp_map)
                else:
                    adata.obs['response'] = adata.obs['response'].where(adata.obs['response'].notna(), sample_key.map(resp_map))

        # Normalize response labels
        if 'response' in adata.obs.columns:
            adata.obs['response'] = adata.obs['response'].apply(_normalize_response_value)

        # Coverage reporting
        if 'response' in adata.obs.columns:
            resp_counts = adata.obs['response'].value_counts(dropna=False).to_dict()
            print(f"Response distribution: {resp_counts}")
        if 'patient_id' in adata.obs.columns:
            mapped = adata.obs['patient_id'].notna().sum()
            print(f"Patient_id coverage: {mapped}/{len(adata.obs)}")

if '_get_supervised_mask_and_labels' not in globals():
    def _get_supervised_mask_and_labels(adata):
        if 'response' not in adata.obs.columns:
            print("WARNING: response column missing. No supervised labels available.")
            supervised_mask = np.zeros(adata.n_obs, dtype=bool)
            return supervised_mask, pd.Series([], dtype=object), None, {}, False
        y_all = adata.obs['response'].astype(str)
        supervised_mask = y_all.isin(['Responder', 'Non-Responder']).values
        y_supervised = y_all[supervised_mask]
        if len(y_supervised) == 0:
            print("WARNING: No labeled samples found for supervised learning.")
            return supervised_mask, y_supervised, None, {}, False
        class_counts = y_supervised.value_counts().to_dict()
        if len(class_counts) < 2 or min(class_counts.values()) < 2:
            print(f"WARNING: Insufficient class balance for supervised learning: {class_counts}")
            return supervised_mask, y_supervised, None, class_counts, False
        le = LabelEncoder()
        return supervised_mask, y_supervised, le, class_counts, True

# Ensure response/patient labels are present and normalized
if 'adata' not in globals() or adata is None:
    raise NameError("adata is not defined. Please run the data loading cells first.")

_ensure_response_and_patient(adata)

supervised_mask, y_supervised, label_encoder, class_counts, SUPERVISED_AVAILABLE = _get_supervised_mask_and_labels(adata)
globals()['SUPERVISED_AVAILABLE'] = SUPERVISED_AVAILABLE

if SUPERVISED_AVAILABLE:
    y_encoded = label_encoder.fit_transform(y_supervised)
    y_encoded = _ensure_int_labels(y_encoded)
    print(f"Working with {int(supervised_mask.sum())} samples for supervised learning")
    print(f"Class distribution: {class_counts}")
else:
    print("WARNING: Supervised labels not available or insufficient. Skipping supervised-only steps.")
    supervised_mask = np.ones(adata.n_obs, dtype=bool)
    y_supervised = pd.Series(['Unknown'] * adata.n_obs, index=adata.obs.index)
    y_encoded = np.array([], dtype=np.int64)

# --- Reduce high-dimensional k-mer features using variance-based selection ---
# Check if k-mer features exist
has_tra_kmer = 'X_tcr_tra_kmer' in adata.obsm
has_trb_kmer = 'X_tcr_trb_kmer' in adata.obsm

if has_tra_kmer:
    tra_kmer_supervised = adata.obsm['X_tcr_tra_kmer'][supervised_mask]
else:
    print("Warning: X_tcr_tra_kmer not found. Using placeholder.")
    tra_kmer_supervised = np.zeros((sum(supervised_mask), 100))

if has_trb_kmer:
    trb_kmer_supervised = adata.obsm['X_tcr_trb_kmer'][supervised_mask]
else:
    print("Warning: X_tcr_trb_kmer not found. Using placeholder.")
    trb_kmer_supervised = np.zeros((sum(supervised_mask), 100))

# Select top variance k-mers to reduce dimensionality
def select_top_variance_features(X, n_features=200):
    """Select features with highest variance"""
    variances = np.var(X, axis=0)
    n_features = min(n_features, X.shape[1])  # Don't select more features than exist
    top_indices = np.argsort(variances)[-n_features:]
    return X[:, top_indices], top_indices

print("Reducing k-mer features by variance selection...")
tra_kmer_reduced, tra_top_idx = select_top_variance_features(tra_kmer_supervised, n_features=200)
trb_kmer_reduced, trb_top_idx = select_top_variance_features(trb_kmer_supervised, n_features=200)

print(f"TRA k-mers reduced from {tra_kmer_supervised.shape[1]} to {tra_kmer_reduced.shape[1]}")
print(f"TRB k-mers reduced from {trb_kmer_supervised.shape[1]} to {trb_kmer_reduced.shape[1]}")

# --- 2. Create strategic feature combinations ---
feature_sets = {}

# Helper function to safely get obsm arrays
def _get_obsm_or_zeros(adata, key, mask, n_cols):
    if key in adata.obsm:
        arr = adata.obsm[key][mask]
        return arr[:, :min(n_cols, arr.shape[1])]
    return np.zeros((sum(mask), n_cols))

# Get gene features (try X_gene_pca first, then X_pca)
if 'X_gene_pca' in adata.obsm:
    gene_features = adata.obsm['X_gene_pca'][supervised_mask]
elif 'X_pca' in adata.obsm:
    gene_features = adata.obsm['X_pca'][supervised_mask]
else:
    print("Warning: No gene PCA features found.")
    gene_features = np.zeros((sum(supervised_mask), 50))

# TCR physicochemical features
tcr_physico_cols_tra = ['tra_length', 'tra_molecular_weight', 'tra_hydrophobicity']
tcr_physico_cols_trb = ['trb_length', 'trb_molecular_weight', 'trb_hydrophobicity']

tra_physico = adata.obs[[c for c in tcr_physico_cols_tra if c in adata.obs.columns]].fillna(0)[supervised_mask].values \
    if any(c in adata.obs.columns for c in tcr_physico_cols_tra) else np.zeros((sum(supervised_mask), 3))
trb_physico = adata.obs[[c for c in tcr_physico_cols_trb if c in adata.obs.columns]].fillna(0)[supervised_mask].values \
    if any(c in adata.obs.columns for c in tcr_physico_cols_trb) else np.zeros((sum(supervised_mask), 3))

# Ensure 3 columns each
if tra_physico.shape[1] < 3:
    tra_physico = np.hstack([tra_physico, np.zeros((tra_physico.shape[0], 3 - tra_physico.shape[1]))])
if trb_physico.shape[1] < 3:
    trb_physico = np.hstack([trb_physico, np.zeros((trb_physico.shape[0], 3 - trb_physico.shape[1]))])

tcr_physico = np.column_stack([tra_physico, trb_physico])

# QC features
qc_cols = ['n_genes_by_counts', 'total_counts', 'pct_counts_mt']
available_qc = [c for c in qc_cols if c in adata.obs.columns]
if available_qc:
    qc_features = adata.obs[available_qc].fillna(0)[supervised_mask].values
else:
    qc_features = np.zeros((sum(supervised_mask), 3))

# Ensure 3 columns for QC
if qc_features.shape[1] < 3:
    qc_features = np.hstack([qc_features, np.zeros((qc_features.shape[0], 3 - qc_features.shape[1]))])

# Basic features (gene expression + physicochemical)
feature_sets['basic'] = np.column_stack([
    gene_features[:, :min(20, gene_features.shape[1])],  # Top 20 gene PCA components
    tcr_physico,
    qc_features
])

# Enhanced gene expression
feature_sets['gene_enhanced'] = np.column_stack([
    gene_features,  # All gene PCA components
    _get_obsm_or_zeros(adata, 'X_gene_svd', supervised_mask, 30),  # Top 30 SVD components
    _get_obsm_or_zeros(adata, 'X_gene_umap', supervised_mask, 20),  # All 20 UMAP components
    tcr_physico,
    qc_features
])

# TCR sequence enhanced
feature_sets['tcr_enhanced'] = np.column_stack([
    gene_features[:, :min(20, gene_features.shape[1])],  # Top 20 gene PCA
    tra_kmer_reduced,  # Top 200 TRA k-mers
    trb_kmer_reduced,  # Top 200 TRB k-mers
    tcr_physico,
    qc_features
])

# Comprehensive (reduced) - Only gene PCA + top k-mers + physicochemical
feature_sets['comprehensive'] = np.column_stack([
    gene_features[:, :min(15, gene_features.shape[1])],  # Top 15 gene PCA
    tra_kmer_reduced[:, :min(50, tra_kmer_reduced.shape[1])],  # Top 50 TRA k-mers
    trb_kmer_reduced[:, :min(50, trb_kmer_reduced.shape[1])],  # Top 50 TRB k-mers
    tcr_physico,
    qc_features
])

print(f"\nFeature set dimensions:")
for name, features in feature_sets.items():
    print(f"  ‚Ä¢ {name}: {features.shape}")

print("Comprehensive feature engineering completed!")

Creating comprehensive feature set using ALL available encodings...
Applying strategic dimensionality reduction to high-dimensional features...
Response distribution: {'Non-Responder': 63074, 'Responder': 36993}
Patient_id coverage: 100067/100067
Working with 100067 samples for supervised learning
Class distribution: {'Non-Responder': 63074, 'Responder': 36993}
Reducing k-mer features by variance selection...
TRA k-mers reduced from 1 to 1
TRB k-mers reduced from 1 to 1

Feature set dimensions:
  ‚Ä¢ basic: (100067, 29)
  ‚Ä¢ gene_enhanced: (100067, 79)
  ‚Ä¢ tcr_enhanced: (100067, 31)
  ‚Ä¢ comprehensive: (100067, 26)
Comprehensive feature engineering completed!
CPU times: user 337 ms, sys: 18.1 ms, total: 355 ms
Wall time: 353 ms


In [36]:
# --- Correlation Analysis of Top Features ---
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Validate that adata exists
if 'adata' not in globals() or adata is None:
    raise NameError("adata is not defined. Please run the data loading cells first.")

# Ensure supervised_mask is defined
if 'supervised_mask' not in globals():
    if 'response' in adata.obs.columns:
        supervised_mask = adata.obs['response'].isin(['Responder', 'Non-Responder']).values
    elif 'Response' in adata.obs.columns:
        supervised_mask = adata.obs['Response'].isin(['Responder', 'Non-Responder']).values
    else:
        supervised_mask = np.ones(adata.n_obs, dtype=bool)
        print("Warning: No response column found. Using all cells.")

# Ensure tcr_physico and qc_features are defined
if 'tcr_physico' not in globals():
    # Extract TRA physicochemical features
    tra_cols = ['tra_length', 'tra_molecular_weight', 'tra_hydrophobicity']
    if all(col in adata.obs.columns for col in tra_cols):
        tra_physico = adata.obs[tra_cols].fillna(0)[supervised_mask].values
    else:
        tra_physico = np.zeros((np.sum(supervised_mask), 3))
    
    # Extract TRB physicochemical features
    trb_cols = ['trb_length', 'trb_molecular_weight', 'trb_hydrophobicity']
    if all(col in adata.obs.columns for col in trb_cols):
        trb_physico = adata.obs[trb_cols].fillna(0)[supervised_mask].values
    else:
        trb_physico = np.zeros((np.sum(supervised_mask), 3))
    
    # Combine TRA and TRB features
    tcr_physico = np.hstack([tra_physico, trb_physico])

if 'qc_features' not in globals():
    qc_cols = ['n_genes_by_counts', 'total_counts', 'pct_counts_mt']
    available_qc = [col for col in qc_cols if col in adata.obs.columns]
    if available_qc:
        qc_features = adata.obs[available_qc].fillna(0)[supervised_mask].values
    else:
        qc_features = np.zeros((np.sum(supervised_mask), 3))

# Select a subset of features for the heatmap
# We'll take the top 10 Gene PCs, top 5 physicochemical, and QC metrics
# Ensure we have the data available
if 'X_gene_pca' in adata.obsm:
    gene_pcs = adata.obsm['X_gene_pca'][supervised_mask][:, :min(10, adata.obsm['X_gene_pca'].shape[1])]
    gene_names = [f"Gene_PC{i+1}" for i in range(gene_pcs.shape[1])]
elif 'X_pca' in adata.obsm:
    gene_pcs = adata.obsm['X_pca'][supervised_mask][:, :min(10, adata.obsm['X_pca'].shape[1])]
    gene_names = [f"Gene_PC{i+1}" for i in range(gene_pcs.shape[1])]
else:
    gene_pcs = np.zeros((np.sum(supervised_mask), 10))
    gene_names = [f"Placeholder_PC{i+1}" for i in range(10)]

heatmap_features = np.column_stack([
    gene_pcs,
    tcr_physico,
    qc_features
])
heatmap_feature_names = gene_names + \
                        ['TRA_Len', 'TRA_MW', 'TRA_Hydro', 'TRB_Len', 'TRB_MW', 'TRB_Hydro'] + \
                        ['n_genes', 'total_counts', 'pct_mt']

# Calculate correlation matrix
corr_matrix = np.corrcoef(heatmap_features, rowvar=False)

# Plot
plt.figure(figsize=(16, 14))
sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap='coolwarm', center=0,
            xticklabels=heatmap_feature_names, yticklabels=heatmap_feature_names,
            linewidths=0.5, linecolor='gray', cbar_kws={"shrink": .8})
plt.title("Feature Correlation Matrix (Top Gene PCs + TCR Features)", fontsize=16)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

## Supervised Classification of Immunotherapy Response
The core predictive task was formulated as a binary classification problem: predicting the patient response label (Responder vs. Non-Responder) for each individual cell. We evaluated a diverse suite of algorithms:
*   **Logistic Regression:** A linear baseline model.
*   **Decision Trees:** A simple, interpretable non-linear model.
*   **Random Forest:** An ensemble of decision trees that reduces overfitting.
*   **XGBoost (Extreme Gradient Boosting):** A highly optimized gradient boosting framework known for strong performance on tabular data.

### Experimental Setup
We designed our experiments to isolate the predictive value of different data modalities. We trained and evaluated models on four nested feature sets:
1.  **Baseline:** Technical covariates only (e.g., mitochondrial percentage, library size).
2.  **Gene-Enhanced:** Baseline + Gene Expression PCs.
3.  **TCR-Enhanced:** Baseline + TCR Encodings (One-hot, K-mer, Physicochemical).
4.  **Comprehensive:** Baseline + Gene Expression PCs + TCR Encodings.

### Validation Strategy (Updated)
To obtain patient-level generalization estimates and to avoid data leakage between cells from the same patient, we use a Leave-One-Patient-Out (LOPO) cross-validation as the outer evaluation loop. Hyperparameter tuning is performed within the training partitions using GroupKFold (grouped by patient) when possible, falling back to stratified folds only when the number of training patients is too small for grouped splits. Feature scaling and imputation are fit on training partitions only and applied to held-out patient data to ensure leakage-free evaluation.

In [37]:
%pip install scipy
import scipy

Note: you may need to restart the kernel to use updated packages.


In [38]:
%%time
# --- Patient-level LOPO CV (Leakage-safe) [OPTIMIZED] ---
print("Starting patient-level LOPO CV with leakage-safe pipelines (Optimized for Speed/Accuracy)...")

from sklearn.model_selection import LeaveOneGroupOut, GroupKFold, StratifiedKFold, GridSearchCV, RandomizedSearchCV
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
import numpy as np
import pandas as pd
import xgboost as xgb
from pathlib import Path
import gc, time

# --- Optimization Settings ---
USE_RANDOM_SEARCH = True  # Use RandomizedSearchCV for speed
N_ITER_SEARCH = 15        # Max hyperparameter combinations to try per fold
N_JOBS_CV = -1            # Parallelize Cross-Validation (uses all cores)
N_JOBS_MODEL = 1          # Single thread per model to avoid contention

# Prepare grouping variable (patient) and supervised mask
# Robust column detection for patient_id
patient_id_col = None
if 'patient_id' in adata.obs.columns:
    patient_id_col = 'patient_id'
elif 'Patient_ID' in adata.obs.columns:
    adata.obs['patient_id'] = adata.obs['Patient_ID']  # Create lowercase copy
    patient_id_col = 'patient_id'
elif 'PatientID' in adata.obs.columns:
    adata.obs['patient_id'] = adata.obs['PatientID']  # Create lowercase copy
    patient_id_col = 'patient_id'
else:
    # Fallback 1: infer from sample_id using metadata_df
    sample_col = None
    for _c in ['sample_id', 'Sample_ID', 'GEX_Sample_ID', 'sample', 'Sample']:
        if _c in adata.obs.columns:
            sample_col = _c
            break

    if sample_col is not None and 'metadata_df' in globals():
        if 'GEX_Sample_ID' in metadata_df.columns and 'Patient_ID' in metadata_df.columns:
            sample_to_patient = (
                metadata_df[['GEX_Sample_ID', 'Patient_ID']]
                .dropna()
                .drop_duplicates()
                .set_index('GEX_Sample_ID')['Patient_ID']
            )
            adata.obs['patient_id'] = adata.obs[sample_col].map(sample_to_patient)
        else:
            md_cols = {c.lower(): c for c in metadata_df.columns}
            if 'gex_sample_id' in md_cols and 'patient_id' in md_cols:
                sample_to_patient = (
                    metadata_df[[md_cols['gex_sample_id'], md_cols['patient_id']]]
                    .dropna()
                    .drop_duplicates()
                    .set_index(md_cols['gex_sample_id'])[md_cols['patient_id']]
                )
                adata.obs['patient_id'] = adata.obs[sample_col].map(sample_to_patient)

    # Fallback 2: parse patient id from sample_id strings (e.g., "PT1")
    if 'patient_id' not in adata.obs.columns or adata.obs['patient_id'].isna().all():
        if sample_col is not None:
            adata.obs['patient_id'] = adata.obs[sample_col].astype(str).str.extract(r'(PT\d+)')[0]

    if 'patient_id' in adata.obs.columns and adata.obs['patient_id'].notna().any():
        patient_id_col = 'patient_id'
    elif adata.n_obs == 0:
        adata.obs['patient_id'] = pd.Series(index=adata.obs.index, dtype='object')
        patient_id_col = 'patient_id'
    else:
        raise KeyError(
            "No patient ID column found. Tried direct columns, metadata_df mapping, and parsing from sample_id."
        )

groups_all = np.array(adata.obs[patient_id_col][supervised_mask])
unique_patients = np.unique(groups_all)
print(f"Supervised patients: {len(unique_patients)} -> {unique_patients}")

# --- EARLY VALIDATION: Check for empty supervised set ---
if len(groups_all) == 0 or len(unique_patients) == 0:
    print("WARNING: No supervised samples found (supervised_mask is empty).")
    print("Skipping patient-level LOPO CV and deep learning evaluation to prevent memory waste and errors.")
    print("This can happen if no samples have valid 'response' annotations.")
    lopo_summary_rows = []
    dl_results_rows = []
else:
    # Per-patient response summary
    patient_response_df = (
        adata.obs[supervised_mask][[patient_id_col, 'response']]
        .reset_index()
        .drop_duplicates(subset=patient_id_col)
        .set_index(patient_id_col)
    )
    print("Per-patient response counts:")
    print(patient_response_df['response'].value_counts())

    # --- Memory cleanup ---
    _start_cleanup = time.time()
    print("Cleaning up temporary variables and large matrices before ML.")
    # Flags (defaults)
    DROP_ONEHOT_OBSM = False
    DROP_RAW = False
    DROP_OBSM_UMAP_TSNE = True

    _vars_to_delete = [
        'tra_onehot','trb_onehot','tra_onehot_flat','trb_onehot_flat',
        'onehot_tra_reduced','onehot_trb_reduced','onehot_trb_pca','onehot_trb_reduced_new',
        'tmp','tmp1','tmp2','seq_scaler','seq_scaler_full','seq_scaler_flat','length_results'
    ]
    for _v in _vars_to_delete:
        if _v in globals():
            try:
                del globals()[_v]
            except Exception: pass

    try:
        if hasattr(adata, 'obsp'):
            for _k in list(adata.obsp.keys()): 
                try: del adata.obsp[_k]
                except: pass
        for _k in ['neighbors', 'umap']:
            if _k in adata.uns: 
                try: del adata.uns[_k]
                except: pass
        if DROP_OBSM_UMAP_TSNE:
            for _key in list(adata.obsm.keys()):
                _lk = _key.lower()
                if 'umap' in _lk or 'tsne' in _lk or (_lk == 'x_pca' and 'x_gene_pca' not in _lk):
                    try: del adata.obsm[_key]
                    except: pass
        if DROP_ONEHOT_OBSM:
            for _key in ['X_tcr_tra_onehot', 'X_tcr_trb_onehot']:
                 if _key in adata.obsm: 
                     try: del adata.obsm[_key]
                     except: pass
        if DROP_RAW and getattr(adata, 'raw', None) is not None:
             adata.raw = None
    except Exception as _e:
        print('Error while pruning adata structures:', _e)

    try:
        import tensorflow.keras.backend as K
        K.clear_session()
    except Exception: pass
    gc.collect()

    # --- Define Models & Optimized Hyperparameters ---
    # Defined here to ensure robust execution without dependency on other cells
    param_grids = {
        'Logistic Regression': {'C': [0.1, 1, 10], 'penalty': ['l2'], 'solver': ['liblinear']},
        'Decision Tree': {'max_depth': [5, 10], 'min_samples_split': [5, 10], 'min_samples_leaf': [2, 4]},
        'Random Forest': {'n_estimators': [100], 'max_depth': [10, 20], 'min_samples_split': [5, 10]}, # Reduced grid
        'XGBoost': {
            'max_depth': [3, 5], 
            'learning_rate': [0.05, 0.1], 
            'subsample': [0.8, 1.0], 
            'colsample_bytree': [0.8, 1.0], 
            'n_estimators': [100]
        }
    }

    models_eval = {
        'Logistic Regression': LogisticRegression(random_state=42, max_iter=1000, solver='liblinear'),
        'Decision Tree': DecisionTreeClassifier(random_state=42),
        'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=N_JOBS_MODEL),
        'XGBoost': (lambda: (globals().get('XGBClassifierSK', xgb.XGBClassifier)(
            random_state=42, 
            use_label_encoder=False, 
            eval_metric='logloss',
            n_jobs=N_JOBS_MODEL,
            **({'tree_method':'gpu_hist','predictor':'gpu_predictor'} 
               if globals().get('XGBOOST_GPU_AVAILABLE', False) 
               else {'tree_method':'hist'}) # Optimization: Use 'hist' on CPU which is much faster than 'exact'
        )))()
    }
    _apply_gpu_patches()

    # Adapt param_grids to pipeline format (prefix 'clf__')
    param_grid_pipeline = {m: {f'clf__{k}': v for k, v in g.items()} for m, g in param_grids.items()}

    logo = LeaveOneGroupOut()
    lopo_summary_rows = []

    # Iterate feature sets
    for feature_name, X_features in feature_sets.items():
        print(f"\n=== Feature set: {feature_name} (shape={X_features.shape}) ===")
        X = X_features
        y = y_encoded
        groups = groups_all

        accum = {m: {'y_true': [], 'y_pred': [], 'y_proba': [], 'groups': []} for m in models_eval.keys()}

        for fold_idx, (train_idx, test_idx) in enumerate(logo.split(X, y, groups)):
            held_patient = np.unique(groups[test_idx])
            print(f"LOPO fold {fold_idx+1}/{len(unique_patients)} -- held patient(s): {held_patient}")

            X_tr, X_te = X[train_idx], X[test_idx]
            y_tr, y_te = y[train_idx], y[test_idx]
            groups_tr = groups[train_idx]
            
            n_train_groups = len(np.unique(groups_tr))
            inner_n_splits = min(3, n_train_groups) if n_train_groups >= 2 else 1

            for model_name, base_model in models_eval.items():
                pipeline = Pipeline([
                    ('imputer', SimpleImputer(strategy='mean')),
                    ('scaler', StandardScaler()),
                    ('clf', base_model)
                ])

                # Hyperparameter tuning
                # Use RandomizedSearchCV to cap the maximum time spent on regular algorithms
                if model_name in param_grid_pipeline:
                    # Determine strategy
                    grid_params = param_grid_pipeline[model_name]
                    grid_size = np.prod([len(v) for v in grid_params.values()])
                    
                    # If grid is small enough, use GridSearch. If large, use RandomizedSearchCV
                    if USE_RANDOM_SEARCH and grid_size > N_ITER_SEARCH:
                        search_impl = RandomizedSearchCV(pipeline, grid_params, n_iter=N_ITER_SEARCH, 
                                                       cv=inner_n_splits if inner_n_splits > 1 else StratifiedKFold(3),
                                                       scoring='accuracy', n_jobs=N_JOBS_CV, random_state=42)
                    else:
                        search_impl = GridSearchCV(pipeline, grid_params, 
                                                 cv=inner_n_splits if inner_n_splits > 1 else StratifiedKFold(3),
                                                 scoring='accuracy', n_jobs=N_JOBS_CV)

                    # Fit
                    if inner_n_splits >= 2:
                        search_impl.fit(X_tr, y_tr, groups=groups_tr)
                    else: 
                        # Fallback for few groups
                        search_impl.fit(X_tr, y_tr)
                        
                    best_model = search_impl.best_estimator_
                else:
                    best_model = pipeline.fit(X_tr, y_tr)

                # Predict
                y_pred = best_model.predict(X_te)
                try:
                    y_pred_proba = best_model.predict_proba(X_te)[:, 1]
                except Exception:
                    try:
                        d = best_model.decision_function(X_te)
                        y_pred_proba = d[:, 1] if d.ndim > 1 else d
                    except:
                        y_pred_proba = np.zeros(len(y_pred))

                # Accumulate
                accum[model_name]['y_true'].extend(y_te.tolist())
                accum[model_name]['y_pred'].extend(y_pred.tolist())
                accum[model_name]['y_proba'].extend(y_pred_proba.tolist())
                accum[model_name]['groups'].extend(groups[test_idx].tolist())

        # --- Aggregation & Reporting ---
        for model_name, data_dict in accum.items():
            y_true_all = np.array(data_dict['y_true'])
            y_pred_all = np.array(data_dict['y_pred'])
            y_proba_all = np.array(data_dict['y_proba'])
            groups_all_pred = np.array(data_dict.get('groups', []), dtype=object)

            if len(y_true_all) == 0: continue

            # Cell-level metrics
            acc = accuracy_score(y_true_all, y_pred_all)
            prec = precision_score(y_true_all, y_pred_all, zero_division=0)
            rec = recall_score(y_true_all, y_pred_all, zero_division=0)
            f1s = f1_score(y_true_all, y_pred_all, zero_division=0)
            try: auc = roc_auc_score(y_true_all, y_proba_all)
            except: auc = float('nan')
            cm = confusion_matrix(y_true_all, y_pred_all)
            if cm.size == 4:
                tn, fp, fn, tp = cm.ravel()
                spec = tn / (tn + fp) if (tn + fp) > 0 else float('nan')
                npv = tn / (tn + fn) if (tn + fn) > 0 else float('nan')
            else: spec, npv = float('nan'), float('nan')

            lopo_summary_rows.append({
                'feature_set': feature_name, 'model': model_name, 'evaluation_level': 'cell',
                'accuracy': acc, 'precision': prec, 'recall': rec, 'f1': f1s, 'auc': auc,
                'specificity': spec, 'npv': npv, 'n_patients': len(unique_patients), 'n_cells': X_features.shape[0]
            })

            # Patient-level aggregation
            try:
                pred_df = pd.DataFrame({'patient': groups_all_pred, 'y_true': y_true_all, 'y_proba': y_proba_all})
                patient_summary = pred_df.groupby('patient').agg({'y_proba': 'mean', 'y_true': 'first'}).reset_index()
                patient_summary['y_pred'] = (patient_summary['y_proba'] >= 0.5).astype(int)

                y_t, y_p = patient_summary['y_true'], patient_summary['y_pred']
                try: auc_p = roc_auc_score(y_t, patient_summary['y_proba'])
                except: auc_p = float('nan')
                
                lopo_summary_rows.append({
                    'feature_set': feature_name, 'model': model_name, 'evaluation_level': 'patient',
                    'accuracy': accuracy_score(y_t, y_p), 'precision': precision_score(y_t, y_p, zero_division=0),
                    'recall': recall_score(y_t, y_p, zero_division=0), 'f1': f1_score(y_t, y_p, zero_division=0),
                    'auc': auc_p, 'n_patients': len(patient_summary), 'n_cells': X_features.shape[0]
                })
                
                p_out = Path('Processed_Data') / f'lopo_patient_predictions_{feature_name}_{model_name}.csv'
                patient_summary.to_csv(p_out, index=False)
            except Exception as e:
                print(f"Failed patient-level metrics: {e}")

    lopo_df = pd.DataFrame(lopo_summary_rows)
    output_path = Path('Processed_Data') / 'lopo_results.csv'
    Path('Processed_Data').mkdir(exist_ok=True)
    lopo_df.to_csv(output_path, index=False)
    print(f"LOPO results saved to: {output_path}")
    display(lopo_df)

Starting patient-level LOPO CV with leakage-safe pipelines (Optimized for Speed/Accuracy)...
Supervised patients: 4 -> ['PT1' 'PT2' 'PT3' 'PT4']
Per-patient response counts:
response
Non-Responder    2
Responder        2
Name: count, dtype: int64
Cleaning up temporary variables and large matrices before ML.
Failed to patch models_eval['XGBoost']: 'XGBClassifierSK' object has no attribute 'predictor'
Patched models_eval['Random Forest'] to use n_jobs=-1.
Patched param_grids['XGBoost'] with GPU options (method=device).
Failed to patch models_eval['XGBoost']: 'XGBClassifierSK' object has no attribute 'predictor'
Patched models_eval['Random Forest'] to use n_jobs=-1.
Patched param_grids['XGBoost'] with GPU options (method=device).

=== Feature set: basic (shape=(100067, 29)) ===
LOPO fold 1/4 -- held patient(s): ['PT1']


AttributeError: 'XGBClassifierSK' object has no attribute 'predictor'

In [39]:
# === FIX 1.4: CONSTRAIN HYPERPARAMETER GRID ===
# PREVIOUS: 162 XGBoost combinations for 7 patients caused overfitting
# IMPROVED: Reduced to 16 combinations to prevent hyperparameter overfitting
param_grids = {
    'Logistic Regression': {
        'C': [0.1, 1, 10],  # Reduced from 5 to 3 options
        'penalty': ['l2'],
        'solver': ['liblinear']
    },
    'Decision Tree': {
        'max_depth': [5, 10],  # Reduced: removed 20 and None (prone to overfitting)
        'min_samples_split': [5, 10],  # Removed 2 (too permissive)
        'min_samples_leaf': [2, 4]  # Removed 1 (too permissive)
    },
    'Random Forest': {
        'n_estimators': [100],  # Fixed value (vs [50, 100, 200])
        'max_depth': [10, 20],  # Removed None (unconstrained depth)
        'min_samples_split': [5, 10]  # Removed 2
    },
    'XGBoost': {
        'max_depth': [3, 5],  # Reduced from [3, 6, 9]
        'learning_rate': [0.05, 0.1],  # Reduced from [0.01, 0.1, 0.3]
        'subsample': [0.8, 1.0],  # Kept same
        'colsample_bytree': [0.8, 1.0],  # Reduced from [0.6, 0.8, 1.0]
        'n_estimators': [100]  # Fixed (vs [50, 100, 200])
    }
}
# Total: LR=3, DT=2√ó2√ó2=8, RF=1√ó2√ó2=4, XGB=2√ó2√ó2√ó1=8 (manageable grid)
print('FIXED: param_grids defined with reduced hyperparameter space:', list(param_grids.keys()))
print('  Logistic Regression: 3 combinations')
print('  Decision Tree: 8 combinations')
print('  Random Forest: 4 combinations')
print('  XGBoost: 8 combinations')


FIXED: param_grids defined with reduced hyperparameter space: ['Logistic Regression', 'Decision Tree', 'Random Forest', 'XGBoost']
  Logistic Regression: 3 combinations
  Decision Tree: 8 combinations
  Random Forest: 4 combinations
  XGBoost: 8 combinations


## Advanced Deep Learning: Multimodal RNN
To better capture the sequential nature of TCR data, we implement a **Multimodal Recurrent Neural Network (RNN)**. This architecture processes the heterogeneous input data using specialized sub-networks:
1.  **Gene Expression Branch:** A Dense network processes the PCA-reduced gene expression features.
2.  **TCR Sequence Branches:** Two separate LSTM (Long Short-Term Memory) networks process the raw amino acid sequences of the TRA and TRB chains, respectively. LSTMs are well-suited for capturing sequential dependencies and motifs in protein sequences.
3.  **Fusion Layer:** The outputs of these branches are concatenated and passed through a final dense classification head.

This approach allows the model to learn complex interactions between the transcriptomic state of the T-cell and its specific antigen receptor sequence.

In [None]:
# --- Optimization Helper: Streaming Data Generator & Memory Config ---
import numpy as np
import math
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import mixed_precision

# 1. Enable Mixed Precision to reduce memory usage (Float16)
try:
    if len(tf.config.list_physical_devices('GPU')) > 0:
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_global_policy(policy)
        print("Mixed precision enabled (mixed_float16).")
    else:
        print("GPU not found; skipping mixed precision (using float32).")
except Exception as e:
    print(f"Mixed precision setup failed: {e}")

# 2. Lazy Generator to prevent OOM
class LazySequenceGenerator(keras.utils.Sequence):
    """
    Generates batches of data from AnnData sparse matrices on the fly.
    Avoids densifying the entire dataset in memory.
    """
    def __init__(self, adata, indices, y, batch_size=32, shuffle=True, 
                 use_gene=False, X_gene=None, 
                 use_seq=False, 
                 tra_key='X_tcr_tra_onehot', trb_key='X_tcr_trb_onehot',
                 arch='MLP', n_channels=20):
        self.adata = adata
        self.indices = indices  # Global indices into adata
        self.y = y              # Labels corresponding to indices
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.use_gene = use_gene
        self.X_gene = X_gene    # Pre-scaled gene data corresponding to indices (0..N relative)
        self.use_seq = use_seq
        self.tra_key = tra_key
        self.trb_key = trb_key
        self.arch = arch
        self.n_channels = n_channels
        self.indexes = np.arange(len(self.indices))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __len__(self):
        return int(np.ceil(len(self.indices) / self.batch_size))

    def __getitem__(self, index):
        # Indices for this batch in the local subset
        batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Global indices for sparse data lookup
        global_indices = self.indices[batch_indexes]
        
        inputs = []
        
        # Sequence Data (Heavy - Handle Lazily)
        if self.use_seq:
            # Slicing sparse matrix is fast and low-memory
            tra = self.adata.obsm[self.tra_key][global_indices]
            trb = self.adata.obsm[self.trb_key][global_indices]
            
            # Densify only this micro-batch
            tra = tra.toarray() if hasattr(tra, 'toarray') else tra
            trb = trb.toarray() if hasattr(trb, 'toarray') else trb
            
            # Ensure float32 (or 16 via mixed precision policy interaction)
            tra = tra.astype(np.float32)
            trb = trb.astype(np.float32)

            if self.arch == 'MLP':
                # For MLP, concatenate flattened (Batch, SeqLen*20)
                # Inputs are already flat (Batch, Features)
                X_seq = np.concatenate([tra, trb], axis=1)
            else:
                # For CNN/RNN, reshape to 3D (Batch, SeqLen, Channels)
                # Infer seq_len
                seq_len = tra.shape[1] // self.n_channels
                tra_seq = tra.reshape(-1, seq_len, self.n_channels)
                trb_seq = trb.reshape(-1, seq_len, self.n_channels)
                X_seq = np.concatenate([tra_seq, trb_seq], axis=2)
            
            inputs.append(X_seq)
            
        # Gene Data (Pre-loaded, small)
        if self.use_gene:
            # Slice the pre-scaled array
            X_g = self.X_gene[batch_indexes]
            inputs.append(X_g)
            
        # Keras Model input format
        if len(inputs) == 1:
            final_X = inputs[0]
        else:
            final_X = inputs # [X_seq, X_gene] order typically
            
        final_y = self.y[batch_indexes]
        return final_X, final_y

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indexes)

In [None]:
%%time
# --- Advanced Multimodal Deep Learning (MLP / CNN / BiLSTM / Transformer) [OPTIMIZED]
# Optimization: Streaming Data Inspection, Mixed Precision, BatchNormalization, Serial Execution

if 'supervised_mask' not in globals() or supervised_mask.sum() == 0:
    print("WARNING: No supervised samples available.")
    dl_results_rows = []
else:
    import itertools
    import time
    import math
    import random
    import gc
    import numpy as np
    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers, regularizers
    from sklearn.model_selection import GroupKFold, StratifiedKFold
    from sklearn.preprocessing import StandardScaler
    from sklearn.utils.class_weight import compute_class_weight
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
    from pathlib import Path
    from joblib import Parallel, delayed

    # Deterministic seeds
    SEED = 42
    np.random.seed(SEED)
    random.seed(SEED)
    tf.random.set_seed(SEED)

    # 1. Device Config
    print("TensorFlow version:", tf.__version__)
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            try:
                tf.config.experimental.set_memory_growth(gpu, True)
            except RuntimeError as e:
                print(e)
    
    # 2. Metadata Inspection (No Data copy)
    def invoke_gc():
        gc.collect()
        tf.keras.backend.clear_session()

    def get_seq_len(adata, n_channels=20):
        # Infer sequence length from sparse matrix shape
        if 'X_tcr_tra_onehot' in adata.obsm:
            shape = adata.obsm['X_tcr_tra_onehot'].shape
            return shape[1] // n_channels
        return None

    # 3. Model Builders with Integrated Scaling (BatchNormalization)
    def compile_model(model, lr):
        model.compile(optimizer=keras.optimizers.Adam(learning_rate=lr), 
                      loss='binary_crossentropy', 
                      metrics=[keras.metrics.AUC(name='auc'), 'accuracy'])
        return model

    def build_mlp(input_dim, hidden1=128, hidden2=64, dropout=0.3, l2_reg=1e-3, lr=1e-3):
        inp = keras.Input(shape=(input_dim,), name='gene_input')
        # Use BatchNormalization instead of manual StandardScaler for efficiency
        x = layers.BatchNormalization()(inp)
        x = layers.Dense(hidden1, kernel_regularizer=regularizers.l2(l2_reg))(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        x = layers.Dropout(dropout)(x)
        x = layers.Dense(hidden2, kernel_regularizer=regularizers.l2(l2_reg))(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        x = layers.Dropout(dropout)(x)
        out = layers.Dense(1, activation='sigmoid')(x)
        model = keras.Model(inputs=inp, outputs=out)
        return compile_model(model, lr)

    def build_cnn(seq_len, n_channels, gene_dim=None, conv_filters=64, kernel_size=5, dropout=0.3, l2_reg=1e-3, lr=1e-3):
        seq_in = keras.Input(shape=(seq_len, n_channels), name='seq_input')
        x = layers.Conv1D(conv_filters, kernel_size, activation='relu', kernel_regularizer=regularizers.l2(l2_reg))(seq_in)
        x = layers.Conv1D(conv_filters, kernel_size, activation='relu', kernel_regularizer=regularizers.l2(l2_reg))(x)
        x = layers.GlobalMaxPooling1D()(x)
        if gene_dim is not None:
            gene_in = keras.Input(shape=(gene_dim,), name='gene_input')
            g = layers.Dense(64, activation='relu')(gene_in)
            x = layers.concatenate([x, g])
            out_in = [seq_in, gene_in]
        else:
            out_in = seq_in
        x = layers.Dropout(dropout)(x)
        x = layers.Dense(64, activation='relu')(x)
        out = layers.Dense(1, activation='sigmoid')(x)
        model = keras.Model(inputs=out_in, outputs=out)
        return compile_model(model, lr)

    def build_bilstm(seq_len, n_channels, gene_dim=None, lstm_units=128, dropout=0.3, l2_reg=1e-3, lr=1e-3):
        seq_in = keras.Input(shape=(seq_len, n_channels), name='seq_input')
        # LSTM
        x = layers.Bidirectional(layers.LSTM(lstm_units, return_sequences=False, kernel_regularizer=regularizers.l2(l2_reg)))(seq_in)
        if gene_dim is not None:
            gene_in = keras.Input(shape=(gene_dim,), name='gene_input')
            g = layers.Dense(64, activation='relu')(gene_in)
            x = layers.concatenate([x, g])
            out_in = [seq_in, gene_in]
        else:
            out_in = seq_in
        x = layers.Dropout(dropout)(x)
        x = layers.Dense(64, activation='relu')(x)
        out = layers.Dense(1, activation='sigmoid')(x)
        return compile_model(keras.Model(inputs=out_in, outputs=out), lr)

    class TransformerBlock(layers.Layer):
        def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs):
            super().__init__(**kwargs)
            self.embed_dim = embed_dim
            self.num_heads = num_heads
            self.ff_dim = ff_dim
            self.rate = rate
            self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
            self.ffn = keras.Sequential([layers.Dense(ff_dim, activation='relu'), layers.Dense(embed_dim)])
            self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
            self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
            self.dropout1 = layers.Dropout(rate)
            self.dropout2 = layers.Dropout(rate)
        def call(self, inputs, training=None):
            attn_output = self.att(inputs, inputs)
            attn_output = self.dropout1(attn_output, training=training)
            out1 = self.layernorm1(inputs + attn_output)
            ffn_output = self.ffn(out1)
            ffn_output = self.dropout2(ffn_output, training=training)
            return self.layernorm2(out1 + ffn_output)
        def get_config(self):
            config = super().get_config()
            config.update({"embed_dim": self.embed_dim, "num_heads": self.num_heads, "ff_dim": self.ff_dim, "rate": self.rate})
            return config

    def build_transformer(seq_len, n_channels, gene_dim=None, embed_dim=64, num_heads=4, ff_dim=128, dropout=0.1, lr=1e-3):
        seq_in = keras.Input(shape=(seq_len, n_channels), name='seq_input')
        x = layers.Dense(embed_dim)(seq_in)
        x = TransformerBlock(embed_dim, num_heads, ff_dim, rate=dropout)(x)
        x = layers.GlobalAveragePooling1D()(x)
        if gene_dim is not None:
            gene_in = keras.Input(shape=(gene_dim,), name='gene_input')
            g = layers.Dense(64, activation='relu')(gene_in)
            x = layers.concatenate([x, g])
            inputs_list = [seq_in, gene_in]
        else:
            inputs_list = seq_in
        x = layers.Dropout(dropout)(x)
        x = layers.Dense(64, activation='relu')(x)
        out = layers.Dense(1, activation='sigmoid')(x)
        return compile_model(keras.Model(inputs=inputs_list, outputs=out), lr)

    # 4. Optimized Training Function (Uses Lazy Generators)
    def train_eval_single_config(cfg_idx, config, use_gene, use_seq,
                                 inner_train_inds_global, inner_val_inds_global,
                                 y_inner_tr, y_inner_val, class_weights,
                                 X_inner_tr_gene_scaled, X_inner_val_gene_scaled,
                                 seq_len):
        invoke_gc() # Clean up before training
        arch, hu, dr, lr, bs, epochs = config
        
        # Generator Creation
        train_gen = LazySequenceGenerator(adata, inner_train_inds_global, y_inner_tr, batch_size=bs,
                                          use_gene=use_gene, X_gene=X_inner_tr_gene_scaled,
                                          use_seq=use_seq, arch=arch, shuffle=True)
        val_gen = LazySequenceGenerator(adata, inner_val_inds_global, y_inner_val, batch_size=bs,
                                        use_gene=use_gene, X_gene=X_inner_val_gene_scaled,
                                        use_seq=use_seq, arch=arch, shuffle=False)
        
        input_dim_gene = X_inner_tr_gene_scaled.shape[1] if use_gene else 0
        input_dim_seq_flat = seq_len * 20 * 2 # approx
        
        try:
            if arch == 'MLP':
                # MLP Input Dim calculation
                # If generator concatenates flat seq + gene
                # Seq: (20 * seq_len)*2. Gene: input_dim_gene.
                # Use a dummy batch to check shape
                x_sample, _ = train_gen[0]
                total_dim = x_sample.shape[1] if hasattr(x_sample, 'shape') else x_sample[0].shape[1]
                model = build_mlp(total_dim, hidden1=hu, hidden2=max(32, hu//2), dropout=dr, l2_reg=1e-3, lr=lr)
                
            elif arch == 'CNN':
                model = build_cnn(seq_len, 20, gene_dim=(input_dim_gene if use_gene else None), conv_filters=hu, dropout=dr, l2_reg=1e-3, lr=lr)
            elif arch == 'BiLSTM':
                model = build_bilstm(seq_len, 20, gene_dim=(input_dim_gene if use_gene else None), lstm_units=hu, dropout=dr, l2_reg=1e-3, lr=lr)
            else: # Transformer
                model = build_transformer(seq_len, 20, gene_dim=(input_dim_gene if use_gene else None), embed_dim=hu//2, num_heads=4, ff_dim=hu, dropout=dr, lr=lr)

            es = keras.callbacks.EarlyStopping(monitor='val_auc', mode='max', patience=5, restore_best_weights=True, verbose=0)
            
            # FIT using GENERATOR
            model.fit(train_gen, validation_data=val_gen, epochs=epochs, 
                      class_weight=class_weights, callbacks=[es], verbose=0)
            
            # Evaluate (use batch processing via generator/predict)
            y_val_pred = model.predict(val_gen, verbose=0).flatten()
            # Truncate to match length (drop last batch remainder padding if any? No, gen handles it)
            # Generator length matches len(y_inner_val)? 
            # predict returns data for all batches.
            # Make sure labels match.
            y_val_label = (y_val_pred > 0.5).astype(int)
            val_acc = accuracy_score(y_inner_val[:len(y_val_label)], y_val_label)
            
            del model
            return cfg_idx, val_acc
        except Exception as e:
            print(f"Err {arch}: {e}")
            return cfg_idx, -1.0

    # 5. Main Execution Flow
    dl_param_grid = {
        'arch': ['MLP', 'CNN', 'BiLSTM', 'Transformer'],
        'hidden_units': [64, 128], # Reduced grid for speed test
        'dropout': [0.3],
        'lr': [1e-3],
        'batch_size': [32],
        'epochs': [20] # Reduced epochs
    }
    grid_items = list(itertools.product(dl_param_grid['arch'], dl_param_grid['hidden_units'], dl_param_grid['dropout'], dl_param_grid['lr'], dl_param_grid['batch_size'], dl_param_grid['epochs']))
    
    # Global Setup
    supervised_mask_local = supervised_mask
    X_gene_all = adata.obsm['X_gene_pca'][supervised_mask_local]
    seq_len = get_seq_len(adata)
    use_sequence = (seq_len is not None)
    
    y_all = y_encoded
    # Patient ID logic
    patient_id_col_local = next((c for c in ['patient_id', 'Patient_ID', 'PatientID'] if c in adata.obs.columns), None)
    groups_all_local = np.array(adata.obs[patient_id_col_local][supervised_mask_local])
    unique_patients = np.unique(groups_all_local)
    global_indices_all = np.where(supervised_mask_local)[0] # Global indices corresponding to y_all

    from sklearn.model_selection import LeaveOneGroupOut
    logo = LeaveOneGroupOut()
    dl_results_rows = []

    for feature_name in ['sequence_structure', 'comprehensive']:
        print(f"\n=== DL: {feature_name} ===")
        use_gene = True
        use_seq = (feature_name != 'gene_only') and use_sequence
        
        accum_arch = {arch: {'y_true': [], 'y_pred': [], 'y_proba': [], 'groups': []} for arch in ['MLP','CNN','BiLSTM','Transformer']}
        
        for fold_idx, (train_idx, test_idx) in enumerate(logo.split(X_gene_all, y_all, groups_all_local)):
            print(f"Fold {fold_idx+1}/{len(unique_patients)}")
            invoke_gc()
            
            # Split Data
            X_tr_gene = X_gene_all[train_idx]
            scaler = StandardScaler().fit(X_tr_gene)
            X_tr_gene_scaled = scaler.transform(X_tr_gene)
            X_te_gene_scaled = scaler.transform(X_gene_all[test_idx])
            
            y_tr = y_all[train_idx]
            y_te = y_all[test_idx]
            groups_tr = groups_all_local[train_idx]
            
            # Map fold indices to global indices
            train_global_inds = global_indices_all[train_idx]
            test_global_inds = global_indices_all[test_idx]

            classes = np.unique(y_tr)
            cw = compute_class_weight(class_weight='balanced', classes=classes, y=y_tr)
            class_weight_dict = {int(c): float(w) for c,w in zip(classes, cw)}
            
            # Inner CV
            n_train_groups = len(np.unique(groups_tr))
            inner_splits = min(3, n_train_groups) if n_train_groups >= 2 else 1
            
            best_cfg = list(grid_items)[0] # Default
            if inner_splits >= 2:
                inner_cv = GroupKFold(n_splits=inner_splits)
                config_scores = {i: [] for i in range(len(grid_items))}
                
                # SERIAL EXECUTION TO SAVE MEMORY (n_jobs=1)
                for inner_train_idx, inner_val_idx in inner_cv.split(X_tr_gene_scaled, y_tr, groups_tr):
                    # Map to global
                    inner_tr_global = train_global_inds[inner_train_idx]
                    inner_val_global = train_global_inds[inner_val_idx]
                    
                    for cfg_idx, config in enumerate(grid_items):
                        idx, score = train_eval_single_config(
                            cfg_idx, config, use_gene, use_seq,
                            inner_tr_global, inner_val_global,
                            y_tr[inner_train_idx], y_tr[inner_val_idx], class_weight_dict,
                            X_tr_gene_scaled[inner_train_idx], X_tr_gene_scaled[inner_val_idx],
                            seq_len
                        )
                        if score >= 0: config_scores[idx].append(score)
                
                # Select best
                best_avg = -math.inf
                for i, scores in config_scores.items():
                    if scores and np.mean(scores) > best_avg:
                        best_avg = np.mean(scores)
                        best_cfg = grid_items[i]
                print(f"  Best Config: {best_cfg} (Acc: {best_avg:.3f})")

            # Retrain Best
            invoke_gc()
            arch, hu, dr, lr, bs, epochs = best_cfg
            
            # Generators for final training
            train_gen_final = LazySequenceGenerator(adata, train_global_inds, y_tr, batch_size=bs,
                                                    use_gene=use_gene, X_gene=X_tr_gene_scaled,
                                                    use_seq=use_seq, arch=arch, shuffle=True)
            test_gen_final = LazySequenceGenerator(adata, test_global_inds, y_te, batch_size=bs,
                                                   use_gene=use_gene, X_gene=X_te_gene_scaled,
                                                   use_seq=use_seq, arch=arch, shuffle=False)
            
            # Build & Train
            x_sample, _ = train_gen_final[0]
            if arch == 'MLP':
                 total_dim = x_sample.shape[1] if hasattr(x_sample, 'shape') else x_sample[0].shape[1]
                 model = build_mlp(total_dim, hidden1=hu, hidden2=max(32, hu//2), dropout=dr, l2_reg=1e-3, lr=lr)
            elif arch == 'CNN':
                 model = build_cnn(seq_len, 20, gene_dim=(X_tr_gene.shape[1] if use_gene else None), conv_filters=hu, dropout=dr, l2_reg=1e-3, lr=lr)
            elif arch == 'BiLSTM':
                 model = build_bilstm(seq_len, 20, gene_dim=(X_tr_gene.shape[1] if use_gene else None), lstm_units=hu, dropout=dr, l2_reg=1e-3, lr=lr)
            else:
                 model = build_transformer(seq_len, 20, gene_dim=(X_tr_gene.shape[1] if use_gene else None), embed_dim=hu//2, num_heads=4, ff_dim=hu, dropout=dr, lr=lr)
            
            es = keras.callbacks.EarlyStopping(monitor='val_auc', mode='max', patience=8, restore_best_weights=True, verbose=0)
            model.fit(train_gen_final, validation_data=test_gen_final, epochs=epochs, 
                      class_weight=class_weight_dict, callbacks=[es], verbose=0)
            
            y_test_proba = model.predict(test_gen_final, verbose=0).flatten()
            y_test_pred = (y_test_proba > 0.5).astype(int)
            
            accum_arch[arch]['y_true'].extend(y_te.tolist())
            accum_arch[arch]['y_pred'].extend(y_test_pred.tolist())
            accum_arch[arch]['y_proba'].extend(y_test_proba.tolist())
            accum_arch[arch]['groups'].extend(groups_all_local[test_idx].tolist())
            
            del model
            invoke_gc()

        # Save Results logic (abbreviated)
        for arch, data in accum_arch.items():
            if not data['y_true']: continue
            acc = accuracy_score(data['y_true'], data['y_pred'])
            dl_results_rows.append({'feature_set': feature_name, 'architecture': arch, 'accuracy': acc, 'n_patients': len(unique_patients)})
    
    if dl_results_rows:
        dl_df = pd.DataFrame(dl_results_rows)
        print("Done. Saved results.")
        display(dl_df)


TensorFlow version: 2.19.0
Found 1 GPU(s): [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
TensorFlow will use: GPU
DL hyperparameter combinations: 32
Failed to patch models_eval['XGBoost']: 'XGBClassifierSK' object has no attribute 'predictor'
Patched models_eval['Random Forest'] to use n_jobs=-1.
Patched param_grids['XGBoost'] with GPU options (method=device).

=== DL evaluation using feature set: sequence_structure ===
LOPO fold 1/4 -- held patient: ['PT1']


I0000 00:00:1770357065.906960     284 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15511 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0
I0000 00:00:1770357077.314546     300 service.cc:152] XLA service 0x7c261000eaa0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1770357077.314597     300 service.cc:160]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
I0000 00:00:1770357078.109484     300 cuda_dnn.cc:529] Loaded cuDNN version 91002
I0000 00:00:1770357082.405076     299 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.




2026-02-06 05:53:25.316906: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-06 05:56:59.189785: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-06 06:09:19.586217: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-06 06:09:19.774051: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-06 06:09:19.995502: E external/local_xla/xla/stream_

  Selected best inner config: ('BiLSTM', 128, 0.3, 0.0001, 32, 30) with mean val acc=0.7840
[1m793/793[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m3s[0m 3ms/step
LOPO fold 2/4 -- held patient: ['PT2']


2026-02-06 07:27:14.241572: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-06 07:32:56.672610: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-06 07:46:58.520700: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-06 07:49:06.638878: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-06 07:49:06.851073: E external/local_xla/xla/stream_

  Selected best inner config: ('BiLSTM', 128, 0.2, 0.001, 32, 30) with mean val acc=0.5855
[1m364/364[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m1s[0m 3ms/step
LOPO fold 3/4 -- held patient: ['PT3']


2026-02-06 08:33:53.886339: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.


  Selected best inner config: ('MLP', 128, 0.3, 0.001, 32, 30) with mean val acc=0.5910
[1m1018/1018[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m2s[0m 2ms/step
LOPO fold 4/4 -- held patient: ['PT4']


In [None]:
# --- SMOKE TEST: Benchmark & Verification ---
import time
import psutil
import os
import numpy as np
from scipy import sparse

print("=== SMOKE TEST: Memory & Performance Check ===")

def get_mem_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024

mem_start = get_mem_usage()
start_time = time.time()

# 1. Setup Test Data (Subset of Real or Mock)
print("Setting up smoke test data...")
try:
    if 'adata' in globals() and 'supervised_mask' in globals():
        indices = np.where(supervised_mask)[0]
        limit = min(len(indices), 128)
        subset_indices = indices[:limit]
        y_subset = y_encoded[:limit]
        dataset_mode = "REAL"
    else:
        raise NameError("Data not loaded")
except Exception as e:
    print(f"Real data unavailable ({e}), creating MOCK data...")
    dataset_mode = "MOCK"
    class MockAdata:
        def __init__(self): self.obsm = {}
        class obs_attr:
            columns = []
        obs = obs_attr()
    
    adata = MockAdata()
    # Create random sparse matrices (simulating one-hot)
    # 20AA * 30 Len = 600 cols
    adata.obsm['X_tcr_tra_onehot'] = sparse.random(500, 600, density=0.05, format='csr')
    adata.obsm['X_tcr_trb_onehot'] = sparse.random(500, 600, density=0.05, format='csr')
    subset_indices = np.arange(128)
    y_subset = np.random.randint(0, 2, 128)

# 2. Test Generator & Training
try:
    # Instantiate Generator
    gen_test = LazySequenceGenerator(adata, subset_indices, y_subset, batch_size=32, 
                                     use_seq=True, use_gene=False, n_channels=20)
    
    # Fetch one batch to verify shapes
    X_batch, y_batch = gen_test[0]
    print(f"Generator Batch Shape: {X_batch.shape if hasattr(X_batch, 'shape') else [x.shape for x in X_batch]}")
    
    # Build a small BiLSTM model
    seq_len = X_batch.shape[1]
    n_channels = X_batch.shape[2]
    
    print(f"Building BiLSTM (SeqLen={seq_len}, Channels={n_channels})...")
    # Need to access build_bilstm from previous cell
    model_test = build_bilstm(seq_len, n_channels, gene_dim=None, lstm_units=32)
    
    print("Running 1 training epoch...")
    h = model_test.fit(gen_test, epochs=1, verbose=1)
    
    mem_end = get_mem_usage()
    print("\n--- SMOKE TEST PASSED ---")
    print(f"Dataset Mode: {dataset_mode}")
    print(f"Memory Increase: {mem_end - mem_start:.2f} MB")
    print(f"Execution Time: {time.time() - start_time:.2f} s")
    print(f"Final Loss: {h.history['loss'][0]:.4f}")
    
except Exception as e:
    print(f"\n!!! SMOKE TEST FAILED !!!")
    print(e)
    if 'model_test' in locals():
        del model_test
    # Traceback
    import traceback
    traceback.print_exc()


In [None]:
%%time
# --- Experiment with Sequence Length Cutoffs ---
print("Experimenting with sequence length cutoffs...")
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score
from sklearn.decomposition import PCA, TruncatedSVD
# Ensure adata exists
if 'adata' not in globals() or adata is None:
    raise NameError("adata is not defined. Please run the data loading cells first.")
# Ensure label/patient helpers have run
if '_ensure_response_and_patient' in globals():
    _ensure_response_and_patient(adata)
if '_get_supervised_mask_and_labels' in globals():
    supervised_mask, y_supervised, _le, class_counts, SUPERVISED_AVAILABLE = _get_supervised_mask_and_labels(adata)
else:
    # Fallback logic
    if 'response' in adata.obs.columns:
        y_all = adata.obs['response'].astype(str)
        supervised_mask = y_all.isin(['Responder', 'Non-Responder']).values
        y_supervised = y_all[supervised_mask]
        class_counts = y_supervised.value_counts().to_dict()
        SUPERVISED_AVAILABLE = len(class_counts) >= 2 and min(class_counts.values()) >= 2
        _le = LabelEncoder() if SUPERVISED_AVAILABLE else None
    else:
        supervised_mask = np.zeros(adata.n_obs, dtype=bool)
        y_supervised = pd.Series([], dtype=object)
        class_counts = {}
        SUPERVISED_AVAILABLE = False
        _le = None
globals()['SUPERVISED_AVAILABLE'] = SUPERVISED_AVAILABLE
if not SUPERVISED_AVAILABLE:
    print("WARNING: No supervised samples available. Skipping sequence length cutoff experiment.")
else:
    supervised_mask = np.asarray(supervised_mask)
    if '_ensure_int_labels' in globals():
        y_encoded = _ensure_int_labels(_le.fit_transform(y_supervised))
    else:
        y_encoded = np.asarray(_le.fit_transform(y_supervised), dtype=np.int64)
    min_class = min(class_counts.values()) if class_counts else 0
    stratify = y_encoded if min_class >= 2 else None
    cv_folds = min(3, min_class) if min_class >= 2 else 0
    # Ensure cdr3_sequences exists
    if 'cdr3_sequences' not in globals():
        cdr3_sequences = {
            'TRA': adata.obs['cdr3_TRA'].astype(str).fillna('').str.upper().tolist() if 'cdr3_TRA' in adata.obs.columns else [''] * adata.n_obs,
            'TRB': adata.obs['cdr3_TRB'].astype(str).fillna('').str.upper().tolist() if 'cdr3_TRB' in adata.obs.columns else [''] * adata.n_obs
        }
    # Ensure gene features exist
    if 'gene_features' not in globals():
        if 'X_gene_pca' in adata.obsm:
            gene_features = adata.obsm['X_gene_pca'][supervised_mask]
        elif 'X_pca' in adata.obsm:
            gene_features = adata.obsm['X_pca'][supervised_mask]
        else:
            gene_features = np.zeros((int(supervised_mask.sum()), 30))
    if gene_features.shape[1] < 30:
        gene_features = np.pad(gene_features, ((0, 0), (0, 30 - gene_features.shape[1])), mode='constant')
    # Ensure TCR physico features exist
    if 'tcr_physico' not in globals():
        if all(c in adata.obs.columns for c in ['tra_length', 'tra_molecular_weight', 'tra_hydrophobicity','trb_length', 'trb_molecular_weight', 'trb_hydrophobicity']):
            tra_physico = adata.obs[['tra_length', 'tra_molecular_weight', 'tra_hydrophobicity']].fillna(0)[supervised_mask].values.astype(np.float32)
            trb_physico = adata.obs[['trb_length', 'trb_molecular_weight', 'trb_hydrophobicity']].fillna(0)[supervised_mask].values.astype(np.float32)
            tcr_physico = np.column_stack([tra_physico, trb_physico]).astype(np.float32)
        else:
            tcr_physico = np.zeros((int(supervised_mask.sum()), 6), dtype=np.float32)
    # Ensure QC features exist
    if 'qc_features' not in globals():
        if all(c in adata.obs.columns for c in ['n_genes_by_counts', 'total_counts', 'pct_counts_mt']):
            qc_features = adata.obs[['n_genes_by_counts', 'total_counts', 'pct_counts_mt']].fillna(0)[supervised_mask].values
        else:
            qc_features = np.zeros((int(supervised_mask.sum()), 3))
    # Define length cutoffs to test
    length_cutoffs = [10, 15, 20, 25, 30, 35, 40, 50]
    length_results = []
    for max_length in length_cutoffs:
        print(f"\nTesting max sequence length: {max_length}")
        # Re-encode sequences with new length - ENSURE FLOAT32 DTYPE
        tra_onehot_new = np.array([one_hot_encode_sequence(seq, max_length, 'ACDEFGHIKLMNPQRSTVWY')
                                   for seq in cdr3_sequences['TRA']], dtype=np.float32)
        tra_onehot_flat_new = tra_onehot_new.reshape(tra_onehot_new.shape[0], -1).astype(np.float32)
        trb_onehot_new = np.array([one_hot_encode_sequence(seq, max_length, 'ACDEFGHIKLMNPQRSTVWY')
                                   for seq in cdr3_sequences['TRB']], dtype=np.float32)
        trb_onehot_flat_new = trb_onehot_new.reshape(trb_onehot_new.shape[0], -1).astype(np.float32)
        # Update AnnData
        adata.obsm['X_tcr_tra_onehot'] = tra_onehot_flat_new
        adata.obsm['X_tcr_trb_onehot'] = trb_onehot_flat_new
        # Re-create feature sets with new encodings using robust PCA
        # Use robust PCA reduction with fallback to TruncatedSVD
        try:
            n_comp_onehot = min(50, adata.obsm['X_tcr_tra_onehot'][supervised_mask].shape[1], max(1, adata.obsm['X_tcr_tra_onehot'][supervised_mask].shape[0]-1))
            onehot_tra_reduced = PCA(n_components=n_comp_onehot, svd_solver='randomized', random_state=42).fit_transform(adata.obsm['X_tcr_tra_onehot'][supervised_mask])
        except Exception as e:
            print(f"  PCA failed for TRA ({e}), using TruncatedSVD")
            n_comp = max(1, min(50, adata.obsm['X_tcr_tra_onehot'][supervised_mask].shape[1]-1))
            onehot_tra_reduced = TruncatedSVD(n_components=n_comp, random_state=42).fit_transform(adata.obsm['X_tcr_tra_onehot'][supervised_mask])
        try:
            n_comp_onehot = min(50, adata.obsm['X_tcr_trb_onehot'][supervised_mask].shape[1], max(1, adata.obsm['X_tcr_trb_onehot'][supervised_mask].shape[0]-1))
            onehot_trb_reduced = PCA(n_components=n_comp_onehot, svd_solver='randomized', random_state=42).fit_transform(adata.obsm['X_tcr_trb_onehot'][supervised_mask])
        except Exception as e:
            print(f"  PCA failed for TRB ({e}), using TruncatedSVD")
            n_comp = max(1, min(50, adata.obsm['X_tcr_trb_onehot'][supervised_mask].shape[1]-1))
            onehot_trb_reduced = TruncatedSVD(n_components=n_comp, random_state=42).fit_transform(adata.obsm['X_tcr_trb_onehot'][supervised_mask])
        X_sequence = np.column_stack([
            gene_features[:, :30],
            onehot_tra_reduced,
            onehot_trb_reduced,
            tcr_physico,
            qc_features
        ])
        # Train and evaluate model
        X_train, X_test, y_train, y_test = train_test_split(
            X_sequence, y_encoded, test_size=0.3, random_state=42, stratify=stratify
        )
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)
        XGBClass = globals().get('XGBClassifierSK', xgb.XGBClassifier)
        model = XGBClass(random_state=42, eval_metric='logloss')
        model.fit(X_train_scaled, y_train)
        y_pred = model.predict(X_test_scaled)
        accuracy = accuracy_score(y_test, y_pred)
        # Cross-validation
        if cv_folds >= 2:
            cv_scores = cross_val_score(model, X_sequence, y_encoded, cv=cv_folds, scoring='accuracy')
            cv_mean = cv_scores.mean()
            cv_std = cv_scores.std()
        else:
            print("  Skipping CV: not enough samples per class.")
            cv_scores = np.array([])
            cv_mean = float('nan')
            cv_std = float('nan')
        length_results.append({
            'max_length': max_length,
            'accuracy': accuracy,
            'cv_mean': cv_mean,
            'cv_std': cv_std
        })
        print(f"  Accuracy: {accuracy:.3f}, CV: {cv_mean:.3f} ¬± {cv_std:.3f}")
    # Plot results
    length_df = pd.DataFrame(length_results)
    plt.figure(figsize=(10, 6))
    plt.plot(length_df['max_length'], length_df['accuracy'], 'o-', label='Test Accuracy', linewidth=2)
    plt.plot(length_df['max_length'], length_df['cv_mean'], 's-', label='CV Accuracy', linewidth=2)
    plt.fill_between(length_df['max_length'],
                     length_df['cv_mean'] - length_df['cv_std'],
                     length_df['cv_mean'] + length_df['cv_std'],
                     alpha=0.3, label='CV ¬± Std')
    plt.xlabel('Maximum Sequence Length Cutoff')
    plt.ylabel('Accuracy')
    plt.title('Model Accuracy vs Sequence Length Cutoff')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    print("\nSequence length cutoff experiment completed!")
    if not length_df.empty and length_df['cv_mean'].notna().any():
        best_len = length_df.loc[length_df['cv_mean'].idxmax(), 'max_length']
        print(f"Optimal length appears to be around {best_len}")

In [None]:
%%time
# --- Experiment with Sequence Length Cutoffs ---

print("Experimenting with sequence length cutoffs...")

from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score
from sklearn.decomposition import PCA, TruncatedSVD
import numpy as np
import pandas as pd
import xgboost as xgb

# Ensure adata exists
if 'adata' not in globals() or adata is None:
    raise NameError("adata is not defined. Please run the data loading cells first.")

# Ensure label/patient helpers have run
if '_ensure_response_and_patient' in globals():
    _ensure_response_and_patient(adata)

if '_get_supervised_mask_and_labels' in globals():
    supervised_mask, y_supervised, _le, class_counts, SUPERVISED_AVAILABLE = _get_supervised_mask_and_labels(adata)
else:
    # Fallback logic
    if 'response' in adata.obs.columns:
        y_all = adata.obs['response'].astype(str)
        supervised_mask = y_all.isin(['Responder', 'Non-Responder']).values
        y_supervised = y_all[supervised_mask]
        class_counts = y_supervised.value_counts().to_dict()
        SUPERVISED_AVAILABLE = len(class_counts) >= 2 and min(class_counts.values()) >= 2
        _le = LabelEncoder() if SUPERVISED_AVAILABLE else None
    else:
        supervised_mask = np.zeros(adata.n_obs, dtype=bool)
        y_supervised = pd.Series([], dtype=object)
        class_counts = {}
        SUPERVISED_AVAILABLE = False
        _le = None

globals()['SUPERVISED_AVAILABLE'] = SUPERVISED_AVAILABLE

if not SUPERVISED_AVAILABLE:
    print("WARNING: Insufficient supervised labels. Skipping sequence length cutoff experiment.")
else:
    supervised_mask = np.asarray(supervised_mask)

    if '_ensure_int_labels' in globals():
        y_encoded = _ensure_int_labels(_le.fit_transform(y_supervised))
    else:
        y_encoded = np.asarray(_le.fit_transform(y_supervised), dtype=np.int64)

    min_class = min(class_counts.values()) if class_counts else 0
    stratify = y_encoded if min_class >= 2 else None
    cv_folds = min(3, min_class) if min_class >= 2 else 0

    # Ensure cdr3_sequences exists
    if 'cdr3_sequences' not in globals():
        cdr3_sequences = {
            'TRA': adata.obs['cdr3_TRA'].astype(str).fillna('').str.upper().tolist() if 'cdr3_TRA' in adata.obs.columns else [''] * adata.n_obs,
            'TRB': adata.obs['cdr3_TRB'].astype(str).fillna('').str.upper().tolist() if 'cdr3_TRB' in adata.obs.columns else [''] * adata.n_obs
        }

    # Ensure gene features exist
    if 'gene_features' not in globals():
        if 'X_gene_pca' in adata.obsm:
            gene_features = adata.obsm['X_gene_pca'][supervised_mask]
        elif 'X_pca' in adata.obsm:
            gene_features = adata.obsm['X_pca'][supervised_mask]
        else:
            gene_features = np.zeros((int(supervised_mask.sum()), 30))
    if gene_features.shape[1] < 30:
        gene_features = np.pad(gene_features, ((0, 0), (0, 30 - gene_features.shape[1])), mode='constant')

    # Ensure TCR physico features exist
    if 'tcr_physico' not in globals():
        if all(c in adata.obs.columns for c in ['tra_length', 'tra_molecular_weight', 'tra_hydrophobicity','trb_length', 'trb_molecular_weight', 'trb_hydrophobicity']):
            tra_physico = adata.obs[['tra_length', 'tra_molecular_weight', 'tra_hydrophobicity']].fillna(0)[supervised_mask].values.astype(np.float32)
            trb_physico = adata.obs[['trb_length', 'trb_molecular_weight', 'trb_hydrophobicity']].fillna(0)[supervised_mask].values.astype(np.float32)
            tcr_physico = np.column_stack([tra_physico, trb_physico]).astype(np.float32)
        else:
            tcr_physico = np.zeros((int(supervised_mask.sum()), 6), dtype=np.float32)

    # Ensure QC features exist
    if 'qc_features' not in globals():
        if all(c in adata.obs.columns for c in ['n_genes_by_counts', 'total_counts', 'pct_counts_mt']):
            qc_features = adata.obs[['n_genes_by_counts', 'total_counts', 'pct_counts_mt']].fillna(0)[supervised_mask].values
        else:
            qc_features = np.zeros((int(supervised_mask.sum()), 3))

    # Define length cutoffs to test
    length_cutoffs = [10, 15, 20, 25, 30, 35, 40, 50]

    length_results = []

    for max_length in length_cutoffs:
        print(f"\nTesting max sequence length: {max_length}")

        # Re-encode sequences with new length - ENSURE FLOAT32 DTYPE
        tra_onehot_new = np.array([one_hot_encode_sequence(seq, max_length, 'ACDEFGHIKLMNPQRSTVWY')
                                   for seq in cdr3_sequences['TRA']], dtype=np.float32)
        tra_onehot_flat_new = tra_onehot_new.reshape(tra_onehot_new.shape[0], -1).astype(np.float32)

        trb_onehot_new = np.array([one_hot_encode_sequence(seq, max_length, 'ACDEFGHIKLMNPQRSTVWY')
                                   for seq in cdr3_sequences['TRB']], dtype=np.float32)
        trb_onehot_flat_new = trb_onehot_new.reshape(trb_onehot_new.shape[0], -1).astype(np.float32)

        # Update AnnData
        adata.obsm['X_tcr_tra_onehot'] = tra_onehot_flat_new
        adata.obsm['X_tcr_trb_onehot'] = trb_onehot_flat_new

        # Re-create feature sets with new encodings using robust PCA
        # Use robust PCA reduction with fallback to TruncatedSVD
        try:
            n_comp_onehot = min(50, adata.obsm['X_tcr_tra_onehot'][supervised_mask].shape[1], max(1, adata.obsm['X_tcr_tra_onehot'][supervised_mask].shape[0]-1))
            onehot_tra_reduced = PCA(n_components=n_comp_onehot, svd_solver='randomized', random_state=42).fit_transform(adata.obsm['X_tcr_tra_onehot'][supervised_mask])
        except Exception as e:
            print(f"  PCA failed for TRA ({e}), using TruncatedSVD")
            n_comp = max(1, min(50, adata.obsm['X_tcr_tra_onehot'][supervised_mask].shape[1]-1))
            onehot_tra_reduced = TruncatedSVD(n_components=n_comp, random_state=42).fit_transform(adata.obsm['X_tcr_tra_onehot'][supervised_mask])

        try:
            n_comp_onehot = min(50, adata.obsm['X_tcr_trb_onehot'][supervised_mask].shape[1], max(1, adata.obsm['X_tcr_trb_onehot'][supervised_mask].shape[0]-1))
            onehot_trb_reduced = PCA(n_components=n_comp_onehot, svd_solver='randomized', random_state=42).fit_transform(adata.obsm['X_tcr_trb_onehot'][supervised_mask])
        except Exception as e:
            print(f"  PCA failed for TRB ({e}), using TruncatedSVD")
            n_comp = max(1, min(50, adata.obsm['X_tcr_trb_onehot'][supervised_mask].shape[1]-1))
            onehot_trb_reduced = TruncatedSVD(n_components=n_comp, random_state=42).fit_transform(adata.obsm['X_tcr_trb_onehot'][supervised_mask])

        X_sequence = np.column_stack([
            gene_features[:, :30],
            onehot_tra_reduced,
            onehot_trb_reduced,
            tcr_physico,
            qc_features
        ])

        # Train and evaluate model
        X_train, X_test, y_train, y_test = train_test_split(
            X_sequence, y_encoded, test_size=0.3, random_state=42, stratify=stratify
        )

        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)

        XGBClass = globals().get('XGBClassifierSK', xgb.XGBClassifier)
        model = XGBClass(random_state=42, eval_metric='logloss')

        model.fit(X_train_scaled, y_train)
        y_pred = model.predict(X_test_scaled)
        accuracy = accuracy_score(y_test, y_pred)

        # Cross-validation
        if cv_folds >= 2:
            cv_scores = cross_val_score(model, X_sequence, y_encoded, cv=cv_folds, scoring='accuracy')
            cv_mean = cv_scores.mean()
            cv_std = cv_scores.std()
        else:
            print("  Skipping CV: not enough samples per class.")
            cv_scores = np.array([])
            cv_mean = float('nan')
            cv_std = float('nan')

        length_results.append({
            'max_length': max_length,
            'accuracy': accuracy,
            'cv_mean': cv_mean,
            'cv_std': cv_std
        })

        print(f"  Accuracy: {accuracy:.3f}, CV: {cv_mean:.3f} ¬± {cv_std:.3f}")

    # Plot results
    length_df = pd.DataFrame(length_results)

    plt.figure(figsize=(10, 6))
    plt.plot(length_df['max_length'], length_df['accuracy'], 'o-', label='Test Accuracy', linewidth=2)
    plt.plot(length_df['max_length'], length_df['cv_mean'], 's-', label='CV Accuracy', linewidth=2)
    plt.fill_between(length_df['max_length'],
                     length_df['cv_mean'] - length_df['cv_std'],
                     length_df['cv_mean'] + length_df['cv_std'],
                     alpha=0.3, label='CV ¬± Std')
    plt.xlabel('Maximum Sequence Length Cutoff')
    plt.ylabel('Accuracy')
    plt.title('Model Accuracy vs Sequence Length Cutoff')

    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    print("\nSequence length cutoff experiment completed!")
    if not length_df.empty and length_df['cv_mean'].notna().any():
        best_len = length_df.loc[length_df['cv_mean'].idxmax(), 'max_length']
        print(f"Optimal length appears to be around {best_len}")


In [None]:
# Safe GPU patcher: apply GPU defaults without forcing unknown attributes
def _apply_gpu_patches():
    """
    Safely patch `param_grids` and `models_eval` to prefer GPU XGBoost settings
    when available. This avoids setting attributes that may not exist on
    estimator objects and wraps callable factories/classes safely.
    """
    import inspect
    try:
        import xgboost as xgb
    except Exception:
        xgb = None

    # Respect explicit user flag if set elsewhere; default False
    XGBOOST_GPU_AVAILABLE = bool(globals().get('XGBOOST_GPU_AVAILABLE', False))

    # Patch param_grids safely (do not overwrite user-specified entries)
    try:
        if 'param_grids' in globals() and XGBOOST_GPU_AVAILABLE:
            pg = dict(param_grids.get('XGBoost', {}))
            pg.setdefault('tree_method', ['gpu_hist'])
            pg.setdefault('predictor', ['gpu_predictor'])
            param_grids['XGBoost'] = pg
            print("Patched param_grids['XGBoost'] with GPU options.")
    except Exception as e:
        print("Error patching param_grids:", e)

    # Patch models_eval in-place (wrap factories/classes or safely set params on instances)
    try:
        if 'models_eval' not in globals():
            return
        me = globals()['models_eval']
        if 'XGBoost' not in me:
            return
        obj = me['XGBoost']

        # If it's a callable factory (e.g., a lambda returning an estimator), wrap it so GPU kwargs are tried safely at call time
        if callable(obj) and not isinstance(obj, type):
            def make_wrapped(factory):
                def wrapped(*a, **kw):
                    if globals().get('XGBOOST_GPU_AVAILABLE', False):
                        try:
                            kw2 = dict(kw)
                            kw2.setdefault('tree_method', 'gpu_hist')
                            kw2.setdefault('predictor', 'gpu_predictor')
                            return factory(*a, **kw2)
                        except TypeError:
                            try:
                                kw2 = dict(kw)
                                kw2.setdefault('tree_method', 'gpu_hist')
                                kw2.pop('predictor', None)
                                return factory(*a, **kw2)
                            except Exception:
                                return factory(*a, **kw)
                    return factory(*a, **kw)
                return wrapped
            me['XGBoost'] = make_wrapped(obj)
            print("Patched callable models_eval['XGBoost'] to include GPU kwargs safely.")
            return

        # If it's a class type, create a subclass wrapper to add defaults in __init__
        if isinstance(obj, type):
            try:
                sig = inspect.signature(obj.__init__)
            except Exception:
                sig = None
            def make_class_with_defaults(cls, sig):
                class Wrapped(cls):
                    def __init__(self, *a, **kw):
                        if globals().get('XGBOOST_GPU_AVAILABLE', False):
                            kw.setdefault('tree_method', 'gpu_hist')
                            if sig and 'predictor' in sig.parameters:
                                kw.setdefault('predictor', 'gpu_predictor')
                        super().__init__(*a, **kw)
                return Wrapped
            me['XGBoost'] = make_class_with_defaults(obj, sig)
            print("Patched class models_eval['XGBoost'] to include GPU defaults.")
            return

        # Otherwise assume it's an instantiated estimator; set params only if supported
        try:
            if hasattr(obj, 'get_params') and hasattr(obj, 'set_params'):
                params = obj.get_params()
                patch = {}
                if globals().get('XGBOOST_GPU_AVAILABLE', False):
                    if 'tree_method' in params:
                        patch['tree_method'] = 'gpu_hist'
                    if 'predictor' in params:
                        patch['predictor'] = 'gpu_predictor'
                if patch:
                    obj.set_params(**patch)
                    print("Patched instance models_eval['XGBoost'] params.")
        except Exception as e:
            print("Failed to patch models_eval['XGBoost']:", e)

    except Exception as e:
        print("Error patching models_eval:", e)


# Task 1-5: Enhanced ML Pipeline for Immunotherapy Response Prediction

This section implements:
1. **Task 1**: GroupKFold cross-validation with Patient-Level Aggregation (Shannon Entropy for TCR diversity)
2. **Task 2**: TCR CDR3 encoding using physicochemical properties (Hydrophobicity, Charge, etc.)
3. **Task 3**: Top 20 feature analysis cross-referenced with Sun et al. 2025 (GZMB, HLA-DR, ISGs)
4. **Task 4**: Extended literature review including I-SPY2 trial and multimodal single-cell ML methods (TCR-H, CoNGA)
5. **Task 5**: 4-panel publication figure (UMAP, SHAP, ROC, Boxplots)

In [None]:
def train_groupkfold_model(patient_df, n_splits=None):
    """
    Train XGBoost model with GroupKFold cross-validation based on Patient_ID.
    """
    base_result = {
        'status': 'skipped',
        'reason': '',
        'cv_accuracy': float('nan'),
        'cv_precision': float('nan'),
        'cv_recall': float('nan'),
        'cv_f1': float('nan'),
        'cv_roc_auc': float('nan'),
        'confusion_matrix': np.array([]),
        'y_true': np.array([]),
        'y_pred': np.array([]),
        'y_pred_proba': np.array([]),
        'feature_importance': pd.DataFrame(),
        'model': None,
        'scaler': None,
        'label_encoder': None,
        'feature_cols': [],
        'patient_df': patient_df if patient_df is not None else pd.DataFrame()
    }
    if patient_df is None or patient_df.empty:
        base_result['reason'] = 'No patient-level data available.'
        print('WARNING:', base_result['reason'])
        return base_result
    if 'Patient_ID' not in patient_df.columns or 'Response' not in patient_df.columns:
        base_result['reason'] = 'Missing required columns Patient_ID/Response in patient_df.'
        print('WARNING:', base_result['reason'])
        return base_result
    n_patients = patient_df['Patient_ID'].nunique()
    if n_patients < 2:
        base_result['reason'] = f'Not enough patients for GroupKFold (n={n_patients}).'
        print('WARNING:', base_result['reason'])
        return base_result
    n_classes = patient_df['Response'].nunique()
    if n_classes < 2:
        base_result['reason'] = f'Not enough response classes for supervised learning (n={n_classes}).'
        print('WARNING:', base_result['reason'])
        return base_result
    print("\n" + "="*60)
    print("Training with GroupKFold Cross-Validation")
    print("="*60)
    # Prepare features and labels
    feature_cols = [col for col in patient_df.columns if col not in ['Patient_ID', 'Response', 'n_cells']]
    if not feature_cols:
        base_result['reason'] = 'No feature columns available for training.'
        print('WARNING:', base_result['reason'])
        return base_result
    X = patient_df[feature_cols].fillna(0).values
    y_labels = patient_df['Response'].values
    # Encode labels
    label_encoder = LabelEncoder()
    y = label_encoder.fit_transform(y_labels)
    y = _ensure_int_labels(y) if '_ensure_int_labels' in globals() else np.asarray(y, dtype=np.int64)
    groups = patient_df['Patient_ID'].values
    print(f"Feature matrix: {X.shape}")
    print(f"Labels: {len(y)}, Classes: {label_encoder.classes_}")
    print(f"Patient groups: {len(np.unique(groups))}")
    # Determine number of splits
    if n_splits is None:
        n_splits = min(5, n_patients)
    if n_splits < 2:
        base_result['reason'] = f'Not enough patients for {n_splits}-fold GroupKFold.'
        print('WARNING:', base_result['reason'])
        return base_result
    print(f"Using {n_splits}-fold GroupKFold CV")
    # Initialize model
    model = xgb.XGBClassifier(
        n_estimators=100,
        max_depth=3,
        learning_rate=0.1,
        random_state=42,
        n_jobs=-1,
        use_label_encoder=False,
        eval_metric='logloss'
    )
    # Perform cross-validation
    cv = GroupKFold(n_splits=n_splits)
    # Get cross-validated predictions
    y_pred = cross_val_predict(model, X, y, groups=groups, cv=cv, n_jobs=1)
    y_pred_proba = cross_val_predict(model, X, y, groups=groups, cv=cv, method='predict_proba', n_jobs=1)
    # Calculate metrics
    accuracy = accuracy_score(y, y_pred)
    precision = precision_score(y, y_pred, zero_division=0)
    recall = recall_score(y, y_pred, zero_division=0)
    f1 = f1_score(y, y_pred, zero_division=0)
    # ROC-AUC (handle binary classification)
    try:
        if len(label_encoder.classes_) == 2:
            roc_auc = roc_auc_score(y, y_pred_proba[:, 1])
        else:
            roc_auc = roc_auc_score(y, y_pred_proba, multi_class='ovr')
    except Exception:
        roc_auc = float('nan')
    print("\n--- Cross-Validation Results ---")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")
    print(f"ROC-AUC:   {roc_auc:.4f}")
    print("\nClassification Report:")
    print(classification_report(y, y_pred, target_names=label_encoder.classes_, zero_division=0))
    print("\nConfusion Matrix:")
    cm = confusion_matrix(y, y_pred)
    print(cm)
    # Train final model on all data
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    final_model = xgb.XGBClassifier(
        n_estimators=100,
        max_depth=3,
        learning_rate=0.1,
        random_state=42,
        n_jobs=-1,
        use_label_encoder=False,
        eval_metric='logloss'
    )
    final_model.fit(X_scaled, y)
    # Feature importance
    feature_importance = pd.DataFrame({
        'feature': feature_cols,
        'importance': final_model.feature_importances_
    }).sort_values('importance', ascending=False)
    print("\n--- Top 10 Most Important Features ---")
    print(feature_importance.head(10).to_string(index=False))
    # Package results
    results = {
        'status': 'ok',
        'reason': '',
        'cv_accuracy': accuracy,
        'cv_precision': precision,
        'cv_recall': recall,
        'cv_f1': f1,
        'cv_roc_auc': roc_auc,
        'confusion_matrix': cm,
        'y_true': y,
        'y_pred': y_pred,
        'y_pred_proba': y_pred_proba,
        'feature_importance': feature_importance,
        'model': final_model,
        'scaler': scaler,
        'label_encoder': label_encoder,
        'feature_cols': feature_cols,
        'patient_df': patient_df
    }
    return results
# ==========================================================================
# Execute Task 1
# ==========================================================================
# Aggregate features at patient level
patient_features_df = aggregate_patient_features(adata)
# Display patient-level features (only available columns)
print("\n--- Patient-Level Feature Summary ---")
summary_cols = ['Patient_ID', 'Response', 'n_cells', 'TRA_shannon_entropy', 'TRB_shannon_entropy',
                'TRA_clonality', 'TRB_clonality']
summary_cols = [c for c in summary_cols if c in patient_features_df.columns]
if summary_cols:
    display(patient_features_df[summary_cols].round(3))
else:
    print("No patient-level summary columns available to display.")
# Train with GroupKFold CV
groupcv_results = train_groupkfold_model(patient_features_df)
# Save results only if training succeeded
if groupcv_results.get('status') == 'ok':
    output_dir = Path('Processed_Data')
    output_dir.mkdir(exist_ok=True)
    patient_features_df.to_csv(output_dir / 'patient_level_features.csv', index=False)
    joblib.dump(groupcv_results['model'], output_dir / 'patient_level_model_groupcv.joblib')
    groupcv_results['feature_importance'].to_csv(output_dir / 'patient_level_groupcv_results.csv', index=False)
    print("\nResults saved to Processed_Data/")
else:
    print(f"\nSkipping save. Reason: {groupcv_results.get('reason')}")

In [None]:
"""
================================================================================
TASK 2: Enhanced TCR CDR3 Encoding with Physicochemical Properties
================================================================================
This cell implements comprehensive TCR CDR3 encoding using:
- Hydrophobicity (Kyte-Doolittle scale)
- Charge (based on pKa values)
- Polarity
- Molecular weight
- Volume
- Flexibility
- Additional biochemical indices

These features capture the biophysical properties that govern TCR-antigen binding.
================================================================================
"""

import numpy as np
import pandas as pd
from collections import OrderedDict

print("="*80)
print("TASK 2: Enhanced TCR CDR3 Physicochemical Encoding")
print("="*80)

# ============================================================================
# Amino Acid Property Tables
# ============================================================================

# Kyte-Doolittle Hydrophobicity Scale (higher = more hydrophobic)
HYDROPHOBICITY_KD = {
    'A': 1.8, 'R': -4.5, 'N': -3.5, 'D': -3.5, 'C': 2.5,
    'Q': -3.5, 'E': -3.5, 'G': -0.4, 'H': -3.2, 'I': 4.5,
    'L': 3.8, 'K': -3.9, 'M': 1.9, 'F': 2.8, 'P': -1.6,
    'S': -0.8, 'T': -0.7, 'W': -0.9, 'Y': -1.3, 'V': 4.2
}

# Amino Acid Charge at pH 7 (approximate)
CHARGE = {
    'A': 0, 'R': 1, 'N': 0, 'D': -1, 'C': 0,
    'Q': 0, 'E': -1, 'G': 0, 'H': 0.1, 'I': 0,  # H is ~10% protonated at pH 7
    'L': 0, 'K': 1, 'M': 0, 'F': 0, 'P': 0,
    'S': 0, 'T': 0, 'W': 0, 'Y': 0, 'V': 0
}

# Polarity (Grantham, 1974)
POLARITY = {
    'A': 8.1, 'R': 10.5, 'N': 11.6, 'D': 13.0, 'C': 5.5,
    'Q': 10.5, 'E': 12.3, 'G': 9.0, 'H': 10.4, 'I': 5.2,
    'L': 4.9, 'K': 11.3, 'M': 5.7, 'F': 5.2, 'P': 8.0,
    'S': 9.2, 'T': 8.6, 'W': 5.4, 'Y': 6.2, 'V': 5.9
}

# Molecular Weight (Da)
MOLECULAR_WEIGHT = {
    'A': 89.1, 'R': 174.2, 'N': 132.1, 'D': 133.1, 'C': 121.2,
    'Q': 146.2, 'E': 147.1, 'G': 75.1, 'H': 155.2, 'I': 131.2,
    'L': 131.2, 'K': 146.2, 'M': 149.2, 'F': 165.2, 'P': 115.1,
    'S': 105.1, 'T': 119.1, 'W': 204.2, 'Y': 181.2, 'V': 117.1
}

# Volume (√Ö¬≥) - Zamyatnin, 1972
VOLUME = {
    'A': 88.6, 'R': 173.4, 'N': 114.1, 'D': 111.1, 'C': 108.5,
    'Q': 143.8, 'E': 138.4, 'G': 60.1, 'H': 153.2, 'I': 166.7,
    'L': 166.7, 'K': 168.6, 'M': 162.9, 'F': 189.9, 'P': 112.7,
    'S': 89.0, 'T': 116.1, 'W': 227.8, 'Y': 193.6, 'V': 140.0
}

# Flexibility Index (Bhaskaran-Ponnuswamy, 1988)
FLEXIBILITY = {
    'A': 0.360, 'R': 0.530, 'N': 0.460, 'D': 0.510, 'C': 0.350,
    'Q': 0.490, 'E': 0.500, 'G': 0.540, 'H': 0.320, 'I': 0.460,
    'L': 0.370, 'K': 0.470, 'M': 0.300, 'F': 0.310, 'P': 0.510,
    'S': 0.510, 'T': 0.440, 'W': 0.310, 'Y': 0.420, 'V': 0.390
}

# Beta-sheet propensity (Chou-Fasman)
BETA_SHEET = {
    'A': 0.83, 'R': 0.93, 'N': 0.89, 'D': 0.54, 'C': 1.19,
    'Q': 1.10, 'E': 0.37, 'G': 0.75, 'H': 0.87, 'I': 1.60,
    'L': 1.30, 'K': 0.74, 'M': 1.05, 'F': 1.38, 'P': 0.55,
    'S': 0.75, 'T': 1.19, 'W': 1.37, 'Y': 1.47, 'V': 1.70
}


def encode_cdr3_physicochemical(sequence, return_features_dict=False):
    """
    Encode a CDR3 sequence using comprehensive physicochemical properties.
    
    Features computed:
    1. Hydrophobicity: mean, sum, min, max, range
    2. Charge: net charge, positive count, negative count, charge ratio
    3. Polarity: mean, std
    4. Size: length, total molecular weight, mean volume
    5. Flexibility: mean, max
    6. Beta-sheet propensity: mean
    7. Positional features: N-term, C-term, middle region properties
    
    Args:
        sequence: CDR3 amino acid sequence string
        return_features_dict: If True, return dict with feature names
    
    Returns:
        numpy array of features (or dict if return_features_dict=True)
    """
    if pd.isna(sequence) or sequence in ['nan', 'NA', '', None]:
        n_features = 26  # Total number of features
        if return_features_dict:
            return {f'physico_feature_{i}': 0.0 for i in range(n_features)}
        return np.zeros(n_features)
    
    seq = str(sequence).upper()
    # Filter to valid amino acids
    valid_aa = set(HYDROPHOBICITY_KD.keys())
    seq = ''.join([c for c in seq if c in valid_aa])
    
    if len(seq) == 0:
        n_features = 26
        if return_features_dict:
            return {f'physico_feature_{i}': 0.0 for i in range(n_features)}
        return np.zeros(n_features)
    
    features = OrderedDict()
    
    # === Hydrophobicity Features ===
    hydro_values = [HYDROPHOBICITY_KD.get(aa, 0) for aa in seq]
    features['hydro_mean'] = np.mean(hydro_values)
    features['hydro_sum'] = np.sum(hydro_values)
    features['hydro_min'] = np.min(hydro_values)
    features['hydro_max'] = np.max(hydro_values)
    features['hydro_range'] = np.max(hydro_values) - np.min(hydro_values)
    features['hydro_std'] = np.std(hydro_values) if len(hydro_values) > 1 else 0
    
    # === Charge Features ===
    charge_values = [CHARGE.get(aa, 0) for aa in seq]
    features['net_charge'] = np.sum(charge_values)
    features['positive_aa_count'] = sum(1 for c in charge_values if c > 0)
    features['negative_aa_count'] = sum(1 for c in charge_values if c < 0)
    features['charge_ratio'] = (features['positive_aa_count'] / 
                                (features['negative_aa_count'] + 1))  # +1 to avoid div by zero
    
    # === Polarity Features ===
    polarity_values = [POLARITY.get(aa, 0) for aa in seq]
    features['polarity_mean'] = np.mean(polarity_values)
    features['polarity_std'] = np.std(polarity_values) if len(polarity_values) > 1 else 0
    
    # === Size Features ===
    features['length'] = len(seq)
    mw_values = [MOLECULAR_WEIGHT.get(aa, 0) for aa in seq]
    features['total_mw'] = np.sum(mw_values)
    features['mean_mw'] = np.mean(mw_values)
    
    volume_values = [VOLUME.get(aa, 0) for aa in seq]
    features['mean_volume'] = np.mean(volume_values)
    features['total_volume'] = np.sum(volume_values)
    
    # === Flexibility Features ===
    flex_values = [FLEXIBILITY.get(aa, 0) for aa in seq]
    features['flexibility_mean'] = np.mean(flex_values)
    features['flexibility_max'] = np.max(flex_values)
    
    # === Beta-sheet Propensity ===
    beta_values = [BETA_SHEET.get(aa, 0) for aa in seq]
    features['beta_propensity_mean'] = np.mean(beta_values)
    
    # === Positional Features (N-term, C-term, Middle) ===
    # CDR3 regions often have conserved ends and variable middle
    n_term = seq[:3] if len(seq) >= 3 else seq
    c_term = seq[-3:] if len(seq) >= 3 else seq
    middle = seq[3:-3] if len(seq) > 6 else seq
    
    features['nterm_hydro'] = np.mean([HYDROPHOBICITY_KD.get(aa, 0) for aa in n_term])
    features['cterm_hydro'] = np.mean([HYDROPHOBICITY_KD.get(aa, 0) for aa in c_term])
    features['middle_hydro'] = np.mean([HYDROPHOBICITY_KD.get(aa, 0) for aa in middle]) if middle else 0
    
    features['nterm_charge'] = np.sum([CHARGE.get(aa, 0) for aa in n_term])
    features['cterm_charge'] = np.sum([CHARGE.get(aa, 0) for aa in c_term])
    features['middle_charge'] = np.sum([CHARGE.get(aa, 0) for aa in middle]) if middle else 0
    
    if return_features_dict:
        return features
    
    return np.array(list(features.values()))


def encode_all_cdr3_physicochemical(adata):
    """
    Encode all CDR3 sequences in the AnnData object with physicochemical features.
    
    Creates:
    - adata.obsm['X_tcr_tra_physico_enhanced']: Enhanced TRA physicochemical features
    - adata.obsm['X_tcr_trb_physico_enhanced']: Enhanced TRB physicochemical features
    - Combined features added to adata.obs
    """
    print("Encoding CDR3 sequences with enhanced physicochemical properties...")
    
    # Get feature names from a sample encoding
    sample_features = encode_cdr3_physicochemical('CASSYSGANVLTF', return_features_dict=True)
    feature_names = list(sample_features.keys())
    print(f"Encoding {len(feature_names)} physicochemical features per sequence")
    
    # Encode TRA sequences (safe if column missing)
    tra_encodings = []
    tra_iter = adata.obs['cdr3_TRA'].astype(str) if 'cdr3_TRA' in adata.obs.columns else pd.Series([''] * adata.n_obs, index=adata.obs.index)
    for seq in tra_iter:
        tra_encodings.append(encode_cdr3_physicochemical(seq))
    tra_matrix = np.vstack(tra_encodings)
    
    # Encode TRB sequences (safe if column missing)
    trb_encodings = []
    trb_iter = adata.obs['cdr3_TRB'].astype(str) if 'cdr3_TRB' in adata.obs.columns else pd.Series([''] * adata.n_obs, index=adata.obs.index)
    for seq in trb_iter:
        trb_encodings.append(encode_cdr3_physicochemical(seq))
    trb_matrix = np.vstack(trb_encodings)
    
    print(f"TRA physicochemical matrix shape: {tra_matrix.shape}")
    print(f"TRB physicochemical matrix shape: {trb_matrix.shape}")
    
    # Ensure matrices are float32 for AnnData compatibility
    tra_matrix = tra_matrix.astype(np.float32)
    trb_matrix = trb_matrix.astype(np.float32)
    
    # Store in AnnData
    adata.obsm['X_tcr_tra_physico_enhanced'] = tra_matrix
    adata.obsm['X_tcr_trb_physico_enhanced'] = trb_matrix
    
    # Also add individual features to obs for easy access (ensure float type)
    for i, fname in enumerate(feature_names):
        adata.obs[f'tra_enhanced_{fname}'] = pd.Series(tra_matrix[:, i], index=adata.obs.index, dtype=np.float32)
        adata.obs[f'trb_enhanced_{fname}'] = pd.Series(trb_matrix[:, i], index=adata.obs.index, dtype=np.float32)
    
    return feature_names


# ============================================================================
# Execute Task 2
# ============================================================================
feature_names_physico = encode_all_cdr3_physicochemical(adata)

# Display summary statistics
print("\n--- Enhanced Physicochemical Feature Summary ---")
print(f"Total features per chain: {len(feature_names_physico)}")
print(f"Feature names: {feature_names_physico}")

# Compare responder vs non-responder
print("\n--- Physicochemical Comparison: Responder vs Non-Responder ---")
resp_mask = adata.obs['response'] == 'Responder'
non_resp_mask = adata.obs['response'] == 'Non-Responder'

comparison_df = []
for fname in ['hydro_mean', 'net_charge', 'polarity_mean', 'flexibility_mean', 'length']:
    tra_col = f'tra_enhanced_{fname}'
    trb_col = f'trb_enhanced_{fname}'
    
    if tra_col in adata.obs.columns:
        resp_tra = adata.obs.loc[resp_mask, tra_col].mean()
        nonresp_tra = adata.obs.loc[non_resp_mask, tra_col].mean()
        resp_trb = adata.obs.loc[resp_mask, trb_col].mean()
        nonresp_trb = adata.obs.loc[non_resp_mask, trb_col].mean()
        
        comparison_df.append({
            'Feature': fname,
            'TRA_Responder': resp_tra,
            'TRA_NonResponder': nonresp_tra,
            'TRA_Diff': resp_tra - nonresp_tra,
            'TRB_Responder': resp_trb,
            'TRB_NonResponder': nonresp_trb,
            'TRB_Diff': resp_trb - nonresp_trb
        })

display(pd.DataFrame(comparison_df).round(3))

print("\n" + "="*80)
print("TASK 2 COMPLETED: Enhanced TCR Physicochemical Encoding")
print("="*80)

In [None]:
"""
================================================================================
TASK 3: Top 20 Feature Analysis Cross-Referenced with Sun et al. 2025
================================================================================
This cell analyzes top predictive features and cross-references them with:
- GZMB (Granzyme B) - key cytotoxicity marker
- HLA-DR genes - antigen presentation
- Interferon-Stimulated Genes (ISGs)
- Other markers identified in Sun et al. 2025

Reference: Sun et al. 2025, npj Breast Cancer 11:65
================================================================================
"""

import numpy as np
import pandas as pd
from pathlib import Path
from scipy.stats import mannwhitneyu

# Validate prerequisites
if 'adata' not in globals() or adata is None:
    raise NameError("adata is not defined. Please run data loading cells first.")

if 'groupcv_results' not in globals() or groupcv_results is None:
    print("WARNING: groupcv_results not defined. Skipping analyze_top_features().")
    print("Please run the patient-level LOPO CV cell first.")
    # Create a minimal placeholder to avoid errors downstream
    groupcv_results = {'feature_importance': pd.DataFrame(columns=['feature', 'importance'])}

# SHAP is optional for this cell; avoid hard failure if missing
try:
    import shap
except Exception:
    shap = None
    print("shap not available; skipping SHAP-specific utilities in Task 3.")

In [None]:
"""
================================================================================
TASK 5: Publication-Quality 4-Panel Figure
================================================================================
This cell generates a comprehensive 4-panel figure suitable for publication:
1. UMAP of cell types colored by response and cell type
2. SHAP importance plot for the multimodal model
3. Patient-level ROC curve from GroupKFold CV
4. Boxplots of top 3 biological markers (GZMB, HLA-DR, ISG)

Figure design follows journal guidelines for Nature/Cell Press publications.
================================================================================
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import seaborn as sns
from sklearn.metrics import roc_curve, auc
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Install SHAP if needed
try:
    import shap
except ImportError:
    %pip install shap
    import shap

print("="*80)
print("TASK 5: Publication-Quality 4-Panel Figure")
print("="*80)

# Set publication-quality defaults
plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial', 'DejaVu Sans'],
    'font.size': 10,
    'axes.labelsize': 12,
    'axes.titlesize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 9,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.transparent': False,
    'axes.spines.top': False,
    'axes.spines.right': False,
})

# Color palette
COLORS = {
    'Responder': '#2ecc71',       # Green
    'Non-Responder': '#e74c3c',   # Red
    'Unknown': '#95a5a6',         # Gray
    'accent': '#3498db',          # Blue
    'purple': '#9b59b6',          # Purple
    'orange': '#e67e22',          # Orange
}


def panel_placeholder(ax, title, message):
    """Render a placeholder panel with a message."""
    ax.axis('off')
    ax.set_title(title, fontweight='bold', loc='left')
    ax.text(0.5, 0.5, message, ha='center', va='center', fontsize=10)
    return ax


def create_panel_a_umap(ax, adata):
    """
    Panel A: UMAP visualization of cells colored by response.
    """
    print("Creating Panel A: UMAP visualization...")

    # Use stored UMAP or compute new one
    umap_coords = None
    if 'X_umap_combined' in adata.obsm:
        umap_coords = adata.obsm['X_umap_combined']
    elif 'X_umap' in adata.obsm:
        umap_coords = adata.obsm['X_umap']
    else:
        # Compute UMAP
        if 'X_gene_pca' in adata.obsm:
            X_pca = adata.obsm['X_gene_pca'][:, :20]
        elif 'X_pca' in adata.obsm:
            X_pca = adata.obsm['X_pca'][:, :20]
        else:
            return panel_placeholder(ax, 'A. Single-Cell UMAP by Response', 'UMAP not available')

        try:
            import umap as umap_module
            reducer = umap_module.UMAP(n_components=2, random_state=42)
            umap_coords = reducer.fit_transform(X_pca)
        except Exception as e:
            print(f"  UMAP computation failed: {e}")
            return panel_placeholder(ax, 'A. Single-Cell UMAP by Response', 'UMAP not available')

    # Create color mapping
    if 'response' in adata.obs.columns:
        resp_series = adata.obs['response'].fillna('Unknown').astype(str)
    else:
        resp_series = pd.Series(['Unknown'] * adata.n_obs, index=adata.obs.index)

    response_colors = []
    for resp in resp_series:
        if resp == 'Responder':
            response_colors.append(COLORS['Responder'])
        elif resp == 'Non-Responder':
            response_colors.append(COLORS['Non-Responder'])
        else:
            response_colors.append(COLORS['Unknown'])

    # Plot with alpha for better visualization
    ax.scatter(
        umap_coords[:, 0],
        umap_coords[:, 1],
        c=response_colors,
        s=3,
        alpha=0.6,
        rasterized=True
    )

    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    ax.set_title('A. Single-Cell UMAP by Response', fontweight='bold', loc='left')

    # Legend
    legend_elements = [
        Patch(facecolor=COLORS['Responder'], label=f"Responder (n={(resp_series=='Responder').sum():,})"),
        Patch(facecolor=COLORS['Non-Responder'], label=f"Non-Responder (n={(resp_series=='Non-Responder').sum():,})"),
        Patch(facecolor=COLORS['Unknown'], label=f"Unknown (n={(resp_series=='Unknown').sum():,})"),
    ]
    ax.legend(handles=legend_elements, loc='upper right', frameon=True, framealpha=0.9)

    return ax


def create_panel_b_shap(ax, groupcv_results, patient_df):
    """
    Panel B: SHAP importance plot for the multimodal model.
    """
    print("Creating Panel B: SHAP importance plot...")

    if groupcv_results is None or patient_df is None or patient_df.empty:
        return panel_placeholder(ax, 'B. Feature Importance (SHAP)', 'Not available')

    model = groupcv_results.get('model')
    feature_cols = groupcv_results.get('feature_cols')
    scaler = groupcv_results.get('scaler')

    if model is None or scaler is None or not feature_cols:
        return panel_placeholder(ax, 'B. Feature Importance (SHAP)', 'Not available')

    # Prepare data
    X = patient_df[feature_cols].fillna(0).values
    X_scaled = scaler.transform(X)

    try:
        # Compute SHAP values
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(X_scaled)
    except Exception as e:
        print(f"  SHAP computation failed: {e}")
        return panel_placeholder(ax, 'B. Feature Importance (SHAP)', 'Not available')

    # Get mean absolute SHAP values for feature importance
    if isinstance(shap_values, list):
        # Multi-class output
        shap_importance = np.abs(shap_values[1]).mean(axis=0)
    else:
        shap_importance = np.abs(shap_values).mean(axis=0)

    # Create DataFrame and get top 15 features
    shap_df = pd.DataFrame({
        'feature': feature_cols,
        'importance': shap_importance
    }).sort_values('importance', ascending=True).tail(15)

    # Create horizontal bar plot
    colors = []
    for feat in shap_df['feature']:
        if 'shannon' in feat.lower() or 'clonality' in feat.lower():
            colors.append(COLORS['purple'])
        elif 'pca' in feat.lower():
            colors.append(COLORS['accent'])
        elif 'hydro' in feat.lower() or 'charge' in feat.lower():
            colors.append(COLORS['orange'])
        else:
            colors.append('#7f8c8d')

    ax.barh(range(len(shap_df)), shap_df['importance'], color=colors)

    # Clean feature names for display
    clean_names = []
    for feat in shap_df['feature']:
        name = feat.replace('_mean', '').replace('_', ' ').title()
        if len(name) > 25:
            name = name[:22] + '...'
        clean_names.append(name)

    ax.set_yticks(range(len(shap_df)))
    ax.set_yticklabels(clean_names)
    ax.set_xlabel('Mean |SHAP Value|')
    ax.set_title('B. Feature Importance (SHAP)', fontweight='bold', loc='left')

    # Legend for feature types
    legend_elements = [
        Patch(facecolor=COLORS['accent'], label='Gene Expression'),
        Patch(facecolor=COLORS['purple'], label='TCR Diversity'),
        Patch(facecolor=COLORS['orange'], label='Physicochemical'),
    ]
    ax.legend(handles=legend_elements, loc='lower right', frameon=True, fontsize=8)

    return ax


def create_panel_c_roc(ax, groupcv_results):
    """
    Panel C: Patient-level ROC curve from GroupKFold CV.
    """
    print("Creating Panel C: Patient-level ROC curve...")

    if groupcv_results is None:
        return panel_placeholder(ax, 'C. Patient-Level ROC Curve', 'Not available')

    y_true = groupcv_results.get('y_true')
    y_proba = groupcv_results.get('y_pred_proba')

    if y_true is None or y_proba is None:
        return panel_placeholder(ax, 'C. Patient-Level ROC Curve', 'Not available')

    y_true = np.asarray(y_true)
    y_proba = np.asarray(y_proba)

    if y_true.size == 0 or y_proba.size == 0:
        return panel_placeholder(ax, 'C. Patient-Level ROC Curve', 'Not available')

    if y_proba.ndim == 2 and y_proba.shape[1] >= 2:
        y_score = y_proba[:, 1]
    elif y_proba.ndim == 1:
        y_score = y_proba
    else:
        return panel_placeholder(ax, 'C. Patient-Level ROC Curve', 'ROC not available')

    # Compute ROC curve
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    roc_auc = auc(fpr, tpr)

    # Plot ROC curve
    ax.plot(fpr, tpr, color=COLORS['accent'], lw=2.5,
            label=f'GroupKFold CV (AUC = {roc_auc:.2f})')

    # Diagonal reference line
    ax.plot([0, 1], [0, 1], 'k--', lw=1.5, alpha=0.5, label='Random (AUC = 0.50)')

    # Fill under curve
    ax.fill_between(fpr, tpr, alpha=0.2, color=COLORS['accent'])

    # Add optimal threshold point
    optimal_idx = np.argmax(tpr - fpr)
    ax.scatter([fpr[optimal_idx]], [tpr[optimal_idx]],
               color=COLORS['Responder'], s=100, zorder=5,
               label=f'Optimal (sens={tpr[optimal_idx]:.2f}, spec={1-fpr[optimal_idx]:.2f})')

    ax.set_xlabel('False Positive Rate (1 - Specificity)')
    ax.set_ylabel('True Positive Rate (Sensitivity)')
    ax.set_title('C. Patient-Level ROC Curve', fontweight='bold', loc='left')
    ax.legend(loc='lower right', frameon=True)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1.05])
    ax.set_aspect('equal')

    return ax


def create_panel_d_boxplots(ax, adata, patient_df=None, groupcv_results=None):
    """
    Panel D: Boxplots of top 3 biological markers.
    """
    print("Creating Panel D: Biomarker boxplots...")

    if patient_df is None and isinstance(groupcv_results, dict):
        patient_df = groupcv_results.get('patient_df')

    # Select markers to plot
    markers_to_plot = []

    # Try to find GZMB, HLA-DRA, and an ISG
    candidate_markers = ['GZMB', 'HLA-DRA', 'ISG15', 'IFI6', 'GNLY', 'PRF1']

    for marker in candidate_markers:
        if marker in adata.var_names:
            markers_to_plot.append(marker)
        if len(markers_to_plot) >= 3:
            break

    # If we don't have 3, fall back to TCR diversity metrics
    if len(markers_to_plot) < 3:
        markers_to_plot.extend(['TRA_shannon_entropy', 'TRB_shannon_entropy', 'TRA_clonality'])
        markers_to_plot = markers_to_plot[:3]

    print(f"  Plotting markers: {markers_to_plot}")

    # Prepare data for plotting
    plot_data = []

    for marker in markers_to_plot:
        if marker in adata.var_names:
            # Gene expression marker
            expr = adata[:, marker].X
            expr = expr.toarray().ravel() if hasattr(expr, 'toarray') else np.asarray(expr).ravel()

            resp_series = adata.obs['response'] if 'response' in adata.obs.columns else pd.Series(['Unknown'] * adata.n_obs)
            for val, resp in zip(expr, resp_series):
                if resp in ['Responder', 'Non-Responder']:
                    plot_data.append({'Marker': marker, 'Expression': val, 'Response': resp})
        elif marker in adata.obs.columns:
            # obs column (TCR metrics)
            resp_series = adata.obs['response'] if 'response' in adata.obs.columns else pd.Series(['Unknown'] * adata.n_obs)
            for val, resp in zip(adata.obs[marker], resp_series):
                if resp in ['Responder', 'Non-Responder']:
                    plot_data.append({'Marker': marker.replace('_', ' ').title(),
                                     'Expression': val, 'Response': resp})

    # Fall back to patient-level features if cell-level data is limited
    if len(plot_data) < 10 and patient_df is not None and not patient_df.empty:
        print("  Using patient-level features for boxplot...")
        for col in ['TRA_shannon_entropy', 'TRB_shannon_entropy', 'TRA_clonality']:
            if col in patient_df.columns:
                for _, row in patient_df.iterrows():
                    plot_data.append({
                        'Marker': col.replace('_', ' ').title(),
                        'Expression': row[col],
                        'Response': row['Response']
                    })

    if len(plot_data) == 0:
        return panel_placeholder(ax, 'D. Key Biomarkers by Response', 'Not available')

    plot_df = pd.DataFrame(plot_data)

    # Create grouped boxplot
    palette = {'Responder': COLORS['Responder'], 'Non-Responder': COLORS['Non-Responder']}

    sns.boxplot(
        data=plot_df,
        x='Marker',
        y='Expression',
        hue='Response',
        palette=palette,
        ax=ax,
        linewidth=1.5,
        fliersize=2
    )

    ax.set_xlabel('')
    ax.set_ylabel('Expression / Value')
    ax.set_title('D. Key Biomarkers by Response', fontweight='bold', loc='left')
    ax.legend(title='Response', loc='upper right', frameon=True)

    # Rotate x-labels if needed
    ax.tick_params(axis='x', rotation=15)

    return ax


def create_publication_figure(adata, groupcv_results=None):
    """
    Create the complete 4-panel publication figure.
    """
    print("\n--- Creating Publication Figure ---")

    # Create figure with 2x2 layout
    fig = plt.figure(figsize=(14, 12))
    gs = gridspec.GridSpec(2, 2, figure=fig, wspace=0.3, hspace=0.35)

    # Panel A: UMAP
    ax_a = fig.add_subplot(gs[0, 0])
    create_panel_a_umap(ax_a, adata)

    # Panel B: SHAP
    ax_b = fig.add_subplot(gs[0, 1])
    if groupcv_results is None or groupcv_results.get('status') == 'skipped':
        panel_placeholder(ax_b, 'B. Feature Importance (SHAP)', 'Not available')
        patient_df = None
    else:
        patient_df = groupcv_results.get('patient_df')
        create_panel_b_shap(ax_b, groupcv_results, patient_df)

    # Panel C: ROC
    ax_c = fig.add_subplot(gs[1, 0])
    if groupcv_results is None or groupcv_results.get('status') == 'skipped':
        panel_placeholder(ax_c, 'C. Patient-Level ROC Curve', 'Not available')
    else:
        create_panel_c_roc(ax_c, groupcv_results)

    # Panel D: Boxplots
    ax_d = fig.add_subplot(gs[1, 1])
    create_panel_d_boxplots(ax_d, adata, patient_df=patient_df, groupcv_results=groupcv_results)

    # Add overall title
    fig.suptitle(
        'Multimodal Machine Learning Predicts Immunotherapy Response in HR+ Breast Cancer',
        fontsize=14,
        fontweight='bold',
        y=0.98
    )

    plt.tight_layout(rect=[0, 0, 1, 0.96])

    return fig


# ==========================================================================
# Execute Task 5
# ==========================================================================

# Get groupcv_results if available
_groupcv_results = globals().get('groupcv_results', None)

# Create the publication figure
try:
    fig = create_publication_figure(adata, _groupcv_results)
except Exception as e:
    print(f"Failed to create publication figure: {e}")
    fig = None

# Save figure in multiple formats
if fig is not None:
    output_dir = Path('Processed_Data/figures')
    output_dir.mkdir(exist_ok=True, parents=True)

    # High-resolution PNG
    fig.savefig(output_dir / 'Figure_Multimodal_ML_Response.png', dpi=300, bbox_inches='tight')
    print(f"Saved: {output_dir / 'Figure_Multimodal_ML_Response.png'}")

    # PDF for publication
    fig.savefig(output_dir / 'Figure_Multimodal_ML_Response.pdf', bbox_inches='tight')
    print(f"Saved: {output_dir / 'Figure_Multimodal_ML_Response.pdf'}")

    # SVG for editing
    fig.savefig(output_dir / 'Figure_Multimodal_ML_Response.svg', bbox_inches='tight')
    print(f"Saved: {output_dir / 'Figure_Multimodal_ML_Response.svg'}")

    plt.show()

    print("\n" + "="*80)
    print("TASK 5 COMPLETED: Publication-Quality 4-Panel Figure Generated")
    print("="*80)
else:
    print("Skipping figure save because figure creation failed or was unavailable.")

## Summary: Enhanced ML Pipeline for HR+ Breast Cancer Immunotherapy Response Prediction

### Tasks Completed

| Task | Description | Key Outputs |
|------|-------------|-------------|
| **Task 1** | GroupKFold CV with Patient-Level Aggregation | `patient_level_features.csv`, `patient_level_model_groupcv.joblib` |
| **Task 2** | Enhanced TCR CDR3 Physicochemical Encoding | 28 features per chain (hydrophobicity, charge, polarity, etc.) |
| **Task 3** | Top 20 Feature Analysis with Sun et al. 2025 | `sun_2025_marker_analysis.csv`, GZMB/HLA-DR/ISG validation |
| **Task 4** | Extended Literature Review | I-SPY2 comparison, TCR-H/CoNGA methods |
| **Task 5** | 4-Panel Publication Figure | `Figure_Multimodal_ML_Response.png/pdf/svg` |

### Key Innovations

1. **Data Leakage Prevention**: GroupKFold ensures all cells from same patient stay in same fold
2. **Shannon Entropy TCR Diversity**: Captures clonal expansion dynamics (responders: dynamic turnover; non-responders: clonal stability)
3. **Comprehensive Physicochemical Encoding**: 28 features capturing binding-relevant properties
4. **Multi-resolution Analysis**: Cell-level clustering + patient-level prediction
5. **Literature Validation**: Cross-referenced with Sun et al. 2025, I-SPY2, and emerging methods

### Files Generated

```
Processed_Data/
‚îú‚îÄ‚îÄ patient_level_features.csv           # Patient-aggregated features with TCR diversity
‚îú‚îÄ‚îÄ patient_level_groupcv_results.csv    # Per-fold CV metrics
‚îú‚îÄ‚îÄ patient_level_model_groupcv.joblib   # Trained XGBoost model
‚îú‚îÄ‚îÄ top_20_features_analysis.csv         # Feature importance ranking
‚îú‚îÄ‚îÄ sun_2025_marker_analysis.csv         # Marker expression comparison
‚îî‚îÄ‚îÄ figures/
    ‚îú‚îÄ‚îÄ Figure_Multimodal_ML_Response.png
    ‚îú‚îÄ‚îÄ Figure_Multimodal_ML_Response.pdf
    ‚îî‚îÄ‚îÄ Figure_Multimodal_ML_Response.svg
```

### Reproducibility Notes

- All random seeds set to 42 for reproducibility
- GroupKFold CV ensures patient-level generalization
- Feature scaling performed with StandardScaler (saved with model)
- Multiple testing correction (Benjamini-Hochberg) applied to marker analysis

### Citation

If using this pipeline, please cite:
- Sun et al. 2025, npj Breast Cancer 11:65 (GSE300475 dataset)
- This enhanced ML pipeline developed for HR+ breast cancer immunotherapy response prediction

## Fixes applied

- **Added safety defaults** for missing `adata.obsm` keys (e.g. `X_gene_umap`, `X_gene_svd`, TCR arrays) to avoid KeyError during feature assembly.
- **Inserted a safe getter** `_get_obsm_or_zeros(adata, key, mask, n_cols)` to retrieve `obsm` arrays with a zeros fallback.
- **Replaced unsafe monkeypatch** of `xgboost.XGBClassifier.__init__` with a **sklearn-compatible wrapper** `XGBClassifierSK` and adjusted `_apply_gpu_patches()` to use it when available.

Notes:
- The notebook contains historical outputs (errors/warnings) from a previous Kaggle run; the code has been made robust so these errors should not reoccur when re-running the notebook in Kaggle.
- I recommend re-running the notebook from the top on Kaggle (where packages and GPUs are available) to validate results and regenerate plots.


In [None]:
# Quick non-fatal sanity checks (safe to run)
try:
    import numpy as np
    if 'adata' in globals():
        n_obs = getattr(adata, 'n_obs', adata.shape[0])
        print('adata.n_obs:', n_obs)
        for k in ['X_gene_pca', 'X_gene_svd', 'X_gene_umap']:
            if k in adata.obsm:
                shape = np.asarray(adata.obsm[k]).shape
                print(f"{k}: present, shape={shape}")
            else:
                print(f"{k}: MISSING")
    else:
        print('adata not defined in this environment (skip checks)')
    print('XGBClassifierSK defined:', 'XGBClassifierSK' in globals())
except Exception as e:
    print('Sanity checks could not be completed:', e)


## Model summary and recommendation

- **Models implemented**
  - **XGBoost (tree ensemble):** Best performing on the *comprehensive* feature set (gene PCs + TCR k-mers + physicochemical features).
  - **RandomForest / LogisticRegression:** Baselines.
  - **Feed-forward MLP:** Dense network for tabular / flattened sequence inputs.
  - **Sequence-aware architectures:** 1D **CNN**, **BiLSTM** (RNN), and **Transformer** (attention) encoders for CDR3 sequences.

- **Recommendation (practical best model):**
  - **XGBoost on the comprehensive feature set** with nested Group/LOPO CV, the expanded hyperparameter grid (n_estimators, max_depth, learning_rate, subsample, colsample_bytree), and **patient-level aggregation** (mean cell probabilities -> patient prediction). This gives best performance and interpretable feature importance.

- **If you want a deep multimodal approach:**
  - Use the **Transformer encoder** for sequence embeddings + MLP for gene PCs, train with **class_weight**, **EarlyStopping** monitoring **val_auc**, and evaluate with patient-level aggregation. Consider pretrained protein language model embeddings (ESM / ProtTrans) if compute permits.

- **Next steps:**
  1. Re-run LOPO with the updated XGBoost grid and patient-level aggregation.
  2. Optionally run a short LOPO experiment for the Transformer-based multimodal model.

*I implemented patient-level metrics and DL training improvements (AUC metrics, val_auc early stopping).*