### **Fine-Tuning 3D ResNet for Pulmonary Nodule Classification in CT Scans**

* This notebook focuses on fine-tuning a pre-trained 3D ResNet model for the classification of pulmonary nodules from CT scan images. 

* The dataset consists of 3D medical images, and the goal is to refine the model's ability to detect nodules by leveraging transfer learning. 

* Preprocessing steps such as patch extraction, resizing, and normalization are applied to prepare the data for training. 



#### **Set Up the Environment**

In [None]:
!pip install numpy pandas SimpleITK matplotlib scipy seaborn tensorflow keras 

^C




In [1]:
import pandas as pd
import SimpleITK as sitk
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.utils import Sequence
import matplotlib.pyplot as plt

In [2]:
import cv2


In [None]:
import torch
import torch.nn as nn
import torchvision.models.video as models
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset
import torch.nn.functional as F


#### **Step 1 : Data Preparation**


**Loading the CSV File**

In [4]:
def load_csv(filename):
    return pd.read_csv(f"Data/{filename}")

**Loading subset**

In [5]:
def load_ct_scan(mhd_path):
    if not os.path.exists(mhd_path):
        raise FileNotFoundError(f"File not found: {mhd_path}")
    try:
        return sitk.ReadImage(mhd_path)
    except RuntimeError as e:
        print(f"Error loading {mhd_path}: {e}")
        return None

In [6]:
def load_image(self, filename):
        try:
            itkimage = sitk.ReadImage(filename)
            numpyImage = sitk.GetArrayFromImage(itkimage).astype(np.float32)  
            numpyOrigin = np.array(itkimage.GetOrigin())[::-1]  
            numpySpacing = np.array(itkimage.GetSpacing())[::-1]  
            
            return numpyImage, numpyOrigin, numpySpacing
        except Exception as e:
            print(f" Error loading {filename}: {e}")
            return None, None, None

#### **Step 2 : Data preprocessing**

Since we are fine-tuning 3D ResNet model we need to rescale intensity values to the range [0,1] for optimal performance.

In [7]:
def normalize_image(image_array):
    image_array = image_array.astype(np.float32)
    
    # Normalize to the range [0, 1] by dividing by the max intensity
    image_array = image_array / np.max(image_array)  
    
    return image_array

Begin by extracting image patches from the CT scan data

In [8]:
def get_nodule_coords_from_csv(csv_path, seriesuid):
    
    df = pd.read_csv(csv_path)
    nodules = df[df['seriesuid'] == seriesuid]
    
    if nodules.empty:
        print(f"No nodules found for seriesuid: {seriesuid}")
    else:
        print(f"Found {len(nodules)} nodules for seriesuid: {seriesuid}")
    
    coords = nodules[['coordX', 'coordY', 'coordZ']].values
    return coords

In [9]:
def real_to_voxel_coordinates(ct_scan, real_coord):

    spacing = np.array(ct_scan.GetSpacing())  
    origin = np.array(ct_scan.GetOrigin())    
    
    voxel_coord = (real_coord - origin) / spacing
    return np.round(voxel_coord).astype(int)

In [None]:
def pad_or_crop_patch(patch, target_size=(64, 64, 64)):
    if patch.ndim != 3:
        raise ValueError(f"Expected 3D patch, but got {patch.ndim}D input.")

    depth, height, width = patch.shape
    
    pad_depth = (target_size[0] - depth) // 2
    pad_height = (target_size[1] - height) // 2
    pad_width = (target_size[2] - width) // 2
    
    crop_depth = max(0, depth - target_size[0])
    crop_height = max(0, height - target_size[1])
    crop_width = max(0, width - target_size[2])
    
    patch = patch[crop_depth:depth-crop_depth, crop_height:height-crop_height, crop_width:width-crop_width]
    
    patch = np.pad(patch, 
                   ((pad_depth, target_size[0] - patch.shape[0] - pad_depth), 
                    (pad_height, target_size[1] - patch.shape[1] - pad_height),
                    (pad_width, target_size[2] - patch.shape[2] - pad_width)),
                   mode='constant', constant_values=0)

    return patch

In [11]:
def extract_patch(ct_scan, coordX, coordY, coordZ, patch_size=(64, 64, 64)):
    voxel_coord = real_to_voxel_coordinates(ct_scan, [coordX, coordY, coordZ])
    #print(f"Extracting patch at voxel coordinates: {voxel_coord}")  # Debug print
    
    start = voxel_coord - np.array(patch_size) // 2
    end = voxel_coord + np.array(patch_size) // 2
    
    size = ct_scan.GetSize()
    overlap_start = np.maximum(start, 0)
    overlap_end = np.minimum(end, np.array(size))
    actual_patch_size = overlap_end - overlap_start
    
    #print(f"Calculated patch size: {actual_patch_size}")  # Debug print
    
    if np.any(actual_patch_size <= 0):
        #print(f"Invalid patch size for coordinates {voxel_coord}")  # Debug print
        return None
    
    patch = sitk.RegionOfInterest(ct_scan, size=np.array(actual_patch_size, dtype=np.uint32).tolist(), index=overlap_start.tolist())
    patch_array = sitk.GetArrayFromImage(patch)
    
    # Ensure patch size is correct, otherwise crop/pad
    if patch_array.shape != patch_size:
        #print(f"Padding or cropping patch from {patch_array.shape} to {patch_size}")  # Debug print
        patch_array = pad_or_crop_patch(patch_array, patch_size)
    
    return patch_array


In [12]:
def extract_nodule_patches(ct_scan, csv_path, seriesuid, patch_size=(64, 64, 64)):
    nodule_coords = get_nodule_coords_from_csv(csv_path, seriesuid)
    #print(f"Nodule coordinates for scan {seriesuid}: {nodule_coords}")  # Debug print
    
    patches = []
    labels = []  
    
    for coordX, coordY, coordZ in nodule_coords:
        patch = extract_patch(ct_scan, coordX, coordY, coordZ, patch_size)
        if patch is not None:
            # Ensure the patch has the correct size
            if patch.shape != patch_size:
                patch = pad_or_crop_patch(patch, patch_size)
            patches.append(patch)
            labels.append(1)  # Label 1 for nodules
           # print(f"Extracted {len(patches)} patches for scan")  # Debug print

    if not patches:
        print(f"No patches were extracted for scan {seriesuid}.")  # Debug print for empty patches

    return patches, labels


##### **Dataset**

In [None]:
class LunaDataset(Dataset):
    def __init__(self, base_path, csv_path, patch_size=(64, 64, 64), transform=None):
        """
        PyTorch dataset for loading 3D patches from LUNA16 dataset.

        Args:
        - base_path (str): Path to the LUNA16 dataset directory.
        - csv_path (str): Path to the CSV file containing nodule annotations.
        - patch_size (tuple): Size of the extracted patches (Depth, Height, Width).
        - transform (callable, optional): Optional transformations on the data.
        """
        self.base_path = base_path
        self.csv_path = csv_path
        self.patch_size = patch_size
        self.transform = transform
        self.mhd_files = self._get_mhd_files()

    def _get_mhd_files(self):
        """Retrieve all .mhd files from the dataset directory."""
        mhd_files = []
        for subset in os.listdir(self.base_path):
            subset_path = os.path.join(self.base_path, subset)
            if os.path.isdir(subset_path):
                for filename in os.listdir(subset_path):
                    if filename.endswith(".mhd"):
                        mhd_files.append(os.path.join(subset_path, filename))
        return mhd_files

    def __len__(self):
        return len(self.mhd_files)

    def __getitem__(self, idx):
        """Load a CT scan, extract patches, and return as tensors."""
        mhd_path = self.mhd_files[idx]
        seriesuid = os.path.basename(mhd_path).replace('.mhd', '')  
        ct_scan = load_ct_scan(mhd_path)  

        if ct_scan is None:
            return None  

        # Extract patches and labels
        patches, labels = extract_nodule_patches(ct_scan, self.csv_path, seriesuid, self.patch_size)

        if patches is None or labels is None:
            return None  

        patches = np.array(patches)  # (num_patches, D, H, W)
        labels = np.array(labels)

        # Add channel dimension if missing (C=1 for grayscale)
        if patches.ndim == 4:
            patches = np.expand_dims(patches, axis=1)  # Shape: (num_patches, 1, D, H, W)

        # Convert to PyTorch tensors
        patches = torch.tensor(patches, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.long)

        # Apply any data augmentations or transformations
        if self.transform:
            patches = self.transform(patches)

        return patches, labels


#### **Step 3 : Modeling**


In this step we are going to load 3D ResNet, modify it for binary classification (nodules exists vs absent), and prepare for fine-tuning.

##### **1 - Load Pretrained 3D ResNet**

In [27]:
def load_pretrained_3d_resnet(num_classes=2):
    print("start loading model")
    model = models.r3d_18(pretrained=True) 

    model.stem[0] = nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)

    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 1),
    )
    return model

##### **2 - Define Loss & Optimizer**

In [28]:
def compile_model(model, learning_rate=0.0001):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_function = nn.BCEWithLogitsLoss()  
    return optimizer, loss_function

##### **3 - Fine-tuning**

In [16]:
def freeze_layers(model, num_layers_to_freeze=5):
    layer_count = 0
    for child in model.children():
        if isinstance(child, nn.Conv3d):
            layer_count += 1
            if layer_count <= num_layers_to_freeze:
                for param in child.parameters():
                    param.requires_grad = False
    return model

In [31]:
def train_model(model, train_loader, val_loader, optimizer, loss_function, epochs, device):
    print("start training")
    labels = labels.view(-1, 1)  # Assure que target a la forme [batch_size, 1]
    labels = labels.float()  # Assure que les labels sont en float

    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        
        for inputs, labels in train_loader:
            # Ensure inputs are already correctly shaped and converted to tensors
            inputs = inputs.to(dtype=torch.float32, device=device)  # No need for stacking
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = loss_function(outputs, labels)

            # Backward pass
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss:.4f}")

    return model


In [24]:
def fine_tune_model(model, train_loader, val_loader, optimizer, loss_function, fine_tune_from=5, epochs=10, device='cuda'):
    print("start unfreezing")
    layer_count = 0
    for child in model.children():
        layer_count += 1
        if layer_count >= fine_tune_from:
            for param in child.parameters():
                param.requires_grad = True
    
    return train_model(model, train_loader, val_loader, optimizer, loss_function, epochs, device)

##### **4 - Evaluation**

In [25]:
def evaluate_model(model, test_loader, device='cuda'):
    print("start evaluating")
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total
    print(f'Test Accuracy: {accuracy * 100:.2f}%')
    return accuracy

In [None]:
def pad_patches(patches):
    if len(patches) == 0:
        print("No patches found!")
        return None
    
    max_depth = max(patch.shape[0] for patch in patches)
    max_channels = max(patch.shape[1] for patch in patches)
    max_height = max(patch.shape[2] for patch in patches)
    max_width = max(patch.shape[3] for patch in patches)

    padded_patches = []
    for patch in patches:
        z_pad = max_depth - patch.shape[0]  
        c_pad = max_channels - patch.shape[1]  
        h_pad = max_height - patch.shape[2] 
        w_pad = max_width - patch.shape[3]  

        patch = F.pad(patch, (0, w_pad, 0, h_pad, 0, z_pad), value=0)

        padded_patches.append(patch)

    return torch.stack(padded_patches)


In [None]:
def collate_fn(batch):
    batch_patches = []
    batch_labels = []

    for item in batch:
        if item is None:
            continue 

        patches, labels = item 

        print(patches.shape)

        batch_patches.append(patches)
        batch_labels.append(labels)

    if not batch_patches:
        return None, None

    batch_patches = torch.cat(batch_patches, dim=0)  
    batch_labels = torch.cat(batch_labels, dim=0) 

    return batch_patches, batch_labels


In [29]:
def main():
    base_path = "E:/Luna"  
    batch_size = 3
    patch_size = (64, 64, 64)

    train_csv = "Data/train_df.csv"
    val_csv = "Data/val_df.csv"
    test_csv = "Data/test_df.csv"

    train_dataset = LunaDataset(base_path=base_path, csv_path=train_csv)
    val_dataset = LunaDataset(base_path=base_path, csv_path=val_csv)
    test_dataset = LunaDataset(base_path=base_path, csv_path=test_csv)
    
    train_loader = DataLoader(train_dataset,batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset,batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset,batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    padded_ct_scans, labels = collate_fn(train_dataset)

    # Load pretrained 3D ResNet model
    model = load_pretrained_3d_resnet(num_classes=2)  

    # Freeze initial layers for transfer learning
    model = freeze_layers(model, num_layers_to_freeze=5)

    # Compile model with optimizer and loss function
    optimizer, loss_function = compile_model(model, learning_rate=0.0001)

    # Set device to GPU if available, otherwise CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Train the model on training data
    model = train_model(model, train_loader, val_loader, optimizer, loss_function, epochs=10, device=device)

    # Fine-tune the model by unfreezing layers
    model = fine_tune_model(model, train_loader, val_loader, optimizer, loss_function, fine_tune_from=5, epochs=5, device=device)

    # Evaluate the model on the test set
    accuracy = evaluate_model(model, test_loader, device=device)
    print(f"Test Accuracy: {accuracy * 100:.2f}%")


In [32]:
if __name__ == "__main__":
    main()

Found 1 nodules for seriesuid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.105756658031515062000744821260
torch.Size([1, 1, 64, 64, 64])
Found 2 nodules for seriesuid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.108197895896446896160048741492
torch.Size([2, 1, 64, 64, 64])
Found 3 nodules for seriesuid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.109002525524522225658609808059
torch.Size([3, 1, 64, 64, 64])
Found 3 nodules for seriesuid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.111172165674661221381920536987
torch.Size([3, 1, 64, 64, 64])
Found 3 nodules for seriesuid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.122763913896761494371822656720
torch.Size([3, 1, 64, 64, 64])
Found 4 nodules for seriesuid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.124154461048929153767743874565
torch.Size([4, 1, 64, 64, 64])
Found 1 nodules for seriesuid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.126121460017257137098781143514
torch.Size([1, 1, 64, 64, 64])
Found 9 nodules for seriesuid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.126264578931778258890371755354
torch.Size([9, 

RuntimeError: [enforce fail at alloc_cpu.cpp:115] data. DefaultCPUAllocator: not enough memory: you tried to allocate 3391094784 bytes.