In [1]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [2]:
import torch
import torch.nn as nn
from torchvision import models
from torch.nn.functional import relu
import torch.nn.functional as F

from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class CustomDataset(Dataset):
    def __init__(self, image_paths, label_paths):
        self.image_paths = image_paths
        self.label_paths = label_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the image
        image = Image.open(self.image_paths[idx])
        np_image = np.array(image, dtype=np.float32)

        # Normalize the image
        normalized_image = np_image / 65535.0  # For 16-bit images

        # Load and process the label data
        label_image = Image.open(self.label_paths[idx])
        label_array = np.array(label_image, dtype=np.float32)

        grayscale_to_class_mapping = {0: 0, 128: 1, 255: 2} # a set that maps gray-levels to a class

        # Map grayscale values to class labels
        mapped_labels = np.copy(label_array)
        for grayscale_value, class_id in grayscale_to_class_mapping.items():
            mapped_labels[label_array == grayscale_value] = class_id

        # Convert to PyTorch tensors
        image_tensor = torch.from_numpy(normalized_image).unsqueeze(0) # unsqueeze to enable channel dimension, was gone due to being a grayscale image
        label_tensor = torch.from_numpy(mapped_labels)

        return image_tensor, label_tensor


In [3]:
### Label images ###
# white class - 255 nickel
# gray class - 128 ysz
# black class - 0 pores

class UNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        # Define a helper function for creating a block
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.Dropout(p=0.1)
            )

        # Encoder
        self.e11 = conv_block(1, 64)
        self.e12 = conv_block(64, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e21 = conv_block(64, 128)
        self.e22 = conv_block(128, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e31 = conv_block(128, 256)
        self.e32 = conv_block(256, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e41 = conv_block(256, 512)
        self.e42 = conv_block(512, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e51 = conv_block(512, 1024)
        self.e52 = conv_block(1024, 1024)

        # Decoder
        self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.d11 = conv_block(1024, 512)
        self.d12 = conv_block(512, 512)

        self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.d21 = conv_block(512, 256)
        self.d22 = conv_block(256, 256)

        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.d31 = conv_block(256, 128)
        self.d32 = conv_block(128, 128)

        self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.d41 = conv_block(128, 64)
        self.d42 = conv_block(64, 64)

        # Output layer
        self.outconv = nn.Conv2d(64, n_class, kernel_size=1)

    def forward(self, x):
        # Encoder
        xe11 = self.e11(x)
        xe12 = self.e12(xe11)
        xp1 = self.pool1(xe12)

        xe21 = self.e21(xp1)
        xe22 = self.e22(xe21)
        xp2 = self.pool2(xe22)

        xe31 = self.e31(xp2)
        xe32 = self.e32(xe31)
        xp3 = self.pool3(xe32)

        xe41 = self.e41(xp3)
        xe42 = self.e42(xe41)
        xp4 = self.pool4(xe42)

        xe51 = self.e51(xp4)
        xe52 = self.e52(xe51)

        # Decoder
        xu1 = self.upconv1(xe52)
        xu11 = torch.cat([xu1, xe42], dim=1)
        xd11 = self.d11(xu11)
        xd12 = self.d12(xd11)

        xu2 = self.upconv2(xd12)
        xu22 = torch.cat([xu2, xe32], dim=1)
        xd21 = self.d21(xu22)
        xd22 = self.d22(xd21)

        xu3 = self.upconv3(xd22)
        xu33 = torch.cat([xu3, xe22], dim=1)
        xd31 = self.d31(xu33)
        xd32 = self.d32(xd31)

        xu4 = self.upconv4(xd32)
        xu44 = torch.cat([xu4, xe12], dim=1)
        xd41 = self.d41(xu44)
        xd42 = self.d42(xd41)

        # Output layer
        out = self.outconv(xd42)

        return out

In [4]:
from torch.utils.data import DataLoader, random_split
from torch import optim
import torch
import os
import matplotlib.pyplot as plt

def dice_coefficient(predicted, target, num_classes):
    dice_scores = []  # To store dice coefficient for each class

    # Convert predictions and targets to one-hot encoded form
    predicted_one_hot = F.one_hot(predicted, num_classes).permute(0, 3, 1, 2).float()
    target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float()

    # Calculate Dice coefficient for each class
    for class_index in range(num_classes):
        intersection = (predicted_one_hot[:, class_index, :, :] * target_one_hot[:, class_index, :, :]).sum()
        union = predicted_one_hot[:, class_index, :, :].sum() + target_one_hot[:, class_index, :, :].sum()
        dice_score = (2 * intersection + 1e-6) / (union + 1e-6)  # Adding a small epsilon to avoid division by zero
        dice_scores.append(dice_score)

    # Average Dice score across all classes
    avg_dice_score = sum(dice_scores) / len(dice_scores)
    return avg_dice_score.item()  # Return the value as a Python scalar

def get_image_paths(data_dir, label_dir):
    data_paths = [os.path.join(data_dir, img) for img in sorted(os.listdir(data_dir))]
    label_paths = [os.path.join(label_dir, lbl) for lbl in sorted(os.listdir(label_dir))]
    return data_paths, label_paths

def create_subsets(dataset, subset_sizes):
    subsets = {}
    for size in subset_sizes:
        if size == len(dataset):
            subsets[size] = dataset  # Use the full dataset
        else:
            subset, _ = random_split(dataset, [size, len(dataset) - size])
            subsets[size] = subset
    return subsets


In [5]:
# Define your dataset paths
data_dir = '/content/gdrive/MyDrive/training_dataset/data_crop64/'
label_dir = '/content/gdrive/MyDrive/training_dataset/label_crop64/'

# Get image paths and create the full dataset
image_paths, label_paths = get_image_paths(data_dir, label_dir)
dataset = CustomDataset(image_paths=image_paths, label_paths=label_paths)

# Define subset sizes including the full dataset size
subset_sizes = [50, 125, 250, len(dataset)]  # Add the full dataset size

# Create subsets
dataset_subsets = create_subsets(dataset, subset_sizes)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

# Training configurations
learning_rate = 0.001
num_epochs = 250  # Adjust as needed

# Loop over subsets and train the model
for size, subset in dataset_subsets.items():
    print(f"\nTraining on subset size: {size}")

    # Split the subset into training, validation, and test datasets
    train_size = int(0.70 * len(subset))
    val_size = int(0.15 * len(subset))
    test_size = len(subset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = random_split(subset, [train_size, val_size, test_size])

    # DataLoader setup
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Model, loss function, and optimizer setup
    model = UNet(n_class=3).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            labels = labels.squeeze(1).long()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item()}")

        # Validation phase
        model.eval()
        with torch.no_grad():
            val_loss = 0
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                labels = labels.squeeze(1).long()
                loss = criterion(outputs, labels)
                val_loss += loss.item()
            val_loss /= len(val_loader)
            print(f"Validation Loss after Epoch {epoch+1}: {val_loss}")

    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        test_loss = 0
        correct = 0
        total = 0
        dice_scores = []

        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)
            _, predicted = torch.max(probabilities, 1)
            labels = labels.squeeze(1).long()

            loss = criterion(outputs, labels)
            test_loss += loss.item()
            total += labels.numel()
            correct += (predicted == labels).sum().item()

            dice_score = dice_coefficient(predicted, labels, num_classes=3)
            dice_scores.append(dice_score)

        test_loss /= len(test_loader)
        test_accuracy = 100 * correct / total
        average_dice_score = sum(dice_scores) / len(dice_scores)

        print(f"Subset size {size} - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%, Average Dice Score: {average_dice_score:.4f}")


Using cuda

Training on subset size: 50
Epoch 1/250, Batch 1/2, Loss: 1.0780671834945679
Validation Loss after Epoch 1: 1.0825315713882446
Epoch 2/250, Batch 1/2, Loss: 0.42933645844459534
Validation Loss after Epoch 2: 1.074283480644226
Epoch 3/250, Batch 1/2, Loss: 0.2991139888763428
Validation Loss after Epoch 3: 1.0712789297103882
Epoch 4/250, Batch 1/2, Loss: 0.2659841477870941
Validation Loss after Epoch 4: 1.0289161205291748
Epoch 5/250, Batch 1/2, Loss: 0.24353034794330597
Validation Loss after Epoch 5: 0.9424404501914978
Epoch 6/250, Batch 1/2, Loss: 0.21740645170211792
Validation Loss after Epoch 6: 0.82953941822052
Epoch 7/250, Batch 1/2, Loss: 0.19537435472011566
Validation Loss after Epoch 7: 0.7662394642829895
Epoch 8/250, Batch 1/2, Loss: 0.18010041117668152
Validation Loss after Epoch 8: 0.726604163646698
Epoch 9/250, Batch 1/2, Loss: 0.1640089899301529
Validation Loss after Epoch 9: 0.7192140817642212
Epoch 10/250, Batch 1/2, Loss: 0.1514086127281189
Validation Loss af

In [6]:
torch.save(model.state_dict(), '64x64smaller_subsets_with_250_Epoch.pth')