# Comprehensive gRNA Data Preparation Pipeline

## üìã Overview

This notebook implements a **rigorous, step-by-step data preparation pipeline** for gRNA classification, incorporating biological insights from [Cooper et al. 2022](https://rnajournal.cshlp.org/content/28/7/972.full.pdf).

### Pipeline Stages:
1. **Load & Validate Raw Sequences** - Load canonical gRNA from FASTA files
2. **Parse GTF & Identify gRNA Regions** - For proper negative sampling
3. **Generate Length-Matched Negatives** - Multi-source, GTF-excluded
4. **Comprehensive Feature Extraction** - 120 biologically-informed features
5. **Quality Control & Validation** - Verify no data leakage
6. **Train/Val/Test Split** - Stratified splitting with balancing
7. **Export Datasets** - Save feature-rich datasets for modeling

### üî¨ Key Biological Principles:
- **Evidence-based initiation patterns**: AAAA (39.7%), GAAA (33%), AGAA (12.1%) - NOT just ATATA!
- **Flexible anchor detection**: Position 4-6, length 8-12 nt, AC-rich/G-poor
- **Molecular ruler hypothesis**: Init + anchor length conserved at 15-19 nt
- **A-elevated guiding regions**: 46% A-content (vs 25-30% in non-gRNA)
- **Terminal T**: 90% of gRNAs end with T (facilitates U-tail addition)

### ‚ö†Ô∏è Critical Requirements:
- **NO sequence length in features!** (causes artifact learning)
- **Length-matched negatives** (KS test p>0.05)
- **Exclude gRNA regions** (using GTF annotations)
- **Balanced classes** in all splits

---

## Setup & Imports

In [None]:
import sys
import warnings
import re
import json
from pathlib import Path
from collections import Counter, defaultdict
from typing import Dict, Tuple, List, Set

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

from Bio import SeqIO
from sklearn.model_selection import train_test_split

warnings.filterwarnings('ignore')
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 150
plt.rcParams['figure.figsize'] = (10, 6)
np.random.seed(42)

print("‚úì Imports loaded successfully")
print(f"  NumPy: {np.__version__}")
print(f"  Pandas: {pd.__version__}")

### Define File Paths

**Input files:**
- `mOs_gRNA_final.fasta`: Canonical gRNA sequences from Cooper 2022
- `mOs_Cooper_minicircle.fasta`: Minicircle genomes for negative sampling
- `mOs_gRNA_final.gtf`: gRNA coordinates for exclusion

In [None]:
# Update these paths for your environment!
PROJECT_ROOT = Path.home() / 'projects' / 'grna-inspector'
DATA_DIR = PROJECT_ROOT / 'data'
RAW_DIR = DATA_DIR / 'gRNAs' / 'Cooper_2022'
PROCESSED_DIR = DATA_DIR / 'processed' / 'comprehensive_pipeline'
PLOTS_DIR = DATA_DIR / 'plots' / 'data_prep'

PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
PLOTS_DIR.mkdir(parents=True, exist_ok=True)

# Input files
GRNA_FILE = RAW_DIR / 'mOs_gRNA_final.fasta'
MINICIRCLE_FILE = RAW_DIR / 'mOs_Cooper_minicircle.fasta'
GTF_FILE = RAW_DIR / 'mOs_gRNA_final.gtf'

print("Checking input files...")
for filepath in [GRNA_FILE, MINICIRCLE_FILE, GTF_FILE]:
    if filepath.exists():
        print(f"  ‚úì {filepath.name}")
    else:
        print(f"  ‚úó {filepath.name} - NOT FOUND!")

print(f"\nOutput directory: {PROCESSED_DIR}")

---
## Stage 1: Load & Validate Positive Sequences

Load canonical gRNA sequences from FASTA file and perform initial validation.

In [None]:
print("="*80)
print("STAGE 1: LOAD & VALIDATE POSITIVE SEQUENCES")
print("="*80)

# Load positive sequences
positive_sequences = {}
for record in SeqIO.parse(GRNA_FILE, "fasta"):
    seq = str(record.seq).upper().replace('U', 'T')
    positive_sequences[record.id] = seq

print(f"\nLoaded {len(positive_sequences):,} canonical gRNA sequences")

# Calculate statistics
lengths = [len(seq) for seq in positive_sequences.values()]
sequences = list(positive_sequences.values())
positive_lengths = lengths  # Save for later

# Nucleotide composition
at_contents = [(seq.count('A') + seq.count('T')) / len(seq) * 100 for seq in sequences]
gc_contents = [(seq.count('G') + seq.count('C')) / len(seq) * 100 for seq in sequences]

print("\nüìä Sequence Statistics:")
print(f"  Length range: {min(lengths)}-{max(lengths)} nt")
print(f"  Mean length: {np.mean(lengths):.1f} ¬± {np.std(lengths):.1f} nt")
print(f"  Median length: {np.median(lengths):.0f} nt")

print("\nüß¨ Nucleotide Composition:")
print(f"  Mean AT-content: {np.mean(at_contents):.1f}%")
print(f"  Mean GC-content: {np.mean(gc_contents):.1f}%")

# Quality checks
n_count = sum(1 for seq in sequences if 'N' in seq)
unique_seqs = set(sequences)
n_duplicates = len(sequences) - len(unique_seqs)

print("\nüîç Quality Checks:")
print(f"  Sequences with N: {n_count} ({n_count/len(sequences)*100:.1f}%)")
print(f"  Duplicate sequences: {n_duplicates}")

if n_duplicates > 0:
    print(f"\n  ‚ö†Ô∏è Found {n_duplicates} duplicated sequence(s)")
    print("     (This is OK if same gRNA is encoded on multiple minicircles)")

print("\n‚úÖ Positive sequences loaded and validated!")
print("\n" + "="*80)

---
## Stage 2: Parse GTF and Identify gRNA Regions

Parse GTF file to get gRNA coordinates on each minicircle.
This allows us to **exclude these regions** when generating negative examples.

In [None]:
def parse_gtf_file(gtf_file: Path) -> Dict[str, List[Tuple[int, int]]]:
    """
    Parse GTF file and extract gRNA coordinates for each minicircle.
    Coordinates are converted to 0-indexed, end-exclusive (Python convention).
    """
    grna_regions = defaultdict(list)
    
    with open(gtf_file, 'r') as f:
        for line in f:
            if line.startswith('#'):
                continue
            parts = line.strip().split('\t')
            if len(parts) < 5:
                continue
            minicircle_id = parts[0]
            start = int(parts[3]) - 1  # Convert to 0-indexed
            end = int(parts[4])  # Keep end as Python end-exclusive
            grna_regions[minicircle_id].append((start, end))
    
    # Sort regions by start position
    for mini_id in grna_regions:
        grna_regions[mini_id].sort()
    
    return dict(grna_regions)


def merge_overlapping_regions(regions: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
    """Merge overlapping regions."""
    if not regions:
        return []
    merged = [regions[0]]
    for start, end in regions[1:]:
        last_start, last_end = merged[-1]
        if start <= last_end:  # Overlapping or adjacent
            merged[-1] = (last_start, max(last_end, end))
        else:
            merged.append((start, end))
    return merged


print("="*80)
print("STAGE 2: PARSE GTF AND IDENTIFY gRNA REGIONS")
print("="*80)

grna_regions = parse_gtf_file(GTF_FILE)

print(f"\nüîç Parsed GTF file:")
print(f"  Found gRNA annotations for {len(grna_regions)} minicircles")

# Merge overlapping regions
total_before = sum(len(regions) for regions in grna_regions.values())
for mini_id in grna_regions:
    grna_regions[mini_id] = merge_overlapping_regions(grna_regions[mini_id])
total_after = sum(len(regions) for regions in grna_regions.values())

print(f"\n  Total gRNA annotations: {total_before}")
print(f"  After merging overlaps: {total_after}")

# Show example
example_mini = list(grna_regions.keys())[0]
print(f"\n  Example (regions on {example_mini}):")
for start, end in grna_regions[example_mini][:3]:
    print(f"    {start:4d} - {end:4d} ({end-start} nt)")

print("\n‚úÖ gRNA regions identified and ready for exclusion!")
print("\n" + "="*80)

---
## Stage 3: Generate Length-Matched Negative Examples

### ‚ö†Ô∏è Critical: Avoiding Length Artifacts!

**Multi-source approach:**
1. **Minicircle non-gRNA regions** (50%): Sample from inter-gRNA regions
2. **Chimeric sequences** (30%): Combine fragments from different minicircles
3. **Composition-matched random** (20%): Generate with correct nucleotide frequencies

In [None]:
def get_non_grna_regions(minicircle_id: str, minicircle_length: int,
                         grna_coords: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
    """Calculate non-gRNA regions (inter-gRNA spaces)."""
    if not grna_coords:
        return [(0, minicircle_length)]
    
    non_grna = []
    # Before first gRNA
    if grna_coords[0][0] > 0:
        non_grna.append((0, grna_coords[0][0]))
    # Between gRNAs
    for i in range(len(grna_coords) - 1):
        gap_start = grna_coords[i][1]
        gap_end = grna_coords[i+1][0]
        if gap_end > gap_start:
            non_grna.append((gap_start, gap_end))
    # After last gRNA
    if grna_coords[-1][1] < minicircle_length:
        non_grna.append((grna_coords[-1][1], minicircle_length))
    return non_grna


def generate_minicircle_negatives(minicircle_file: Path, 
                                   grna_regions_dict: Dict,
                                   target_lengths: List[int],
                                   n_samples: int) -> Dict[str, str]:
    """Generate negatives from minicircles, EXCLUDING gRNA regions."""
    # Load minicircles and calculate non-gRNA regions
    minicircles = []
    non_grna_regions = {}
    
    for record in SeqIO.parse(minicircle_file, "fasta"):
        mini_id = record.id
        seq = str(record.seq).upper().replace('U', 'T')
        minicircles.append((mini_id, seq))
        grna_coords = grna_regions_dict.get(mini_id, [])
        non_grna = get_non_grna_regions(mini_id, len(seq), grna_coords)
        non_grna_regions[mini_id] = non_grna
    
    print(f"  Loaded {len(minicircles)} minicircles")
    
    negatives = {}
    attempts = 0
    max_attempts = n_samples * 20
    
    while len(negatives) < n_samples and attempts < max_attempts:
        attempts += 1
        target_len = np.random.choice(target_lengths)
        mini_id, mini_seq = minicircles[np.random.randint(len(minicircles))]
        available_regions = non_grna_regions[mini_id]
        
        if not available_regions:
            continue
        
        region = available_regions[np.random.randint(len(available_regions))]
        region_start, region_end = region
        region_len = region_end - region_start
        
        if region_len < target_len:
            continue
        
        frag_start = np.random.randint(region_start, region_end - target_len + 1)
        fragment = mini_seq[frag_start:frag_start + target_len]
        
        # Quality filters
        if 'N' in fragment:
            continue
        if len(set(fragment)) == 1:
            continue
        
        neg_id = f"{mini_id}_nonGRNA_{frag_start}_{frag_start+target_len}"
        negatives[neg_id] = fragment
    
    return negatives


def generate_chimeric_negatives(minicircle_file: Path,
                                 grna_regions_dict: Dict,
                                 target_lengths: List[int],
                                 n_samples: int) -> Dict[str, str]:
    """Generate chimeric sequences by combining fragments from different minicircles."""
    minicircles = []
    non_grna_regions = {}
    
    for record in SeqIO.parse(minicircle_file, "fasta"):
        mini_id = record.id
        seq = str(record.seq).upper().replace('U', 'T')
        minicircles.append((mini_id, seq))
        grna_coords = grna_regions_dict.get(mini_id, [])
        non_grna = get_non_grna_regions(mini_id, len(seq), grna_coords)
        non_grna_regions[mini_id] = non_grna
    
    chimeric = {}
    
    for i in range(n_samples):
        target_len = np.random.choice(target_lengths)
        n_fragments = np.random.randint(2, 5)
        frag_lens = np.random.multinomial(target_len - n_fragments, 
                                           np.ones(n_fragments) / n_fragments) + 1
        
        fragments = []
        for frag_len in frag_lens:
            for _ in range(100):
                mini_id, mini_seq = minicircles[np.random.randint(len(minicircles))]
                available = non_grna_regions[mini_id]
                if not available:
                    continue
                region = available[np.random.randint(len(available))]
                if region[1] - region[0] >= frag_len:
                    start = np.random.randint(region[0], region[1] - frag_len + 1)
                    fragments.append(mini_seq[start:start + frag_len])
                    break
        
        if len(fragments) == n_fragments:
            combined = ''.join(fragments)
            if 'N' not in combined and len(set(combined)) > 1:
                chimeric[f'chimeric_{i}'] = combined
    
    return chimeric


def generate_random_negatives(positive_sequences: Dict[str, str],
                               target_lengths: List[int],
                               n_samples: int) -> Dict[str, str]:
    """Generate random sequences matching nucleotide composition."""
    # Calculate overall composition from positives
    all_seqs = ''.join(positive_sequences.values())
    freqs = {nt: all_seqs.count(nt) / len(all_seqs) for nt in 'ATGC'}
    nucs = list(freqs.keys())
    probs = list(freqs.values())
    
    randoms = {}
    for i in range(n_samples):
        length = np.random.choice(target_lengths)
        seq = ''.join(np.random.choice(nucs, size=length, p=probs))
        randoms[f'random_{i}'] = seq
    
    return randoms

In [None]:
print("="*80)
print("STAGE 3: GENERATE LENGTH-MATCHED NEGATIVES")
print("="*80)

n_positives = len(positive_sequences)
n_total_negatives = n_positives  # 1:1 ratio

# Distribution: 50% minicircle, 30% chimeric, 20% random
n_minicircle = int(n_total_negatives * 0.50)
n_chimeric = int(n_total_negatives * 0.30)
n_random = n_total_negatives - n_minicircle - n_chimeric

print(f"\nüìä Negative sampling strategy:")
print(f"  Minicircle non-gRNA: {n_minicircle} (50%)")
print(f"  Chimeric: {n_chimeric} (30%)")
print(f"  Random: {n_random} (20%)")
print(f"  Total: {n_total_negatives}")

# Generate negatives
print("\n[1/3] Generating minicircle non-gRNA negatives...")
minicircle_negatives = generate_minicircle_negatives(
    MINICIRCLE_FILE, grna_regions, positive_lengths, n_minicircle
)
print(f"  Generated: {len(minicircle_negatives)}")

print("\n[2/3] Generating chimeric negatives...")
chimeric_negatives = generate_chimeric_negatives(
    MINICIRCLE_FILE, grna_regions, positive_lengths, n_chimeric
)
print(f"  Generated: {len(chimeric_negatives)}")

print("\n[3/3] Generating random negatives...")
random_negatives = generate_random_negatives(
    positive_sequences, positive_lengths, n_random
)
print(f"  Generated: {len(random_negatives)}")

# Combine all negatives
all_negatives = {**minicircle_negatives, **chimeric_negatives, **random_negatives}
negative_lengths = [len(seq) for seq in all_negatives.values()]

print(f"\n‚úÖ Total negatives generated: {len(all_negatives):,}")
print("\n" + "="*80)

### Validate Length Matching (KS Test)

In [None]:
# Perform KS test
ks_stat, ks_pval = stats.ks_2samp(positive_lengths, negative_lengths)

print("="*80)
print("LENGTH DISTRIBUTION VALIDATION")
print("="*80)

print(f"\nüìà Kolmogorov-Smirnov Test:")
print(f"  KS statistic: {ks_stat:.4f}")
print(f"  p-value: {ks_pval:.4f}")

if ks_pval > 0.05:
    print(f"\n  ‚úÖ PASS: Distributions are statistically identical (p={ks_pval:.4f})")
    print(f"     ‚Üí No length leakage!")
else:
    print(f"\n  ‚ùå FAIL: Distributions differ (p={ks_pval:.4f})")
    print(f"     ‚Üí WARNING: May need to regenerate!")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histograms
axes[0].hist(positive_lengths, bins=30, alpha=0.5, label='Positive', 
            color='steelblue', edgecolor='black')
axes[0].hist(negative_lengths, bins=30, alpha=0.5, label='Negative', 
            color='coral', edgecolor='black')
axes[0].set_xlabel('Sequence Length (nt)')
axes[0].set_ylabel('Count')
axes[0].set_title('Length Distribution Comparison')
axes[0].legend()

# ECDF
pos_sorted = np.sort(positive_lengths)
neg_sorted = np.sort(negative_lengths)
pos_ecdf = np.arange(1, len(pos_sorted)+1) / len(pos_sorted)
neg_ecdf = np.arange(1, len(neg_sorted)+1) / len(neg_sorted)

axes[1].plot(pos_sorted, pos_ecdf, label='Positive', color='steelblue')
axes[1].plot(neg_sorted, neg_ecdf, label='Negative', color='coral')
axes[1].set_xlabel('Sequence Length (nt)')
axes[1].set_ylabel('ECDF')
axes[1].set_title(f'ECDF (KS p={ks_pval:.4f})')
axes[1].legend()

plt.tight_layout()
plt.savefig(PLOTS_DIR / 'length_distribution.png', dpi=150)
plt.show()

print("\n‚úì Plot saved")
print("\n" + "="*80)

---
## Stage 4: Comprehensive Feature Extraction

Extract **120 biologically-informed features** from each sequence:
- Evidence-based initiation patterns (AAAA, GAAA, AGAA - NOT just ATATA!)
- Flexible anchor region detection
- Guiding region composition
- Terminal features
- K-mer and structural features

**CRITICAL: NO LENGTH FEATURES!**

In [None]:
class EnhancedGrnaFeatureExtractor:
    """
    Enhanced feature extractor based on empirical analysis of Cooper et al. 2022 data.
    
    Key improvements:
    1. Multiple initiation patterns (AAAA 39.7%, GAAA 33%, AGAA 12.1% - NOT just ATATA!)
    2. Flexible anchor region detection (position 4-6, length 8-12)
    3. Evidence-based thresholds from real data
    4. NO LENGTH FEATURES (critical for avoiding artifacts!)
    
    Total features: 120
    """
    
    def __init__(self):
        # EVIDENCE-BASED initiation patterns
        self.initiation_patterns = {
            'ATATA': 'ATATA',
            'AWAHH': r'A[AT]A[ACT][ACT]',
            'ATRTR': r'AT[AG]T[AG]',
            'AWAWA': r'A[AT]A[AT]A',
            'AAAA': 'AAAA',        # 39.7% in real data!
            'GAAA': 'GAAA',        # 33.0% in real data!
            'AGAA': 'AGAA',        # 12.1% in real data!
            'TAAA': 'TAAA',
            'CAAA': 'CAAA',
            'XAAA': r'[ATGC]AAA',
            'AXAA': r'A[ATGC]AA',
        }
        self.important_3mers = ['AAA', 'ATA', 'TAT', 'TTT', 'AAT', 'ATT', 'GAA', 'AGA']
        self.important_4mers = ['ATAT', 'TATA', 'AAAA', 'TTTT', 'AAAG', 'AAGA', 'GAAA', 'AGAA']
        self.anchor_start_range = (4, 6)
        self.anchor_length_range = (8, 12)
    
    def extract_features(self, sequence):
        features = {}
        seq = sequence.upper().replace('U', 'T')
        
        features.update(self._extract_initiation_features(seq))
        features.update(self._extract_anchor_features(seq))
        features.update(self._extract_guiding_features(seq))
        features.update(self._extract_terminal_features(seq))
        features.update(self._extract_kmer_features(seq))
        features.update(self._extract_structural_features(seq))
        features.update(self._extract_positional_features(seq))
        features.update(self._extract_dinucleotide_features(seq))
        features.update(self._extract_composition_features(seq))
        features.update(self._extract_advanced_features(seq))
        features.update(self._extract_meta_features(seq, features))
        
        return features
    
    def _extract_initiation_features(self, seq):
        features = {}
        init_region = seq[:6] if len(seq) >= 6 else seq
        
        for pattern_name, pattern in self.initiation_patterns.items():
            has_pattern = bool(re.match(pattern, init_region))
            features[f'init_has_{pattern_name}'] = float(has_pattern)
        
        features['init_starts_A'] = float(seq[0] == 'A') if len(seq) > 0 else 0.0
        features['init_starts_G'] = float(seq[0] == 'G') if len(seq) > 0 else 0.0
        features['init_starts_T'] = float(seq[0] == 'T') if len(seq) > 0 else 0.0
        features['init_starts_C'] = float(seq[0] == 'C') if len(seq) > 0 else 0.0
        features['init_starts_purine'] = float(seq[0] in 'AG') if len(seq) > 0 else 0.0
        
        first4 = seq[:4] if len(seq) >= 4 else seq
        if len(first4) > 0:
            features['init_4_A_count'] = first4.count('A')
            features['init_4_T_count'] = first4.count('T')
            features['init_4_G_count'] = first4.count('G')
            features['init_4_C_count'] = first4.count('C')
            features['init_4_A_rich'] = float(first4.count('A') >= 3)
        else:
            for f in ['init_4_A_count', 'init_4_T_count', 'init_4_G_count', 'init_4_C_count', 'init_4_A_rich']:
                features[f] = 0.0
        
        total_patterns = sum(1 for p in self.initiation_patterns 
                           if features.get(f'init_has_{p}', 0) > 0)
        features['init_pattern_count'] = float(total_patterns)
        features['init_any_known_pattern'] = float(total_patterns > 0)
        
        return features
    
    def _extract_anchor_features(self, seq):
        features = {}
        best_anchor = None
        best_score = -1
        best_start = 0
        
        for start in range(self.anchor_start_range[0], 
                          min(self.anchor_start_range[1] + 1, len(seq) - 5)):
            for length in range(self.anchor_length_range[0], 
                               min(self.anchor_length_range[1] + 1, len(seq) - start + 1)):
                anchor = seq[start:start + length]
                if len(anchor) < 5:
                    continue
                ac_content = (anchor.count('A') + anchor.count('C')) / len(anchor)
                g_content = anchor.count('G') / len(anchor)
                score = ac_content - g_content
                if score > best_score:
                    best_score = score
                    best_anchor = anchor
                    best_start = start
        
        if best_anchor and len(best_anchor) > 0:
            anchor = best_anchor
            for nt in 'ATGC':
                features[f'anchor_{nt}_freq'] = anchor.count(nt) / len(anchor)
            features['anchor_AT_freq'] = (anchor.count('A') + anchor.count('T')) / len(anchor)
            features['anchor_GC_freq'] = (anchor.count('G') + anchor.count('C')) / len(anchor)
            features['anchor_purine_freq'] = (anchor.count('A') + anchor.count('G')) / len(anchor)
            features['anchor_AC_content'] = (anchor.count('A') + anchor.count('C')) / len(anchor)
            features['anchor_length'] = float(len(anchor))
            features['anchor_start_pos'] = float(best_start)
            features['anchor_G_depleted'] = float(features['anchor_G_freq'] < 0.15)
            features['anchor_AC_rich'] = float(features['anchor_AC_content'] > 0.60)
            features['anchor_AC_very_rich'] = float(features['anchor_AC_content'] > 0.70)
            init_anchor_len = best_start + len(anchor)
            features['init_anchor_total_len'] = float(init_anchor_len)
            features['in_molecular_ruler_range'] = float(15 <= init_anchor_len <= 19)
            features['anchor_entropy'] = self._calculate_entropy(anchor)
            features['anchor_unique_dinucs'] = float(len(set(
                anchor[i:i+2] for i in range(len(anchor)-1)
            ))) if len(anchor) > 1 else 0.0
        else:
            for ft in ['anchor_A_freq', 'anchor_T_freq', 'anchor_G_freq', 'anchor_C_freq',
                      'anchor_AT_freq', 'anchor_GC_freq', 'anchor_purine_freq', 'anchor_AC_content',
                      'anchor_length', 'anchor_start_pos', 'anchor_G_depleted', 'anchor_AC_rich',
                      'anchor_AC_very_rich', 'init_anchor_total_len', 'in_molecular_ruler_range',
                      'anchor_entropy', 'anchor_unique_dinucs']:
                features[ft] = 0.0
        return features
    
    def _extract_guiding_features(self, seq):
        features = {}
        guide_start = min(15, len(seq))
        guide = seq[guide_start:]
        
        if len(guide) > 0:
            for nt in 'ATGC':
                features[f'guide_{nt}_freq'] = guide.count(nt) / len(guide)
            features['guide_AT_freq'] = (guide.count('A') + guide.count('T')) / len(guide)
            features['guide_GC_freq'] = (guide.count('G') + guide.count('C')) / len(guide)
            features['guide_A_elevated'] = float(features['guide_A_freq'] > 0.40)
            features['guide_A_content_high'] = float(features['guide_A_freq'] > 0.45)
            purine_freq = (guide.count('A') + guide.count('G')) / len(guide)
            features['guide_purine_freq'] = purine_freq
            features['guide_purine_rich'] = float(purine_freq > 0.55)
            features['guide_pyrimidine_freq'] = (guide.count('T') + guide.count('C')) / len(guide)
            features['guide_C_count'] = float(guide.count('C'))
            features['guide_T_count'] = float(guide.count('T'))
            features['guide_edit_potential'] = (guide.count('C') + guide.count('T')) / len(guide)
        else:
            for ft in ['guide_A_freq', 'guide_T_freq', 'guide_G_freq', 'guide_C_freq',
                      'guide_AT_freq', 'guide_GC_freq', 'guide_A_elevated', 'guide_A_content_high',
                      'guide_purine_freq', 'guide_purine_rich', 'guide_pyrimidine_freq',
                      'guide_C_count', 'guide_T_count', 'guide_edit_potential']:
                features[ft] = 0.0
        return features
    
    def _extract_terminal_features(self, seq):
        features = {}
        if len(seq) > 0:
            features['ends_with_T'] = float(seq[-1] == 'T')
            features['ends_with_A'] = float(seq[-1] == 'A')
            features['ends_with_G'] = float(seq[-1] == 'G')
            features['ends_with_C'] = float(seq[-1] == 'C')
            last3 = seq[-3:] if len(seq) >= 3 else seq
            features['last3_T_count'] = float(last3.count('T'))
            features['last3_A_count'] = float(last3.count('A'))
            features['last3_TT'] = float(last3.endswith('TT')) if len(last3) >= 2 else 0.0
            features['last3_AT'] = float('AT' in last3) if len(last3) >= 2 else 0.0
            last5 = seq[-5:] if len(seq) >= 5 else seq
            if len(last5) > 0:
                features['last5_T_freq'] = last5.count('T') / len(last5)
                features['last5_A_freq'] = last5.count('A') / len(last5)
                features['last5_AT_freq'] = (last5.count('A') + last5.count('T')) / len(last5)
            else:
                features['last5_T_freq'] = features['last5_A_freq'] = features['last5_AT_freq'] = 0.0
            features['ends_poly_T_2'] = float(seq[-2:] == 'TT') if len(seq) >= 2 else 0.0
            features['ends_poly_T_3'] = float(seq[-3:] == 'TTT') if len(seq) >= 3 else 0.0
        else:
            for ft in ['ends_with_T', 'ends_with_A', 'ends_with_G', 'ends_with_C',
                      'last3_T_count', 'last3_A_count', 'last3_TT', 'last3_AT',
                      'last5_T_freq', 'last5_A_freq', 'last5_AT_freq',
                      'ends_poly_T_2', 'ends_poly_T_3']:
                features[ft] = 0.0
        return features
    
    def _extract_kmer_features(self, seq):
        features = {}
        n = len(seq)
        if n < 3:
            for kmer in self.important_3mers:
                features[f'kmer3_{kmer}_count'] = 0.0
                features[f'kmer3_{kmer}_freq'] = 0.0
            for kmer in self.important_4mers:
                features[f'kmer4_{kmer}_present'] = 0.0
            return features
        
        kmer3_counts = Counter(seq[i:i+3] for i in range(n-2))
        total_3mers = n - 2
        for kmer in self.important_3mers:
            count = kmer3_counts.get(kmer, 0)
            features[f'kmer3_{kmer}_count'] = float(count)
            features[f'kmer3_{kmer}_freq'] = count / total_3mers if total_3mers > 0 else 0.0
        
        if n >= 4:
            for kmer in self.important_4mers:
                features[f'kmer4_{kmer}_present'] = float(kmer in seq)
        else:
            for kmer in self.important_4mers:
                features[f'kmer4_{kmer}_present'] = 0.0
        return features
    
    def _extract_structural_features(self, seq):
        features = {}
        n = len(seq)
        if n == 0:
            features['entropy'] = 0.0
            features['complexity_ratio'] = 0.0
            features['max_homopolymer'] = 0.0
            features['n_homopolymers_3plus'] = 0.0
            return features
        
        features['entropy'] = self._calculate_entropy(seq)
        if n >= 3:
            unique_3mers = len(set(seq[i:i+3] for i in range(n-2)))
            possible_3mers = min(n - 2, 64)
            features['complexity_ratio'] = unique_3mers / possible_3mers if possible_3mers > 0 else 0.0
        else:
            features['complexity_ratio'] = 0.0
        
        max_run = 0
        n_runs = 0
        current_run = 1
        for i in range(1, n):
            if seq[i] == seq[i-1]:
                current_run += 1
            else:
                if current_run >= 3:
                    n_runs += 1
                max_run = max(max_run, current_run)
                current_run = 1
        if current_run >= 3:
            n_runs += 1
        max_run = max(max_run, current_run)
        features['max_homopolymer'] = float(max_run)
        features['n_homopolymers_3plus'] = float(n_runs)
        return features
    
    def _extract_positional_features(self, seq):
        features = {}
        n = len(seq)
        if n == 0:
            features['first_A_pos_rel'] = 0.0
            features['first_G_pos_rel'] = 0.0
            features['last_T_pos_rel'] = 0.0
            return features
        for nt in ['A', 'G']:
            pos = seq.find(nt)
            features[f'first_{nt}_pos_rel'] = pos / n if pos >= 0 else 1.0
        for nt in ['T']:
            pos = seq.rfind(nt)
            features[f'last_{nt}_pos_rel'] = pos / n if pos >= 0 else 0.0
        return features
    
    def _extract_dinucleotide_features(self, seq):
        features = {}
        n = len(seq)
        important_dinucs = ['AA', 'AT', 'TA', 'TT', 'GC', 'CG', 'AC', 'CA']
        if n < 2:
            for dn in important_dinucs:
                features[f'dinuc_{dn}_freq'] = 0.0
            features['dinuc_bias_AT'] = 0.0
            return features
        
        dinuc_counts = Counter(seq[i:i+2] for i in range(n-1))
        total_dinucs = n - 1
        for dn in important_dinucs:
            features[f'dinuc_{dn}_freq'] = dinuc_counts.get(dn, 0) / total_dinucs
        at_dinucs = sum(dinuc_counts.get(d, 0) for d in ['AA', 'AT', 'TA', 'TT'])
        features['dinuc_bias_AT'] = at_dinucs / total_dinucs
        return features
    
    def _extract_composition_features(self, seq):
        features = {}
        n = len(seq)
        if n == 0:
            for nt in 'ATGC':
                features[f'global_{nt}_freq'] = 0.0
            features['global_AT_content'] = 0.0
            features['global_GC_content'] = 0.0
            features['global_purine_content'] = 0.0
            return features
        for nt in 'ATGC':
            features[f'global_{nt}_freq'] = seq.count(nt) / n
        features['global_AT_content'] = (seq.count('A') + seq.count('T')) / n
        features['global_GC_content'] = (seq.count('G') + seq.count('C')) / n
        features['global_purine_content'] = (seq.count('A') + seq.count('G')) / n
        return features
    
    def _extract_advanced_features(self, seq):
        features = {}
        n = len(seq)
        if n == 0:
            features['skew_AT'] = 0.0
            features['skew_GC'] = 0.0
            features['balance_ratio'] = 0.0
            return features
        a, t, g, c = seq.count('A'), seq.count('T'), seq.count('G'), seq.count('C')
        features['skew_AT'] = (a - t) / (a + t) if a + t > 0 else 0.0
        features['skew_GC'] = (g - c) / (g + c) if g + c > 0 else 0.0
        freqs = [a, t, g, c]
        features['balance_ratio'] = min(freqs) / max(freqs) if max(freqs) > 0 else 0.0
        return features
    
    def _extract_meta_features(self, seq, existing_features):
        features = {}
        init_score = existing_features.get('init_any_known_pattern', 0)
        anchor_score = existing_features.get('anchor_AC_rich', 0)
        guide_score = existing_features.get('guide_A_elevated', 0)
        terminal_score = existing_features.get('ends_with_T', 0)
        features['grna_signature_count'] = init_score + anchor_score + guide_score + terminal_score
        features['grna_signature_all'] = float(
            init_score > 0 and anchor_score > 0 and guide_score > 0 and terminal_score > 0
        )
        features['init_anchor_quality'] = (
            existing_features.get('init_any_known_pattern', 0) * 0.3 +
            existing_features.get('anchor_AC_rich', 0) * 0.4 +
            existing_features.get('in_molecular_ruler_range', 0) * 0.3
        )
        return features
    
    def _calculate_entropy(self, seq):
        if len(seq) == 0:
            return 0.0
        counts = Counter(seq)
        probs = [count / len(seq) for count in counts.values()]
        return -sum(p * np.log2(p) for p in probs if p > 0)
    
    def get_feature_names(self):
        dummy = self.extract_features('AAAAGCACTTTAAATTGCGCGCGCGCGCGCGCGT')
        return list(dummy.keys())

In [None]:
print("="*80)
print("STAGE 4: FEATURE EXTRACTION")
print("="*80)

extractor = EnhancedGrnaFeatureExtractor()
print(f"\n‚úì Feature extractor initialized")
print(f"  Total features: {len(extractor.get_feature_names())}")

print("\n" + "-"*40)
print("Extracting features from all sequences...")
print("-"*40)

data_rows = []

# Add positives
print(f"\n[1/2] Processing {len(positive_sequences):,} positive sequences...")
for i, (seq_id, seq) in enumerate(positive_sequences.items()):
    if i % 300 == 0:
        print(f"  Progress: {i}/{len(positive_sequences)}")
    features = extractor.extract_features(seq)
    row = {
        'sequence_id': seq_id,
        'sequence': seq,
        'length': len(seq),  # Stored but NOT used as feature!
        'label': 1,
        'source': 'canonical_gRNA',
        **features
    }
    data_rows.append(row)

# Add negatives
print(f"\n[2/2] Processing {len(all_negatives):,} negative sequences...")
for i, (seq_id, seq) in enumerate(all_negatives.items()):
    if i % 300 == 0:
        print(f"  Progress: {i}/{len(all_negatives)}")
    
    if 'nonGRNA' in seq_id:
        source = 'minicircle_nonGRNA'
    elif 'chimeric' in seq_id:
        source = 'chimeric'
    else:
        source = 'random'
    
    features = extractor.extract_features(seq)
    row = {
        'sequence_id': seq_id,
        'sequence': seq,
        'length': len(seq),
        'label': 0,
        'source': source,
        **features
    }
    data_rows.append(row)

# Create DataFrame
df_all = pd.DataFrame(data_rows)

# Identify feature columns
metadata_cols = ['sequence_id', 'sequence', 'length', 'label', 'source']
feature_cols = [c for c in df_all.columns if c not in metadata_cols]

print(f"\n{'='*60}")
print(f"‚úì Feature extraction complete!")
print(f"{'='*60}")
print(f"\nüìä Dataset summary:")
print(f"  Total samples: {len(df_all):,}")
print(f"  Positive (gRNA): {sum(df_all['label']==1):,}")
print(f"  Negative: {sum(df_all['label']==0):,}")
print(f"  Features extracted: {len(feature_cols)}")

# Verify length NOT in features
if 'length' in feature_cols:
    print("\n  ‚ùå WARNING: 'length' in features - REMOVE IT!")
else:
    print(f"\n  ‚úì Length correctly EXCLUDED from features")

print("\n" + "="*80)

---
## Stage 5: Quality Control

Critical checks:
1. No NaN/inf values
2. Length not in features (would cause leakage)
3. Class balance

In [None]:
print("="*80)
print("STAGE 5: QUALITY CONTROL")
print("="*80)

# Check for NaN/inf
nan_features = [c for c in feature_cols if df_all[c].isna().any()]
inf_features = [c for c in feature_cols if np.isinf(df_all[c]).any()]

print("\nüîç Data quality checks:")
print(f"  Features with NaN: {len(nan_features)}")
print(f"  Features with inf: {len(inf_features)}")

if nan_features:
    print(f"  ‚ö†Ô∏è Found NaN in: {nan_features[:5]}...")
if inf_features:
    print(f"  ‚ö†Ô∏è Found inf in: {inf_features[:5]}...")

# Check length not in features
if 'length' in feature_cols:
    print("\n  ‚ùå ERROR: 'length' is in feature columns!")
    print("     This will cause the model to learn length artifacts!")
else:
    print("\n  ‚úÖ Length correctly excluded from features")

# Class balance
n_pos = sum(df_all['label'] == 1)
n_neg = sum(df_all['label'] == 0)
balance_ratio = min(n_pos, n_neg) / max(n_pos, n_neg)

print(f"\nüìä Class distribution:")
print(f"  Positive: {n_pos:,} ({n_pos/len(df_all)*100:.1f}%)")
print(f"  Negative: {n_neg:,} ({n_neg/len(df_all)*100:.1f}%)")
print(f"  Balance ratio: {balance_ratio:.3f}")

if balance_ratio < 0.9:
    print("  ‚ö†Ô∏è Classes are imbalanced!")
else:
    print("  ‚úÖ Classes are balanced")

print("\n‚úÖ Quality control complete!")
print("\n" + "="*80)

---
## Stage 6: Train/Val/Test Split

70/15/15 split with stratification by source

In [None]:
print("="*80)
print("STAGE 6: TRAIN/VAL/TEST SPLIT")
print("="*80)

# Create stratification column
df_all['strat_group'] = df_all['label'].astype(str) + '_' + df_all['source']

# First split: 70% train, 30% temp
print("\n[1/2] Splitting train vs temp...")
train_df, temp_df = train_test_split(
    df_all,
    test_size=0.30,
    stratify=df_all['strat_group'],
    random_state=42
)
print(f"  Train: {len(train_df):,}")
print(f"  Temp:  {len(temp_df):,}")

# Second split: 50/50 of temp ‚Üí val and test
print("\n[2/2] Splitting val vs test...")
val_df, test_df = train_test_split(
    temp_df,
    test_size=0.50,
    stratify=temp_df['strat_group'],
    random_state=42
)
print(f"  Val:   {len(val_df):,}")
print(f"  Test:  {len(test_df):,}")

# Verify
print("\nüìä Final split:")
total = len(df_all)
print(f"  Train: {len(train_df)/total*100:.1f}%")
print(f"  Val:   {len(val_df)/total*100:.1f}%")
print(f"  Test:  {len(test_df)/total*100:.1f}%")

# Class balance
for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]:
    pos = sum(df['label']==1)
    neg = sum(df['label']==0)
    print(f"\n  {name}: pos={pos:,} neg={neg:,} balance={min(pos,neg)/max(pos,neg):.3f}")

# Drop stratification column
train_df = train_df.drop('strat_group', axis=1)
val_df = val_df.drop('strat_group', axis=1)
test_df = test_df.drop('strat_group', axis=1)

print("\n‚úÖ Splitting complete!")
print("\n" + "="*80)

---
## Stage 7: Export Datasets

In [None]:
print("="*80)
print("STAGE 7: EXPORT DATASETS")
print("="*80)

# Save CSV files
print("\nSaving datasets...")

train_file = PROCESSED_DIR / 'train_data.csv'
val_file = PROCESSED_DIR / 'val_data.csv'
test_file = PROCESSED_DIR / 'test_data.csv'

train_df.to_csv(train_file, index=False)
print(f"  ‚úì {train_file}")

val_df.to_csv(val_file, index=False)
print(f"  ‚úì {val_file}")

test_df.to_csv(test_file, index=False)
print(f"  ‚úì {test_file}")

# Save feature names
feature_file = PROCESSED_DIR / 'feature_names.txt'
with open(feature_file, 'w') as f:
    for feat in feature_cols:
        f.write(feat + '\n')
print(f"  ‚úì {feature_file}")

# Save metadata
metadata = {
    'creation_date': pd.Timestamp.now().isoformat(),
    'notebook': '2_data_preparation_comprehensive.ipynb',
    'source_files': {
        'positive': str(GRNA_FILE),
        'minicircle': str(MINICIRCLE_FILE),
        'gtf': str(GTF_FILE)
    },
    'negative_strategy': {
        'minicircle_nonGRNA': f'{n_minicircle} (50%)',
        'chimeric': f'{n_chimeric} (30%)',
        'random': f'{n_random} (20%)'
    },
    'total_samples': len(df_all),
    'n_features': len(feature_cols),
    'n_positives': int(sum(df_all['label']==1)),
    'n_negatives': int(sum(df_all['label']==0)),
    'splits': {
        'train': len(train_df),
        'val': len(val_df),
        'test': len(test_df)
    },
    'quality_checks': {
        'length_excluded': 'length' not in feature_cols,
        'ks_test_pval': float(ks_pval),
        'class_balance': float(balance_ratio),
        'no_nan': len(nan_features) == 0,
        'no_inf': len(inf_features) == 0
    }
}

summary_file = PROCESSED_DIR / 'dataset_summary.json'
with open(summary_file, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"  ‚úì {summary_file}")

print("\n" + "="*80)
print("‚úÖ DATA PREPARATION COMPLETE!")
print("="*80)

print("\nüìÅ Output files:")
for f in [train_file, val_file, test_file, feature_file, summary_file]:
    print(f"  {f}")

print("\nüìä Summary:")
print(f"  Total: {len(df_all):,} samples")
print(f"  Features: {len(feature_cols)}")
print(f"  Train: {len(train_df):,}")
print(f"  Val: {len(val_df):,}")
print(f"  Test: {len(test_df):,}")

print("\n‚úÖ Ready for training!")
print("\n" + "="*80)