In [1]:
"""
DeepQuant V2: Gene-Aware Transcriptome Quantification with Hierarchical Learning
================================================================================

A next-generation transcriptome quantification tool that leverages gene family relationships
for improved isoform disambiguation using transformer architectures and probabilistic modeling.

Key Features:
- Gene-aware hierarchical modeling
- Multi-level contrastive learning  
- Uncertainty-aware assignment strategies
- Integration with classical statistical methods
- Comprehensive evaluation framework

Version: 2.0
"""



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from Bio import SeqIO
from transformers import AutoTokenizer, AutoModel
import faiss
from collections import defaultdict, Counter
import time
import os
import logging
from typing import Dict, List, Tuple, Optional, Set
import pandas as pd
from dataclasses import dataclass
import json
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from scipy.stats import entropy
from scipy.optimize import minimize
import warnings
warnings.filterwarnings('ignore')

### Configuration and Data Structures  

In [3]:
@dataclass
class DeepQuantConfig:
    """Configuration class for DeepQuant V2."""
    
    # Model architecture
    model_name: str = "zhihan1996/DNABERT-2-117M"
    embedding_dim: int = 256
    gene_embedding_dim: int = 128
    dropout: float = 0.1
    num_attention_heads: int = 8
    hidden_dim_multiplier: int = 4
    
    # Training parameters
    batch_size: int = 16
    learning_rate: float = 1e-4
    weight_decay: float = 1e-5
    num_epochs: int = 10
    warmup_steps: int = 1000
    
    # Gene-aware learning
    gene_contrastive_weight: float = 0.3
    isoform_contrastive_weight: float = 0.7
    gene_hierarchy_weight: float = 0.2
    temperature: float = 0.1
    
    # Assignment strategy
    assignment_strategy: str = "hierarchical"  # "hierarchical", "joint", "ensemble"
    gene_similarity_threshold: float = 0.7
    isoform_similarity_threshold: float = 0.85
    uncertainty_threshold: float = 0.8
    
    # Computational
    max_sequence_length: int = 512
    use_mixed_precision: bool = True
    num_workers: int = 4
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Output and logging
    output_dir: str = "./deepquant_v2_results"
    log_level: str = "INFO"
    save_attention_maps: bool = False

@dataclass 
class TranscriptInfo:
    """Information about a single transcript."""
    transcript_id: str
    gene_id: str
    gene_name: str
    sequence: str
    length: int
    transcript_type: str
    
@dataclass
class GeneFamilyInfo:
    """Information about a gene family and its isoforms."""
    gene_name: str
    gene_id: str
    transcripts: Dict[str, TranscriptInfo]
    similarity_matrix: Optional[np.ndarray] = None
    complexity_score: float = 0.0
    
    def __post_init__(self):
        self.num_isoforms = len(self.transcripts)
        self.transcript_ids = list(self.transcripts.keys())

### Data Loading and Gene Family Analysis

In [4]:
class TranscriptomeParser:
    """Parse transcriptome FASTA and organize by gene families."""
    
    def __init__(self, config: DeepQuantConfig):
        self.config = config
        self.logger = logging.getLogger(__name__)
        
    def parse_fasta_header(self, header: str) -> Dict[str, str]:
        """Parse GENCODE-style FASTA header - extract only essential fields."""
        # Remove '>' and split by '|'
        fields = header.split('|')
        
        if len(fields) < 6:  # Need at least transcript_id and gene_name
            self.logger.warning(f"Header missing essential fields: {header}")
            return None
            
        return {
            'transcript_id': fields[0],          # First field: transcript ID
            'gene_name': fields[5],              # Sixth field: gene name  
            'gene_id': fields[1] if len(fields) > 1 else fields[0],  # Fallback to transcript_id
            'transcript_type': fields[7] if len(fields) > 7 else 'unknown'  # Optional
        }
    
    def load_transcriptome(self, fasta_path: str) -> Tuple[Dict[str, GeneFamilyInfo], Dict[str, TranscriptInfo]]:
        """Load transcriptome and organize by gene families."""
        self.logger.info(f"Loading transcriptome from {fasta_path}")
        
        transcripts = {}
        gene_families = defaultdict(lambda: {
            'gene_name': '',
            'gene_id': '',
            'transcripts': {}
        })
        
        skipped = 0
        loaded = 0
        
        try:
            for record in SeqIO.parse(fasta_path, "fasta"):
                header_info = self.parse_fasta_header(record.id)
                if not header_info:
                    skipped += 1
                    continue
                
                # Skip very short sequences
                if len(record.seq) < 50:
                    skipped += 1
                    continue
                
                transcript_info = TranscriptInfo(
                    transcript_id=header_info['transcript_id'],
                    gene_id=header_info['gene_id'],
                    gene_name=header_info['gene_name'],
                    sequence=str(record.seq).upper(),
                    length=len(record.seq),
                    transcript_type=header_info['transcript_type']
                )
                
                transcripts[transcript_info.transcript_id] = transcript_info
                
                # Group by gene name
                gene_name = header_info['gene_name']
                gene_families[gene_name]['gene_name'] = gene_name
                gene_families[gene_name]['gene_id'] = header_info['gene_id']
                gene_families[gene_name]['transcripts'][transcript_info.transcript_id] = transcript_info
                
                loaded += 1
                
        except Exception as e:
            self.logger.error(f"Error loading transcriptome: {e}")
            raise
        
        # Convert to GeneFamilyInfo objects
        final_gene_families = {}
        for gene_name, info in gene_families.items():
            final_gene_families[gene_name] = GeneFamilyInfo(
                gene_name=info['gene_name'],
                gene_id=info['gene_id'],
                transcripts=info['transcripts']
            )
        
        self.logger.info(f"Loaded {loaded} transcripts, {len(final_gene_families)} gene families")
        self.logger.info(f"Skipped {skipped} transcripts")
        
        return final_gene_families, transcripts
    
    def analyze_gene_families(self, gene_families: Dict[str, GeneFamilyInfo]) -> Dict[str, any]:
        """Analyze gene family complexity and similarity patterns."""
        self.logger.info("Analyzing gene family complexity...")
        
        stats = {
            'total_genes': len(gene_families),
            'total_transcripts': sum(len(gf.transcripts) for gf in gene_families.values()),
            'single_isoform_genes': 0,
            'multi_isoform_genes': 0,
            'max_isoforms': 0,
            'complex_genes': [],
            'isoform_distribution': Counter()
        }
        
        for gene_name, gene_family in gene_families.items():
            num_isoforms = len(gene_family.transcripts)
            stats['isoform_distribution'][num_isoforms] += 1
            
            if num_isoforms == 1:
                stats['single_isoform_genes'] += 1
            else:
                stats['multi_isoform_genes'] += 1
                
            if num_isoforms > stats['max_isoforms']:
                stats['max_isoforms'] = num_isoforms
            
            # Mark complex genes (>5 isoforms) for special attention
            if num_isoforms > 5:
                stats['complex_genes'].append((gene_name, num_isoforms))
        
        self.logger.info(f"Gene family analysis complete:")
        self.logger.info(f"  Total genes: {stats['total_genes']}")
        self.logger.info(f"  Single isoform genes: {stats['single_isoform_genes']}")
        self.logger.info(f"  Multi-isoform genes: {stats['multi_isoform_genes']}")
        self.logger.info(f"  Most complex gene: {stats['max_isoforms']} isoforms")
        
        return stats


### Gene-Aware Dataset and DataLoader

In [5]:
class GeneAwareDataset(Dataset):
    """Dataset that provides gene family context for training."""
    
    def __init__(self, 
                 reads: List[str],
                 ground_truth: pd.DataFrame,
                 gene_families: Dict[str, GeneFamilyInfo],
                 transcripts: Dict[str, TranscriptInfo],
                 config: DeepQuantConfig,
                 mode: str = "train"):
        
        self.reads = reads
        self.ground_truth = ground_truth
        self.gene_families = gene_families
        self.transcripts = transcripts
        self.config = config
        self.mode = mode
        
        # Create mappings
        self.transcript_to_gene = {t_id: info.gene_name 
                                  for t_id, info in transcripts.items()}
        
        # Prepare tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        
        # Create gene family lookup for efficient sampling
        self.gene_to_transcripts = defaultdict(list)
        for t_id, info in transcripts.items():
            self.gene_to_transcripts[info.gene_name].append(t_id)
        
        self.logger = logging.getLogger(__name__)
        
    def __len__(self):
        return len(self.reads)
    
    def tokenize_sequence(self, sequence: str) -> Dict[str, torch.Tensor]:
        """Tokenize DNA sequence with proper truncation/padding."""
        sequence = sequence.upper().replace('N', 'A')
        
        tokens = self.tokenizer(
            sequence,
            max_length=self.config.max_sequence_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Remove batch dimension
        return {k: v.squeeze(0) for k, v in tokens.items()}
    
    def get_gene_family_context(self, transcript_id: str) -> Dict[str, any]:
        """Get gene family context for a transcript."""
        transcript_info = self.transcripts[transcript_id]
        gene_name = transcript_info.gene_name
        gene_family = self.gene_families[gene_name]
        
        # Get all isoforms in the same gene family
        family_transcripts = list(gene_family.transcripts.keys())
        
        # Sample negative transcripts from different genes
        negative_genes = [g for g in self.gene_families.keys() if g != gene_name]
        negative_transcripts = []
        if negative_genes:
            neg_gene = np.random.choice(negative_genes)
            neg_family = self.gene_families[neg_gene]
            negative_transcripts = list(neg_family.transcripts.keys())
        
        return {
            'gene_name': gene_name,
            'family_transcripts': family_transcripts,
            'negative_transcripts': negative_transcripts,
            'num_isoforms': len(family_transcripts),
            'is_complex_gene': len(family_transcripts) > 3
        }
    
    def __getitem__(self, idx):
        read_sequence = self.reads[idx]
        
        # Get ground truth information
        if self.ground_truth is not None and idx < len(self.ground_truth):
            gt_row = self.ground_truth.iloc[idx]
            true_transcript = gt_row['true_transcript']
        else:
            # For inference mode
            true_transcript = None
        
        # Tokenize the read
        read_tokens = self.tokenize_sequence(read_sequence)
        
        sample = {
            'read_id': f"read_{idx:06d}",
            'read_sequence': read_sequence,
            'input_ids': read_tokens['input_ids'],
            'attention_mask': read_tokens['attention_mask'],
            'read_idx': idx
        }
        
        # Add training-specific information
        if self.mode == "train" and true_transcript is not None:
            sample['true_transcript'] = true_transcript
            
            # Get gene family context
            if true_transcript in self.transcripts:
                try:
                    gene_context = self.get_gene_family_context(true_transcript)
                    sample.update({
                        'gene_name': gene_context['gene_name'],
                        'family_transcripts': gene_context['family_transcripts'],
                        'negative_transcripts': gene_context['negative_transcripts'],
                        'num_isoforms': gene_context['num_isoforms'],
                        'is_complex_gene': gene_context['is_complex_gene']
                    })
                except Exception as e:
                    self.logger.warning(f"Error getting gene context for {true_transcript}: {e}")
                    # Provide fallback values
                    sample.update({
                        'gene_name': None,
                        'family_transcripts': [],
                        'negative_transcripts': [],
                        'num_isoforms': 1,
                        'is_complex_gene': False
                    })
            else:
                self.logger.warning(f"Unknown transcript in ground truth: {true_transcript}")
                # Provide fallback values
                sample.update({
                    'gene_name': None,
                    'family_transcripts': [],
                    'negative_transcripts': [],
                    'num_isoforms': 1,
                    'is_complex_gene': False
                })
        
        return sample

def custom_collate_fn(batch):
    """Custom collate function that handles variable-length gene family information."""
    
    # Tensor fields that are already tensors
    tensor_fields = ['input_ids', 'attention_mask']
    collated = {}
    
    for field in tensor_fields:
        if field in batch[0]:
            collated[field] = torch.stack([sample[field] for sample in batch])
    
    # Handle read_idx separately (it's an integer, needs to be converted to tensor)
    if 'read_idx' in batch[0]:
        collated['read_idx'] = torch.tensor([sample['read_idx'] for sample in batch])
    
    # List fields
    list_fields = ['read_id', 'read_sequence', 'true_transcript', 'gene_name']
    for field in list_fields:
        if field in batch[0]:
            collated[field] = [sample[field] for sample in batch]
    
    # Gene family context (variable length)
    if 'family_transcripts' in batch[0]:
        collated['family_transcripts'] = [sample['family_transcripts'] for sample in batch]
        collated['negative_transcripts'] = [sample['negative_transcripts'] for sample in batch]
        collated['num_isoforms'] = torch.tensor([sample['num_isoforms'] for sample in batch])
        collated['is_complex_gene'] = torch.tensor([sample['is_complex_gene'] for sample in batch])
    
    return collated

### Neural Architecture Components  

In [6]:
class MultiScaleAttention(nn.Module):
    """Multi-scale attention mechanism for capturing both local and global patterns."""
    
    def __init__(self, hidden_dim: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        
        # Multi-scale attention heads
        self.local_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads // 2,
            dropout=dropout,
            batch_first=True
        )
        
        self.global_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads // 2,
            dropout=dropout,
            batch_first=True
        )
        
        self.scale_fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def forward(self, x, attention_mask=None):
        batch_size, seq_len, hidden_dim = x.shape
        
        # Convert attention mask to key_padding_mask format if provided
        key_padding_mask = None
        if attention_mask is not None:
            # attention_mask: 1 for real tokens, 0 for padding
            # key_padding_mask: True for padding tokens, False for real tokens
            key_padding_mask = (attention_mask == 0)
        
        # Local attention with restricted window
        local_out, local_weights = self.local_attention(x, x, x, key_padding_mask=key_padding_mask)
        
        # Global attention 
        global_out, global_weights = self.global_attention(x, x, x, key_padding_mask=key_padding_mask)
        
        # Fuse multi-scale features
        fused = torch.cat([local_out, global_out], dim=-1)
        output = self.scale_fusion(fused)
        
        return output, {'local_weights': local_weights, 'global_weights': global_weights}

class GeneHierarchicalEncoder(nn.Module):
    """Hierarchical encoder that learns gene-level and isoform-level representations."""
    
    def __init__(self, config: DeepQuantConfig):
        super().__init__()
        self.config = config
        
        # Base transformer encoder
        self.base_encoder = AutoModel.from_pretrained(config.model_name)
        base_hidden_size = self.base_encoder.config.hidden_size
        
        # Gene-level encoder
        self.gene_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=base_hidden_size,
                nhead=config.num_attention_heads,
                dim_feedforward=base_hidden_size * config.hidden_dim_multiplier,
                dropout=config.dropout,
                batch_first=True
            ),
            num_layers=2
        )
        
        # Isoform-specific encoder
        self.isoform_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=base_hidden_size,
                nhead=config.num_attention_heads,
                dim_feedforward=base_hidden_size * config.hidden_dim_multiplier,
                dropout=config.dropout,
                batch_first=True
            ),
            num_layers=2
        )
        
        # Multi-scale attention
        self.multi_scale_attention = MultiScaleAttention(
            hidden_dim=base_hidden_size,
            num_heads=config.num_attention_heads,
            dropout=config.dropout
        )
        
        # Projection heads
        self.gene_projection = nn.Sequential(
            nn.Linear(base_hidden_size, config.gene_embedding_dim * 2),
            nn.LayerNorm(config.gene_embedding_dim * 2),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.gene_embedding_dim * 2, config.gene_embedding_dim),
            nn.LayerNorm(config.gene_embedding_dim)
        )
        
        self.isoform_projection = nn.Sequential(
            nn.Linear(base_hidden_size, config.embedding_dim * 2),
            nn.LayerNorm(config.embedding_dim * 2),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.embedding_dim * 2, config.embedding_dim),
            nn.LayerNorm(config.embedding_dim)
        )
        
        # Uncertainty estimation heads
        self.gene_uncertainty = nn.Sequential(
            nn.Linear(config.gene_embedding_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
        self.isoform_uncertainty = nn.Sequential(
            nn.Linear(config.embedding_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, input_ids, attention_mask):
        # Base encoding
        base_output = self.base_encoder(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = base_output.last_hidden_state
        
        # Multi-scale attention
        enhanced_output, attention_weights = self.multi_scale_attention(
            sequence_output, attention_mask
        )
        
        # Convert attention mask to proper format for transformer layers
        # attention_mask: 1 for real tokens, 0 for padding
        # key_padding_mask: True for padding tokens, False for real tokens
        padding_mask = (attention_mask == 0)  # Convert to boolean mask

        # Gene-level encoding (captures common patterns)
        gene_features = self.gene_encoder(enhanced_output, src_key_padding_mask=padding_mask)

        # Isoform-level encoding (captures discriminative patterns) 
        isoform_features = self.isoform_encoder(enhanced_output, src_key_padding_mask=padding_mask)
        
        # Global pooling with attention mask
        gene_pooled = self._attention_pooling(gene_features, attention_mask)
        isoform_pooled = self._attention_pooling(isoform_features, attention_mask)
        
        # Project to embedding spaces
        gene_embedding = self.gene_projection(gene_pooled)
        isoform_embedding = self.isoform_projection(isoform_pooled)
        
        # Normalize embeddings
        gene_embedding = F.normalize(gene_embedding, p=2, dim=-1)
        isoform_embedding = F.normalize(isoform_embedding, p=2, dim=-1)
        
        # Uncertainty estimates
        gene_uncertainty = self.gene_uncertainty(gene_embedding)
        isoform_uncertainty = self.isoform_uncertainty(isoform_embedding)
        
        return {
            'gene_embedding': gene_embedding,
            'isoform_embedding': isoform_embedding,
            'gene_uncertainty': gene_uncertainty,
            'isoform_uncertainty': isoform_uncertainty,
            'attention_weights': attention_weights
        }
    
    def _attention_pooling(self, features, attention_mask):
        """Attention-based pooling that respects padding."""
        # features: [batch_size, seq_len, hidden_dim]
        # attention_mask: [batch_size, seq_len]
        
        masked_features = features * attention_mask.unsqueeze(-1)
        pooled = masked_features.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
        
        return pooled

### Gene-Aware Contrastive Learning

In [7]:
class GeneAwareContrastiveLoss(nn.Module):
    """Multi-level contrastive loss with gene family awareness."""
    
    def __init__(self, config: DeepQuantConfig):
        super().__init__()
        self.config = config
        self.temperature = config.temperature
        self.gene_weight = config.gene_contrastive_weight
        self.isoform_weight = config.isoform_contrastive_weight
        
    def info_nce_loss(self, embeddings, labels, temperature):
        """Compute InfoNCE loss for contrastive learning."""
        batch_size = embeddings.shape[0]
        
        if batch_size < 2:
            return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
        
        # Compute similarity matrix
        sim_matrix = torch.matmul(embeddings, embeddings.T) / temperature
        
        # Create positive mask
        labels = labels.view(-1, 1)
        positive_mask = (labels == labels.T).float()
        positive_mask.fill_diagonal_(0)  # Remove self-similarity
        
        # Check if there are any positive pairs
        if positive_mask.sum() == 0:
            return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
        
        # Numerically stable computation
        # Subtract max for numerical stability
        sim_matrix = sim_matrix - sim_matrix.max(dim=1, keepdim=True)[0].detach()
        
        exp_sim = torch.exp(sim_matrix)
        
        # Mask out diagonal (self-similarity)
        mask = torch.eye(batch_size, device=embeddings.device, dtype=torch.bool)
        exp_sim = exp_sim.masked_fill(mask, 0)
        
        # Compute positive and negative terms
        pos_sim = (exp_sim * positive_mask).sum(dim=1)
        neg_sim = exp_sim.sum(dim=1)
        
        # Avoid log(0) by adding small epsilon
        eps = 1e-8
        loss = -torch.log((pos_sim + eps) / (neg_sim + eps))
        
        # Only compute loss for samples that have positive pairs
        valid_samples = positive_mask.sum(dim=1) > 0
        if valid_samples.sum() == 0:
            return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
        
        return loss[valid_samples].mean()
    
    def forward(self, gene_embeddings, isoform_embeddings, gene_labels, transcript_labels):
        """
        Compute multi-level contrastive loss.
        
        Args:
            gene_embeddings: [batch_size, gene_embedding_dim]
            isoform_embeddings: [batch_size, embedding_dim]  
            gene_labels: [batch_size] - gene identifiers
            transcript_labels: [batch_size] - transcript identifiers
        """
        
        # Gene-level contrastive loss (easier task)
        gene_loss = self.info_nce_loss(gene_embeddings, gene_labels, self.temperature)
        
        # Isoform-level contrastive loss (harder task)
        isoform_loss = self.info_nce_loss(isoform_embeddings, transcript_labels, self.temperature * 0.5)
        
        # Combined loss
        total_loss = self.gene_weight * gene_loss + self.isoform_weight * isoform_loss
        
        return {
            'total_loss': total_loss,
            'gene_loss': gene_loss,
            'isoform_loss': isoform_loss
        }

### Vector Store and Search

In [8]:
class GeneAwareVectorStore:
    """Vector store with gene family awareness for efficient search."""
    
    def __init__(self, config: DeepQuantConfig):
        self.config = config
        self.gene_families = None
        self.transcripts = None
        
        # Separate indices for gene and isoform embeddings
        self.gene_index = None
        self.isoform_index = None
        
        # Mappings
        self.transcript_to_idx = {}
        self.idx_to_transcript = {}
        self.transcript_to_gene = {}
        
        # Embeddings storage
        self.gene_embeddings = None
        self.isoform_embeddings = None
        self.gene_uncertainties = None
        self.isoform_uncertainties = None
        
        self.logger = logging.getLogger(__name__)
    
    def build_index(self, 
                   gene_families: Dict[str, GeneFamilyInfo],
                   transcripts: Dict[str, TranscriptInfo],
                   model: nn.Module,
                   tokenizer):
        """Build gene-aware search indices."""
        
        self.gene_families = gene_families
        self.transcripts = transcripts
        
        self.logger.info("Building gene-aware vector indices...")
        
        # Prepare transcript data
        transcript_ids = list(transcripts.keys())
        transcript_sequences = [transcripts[tid].sequence for tid in transcript_ids]
        
        # Create mappings
        for idx, tid in enumerate(transcript_ids):
            self.transcript_to_idx[tid] = idx
            self.idx_to_transcript[idx] = tid
            self.transcript_to_gene[tid] = transcripts[tid].gene_name
        
        # Generate embeddings in batches
        model.eval()
        gene_embeddings_list = []
        isoform_embeddings_list = []
        gene_uncertainties_list = []
        isoform_uncertainties_list = []
        
        batch_size = self.config.batch_size
        
        with torch.no_grad():
            for i in range(0, len(transcript_sequences), batch_size):
                batch_seqs = transcript_sequences[i:i+batch_size]
                
                # Tokenize batch
                batch_tokens = []
                for seq in batch_seqs:
                    tokens = tokenizer(
                        seq.upper().replace('N', 'A'),
                        max_length=self.config.max_sequence_length,
                        padding='max_length',
                        truncation=True,
                        return_tensors='pt'
                    )
                    batch_tokens.append(tokens)
                
                # Stack tensors
                input_ids = torch.stack([t['input_ids'].squeeze(0) for t in batch_tokens]).to(self.config.device)
                attention_mask = torch.stack([t['attention_mask'].squeeze(0) for t in batch_tokens]).to(self.config.device)
                
                # Get embeddings
                outputs = model(input_ids, attention_mask)
                
                gene_embeddings_list.append(outputs['gene_embedding'].cpu().numpy())
                isoform_embeddings_list.append(outputs['isoform_embedding'].cpu().numpy())
                gene_uncertainties_list.append(outputs['gene_uncertainty'].cpu().numpy())
                isoform_uncertainties_list.append(outputs['isoform_uncertainty'].cpu().numpy())
        
        # Concatenate all embeddings
        self.gene_embeddings = np.vstack(gene_embeddings_list).astype('float32')
        self.isoform_embeddings = np.vstack(isoform_embeddings_list).astype('float32')
        self.gene_uncertainties = np.vstack(gene_uncertainties_list).astype('float32')
        self.isoform_uncertainties = np.vstack(isoform_uncertainties_list).astype('float32')
        
        # Build FAISS indices
        gene_dim = self.gene_embeddings.shape[1]
        isoform_dim = self.isoform_embeddings.shape[1]
        
        self.gene_index = faiss.IndexFlatIP(gene_dim)
        self.isoform_index = faiss.IndexFlatIP(isoform_dim)
        
        self.gene_index.add(self.gene_embeddings)
        self.isoform_index.add(self.isoform_embeddings)
        
        self.logger.info(f"Built indices with {len(transcript_ids)} transcripts")
        self.logger.info(f"Gene embedding dim: {gene_dim}, Isoform embedding dim: {isoform_dim}")
    
    def hierarchical_search(self, 
                          gene_embedding: np.ndarray,
                          isoform_embedding: np.ndarray,
                          top_k_genes: int = 10,
                          top_k_isoforms: int = 5) -> Dict[str, any]:
        """
        Hierarchical search: first find candidate genes, then best isoforms within those genes.
        """
        
        # Step 1: Gene-level search
        gene_similarities, gene_indices = self.gene_index.search(
            gene_embedding.astype('float32'), top_k_genes
        )
        
        # Get candidate genes
        candidate_transcripts = []
        candidate_genes = set()
        
        for idx in gene_indices[0]:
            transcript_id = self.idx_to_transcript[idx]
            gene_name = self.transcript_to_gene[transcript_id]
            candidate_genes.add(gene_name)
        
        # Step 2: Collect all transcripts from candidate genes
        for gene_name in candidate_genes:
            gene_family = self.gene_families[gene_name]
            candidate_transcripts.extend(gene_family.transcript_ids)
        
        # Step 3: Isoform-level search within candidates
        if candidate_transcripts:
            candidate_indices = [self.transcript_to_idx[tid] for tid in candidate_transcripts]
            candidate_embeddings = self.isoform_embeddings[candidate_indices]
            
            # Compute similarities
            isoform_similarities = np.dot(candidate_embeddings, isoform_embedding.T).flatten()
            
            # Get top isoforms
            top_indices = np.argsort(isoform_similarities)[-top_k_isoforms:][::-1]
            
            results = []
            for i, idx in enumerate(top_indices):
                global_idx = candidate_indices[idx]
                transcript_id = self.idx_to_transcript[global_idx]
                gene_name = self.transcript_to_gene[transcript_id]
                
                results.append({
                    'transcript_id': transcript_id,
                    'gene_name': gene_name,
                    'gene_similarity': float(gene_similarities[0][0]),  # Use first gene match
                    'isoform_similarity': float(isoform_similarities[idx]),
                    'gene_uncertainty': float(self.gene_uncertainties[global_idx][0]),
                    'isoform_uncertainty': float(self.isoform_uncertainties[global_idx][0]),
                    'global_idx': global_idx
                })
        else:
            results = []
        
        return {
            'results': results,
            'candidate_genes': list(candidate_genes),
            'num_candidates': len(candidate_transcripts)
        }

### Uncertainty-Aware Assignment Engine  

In [9]:
class UncertaintyAwareAssignment:
    """Advanced assignment engine with multiple strategies and uncertainty quantification."""
    
    def __init__(self, config: DeepQuantConfig):
        self.config = config
        self.logger = logging.getLogger(__name__)
        
        # Assignment thresholds (will be learned/adapted)
        self.gene_threshold = config.gene_similarity_threshold
        self.isoform_threshold = config.isoform_similarity_threshold
        self.uncertainty_threshold = config.uncertainty_threshold
        
    def compute_assignment_confidence(self, 
                                    gene_similarity: float,
                                    isoform_similarity: float,
                                    gene_uncertainty: float,
                                    isoform_uncertainty: float,
                                    num_candidates: int) -> float:
        """Compute overall assignment confidence score."""
        
        # Similarity component (higher = better)
        sim_score = 0.4 * gene_similarity + 0.6 * isoform_similarity
        
        # Uncertainty component (lower = better, so we take 1 - uncertainty)
        uncertainty_score = 0.4 * (1 - gene_uncertainty) + 0.6 * (1 - isoform_uncertainty)
        
        # Complexity penalty (more candidates = lower confidence)
        complexity_penalty = 1.0 / np.log(num_candidates + 2)
        
        # Combined confidence
        confidence = sim_score * uncertainty_score * complexity_penalty
        
        return np.clip(confidence, 0.0, 1.0)
    
    def adaptive_threshold_selection(self,
                                   similarities: List[float],
                                   uncertainties: List[float],
                                   gene_complexity: int) -> float:
        """Adaptively select similarity threshold based on context."""
        
        if not similarities:
            return self.isoform_threshold
        
        # Base threshold
        base_threshold = self.isoform_threshold
        
        # Adjust based on gene complexity
        complexity_adjustment = min(0.1, gene_complexity * 0.02)
        
        # Adjust based on uncertainty distribution
        mean_uncertainty = np.mean(uncertainties)
        uncertainty_adjustment = mean_uncertainty * 0.1
        
        # Adjust based on similarity gap
        if len(similarities) > 1:
            top_sim = max(similarities)
            second_sim = sorted(similarities, reverse=True)[1]
            gap = top_sim - second_sim
            gap_adjustment = -gap * 0.05  # Lower threshold if there's a clear winner
        else:
            gap_adjustment = 0.0
        
        adaptive_threshold = base_threshold + complexity_adjustment + uncertainty_adjustment + gap_adjustment
        
        return np.clip(adaptive_threshold, 0.3, 0.95)
    
    def hierarchical_assignment(self, search_results: Dict[str, any]) -> Dict[str, any]:
        """Hierarchical assignment strategy: gene first, then isoform."""
        
        if not search_results['results']:
            return {
                'assigned_transcript': None,
                'assigned_gene': None,
                'confidence': 0.0,
                'assignment_type': 'unassigned',
                'reason': 'no_candidates'
            }
        
        results = search_results['results']
        
        # Step 1: Gene-level filtering
        gene_candidates = [r for r in results if r['gene_similarity'] >= self.gene_threshold]
        
        if not gene_candidates:
            return {
                'assigned_transcript': None,
                'assigned_gene': None,
                'confidence': 0.0,
                'assignment_type': 'unassigned',
                'reason': 'gene_threshold_not_met'
            }
        
        # Step 2: Adaptive isoform threshold
        isoform_sims = [r['isoform_similarity'] for r in gene_candidates]
        isoform_uncs = [r['isoform_uncertainty'] for r in gene_candidates]
        
        adaptive_threshold = self.adaptive_threshold_selection(
            isoform_sims, isoform_uncs, len(search_results['candidate_genes'])
        )
        
        # Step 3: Isoform-level filtering
        isoform_candidates = [r for r in gene_candidates 
                            if r['isoform_similarity'] >= adaptive_threshold]
        
        if not isoform_candidates:
            return {
                'assigned_transcript': None,
                'assigned_gene': gene_candidates[0]['gene_name'],  # At least assign gene
                'confidence': 0.0,
                'assignment_type': 'gene_only',
                'reason': 'isoform_threshold_not_met'
            }
        
        # Step 4: Select best candidate
        best_candidate = max(isoform_candidates, 
                           key=lambda x: self.compute_assignment_confidence(
                               x['gene_similarity'], x['isoform_similarity'],
                               x['gene_uncertainty'], x['isoform_uncertainty'],
                               len(isoform_candidates)
                           ))
        
        confidence = self.compute_assignment_confidence(
            best_candidate['gene_similarity'], best_candidate['isoform_similarity'],
            best_candidate['gene_uncertainty'], best_candidate['isoform_uncertainty'],
            len(isoform_candidates)
        )
        
        # Determine assignment type
        if confidence >= self.uncertainty_threshold:
            assignment_type = 'high_confidence'
        elif len(isoform_candidates) == 1:
            assignment_type = 'unique_match'
        else:
            assignment_type = 'best_match'
        
        return {
            'assigned_transcript': best_candidate['transcript_id'],
            'assigned_gene': best_candidate['gene_name'],
            'confidence': confidence,
            'assignment_type': assignment_type,
            'reason': 'successful_assignment',
            'num_gene_candidates': len(gene_candidates),
            'num_isoform_candidates': len(isoform_candidates),
            'adaptive_threshold': adaptive_threshold,
            'all_candidates': isoform_candidates[:3]  # Keep top 3 for analysis
        }

### DeepQuantV2 Class

In [10]:
class DeepQuantV2:
    """
    Main DeepQuant V2 class that orchestrates the entire pipeline.
    """
    
    def __init__(self, config: DeepQuantConfig):
        self.config = config
        self.device = torch.device(config.device)
        
        # Setup logging
        logging.basicConfig(
            level=getattr(logging, config.log_level),
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger(__name__)
        
        # Initialize components
        self.parser = TranscriptomeParser(config)
        self.model = None
        self.tokenizer = None
        self.vector_store = GeneAwareVectorStore(config)
        self.assignment_engine = UncertaintyAwareAssignment(config)
        
        # Data storage
        self.gene_families = None
        self.transcripts = None
        self.training_stats = {}
        
        # Create output directory
        os.makedirs(config.output_dir, exist_ok=True)
    
    def load_data(self, transcriptome_path: str, reads_path: str = None, ground_truth_path: str = None):
        """Load transcriptome, reads, and ground truth data."""
        
        # Load transcriptome
        self.gene_families, self.transcripts = self.parser.load_transcriptome(transcriptome_path)
        
        # Analyze gene families
        self.gene_family_stats = self.parser.analyze_gene_families(self.gene_families)
        
        # Load reads if provided
        self.reads = None
        if reads_path:
            self.reads = self._load_reads(reads_path)
        
        # Load ground truth if provided
        self.ground_truth = None
        if ground_truth_path:
            self.ground_truth = pd.read_csv(ground_truth_path)
            self.logger.info(f"Loaded ground truth with {len(self.ground_truth)} entries")
    
    def _load_reads(self, reads_path: str) -> List[str]:
        """Load sequencing reads from FASTQ file."""
        self.logger.info(f"Loading reads from {reads_path}")
        
        reads = []
        for record in SeqIO.parse(reads_path, "fastq"):
            reads.append(str(record.seq))
        
        self.logger.info(f"Loaded {len(reads)} reads")
        return reads
    
    def initialize_model(self):
        """Initialize the neural model and tokenizer."""
        
        self.logger.info("Initializing model...")
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
        
        # Initialize model
        self.model = GeneHierarchicalEncoder(self.config).to(self.device)
        
        # Initialize loss function
        self.contrastive_loss = GeneAwareContrastiveLoss(self.config)
        
        self.logger.info("Model initialized successfully")
    
    def train(self):
        """Train the model with gene-aware contrastive learning."""
        
        if not self.model:
            self.initialize_model()
        
        if not self.ground_truth is not None:
            raise ValueError("Ground truth data required for training")
        
        self.logger.info("Starting training...")
        
        # Create dataset and dataloader
        train_dataset = GeneAwareDataset(
            reads=self.reads,
            ground_truth=self.ground_truth,
            gene_families=self.gene_families,
            transcripts=self.transcripts,
            config=self.config,
            mode="train"
        )
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            collate_fn=custom_collate_fn,
            pin_memory=True
        )
        
        # Initialize optimizer and scheduler
        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )
        
        total_steps = len(train_loader) * self.config.num_epochs
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.config.learning_rate,
            total_steps=total_steps,
            pct_start=0.1
        )
        
        # Training loop
        self.model.train()
        training_losses = []
        
        for epoch in range(self.config.num_epochs):
            epoch_losses = []
            
            for batch_idx, batch in enumerate(train_loader):
                # Move to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                
                # Skip batch if no training labels or gene names
                if 'true_transcript' not in batch or 'gene_name' not in batch:
                    continue
                
                # Filter out None values in gene_name and true_transcript
                valid_indices = []
                valid_gene_names = []
                valid_transcripts = []
                
                for i, (gene_name, transcript_id) in enumerate(zip(batch['gene_name'], batch['true_transcript'])):
                    if gene_name is not None and transcript_id is not None:
                        valid_indices.append(i)
                        valid_gene_names.append(gene_name)
                        valid_transcripts.append(transcript_id)
                
                if not valid_indices:
                    continue  # Skip batch if no valid samples
                
                # Filter tensors to valid indices only
                input_ids = input_ids[valid_indices]
                attention_mask = attention_mask[valid_indices]
                
                # Create gene and transcript labels for contrastive learning
                gene_labels = []
                transcript_labels = []
                
                for gene_name, transcript_id in zip(valid_gene_names, valid_transcripts):
                    # Convert to numeric labels (hash for consistency)
                    gene_labels.append(hash(gene_name) % 10000)
                    transcript_labels.append(hash(transcript_id) % 100000)
                
                if not gene_labels:  # Skip if no valid labels
                    continue
                    
                gene_labels = torch.tensor(gene_labels, device=self.device)
                transcript_labels = torch.tensor(transcript_labels, device=self.device)
                
                # Forward pass
                outputs = self.model(input_ids, attention_mask)
                
                # Compute contrastive loss
                loss_dict = self.contrastive_loss(
                    outputs['gene_embedding'],
                    outputs['isoform_embedding'],
                    gene_labels,
                    transcript_labels
                )
                
                total_loss = loss_dict['total_loss']

                # Check for invalid loss values
                if torch.isnan(total_loss) or torch.isinf(total_loss):
                    self.logger.warning(f"Invalid loss detected: {total_loss.item()}, skipping batch")
                    continue

                # Skip if loss is too large (potential exploding gradient)
                if total_loss.item() > 100:
                    self.logger.warning(f"Loss too large: {total_loss.item()}, skipping batch")
                    continue
                
                # Backward pass
                optimizer.zero_grad()
                total_loss.backward()

                # Add gradient clipping to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

                optimizer.step()
                scheduler.step()
                
                # Record loss
                epoch_losses.append(total_loss.item())
                
                # Log progress
                if batch_idx % 50 == 0:
                    self.logger.info(
                        f"Epoch {epoch+1}/{self.config.num_epochs}, "
                        f"Batch {batch_idx}/{len(train_loader)}, "
                        f"Loss: {total_loss.item():.4f}"
                    )
            
            avg_loss = np.mean(epoch_losses)
            training_losses.append(avg_loss)
            
            self.logger.info(f"Epoch {epoch+1} completed. Average loss: {avg_loss:.4f}")
        
        # Save training stats
        self.training_stats = {
            'training_losses': training_losses,
            'final_loss': training_losses[-1] if training_losses else 0,
            'num_epochs': self.config.num_epochs,
            'total_batches': len(train_loader) * self.config.num_epochs
        }
        
        self.logger.info("Training completed!")
    
    def build_index(self):
        """Build the gene-aware vector search index."""
        
        if not self.model:
            raise ValueError("Model must be initialized before building index")
        
        self.vector_store.build_index(
            self.gene_families,
            self.transcripts,
            self.model,
            self.tokenizer
        )
    
    def quantify_reads(self, 
                      reads: List[str] = None,
                      top_k_genes: int = 10,
                      top_k_isoforms: int = 5) -> Dict[str, any]:
        """Quantify reads using gene-aware hierarchical assignment."""
        
        if reads is None:
            reads = self.reads
        
        if not reads:
            raise ValueError("No reads provided for quantification")
        
        if self.vector_store.gene_index is None:
            raise ValueError("Vector store index must be built before quantification")
        
        self.logger.info(f"Quantifying {len(reads)} reads...")
        
        # Results storage
        assignments = []
        assignment_stats = {
            'total_reads': len(reads),
            'high_confidence': 0,
            'unique_match': 0,
            'best_match': 0,
            'gene_only': 0,
            'unassigned': 0
        }
        
        # Process reads in batches
        batch_size = self.config.batch_size
        self.model.eval()
        
        with torch.no_grad():
            for i in range(0, len(reads), batch_size):
                batch_reads = reads[i:i+batch_size]
                
                # Tokenize batch
                batch_tokens = []
                for read in batch_reads:
                    tokens = self.tokenizer(
                        read.upper().replace('N', 'A'),
                        max_length=self.config.max_sequence_length,
                        padding='max_length',
                        truncation=True,
                        return_tensors='pt'
                    )
                    batch_tokens.append(tokens)
                
                # Stack tensors
                input_ids = torch.stack([t['input_ids'].squeeze(0) for t in batch_tokens]).to(self.device)
                attention_mask = torch.stack([t['attention_mask'].squeeze(0) for t in batch_tokens]).to(self.device)
                
                # Get embeddings
                outputs = self.model(input_ids, attention_mask)
                
                # Process each read in the batch
                for j, read in enumerate(batch_reads):
                    read_idx = i + j
                    
                    # Extract embeddings for this read
                    gene_emb = outputs['gene_embedding'][j:j+1].cpu().numpy()
                    isoform_emb = outputs['isoform_embedding'][j:j+1].cpu().numpy()
                    
                    # Hierarchical search
                    search_results = self.vector_store.hierarchical_search(
                        gene_emb, isoform_emb, top_k_genes, top_k_isoforms
                    )
                    
                    # Assignment
                    assignment = self.assignment_engine.hierarchical_assignment(search_results)
                    
                    # Record assignment
                    assignment['read_id'] = f"read_{read_idx:06d}"
                    assignment['read_sequence'] = read
                    assignments.append(assignment)
                    
                    # Update stats
                    assignment_stats[assignment['assignment_type']] += 1
                
                # Log progress
                if i % (batch_size * 10) == 0:
                    self.logger.info(f"Processed {i+len(batch_reads)}/{len(reads)} reads")
        
        self.logger.info("Quantification completed!")
        self.logger.info(f"Assignment statistics: {assignment_stats}")
        
        return {
            'assignments': assignments,
            'stats': assignment_stats,
            'transcript_counts': self._compute_transcript_counts(assignments)
        }
    
    def _compute_transcript_counts(self, assignments: List[Dict]) -> Dict[str, Dict]:
        """Compute transcript abundance estimates from assignments."""
        
        transcript_counts = defaultdict(lambda: {
            'raw_count': 0,
            'weighted_count': 0.0,
            'confidence_sum': 0.0,
            'assignments': []
        })
        
        for assignment in assignments:
            if assignment['assigned_transcript']:
                tid = assignment['assigned_transcript']
                confidence = assignment['confidence']
                
                transcript_counts[tid]['raw_count'] += 1
                transcript_counts[tid]['weighted_count'] += confidence
                transcript_counts[tid]['confidence_sum'] += confidence
                transcript_counts[tid]['assignments'].append(assignment)
        
        # Compute final metrics
        final_counts = {}
        total_weighted = sum(tc['weighted_count'] for tc in transcript_counts.values())
        
        for tid, counts in transcript_counts.items():
            num_assignments = len(counts['assignments'])
            avg_confidence = counts['confidence_sum'] / num_assignments if num_assignments > 0 else 0.0
            
            final_counts[tid] = {
                'raw_count': counts['raw_count'],
                'weighted_count': counts['weighted_count'],
                'tpm': (counts['weighted_count'] / total_weighted * 1e6) if total_weighted > 0 else 0.0,
                'average_confidence': avg_confidence,
                'gene_name': self.transcripts[tid].gene_name if tid in self.transcripts else 'unknown'
            }
        
        return final_counts
    
    def save_results(self, results: Dict[str, any], output_prefix: str = "deepquant_v2"):
        """Save quantification results to files."""
        
        output_dir = Path(self.config.output_dir)
        
        # Save assignments
        assignments_df = pd.DataFrame(results['assignments'])
        assignments_file = output_dir / f"{output_prefix}_assignments.csv"
        assignments_df.to_csv(assignments_file, index=False)
        
        # Save transcript counts
        counts_data = []
        for tid, counts in results['transcript_counts'].items():
            counts_data.append({
                'transcript_id': tid,
                'gene_name': counts['gene_name'],
                'raw_count': counts['raw_count'],
                'weighted_count': counts['weighted_count'],
                'tpm': counts['tpm'],
                'average_confidence': counts['average_confidence']
            })
        
        counts_df = pd.DataFrame(counts_data).sort_values('weighted_count', ascending=False)
        counts_file = output_dir / f"{output_prefix}_transcript_counts.csv"
        counts_df.to_csv(counts_file, index=False)
        
        # Save statistics
        stats_file = output_dir / f"{output_prefix}_stats.json"
        with open(stats_file, 'w') as f:
            json.dump({
                'assignment_stats': results['stats'],
                'gene_family_stats': self.gene_family_stats,
                'training_stats': self.training_stats,
                'config': self.config.__dict__
            }, f, indent=2)
        
        self.logger.info(f"Results saved to {output_dir}")
        self.logger.info(f"  Assignments: {assignments_file}")
        self.logger.info(f"  Counts: {counts_file}")
        self.logger.info(f"  Stats: {stats_file}")
        
        return {
            'assignments_file': assignments_file,
            'counts_file': counts_file,
            'stats_file': stats_file
        }

### Evaluation and Validation Framework

In [11]:
class DeepQuantEvaluator:
    """Comprehensive evaluation framework for DeepQuant V2."""
    
    def __init__(self, config: DeepQuantConfig):
        self.config = config
        self.logger = logging.getLogger(__name__)
    
    def evaluate_against_ground_truth(self,
                                    assignments: List[Dict],
                                    ground_truth: pd.DataFrame) -> Dict[str, any]:
        """Evaluate assignments against ground truth."""
        
        self.logger.info("Evaluating against ground truth...")
        
        # Align assignments with ground truth
        eval_data = []
        
        for i, assignment in enumerate(assignments):
            if i < len(ground_truth):
                gt_row = ground_truth.iloc[i]
                
                eval_data.append({
                    'read_id': assignment['read_id'],
                    'predicted_transcript': assignment.get('assigned_transcript'),
                    'predicted_gene': assignment.get('assigned_gene'), 
                    'true_transcript': gt_row['true_transcript'],
                    'true_gene': gt_row.get('true_gene', 'unknown'),
                    'confidence': assignment.get('confidence', 0.0),
                    'assignment_type': assignment.get('assignment_type', 'unknown')
                })
        
        eval_df = pd.DataFrame(eval_data)
        
        # Compute metrics
        metrics = self._compute_evaluation_metrics(eval_df)
        
        return metrics
    
    def _compute_evaluation_metrics(self, eval_df: pd.DataFrame) -> Dict[str, any]:
        """Compute comprehensive evaluation metrics."""
        
        total_reads = len(eval_df)
        
        # Basic assignment metrics
        assigned_reads = eval_df['predicted_transcript'].notna().sum()
        assignment_rate = assigned_reads / total_reads
        
        # Transcript-level accuracy
        transcript_correct = (eval_df['predicted_transcript'] == eval_df['true_transcript']).sum()
        transcript_accuracy = transcript_correct / total_reads
        
        # Gene-level accuracy (including partial assignments)
        gene_correct = (eval_df['predicted_gene'] == eval_df['true_gene']).sum()
        gene_accuracy = gene_correct / total_reads
        
        # Confidence-weighted metrics
        assigned_df = eval_df[eval_df['predicted_transcript'].notna()].copy()
        if len(assigned_df) > 0:
            weighted_transcript_acc = ((assigned_df['predicted_transcript'] == assigned_df['true_transcript']) * 
                                     assigned_df['confidence']).sum() / assigned_df['confidence'].sum()
            avg_confidence = assigned_df['confidence'].mean()
        else:
            weighted_transcript_acc = 0.0
            avg_confidence = 0.0
        
        # Assignment type breakdown
        assignment_breakdown = eval_df['assignment_type'].value_counts().to_dict()
        
        # Confidence analysis
        confidence_stats = {
            'mean': float(eval_df['confidence'].mean()),
            'std': float(eval_df['confidence'].std()),
            'median': float(eval_df['confidence'].median()),
            'q75': float(eval_df['confidence'].quantile(0.75)),
            'q25': float(eval_df['confidence'].quantile(0.25))
        }
        
        return {
            'total_reads': total_reads,
            'assignment_rate': assignment_rate,
            'transcript_accuracy': transcript_accuracy,
            'gene_accuracy': gene_accuracy,
            'weighted_transcript_accuracy': weighted_transcript_acc,
            'average_confidence': avg_confidence,
            'assignment_breakdown': assignment_breakdown,
            'confidence_stats': confidence_stats
        }
    
    def generate_evaluation_report(self, metrics: Dict[str, any], output_dir: str):
        """Generate comprehensive evaluation report with visualizations."""
        
        output_path = Path(output_dir)
        
        # Create evaluation report
        report = {
            'summary': {
                'assignment_rate': f"{metrics['assignment_rate']:.1%}",
                'transcript_accuracy': f"{metrics['transcript_accuracy']:.1%}",
                'gene_accuracy': f"{metrics['gene_accuracy']:.1%}",
                'average_confidence': f"{metrics['average_confidence']:.3f}"
            },
            'detailed_metrics': metrics
        }
        
        # Save report
        report_file = output_path / "evaluation_report.json"
        with open(report_file, 'w') as f:
            json.dump(report, f, indent=2)
        
        self.logger.info(f"Evaluation report saved to {report_file}")
        
        return report_file

### Execution Function

In [12]:
def main():
    """Main execution function for DeepQuant V2."""
    
    # Configuration
    config = DeepQuantConfig(
        # Model parameters
        embedding_dim=256,
        gene_embedding_dim=128,
        batch_size=8,  # Reduced for memory efficiency
        learning_rate=1e-5,
        num_epochs=5,
        
        # Assignment parameters
        gene_similarity_threshold=0.7,
        isoform_similarity_threshold=0.85,
        uncertainty_threshold=0.8,
        
        # Paths
        output_dir="./deepquant_v2_results"
    )
    
    # Initialize DeepQuant V2
    dq = DeepQuantV2(config)
    
    # Load data
    transcriptome_path = "/run/media/saadat/A/tools/DeepQuant/gencode.v47.transcripts_1000.fa"
    reads_path = "/run/media/saadat/A/tools/DeepQuant/simulated_reads.fastq"
    ground_truth_path = "/run/media/saadat/A/tools/DeepQuant/simulated_ground_truth.csv"
    
    print("🧬 Loading transcriptome and data...")
    dq.load_data(transcriptome_path, reads_path, ground_truth_path)
    
    # Initialize and train model
    print("🤖 Initializing and training model...")
    dq.initialize_model()
    dq.train()
    
    # Build search index
    print("🔍 Building gene-aware search index...")
    dq.build_index()
    
    # Quantify reads
    print("📊 Quantifying reads...")
    results = dq.quantify_reads()
    
    # Save results
    print("💾 Saving results...")
    output_files = dq.save_results(results)
    
    # Evaluation
    print("📈 Evaluating results...")
    evaluator = DeepQuantEvaluator(config)
    eval_metrics = evaluator.evaluate_against_ground_truth(
        results['assignments'], 
        dq.ground_truth
    )
    
    # Generate evaluation report
    evaluator.generate_evaluation_report(eval_metrics, config.output_dir)
    
    # Print summary
    print("\n" + "="*60)
    print("🎯 DEEPQUANT V2 RESULTS SUMMARY")
    print("="*60)
    print(f"📋 Total reads processed: {eval_metrics['total_reads']}")
    print(f"✅ Assignment rate: {eval_metrics['assignment_rate']:.1%}")
    print(f"🎯 Transcript accuracy: {eval_metrics['transcript_accuracy']:.1%}")
    print(f"🧬 Gene accuracy: {eval_metrics['gene_accuracy']:.1%}")
    print(f"🎲 Average confidence: {eval_metrics['average_confidence']:.3f}")
    print(f"📁 Results saved to: {config.output_dir}")
    print("="*60)
    
    return dq, results, eval_metrics

if __name__ == "__main__":
    main()

2025-08-25 20:01:20,835 - __main__ - INFO - Loading transcriptome from /run/media/saadat/A/tools/DeepQuant/gencode.v47.transcripts_1000.fa
2025-08-25 20:01:20,860 - __main__ - INFO - Loaded 1000 transcripts, 61 gene families
2025-08-25 20:01:20,860 - __main__ - INFO - Skipped 0 transcripts
2025-08-25 20:01:20,861 - __main__ - INFO - Analyzing gene family complexity...
2025-08-25 20:01:20,861 - __main__ - INFO - Gene family analysis complete:
2025-08-25 20:01:20,861 - __main__ - INFO -   Total genes: 61
2025-08-25 20:01:20,862 - __main__ - INFO -   Single isoform genes: 43
2025-08-25 20:01:20,862 - __main__ - INFO -   Multi-isoform genes: 18
2025-08-25 20:01:20,862 - __main__ - INFO -   Most complex gene: 332 isoforms
2025-08-25 20:01:20,863 - __main__ - INFO - Loading reads from /run/media/saadat/A/tools/DeepQuant/simulated_reads.fastq
2025-08-25 20:01:20,910 - __main__ - INFO - Loaded 10000 reads
2025-08-25 20:01:20,930 - __main__ - INFO - Loaded ground truth with 10000 entries
2025-0

🧬 Loading transcriptome and data...
🤖 Initializing and training model...


2025-08-25 20:01:23.607105: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-25 20:01:23.618871: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756144883.631121   11972 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756144883.635294   11972 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1756144883.645022   11972 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

OutOfMemoryError: CUDA out of memory. Tried to allocate 48.00 MiB. GPU 0 has a total capacity of 7.61 GiB of which 50.06 MiB is free. Including non-PyTorch memory, this process has 6.81 GiB memory in use. Of the allocated memory 6.42 GiB is allocated by PyTorch, and 238.37 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)