# D2Defend

In [None]:
!pip install -r requirements.txt

In [None]:
!pip install pywavelets

In [None]:
# Standard Library Imports
import os
import logging
from typing import Tuple, List
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import numpy as np
from sklearn.metrics import (
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score
)
from tqdm import tqdm
import cv2
from skimage.restoration import denoise_wavelet
import zipfile
import pywt
from torch.optim.lr_scheduler import ReduceLROnPlateau


# Model Definition

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Bilateral Filtering Code

def bilateral_filter(image, device, sigma_r=75, sigma_s=75):
    """
    Apply bilateral filtering to an image.

    Args:
        image (torch.Tensor): Input image tensor of shape (C, H, W).
        device (torch.device): Device for computation (CPU/GPU).
        sigma_r (float): Sigma value for color space.
        sigma_s (float): Sigma value for coordinate space.

    Returns:
        torch.Tensor: Bilateral filtered image tensor of shape (C, H, W).
    """
    assert isinstance(image, torch.Tensor), "Input must be a PyTorch tensor."
    assert len(image.shape) == 3, "Expected shape (C, H, W) for a single image."

    # Convert PyTorch tensor to NumPy array (HWC) scaled to 0-255
    image_np = image.permute(1, 2, 0).cpu().numpy() * 255.0
    image_np = np.clip(image_np, 0, 255).astype(np.uint8)

    # Apply bilateral filter
    filtered_np = cv2.bilateralFilter(image_np, d=9, sigmaColor=sigma_r, sigmaSpace=sigma_s)

    # Convert back to PyTorch tensor (C, H, W) scaled to 0-1
    filtered_tensor = torch.from_numpy(filtered_np.astype(np.float32) / 255.0).permute(2, 0, 1).to(device)
    return filtered_tensor

def wavelet_shrinkage_color_image(image, device, method='BayesShrink', sigma=None):
    """
    Apply wavelet shrinkage to denoise a color image.

    Args:
        image (torch.Tensor): Input image tensor of shape (C, H, W).
        device (torch.device): Device for computation (CPU/GPU).
        method (str): Method for wavelet shrinkage ('BayesShrink' or 'VisuShrink').
        sigma (float): Noise standard deviation.

    Returns:
        torch.Tensor: Denoised image tensor of shape (C, H, W).
    """
    # Move the tensor to CPU
    noisy_image = image.cpu().detach().numpy()

    denoised_image = np.zeros_like(noisy_image)

    for i in range(noisy_image.shape[2]):  # Iterate over each channel (R, G, B)
        denoised_image[..., i] = denoise_wavelet(
            noisy_image[..., i],
            method=method,
            mode='soft',
            sigma=sigma,
            rescale_sigma=True,
            channel_axis=None
        )

    # Convert back to PyTorch tensor and move to the specified device
    denoised_image = torch.tensor(denoised_image).to(device)
    return denoised_image


# D2Defend Model
class D2Defend(nn.Module):
    """
    D2Defend: A robust image classification model with advanced denoising techniques.

    Key Features:
    - Bilateral filtering for edge preservation
    - Short-time Fourier transform for texture analysis
    - Wavelet shrinkage for noise reduction
    - Multi-stage convolutional classifier
    """
    def __init__(self, sigma_r, sigma_s, lambda_val, num_classes, device):
        """
        Initialize D2Defend model components.

        Args:
            sigma_r (float): Color space sigma for bilateral filtering
            sigma_s (float): Coordinate space sigma for bilateral filtering
            lambda_val (float): Regularization parameter
            num_classes (int): Number of output classification categories
            device (torch.device): Computational device
        """
        super(D2Defend, self).__init__()
        self.sigma_r = sigma_r
        self.sigma_s = sigma_s
        self.lambda_val = lambda_val
        self.device = device

        # Bilateral filtering layer
        self.bilateral_filter = bilateral_filter


        # Short-time Fourier transform layer
        self.stft = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU()
        )

        # Wavelet shrinkage layer
        self.wavelet_shrinkage = wavelet_shrinkage_color_image

        # Inverse short-time Fourier transform layer

        self.istft = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU()
        )

        # Classification layer

        self.classifier = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.3),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.4),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.4),

            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
        )


        # Fully connected layer
        self.fc = nn.Sequential(
            nn.Linear(512, 128),  # Fully connected layer
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )



    def forward(self, x):
        """
        Forward pass through the D2Defend model.

        Args:
            x (torch.Tensor): Input image tensor

        Returns:
            torch.Tensor: Classification output
        """
        if len(x.shape) == 4:
            filtered_images = []
            for img in x:  # Loop through each image in the batch
                filtered_images.append(self.bilateral_filter(img, self.device,self.sigma_r, self.sigma_s))
            x_edge = torch.stack(filtered_images).to(self.device)  # Combine filtered images back into a batch
        elif len(x.shape) == 3:  # Single image
            x_edge = self.bilateral_filter(x, self.sigma_r, self.sigma_s, self.device).to(self.device)  # Move to same device as x

        # Short-time Fourier transform
        x_texture =(x - x_edge).to(self.device)

        x_edge = x_edge.to(self.device)
        x_texture_stft = self.stft(x_texture)

        # Wavelet shrinkage
        x_texture_shrinkage = self.wavelet_shrinkage(x_texture_stft, device=self.device, method='BayesShrink', sigma=None)

        # Inverse short-time Fourier transform
        x_texture_istft = self.istft(x_texture_shrinkage).to(self.device)

        # Combine edge and texture layers
        x_texture_istft = x_texture_istft.to(self.device)
        x_defend = x_edge + x_texture_istft
        print(x_defend.shape)
        # Classification
        features = self.classifier(x_defend)
        features = features.view(features.size(0), -1).to(self.device)  # Flatten for fully connected layer
        output = self.fc(features)

        return output

# Model initialization

In [None]:
# Model initialization
sigma_r = 75
sigma_s = 75
lambda_val = 1
num_epochs = 20
early_stop_patience = 5
early_stop_counter = 0
best_loss = float('inf')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = D2Defend(sigma_r, sigma_s, lambda_val, num_classes=2, device=device).to(device)

# Training & Testing

In [None]:
# Unzipping custom_dataset.zip
#The unzipping results in 2 datasets - custom_dataset and test_dataset - being created in the working directory
with zipfile.ZipFile("corrected_dataset.zip", 'r') as zip_ref:
    zip_ref.extractall('custom_dataset')

# Unzipping tes-final.zip
with zipfile.ZipFile("test_dataset.zip", 'r') as zip_ref:
    zip_ref.extractall('test_dataset')

In [None]:
# Define transforms for the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# loading dataset
test_path = 'custom_dataset/test'
train_path = 'custom_dataset/train'

test_dataset = datasets.ImageFolder(test_path, transform=transform)
train_dataset = datasets.ImageFolder(train_path, transform=transform)

# Split training dataset into 80% for training and 20% for validation
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Define early stopping criterion
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001): # Change _init_ to __init__
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = np.inf
        self.counter = 0

    def __call__(self, loss): # Change _call_ to __call__
        if loss < self.best_loss - self.min_delta:
            self.best_loss = loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

# Define model, optimizer, and criterion
model = D2Defend(sigma_r, sigma_s, lambda_val, num_classes=2, device=device).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Define early stopping and best/final weights saving
early_stopping = EarlyStopping(patience=5, min_delta=0.001)
best_weights = None
final_weights = None

# Training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    model.train()
    total_train_loss = 0
    for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc="Training")):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        clean_output = model(data)
        clean_loss = criterion(clean_output, target)
        clean_loss.backward(retain_graph=True)
        optimizer.step()
        total_train_loss += clean_loss.item()

    avg_train_loss = total_train_loss / len(train_loader)

    # Validation phase
    model.eval()
    total_val_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(val_loader, desc="Validation"):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_val_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    avg_val_loss = total_val_loss / len(val_loader)
    accuracy = 100. * correct / len(val_loader.dataset)

    print(f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.2f}%")

    # Early stopping and best/final weights saving
    if early_stopping(avg_val_loss):
        print("Early stopping!")
        break
    if avg_val_loss < early_stopping.best_loss:
        best_weights = model.state_dict()
    final_weights = model.state_dict()

# Make inferences on test dataset
model.load_state_dict(final_weights)
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in tqdm(test_loader, desc="Test"):
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        test_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader)
accuracy = 100. * correct / len(test_loader.dataset)
print(f"Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%")