Import Libraries:

In [16]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import models
import torchvision.transforms as transforms
from torchvision.transforms import Compose, Resize, v2
from torchvision.transforms.functional import to_tensor

import os
import matplotlib.pyplot as plt

from PIL import Image

Import Dataset: 

In [17]:
class HarveyData(Dataset):
    def __init__(self, dataset_dir, transforms=None):
        super(HarveyData, self).__init__()
        self.dataset_dir = dataset_dir
        self.transforms = transforms

        self.pre_image_paths = sorted(os.listdir(os.path.join(dataset_dir, 'pre_img')))
        self.post_image_paths = sorted(os.listdir(os.path.join(dataset_dir, 'post_img')))
        self.mask_paths = sorted(os.listdir(os.path.join(dataset_dir, 'post_msk')))
        self.num_images = len(self.pre_image_paths)

    def __getitem__(self, idx):
        pre_image_path = os.path.join(self.dataset_dir, 'pre_img', self.pre_image_paths[idx])
        post_image_path = os.path.join(self.dataset_dir, 'post_img', self.post_image_paths[idx])
        mask_path = os.path.join(self.dataset_dir, 'post_msk', self.mask_paths[idx])

        pre_image = Image.open(pre_image_path)
        post_image = Image.open(post_image_path)
        mask = Image.open(mask_path)

        # Convert PIL Images to Tensors
        pre_image = to_tensor(pre_image)
        post_image = to_tensor(post_image)
        mask = to_tensor(mask)

        # Apply additional transformations if any
        if self.transforms:
            pre_image = self.transforms(pre_image)
            post_image = self.transforms(post_image)
            mask = self.transforms(mask)

        combined_image = torch.cat([pre_image, post_image], dim=0)
        return combined_image, mask

    def __len__(self):
        return self.num_images

DeepLabV3 Model: 

In [18]:
class DeepLabV3(nn.Module):
    def __init__(self, num_input_channels, num_classes):
        super(DeepLabV3, self).__init__()
        self.deeplabv3 = torchvision.models.segmentation.deeplabv3_resnet50(num_classes=num_classes, weights=None)

        # Modify the first convolutional layer of the ResNet50 backbone
        self.deeplabv3.backbone.conv1 = nn.Conv2d(num_input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Initialize the modified convolutional layer
        nn.init.kaiming_normal_(self.deeplabv3.backbone.conv1.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        return self.deeplabv3(x)['out']

Training and Testing: 

In [19]:
batch_size = 16
num_input_channels = 6
num_classes = 4
lr = 1e-5
image_size = 224

transforms = v2.Compose([
    v2.Resize((image_size, image_size), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5),
    v2.RandomRotation(degrees=(1, 359)),
    v2.RandomResizedCrop(size=image_size, antialias=True)
    ])

cwd = os.getcwd()

train_dataset = HarveyData(os.path.join(cwd, 'dataset//training'), transforms=transforms)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = HarveyData(os.path.join(cwd, 'dataset//testing'), transforms=transforms)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = DeepLabV3(num_input_channels, num_classes).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

num_epochs = 20

predicted_images = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for i, data in enumerate(train_dataloader):
        image, mask = data
        image = image.to(device)
        mask = mask.to(device).long()  # Ensure mask is LongTensor

        optimizer.zero_grad()
        outputs = model(image)

        # Resize outputs to match mask size if necessary
        outputs = nn.functional.interpolate(outputs, size=mask.shape[-2:], mode='bilinear', align_corners=True)

        loss = criterion(outputs, mask)  # mask is already LongTensor
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        print('Batch %d --- Loss: %.4f' % (i, loss.item() / batch_size))

    print('Epoch %d / %d --- Average Loss: %.4f' % (epoch + 1, num_epochs, epoch_loss / len(train_dataset)))

    model.eval()
    total_loss = 0.0
    correct_predictions = 0
    total_pixels = 0
    dice_score = 0.0

    with torch.no_grad():
        for i, data in enumerate(test_dataloader):
            image, mask = data
            image = image.to(device)
            mask = mask.to(device)

            outputs = model(image)
            outputs = nn.functional.interpolate(outputs, size=mask.shape[-2:], mode='bilinear', align_corners=True)

            loss = criterion(outputs, mask.long())  # Ensure the mask is in long format
            total_loss += loss.item()

            predicted = torch.argmax(outputs, dim=1)
            correct_predictions += (predicted == mask).sum().item()
            total_pixels += mask.numel()

            # Calculate Dice coefficient
            dice_score += (2 * (predicted & mask).sum()) / ((predicted + mask).sum() + 1e-8)

    accuracy = 100.0 * correct_predictions / total_pixels
    average_loss = total_loss / len(test_dataset)
    dice_score = dice_score / len(test_dataset)

    print(f'Epoch {epoch + 1}/{num_epochs} --- Test Loss: {average_loss:.4f}, Accuracy: {accuracy:.2f}%, Dice: {dice_score:.4f}')

fig, axs = plt.subplots(8, 3, figsize=(32, 32))

for i in range(8):
    # Plot the input image
    image, mask = test_dataset.__getitem__(i)
    axs[i, 0].imshow(image.numpy()[3:6, :, :].T, aspect='equal')
    axs[i, 0].set_title("Input Image")
    axs[i, 0].axis('off')

    # Plot the predicted image
    predicted_images_flat = [item for sublist in predicted_images for item in sublist]
    axs[i, 1].imshow(predicted_images_flat[i].T, cmap="viridis", aspect='equal')  # Adjust the colormap as needed
    axs[i, 1].set_title("Predicted Image")
    axs[i, 1].axis('off')

    # Plot the ground truth mask
    axs[i, 2].imshow(mask.numpy()[0].T, cmap="viridis", aspect='equal')  # Assuming the mask is a single-channel image
    axs[i, 2].set_title("Ground Truth Mask")
    axs[i, 2].axis('off')

plt.show()

RuntimeError: Expected floating point type for target with class probabilities, got Long