In [1]:
import os
import shutil
import random
from sklearn.model_selection import train_test_split
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import numpy as np
from PIL import Image

# Data loading and splitting

In [None]:
import os
import shutil
import random
from sklearn.model_selection import train_test_split
from collections import defaultdict

# Set random seed for reproducibility
random.seed(42)

# Define paths
base_dir = "."
content_base_dir = os.path.join(base_dir, "Content", "images", "images")
style_dir = os.path.join(base_dir, "style")

# Output directories for splits
content_train_dir = os.path.join(base_dir, "Content", "train")
content_val_dir = os.path.join(base_dir, "Content", "validation")
content_test_dir = os.path.join(base_dir, "Content", "test")
style_train_dir = os.path.join(base_dir, "style", "train")
style_val_dir = os.path.join(base_dir, "style", "validation")
style_test_dir = os.path.join(base_dir, "style", "test")

# Check if split directories already contain images
def check_existing_splits(dirs):
    for directory in dirs:
        if os.path.exists(directory) and len(os.listdir(directory)) > 0:
            return True
    return False

split_dirs = [content_train_dir, content_val_dir, content_test_dir, 
              style_train_dir, style_val_dir, style_test_dir]

if check_existing_splits(split_dirs):
    print("Existing splits found. Skipping sampling and splitting.")
    print("Using images in Content/train/, Content/validation/, Content/test/, style/train/, style/validation/, style/test/.")
else:
    # Check if input directories exist and contain images
    def check_input_directories(content_dir, style_dir):
        if not os.path.exists(content_dir) or not any(os.path.isdir(os.path.join(content_dir, d)) for d in os.listdir(content_dir)):
            raise FileNotFoundError(f"Content directory {content_dir} is missing or empty. Please populate it with images in subfolders (e.g., architecture, art and culture).")
        if not os.path.exists(style_dir) or not any(f.lower().endswith((".jpg", ".jpeg", ".png")) for f in os.listdir(style_dir)):
            raise FileNotFoundError(f"Style directory {style_dir} is missing or empty. Please populate it with style images.")

    # Helper function to get image paths from subfolders
    def get_image_paths(base_directory):
        image_paths = defaultdict(list)
        for category in os.listdir(base_directory):
            category_path = os.path.join(base_directory, category)
            if os.path.isdir(category_path):
                image_paths[category] = [os.path.join(category_path, f) for f in os.listdir(category_path)
                                        if f.lower().endswith((".jpg", ".jpeg", ".png"))]
        return image_paths

    # Check input directories
    check_input_directories(content_base_dir, style_dir)

    # Create output directories if they don't exist
    for directory in [content_train_dir, content_val_dir, content_test_dir, 
                      style_train_dir, style_val_dir, style_test_dir]:
        os.makedirs(directory, exist_ok=True)

    # Load all image paths
    content_images = get_image_paths(content_base_dir)
    style_images = [os.path.join(style_dir, f) for f in os.listdir(style_dir) 
                    if f.lower().endswith((".jpg", ".jpeg", ".png"))]

    # Print initial counts
    total_content = sum(len(images) for images in content_images.values())
    print(f"Total content images: {total_content}")
    print(f"Total style images: {len(style_images)}")

    # Step 1: Sample the datasets
    # Target: 7,000 content images, 3,000 style images
    target_content = 7000
    target_style = 3000

    # Sample content proportionally from each category
    content_sampled = []
    total_images = sum(len(images) for images in content_images.values())
    for category, images in content_images.items():
        if total_images > 0:
            category_target = int((len(images) / total_images) * target_content)
            category_target = min(category_target, len(images))  # Don't exceed available images
            sampled = random.sample(images, category_target) if len(images) >= category_target else images
            content_sampled.extend(sampled)

    # Adjust if still over 7,000 due to rounding
    if len(content_sampled) > target_content:
        content_sampled = random.sample(content_sampled, target_content)
    elif len(content_sampled) < target_content and total_images > len(content_sampled):
        additional = random.sample([img for cat in content_images.values() for img in cat 
                                  if img not in content_sampled], target_content - len(content_sampled))
        content_sampled.extend(additional)

    # Sample style images
    style_sampled = random.sample(style_images, target_style) if len(style_images) >= target_style else style_images

    print(f"Sampled content images: {len(content_sampled)}")
    print(f"Sampled style images: {len(style_sampled)}")

    # Step 2: Split the data (70/15/15)
    # Content split
    content_train, content_temp = train_test_split(content_sampled, train_size=0.7, random_state=42)
    content_val, content_test = train_test_split(content_temp, test_size=0.5, random_state=42)

    # Style split
    style_train, style_temp = train_test_split(style_sampled, train_size=0.7, random_state=42)
    style_val, style_test = train_test_split(style_temp, test_size=0.5, random_state=42)

    print(f"Content split - Train: {len(content_train)}, Val: {len(content_val)}, Test: {len(content_test)}")
    print(f"Style split - Train: {len(style_train)}, Val: {len(style_val)}, Test: {len(style_test)}")

    # Step 3: Move images to respective split directories
    def move_images(image_paths, destination_dir):
        for img_path in image_paths:
            try:
                shutil.move(img_path, os.path.join(destination_dir, os.path.basename(img_path)))
            except Exception as e:
                print(f"Failed to move {img_path}: {e}")

    # Move content images
    move_images(content_train, content_train_dir)
    move_images(content_val, content_val_dir)
    move_images(content_test, content_test_dir)

    # Move style images
    move_images(style_train, style_train_dir)
    move_images(style_val, style_val_dir)
    move_images(style_test, style_test_dir)

    # Step 4: Verify splits before deleting remaining images
    print("Verifying splits before deleting remaining images...")
    print(f"Content train images: {len(os.listdir(content_train_dir))}")
    print(f"Content validation images: {len(os.listdir(content_val_dir))}")
    print(f"Content test images: {len(os.listdir(content_test_dir))}")
    print(f"Style train images: {len(os.listdir(style_train_dir))}")
    print(f"Style validation images: {len(os.listdir(style_val_dir))}")
    print(f"Style test images: {len(os.listdir(style_test_dir))}")

    # Step 5: Delete remaining (unused) images in original folders
    # Reload remaining content images
    remaining_content_images = get_image_paths(content_base_dir)
    remaining_style_images = [os.path.join(style_dir, f) for f in os.listdir(style_dir) 
                             if f.lower().endswith((".jpg", ".jpeg", ".png"))]

    # Delete remaining content images
    print("Deleting remaining content images...")
    for category, images in remaining_content_images.items():
        category_path = os.path.join(content_base_dir, category)
        for img_path in images:
            try:
                os.remove(img_path)
                print(f"Deleted {img_path}")
            except Exception as e:
                print(f"Failed to delete {img_path}: {e}")
        # Remove category folder if empty
        if not os.listdir(category_path):
            shutil.rmtree(category_path)
            print(f"Removed empty category folder: {category_path}")

    # Delete remaining style images
    print("Deleting remaining style images...")
    for img_path in remaining_style_images:
        try:
            os.remove(img_path)
            print(f"Deleted {img_path}")
        except Exception as e:
            print(f"Failed to delete {img_path}: {e}")

# Final verification (always executed)
print("Final folder structure:")
print(f"Content train images: {len(os.listdir(content_train_dir))}")
print(f"Content validation images: {len(os.listdir(content_val_dir))}")
print(f"Content test images: {len(os.listdir(content_test_dir))}")
print(f"Style train images: {len(os.listdir(style_train_dir))}")
print(f"Style validation images: {len(os.listdir(style_val_dir))}")
print(f"Style test images: {len(os.listdir(style_test_dir))}")

Existing splits found. Skipping sampling and splitting.
Using images in Content/train/, Content/validation/, Content/test/, style/train/, style/validation/, style/test/.
Final folder structure:
Content train images: 4876
Content validation images: 1049
Content test images: 1050
Style train images: 2100
Style validation images: 450
Style test images: 450


# Image Preprocessing and Dataset Creation

In [5]:
from torch.utils.data import Dataset
import itertools

# Define image transformations for preprocessing
def get_transforms(img_size=256):
    # Transformations for both content and style images
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform

# Custom dataset for neural style transfer
class StyleTransferDataset(Dataset):
    def __init__(self, content_dir, style_dir, transform=None, threshold_values=None, max_samples=None):
        """
        Args:
            content_dir: Directory with content images
            style_dir: Directory with style images
            transform: Optional transform to be applied on images
            threshold_values: List of threshold values for style transfer intensity
            max_samples: Maximum number of content-style pairs to use (None = use all combinations)
        """
        self.content_paths = [os.path.join(content_dir, f) for f in os.listdir(content_dir) 
                             if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        self.style_paths = [os.path.join(style_dir, f) for f in os.listdir(style_dir) 
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        self.transform = transform
        
        # Default threshold values if none provided
        if threshold_values is None:
            # Values between 0.0 (content-focused) and 1.0 (style-focused)
            self.threshold_values = [0.2, 0.4, 0.6, 0.8]
        else:
            self.threshold_values = threshold_values
            
        # Create pairs of content-style indices
        content_indices = range(len(self.content_paths))
        style_indices = range(len(self.style_paths))
        
        # Generate all possible combinations
        all_combinations = list(itertools.product(content_indices, style_indices))
        
        # Sample if max_samples is specified
        if max_samples and max_samples < len(all_combinations):
            # Shuffle and take max_samples
            random.seed(42)  # For reproducibility
            random.shuffle(all_combinations)
            self.combinations = all_combinations[:max_samples]
        else:
            self.combinations = all_combinations
    
    def __len__(self):
        return len(self.combinations)
    
    def __getitem__(self, idx):
        # Get content and style indices from precomputed combinations
        content_idx, style_idx = self.combinations[idx]
        
        # Load content and style images
        content_img = Image.open(self.content_paths[content_idx]).convert('RGB')
        style_img = Image.open(self.style_paths[style_idx]).convert('RGB')
        
        # Apply transformations
        if self.transform:
            content_img = self.transform(content_img)
            style_img = self.transform(style_img)
        
        # Randomly select a threshold value
        threshold = random.choice(self.threshold_values)
        
        # Return a triplet: content image, style image, and threshold
        return {
            'content': content_img, 
            'style': style_img, 
            'threshold': torch.tensor(threshold, dtype=torch.float32),
            'content_path': self.content_paths[content_idx],
            'style_path': self.style_paths[style_idx]
        }

In [6]:
# Create DataLoaders for training, validation, and testing
def create_dataloaders(batch_size=16, num_workers=4, img_size=256, threshold_values=None, 
                      max_train_samples=5000, max_val_samples=1000, max_test_samples=500):
    # Define image transformations
    transform = get_transforms(img_size=img_size)
    
    # Define paths to data directories
    content_train_dir = os.path.join("Content", "train")
    content_val_dir = os.path.join("Content", "validation")
    content_test_dir = os.path.join("Content", "test")
    
    style_train_dir = os.path.join("style", "train")
    style_val_dir = os.path.join("style", "validation")
    style_test_dir = os.path.join("style", "test")
    
    # Create datasets with sample limits
    train_dataset = StyleTransferDataset(content_train_dir, style_train_dir, transform, 
                                        threshold_values, max_samples=max_train_samples)
    val_dataset = StyleTransferDataset(content_val_dir, style_val_dir, transform, 
                                      threshold_values, max_samples=max_val_samples)
    test_dataset = StyleTransferDataset(content_test_dir, style_test_dir, transform, 
                                       threshold_values, max_samples=max_test_samples)
    
    # Print dataset details
    print(f"Train dataset: Using {len(train_dataset)} samples out of {len(os.listdir(content_train_dir)) * len(os.listdir(style_train_dir))} possible combinations")
    print(f"Validation dataset: Using {len(val_dataset)} samples out of {len(os.listdir(content_val_dir)) * len(os.listdir(style_val_dir))} possible combinations")
    print(f"Test dataset: Using {len(test_dataset)} samples out of {len(os.listdir(content_test_dir)) * len(os.listdir(style_test_dir))} possible combinations")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader

In [None]:
# Verify the dataset and data loaders
# Note: Only run this after data preparation is complete

# Create data loaders with specified batch size and sample limits
batch_size = 8  # Smaller batch size for testing
train_loader, val_loader, test_loader = create_dataloaders(
    batch_size=batch_size, 
    num_workers=2,
    max_train_samples=7000,  # Limit train samples to 5000
    max_val_samples=1000,    # Limit validation samples to 1000
    max_test_samples=1000     # Limit test samples to 500
)

# Print dataset sizes
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")

# Check a single batch
sample_batch = next(iter(train_loader))
print(f"\nBatch shape for content images: {sample_batch['content'].shape}")
print(f"Batch shape for style images: {sample_batch['style'].shape}")
print(f"Threshold values: {sample_batch['threshold']}")

# Function to denormalize and display images
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    return tensor * std + mean

# Display first few images in the batch (optional - uncomment to visualize)
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6))
for i in range(min(4, batch_size)):
#     # Display content image
     plt.subplot(2, 4, i + 1)
     img = denormalize(sample_batch['content'][i:i+1]).squeeze(0).permute(1, 2, 0).numpy()
     img = np.clip(img, 0, 1)
     plt.imshow(img)
     plt.title(f"Content {i}")
     plt.axis('off')
     
#     # Display style image
     plt.subplot(2, 4, i + 5)
     img = denormalize(sample_batch['style'][i:i+1]).squeeze(0).permute(1, 2, 0).numpy()
     img = np.clip(img, 0, 1)
     plt.imshow(img)
     plt.title(f"Style {i} (t={sample_batch['threshold'][i]:.2f})")
     plt.axis('off')
 
plt.tight_layout()
plt.show()

Train dataset: Using 8000 samples out of 10239600 possible combinations
Validation dataset: Using 1000 samples out of 472050 possible combinations
Test dataset: Using 100 samples out of 472500 possible combinations
Training samples: 8000
Validation samples: 1000
Test samples: 100


: 

: 