In [1]:
import os
import numpy as np
import pydicom
from pathlib import Path
import torch
from tqdm.notebook import tqdm
import concurrent.futures
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [None]:
@dataclass
class ScanInfo:
    """Class to store information about each scan"""
    patient_id: str
    exam_date: str
    file_path: Path
    group: str  # 'PD', 'SWEDD', or 'Control'
    
class DATSCANDataset:
    """Class to handle DATSCAN image dataset"""
    def __init__(self, base_path: str, device: str = 'cuda'):
        self.base_path = Path(base_path)
        self.device = device
        self.scans: Dict[str, List[ScanInfo]] = {
            'PD': [],
            'SWEDD': [],
            'Control': []
        }
        
    def discover_files(self):
        """Discover all DICOM files in the directory structure"""
        group_folders = {
            'PD': 'PPMI_Images_PD',
            'SWEDD': 'PPMI_Images_SWEDD',
            'Control': 'PPMI_Images_Cont'
        }
        
        for group, folder in group_folders.items():
            group_path = self.base_path / 'Images' / folder
            if not group_path.exists():
                logger.warning(f"Group path does not exist: {group_path}")
                continue
                
            logger.info(f"Discovering files for group: {group}")
            for patient_folder in tqdm(list(group_path.iterdir()), desc=f"Processing {group} patients"):
                if not patient_folder.is_dir():
                    continue
                    
                datscan_path = patient_folder / 'Reconstructed_DaTSCAN'
                if not datscan_path.exists():
                    continue
                    
                for exam_date_folder in datscan_path.iterdir():
                    for exam_id_folder in exam_date_folder.iterdir():
                        dicom_files = list(exam_id_folder.glob('*.dcm'))
                        for dicom_file in dicom_files:
                            scan_info = ScanInfo(
                                patient_id=patient_folder.name,
                                exam_date=exam_date_folder.name,
                                file_path=dicom_file,
                                group=group
                            )
                            self.scans[group].append(scan_info)
                            
        return self._summarize_dataset()
    
    def _summarize_dataset(self) -> Dict:
        """Summarize the discovered dataset"""
        summary = {
            'total_scans': sum(len(scans) for scans in self.scans.values()),
            'scans_per_group': {group: len(scans) for group, scans in self.scans.items()},
            'unique_patients': {
                group: len(set(scan.patient_id for scan in scans))
                for group, scans in self.scans.items()
            }
        }
        return summary

In [None]:
# Cell 3: Define image loading and preprocessing functions
def load_dicom(file_path: Path) -> Optional[np.ndarray]:
    """Load a DICOM file and return as numpy array"""
    try:
        dcm = pydicom.dcmread(str(file_path))
        return dcm.pixel_array
    except Exception as e:
        logger.error(f"Error loading DICOM file {file_path}: {str(e)}")
        return None

def preprocess_image(image: np.ndarray) -> torch.Tensor:
    """Preprocess the image and convert to tensor"""
    # Normalize to [0,1] range
    image = (image - image.min()) / (image.max() - image.min())
    
    # Convert to tensor and add batch & channel dimensions
    tensor = torch.from_numpy(image).float()
    if len(tensor.shape) == 2:
        tensor = tensor.unsqueeze(0).unsqueeze(0)  # Add batch and channel dims
    return tensor

class DATSCANLoader:
    """Class to handle batch loading of DATSCAN images"""
    def __init__(self, dataset: DATSCANDataset, batch_size: int = 32):
        self.dataset = dataset
        self.batch_size = batch_size
        
    def load_group(self, group: str, max_samples: Optional[int] = None) -> torch.Tensor:
        """Load all images for a specific group"""
        scans = self.dataset.scans[group]
        if max_samples:
            scans = scans[:max_samples]
            
        all_images = []
        for i in tqdm(range(0, len(scans), self.batch_size), desc=f"Loading {group} images"):
            batch_scans = scans[i:i + self.batch_size]
            batch_images = self._load_batch(batch_scans)
            all_images.extend(batch_images)
            
        return torch.stack(all_images).to(self.dataset.device)
    
    def _load_batch(self, batch_scans: List[ScanInfo]) -> List[torch.Tensor]:
        """Load a batch of images in parallel"""
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(load_dicom, scan.file_path) for scan in batch_scans]
            images = []
            for future in concurrent.futures.as_completed(futures):
                image = future.result()
                if image is not None:
                    tensor = preprocess_image(image)
                    images.append(tensor)
        return images

In [None]:
# Cell 4: Initialize the dataset and discover files
dataset = DATSCANDataset(base_path='.', device='cuda')
summary = dataset.discover_files()
print("Dataset Summary:")
print(f"Total number of scans: {summary['total_scans']}")
print("\nScans per group:")
for group, count in summary['scans_per_group'].items():
    print(f"{group}: {count} scans")
print("\nUnique patients per group:")
for group, count in summary['unique_patients'].items():
    print(f"{group}: {count} patients")

In [None]:
# Cell 5: Create the data loader
loader = DATSCANLoader(dataset, batch_size=32)  # Adjust batch size based on GPU memory