In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
import os
import torch.nn.functional as F
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import kagglehub
import shutil
import random
from torch.utils.data import Dataset, DataLoader
import time

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Define the CNN model with batch normalization and more efficient architecture
class WatermarkCNN(nn.Module):
    def __init__(self):
        super(WatermarkCNN, self).__init__()
        # Encoder with batch normalization for faster convergence and training stability
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),  # inplace operations save memory
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

        # Decoder with batch normalization
        self.decoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

        # Move model to GPU if available
        self.to(device)

    def encode(self, image, watermark):
        # Optimize watermark preparation
        watermark = watermark.resize((image.shape[3], image.shape[2]))
        watermark_tensor = transforms.ToTensor()(watermark).unsqueeze(0).to(device)

        # Handle batch processing
        if watermark_tensor.shape[0] != image.shape[0]:
            watermark_tensor = watermark_tensor.repeat(image.shape[0], 1, 1, 1)

        encoded_image = self.encoder(image + watermark_tensor)
        return encoded_image

    def decode(self, watermarked_image):
        decoded_watermark = self.decoder(watermarked_image)
        return decoded_watermark
    
    def forward(self, image, watermark=None):
        """
        Forward pass for training
        If watermark is None, only decoding is performed (for detection)
        If watermark is provided, both encoding and decoding are performed
        """
        if watermark is not None:
            # Resize watermark to match image dimensions
            encoded = self.encode(image, watermark)
            decoded = self.decode(encoded)
            return encoded, decoded
        else:
            # Just perform decoding (for watermark detection)
            decoded = self.decode(image)
            return decoded

# Custom dataset for watermark images
class WatermarkDataset(Dataset):
    def __init__(self, input_dir, watermark_path, transform=None):
        self.input_dir = input_dir
        self.transform = transform or transforms.ToTensor()
        
        # Load all image paths and identify watermarked vs non-watermarked
        self.image_files = []
        self.has_watermark = []
        
        for filename in os.listdir(input_dir):
            if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                continue
                
            self.image_files.append(filename)
            # Detect watermark by filename prefix
            self.has_watermark.append(filename.startswith("wm_"))
        
        # Load watermark image
        self.watermark_img = Image.open(watermark_path).convert("RGB")
        self.watermark_transform = transforms.ToTensor()
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.input_dir, img_name)
        
        image = Image.open(img_path).convert("RGB")
        image_tensor = self.transform(image)
        
        has_watermark = self.has_watermark[idx]
        
        # Convert watermark to tensor for training
        watermark_tensor = self.watermark_transform(self.watermark_img)
        
        # Return the sample
        return {
            'image': image_tensor, 
            'watermark': watermark_tensor,
            'has_watermark': torch.tensor(has_watermark, dtype=torch.float32)
        }

# Cache for the model instance to avoid reloading
model_cache = None

def get_model():
    global model_cache
    if model_cache is None:
        model_cache = WatermarkCNN()
        
        # Try to load saved model if it exists
        model_path = 'watermark_model.pth'
        if os.path.exists(model_path):
            try:
                model_cache.load_state_dict(torch.load(model_path, map_location=device))
                print(f"Loaded pre-trained model from {model_path}")
            except Exception as e:
                print(f"Error loading model: {str(e)}")
    
    return model_cache

# Train the watermark model
def train_model(input_dir, watermark_path, epochs=10, batch_size=16, learning_rate=0.001, save_path='watermark_model.pth'):
    print(f"Training watermark model...")
    print(f"- Input directory: {input_dir}")
    print(f"- Watermark path: {watermark_path}")
    print(f"- Epochs: {epochs}")
    print(f"- Batch size: {batch_size}")
    print(f"- Learning rate: {learning_rate}")
    
    # Check if input directory exists and has images
    if not os.path.exists(input_dir):
        raise FileNotFoundError(f"Input directory not found: {input_dir}")
    
    # Create dataset and data loaders
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize to standardize images
        transforms.ToTensor()
    ])
    
    dataset = WatermarkDataset(input_dir, watermark_path, transform)
    
    if len(dataset) == 0:
        raise ValueError("No training data found in input directory")
    
    # Split dataset into train and validation sets
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    # Initialize model
    model = WatermarkCNN().to(device)
    
    # Loss functions
    mse_loss = nn.MSELoss()
    bce_loss = nn.BCELoss()
    
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Training loop
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        start_time = time.time()
        
        for batch in train_loader:
            # Get batch data
            images = batch['image'].to(device)
            watermarks = batch['watermark'].to(device)
            has_watermark = batch['has_watermark'].to(device)
            
            # Zero the gradients
            optimizer.zero_grad()
            
            # Forward pass
            encoded_images, decoded_watermarks = model(images, watermarks)
            
            # Calculate losses
            # 1. Encoding loss - how well can we hide the watermark
            encoding_loss = mse_loss(encoded_images, images)
            
            # 2. Decoding loss - how well can we extract the watermark
            detection_loss = bce_loss(decoded_watermarks.mean(dim=(1, 2, 3)), has_watermark)
            
            # 3. Reconstruction loss - can we recover the original watermark
            watermark_recon_loss = mse_loss(decoded_watermarks, 
                                           watermarks.mean(dim=1, keepdim=True) * has_watermark.view(-1, 1, 1, 1))
            
            # Total loss - balance between hiding and detecting
            total_loss = encoding_loss + 5.0 * detection_loss + watermark_recon_loss
            
            # Backward pass and optimize
            total_loss.backward()
            optimizer.step()
            
            train_loss += total_loss.item()
        
        # Validation
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for batch in val_loader:
                images = batch['image'].to(device)
                watermarks = batch['watermark'].to(device)
                has_watermark = batch['has_watermark'].to(device)
                
                encoded_images, decoded_watermarks = model(images, watermarks)
                
                encoding_loss = mse_loss(encoded_images, images)
                detection_loss = bce_loss(decoded_watermarks.mean(dim=(1, 2, 3)), has_watermark)
                watermark_recon_loss = mse_loss(decoded_watermarks, 
                                               watermarks.mean(dim=1, keepdim=True) * has_watermark.view(-1, 1, 1, 1))
                
                total_loss = encoding_loss + 5.0 * detection_loss + watermark_recon_loss
                val_loss += total_loss.item()
        
        # Calculate average losses
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        # Print epoch statistics
        elapsed_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {elapsed_time:.2f}s")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), save_path)
            print(f"Saved new best model to {save_path}")
    
    print(f"Training completed! Best validation loss: {best_val_loss:.4f}")
    
    # Set the trained model as the cached model
    global model_cache
    model_cache = model
    
    return model

# Optimized detection with batching support
def detect_watermark(image_path):
    model = get_model()
    image = Image.open(image_path).convert("RGB")
    transform = transforms.ToTensor()
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        watermark = model.decode(image_tensor)
        has_watermark = torch.mean(watermark) > 0.05  # Threshold for watermark presence

    return has_watermark.item(), image_tensor

# Optimized watermark removal
def remove_watermark(image_tensor):
    model = get_model()
    with torch.no_grad():
        watermark = model.decode(image_tensor)
        cleaned_image = image_tensor - watermark
    return cleaned_image

# Optimized watermark addition
def add_watermark(image_tensor, watermark):
    model = get_model()
    return model.encode(image_tensor, watermark)

# Process single image
def process_single_image(filename, input_dir, watermark_img, output_unwatermarked, output_watermarked):
    img_path = os.path.join(input_dir, filename)

    try:
        has_watermark, image_tensor = detect_watermark(img_path)

        if has_watermark:
            cleaned_image = remove_watermark(image_tensor)
            cleaned_image_pil = transforms.ToPILImage()(cleaned_image.squeeze(0).cpu())
            output_path = os.path.join(output_unwatermarked, filename)
            cleaned_image_pil.save(output_path)
            return f"Removed watermark from: {filename}"
        else:
            watermarked_image = add_watermark(image_tensor, watermark_img)
            watermarked_image_pil = transforms.ToPILImage()(watermarked_image.squeeze(0).cpu())
            output_path = os.path.join(output_watermarked, filename)
            watermarked_image_pil.save(output_path)
            return f"Added watermark to: {filename}"
    except Exception as e:
        return f"Error processing {filename}: {str(e)}"

# Main processing function with parallel execution
def process_images(input_dir, watermark_path, output_unwatermarked, output_watermarked, num_workers=None):
    # Create output directories if they don't exist
    os.makedirs(output_unwatermarked, exist_ok=True)
    os.makedirs(output_watermarked, exist_ok=True)

    # Verify watermark exists
    if not os.path.exists(watermark_path):
        raise FileNotFoundError(f"Watermark image not found: {watermark_path}")
    
    # Load watermark once
    try:
        watermark_img = Image.open(watermark_path).convert("RGB")
    except Exception as e:
        raise ValueError(f"Error loading watermark image: {str(e)}")

    # Get list of image files
    image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
    
    if not image_files:
        print(f"No image files found in {input_dir}")
        return 0

    # Process images in parallel
    process_func = partial(
        process_single_image,
        input_dir=input_dir,
        watermark_img=watermark_img,
        output_unwatermarked=output_unwatermarked,
        output_watermarked=output_watermarked
    )

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        results = list(executor.map(process_func, image_files))

    # Print results
    for result in results:
        print(result)

    return len(image_files)

# Function to download and prepare Kaggle dataset
def download_and_prepare_kaggle_dataset(use_train=True, use_valid=False, sample_limit=None):
    print("Downloading watermarked/non-watermarked dataset from Kaggle...")
    dataset_path = kagglehub.dataset_download("felicepollano/watermarked-not-watermarked-images")
    print(f"Dataset downloaded to: {dataset_path}")

    # Create input directory for our processing
    input_dir = "input_images"
    os.makedirs(input_dir, exist_ok=True)
    
    # Define source directories based on dataset structure
    source_dirs = []
    
    if use_train:
        source_dirs.append(os.path.join(dataset_path, "train", "no-watermark"))
        source_dirs.append(os.path.join(dataset_path, "train", "watermark"))
    
    if use_valid:
        source_dirs.append(os.path.join(dataset_path, "valid", "no-watermark"))
        source_dirs.append(os.path.join(dataset_path, "valid", "watermark"))
    
    # Track statistics for reporting
    copied_count = 0
    total_watermarked = 0
    total_non_watermarked = 0
    
    # Process each source directory
    for source_dir in source_dirs:
        if not os.path.exists(source_dir):
            print(f"Warning: Directory {source_dir} not found in dataset")
            continue
            
        # Determine if this is a watermarked directory
        is_watermark_dir = os.path.basename(source_dir) == "watermark"
        
        # Get all image files in this directory
        image_files = [f for f in os.listdir(source_dir) 
                      if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
        
        # Apply sample limit if specified
        if sample_limit is not None and len(image_files) > sample_limit:
            random.shuffle(image_files)  # Randomize to get a representative sample
            image_files = image_files[:sample_limit]
        
        # Prefix to avoid filename collisions and track source
        prefix = "wm_" if is_watermark_dir else "nowm_"
        
        # Copy files to input directory with appropriate prefix
        for file in image_files:
            source_path = os.path.join(source_dir, file)
            # Add prefix to avoid filename collisions
            dest_name = f"{prefix}{file}"
            dest_path = os.path.join(input_dir, dest_name)
            
            shutil.copy2(source_path, dest_path)
            copied_count += 1
            
            if is_watermark_dir:
                total_watermarked += 1
            else:
                total_non_watermarked += 1
    
    # Print statistics
    print(f"Dataset preparation complete:")
    print(f"- Total images copied: {copied_count}")
    print(f"- Watermarked images: {total_watermarked}")
    print(f"- Non-watermarked images: {total_non_watermarked}")
    
    if copied_count == 0:
        print("Warning: No images were copied. Check the dataset structure.")
    
    return input_dir

def run_with_params(watermark_path, input_dir=None, output_unwatermarked='output/unwatermarked', 
                  output_watermarked='output/watermarked', workers=None, train=True, 
                  epochs=10, batch_size=16, learning_rate=0.001,
                  use_train=True, use_valid=False, sample_limit=50):
    
    # Get input directory
    if input_dir is None:
        # Download and prepare the Kaggle dataset if no input directory provided
        input_dir = download_and_prepare_kaggle_dataset(use_train, use_valid, sample_limit)
    
    # Ensure the input directory exists
    if not os.path.exists(input_dir):
        raise FileNotFoundError(f"Input directory not found: {input_dir}")
    
    # Verify watermark path
    if not os.path.exists(watermark_path):
        raise FileNotFoundError(f"Watermark image not found: {watermark_path}")
    
    # Determine optimal number of workers if not specified
    if workers is None:
        workers = min(os.cpu_count(), 8)  # Cap at 8 workers to avoid excessive resource usage
    
    print(f"Processing with the following settings:")
    print(f"- Input directory: {input_dir}")
    print(f"- Watermark image: {watermark_path}")
    print(f"- Output directory (unwatermarked): {output_unwatermarked}")
    print(f"- Output directory (watermarked): {output_watermarked}")
    print(f"- Workers: {workers}")
    print(f"- Device: {device}")
    
    try:
        # First, train the model if requested
        if train:
            model = train_model(input_dir, watermark_path, epochs, batch_size, learning_rate)
        else:
            # Just load the model
            model = get_model()
        
        # Process the images
        num_processed = process_images(
            input_dir,
            watermark_path,
            output_unwatermarked,
            output_watermarked,
            num_workers=workers
        )
        
        print(f"Processing complete! Processed {num_processed} images using {workers} workers.")
        return num_processed
    except Exception as e:
        print(f"Error during processing: {str(e)}")
        import traceback
        traceback.print_exc()
        return 0

if __name__ == "__main__":
    # Paths and directories
    watermark_path = "logo.webp"  # Use your watermark image
    output_unwatermarked = "output/unwatermarked"
    output_watermarked = "output/watermarked"

    # Check if watermark image exists, if not create a simple one
    if not os.path.exists(watermark_path):
        print("Creating a sample watermark image...")
        from PIL import Image, ImageDraw, ImageFont

        # Create a blank image with transparent background
        watermark = Image.new('RGBA', (200, 100), (255, 255, 255, 0))
        draw = ImageDraw.Draw(watermark)

        # Draw text on the image
        draw.text((10, 10), "WATERMARK", fill=(255, 255, 255, 128))

        # Save as PNG to preserve transparency
        watermark = watermark.convert("RGB")
        watermark.save(watermark_path)
        print(f"Sample watermark created at {watermark_path}")

    # Download the dataset
    input_dir = download_and_prepare_kaggle_dataset(sample_limit=100)
    
    # First train the model
    model = train_model(
        input_dir=input_dir,
        watermark_path=watermark_path,
        epochs=15,           # Number of training epochs
        batch_size=16,       # Training batch size
        learning_rate=0.001  # Learning rate
    )
    
    # Determine optimal number of workers based on CPU cores
    num_workers = min(os.cpu_count(), 8)  # Cap at 8 workers to avoid excessive resource usage

    # Process the images after training
    num_processed = process_images(
        input_dir,
        watermark_path,
        output_unwatermarked,
        output_watermarked,
        num_workers=num_workers
    )

    print(f"Processing complete! Processed {num_processed} images using {num_workers} workers.")

Using device: mps
Downloading watermarked/non-watermarked dataset from Kaggle...
Dataset downloaded to: /Users/anurag/.cache/kagglehub/datasets/felicepollano/watermarked-not-watermarked-images/versions/1
Dataset preparation complete:
- Total images copied: 0
- Watermarked images: 0
- Non-watermarked images: 0
Training watermark model...
- Input directory: input_images
- Watermark path: logo.webp
- Epochs: 15
- Batch size: 16
- Learning rate: 0.001


ValueError: No training data found in input directory