In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import os
from ultralytics import YOLO
from torchvision import transforms
from PIL import Image

### Explanation of DoubleConv, Down, Up, OutConv, and VAE Classes

The `DoubleConv`, `Down`, `Up`, and `OutConv` classes serve as fundamental components in constructing the encoder-decoder architecture utilized in the `VAE` class. These modules facilitate various stages of feature extraction, downsampling, upsampling, and output generation, forming the backbone of the Variational Autoencoder (VAE) designed specifically for watermark removal.

- **DoubleConv** acts as a core building block that sequentially applies two convolutional layers, each followed by Batch Normalization and ReLU activation. This ensures effective feature extraction with normalization for stable training.
- **Down** performs downsampling using a MaxPooling layer, followed by a `DoubleConv` module for extracting hierarchical features while reducing spatial dimensions.
- **Up** focuses on upsampling the feature maps using a transposed convolution, followed by concatenation with skip connections from the encoder. It then applies a `DoubleConv` to refine the upsampled features, enabling high-resolution reconstruction.
- **OutConv** is responsible for producing the final output of the model by reducing the number of channels and applying a Sigmoid activation. This is particularly suited for pixel-level predictions, such as mask generation or image reconstruction.
- **VAE** integrates these components within an encoder-decoder framework and incorporates a latent space for compact representation. Additionally, it leverages a pre-trained YOLO model to identify watermark regions, using these predictions as an extra input channel to enhance the reconstruction process in masked areas.

The VAE model operates through three primary stages: detecting watermarked regions via YOLO, encoding the input image into a latent representation, and decoding this representation to reconstruct a watermark-free image. This modular design, combining feature extraction, downsampling, and upsampling with guidance from YOLO, enables precise and efficient watermar removal.

In [3]:
# ---------------------------------------------------------
# Define the DoubleConv module: Two convolutional layers with BatchNorm and ReLU
# ---------------------------------------------------------

class DoubleConv(nn.Module):
    """
    A sequential module consisting of two convolutional layers, 
    each followed by BatchNorm and ReLU activation.
    """

    def __init__(self, in_channels, out_channels):
        """
        Initialize the DoubleConv module.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
        """
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        """
        Forward pass of the DoubleConv module.

        Args:
            x (torch.Tensor): Input tensor of shape (B, C, H, W).
        
        Returns:
            torch.Tensor: Output tensor after two convolutions.
        """
        return self.double_conv(x)


# ---------------------------------------------------------
# Define the Down module: MaxPooling followed by DoubleConv
# ---------------------------------------------------------

class Down(nn.Module):
    """
    A module for downsampling that applies MaxPooling followed by a DoubleConv.
    """

    def __init__(self, in_channels, out_channels):
        """
        Initialize the Down module.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
        """
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),  # Downsample by a factor of 2
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        """
        Forward pass of the Down module.

        Args:
            x (torch.Tensor): Input tensor of shape (B, C, H, W).
        
        Returns:
            torch.Tensor: Downsampled tensor after MaxPooling and DoubleConv.
        """
        return self.maxpool_conv(x)


# ---------------------------------------------------------
# Define the Up module: Upsampling followed by DoubleConv
# ---------------------------------------------------------

class Up(nn.Module):
    """
    A module for upsampling that applies ConvTranspose2d for resizing, 
    concatenates with skip connections, and applies DoubleConv.
    """

    def __init__(self, in_channels, out_channels):
        """
        Initialize the Up module.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
        """
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        """
        Forward pass of the Up module.

        Args:
            x1 (torch.Tensor): Tensor from the decoder path (upsampled).
            x2 (torch.Tensor): Tensor from the encoder path (skip connection).
        
        Returns:
            torch.Tensor: Output tensor after upsampling and concatenation.
        """
        # Perform upsampling
        x1 = self.up(x1)

        # Handle size mismatch due to rounding in upsampling
        diffY = x2.size(2) - x1.size(2)
        diffX = x2.size(3) - x1.size(3)
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])

        # Concatenate and apply DoubleConv
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


# ---------------------------------------------------------
# Define the OutConv module: Final output layer with Sigmoid activation
# ---------------------------------------------------------

class OutConv(nn.Module):
    """
    A module for reducing the number of channels to the desired output 
    and applying Sigmoid activation for pixel-wise predictions.
    """

    def __init__(self, in_channels, out_channels):
        """
        Initialize the OutConv module.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
        """
        super(OutConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),  # 1x1 convolution
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        Forward pass of the OutConv module.

        Args:
            x (torch.Tensor): Input tensor of shape (B, C, H, W).
        
        Returns:
            torch.Tensor: Output tensor with reduced channels.
        """
        return self.conv(x)

# ---------------------------------------------------------
# Define the VAE (Variational Autoencoder) Model
# ---------------------------------------------------------

class VAE(nn.Module):
    """
    A custom Variational Autoencoder (VAE) model for watermark removal. 
    Integrates a YOLO object detection model for mask prediction.
    """

    def __init__(self, in_channels=4, out_channels=3, latent_dim=256, input_size=128, yolo_model_path="yolov8s.pt", device='cuda'):
        """
        Initialize the VAE model.

        Args:
            in_channels (int): Number of input channels (e.g., 4 for RGB + Mask).
            out_channels (int): Number of output channels (e.g., 3 for RGB).
            latent_dim (int): Dimensionality of the latent space.
            input_size (int): Size of the input image (height and width).
            yolo_model_path (str): Path to the pre-trained YOLO model.
            device (str): Device to run the model on ('cuda' or 'cpu').
        """
        super(VAE, self).__init__()
        yolo = YOLO(yolo_model_path)
        object.__setattr__(self, 'yolo_model', yolo)
        self.device = device
        self.input_size = input_size

        # Encoder layers
        self.inc = DoubleConv(in_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)

        # Latent space
        self.fc_mu = nn.Linear(1024 * (input_size // 16) ** 2, latent_dim)
        self.fc_logvar = nn.Linear(1024 * (input_size // 16) ** 2, latent_dim)
        self.fc_dec = nn.Linear(latent_dim, 1024 * (input_size // 16) ** 2)

        # Decoder layers
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, out_channels)

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick for the latent space.

        Args:
            mu (torch.Tensor): Mean of the latent distribution.
            logvar (torch.Tensor): Log variance of the latent distribution.
        
        Returns:
            torch.Tensor: Sampled latent vector.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def create_mask_from_yolo_preds(self, predictions, image_size):
        """
        Create a binary mask from YOLO predictions.

        Args:
            predictions: YOLO detection results.
            image_size (int): Size of the output mask (H, W).
        
        Returns:
            torch.Tensor: Binary mask of shape (1, H, W).
        """
        mask = torch.zeros((1, image_size, image_size))
        for det in predictions:
            boxes = det.boxes
            for box in boxes:
                x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
                mask[:, y1:y2, x1:x2] = 1.0
        return mask

    def forward(self, input_images):
        """
        Forward pass of the VAE model.

        Args:
            input_images (torch.Tensor): Batch of input images.
        
        Returns:
            tuple: (Final output, mean, log variance, predicted masks).
        """
        # YOLO mask prediction
        batch_size = input_images.size(0)
        predicted_masks = []
        for i in range(batch_size):
            img_np = (input_images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
            results = self.yolo_model.predict(img_np, imgsz=self.input_size, verbose=False, device=self.device)
            pmask = self.create_mask_from_yolo_preds(results, self.input_size).to(input_images.device)
            predicted_masks.append(pmask)
        predicted_masks = torch.stack(predicted_masks, dim=0)

        # Combine input with predicted masks
        combined_input = torch.cat([input_images, predicted_masks], dim=1)

        # Encoder
        x1 = self.inc(combined_input)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        # Latent space
        x_flat = x5.view(batch_size, -1)
        mu = torch.clamp(self.fc_mu(x_flat), -10, 10)
        logvar = torch.clamp(self.fc_logvar(x_flat), -10, 10)
        z = self.reparameterize(mu, logvar)

        # Decoder
        x_decoded = self.fc_dec(z).view(batch_size, 1024, x5.size(2), x5.size(3))
        x = self.up1(x_decoded, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        reconstruction = self.outc(x)

        # Apply mask blending
        predicted_masks_3 = predicted_masks.expand_as(reconstruction)
        final_output = input_images * (1 - predicted_masks_3) + reconstruction * predicted_masks_3

        return final_output, mu, logvar, predicted_masks

### Load Trained Model Function
load_trained_model function initializes a pre-trained Variational Autoencoder (VAE) model with specified parameters. It sets the model to evaluation mode and loads the saved state dictionary from the given model_path. The model is designed to process images with specific configurations, including integration with a YOLO model for additional processing.

In [5]:
def load_trained_model(model_path, device):
    """
    Loads a pre-trained VAE model with specified configurations.

    Parameters:
    - model_path (str): Path to the saved model state dictionary.
    - device (torch.device): The device (CPU or GPU) to load the model onto.

    Returns:
    - model (VAE): The VAE model loaded with pre-trained weights, ready for inference.
    """
    # Initialize the VAE model with predefined parameters
    model = VAE(
        in_channels=4,         # Number of input channels (e.g., RGBA)
        out_channels=3,        # Number of output channels (e.g., RGB)
        latent_dim=256,        # Dimensionality of the latent space
        input_size=128,        # Size of the input image (128x128)
        yolo_model_path="yolo_best_model/best.pt",  # Path to the YOLO model
        device=device          # Device to use (CPU or GPU)
    )
    
    # Load the model's state dictionary from the specified path
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # Set the model to evaluation mode to disable dropout and batch normalization
    model.eval()
    
    # Return the loaded model
    return model

### Remove Watermark And Display`

The `remove_watermark_and_display` function visualizes and saves the results of a watermark removal process using a trained model. It takes a watermarked image, an original clean image for reference, and processes the watermarked image through the model to produce a reconstructed image and a predicted watermark mask. The function displays these outputs side by side (watermarked image, predicted mask, reconstructed image, and original image) and saves the visualization in the `visualizator_outputs` directory with the model name prefixed to the file name.

In [7]:
def remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, model_name):
    """
    Visualizes and saves the results of watermark removal, including the watermarked image,
    predicted watermark mask, reconstructed image, and the original (clean) image.

    Parameters:
    - model (torch.nn.Module): The trained model for watermark removal.
    - watermarked_image_path (str): Path to the watermarked image file.
    - original_image_path (str): Path to the original clean image file.
    - transform (callable): Transformation to apply to the input image.
    - device (torch.device): Device (CPU or GPU) to use for inference.
    - model_name (str): Name of the model, used to prefix the output file name.
    """
    # Load and preprocess the watermarked and original images
    watermarked_image = Image.open(watermarked_image_path).convert('RGB')  # Load watermarked image
    original_image = Image.open(original_image_path).convert('RGB')        # Load original image
    input_image = transform(watermarked_image).unsqueeze(0)               # Apply transformations and add batch dimension
    
    with torch.no_grad():
        # Get model predictions for the input image
        reconstructed_image, _, _, predicted_mask = model(input_image)
    
    # Post-process the model outputs for visualization
    reconstructed_image = reconstructed_image.squeeze(0).cpu().clamp(0, 1)  # Convert output to [0, 1] range
    predicted_mask = predicted_mask.squeeze(0).squeeze(0).cpu()             # Convert mask to CPU
    input_image_for_display = input_image.squeeze(0).permute(1, 2, 0).cpu()  # Prepare input image for display
    
    # Create a side-by-side visualization with four subplots
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    axs[0].imshow(watermarked_image)
    axs[0].set_title('Watermarked Image')  # Watermarked input image
    axs[1].imshow(predicted_mask, cmap='gray')
    axs[1].set_title('Predicted Watermark Mask')  # Predicted watermark mask
    axs[2].imshow(transforms.ToPILImage()(reconstructed_image))
    axs[2].set_title('Watermark Removed Image')  # Reconstructed clean image
    axs[3].imshow(original_image)
    axs[3].set_title('Original Image')  # Ground truth clean image

    # Hide axes for a cleaner display
    for ax in axs:
        ax.axis('off')

    # Save the output visualization with the model name prefixed
    output_dir = "visualizator_outputs"
    os.makedirs(output_dir, exist_ok=True)  # Ensure the output directory exists
    image_name = os.path.basename(watermarked_image_path)  # Extract the base name of the watermarked image
    output_name = f"{model_name}_{image_name}"  # Prefix the image name with the model name
    output_path = os.path.join(output_dir, output_name)
    plt.savefig(output_path, bbox_inches='tight')  # Save the visualization to the specified path
    plt.close(fig)  # Close the figure to release resources

    print(f"Visualization saved to {output_path}")

### Show Model Outputs
This code demonstrates loading pre-trained VAE models to remove watermarks from images and visualize the results. Two models are used: one trained on low-opacity watermarks and another on high-opacity watermarks. The script processes specific test images, displaying the watermarked image, predicted mask, reconstructed image, and original clean image side by side. The visualizations are saved with filenames indicating the model and image details for clear comparison. The device (CPU or GPU) is automatically selected based on availability, and input images are resized and transformed to match the model requirements.

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the image transformation pipeline
# Images will be resized to 128x128 and converted to tensors
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

model_path = 'outputs/VAE_NoLogoLowOpacity_128px_5folds_final.pt'
model = load_trained_model(model_path, device)

number = 3167
watermarked_image_path = f'no_logo_and_low_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'no_logo_and_low_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"no_logo_and_low_opacity_frame_{number}")

number = 1325
watermarked_image_path = f'no_logo_and_low_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'no_logo_and_low_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"no_logo_and_low_opacity_frame_{number}")

number = 2651
watermarked_image_path = f'no_logo_and_low_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'no_logo_and_low_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"no_logo_and_low_opacity_frame_{number}")

number = 96
watermarked_image_path = f'no_logo_and_low_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'no_logo_and_low_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"no_logo_and_low_opacity_frame_{number}")

model_path = 'outputs/VAE_NoLogoHighOpacity_128px_5folds_final.pt'
model = load_trained_model(model_path, device)

number = 3167
watermarked_image_path = f'no_logo_and_high_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'no_logo_and_high_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"no_logo_and_high_opacity_frame_{number}")

number = 1325
watermarked_image_path = f'no_logo_and_high_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'no_logo_and_high_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"no_logo_and_high_opacity_frame_{number}")

number = 2651
watermarked_image_path = f'no_logo_and_high_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'no_logo_and_high_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"no_logo_and_high_opacity_frame_{number}")

number = 96
watermarked_image_path = f'no_logo_and_high_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'no_logo_and_high_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"no_logo_and_high_opacity_frame_{number}")

model_path = 'outputs/VAE_LogoHighOpacity_128px_5folds_final.pt'
model = load_trained_model(model_path, device)

number = 3167
watermarked_image_path = f'logo_and_high_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'no_logo_and_high_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"logo_and_high_opacity_frame_{number}")

number = 1325
watermarked_image_path = f'logo_and_high_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'logo_and_high_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"logo_and_high_opacity_frame_{number}")

number = 2651
watermarked_image_path = f'logo_and_high_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'logo_and_high_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"logo_and_high_opacity_frame_{number}")

number = 96
watermarked_image_path = f'logo_and_high_opacity_watermark_dataset_test/frame_{number}_watermarked.jpg'
original_image_path = f'logo_and_high_opacity_watermark_dataset_test/frame_{number}.jpg'
remove_watermark_and_display(model, watermarked_image_path, original_image_path, transform, device, f"logo_and_high_opacity_frame_{number}")

Visualization saved to visualizator_outputs\no_logo_and_low_opacity_frame_3167_frame_3167_watermarked.jpg
Visualization saved to visualizator_outputs\no_logo_and_low_opacity_frame_1325_frame_1325_watermarked.jpg
Visualization saved to visualizator_outputs\no_logo_and_low_opacity_frame_2651_frame_2651_watermarked.jpg
Visualization saved to visualizator_outputs\no_logo_and_low_opacity_frame_96_frame_96_watermarked.jpg
Visualization saved to visualizator_outputs\no_logo_and_high_opacity_frame_3167_frame_3167_watermarked.jpg
Visualization saved to visualizator_outputs\no_logo_and_high_opacity_frame_1325_frame_1325_watermarked.jpg
Visualization saved to visualizator_outputs\no_logo_and_high_opacity_frame_2651_frame_2651_watermarked.jpg
Visualization saved to visualizator_outputs\no_logo_and_high_opacity_frame_96_frame_96_watermarked.jpg
Visualization saved to visualizator_outputs\logo_and_high_opacity_frame_3167_frame_3167_watermarked.jpg
Visualization saved to visualizator_outputs\logo_and