In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch.nn.functional as F
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import random
import numpy as np

# Set random seeds for reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# Use GPU if available
device = torch.device("mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define the HiDDeN-based model
class HiDDeNModel(nn.Module):
    def __init__(self):
        super(HiDDeNModel, self).__init__()
        
        # Watermark classifier (detects if image has watermark)
        self.classifier = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier_fc = nn.Linear(256, 1)
        
        # Encoder (for adding watermark)
        self.encoder = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=3, padding=1),  # 3 (image) + 3 (watermark)
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Tanh()  # Tanh to keep values in [-1, 1] range
        )

        # Decoder (for extracting/removing watermark)
        self.decoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

        # Move model to GPU if available
        self.to(device)

    def classify(self, image):
        """Determine if image has watermark"""
        features = self.classifier(image)
        features = features.view(features.size(0), -1)
        return torch.sigmoid(self.classifier_fc(features))

    def encode(self, image, watermark):
        """Add watermark to image"""
        # Ensure watermark has same dimensions as image
        if watermark.shape[2:] != image.shape[2:]:
            watermark = F.interpolate(watermark, size=image.shape[2:], mode='bilinear', align_corners=False)
        
        # Concatenate image and watermark along channel dimension
        combined = torch.cat([image, watermark], dim=1)
        encoded_image = image + self.encoder(combined)
        
        # Ensure pixel values stay in valid range
        encoded_image = torch.clamp(encoded_image, -1, 1)
        return encoded_image

    def decode(self, watermarked_image):
        """Extract watermark from image"""
        extracted_watermark = self.decoder(watermarked_image)
        return extracted_watermark
    
    def remove_watermark(self, watermarked_image):
        """Remove watermark from image"""
        extracted_watermark = self.decode(watermarked_image)
        clean_image = watermarked_image - extracted_watermark
        clean_image = torch.clamp(clean_image, -1, 1)
        return clean_image

# Custom dataset for watermarked and non-watermarked images
class WatermarkDataset(Dataset):
    def __init__(self, clean_dir, watermarked_dir=None, watermark_path=None, transform=None, is_train=True):
        """
        Args:
            clean_dir (string): Directory with non-watermarked/clean images.
            watermarked_dir (string): Directory with watermarked images.
            watermark_path (string): Path to watermark image to use for training.
            transform (callable, optional): Optional transform to be applied on a sample.
            is_train (bool): Whether this is training or validation set.
        """
        self.clean_dir = clean_dir
        self.watermarked_dir = watermarked_dir
        self.watermark_path = watermark_path
        self.transform = transform
        self.is_train = is_train
        self.watermark_tensor = None
        
        # Get all clean image files
        self.clean_image_files = []
        if os.path.exists(clean_dir):
            self.clean_image_files = [f for f in os.listdir(clean_dir)
                        if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
        
        # Get all watermarked image files if directory is provided
        self.watermarked_image_files = []
        if watermarked_dir and os.path.exists(watermarked_dir):
            self.watermarked_image_files = [f for f in os.listdir(watermarked_dir)
                            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
        
        # Load watermark if provided
        if watermark_path and os.path.exists(watermark_path):
            watermark_img = Image.open(watermark_path).convert('RGB')
            if transform:
                self.watermark_tensor = transform(watermark_img)
            print(f"Using watermark from: {watermark_path}")
        
        print(f"Found {len(self.clean_image_files)} clean images and {len(self.watermarked_image_files)} watermarked images")
    
    def __len__(self):
        if self.is_train:
            # For training, use both clean and watermarked images
            return len(self.clean_image_files) + len(self.watermarked_image_files)
        else:
            # For validation, use all available images
            return len(self.clean_image_files) + len(self.watermarked_image_files)
    
    def __getitem__(self, idx):
        # Determine if we're loading a clean or watermarked image
        if idx >= len(self.clean_image_files):
            # This is a watermarked image
            if len(self.watermarked_image_files) == 0:
                # No watermarked images available, wrap around to clean images
                clean_idx = idx % len(self.clean_image_files)
                img_name = os.path.join(self.clean_dir, self.clean_image_files[clean_idx])
                image = Image.open(img_name).convert('RGB')
                
                if self.transform:
                    image = self.transform(image)
                
                # Use a blank watermark if none provided
                if self.watermark_tensor is None:
                    watermark = torch.zeros_like(image)
                else:
                    watermark = self.watermark_tensor
                
                return {
                    'image': image,
                    'watermark': watermark,
                    'has_watermark': torch.tensor([0.0], dtype=torch.float32)  # Treat as non-watermarked
                }
            else:
                # Load actual watermarked image
                watermarked_idx = idx - len(self.clean_image_files)
                watermarked_idx = watermarked_idx % len(self.watermarked_image_files)  # Handle overflow
                img_name = os.path.join(self.watermarked_dir, self.watermarked_image_files[watermarked_idx])
                image = Image.open(img_name).convert('RGB')
                
                if self.transform:
                    image = self.transform(image)
                
                # Use a blank watermark if none provided
                if self.watermark_tensor is None:
                    watermark = torch.zeros_like(image)
                else:
                    watermark = self.watermark_tensor
                
                return {
                    'image': image,
                    'watermark': watermark,
                    'has_watermark': torch.tensor([1.0], dtype=torch.float32)
                }
        else:
            # This is a clean/non-watermarked image
            img_name = os.path.join(self.clean_dir, self.clean_image_files[idx])
            image = Image.open(img_name).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
            
            # Use a blank watermark if none provided
            if self.watermark_tensor is None:
                watermark = torch.zeros_like(image)
            else:
                watermark = self.watermark_tensor
            
            return {
                'image': image,
                'watermark': watermark,
                'has_watermark': torch.tensor([0.0], dtype=torch.float32)
            }

# Helper functions for training
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.001):
    """Train the HiDDeN model"""
    # Optimizers
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Loss functions
    classifier_criterion = nn.BCELoss()
    reconstruction_criterion = nn.MSELoss()
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        train_classifier_loss = 0.0
        train_encoder_loss = 0.0
        train_decoder_loss = 0.0
        train_removal_loss = 0.0
        
        for batch in train_loader:
            images = batch['image'].to(device)
            watermarks = batch['watermark'].to(device)
            has_watermark = batch['has_watermark'].to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass - classification
            watermark_pred = model.classify(images)
            
            # Classification loss
            classifier_loss = classifier_criterion(watermark_pred, has_watermark)
            
            # Split batch into watermarked and clean images based on has_watermark
            watermarked_indices = has_watermark.squeeze() > 0.5
            clean_indices = ~watermarked_indices
            
            # Process both parts of the batch
            total_loss = classifier_loss  # Start with classification loss
            
            # If we have watermarked images, train the decoder/removal
            if watermarked_indices.sum() > 0:
                watermarked_images = images[watermarked_indices]
                
                # Forward pass - removal (should get back clean image)
                clean_recovered = model.remove_watermark(watermarked_images)
                
                # Since we don't have ground truth clean versions of these images,
                # we can use a smoothness constraint
                removal_loss = torch.mean(torch.abs(
                    clean_recovered[:, :, 1:, :] - clean_recovered[:, :, :-1, :]
                )) + torch.mean(torch.abs(
                    clean_recovered[:, :, :, 1:] - clean_recovered[:, :, :, :-1]
                ))
                
                # Extract watermark
                decoded_watermarks = model.decode(watermarked_images)
                
                # We want the decoded watermark to be non-zero
                decoder_loss = torch.mean((decoded_watermarks.mean(dim=[2, 3]) - 0.5) ** 2)
                
                train_removal_loss += removal_loss.item()
                train_decoder_loss += decoder_loss.item()
                total_loss = total_loss + removal_loss + decoder_loss
            
            # If we have clean images, train the encoder
            if clean_indices.sum() > 0:
                clean_images = images[clean_indices]
                clean_watermarks = watermarks[clean_indices]
                
                # Forward pass - encoding
                encoded_images = model.encode(clean_images, clean_watermarks)
                
                # Encoding loss (encoded image should be similar to original)
                encoder_loss = reconstruction_criterion(encoded_images, clean_images)
                
                # The encoded image should be classified as watermarked
                encoded_pred = model.classify(encoded_images)
                has_watermark_encoded = torch.ones_like(encoded_pred)
                encoder_classify_loss = classifier_criterion(encoded_pred, has_watermark_encoded)
                
                # Forward pass - decoding (should recover watermark from encoded image)
                decoded_watermarks = model.decode(encoded_images)
                decoder_recovery_loss = reconstruction_criterion(decoded_watermarks, clean_watermarks)
                
                train_encoder_loss += encoder_loss.item()
                train_decoder_loss += decoder_recovery_loss.item()
                total_loss = total_loss + encoder_loss + encoder_classify_loss + decoder_recovery_loss
            
            # Backward pass and optimize
            total_loss.backward()
            optimizer.step()
            
            # Track classification loss
            train_classifier_loss += classifier_loss.item()
        
        # Validation
        model.eval()
        val_classifier_loss = 0.0
        val_accuracy = 0.0
        
        with torch.no_grad():
            for batch in val_loader:
                images = batch['image'].to(device)
                has_watermark = batch['has_watermark'].to(device)
                
                # Validation on classification
                watermark_pred = model.classify(images)
                val_loss = classifier_criterion(watermark_pred, has_watermark)
                
                # Calculate accuracy
                predicted = (watermark_pred > 0.5).float()
                val_accuracy += (predicted == has_watermark).float().mean().item()
                
                val_classifier_loss += val_loss.item()
        
        # Print metrics
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Classifier Loss: {train_classifier_loss/len(train_loader):.4f}")
        print(f"Train Encoder Loss: {train_encoder_loss/len(train_loader):.4f}")
        print(f"Train Decoder Loss: {train_decoder_loss/len(train_loader):.4f}")
        print(f"Train Removal Loss: {train_removal_loss/len(train_loader):.4f}")
        print(f"Val Classifier Loss: {val_classifier_loss/len(val_loader):.4f}")
        print(f"Val Accuracy: {val_accuracy/len(val_loader):.4f}")
    
    return model

# Process single image
def process_single_image(filename, input_dir, watermark_img, output_unwatermarked, output_watermarked, model):
    img_path = os.path.join(input_dir, filename)

    try:
        # Load and preprocess image
        image = Image.open(img_path).convert("RGB")
        
        # Preserve original image dimensions for later
        original_width, original_height = image.size
        
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        image_tensor = transform(image).unsqueeze(0).to(device)
        
        # Prepare watermark
        watermark = watermark_img.resize((256, 256))
        watermark_tensor = transform(watermark).unsqueeze(0).to(device)

        with torch.no_grad():
            # Detect if image has watermark
            watermark_prob = model.classify(image_tensor)
            has_watermark = watermark_prob > 0.5
            confidence = abs(watermark_prob.item() - 0.5) * 2  # Scale to 0-1 range

            if has_watermark.item():
                # Remove watermark
                cleaned_image = model.remove_watermark(image_tensor)
                
                # Convert back to PIL image 
                cleaned_image = cleaned_image * 0.5 + 0.5  # Denormalize
                cleaned_image_pil = transforms.ToPILImage()(cleaned_image.squeeze(0).cpu())
                
                # Resize back to original dimensions
                cleaned_image_pil = cleaned_image_pil.resize((original_width, original_height), Image.LANCZOS)
                
                # Save the unwatermarked image
                output_path = os.path.join(output_unwatermarked, filename)
                cleaned_image_pil.save(output_path)
                return f"Removed watermark from: {filename} (confidence: {confidence:.2f})"
            else:
                # Add watermark
                watermarked_image = model.encode(image_tensor, watermark_tensor)
                
                # Convert back to PIL image
                watermarked_image = watermarked_image * 0.5 + 0.5  # Denormalize
                watermarked_image_pil = transforms.ToPILImage()(watermarked_image.squeeze(0).cpu())
                
                # Resize back to original dimensions
                watermarked_image_pil = watermarked_image_pil.resize((original_width, original_height), Image.LANCZOS)
                
                # Save the watermarked image
                output_path = os.path.join(output_watermarked, filename)
                watermarked_image_pil.save(output_path)
                return f"Added watermark to: {filename} (confidence: {confidence:.2f})"
    except Exception as e:
        return f"Error processing {filename}: {str(e)}"

# Main processing function with parallel execution
def process_images(input_dir, watermark_path, output_unwatermarked, output_watermarked, model, num_workers=None):
    # Create output directories if they don't exist
    os.makedirs(output_unwatermarked, exist_ok=True)
    os.makedirs(output_watermarked, exist_ok=True)

    # Load user-provided watermark
    try:
        watermark_img = Image.open(watermark_path).convert("RGB")
        print(f"Successfully loaded user watermark: {watermark_path}")
    except Exception as e:
        print(f"Error loading watermark: {str(e)}")
        return 0

    # Get list of image files
    image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
    
    if len(image_files) == 0:
        print(f"No image files found in {input_dir}")
        return 0

    # Process images in parallel
    process_func = partial(
        process_single_image,
        input_dir=input_dir,
        watermark_img=watermark_img,
        output_unwatermarked=output_unwatermarked,
        output_watermarked=output_watermarked,
        model=model
    )

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        results = list(executor.map(process_func, image_files))

    # Print results
    for result in results:
        print(result)

    return len(image_files)

def main():
    # Define data directories (these should match your provided structure)
    train_clean_dir = "wm-nowm/train/no-watermark"
    train_watermarked_dir = "wm-nowm/train/watermark"
    val_clean_dir = "wm-nowm/valid/no-watermark"
    val_watermarked_dir = "wm-nowm/valid/watermark" 
    
    # Check if directories exist
    for directory in [train_clean_dir, train_watermarked_dir, val_clean_dir, val_watermarked_dir]:
        if not os.path.exists(directory):
            print(f"Warning: Directory {directory} does not exist.")
    
    # Get watermark path from user for training
    watermark_path = input("Enter path to watermark image for training (or press Enter to use a blank watermark): ")
    if not watermark_path or not os.path.exists(watermark_path):
        print("No valid watermark path provided. Will use a blank watermark during training.")
        watermark_path = None
    
    # Prepare data transforms
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # Create datasets
    train_dataset = WatermarkDataset(
        clean_dir=train_clean_dir,
        watermarked_dir=train_watermarked_dir,
        watermark_path=watermark_path,
        transform=transform,
        is_train=True
    )
    
    val_dataset = WatermarkDataset(
        clean_dir=val_clean_dir, 
        watermarked_dir=val_watermarked_dir,
        watermark_path=watermark_path,
        transform=transform,
        is_train=False
    )
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)
    
    # Create model
    model = HiDDeNModel()
    
    # Check if model already exists and load it
    model_path = "watermark_model.pth"
    if os.path.exists(model_path):
        print(f"Loading existing model from {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=device))
    else:
        # Train model
        print("Training model...")
        model = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            num_epochs=10,
            learning_rate=0.001
        )
        
        # Save model
        torch.save(model.state_dict(), model_path)
        print(f"Model saved to {model_path}")
    
    # Get user inputs for processing
    print("\n--- Watermark Processing ---")
    input_dir = input("Enter the directory path containing images to process: ")
    
    if not input_dir or not os.path.exists(input_dir):
        print("Error: Invalid input directory.")
        return
    
    watermark_path = input("Enter the path to the watermark image you want to apply: ")
    if not watermark_path or not os.path.exists(watermark_path):
        print("Error: Invalid watermark path.")
        return
    
    # Set output directories
    output_unwatermarked = "output/unwatermarked"
    output_watermarked = "output/watermarked"
    
    # Determine optimal number of workers based on CPU cores
    num_workers = min(os.cpu_count() or 1, 8)  # Cap at 8 workers to avoid excessive resource usage
    
    # Process the images with user-provided watermark
    print(f"\nProcessing images with watermark: {watermark_path}")
    num_processed = process_images(
        input_dir,
        watermark_path,
        output_unwatermarked,
        output_watermarked,
        model,
        num_workers=num_workers
    )
    
    print(f"Processing complete! Processed {num_processed} images using {num_workers} workers.")
    print(f"Unwatermarked images saved to: {output_unwatermarked}")
    print(f"Watermarked images saved to: {output_watermarked}")

if __name__ == "__main__":
    main()

Using device: mps
No valid watermark path provided. Will use a blank watermark during training.
Found 12477 clean images and 12510 watermarked images
Found 3289 clean images and 3299 watermarked images
Training model...


Traceback (most recent call last):
  File [35m"<string>"[0m, line [35m1[0m, in [35m<module>[0m
    from multiprocessing.spawn import spawn_main; [31mspawn_main[0m[1;31m(tracker_fd=79, pipe_handle=93)[0m
                                                  [31m~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m122[0m, in [35mspawn_main[0m
    exitcode = _main(fd, parent_sentinel)
  File [35m"/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m132[0m, in [35m_main[0m
    self = reduction.pickle.load(from_parent)
[1;35mAttributeError[0m: [35mCan't get attribute 'WatermarkDataset' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>[0m
