In [20]:
!pip install ipywidgets

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting ipywidgets
  Downloading ipywidgets-8.1.5-py3-none-any.whl.metadata (2.3 kB)
Collecting widgetsnbextension~=4.0.12 (from ipywidgets)
  Downloading widgetsnbextension-4.0.13-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab-widgets~=3.0.12 (from ipywidgets)
  Downloading jupyterlab_widgets-3.0.13-py3-none-any.whl.metadata (4.1 kB)
Downloading ipywidgets-8.1.5-py3-none-any.whl (139 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.8/139.8 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jupyterlab_widgets-3.0.13-py3-none-any.whl (214 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.4/214.4 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading widgetsnbextension-4.0.13-py3-none-any.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
[?25h

In [18]:
# Cell 1: Import Libraries
import os
import gc
import warnings
from pathlib import Path
from typing import Union, List, Tuple, Optional, Dict
import time

# Data processing
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import pydicom
from skimage import transform, filters, morphology
from scipy import ndimage

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Suppress warnings
warnings.filterwarnings('ignore')

In [2]:
# Cell 2: GPU Setup and Environment Configuration
def setup_environment(seed: int = 42) -> Dict[str, Union[torch.device, bool]]:
    """
    Configure the GPU environment and set random seeds for reproducibility.
    
    Args:
        seed (int): Random seed for reproducibility
        
    Returns:
        dict: Environment configuration including device and GPU availability
    """
    # Set random seeds
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # Configure environment
    env_config = {
        'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        'gpu_available': torch.cuda.is_available(),
        'amp_enabled': True,  # Enable Automatic Mixed Precision
    }
    
    if env_config['gpu_available']:
        # Configure CUDA for optimal performance
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        
        # Print GPU information
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Memory Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    else:
        print("WARNING: GPU not available. Using CPU.")
    
    return env_config

In [3]:
# Cell 3: Memory Management Utilities
class MemoryTracker:
    """Utility class to track and manage GPU memory usage."""
    
    @staticmethod
    def get_gpu_memory_usage() -> Tuple[float, float]:
        """
        Get current GPU memory usage.
        
        Returns:
            tuple: (allocated memory in GB, cached memory in GB)
        """
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1e9
            cached = torch.cuda.memory_reserved() / 1e9
            return allocated, cached
        return 0.0, 0.0
    
    @staticmethod
    def clear_gpu_memory():
        """Clear unused GPU memory and run garbage collection."""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
    @staticmethod
    def log_memory_usage(prefix: str = ""):
        """
        Log current GPU memory usage.
        
        Args:
            prefix (str): Optional prefix for the log message
        """
        if torch.cuda.is_available():
            allocated, cached = MemoryTracker.get_gpu_memory_usage()
            print(f"{prefix}GPU Memory: {allocated:.2f} GB allocated, {cached:.2f} GB cached")

# Initialize environment
env_config = setup_environment()
memory_tracker = MemoryTracker()

# Display initial memory usage
memory_tracker.log_memory_usage("Initial ")

GPU: NVIDIA GeForce RTX 4070 Ti
Memory Available: 12.88 GB
Initial GPU Memory: 0.00 GB allocated, 0.00 GB cached


In [13]:
# Cell 4: File Collection and Validation
class DataCollector:
    """Handles collecting and validating DICOM files from the PPMI dataset."""
    
    def __init__(self, base_path: Union[str, Path]):
        self.base_path = Path(base_path)
        self.excluded_files: List[Path] = []
        self.valid_files: List[Dict] = []
        
    def _is_valid_file(self, file_path: Path) -> bool:
        """
        Check if file is valid (not containing 'br_raw').
        
        Args:
            file_path (Path): Path to check
            
        Returns:
            bool: True if file is valid, False otherwise
        """
        return (file_path.suffix.lower() == '.dcm' and 
                'br_raw' not in str(file_path).lower())
    
    def _get_group_from_path(self, file_path: Path) -> str:
        """Extract group (PD, SWEDD, Control) from file path."""
        path_str = str(file_path)
        if 'PPMI_Images_PD' in path_str:
            return 'PD'
        elif 'PPMI_Images_SWEDD' in path_str:
            return 'SWEDD'
        elif 'PPMI_Images_Cont' in path_str:
            return 'Control'
        else:
            return 'Unknown'
    
    def collect_files(self) -> pd.DataFrame:
        """
        Recursively collect all valid DICOM files.
        
        Returns:
            pd.DataFrame: DataFrame containing file information
        """
        print(f"Starting file collection from: {self.base_path}")
        
        # First, check if the path exists
        if not self.base_path.exists():
            raise FileNotFoundError(f"Path does not exist: {self.base_path}")
        
        print("Building file list... (updating every 1000 files found)")
        
        # Use a more efficient way to collect files with progress updates
        all_files = []
        files_found = 0
        start_time = time.time()
        
        for root, _, files in os.walk(str(self.base_path)):
            dcm_files = [Path(root) / f for f in files if f.endswith('.dcm')]
            all_files.extend(dcm_files)
            
            files_found += len(dcm_files)
            if files_found % 1000 == 0:
                elapsed = time.time() - start_time
                print(f"\rFound {files_found} DICOM files... ({elapsed:.1f} seconds elapsed)", 
                      end="", flush=True)
        
        total_files = len(all_files)
        print(f"\nCompleted file discovery: {total_files} DICOM files found in {time.time() - start_time:.1f} seconds")
        
        if total_files == 0:
            raise ValueError(f"No DICOM files found in {self.base_path}")
            
        print(f"\nFound {total_files} DICOM files. Processing...")
        
        # Process files with progress bar
        for file_path in tqdm(all_files, 
                            desc="Processing files",
                            ncols=80,
                            unit='files'):
            if self._is_valid_file(file_path):
                self.valid_files.append({
                    'path': str(file_path),
                    'group': self._get_group_from_path(file_path),
                    'patient_id': file_path.parents[3].name
                })
            else:
                self.excluded_files.append(file_path)
            
            # Show interim progress every 1000 files
            if len(self.valid_files) % 1000 == 0:
                print(f"\nInterim progress:")
                print(f"Processed {len(self.valid_files) + len(self.excluded_files)}/{total_files} files")
                print(f"Valid: {len(self.valid_files)}, Excluded: {len(self.excluded_files)}")
        
        # Create DataFrame
        df = pd.DataFrame(self.valid_files)
        
        # Print summary
        print(f"\nCollection Summary:")
        print(f"Valid files: {len(self.valid_files)}")
        print(f"Excluded files: {len(self.excluded_files)}")
        print("\nGroup distribution:")
        print(df['group'].value_counts())
        
        return df

In [14]:
# Cell 5: DICOM Processing and Preprocessing
class DICOMProcessor:
    """Handles DICOM loading and preprocessing."""
    
    def __init__(self, target_shape: Tuple[int, int, int] = (128, 128, 128)):
        self.target_shape = target_shape
        self.memory_tracker = MemoryTracker()
    
    def load_dicom(self, file_path: Union[str, Path], to_gpu: bool = False) -> torch.Tensor:
        """
        Load and preprocess DICOM file.
        
        Args:
            file_path: Path to DICOM file
            to_gpu: Whether to move tensor to GPU immediately after preprocessing
            
        Returns:
            torch.Tensor: Preprocessed image tensor
        """
        # Load DICOM
        dcm = pydicom.dcmread(str(file_path))
        
        # Extract pixel array and convert to float32
        img = dcm.pixel_array.astype(np.float32)
        
        # Apply rescaling
        if hasattr(dcm, 'RescaleSlope') and hasattr(dcm, 'RescaleIntercept'):
            img = img * float(dcm.RescaleSlope) + float(dcm.RescaleIntercept)
        
        # Preprocess
        img = self._preprocess_volume(img)
        
        # Convert to tensor
        tensor = torch.from_numpy(img).unsqueeze(0)  # Add channel dimension
        
        # Move to GPU if requested
        if to_gpu and torch.cuda.is_available():
            tensor = tensor.cuda()
            self.memory_tracker.log_memory_usage("After GPU transfer: ")
        
        return tensor
    
    def _preprocess_volume(self, volume: np.ndarray) -> np.ndarray:
        """
        Apply preprocessing steps to volume.
        
        Args:
            volume: Input volume array
            
        Returns:
            np.ndarray: Preprocessed volume
        """
        # Intensity normalization
        volume = np.clip(volume, 0, None)  # Remove negative values
        if volume.max() > 0:
            volume = (volume - volume.min()) / (volume.max() - volume.min())
        
        # Brain masking using Otsu thresholding
        threshold = filters.threshold_otsu(volume)
        mask = volume > threshold
        mask = morphology.binary_closing(mask)
        volume = volume * mask
        
        # Resize to target shape
        if volume.shape != self.target_shape:
            # Calculate padding/cropping
            pad_width = [(max((t - s) // 2, 0), max((t - s + 1) // 2, 0))
                        for t, s in zip(self.target_shape, volume.shape)]
            crop_width = [(max((s - t) // 2, 0), max((s - t + 1) // 2, 0))
                         for t, s in zip(self.target_shape, volume.shape)]
            
            # Apply padding if needed
            if any(p[0] > 0 or p[1] > 0 for p in pad_width):
                volume = np.pad(volume, pad_width, mode='constant')
            
            # Apply cropping if needed
            if any(c[0] > 0 or c[1] > 0 for c in crop_width):
                slices = tuple(slice(c[0], s - c[1]) for c, s in zip(crop_width, volume.shape))
                volume = volume[slices]
        
        return volume

In [15]:
# Cell 6: PyTorch Dataset Implementation
class DATSCANDataset(Dataset):
    """PyTorch Dataset for DATSCAN images."""
    
    def __init__(self, data_df: pd.DataFrame, processor: DICOMProcessor, 
                 to_gpu: bool = False):
        """
        Initialize dataset.
        
        Args:
            data_df: DataFrame containing file paths and labels
            processor: DICOM processor instance
            to_gpu: Whether to move tensors to GPU immediately
        """
        self.data_df = data_df
        self.processor = processor
        self.to_gpu = to_gpu
    
    def __len__(self) -> int:
        return len(self.data_df)
    
    def __getitem__(self, idx: int) -> torch.Tensor:
        """Get preprocessed image tensor."""
        file_path = self.data_df.iloc[idx]['path']
        return self.processor.load_dicom(file_path, self.to_gpu)

In [19]:
# Example usage:
if __name__ == "__main__":
    # Initialize collector and processor
    collector = DataCollector(base_path="Images")
    processor = DICOMProcessor()
    
    # Collect files
    data_df = collector.collect_files()
    
    # Create dataset
    dataset = DATSCANDataset(data_df, processor, to_gpu=True)
    
    # Create dataloader with optimized settings
    dataloader = DataLoader(
        dataset,
        batch_size=2,  # Start with small batch size, adjust based on GPU memory
        num_workers=6,  # Optimized for 8-core CPU
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )

Starting file collection from: Images
Building file list... (updating every 1000 files found)
Found 7000 DICOM files... (128.7 seconds elapsed)
Completed file discovery: 7754 DICOM files found in 211.3 seconds

Found 7754 DICOM files. Processing...


ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html