In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision import datasets, transforms
import numpy as np

# Check if MPS is available and set the device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Define the U-Net model
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
            )

        def up_conv(in_channels, out_channels):
            return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

        self.encoder1 = conv_block(1, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = conv_block(512, 1024)

        self.upconv4 = up_conv(1024, 512)
        self.decoder4 = conv_block(1024, 512)
        self.upconv3 = up_conv(512, 256)
        self.decoder3 = conv_block(512, 256)
        self.upconv2 = up_conv(256, 128)
        self.decoder2 = conv_block(256, 128)
        self.upconv1 = up_conv(128, 64)
        self.decoder1 = conv_block(128, 64)

        self.conv_final = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        # Encoding
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool(e1))
        e3 = self.encoder3(self.pool(e2))
        e4 = self.encoder4(self.pool(e3))

        # Bottleneck
        b = self.bottleneck(self.pool(e4))

        # Decoding
        d4 = self.upconv4(b)
        d4 = torch.cat((d4, e4), dim=1)
        d4 = self.decoder4(d4)

        d3 = self.upconv3(d4)
        d3 = torch.cat((d3, e3), dim=1)
        d3 = self.decoder3(d3)

        d2 = self.upconv2(d3)
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.decoder2(d2)

        d1 = self.upconv1(d2)
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.decoder1(d1)

        return torch.sigmoid(self.conv_final(d1))

# Instantiate model and move it to device
model = UNet().to(device)


In [2]:
import os
class LungDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])  # No suffix needed
        
        try:
            # Load the image and mask
            image = Image.open(img_path).convert("L")
            mask = Image.open(mask_path).convert("L")
            
            # Apply transformations if specified
            if self.transform:
                image = self.transform(image)
                mask = self.transform(mask)
            
            return image, mask

        except (FileNotFoundError, OSError) as e:
            
            #print(f"Skipping file due to error: {e}")
            return self.__getitem__((idx + 1) % len(self))

# Set up transformations and dataset
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

train_dataset = LungDataset("CXR-images", "Mask-images", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)


In [25]:
# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader):.4f}")

# Save the model
torch.save(model.state_dict(), "unet_lung_segmentation.pth")


Epoch 1/20, Loss: 0.0804
Epoch 2/20, Loss: 0.0706
Epoch 3/20, Loss: 0.0682
Epoch 4/20, Loss: 0.0669
Epoch 5/20, Loss: 0.0613
Epoch 6/20, Loss: 0.0588
Epoch 7/20, Loss: 0.0564
Epoch 8/20, Loss: 0.0555
Epoch 9/20, Loss: 0.0603
Epoch 10/20, Loss: 0.0576
Epoch 11/20, Loss: 0.0531
Epoch 12/20, Loss: 0.0516
Epoch 13/20, Loss: 0.0501
Epoch 14/20, Loss: 0.0480
Epoch 15/20, Loss: 0.0461
Epoch 16/20, Loss: 0.0459
Epoch 17/20, Loss: 0.0441
Epoch 18/20, Loss: 0.0434
Epoch 19/20, Loss: 0.0432
Epoch 20/20, Loss: 0.0408


In [3]:
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, roc_auc_score
from sklearn.metrics import confusion_matrix

# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 20

def compute_metrics(outputs, masks):
    # Convert model outputs and true masks to binary predictions (0 or 1)
    preds = (outputs > 0.5).float()  # 0.5 threshold for binary classification
    true_labels = (masks > 0.5).float()

    # Flatten tensors to 1D arrays for metric calculations
    preds = preds.view(-1).cpu().detach().numpy()  # Detach before converting to numpy
    true_labels = true_labels.view(-1).cpu().detach().numpy()  # Detach before converting to numpy

    # Compute confusion matrix
    tn, fp, fn, tp = confusion_matrix(true_labels, preds).ravel()

    # Accuracy
    accuracy = (tp + tn) / (tp + tn + fp + fn)

    # Precision
    precision = tp / (tp + fp) if (tp + fp) != 0 else 0.0

    # Specificity
    specificity = tn / (tn + fp) if (tn + fp) != 0 else 0.0

    # AUC
    auc = roc_auc_score(true_labels, outputs.view(-1).cpu().detach().numpy())  # Detach here as well

    return accuracy, precision, specificity, auc


# Training loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    epoch_accuracy = 0
    epoch_precision = 0
    epoch_specificity = 0
    epoch_auc = 0
    
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Compute metrics
        accuracy, precision, specificity, auc = compute_metrics(outputs, masks)
        
        epoch_loss += loss.item()
        epoch_accuracy += accuracy
        epoch_precision += precision
        epoch_specificity += specificity
        epoch_auc += auc
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Print metrics after each epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader):.4f}, "
          f"Accuracy: {epoch_accuracy/len(train_loader):.4f}, "
          f"Precision: {epoch_precision/len(train_loader):.4f}, "
          f"Specificity: {epoch_specificity/len(train_loader):.4f}, "
          f"AUC: {epoch_auc/len(train_loader):.4f}")

# Save the model
torch.save(model.state_dict(), "unet_lung_segmentation2.pth")


Epoch 1/20, Loss: 0.5107, Accuracy: 0.7454, Precision: 0.0000, Specificity: 1.0000, AUC: 0.7720
Epoch 2/20, Loss: 0.2797, Accuracy: 0.8786, Precision: 0.7441, Specificity: 0.9620, AUC: 0.9478
Epoch 3/20, Loss: 0.1709, Accuracy: 0.9354, Precision: 0.9026, Specificity: 0.9685, AUC: 0.9771
Epoch 4/20, Loss: 0.1107, Accuracy: 0.9581, Precision: 0.9319, Specificity: 0.9772, AUC: 0.9902
Epoch 5/20, Loss: 0.0825, Accuracy: 0.9686, Precision: 0.9470, Specificity: 0.9820, AUC: 0.9949
Epoch 6/20, Loss: 0.0751, Accuracy: 0.9716, Precision: 0.9519, Specificity: 0.9837, AUC: 0.9956
Epoch 7/20, Loss: 0.0694, Accuracy: 0.9738, Precision: 0.9566, Specificity: 0.9853, AUC: 0.9962
Epoch 8/20, Loss: 0.0668, Accuracy: 0.9749, Precision: 0.9592, Specificity: 0.9862, AUC: 0.9966
Epoch 9/20, Loss: 0.0634, Accuracy: 0.9761, Precision: 0.9610, Specificity: 0.9869, AUC: 0.9968
Epoch 10/20, Loss: 0.0645, Accuracy: 0.9755, Precision: 0.9601, Specificity: 0.9866, AUC: 0.9969
Epoch 11/20, Loss: 0.0591, Accuracy: 0.