In [8]:
!git clone https://github.com/JeissonParra12/Class_project_Early_detection_of_brain_tumor.git

fatal: destination path 'Class_project_Early_detection_of_brain_tumor' already exists and is not an empty directory.


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import cv2
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Union
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from sklearn.metrics import precision_recall_curve, average_precision_score
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# REGION PROPOSAL NETWORK (RPN) FOR TUMOR DETECTION - FIXED VERSION
# ============================================================================

class LightweightRPN(nn.Module):
    """
    Lightweight Region Proposal Network for detecting potential tumor regions
    Optimized for high recall to ensure small lesions are not overlooked
    """
    
    def __init__(self, input_channels: int = 4, num_anchors: int = 9):
        super(LightweightRPN, self).__init__()
        
        self.input_channels = input_channels
        self.num_anchors = num_anchors
        
        # Feature extraction backbone (lightweight)
        self.backbone = nn.Sequential(
            # First conv block
            nn.Conv2d(input_channels, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Second conv block
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Third conv block
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        
        # Classification head (tumor vs background)
        self.cls_head = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_anchors * 2, 1)  # 2 scores per anchor (tumor, background)
        )
        
        # Regression head (bounding box adjustments)
        self.reg_head = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_anchors * 4, 1)  # 4 coordinates per anchor
        )
        
        # Initialize anchors for different scales and aspect ratios
        self.anchor_scales = [32, 64, 128]  # Different sizes for small/medium/large tumors
        self.anchor_ratios = [0.5, 1.0, 2.0]  # Different aspect ratios
        
        # We'll generate anchors dynamically based on feature map size
        self.anchors = None
        
    def generate_anchors(self, feature_map_size: Tuple[int, int], device: torch.device) -> torch.Tensor:
        """Generate anchor boxes for a given feature map size"""
        if self.anchors is not None:
            return self.anchors
            
        # Calculate the stride based on input size (224) and feature map size
        stride_h = 224 / feature_map_size[0]
        stride_w = 224 / feature_map_size[1]
        
        anchors = []
        
        # Generate anchors for each position in the feature map
        for i in range(feature_map_size[0]):
            for j in range(feature_map_size[1]):
                center_y = (i + 0.5) * stride_h
                center_x = (j + 0.5) * stride_w
                
                for scale in self.anchor_scales:
                    for ratio in self.anchor_ratios:
                        w = scale * np.sqrt(ratio)
                        h = scale / np.sqrt(ratio)
                        
                        x1 = center_x - w / 2
                        y1 = center_y - h / 2
                        x2 = center_x + w / 2
                        y2 = center_y + h / 2
                        
                        anchors.append([x1, y1, x2, y2])
        
        anchors = torch.tensor(anchors, dtype=torch.float32, device=device)
        self.anchors = anchors
        return anchors
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass through RPN
        Returns: tuple of (classification_scores, bounding_box_regressions)
        """
        batch_size = x.size(0)
        
        # Extract features
        features = self.backbone(x)  # Shape: (batch_size, 256, H, W)
        
        # Get feature map size
        feature_map_size = (features.size(2), features.size(3))
        
        # Generate anchors for this feature map size
        anchors = self.generate_anchors(feature_map_size, x.device)
        total_anchors = anchors.size(0)
        
        # Get classification scores and bounding box regressions
        cls_scores = self.cls_head(features)
        reg_preds = self.reg_head(features)
        
        # Reshape outputs
        cls_scores = cls_scores.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2)
        reg_preds = reg_preds.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 4)
        
        # Store anchors for later use
        self.current_anchors = anchors
        
        return cls_scores, reg_preds

# ============================================================================
# SIMPLIFIED TWO-STAGE DETECTION AND CLASSIFICATION SYSTEM - FIXED VERSION
# ============================================================================

class BrainTumorDetectionSystem(nn.Module):
    """
    Simplified two-stage brain tumor detection and classification system
    """
    
    def __init__(self, input_channels: int = 4, num_classes: int = 2):
        super(BrainTumorDetectionSystem, self).__init__()
        
        # Stage 1: Region Proposal Network
        self.rpn = LightweightRPN(input_channels)
        
        # Stage 2: CLM-based classification
        self.clm_classifier = CorrelationLearningMechanism(input_channels, num_classes)
        
        # Get the actual feature dimension from CLM
        self.clm_feature_dim = self._get_clm_feature_dim()
        
        print(f"CLM feature dimension: {self.clm_feature_dim}")
        
        # Fusion layer to combine features - FIXED DIMENSIONS
        self.fusion_layer = nn.Sequential(
            nn.Linear(self.clm_feature_dim + 2, 128),  # CLM features + RPN scores
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        
    def _get_clm_feature_dim(self) -> int:
        """Get the actual feature dimension from CLM classifier"""
        # Create a dummy input to get the feature dimension
        dummy_input = torch.randn(1, 4, 224, 224)
        with torch.no_grad():
            features, _ = self.clm_classifier(dummy_input)
            if features.dim() == 4:
                features = F.adaptive_avg_pool2d(features, (1, 1))
                features = features.view(1, -1)
            return features.size(1)
        
    def forward(self, x: torch.Tensor):
        """
        Fixed forward pass with proper dimension handling
        """
        batch_size = x.size(0)
        
        # Get RPN predictions
        rpn_cls, rpn_reg = self.rpn(x)
        
        # Get RPN image-level scores
        rpn_scores = F.softmax(rpn_cls, dim=2)[:, :, 1]  # Tumor probability for each anchor
        rpn_max_scores, _ = rpn_scores.max(dim=1)  # Max score per image
        
        # Create RPN feature representation (tumor score and background score)
        rpn_feature = torch.stack([rpn_max_scores, 1 - rpn_max_scores], dim=1)  # Shape: (batch_size, 2)
        
        # Get CLM features
        clm_features, clm_classification = self.clm_classifier(x)
        
        # Process CLM features to ensure correct dimensions
        if clm_features.dim() == 4:
            clm_features = F.adaptive_avg_pool2d(clm_features, (1, 1))
            clm_features = clm_features.view(batch_size, -1)
        
        # Ensure CLM features have the expected dimension
        if clm_features.size(1) != self.clm_feature_dim:
            # Handle dimension mismatch by projection
            if clm_features.size(1) < self.clm_feature_dim:
                # Pad if smaller
                padding = torch.zeros(batch_size, self.clm_feature_dim - clm_features.size(1), 
                                    device=x.device)
                clm_features = torch.cat([clm_features, padding], dim=1)
            else:
                # Truncate if larger
                clm_features = clm_features[:, :self.clm_feature_dim]
        
        # Combine features
        combined_features = torch.cat([clm_features, rpn_feature], dim=1)
        
        # Final classification
        final_classification = self.fusion_layer(combined_features)
        
        return final_classification, rpn_cls, rpn_reg, clm_features
    
    def detect_regions(self, x: torch.Tensor, threshold: float = 0.3):
        """
        Separate method for region detection - FIXED VERSION
        """
        batch_size = x.size(0)
        rpn_cls, rpn_reg = self.rpn(x)
        
        proposals_list = []
        confidence_list = []
        
        for i in range(batch_size):
            # Get scores for this image
            image_cls = rpn_cls[i]  # Shape: (num_anchors, 2)
            image_reg = rpn_reg[i]  # Shape: (num_anchors, 4)
            
            # Convert to probabilities
            probs = F.softmax(image_cls, dim=1)[:, 1]  # Tumor probability
            
            # Get anchors for this forward pass
            anchors = self.rpn.current_anchors
            
            # Ensure anchors and probabilities have the same number of elements
            if anchors.size(0) != probs.size(0):
                # This should not happen, but if it does, use the minimum
                min_size = min(anchors.size(0), probs.size(0))
                anchors = anchors[:min_size]
                probs = probs[:min_size]
                image_reg = image_reg[:min_size]
            
            # Simple threshold-based proposal generation
            keep_mask = probs > threshold
            
            if keep_mask.sum() > 0:
                kept_anchors = anchors[keep_mask]
                kept_reg = image_reg[keep_mask]
                kept_probs = probs[keep_mask]
                
                # Apply regression adjustments
                adjusted_proposals = self._apply_regression(kept_anchors, kept_reg)
                proposals_list.append(adjusted_proposals)
                confidence_list.append(kept_probs)
            else:
                # No proposals - return empty tensors
                proposals_list.append(torch.empty(0, 4, device=x.device))
                confidence_list.append(torch.empty(0, device=x.device))
        
        return proposals_list, confidence_list
    
    def _apply_regression(self, anchors: torch.Tensor, regressions: torch.Tensor) -> torch.Tensor:
        """Apply bounding box regression adjustments to anchors"""
        # Convert from [x1, y1, x2, y2] to [center_x, center_y, width, height]
        widths = anchors[:, 2] - anchors[:, 0]
        heights = anchors[:, 3] - anchors[:, 1]
        center_x = anchors[:, 0] + 0.5 * widths
        center_y = anchors[:, 1] + 0.5 * heights
        
        # Apply regression (dx, dy, dw, dh)
        dx = regressions[:, 0]
        dy = regressions[:, 1]
        dw = regressions[:, 2]
        dh = regressions[:, 3]
        
        pred_center_x = center_x + dx * widths
        pred_center_y = center_y + dy * heights
        pred_width = widths * torch.exp(dw)
        pred_height = heights * torch.exp(dh)
        
        # Convert back to [x1, y1, x2, y2] format
        pred_x1 = pred_center_x - 0.5 * pred_width
        pred_y1 = pred_center_y - 0.5 * pred_height
        pred_x2 = pred_center_x + 0.5 * pred_width
        pred_y2 = pred_center_y + 0.5 * pred_height
        
        # Clip to image boundaries
        pred_x1 = torch.clamp(pred_x1, 0, 223)
        pred_y1 = torch.clamp(pred_y1, 0, 223)
        pred_x2 = torch.clamp(pred_x2, 0, 223)
        pred_y2 = torch.clamp(pred_y2, 0, 223)
        
        return torch.stack([pred_x1, pred_y1, pred_x2, pred_y2], dim=1)

# ============================================================================
# SIZE-AWARE LOSS FUNCTIONS
# ============================================================================

class SizeAwareLoss(nn.Module):
    """
    Size-aware loss function that gives higher weight to small tumor examples
    """
    
    def __init__(self, alpha: float = 0.25, gamma: float = 2.0, size_weights: Dict[str, float] = None):
        super(SizeAwareLoss, self).__init__()
        
        self.alpha = alpha
        self.gamma = gamma
        
        # Default weights for different tumor sizes (small tumors get higher weight)
        if size_weights is None:
            size_weights = {
                'small': 3.0,    # Highest weight for small tumors
                'medium': 1.5,   # Medium weight for medium tumors  
                'large': 1.0     # Base weight for large tumors
            }
        self.size_weights = size_weights
        
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor, 
                tumor_sizes: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Compute size-aware loss
        """
        # Standard cross entropy loss
        ce_loss = F.cross_entropy(predictions, targets, reduction='none')
        
        # Apply focal loss component to focus on hard examples
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        # Apply size-based weighting if tumor sizes are provided
        if tumor_sizes is not None:
            size_weights = self._get_size_weights(tumor_sizes).to(predictions.device)
            focal_loss = focal_loss * size_weights
        
        return focal_loss.mean()
    
    def _get_size_weights(self, tumor_sizes: torch.Tensor) -> torch.Tensor:
        """Convert tumor sizes to appropriate weights"""
        weights = torch.ones_like(tumor_sizes, dtype=torch.float32)
        
        # Small tumors (area < 500 pixels)
        small_mask = tumor_sizes < 500
        weights[small_mask] = self.size_weights['small']
        
        # Medium tumors (500 <= area < 2000 pixels)
        medium_mask = (tumor_sizes >= 500) & (tumor_sizes < 2000)
        weights[medium_mask] = self.size_weights['medium']
        
        return weights

# ============================================================================
# DETECTION TRAINER
# ============================================================================

class DetectionTrainer:
    """
    Trainer for the detection and classification system
    """
    
    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model.to(device)
        self.device = device
        
        # Optimizer - only optimize unfrozen parameters
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        self.optimizer = optim.AdamW(trainable_params, lr=1e-4, weight_decay=1e-4)
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
        # Tracking
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        
    def train_epoch(self, train_loader: DataLoader) -> Tuple[float, float]:
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(self.device), targets.to(self.device)
            
            self.optimizer.zero_grad()
            
            # Forward pass
            final_classification, rpn_cls, rpn_reg, clm_features = self.model(data)
            
            # Compute loss
            loss = self.criterion(final_classification, targets)
            
            # Backward pass
            loss.backward()
            self.optimizer.step()
            
            running_loss += loss.item()
            _, predicted = final_classification.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            if batch_idx % 20 == 0:
                accuracy = 100. * correct / total if total > 0 else 0
                print(f'Batch {batch_idx}, Loss: {loss.item():.4f}, Acc: {accuracy:.2f}%')
        
        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = 100. * correct / total if total > 0 else 0
        
        return epoch_loss, epoch_accuracy
    
    def validate_epoch(self, val_loader: DataLoader) -> Tuple[float, float]:
        """Validate for one epoch"""
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, targets in val_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                
                # Forward pass
                final_classification, rpn_cls, rpn_reg, clm_features = self.model(data)
                
                # Compute loss and accuracy
                loss = self.criterion(final_classification, targets)
                
                running_loss += loss.item()
                _, predicted = final_classification.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        epoch_loss = running_loss / len(val_loader) if len(val_loader) > 0 else 0
        epoch_accuracy = 100. * correct / total if total > 0 else 0
        
        return epoch_loss, epoch_accuracy
    
    def train(self, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 15):
        """Complete training procedure"""
        print("Starting Two-Stage Detection Training...")
        
        best_accuracy = 0.0
        
        for epoch in range(epochs):
            # Training
            train_loss, train_acc = self.train_epoch(train_loader)
            self.train_losses.append(train_loss)
            self.train_accuracies.append(train_acc)
            
            # Validation
            val_loss, val_acc = self.validate_epoch(val_loader)
            self.val_losses.append(val_loss)
            self.val_accuracies.append(val_acc)
            
            print(f'Epoch {epoch+1}/{epochs}:')
            print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
            
            # Save best model
            if val_acc > best_accuracy:
                best_accuracy = val_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_accuracy': val_acc,
                }, 'best_detection_model.pth')
                print(f'  New best model saved with validation accuracy: {val_acc:.2f}%')
            
            print('-' * 50)

# ============================================================================
# VISUALIZATION AND EVALUATION TOOLS - FIXED VERSION
# ============================================================================

def visualize_detections(model: nn.Module, dataloader: DataLoader, device: torch.device, num_examples: int = 5):
    """Visualize model detections on sample images - FIXED VERSION"""
    model.eval()
    
    fig, axes = plt.subplots(num_examples, 2, figsize=(15, 3*num_examples))
    if num_examples == 1:
        axes = axes.reshape(1, -1)
    
    examples_processed = 0
    
    with torch.no_grad():
        for data, targets in dataloader:
            if examples_processed >= num_examples:
                break
                
            data, targets = data.to(device), targets.to(device)
            
            # Use detection method
            proposals, confidences = model.detect_regions(data, threshold=0.3)
            
            for i in range(data.size(0)):
                if examples_processed >= num_examples:
                    break
                    
                # Get original image (first channel)
                original_image = data[i, 0].cpu().numpy()
                
                # Plot original image
                axes[examples_processed, 0].imshow(original_image, cmap='gray')
                axes[examples_processed, 0].set_title(f'Original (Label: {"Tumor" if targets[i].item() == 1 else "Normal"})')
                axes[examples_processed, 0].axis('off')
                
                # Plot detections
                axes[examples_processed, 1].imshow(original_image, cmap='gray')
                image_proposals = proposals[i]
                image_confidences = confidences[i]
                
                if len(image_proposals) > 0:
                    for proposal, confidence in zip(image_proposals, image_confidences):
                        x1, y1, x2, y2 = proposal.cpu().numpy()
                        
                        # Ensure valid coordinates
                        x1 = max(0, min(x1, 223))
                        y1 = max(0, min(y1, 223))
                        x2 = max(0, min(x2, 223))
                        y2 = max(0, min(y2, 223))
                        
                        if x2 > x1 and y2 > y1:  # Valid proposal
                            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, 
                                                   linewidth=2, edgecolor='r', facecolor='none')
                            axes[examples_processed, 1].add_patch(rect)
                            
                            # Add confidence text
                            axes[examples_processed, 1].text(x1, max(0, y1-5), f'{confidence:.2f}', 
                                             color='red', fontsize=8, weight='bold')
                
                axes[examples_processed, 1].set_title(f'Detections ({len(image_proposals)} regions)')
                axes[examples_processed, 1].axis('off')
                
                examples_processed += 1
    
    # Hide unused subplots
    for i in range(examples_processed, num_examples):
        axes[i, 0].axis('off')
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.savefig('detection_visualizations.png', dpi=300, bbox_inches='tight')
    plt.show()

def evaluate_detection_performance(model: nn.Module, dataloader: DataLoader, device: torch.device):
    """Evaluate detection performance"""
    model.eval()
    
    all_predictions = []
    all_targets = []
    detection_counts = []  # Number of regions detected per image
    
    with torch.no_grad():
        for data, targets in dataloader:
            data, targets = data.to(device), targets.to(device)
            
            # Get image-level predictions
            final_classification, _, _, _ = model(data)
            
            # Get detection counts
            proposals, _ = model.detect_regions(data, threshold=0.3)
            
            for i in range(data.size(0)):
                all_predictions.append(F.softmax(final_classification[i], dim=0)[1].item())
                all_targets.append(targets[i].item())
                detection_counts.append(len(proposals[i]))
    
    # Convert to numpy
    predictions = np.array(all_predictions)
    targets = np.array(all_targets)
    detection_counts = np.array(detection_counts)
    
    # Compute metrics
    results = {}
    thresholds = [0.1, 0.3, 0.5, 0.7, 0.9]
    
    for threshold in thresholds:
        pred_labels = (predictions >= threshold).astype(int)
        accuracy = (pred_labels == targets).mean()
        recall = (pred_labels[targets == 1] == 1).mean() if (targets == 1).sum() > 0 else 0
        precision = (targets[pred_labels == 1] == 1).mean() if (pred_labels == 1).sum() > 0 else 0
        
        results[threshold] = {
            'accuracy': accuracy,
            'recall': recall,
            'precision': precision
        }
    
    # Plot precision-recall curve
    precision_vals, recall_vals, _ = precision_recall_curve(targets, predictions)
    ap = average_precision_score(targets, predictions)
    
    plt.figure(figsize=(10, 6))
    plt.plot(recall_vals, precision_vals, linewidth=2)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'Detection Precision-Recall Curve (AP: {ap:.4f})')
    plt.grid(True)
    plt.savefig('detection_precision_recall.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Analyze detection counts
    if len(targets[targets == 1]) > 0:
        tumor_detections = detection_counts[targets == 1]
        print(f"Average detections in tumor images: {tumor_detections.mean():.2f} ¬± {tumor_detections.std():.2f}")
    
    if len(targets[targets == 0]) > 0:
        normal_detections = detection_counts[targets == 0]
        print(f"Average detections in normal images: {normal_detections.mean():.2f} ¬± {normal_detections.std():.2f}")
    
    return results, ap

# ============================================================================
# MAIN DETECTION PIPELINE
# ============================================================================

def main_detection():
    """Main function for detection and classification pipeline"""
    # Configuration
    DATA_DIR = "/content/Class_project_Early_detection_of_brain_tumor/CT_enhanced"
    BATCH_SIZE = 8
    EPOCHS = 15
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create datasets
    train_dataset = BrainTumorDataset(DATA_DIR, split="train")
    val_dataset = BrainTumorDataset(DATA_DIR, split="val")
    test_dataset = BrainTumorDataset(DATA_DIR, split="test")
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    # Initialize detection system
    detection_model = BrainTumorDetectionSystem(input_channels=4, num_classes=2)
    
    # Load pre-trained CLM weights
    try:
        clm_checkpoint = torch.load('best_clm_model.pth', map_location=device)
        # Transfer CLM weights to detection system
        detection_model.clm_classifier.load_state_dict(clm_checkpoint['model_state_dict'])
        print("‚úÖ Loaded pre-trained CLM weights")
        
        # Freeze CLM initially for stable training
        for param in detection_model.clm_classifier.parameters():
            param.requires_grad = False
        print("‚úÖ Frozen CLM weights for initial training")
        
    except FileNotFoundError as e:
        print(f"‚ö†Ô∏è No pre-trained CLM found: {e}")
        print("Training from scratch...")
    
    # Initialize trainer
    trainer = DetectionTrainer(detection_model, device)
    
    # Train the detection system
    trainer.train(train_loader, val_loader, epochs=EPOCHS)
    
    # Unfreeze CLM for fine-tuning if it was frozen
    try:
        for param in detection_model.clm_classifier.parameters():
            param.requires_grad = True
        print("‚úÖ Unfrozen CLM weights for fine-tuning")
        
        # Reinitialize optimizer with all parameters
        trainer.optimizer = optim.AdamW(detection_model.parameters(), lr=5e-5, weight_decay=1e-4)
        print("Fine-tuning with unfrozen CLM...")
        trainer.train(train_loader, val_loader, epochs=5)
    except Exception as e:
        print(f"Fine-tuning skipped: {e}")
    
    # Visualize detections
    print("Visualizing detections...")
    visualize_detections(detection_model, val_loader, device, num_examples=5)
    
    # Evaluate detection performance
    print("Evaluating detection performance...")
    results, ap = evaluate_detection_performance(detection_model, test_loader, device)
    
    print("\nüìä Detection Results:")
    for threshold, metrics in results.items():
        print(f"Threshold {threshold}: "
              f"Accuracy={metrics['accuracy']:.4f}, "
              f"Recall={metrics['recall']:.4f}, "
              f"Precision={metrics['precision']:.4f}")
    print(f"Average Precision: {ap:.4f}")
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(trainer.train_losses, label='Training Loss')
    plt.plot(trainer.val_losses, label='Validation Loss')
    plt.title('Detection Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(trainer.train_accuracies, label='Training Accuracy')
    plt.plot(trainer.val_accuracies, label='Validation Accuracy')
    plt.title('Detection Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('detection_training_history.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("\n‚úÖ Two-Stage Detection and Classification completed successfully!")
    print("üìä Model saved as: best_detection_model.pth")
    print("üñºÔ∏è Detection visualizations saved as: detection_visualizations.png")
    print("üìà Evaluation results saved as: detection_precision_recall.png")

# ============================================================================
# INTEGRATION WITH EXISTING PIPELINE
# ============================================================================

def run_complete_pipeline():
    """Run the complete pipeline from feature extraction to detection"""
    print("üöÄ Starting Complete Brain Tumor Detection Pipeline...")
    
    # First, ensure feature extraction is complete
    try:
        # Test if CLM model exists
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        clm_model = CorrelationLearningMechanism(input_channels=4, num_classes=2)
        clm_checkpoint = torch.load('best_clm_model.pth', map_location=device)
        clm_model.load_state_dict(clm_checkpoint['model_state_dict'])
        print("‚úÖ Loaded trained CLM model")
    except Exception as e:
        print(f"‚ö†Ô∏è CLM model not found or error loading: {e}")
        print("Please run feature extraction first or continue with detection from scratch")
    
    # Now run detection and classification
    main_detection()

if __name__ == "__main__":
    # Run the complete detection pipeline
    run_complete_pipeline()

üöÄ Starting Complete Brain Tumor Detection Pipeline...
‚ö†Ô∏è CLM model not found or error loading: name 'CorrelationLearningMechanism' is not defined
Please run feature extraction first or continue with detection from scratch
Using device: cpu


NameError: name 'BrainTumorDataset' is not defined