In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import os
import torch.nn.functional as F
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import kagglehub
import shutil

# Use GPU if available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Define the CNN model with batch normalization and more efficient architecture
class WatermarkCNN(nn.Module):
    def __init__(self):
        super(WatermarkCNN, self).__init__()
        # Encoder with batch normalization for faster convergence and training stability
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),  # inplace operations save memory
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

        # Decoder with batch normalization
        self.decoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

        # Move model to GPU if available
        self.to(device)

    def encode(self, image, watermark):
        # Optimize watermark preparation
        watermark = watermark.resize((image.shape[3], image.shape[2]))
        watermark_tensor = transforms.ToTensor()(watermark).unsqueeze(0).to(device)

        # Handle batch processing
        if watermark_tensor.shape[0] != image.shape[0]:
            watermark_tensor = watermark_tensor.repeat(image.shape[0], 1, 1, 1)

        encoded_image = self.encoder(image + watermark_tensor)
        return encoded_image

    def decode(self, watermarked_image):
        decoded_watermark = self.decoder(watermarked_image)
        return decoded_watermark

# Cache for the model instance to avoid reloading
model_cache = None

def get_model():
    global model_cache
    if model_cache is None:
        model_cache = WatermarkCNN()
        # Here you would load pre-trained weights if available
        # model_cache.load_state_dict(torch.load('watermark_model.pth'))
    return model_cache

# Optimized detection with batching support
def detect_watermark(image_path):
    model = get_model()
    image = Image.open(image_path).convert("RGB")
    transform = transforms.ToTensor()
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        watermark = model.decode(image_tensor)
        has_watermark = torch.mean(watermark) > 0.05  # Threshold for watermark presence

    return has_watermark.item(), image_tensor

# Optimized watermark removal
def remove_watermark(image_tensor):
    model = get_model()
    with torch.no_grad():
        watermark = model.decode(image_tensor)
        cleaned_image = image_tensor - watermark
    return cleaned_image

# Optimized watermark addition
def add_watermark(image_tensor, watermark):
    model = get_model()
    return model.encode(image_tensor, watermark)

# Process single image
def process_single_image(filename, input_dir, watermark_img, output_unwatermarked, output_watermarked):
    img_path = os.path.join(input_dir, filename)

    try:
        has_watermark, image_tensor = detect_watermark(img_path)

        if has_watermark:
            cleaned_image = remove_watermark(image_tensor)
            cleaned_image_pil = transforms.ToPILImage()(cleaned_image.squeeze(0).cpu())
            output_path = os.path.join(output_unwatermarked, filename)
            cleaned_image_pil.save(output_path)
            return f"Removed watermark from: {filename}"
        else:
            watermarked_image = add_watermark(image_tensor, watermark_img)
            watermarked_image_pil = transforms.ToPILImage()(watermarked_image.squeeze(0).cpu())
            output_path = os.path.join(output_watermarked, filename)
            watermarked_image_pil.save(output_path)
            return f"Added watermark to: {filename}"
    except Exception as e:
        return f"Error processing {filename}: {str(e)}"

# Main processing function with parallel execution
def process_images(input_dir, watermark_path, output_unwatermarked, output_watermarked, num_workers=None):
    # Create output directories if they don't exist
    os.makedirs(output_unwatermarked, exist_ok=True)
    os.makedirs(output_watermarked, exist_ok=True)

    # Load watermark once
    watermark_img = Image.open(watermark_path).convert("RGB")

    # Get list of image files
    image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]

    # Process images in parallel
    process_func = partial(
        process_single_image,
        input_dir=input_dir,
        watermark_img=watermark_img,
        output_unwatermarked=output_unwatermarked,
        output_watermarked=output_watermarked
    )

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        results = list(executor.map(process_func, image_files))

    # Print results
    for result in results:
        print(result)

    return len(image_files)

# Function to download and prepare Kaggle dataset
def download_and_prepare_kaggle_dataset():
    print("Downloading watermarked/non-watermarked dataset from Kaggle...")
    dataset_path = kagglehub.dataset_download("felicepollano/watermarked-not-watermarked-images")
    print(f"Dataset downloaded to: {dataset_path}")

    # Create input directory for our processing
    input_dir = "input_images"
    os.makedirs(input_dir, exist_ok=True)

    # Copy a sample of images from the dataset to our input directory
    # The dataset likely has a structure we need to navigate
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                source_path = os.path.join(root, file)
                dest_path = os.path.join(input_dir, file)
                shutil.copy2(source_path, dest_path)
                print(f"Copied {file} to input directory")

    return input_dir

if __name__ == "__main__":
    # Download and prepare the Kaggle dataset
    input_dir = download_and_prepare_kaggle_dataset()

    # Paths and directories
    watermark_path = "watermark.png"  # We'll need to create or get a watermark image
    output_unwatermarked = "output/unwatermarked"
    output_watermarked = "output/watermarked"

    # Check if watermark image exists, if not create a simple one
    if not os.path.exists(watermark_path):
        print("Creating a sample watermark image...")
        from PIL import Image, ImageDraw, ImageFont

        # Create a blank image with transparent background
        watermark = Image.new('RGBA', (200, 100), (255, 255, 255, 0))
        draw = ImageDraw.Draw(watermark)

        # Draw text on the image
        draw.text((10, 10), "WATERMARK", fill=(255, 255, 255, 128))

        # Save as PNG to preserve transparency
        watermark = watermark.convert("RGB")
        watermark.save(watermark_path)
        print(f"Sample watermark created at {watermark_path}")

    # Determine optimal number of workers based on CPU cores
    num_workers = min(os.cpu_count(), 8)  # Cap at 8 workers to avoid excessive resource usage

    # Process the images
    num_processed = process_images(
        input_dir,
        watermark_path,
        output_unwatermarked,
        output_watermarked,
        num_workers=num_workers
    )

    print(f"Processing complete! Processed {num_processed} images using {num_workers} workers.")

Downloading watermarked/non-watermarked dataset from Kaggle...
Dataset downloaded to: /Users/anurag/.cache/kagglehub/datasets/felicepollano/watermarked-not-watermarked-images/versions/1
Copied pexels-photo-4393532.jpeg to input directory
Copied pexels-photo-3614167.jpeg to input directory
Copied pexels-photo-442559.jpeg to input directory
Copied pexels-photo-1309052.jpeg to input directory
Copied pexels-photo-2587319.jpeg to input directory
Copied pexels-photo-3150551.jpeg to input directory
Copied pexels-photo-3760514.jpeg to input directory
Copied pexels-photo-164022.jpeg to input directory
Copied pexels-photo-246805.jpeg to input directory
Copied pexels-photo-2263670.jpeg to input directory
Copied pexels-photo-1458373.jpeg to input directory
Copied pexels-photo-274974.jpeg to input directory
Copied pexels-photo-4430883.jpeg to input directory
Copied pexels-photo-988610.jpeg to input directory
Copied pexels-photo-1490908.jpeg to input directory
Copied pexels-photo-3760778.jpeg to inp