In [None]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.transforms import v2
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.parallel import DataParallel

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, image_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_transform = image_transform
        
        self.images = os.listdir(self.image_dir)
        self.masks = os.listdir(self.mask_dir)
        
        # Sort images and masks for consistency
        self.images.sort()
        self.masks.sort()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.images[idx])
        mask_name = os.path.join(self.mask_dir, self.masks[idx])
    
        image = Image.open(img_name).convert("RGB")
        mask = Image.open(mask_name).convert("L") # Convert to grayscale mask
    
        # Convert the Image object to a numpy array
        mask = np.array(mask) / 255.0
    
        # Now you can perform the comparison
        mask = np.where(mask > 0, 1, mask)
    
        if self.transform:
            augmented = self.transform(image=np.array(image), mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            
        if self.image_transform:
            # Convert the tensor to a numpy array
            image_np = image.cpu().numpy()
            # Convert the numpy array back to a PIL Image
            image = self.image_transform(image)
    
        return image, mask


common_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    ToTensorV2()
], additional_targets={'image': 'image', 'mask': 'mask'})

image_transform = v2.Compose([
        v2.ColorJitter(brightness = .5),
        v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 2.5)),
        v2.ToDtype(torch.float32, scale=True)
])

train_image_dir = "../../images"
train_mask_dir = "../../masks"

test_image_dir = "../../images"
test_mask_dir = "../../masks"

train_dataset = SegmentationDataset(train_image_dir, train_mask_dir, transform=common_transform, image_transform=image_transform)
test_dataset = SegmentationDataset(test_image_dir, test_mask_dir, transform=common_transform, image_transform=image_transform)


In [None]:
def display_images_masks(dataset, start_idx, num_images):
    fig, ax = plt.subplots(num_images, 2, figsize=(12, 6*num_images))

    for i in range(num_images):
        image, mask = dataset[start_idx + i]

        # Convert tensors to numpy arrays for visualization
        if isinstance(image, torch.Tensor):
            image = image.permute(1, 2, 0).numpy()
        if isinstance(mask, torch.Tensor):
            mask = mask.squeeze().numpy()

        ax[i, 0].imshow(image)
        ax[i, 0].set_title(f'Image {start_idx + i}')
        ax[i, 1].imshow(mask, cmap='gray')
        ax[i, 1].set_title(f'Mask {start_idx + i}')

    plt.tight_layout()
    plt.show()

# Display the first 10 images and masks from the training dataset
display_images_masks(train_dataset, 10, 20)


In [None]:
sample_image, sample_mask = train_dataset[0]
print("Sample Image Shape:", sample_image.shape)
print("Sample Mask Shape:", sample_mask.shape)
print("Unique Values in mask:", np.unique(sample_mask))
print("Unique Values in image:", np.unique(sample_image))

In [None]:
len(train_dataset), len(test_dataset),

In [None]:
sample_image, sample_mask = train_dataset[0]
print("Sample Image Shape:", sample_image.shape)
print("Sample Mask Shape:", sample_mask.shape)

In [None]:
sample_image, sample_mask = test_dataset[0]
print("Sample Image Shape:", sample_image.shape)
print("Sample Mask Shape:", sample_mask.shape)

In [None]:
!pip install segmentation-models-pytorch

In [None]:
import segmentation_models_pytorch as smp

aux_params=dict(
    classes=1,
    dropout=0.2,
    activation=None,
)

model = smp.Unet(
    encoder_name="resnet152",  # You can choose different versions of EfficientNet (b0 to b7)
    encoder_weights= None,  # Use pre-trained weights if available
    in_channels=3,  # Number of input channels (RGB)
    aux_params=aux_params,
)

In [None]:
summary(model,input_size=(64,3,224,224))

In [None]:
# Hyperparameters
batch_size = 32
learning_rate = 0.001
num_epochs = 100

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Loss and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model = DataParallel(model)

In [None]:
len(train_loader), len(test_loader)

In [None]:
def dice_loss(pred, target, epsilon=1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    dice_coeff = (2. * intersection + epsilon) / (union + epsilon)
    return 1 - dice_coeff

In [None]:
import copy

In [None]:
losses = []
best_loss = float('inf')
best_model_wts = copy.deepcopy(model.state_dict())

# Training loop
for epoch in range(num_epochs):
    model.train()
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        outputs, _ = model(images)
        
        # Calculate loss
        loss = dice_loss(F.sigmoid(outputs), masks)
        losses.append(loss.item())
        
        # Check if this is the best model
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_model_wts = copy.deepcopy(model.state_dict())
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# Save the best model
torch.save(best_model_wts, "unet_coastline_detection_best.pth")

In [None]:
# Function to display images, masks, and predictions
def visualize_results(model, test_loader, num_images=40):
    model.eval()
    
    with torch.no_grad():
        for i, (images, masks) in enumerate(test_loader):
            if i >= num_images:
                break
            
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs, _ = model(images)
            predicted_masks = torch.sigmoid(outputs)
            
            images = images.cpu().numpy()
            masks = masks.cpu().numpy()
            predicted_masks = predicted_masks.cpu().numpy()
            
            print(masks.shape)
            print(images.shape)
            
            # Convert images to 0-1 range
            images = np.transpose(images, (0, 2, 3, 1))
            images = (images - np.min(images)) / (np.max(images) - np.min(images))
            
            # Convert masks to 0-1 range
            #masks = np.transpose(masks, (0, 2, 1))  # Corrected this line
            predicted_masks = np.transpose(predicted_masks, (0, 2, 3, 1))
            
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            
            # Display original image
            axes[0].imshow(images[0])
            axes[0].set_title("Original Image")
            axes[0].axis("off")
            
            # Display ground truth mask
            axes[1].imshow(masks[0].squeeze(), cmap='gray')
            axes[1].set_title("Ground Truth Mask")
            axes[1].axis("off")
            
            # Display predicted mask
            axes[2].imshow(predicted_masks[0].squeeze(), cmap='gray')
            axes[2].set_title("Predicted Mask")
            axes[2].axis("off")
            
            plt.show()

# Visualize the model's performance on the test dataset
visualize_results(model, test_loader)


In [None]:
def calculate_metrics(pred, target, threshold=0.5):
    pred_binary = (pred > threshold).float()
    target_binary = (target > 0.5).float()

    intersection = torch.sum(pred_binary * target_binary)
    union = torch.sum(pred_binary) + torch.sum(target_binary) - intersection

    # Dice Coefficient
    dice = (2 * intersection + 1e-6) / (torch.sum(pred_binary) + torch.sum(target_binary) + 1e-6)

    # Precision
    precision = (intersection + 1e-6) / (torch.sum(pred_binary) + 1e-6)

    # Recall
    recall = (intersection + 1e-6) / (torch.sum(target_binary) + 1e-6)

    # F1 Score
    f1_score = 2 * ((precision * recall) / (precision + recall + 1e-6))

    # Accuracy
    accuracy = torch.sum((pred_binary == target_binary).float()) / target_binary.numel()

    return dice.item(), precision.item(), recall.item(), f1_score.item(), accuracy.item()

In [None]:
def evaluate_model(model, test_loader, threshold=0.5):
    model.eval()

    dice_scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []
    accuracy_scores = []  # Added line

    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            outputs, _ = model(images)
            pred_masks = torch.sigmoid(outputs)

            # Calculate metrics
            for i in range(images.size(0)):
                dice, precision, recall, f1_score, accuracy = calculate_metrics(pred_masks[i], masks[i], threshold)  # Modified line
                dice_scores.append(dice)
                precision_scores.append(precision)
                recall_scores.append(recall)
                f1_scores.append(f1_score)
                accuracy_scores.append(accuracy)  # Added line

    # Calculate average scores
    avg_dice = sum(dice_scores) / len(dice_scores)
    avg_precision = sum(precision_scores) / len(precision_scores)
    avg_recall = sum(recall_scores) / len(recall_scores)
    avg_f1_score = sum(f1_scores) / len(f1_scores)
    avg_accuracy = sum(accuracy_scores) / len(accuracy_scores)  # Added line

    print(f"Avg Dice Coefficient: {avg_dice:.4f}")
    print(f"Avg Precision: {avg_precision:.4f}")
    print(f"Avg Recall: {avg_recall:.4f}")
    print(f"Avg F1 Score: {avg_f1_score:.4f}")
    print(f"Avg Accuracy: {avg_accuracy:.4f}")  # Added line

# Evaluate the model on the test dataset
evaluate_model(model, test_loader)

In [None]:
a = torch.Tensor(losses).to('cpu')
epochs = range(1, len(a) + 1)

# Plotting the loss curve
plt.figure(figsize=(10, 5))
plt.plot(epochs, a, label='Training Loss', color='b')
plt.title('Training Loss Curve')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()