# Dog Image Generator using StyleGAN2

This notebook demonstrates how to generate realistic dog images using a pre-trained StyleGAN2 model.
StyleGAN2 is a state-of-the-art generative adversarial network for image synthesis.

## Setup and Installation

First, we'll install the necessary packages.

In [None]:
!pip install torch torchvision numpy matplotlib tqdm gdown

## Download Pre-trained Model

We'll download a pre-trained StyleGAN2 model for dogs.

In [None]:
import os
import gdown

# Create a directory for the model
os.makedirs('models', exist_ok=True)

# Download the pre-trained model for dogs
url = 'https://drive.google.com/uc?id=1yjO5y2S0XA-p59Xkx9n8W9-KlsQxmLbm'
output = 'models/stylegan2-afhqdog.pt'
if not os.path.exists(output):
    gdown.download(url, output, quiet=False)

## Clone StyleGAN2 Repository

We need to clone the StyleGAN2-ADA repository to use its code for generation.

In [None]:
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
import sys
sys.path.append('stylegan2-ada-pytorch')

## Generate Dog Images

Now we'll use the pre-trained model to generate dog images.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Load the pre-trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Import required modules from the StyleGAN2 repository
import dnnlib
import legacy

# Load the network
network_pkl = 'models/stylegan2-afhqdog.pt'
print(f'Loading networks from "{network_pkl}"...')
with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device)

In [None]:
# Function to generate images
def generate_images(num_images=5, seed=None):
    # Set random seed for reproducibility if provided
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
    
    # Generate random latent vectors
    z = torch.randn(num_images, G.z_dim).to(device)
    
    # Generate images
    with torch.no_grad():
        img = G(z, None)
    
    # Convert images to numpy arrays
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
    
    return img

# Generate and display some images
images = generate_images(num_images=5, seed=42)

# Plot the generated images
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
for i, ax in enumerate(axes):
    ax.imshow(images[i])
    ax.axis('off')
    ax.set_title(f"Generated Dog {i+1}")
plt.tight_layout()
plt.show()

## Generate and Save a Dataset of Dog Images

Let's generate a larger dataset of dog images for machine learning purposes.

In [None]:
import os
from PIL import Image

# Create a directory to save generated images
output_dir = 'generated_dogs'
os.makedirs(output_dir, exist_ok=True)

# Number of images to generate
num_images = 100

# Generate images in batches
batch_size = 10
num_batches = num_images // batch_size

print(f"Generating {num_images} dog images...")

for batch_idx in tqdm(range(num_batches)):
    # Generate a batch of images
    batch_images = generate_images(num_images=batch_size, seed=batch_idx)
    
    # Save each image in the batch
    for i, img in enumerate(batch_images):
        img_idx = batch_idx * batch_size + i
        img_path = os.path.join(output_dir, f'dog_{img_idx:04d}.png')
        
        # Convert numpy array to PIL Image and save
        Image.fromarray(img).save(img_path)

print(f"Successfully generated and saved {num_images} dog images to {output_dir}/")

## Interpolation Between Dog Images

We can also create smooth transitions between different dog images.

In [None]:
def interpolate_images(num_steps=10):
    # Generate two random latent vectors
    z1 = torch.randn(1, G.z_dim).to(device)
    z2 = torch.randn(1, G.z_dim).to(device)
    
    # Create interpolation steps
    alphas = np.linspace(0, 1, num_steps)
    interpolated_images = []
    
    # Generate images at each interpolation step
    for alpha in alphas:
        # Linear interpolation between the two latent vectors
        z_interp = (1 - alpha) * z1 + alpha * z2
        
        # Generate image
        with torch.no_grad():
            img = G(z_interp, None)
        
        # Convert to numpy array
        img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()[0]
        interpolated_images.append(img)
    
    return interpolated_images

# Generate interpolated images
interpolated_images = interpolate_images(num_steps=10)

# Plot the interpolation
fig, axes = plt.subplots(1, 10, figsize=(20, 4))
for i, ax in enumerate(axes):
    ax.imshow(interpolated_images[i])
    ax.axis('off')
    ax.set_title(f"Step {i+1}")
plt.tight_layout()
plt.show()

## Style Mixing

One of the cool features of StyleGAN2 is style mixing, where we can combine features from different dogs.

In [None]:
def style_mixing(num_source=4, num_dest=3):
    # Generate source and destination latent vectors
    src_latents = torch.randn(num_source, G.z_dim).to(device)
    dst_latents = torch.randn(num_dest, G.z_dim).to(device)
    
    # Maps from Z to W space
    with torch.no_grad():
        src_ws = G.mapping(src_latents, None)  # [NUM_SRC, num_ws, w_dim]
        dst_ws = G.mapping(dst_latents, None)  # [NUM_DST, num_ws, w_dim]
        
        # Style layer indices to mix
        # Low (0-3): Coarse features (pose, shape)
        # Middle (4-8): Mid-level features (fur, ears, etc.)
        # High (9+): Fine details (colors, textures)
        mix_ranges = [[0, 3], [4, 8], [9, G.num_ws-1]]
        src_images = []
        mixed_images = []
        
        # Generate source images
        for src_idx in range(num_source):
            src_img = G.synthesis(src_ws[src_idx:src_idx+1], noise_mode='const')
            src_img = (src_img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()[0]
            src_images.append(src_img)
        
        # Generate mixed images for each destination and style range
        for dst_idx in range(num_dest):
            for mix_range in mix_ranges:
                # Create a copy of the destination latent
                w = dst_ws[dst_idx:dst_idx+1].clone()
                
                # Style mixing
                for src_idx in range(num_source):
                    # Apply source style to the specified range
                    w_mixed = w.clone()
                    w_mixed[:, mix_range[0]:mix_range[1]+1] = src_ws[src_idx:src_idx+1, mix_range[0]:mix_range[1]+1]
                    
                    # Generate mixed image
                    img = G.synthesis(w_mixed, noise_mode='const')
                    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()[0]
                    mixed_images.append((src_idx, dst_idx, mix_range, img))
    
    return src_images, mixed_images

# Generate source and mixed images
src_images, mixed_images = style_mixing(num_source=4, num_dest=3)

# Plot source images
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for i, ax in enumerate(axes):
    ax.imshow(src_images[i])
    ax.axis('off')
    ax.set_title(f"Source {i+1}")
plt.tight_layout()
plt.show()

# Plot mixed images
# We'll display 3x4x3 grid (3 destinations, 4 sources, 3 style ranges)
mix_ranges = [[0, 3], [4, 8], [9, G.num_ws-1]]
range_names = ["Coarse", "Medium", "Fine"]

fig, axes = plt.subplots(3, 12, figsize=(24, 7))
for dst_idx in range(3):
    for range_idx, (mix_range, range_name) in enumerate(zip(mix_ranges, range_names)):
        for src_idx in range(4):
            # Find the corresponding mixed image
            for mix_data in mixed_images:
                if mix_data[0] == src_idx and mix_data[1] == dst_idx and mix_data[2] == mix_range:
                    idx = dst_idx * 12 + range_idx * 4 + src_idx
                    row, col = idx // 12, idx % 12
                    axes[row, col].imshow(mix_data[3])
                    axes[row, col].axis('off')
                    if src_idx == 0:
                        axes[row, col].set_title(f"Dest {dst_idx+1}\n{range_name}")
                    else:
                        axes[row, col].set_title(f"Src {src_idx}")
plt.tight_layout()
plt.show()