In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob
from sklearn.model_selection import train_test_split

In [14]:
class PlacesDataset(Dataset):
    def __init__(self, root, transform=None, train=True, split_ratio=0.8):
        self.root = root
        self.transform = transform
        self.image_files = sorted(glob.glob(os.path.join(root, "**/*.jpg"), recursive=True))

        if len(self.image_files) == 0:
            raise RuntimeError(f"No images found in {root}. Check folder structure!")

        # Split dataset into train & validation
        train_files, val_files = train_test_split(self.image_files, train_size=split_ratio, random_state=42)
        self.image_files = train_files if train else val_files  # Use train or val set

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]

        try:
            real = Image.open(img_path).convert("RGB")  # Load real image
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return self.__getitem__((idx + 1) % len(self))  # Skip to next image

        if self.transform:
            real = self.transform(real)

        return real  # Only return real image (no edges)

# Define dataset path for Places
places_path = r"D:\scribbles\places"


# Define image transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize to match DCGAN generator output
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1,1] for Tanh activation
])

# Load train and validation datasets
places_train_dataset = PlacesDataset(places_path, transform=transform, train=True)
places_val_dataset = PlacesDataset(places_path, transform=transform, train=False)

# Create dataloaders
batch_size = 16
places_train_dataloader = DataLoader(places_train_dataset, batch_size=batch_size, shuffle=True)
places_val_dataloader = DataLoader(places_val_dataset, batch_size=batch_size, shuffle=False)

# Print dataset sizes
print(f"Places Train Dataset Loaded: {len(places_train_dataset)} images")
print(f"Places Validation Dataset Loaded: {len(places_val_dataset)} images")

Places Train Dataset Loaded: 2428 images
Places Validation Dataset Loaded: 608 images


In [12]:
import torch
import torch.nn as nn

class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=3, features_g=64):
        super(DCGANGenerator, self).__init__()

        self.net = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, features_g * 16, 4, 1, 0, bias=False),  # 4x4
            nn.BatchNorm2d(features_g * 16),
            nn.ReLU(True),

            nn.ConvTranspose2d(features_g * 16, features_g * 8, 4, 2, 1, bias=False),  # 8x8
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),  # 16x16
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),  # 32x32
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),  # 64x64
            nn.BatchNorm2d(features_g),
            nn.ReLU(True),

            nn.ConvTranspose2d(features_g, img_channels, 4, 2, 1, bias=False),  # 128x128
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

print(" Updated Generator for 128x128 resolution")



 Updated Generator for 128x128 resolution


In [6]:
class DCGANDiscriminator(nn.Module):
    def __init__(self, img_channels=3, features_d=64):
        super(DCGANDiscriminator, self).__init__()

        self.net = nn.Sequential(
            nn.Conv2d(img_channels, features_d, 4, 2, 1, bias=False),  # 128x128 → 64x64
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),  # 64x64 → 32x32
            nn.BatchNorm2d(features_d * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),  # 32x32 → 16x16
            nn.BatchNorm2d(features_d * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),  # 16x16 → 8x8
            nn.BatchNorm2d(features_d * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features_d * 8, features_d * 16, 4, 2, 1, bias=False),  # 8x8 → 4x4
            nn.BatchNorm2d(features_d * 16),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features_d * 16, 1, 4, 1, 0, bias=False),  # 4x4 → 1x1
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.net(img)

print(" Updated Discriminator for 128x128 resolution")


 Updated Discriminator for 128x128 resolution


In [8]:
# Define dataset path for Places
places_path = r"D:\scribbles\places"

# Define image transformations (Set resolution to 128x128)
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Ensure all images are resized to 128x128
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1,1] for Tanh activation
])

# Load train and validation datasets
places_train_dataset = PlacesDataset(places_path, transform=transform, train=True)
places_val_dataset = PlacesDataset(places_path, transform=transform, train=False)

# Create dataloaders
batch_size = 16  # Reduce if memory runs out
places_train_dataloader = DataLoader(places_train_dataset, batch_size=batch_size, shuffle=True)
places_val_dataloader = DataLoader(places_val_dataset, batch_size=batch_size, shuffle=False)

print(f"Places Train Dataset Loaded: {len(places_train_dataset)} images")
print(f"Places Validation Dataset Loaded: {len(places_val_dataset)} images")


Places Train Dataset Loaded: 2428 images
Places Validation Dataset Loaded: 608 images


In [16]:
# Adjust Generator and Discriminator
places_generator = DCGANGenerator(latent_dim=100)
places_discriminator = DCGANDiscriminator()

# Loss and Optimizers
criterion = nn.BCELoss()
places_optimizer_G = torch.optim.Adam(places_generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
places_optimizer_D = torch.optim.Adam(places_discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

print(" Loss functions and optimizers initialized for Places dataset!")

 Loss functions and optimizers initialized for Places dataset!


In [20]:
# Define the models
places_generator = DCGANGenerator()
places_discriminator = DCGANDiscriminator()

# Load the saved weights
checkpoint = torch.load("F:\places_dcgan_generator_weights.pth", map_location="cpu")
places_generator.load_state_dict(checkpoint, strict=False)
checkpoint = torch.load("F:\places_dcgan_discriminator_weights.pth", map_location="cpu")
places_discriminator.load_state_dict(checkpoint, strict=False)

# Set models to evaluation mode
places_generator.eval()
print(" Generator loaded ")
places_discriminator.eval()
print(" Discriminator loaded ")

RuntimeError: Error(s) in loading state_dict for DCGANGenerator:
	size mismatch for net.0.weight: copying a param with shape torch.Size([100, 512, 4, 4]) from checkpoint, the shape in current model is torch.Size([100, 1024, 4, 4]).
	size mismatch for net.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for net.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for net.1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for net.1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for net.3.weight: copying a param with shape torch.Size([512, 256, 4, 4]) from checkpoint, the shape in current model is torch.Size([1024, 512, 4, 4]).
	size mismatch for net.4.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for net.4.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for net.4.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for net.4.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for net.6.weight: copying a param with shape torch.Size([256, 128, 4, 4]) from checkpoint, the shape in current model is torch.Size([512, 256, 4, 4]).
	size mismatch for net.7.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.7.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.7.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.7.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.9.weight: copying a param with shape torch.Size([128, 64, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 128, 4, 4]).
	size mismatch for net.10.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for net.10.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for net.10.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for net.10.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for net.12.weight: copying a param with shape torch.Size([64, 3, 4, 4]) from checkpoint, the shape in current model is torch.Size([128, 64, 4, 4]).

# Project Real Images onto the DCGAN Latent Space

In [16]:
from torchvision import models
from torchvision.models import VGG16_Weights

vgg16 = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:16].eval()


def perceptual_loss(gen_image, real_image):
    """
    Computes perceptual loss using VGG16 features.
    """
    with torch.no_grad():
        real_features = vgg16(real_image)
    gen_features = vgg16(gen_image)
    return F.mse_loss(gen_features, real_features)

In [17]:
import torch
import torch.nn.functional as F
from torchvision import transforms

def project_image_to_latent(image, generator, num_steps=500, lr=0.1):
    """
    Projects a real image onto the DCGAN latent space by optimizing z.
    """
    generator.eval()  # Set generator to evaluation mode

    # Resize real image to match GAN output (128x128)
    image = transforms.Resize((128, 128))(image)  # Ensure same size
    image = image.unsqueeze(0)  # Add batch dimension
    image = (image / 127.5) - 1  # Normalize to [-1,1] range

    # Initialize latent vector z with random noise
    z = torch.randn(1, 100, 1, 1, requires_grad=True)

    # Define optimizer for z
    optimizer = torch.optim.Adam([z], lr=lr)

    for step in range(num_steps):
        optimizer.zero_grad()

        # Generate image from z
        generated_image = generator(z)

        # Compute loss (Mean Squared Error between real and generated images)
        loss = F.mse_loss(generated_image, image)

        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            print(f"Step {step}/{num_steps} - Loss: {loss.item():.6f}")

    return z.detach()

# Example usage:
real_batch = next(iter(places_val_dataloader))  # Get a batch of images
real_image = real_batch[0]  # Extract a single image
real_image = real_image 

projected_z = project_image_to_latent(real_image, places_generator)
print("Image projected onto latent space successfully!")


Step 0/500 - Loss: 0.895149
Step 100/500 - Loss: 0.000091
Step 200/500 - Loss: 0.000068
Step 300/500 - Loss: 0.000058
Step 400/500 - Loss: 0.000051
Image projected onto latent space successfully!


# Perform Image Editing in the Latent Space

In [20]:
def modify_latent_vector(z, generator, edit_direction, intensity=0.5):
    """
    Modifies the latent vector in a specific direction to change color or shape.
    """
    modified_z = z + intensity * edit_direction
    modified_z = modified_z.detach().clone().requires_grad_(True)  # Ensure requires_grad is True

    return modified_z

# Example: Apply a small color change
color_edit_direction = torch.randn_like(projected_z)  # Random color change direction
modified_z = modify_latent_vector(projected_z, places_generator, color_edit_direction, intensity=0.2)

# Generate the edited image
edited_image = places_generator(modified_z)


# Transfer Edits Back to High-Resolution Images

In [22]:
import cv2
import numpy as np

def compute_optical_flow(img1, img2):
    """
    Computes dense optical flow between two images using Farneback method.
    """
    img1_gray = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY)
    img2_gray = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY)

    flow = cv2.calcOpticalFlowFarneback(img1_gray, img2_gray, None, 
                                        0.5, 3, 15, 3, 5, 1.2, 0)

    return flow.astype(np.float32)  # Ensure float32 type

In [24]:
def apply_optical_flow(img, flow):
    """
    Applies motion flow to warp the original image.
    """
    h, w = flow.shape[:2]

    # Create a grid of (x,y) coordinates
    y_coords, x_coords = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')

    # Compute remap coordinates by adding flow vectors
    map_x = (x_coords + flow[..., 0]).astype(np.float32)
    map_y = (y_coords + flow[..., 1]).astype(np.float32)

    # Warp image using flow
    warped = cv2.remap(img, map_x, map_y, interpolation=cv2.INTER_LINEAR)

    return warped

# Convert tensors to numpy images
real_img_np = (real_image.cpu().squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
edited_img_np = (edited_image.detach().cpu().squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)

# Compute motion flow
flow = compute_optical_flow(real_img_np, edited_img_np)

# Apply flow to high-resolution original image
final_high_res = apply_optical_flow(real_img_np, flow)

cv2.imwrite("final_high_res_edit.jpg", final_high_res)
print("Edited high-resolution image saved successfully!")

Edited high-resolution image saved successfully!


# Interactive UI

In [35]:
pip install pyqt5 opencv-python numpy torch torchvision


Note: you may need to restart the kernel to use updated packages.


DEPRECATION: Loading egg at c:\users\shashwati\anaconda3\lib\site-packages\apache_beam-2.60.0-py3.11-win-amd64.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330
DEPRECATION: Loading egg at c:\users\shashwati\anaconda3\lib\site-packages\avro_python3-1.10.2-py3.11.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330
DEPRECATION: Loading egg at c:\users\shashwati\anaconda3\lib\site-packages\contextlib2-21.6.0-py3.11.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330
DEPRECATION: Loading egg at c:\users\shashwati\anaconda3\lib\site-packages\lvis-0.5.3-py3.11.egg is d

In [26]:
import sys
import torch
import cv2
import numpy as np
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QPushButton, QFileDialog, QVBoxLayout, QWidget
from PyQt5.QtGui import QPixmap, QImage, QPainter, QPen
from PyQt5.QtCore import Qt, QPoint

In [28]:
class ImageEditor(QMainWindow):
    def __init__(self):
        super().__init__()

        self.setWindowTitle("Realistic Image Manipulation UI")
        self.setGeometry(100, 100, 800, 600)

        # UI Elements
        self.image_label = QLabel(self)
        self.image_label.setAlignment(Qt.AlignCenter)

        self.load_button = QPushButton("Load Image", self)
        self.load_button.clicked.connect(self.load_image)

        self.color_brush_button = QPushButton("Color Brush", self)
        self.color_brush_button.clicked.connect(self.use_color_brush)

        self.warp_tool_button = QPushButton("Warp Tool", self)
        self.warp_tool_button.clicked.connect(self.use_warp_tool)

        self.apply_gan_button = QPushButton("Apply GAN Edit", self)
        self.apply_gan_button.clicked.connect(self.apply_gan_edit)

        self.save_button = QPushButton("Save Image", self)
        self.save_button.clicked.connect(self.save_image)

        # Layout
        layout = QVBoxLayout()
        layout.addWidget(self.image_label)
        layout.addWidget(self.load_button)
        layout.addWidget(self.color_brush_button)
        layout.addWidget(self.warp_tool_button)
        layout.addWidget(self.apply_gan_button)
        layout.addWidget(self.save_button)

        container = QWidget()
        container.setLayout(layout)
        self.setCentralWidget(container)

        # Image Variables
        self.image = None
        self.image_path = None
        self.last_point = QPoint()
        self.drawing = False
        self.brush_color = (255, 0, 0)  # Default red brush
        self.current_tool = None
        self.places_generator = places_generator  # Ensure the generator is available in the class

    def load_image(self):
        """Loads an image from disk and displays it."""
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getOpenFileName(self, "Load Image", "", "Images (*.png *.jpg *.jpeg)", options=options)

        if file_path:
            self.image_path = file_path
            self.image = cv2.imread(file_path)
            self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)  # Convert to RGB
            self.display_image()

    def display_image(self):
        """Displays the current image on the UI."""
        if self.image is None:
            return

        h, w, ch = self.image.shape
        bytes_per_line = ch * w
        q_image = QImage(self.image.data, w, h, bytes_per_line, QImage.Format_RGB888)

        pixmap = QPixmap.fromImage(q_image)
        self.image_label.setPixmap(pixmap)

    def use_color_brush(self):
        """Activates color brush tool."""
        self.current_tool = "color_brush"

    def use_warp_tool(self):
        """Activates warp tool."""
        self.current_tool = "warp_tool"

    def mousePressEvent(self, event):
        """Handles mouse press for drawing and warping."""
        if self.image is None or self.current_tool is None:
            return

        self.drawing = True
        self.last_point = event.pos()

    def mouseMoveEvent(self, event):
        """Handles drawing or warping on mouse movement."""
        if not self.drawing or self.image is None:
            return

        # Convert PyQt coordinates to NumPy image coordinates
        x1, y1 = self.last_point.x(), self.last_point.y()
        x2, y2 = event.pos().x(), event.pos().y()

        h, w, _ = self.image.shape
        x1, x2 = int(x1 * w / self.width()), int(x2 * w / self.width())
        y1, y2 = int(y1 * h / self.height()), int(y2 * h / self.height())

        if self.current_tool == "color_brush":
            cv2.line(self.image, (x1, y1), (x2, y2), self.brush_color, thickness=5)

        elif self.current_tool == "warp_tool":
            dx = x2 - x1
            dy = y2 - y1
            self.image = self.warp_image(dx, dy)

        self.last_point = event.pos()
        self.display_image()

    def mouseReleaseEvent(self, event):
        """Stops drawing or warping on mouse release."""
        self.drawing = False

    def warp_image(self, dx, dy):
        """Applies a warp effect to the image based on user input."""
        h, w = self.image.shape[:2]
        map_x, map_y = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')

        map_x = map_x.astype(np.float32) + dx
        map_y = map_y.astype(np.float32) + dy

        warped = cv2.remap(self.image, map_x, map_y, interpolation=cv2.INTER_LINEAR)
        return warped

    def apply_gan_edit(self):
        """Applies the trained DCGAN edits to the image."""
        if self.image is None:
            return

        # Resize image to 128x128 to match GAN output
        resized_img = cv2.resize(self.image, (128, 128))

        # Convert image to tensor
        image_tensor = torch.tensor(resized_img).float().permute(2, 0, 1).unsqueeze(0)
        image_tensor = (image_tensor / 127.5) - 1  # Normalize to [-1,1]

        # Project to latent space
        projected_z = torch.randn(1, 100, 1, 1, requires_grad=True)
        optimizer = torch.optim.Adam([projected_z], lr=0.1)

        for step in range(300):
            optimizer.zero_grad()
            generated_image = self.places_generator(projected_z)

            # Ensure generated_image is 128x128 before calculating loss
            loss = torch.nn.functional.mse_loss(generated_image, image_tensor)

            loss.backward()
            optimizer.step()

        # Convert generated image to NumPy format
        edited_image = generated_image.detach().cpu().squeeze().permute(1, 2, 0).numpy()
        edited_image = ((edited_image + 1) * 127.5).astype(np.uint8)  # Convert back to [0,255]

        # Resize back to original size
        self.image = cv2.resize(edited_image, (self.image.shape[1], self.image.shape[0]))
        self.display_image()

    def save_image(self):
        """Saves the edited image."""
        if self.image is None:
            return

        save_path, _ = QFileDialog.getSaveFileName(self, "Save Image", "", "Images (*.png *.jpg *.jpeg)")
        if save_path:
            cv2.imwrite(save_path, cv2.cvtColor(self.image, cv2.COLOR_RGB2BGR))
            print("Image saved successfully!")

# Run Application
if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = ImageEditor()
    window.show()
    sys.exit(app.exec_())


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
