In [35]:
import os
import logging
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from monai.networks.nets import resnet18  # Import resnet18 from MONAI
from fastai.learner import Learner
from fastai.data.core import DataLoaders
from fastai.metrics import accuracy
from fastai.losses import CrossEntropyLossFlat
from fastai.callback.all import SaveModelCallback, EarlyStoppingCallback
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [36]:
# Configure logging
logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("heart_disease_model.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)


In [41]:
class NPYDataset(Dataset):
    """
    Custom PyTorch Dataset for loading 3D medical imaging data from .npy files.
    """
    def __init__(self, dataframe, image_column_name, label_column_name, custom_transform=None):
        """
        Initialize the dataset with optional custom transforms.
        
        Args:
            dataframe (pd.DataFrame): Dataframe containing file paths and labels
            image_column_name (str): Column name for image file paths
            label_column_name (str): Column name for labels
            custom_transform (callable, optional): Optional custom transform pipeline
        """
        self.dataframe = dataframe
        self.image_column_name = image_column_name
        self.label_column_name = label_column_name
        
        # Default transforms 
        default_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(15),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.Resize((112, 112)),  # Med3D default size
        ])
        
        # Use custom transform if provided, otherwise use default
        self.transform = custom_transform if custom_transform is not None else default_transform

    def __len__(self):
        """Return the total number of samples in the dataset."""
        return len(self.dataframe)

    def __getitem__(self, idx):
        """
        Load and preprocess a single sample.
        
        Returns:
            tuple: (processed image, label)
        """
        try:
            npy_path = self.dataframe[self.image_column_name].iloc[idx]
            label = self.dataframe[self.label_column_name].iloc[idx]
            
            # Load and preprocess image
            image = np.load(npy_path)[:, :, :, 0]  # 2nd axis view
            image = image[17:33, :, :]  # Select frames 17 to 32
            
            # Convert to tensor and add channel dimension
            image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)
            
            # Apply transforms
            image = self.transform(image)
            
            return image, label
        
        except Exception as e:
            logger.error(f"Error loading image at index {idx}: {e}")
            raise


class HeartDiseaseModel:
    """
    Comprehensive model for heart disease classification using 3D medical imaging.
    """
    def __init__(self, config):
        """
        Initialize the model with comprehensive configuration.
        
        Args:
            config (dict): Configuration dictionary with model parameters
        """
        # Validate and set configuration
        self.config = self._validate_config(config)
        
        # Set up logging
        self.logger = logging.getLogger(self.__class__.__name__)
        
        # Set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.logger.info(f"Using device: {self.device}")
        
        # Prepare data and model
        self._prepare_data()
        self._prepare_model()

    def _validate_config(self, config):
        """
        Validate and set default values for configuration.
        
        Args:
            config (dict): Input configuration
        
        Returns:
            dict: Validated configuration with defaults
        """
        default_config = {
            'train_dataframe_path': None,
            'test_dataframe_path': None,
            'image_column_name': 'FilePath',
            'label_column_name': 'CAD',
            'pretrained_weights_path': None,
            'batch_size': 8,
            'split_ratio': 0.85,
            'model_name': 'heart_disease_model_resnet18',
            'learning_rate': 1e-5,
            'epochs': 50,
            'early_stopping_patience': 20,
            'weight_decay': 1e-4
        }
        
        # Update defaults with provided config
        default_config.update(config)
        
        # Validate required paths
        required_paths = [
            'train_dataframe_path', 
            'test_dataframe_path', 
            'pretrained_weights_path'
        ]
        for path in required_paths:
            if not default_config[path] or not os.path.exists(default_config[path]):
                raise ValueError(f"Invalid path for {path}: {default_config[path]}")
        
        return default_config

    def _prepare_data(self):
        """Prepare training, validation, and test datasets and dataloaders."""
        try:
            # Load dataframes
            train_df = pd.read_csv(self.config['train_dataframe_path'])
            test_df = pd.read_csv(self.config['test_dataframe_path'])
            
            # Create dataset
            dataset = NPYDataset(
                train_df, 
                self.config['image_column_name'], 
                self.config['label_column_name']
            )
            
            # Split train and validation
            train_size = int(self.config['split_ratio'] * len(dataset))
            val_size = len(dataset) - train_size
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(
                dataset, [train_size, val_size]
            )
            
            # Create dataloaders
            self.train_loader = DataLoader(
                self.train_dataset, 
                batch_size=self.config['batch_size'], 
                shuffle=True, 
                num_workers=8
            )
            self.val_loader = DataLoader(
                self.val_dataset, 
                batch_size=self.config['batch_size'], 
                shuffle=False, 
                num_workers=8
            )
            
            # Prepare test data
            self.test_dataset = NPYDataset(
                test_df, 
                self.config['image_column_name'], 
                self.config['label_column_name']
            )
            self.test_loader = DataLoader(
                self.test_dataset, 
                batch_size=self.config['batch_size'], 
                shuffle=False, 
                num_workers=8
            )
            
            self.logger.info("Data preparation completed successfully")
        
        except Exception as e:
            self.logger.error(f"Error in data preparation: {e}")
            raise

    def _prepare_model(self):
        """Prepare the model, loss function, and learning rate scheduler."""
        try:
            # Initialize model with ResNet18
            self.model = resnet18(
                spatial_dims=3, 
                n_input_channels=1,  # Assuming grayscale input
                num_classes=2  # Binary classification
            )
            
            # Load pretrained weights
            self.logger.info("Loading pretrained weights...")
            state_dict = torch.load(self.config['pretrained_weights_path'])
            self.model.load_state_dict(state_dict, strict=False)
            
            # Parallel processing and device transfer
            self.model = nn.DataParallel(self.model)
            self.model.to(self.device)
            
            # Create FastAI learner
            self.dls = DataLoaders(self.train_loader, self.val_loader)
            self.learn = Learner(
                self.dls,
                self.model,
                loss_func=CrossEntropyLossFlat(),
                metrics=[accuracy],
                wd=self.config['weight_decay'],
                cbs=[
                    SaveModelCallback(
                        fname=self.config['model_name'], 
                        monitor='valid_loss'
                    ),
                    EarlyStoppingCallback(
                        monitor='valid_loss', 
                        patience=self.config['early_stopping_patience']
                    )
                ]
            ).to_fp16()
            
            self.logger.info("Model preparation completed successfully")
        
        except Exception as e:
            self.logger.error(f"Error in model preparation: {e}")
            raise

    def train(self):
        """Train the model with configured hyperparameters."""
        try:
            self.logger.info("Starting model training...")
            self.learn.fine_tune(
                self.config['epochs'], 
                base_lr=self.config['learning_rate']
            )
            self.logger.info("Model training completed successfully")
        
        except Exception as e:
            self.logger.error(f"Error during model training: {e}")
            raise

    def evaluate(self):
        """Evaluate model performance on validation and test datasets."""
        try:
            def evaluate_dataset(loader, dataset_name):
                self.learn.validate()
                preds, targs = self._get_predictions(loader)
                
                # Classification report
                report = classification_report(
                    targs, preds, 
                    target_names=['No CAD', 'CAD']
                )
                self.logger.info(f"Classification Report for {dataset_name} Data:\n{report}")
                
                # Confusion matrix
                cm = confusion_matrix(targs, preds)
                plt.figure(figsize=(8, 6))
                sns.heatmap(
                    cm, 
                    annot=True, 
                    fmt='d', 
                    cmap='Blues', 
                    xticklabels=['No CAD', 'CAD'], 
                    yticklabels=['No CAD', 'CAD']
                )
                plt.xlabel('Predicted')
                plt.ylabel('True')
                plt.title(f'Confusion Matrix for {dataset_name} Data')
                plt.savefig(f'{dataset_name.lower()}_confusion_matrix.png')
                plt.close()
            
            # Evaluate validation and test datasets
            evaluate_dataset(self.val_loader, "Validation")
            evaluate_dataset(self.test_loader, "Test")
        
        except Exception as e:
            self.logger.error(f"Error during model evaluation: {e}")
            raise

    def _get_predictions(self, data_loader):
        """
        Generate predictions for a given dataloader.
        """
        preds, targs = [], []
        self.model.eval()
        with torch.no_grad():
            for batch in data_loader:
                images, labels = batch
                images = images.to(self.device)
                labels = labels.to(self.device)
                
                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                
                preds.extend(predicted.cpu().numpy())
                targs.extend(labels.cpu().numpy())
        
        return preds, targs

In [None]:
def main():
    """Main execution function for the heart disease classification model."""
    try:
        config = {
            'train_dataframe_path': 'Final_Datasets/train_resnet_heart.csv',
            'test_dataframe_path': 'Final_Datasets/test_data_incidence.csv',
            'pretrained_weights_path': '../Med3D/resnet_18_23dataset.pth',  # Path to ResNet18 weights
            'model_name': 'heart_ch0_3channel_MONAI_resnet18',
            'epochs': 50,
            'learning_rate': 1e-5
        }
        
        model = HeartDiseaseModel(config)
        model.train()
        model.evaluate()
    
    except Exception as e:
        logger.error(f"Critical error in main execution: {e}")
        raise

if __name__ == "__main__":
    main()


2024-12-01 16:07:42,257 - HeartDiseaseModel - INFO - Using device: cuda
2024-12-01 16:07:43,129 - HeartDiseaseModel - INFO - Data preparation completed successfully
2024-12-01 16:07:43,919 - HeartDiseaseModel - INFO - Loading pretrained weights...
  state_dict = torch.load(self.config['pretrained_weights_path'])
2024-12-01 16:07:44,476 - HeartDiseaseModel - INFO - Model preparation completed successfully
2024-12-01 16:07:44,478 - HeartDiseaseModel - INFO - Starting model training...
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()


epoch,train_loss,valid_loss,accuracy,time
0,0.679124,0.654153,0.61305,00:57


Better model found at epoch 0 with valid_loss value: 0.6541532874107361.


  state = torch.load(file, map_location=device, **torch_load_kwargs)


epoch,train_loss,valid_loss,accuracy,time
0,0.6711,0.656159,0.606222,00:55
1,0.659614,0.642117,0.638088,00:55
2,0.668766,0.644181,0.632777,00:55
3,0.66771,0.645368,0.630501,00:55
4,0.660385,0.640198,0.638847,00:58
5,0.67008,0.649243,0.630501,00:55
6,0.67909,0.640185,0.638088,00:56
7,0.650495,0.6457,0.637329,00:56
8,0.648227,0.643213,0.638847,00:55
9,0.639713,0.655661,0.628225,00:56


Better model found at epoch 0 with valid_loss value: 0.6561588644981384.
Better model found at epoch 1 with valid_loss value: 0.6421169638633728.
Better model found at epoch 4 with valid_loss value: 0.6401976346969604.
Better model found at epoch 6 with valid_loss value: 0.6401851773262024.
Better model found at epoch 17 with valid_loss value: 0.6357626914978027.
Better model found at epoch 18 with valid_loss value: 0.6327994465827942.
Better model found at epoch 21 with valid_loss value: 0.6301022171974182.
Better model found at epoch 24 with valid_loss value: 0.6226145029067993.
Better model found at epoch 32 with valid_loss value: 0.6208434104919434.
Better model found at epoch 33 with valid_loss value: 0.6168020963668823.
Better model found at epoch 34 with valid_loss value: 0.6100522875785828.
Better model found at epoch 36 with valid_loss value: 0.6084619164466858.


2024-12-01 16:55:17,925 - HeartDiseaseModel - INFO - Model training completed successfully
