In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
from tqdm import tqdm

  check_for_updates()


In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super().__init__()

        # Define the downsampling layers (contracting path)
        self.downs = nn.ModuleList()

        # Define the upsampling layers (expanding path)
        self.ups_transpose = nn.ModuleList()  # List for ConvTranspose2d layers
        self.ups_conv = nn.ModuleList()       # List for DoubleConv blocks

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down path of U-Net
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up path of U-Net (upsampling)
        for feature in reversed(features):
            self.ups_transpose.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups_conv.append(DoubleConv(feature*2, feature))

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)

        # Final 1x1 convolution to get the output channels
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def down(self, x, skip_connections):
        """
        This method performs the downsampling (contracting) part of the U-Net.
        It appends the feature map to skip_connections and pools the output.
        """
        for down in self.downs:
            x = down(x)  # Apply DoubleConv block
            skip_connections.append(x)
            x = self.pool(x)  # Apply max-pooling for downsampling
        return x

    def up(self, x, skip_connections):
        """
        This method performs the upsampling (expanding) part of the U-Net.
        It applies ConvTranspose2d followed by DoubleConv at each step.
        """
        skip_connections = skip_connections[::-1]  # Reverse the skip connections list
        for idx in range(len(self.ups_transpose)):  # Loop through the transpose layers
            x = self.ups_transpose[idx](x)  # Apply ConvTranspose2d to upsample
            skip_connection = skip_connections[idx]  # Get the corresponding skip connection

            # Resize the upsampled output to match the skip connection shape
            #(this will occur if the image pixels aren't divided by 2)
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            # Concatenate the skip connection with the upsampled feature map
            concat_skip = torch.cat((skip_connection, x), dim=1)

            # Apply DoubleConv block to refine the concatenated feature map
            x = self.ups_conv[idx](concat_skip)

        return x

    def forward(self, x):
        skip_connections = []
        x = self.down(x, skip_connections)
        x = self.bottleneck(x)
        x = self.up(x, skip_connections)
        return self.final_conv(x)

In [3]:
class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

# Step 3: Set dataset paths using current working directory
current_dir = os.getcwd()  # Get current working directory (useful for Jupyter Notebooks)
train_images_dir = os.path.join(current_dir, "train")
train_masks_dir = os.path.join(current_dir, "train_masks")
verification_images_dir = os.path.join(current_dir, "validation")
verification_masks_dir = os.path.join(current_dir, "validation_masks")

def get_loaders(
    train_images_dir,
    train_masks_dir,
    verification_images_dir,
    verification_masks_dir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    # Create train dataset
    train_ds = CarvanaDataset(
        image_dir=train_images_dir,       # Use the updated variable for train images
        mask_dir=train_masks_dir,         # Use the updated variable for train masks
        transform=train_transform,
    )

    # Create train DataLoader
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    # Create validation dataset
    val_ds = CarvanaDataset(
        image_dir=verification_images_dir,   # Use the updated variable for validation images
        mask_dir=verification_masks_dir,     # Use the updated variable for validation masks
        transform=val_transform,
    )

    # Create validation DataLoader
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

In [4]:
device = torch.device('cuda')
batch_size = 4
NUM_WORKERS = 2
IMAGE_HEIGHT = 160  # 1280 originally
IMAGE_WIDTH = 240  # 1918 originally
PIN_MEMORY = True
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

# Define the transformation for validation
val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

train_loader, val_loader = get_loaders(
    train_images_dir,          # Path to training images
    train_masks_dir,           # Path to training masks
    verification_images_dir,   # Path to validation images
    verification_masks_dir,    # Path to validation masks
    batch_size,
    train_transform,           # Training transformations
    val_transforms,            # Validation transformations
    NUM_WORKERS,
    PIN_MEMORY,
)

In [5]:
# Load Pretrained Model
model = torch.load(
    'unet_complete_model.pth',
    map_location=device,
    weights_only=False  # Required if model contains custom classes
).to(device)
model.eval()

UNET(
  (downs): ModuleList(
    (0): DoubleConv(
      (conv): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): DoubleConv(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=T

In [6]:
def evaluate(loader, delta=None):
    dice_scores = []
    model.eval()
    
    with torch.no_grad():
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)
            
            if delta is not None:
                images = torch.clamp(images + delta, 0, 1)
            
            outputs = model(images)
            
            # Apply sigmoid and threshold
            preds = torch.sigmoid(outputs)
            preds = (preds > 0.5).float()
            
            # Reshape masks to match preds: (B, H, W) -> (B, 1, H, W)
            masks = masks.unsqueeze(1)
            
            dice = dice_score(preds, masks)
            dice_scores.append(dice.item())
    
    model.train()
    return sum(dice_scores) / len(dice_scores)

def dice_score(pred, target):
    smooth = 1e-5
    pred = pred.contiguous().view(-1)
    target = target.contiguous().view(-1)
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

In [7]:
def deepfool_attack(model, image, mask, num_classes=2, max_iter=50, overshoot=0.02, device='cuda'):
    """
    DeepFool attack implementation for segmentation models
    """
    model.eval()
    image = image.clone().detach().to(device)
    mask = mask.clone().detach().to(device)
    
    batch_size, _, H, W = image.shape
    pert_image = image.clone().detach().requires_grad_(True)
    
    # Initialize perturbation
    r_total = torch.zeros_like(image).to(device)
    loop_i = 0
    
    with torch.enable_grad():
        while loop_i < max_iter:
            pert_image.requires_grad = True
            output = model(pert_image)
            
            # Convert mask to class indices (0 or 1)
            target = (mask > 0.5).float()
            
            # Calculate current classification
            pred = (torch.sigmoid(output) > 0.5).float()
            correct = (pred == target).all()
            if correct:
                break
                
            # Compute gradients
            loss = torch.nn.functional.binary_cross_entropy_with_logits(output, target)
            grad = torch.autograd.grad(loss, pert_image, retain_graph=False)[0]
            
            # Compute perturbation direction
            w = grad / (grad.norm() + 1e-8)
            r_i = (loss + 1e-4) * w
            
            # Accumulate perturbation
            r_total = (r_total + r_i).clamp(-overshoot, overshoot)
            pert_image = torch.clamp(image + r_total, 0, 1).detach()
            
            loop_i += 1
            
    return pert_image.detach(), r_total.detach()

In [8]:
def plot_and_save_results(image, pert_image, clean_pred, pert_pred, mask, save_dir, filename):
    """Plot and save attack results with proper perturbation scaling"""
    os.makedirs(save_dir, exist_ok=True)
    
    plt.figure(figsize=(20, 10))
    
    # Convert tensors to numpy arrays
    image_np = image.detach().cpu().permute(1, 2, 0).numpy()
    pert_image_np = pert_image.detach().cpu().permute(1, 2, 0).numpy()
    perturbation_np = pert_image_np - image_np  # Range [-0.02, 0.02]
    
    # Normalize perturbation for visualization
    max_val = np.abs(perturbation_np).max()
    perturbation_normalized = (perturbation_np + max_val) / (2 * max_val)  # [0, 1]
    
    # Get predictions
    clean_pred_np = clean_pred.detach().cpu().squeeze().numpy()
    pert_pred_np = pert_pred.detach().cpu().squeeze().numpy()
    mask_np = mask.detach().cpu().squeeze().numpy()

    # Plot images
    plt.subplot(2, 3, 1)
    plt.imshow(image_np)
    plt.title('Original Image')
    
    plt.subplot(2, 3, 2)
    plt.imshow(pert_image_np)
    plt.title('Perturbed Image')
    
    plt.subplot(2, 3, 3)
    plt.imshow(perturbation_normalized, cmap='coolwarm', vmin=0, vmax=1)
    plt.title('Perturbation (Normalized)')
    
    plt.subplot(2, 3, 4)
    plt.imshow(clean_pred_np, cmap='gray')
    plt.title('Clean Prediction')
    
    plt.subplot(2, 3, 5)
    plt.imshow(pert_pred_np, cmap='gray')
    plt.title('Perturbed Prediction')
    
    plt.subplot(2, 3, 6)
    plt.imshow(mask_np, cmap='gray')
    plt.title('Ground Truth')
    
    plt.savefig(os.path.join(save_dir, filename))
    plt.close()

In [9]:
def run_deepfool_attack(model, val_loader, save_dir='deepfool_attack_results', device='cuda'):
    """Run DeepFool attack on validation set"""
    model = model.to(device).eval()
    os.makedirs(save_dir, exist_ok=True)
    
    dice_scores = []
    num_samples = 0
    
    for batch_idx, (images, masks) in enumerate(tqdm(val_loader)):
        images = images.to(device)
        masks = masks.to(device).unsqueeze(1)
        
        # Generate adversarial examples
        pert_images, _ = deepfool_attack(model, images, masks, device=device)
        
        with torch.no_grad():
            # Get predictions
            clean_outputs = model(images)
            clean_preds = (torch.sigmoid(clean_outputs) > 0.5).float()
            
            pert_outputs = model(pert_images)
            pert_preds = (torch.sigmoid(pert_outputs) > 0.5).float()
            
            # Calculate Dice score
            dice = dice_score(pert_preds, masks)
            dice_scores.append(dice.item())
            
        # Save visualizations for first 5 samples
        for i in range(min(5, images.size(0))):
            plot_and_save_results(
                image=images[i],
                pert_image=pert_images[i],
                clean_pred=clean_preds[i],
                pert_pred=pert_preds[i],
                mask=masks[i],
                save_dir=save_dir,
                filename=f'batch_{batch_idx}_sample_{i}.png'
            )
            
        num_samples += images.size(0)
        
    # Calculate final metrics
    avg_dice = np.mean(dice_scores)
    print(f'\nDeepFool Attack Results:')
    print(f'Average Dice Score: {avg_dice:.4f}')
    
    return avg_dice

In [10]:
# Run DeepFool attack
attack_dice = run_deepfool_attack(model, val_loader)

# Compare with clean performance
clean_dice = evaluate(val_loader)
print(f'\nPerformance Comparison:')
print(f'Clean Dice: {clean_dice:.4f}')
print(f'DeepFool Dice: {attack_dice:.4f}')
print(f'Performance Drop: {clean_dice - attack_dice:.4f}')

100%|█████████████████████████████████████████████████████| 12/12 [03:48<00:00, 19.05s/it]


DeepFool Attack Results:
Average Dice Score: 0.6758






Performance Comparison:
Clean Dice: 0.9837
DeepFool Dice: 0.6758
Performance Drop: 0.3079
