# Lung Cancer Detection with 3D CNN (LUNA16 Dataset - Subset 0)

This notebook demonstrates the steps to build and train a 3D CNN for lung nodule detection using **Subset 0** of the LUNA16 dataset.

**Note:** Ensure you have downloaded `subset0` and `candidates.csv` into the `data/` directory as described in the `README.md`.

In [None]:
import sys
import os
import pandas as pd
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
import torch

# Add src to path so we can import our modules
sys.path.append(os.path.abspath(os.path.join('..', 'src')))

from dataset import LunaDataset
from model import Simple3DCNN
from utils import plot_3d_scan, plot_nodule

%matplotlib inline

## 2. Load and Visualize CT Scan Data

We will load the `candidates.csv` file to find the location of potential nodules and then load the corresponding CT scan.

In [None]:
# Load candidates
candidates_file = '../data/candidates.csv'
if os.path.exists(candidates_file):
    candidates = pd.read_csv(candidates_file)
    print(f"Total candidates: {len(candidates)}")
    print(candidates.head())
else:
    print("candidates.csv not found. Please download the dataset.")

# Example: Load one scan (assuming subset0 is present)
# You might need to adjust the path to where you extracted subset0
subset0_path = '../data/subset0'
if os.path.exists(subset0_path):
    # Get the first candidate in subset0
    # Note: This is a simplified check. In reality, you'd match seriesuid to files in subset0.
    # Let's just pick a file from the directory for visualization
    mhd_files = [f for f in os.listdir(subset0_path) if f.endswith('.mhd')]
    if mhd_files:
        example_file = os.path.join(subset0_path, mhd_files[0])
        print(f"Loading {example_file}...")
        
        itk_image = sitk.ReadImage(example_file)
        scan_array = sitk.GetArrayFromImage(itk_image) # (Z, Y, X)
        
        print(f"Scan shape: {scan_array.shape}")
        print(f"Origin: {itk_image.GetOrigin()}")
        print(f"Spacing: {itk_image.GetSpacing()}")
        
        # Visualize middle slice
        plot_3d_scan(scan_array)
    else:
        print("No .mhd files found in subset0.")
else:
    print("subset0 directory not found.")

## 3. Preprocess Volumes (Normalization and Patch Extraction)

We need to extract 3D patches around the candidates and normalize the pixel intensities (Hounsfield Units). The `LunaDataset` class handles this.

In [None]:
# Initialize Dataset
# Note: This will fail if data is not present.
if os.path.exists(subset0_path) and os.path.exists(candidates_file):
    dataset = LunaDataset(
        root_dir=subset0_path,
        candidates_file=candidates_file,
        patch_size=(64, 64, 64)
    )
    
    # Get a sample
    # We need to find an index that corresponds to a file in subset0
    # Since our dataset class currently iterates all candidates, we might hit a file not in subset0.
    # For demonstration, we'll just try to find one that works or mock it.
    
    print("Attempting to load a sample patch...")
    try:
        # In a real scenario, you'd filter candidates by subset.
        # Here we just try the first few until we find one in our subset folder.
        for i in range(100):
            try:
                patch, label = dataset[i]
                print(f"Loaded patch index {i}")
                print(f"Patch shape: {patch.shape}")
                print(f"Label: {label}")
                
                # Visualize the middle slice of the patch
                plt.imshow(patch[0, 32, :, :], cmap='gray')
                plt.title(f"Patch Middle Slice (Label: {label})")
                plt.show()
                break
            except FileNotFoundError:
                continue
    except Exception as e:
        print(f"Error loading patch: {e}")
else:
    print("Data not available for preprocessing demonstration.")

## 4. Data Augmentation for 3D Images

Data augmentation is crucial to prevent overfitting. Common techniques for 3D medical images include:
*   Random rotations (axial, coronal, sagittal)
*   Random flipping
*   Elastic deformations
*   Adding noise

*Note: The current `LunaDataset` implementation does not include augmentation, but it can be added in the `__getitem__` method.*

## 5. Build the 3D CNN Architecture

We use a simple 3D CNN with 4 convolutional blocks, each followed by max pooling. The final layers are fully connected for binary classification.

In [None]:
model = Simple3DCNN()
print(model)

# Test with a dummy input
dummy_input = torch.randn(1, 1, 64, 64, 64)
output = model(dummy_input)
print(f"Output shape: {output.shape}")
print(f"Output value: {output.item()}")

## 6. Compile and Train the Model

We use Binary Cross Entropy Loss and the Adam optimizer. The training loop is defined in `src/train.py`.

To run the training, you can execute the script from the terminal:
```bash
python src/train.py
```
Or run the `train` function directly here (ensure config is correct).

## 7. Evaluate Model Performance

After training, we evaluate the model on a held-out test set (e.g., a different subset of LUNA16). We calculate metrics like Accuracy, Sensitivity, and Specificity.

*Note: Since we haven't trained a model yet, this section is a placeholder.*

## 8. Visualize Prediction Results

We can visualize the model's predictions on new data.

```python
# Example inference code
model.eval()
with torch.no_grad():
    # patch = ... load a patch ...
    # output = model(patch)
    # print(f"Prediction: {output.item()}")
    pass
```