In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
import segmentation_models_pytorch as smp

# **Preprocessing function for single image**
def preprocess_ct_image(image_path):
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    image = Image.open(image_path).convert("L")
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

# **Dataset Class**
class LungDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.patients = sorted(os.listdir(img_dir))

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        patient_id = self.patients[idx]
        nodule_folder = os.path.join(self.img_dir, patient_id, "nodule-0")

        image_path = os.path.join(nodule_folder, "images")
        mask_paths = [os.path.join(nodule_folder, f"mask-{i}") for i in range(4)]

        image_slices = sorted(os.listdir(image_path))
        mask_slices = [sorted(os.listdir(mask_path)) for mask_path in mask_paths]

        # Pick **middle slice** 
        mid_slice = len(image_slices) // 2  
        img = Image.open(os.path.join(image_path, image_slices[mid_slice])).convert("L")
        img = self.transform(img) if self.transform else transforms.ToTensor()(img)

        # Get and average masks
        mask_stack = torch.stack([
            transforms.ToTensor()(Image.open(os.path.join(mask_paths[j], mask_slices[j][mid_slice])).convert("L"))
            for j in range(4)
        ])
        avg_mask = mask_stack.mean(dim=0)  # Average masks
        avg_mask = (avg_mask > 0.5).float()  # Binarize

        return img, avg_mask


# **Set base directory**
base_dir = "datasets/LIDC-IDRI-slices"

# **Transformations**
transform = transforms.Compose([
    transforms.Resize((128, 128)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# **Load dataset**
dataset = LungDataset(base_dir, base_dir, transform=transform)

# **Split into train and test**
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

print(f"Total images: {len(dataset)}, Train: {len(train_dataset)}, Test: {len(test_dataset)}")

# **Define U-Net Model with ResNet34 Backbone**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = smp.Unet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=1, classes=1).to(device)

# **Define Loss Function & Optimizer**
def dice_loss(pred, target, smooth=1.):
    pred = torch.sigmoid(pred)  
    intersection = (pred * target).sum()
    return 1 - ((2. * intersection + smooth) / (pred.sum() + target.sum() + smooth))

criterion = dice_loss
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# **Model Saving & Loading**
num_epochs = 10
checkpoint_path = f"lung_cancer_model_epoch_{num_epochs}.pth"

if os.path.exists(checkpoint_path):
    print(f"Loading existing model from {checkpoint_path}")
    model.load_state_dict(torch.load(checkpoint_path))
else:
    # **Training Loop**
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader):.4f}")

    # Save Model
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Model saved as {checkpoint_path}")

Total images: 875, Train: 700, Test: 175
Loading existing model from lung_cancer_model_epoch_10.pth


  model.load_state_dict(torch.load(checkpoint_path))


In [12]:

# **Testing Loop**
model.eval()
test_dice_score = 0
test_iou = 0
test_acc = 0

with torch.no_grad():
    for images, masks in test_loader:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        preds = torch.sigmoid(outputs) > 0.5  

        # Compute Dice Score
        intersection = (preds * masks).sum()
        dice_score = (2. * intersection) / (preds.sum() + masks.sum() + 1e-8)

        # Compute IoU
        iou = intersection / (preds.sum() + masks.sum() - intersection + 1e-8)

        # Compute Accuracy
        acc = (preds == masks).float().mean()

        test_dice_score += dice_score.item()
        test_iou += iou.item()
        test_acc += acc.item()

print(f"Test Dice Score: {test_dice_score/len(test_loader):.4f}")
print(f"Test IoU: {test_iou/len(test_loader):.4f}")
print(f"Test Accuracy: {test_acc/len(test_loader):.4f}")


Test Dice Score: 0.7841
Test IoU: 0.6480
Test Accuracy: 0.9982


In [14]:

# **Prediction Function with Tumor Size Calculation**
def predict_cancer_stage(image_path, model, device):
    model.eval()
    image = preprocess_ct_image(image_path).to(device)

    with torch.no_grad():
        output = model(image)
        output = torch.sigmoid(output)  
        tumor_mask = output > 0.5  

        tumor_size = tumor_mask.sum().item()  
        tumor_prob = output.mean().item()  

    # **Determine Cancer Stage**
    STAGE_THRESHOLDS = {0: 0.0, 1: 0.2, 2: 0.5, 3: 0.8}
    predicted_stage = max([stage for stage, threshold in STAGE_THRESHOLDS.items() if tumor_prob > threshold])

    print(f"Predicted Tumor Probability: {tumor_prob:.4f}")
    print(f"Predicted Tumor Size (in pixels): {tumor_size}")
    print(f"Predicted Cancer Stage: {predicted_stage}")
    
    return predicted_stage, tumor_size

# **Example Usage**
#image_path = "datasets/LIDC-IDRI-slices/LIDC-IDRI-0265/nodule-0/images/slice-3.png"
image_path = "testimages/1.jpg"
predicted_stage, tumor_size = predict_cancer_stage(image_path, model, device)

Predicted Tumor Probability: 0.0404
Predicted Tumor Size (in pixels): 618
Predicted Cancer Stage: 0
