In [73]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.utils as vutils
from PIL import Image
import os
from pathlib import Path
import math
import random
import numpy as np
from tqdm import tqdm
from datetime import datetime
import multiprocessing

# Set random seeds for reproducibility
def set_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [74]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from pathlib import Path

class StyleGANConfig:
    def __init__(self):
        # Basic settings
        self.image_size = 256
        self.style_dim = 512
        self.latent_dim = 512
        self.n_mlp = 8
        
        # Training settings
        self.batch_size = 8
        self.lr = 0.0002
        self.beta1 = 0.0
        self.beta2 = 0.99
        
        # Device settings
        self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")



# Mapping Network - Converts latent vectors to style vectors
class MappingNetwork(nn.Module):
    def __init__(self, latent_dim, style_dim, n_mlp):
        super().__init__()
        layers = []
        # First layer from latent_dim to style_dim
        layers.append(nn.Linear(latent_dim, style_dim))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        
        # Additional layers
        for _ in range(n_mlp - 1):
            layers.append(nn.Linear(style_dim, style_dim))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            
        self.mapping = nn.Sequential(*layers)

    def forward(self, z):
        # Normalize input
        z = F.normalize(z, dim=1)
        return self.mapping(z)

# Style-based Generator
class Generator(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.style_dim = config.style_dim
        
        # Mapping network
        self.mapping = nn.Sequential(
            *[nn.Sequential(
                EqualLinear(config.style_dim, config.style_dim),
                nn.LeakyReLU(0.2, inplace=True)
            ) for _ in range(config.n_mlp)]
        )
        
        # Initial input
        self.input = nn.Parameter(torch.randn(1, config.style_dim, 4, 4))
        
        # Progressive blocks with consistent dimensions
        channels = [512, 512, 256, 128, 64, 32]
        self.blocks = nn.ModuleList()
        in_chan = config.style_dim
        
        for out_chan in channels:
            self.blocks.append(StyleBlock(in_chan, out_chan, config.style_dim))
            in_chan = out_chan
            
        # RGB output
        self.to_rgb = nn.Sequential(
            nn.Conv2d(channels[-1], 3, 1),
            nn.Tanh()
        )

    def forward(self, z):
        # Map to style space
        w = z
        for layer in self.mapping:
            w = layer(w)
        
        # Start from learned constant
        x = self.input.repeat(z.shape[0], 1, 1, 1)
        
        # Apply style blocks
        for block in self.blocks:
            x = block(x, w)
        
        # Convert to RGB
        return self.to_rgb(x)

# Style Block with AdaIN
class StyleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, style_dim=512):
        super().__init__()
        self.conv1 = ModulatedConv2d(in_channels, out_channels, 3, style_dim)
        self.noise1 = NoiseInjection()
        self.activation = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x, w):
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.conv1(x, w)
        x = self.noise1(x)
        x = self.activation(x)
        return x
        
class AdaIN(nn.Module):
    def __init__(self, channels, style_dim=512):
        super().__init__()
        self.norm = nn.InstanceNorm2d(channels, affine=False)
        self.style = EqualLinear(style_dim, channels * 2)  # Using EqualLinear for better stability

    def forward(self, x, w):
        style = self.style(w).unsqueeze(2).unsqueeze(3)
        gamma, beta = style.chunk(2, 1)
        return (1 + gamma) * self.norm(x) + beta

class EqualLinear(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
        self.bias = nn.Parameter(torch.zeros(out_dim))
        self.scale = (1 / math.sqrt(in_dim))

    def forward(self, input):
        out = F.linear(input, self.weight * self.scale, self.bias)
        return out
        
# Simple Discriminator
class Discriminator(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        # Progressive downsample
        self.main = nn.Sequential(
            # 256x256 -> 128x128
            nn.Conv2d(3, 16, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 128x128 -> 64x64
            nn.Conv2d(16, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 64x64 -> 32x32
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 32x32 -> 16x16
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 16x16 -> 8x8
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 8x8 -> 4x4
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Final classification
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Flatten()
        )

    def forward(self, x):
        x = self.main(x)
        return self.classifier(x)

# Dataset loader
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
import os
from pathlib import Path
import multiprocessing

# Set multiprocessing start method
try:
    multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
    pass

class CelebAHQDataset(Dataset):
    def __init__(self, root_dir, image_size=256):
        super().__init__()
        self.root_dir = Path(root_dir)
        self.image_paths = list(self.root_dir.glob("*.jpg"))
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = str(self.image_paths[idx])  # Convert Path to string
        try:
            image = Image.open(img_path).convert('RGB')
            return self.transform(image)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a blank image in case of error
            return torch.zeros(3, self.transform.transforms[0].size, 
                             self.transform.transforms[0].size)

def create_dataloader(root_dir, config):
    """Create dataloader with error handling"""
    try:
        dataset = CelebAHQDataset(root_dir, config.image_size)
        return DataLoader(
            dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=0,  # Set to 0 for debugging
            pin_memory=True if torch.backends.mps.is_available() else False,
            drop_last=True
        )
    except Exception as e:
        print(f"Error creating dataloader: {e}")
        return None

class NoiseInjection(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1))

    def forward(self, x, noise=None):
        if noise is None:
            batch, _, height, width = x.shape
            noise = torch.randn(batch, 1, height, width, device=x.device)
        return x + self.weight * noise

    
# Training helper
def train_step(real_imgs, generator, discriminator, g_optimizer, d_optimizer, config):
    batch_size = real_imgs.size(0)
    device = config.device
    
    # Train Discriminator
    d_optimizer.zero_grad()
    
    # Real images
    real_pred = discriminator(real_imgs)
    d_real_loss = F.softplus(-real_pred).mean()
    
    # Fake images
    z = torch.randn(batch_size, config.latent_dim, device=device)
    with torch.no_grad():
        fake_imgs = generator(z)
    fake_pred = discriminator(fake_imgs)
    d_fake_loss = F.softplus(fake_pred).mean()
    
    d_loss = d_real_loss + d_fake_loss
    d_loss.backward()
    d_optimizer.step()
    
    # Train Generator
    g_optimizer.zero_grad()
    
    z = torch.randn(batch_size, config.latent_dim, device=device)
    fake_imgs = generator(z)
    fake_pred = discriminator(fake_imgs)
    
    g_loss = F.softplus(-fake_pred).mean()
    g_loss.backward()
    g_optimizer.step()
    
    # Clear cache for M1
    if device.type == "mps":
        torch.mps.empty_cache()
    
    return {
        'd_loss': d_loss.item(),
        'g_loss': g_loss.item(),
        'd_real': torch.sigmoid(real_pred).mean().item(),
        'd_fake': torch.sigmoid(fake_pred).mean().item()
    }

In [75]:
class EqualLinear(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
        self.bias = nn.Parameter(torch.zeros(out_dim))
        self.scale = (1 / math.sqrt(in_dim))

    def forward(self, input):
        out = F.linear(input, self.weight * self.scale, self.bias)
        return out

def save_sample_images(generator, config, name, num_samples=16, nrow=4):
    """Generate and save sample images"""
    try:
        os.makedirs('samples', exist_ok=True)
        
        generator.eval()
        with torch.no_grad():
            z = torch.randn(num_samples, config.latent_dim, device=config.device)
            fake_images = generator(z)
            fake_images = (fake_images + 1) / 2
            
            save_path = f'samples/generated_{name}.png'
            vutils.save_image(fake_images.cpu(), save_path, nrow=nrow, normalize=False)
            print(f'\nSaved generated images at {save_path}')
            
        generator.train()
        
    except Exception as e:
        print(f"Error saving sample images: {e}")

def train_model(generator, discriminator, dataloader, config, num_epochs=10):
    try:
        # Optimizers
        g_optimizer = torch.optim.Adam(
            generator.parameters(),
            lr=config.lr,
            betas=(config.beta1, config.beta2)
        )
        d_optimizer = torch.optim.Adam(
            discriminator.parameters(),
            lr=config.lr,
            betas=(config.beta1, config.beta2)
        )

        print("Starting training...")
        os.makedirs('samples', exist_ok=True)
        os.makedirs('checkpoints', exist_ok=True)
        
        for epoch in range(num_epochs):
            pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
            epoch_metrics = []
            
            for batch_idx, real_imgs in enumerate(pbar):
                try:
                    real_imgs = real_imgs.to(config.device)
                    metrics = train_step(
                        real_imgs, generator, discriminator,
                        g_optimizer, d_optimizer, config
                    )
                    epoch_metrics.append(metrics)
                    
                    # Update progress bar with rolling average
                    if len(epoch_metrics) > 50:
                        avg_metrics = {
                            k: sum(m[k] for m in epoch_metrics[-50:]) / 50 
                            for k in metrics.keys()
                        }
                        pbar.set_postfix(avg_metrics)
                    
                    # Save samples and checkpoint periodically
                    if batch_idx % 500 == 0:
                        try:
                            save_sample_images(
                                generator, 
                                config, 
                                f"epoch_{epoch+1}_batch_{batch_idx}"
                            )
                            
                            torch.save({
                                'epoch': epoch,
                                'batch': batch_idx,
                                'generator_state_dict': generator.state_dict(),
                                'discriminator_state_dict': discriminator.state_dict(),
                                'g_optimizer_state_dict': g_optimizer.state_dict(),
                                'd_optimizer_state_dict': d_optimizer.state_dict(),
                            }, f'checkpoints/checkpoint_e{epoch}_b{batch_idx}.pt')
                            
                        except Exception as e:
                            print(f"Error saving checkpoint: {e}")
                            continue
                        
                except RuntimeError as e:
                    print(f"\nError in batch {batch_idx}: {str(e)}")
                    if "out of memory" in str(e):
                        if torch.backends.mps.is_available():
                            torch.mps.empty_cache()
                        continue
                    raise e
            
            # Save final state for epoch
            try:
                save_sample_images(generator, config, f"epoch_{epoch+1}_final")
                torch.save({
                    'epoch': epoch,
                    'generator_state_dict': generator.state_dict(),
                    'discriminator_state_dict': discriminator.state_dict(),
                    'g_optimizer_state_dict': g_optimizer.state_dict(),
                    'd_optimizer_state_dict': d_optimizer.state_dict(),
                }, f'checkpoints/checkpoint_epoch_{epoch+1}.pt')
            except Exception as e:
                print(f"Error saving final epoch state: {e}")
                    
    except Exception as e:
        print(f"Training error: {e}")
        return False
    
    return True

In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
import os
from pathlib import Path
import multiprocessing
from tqdm import tqdm

class CelebAHQDataset(Dataset):
    def __init__(self, root_dir, image_size=256):
        super().__init__()
        self.root_dir = Path(root_dir)
        if not self.root_dir.exists():
            raise ValueError(f"Dataset directory {root_dir} does not exist!")

        # Look for images recursively in all subdirectories
        self.image_paths = []
        valid_extensions = {'.jpg', '.jpeg', '.png'}
        
        print(f"Scanning {root_dir} for images...")
        for ext in valid_extensions:
            self.image_paths.extend(list(self.root_dir.rglob(f"*{ext}")))
        
        if len(self.image_paths) == 0:
            raise ValueError(f"No valid images found in {root_dir}")
        
        print(f"Found {len(self.image_paths)} images")
        
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        # Validate images
        self.validate_images()
        
    def validate_images(self):
        """Validate images and remove corrupted ones"""
        valid_paths = []
        print("Validating images...")
        for img_path in tqdm(self.image_paths):
            try:
                with Image.open(img_path) as img:
                    img.verify()
                valid_paths.append(img_path)
            except Exception as e:
                print(f"Corrupted image found {img_path}: {e}")
        
        self.image_paths = valid_paths
        print(f"Found {len(self.image_paths)} valid images")
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = str(self.image_paths[idx])
        try:
            with Image.open(img_path) as img:
                img = img.convert('RGB')
                return self.transform(img)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return torch.zeros(3, self.transform.transforms[0].size, 
                             self.transform.transforms[0].size)
            
def calc_path_lengths(fake_imgs, generator):
    """Calculate path lengths for path length regularization"""
    batch_size = fake_imgs.shape[0]
    noise = torch.randn_like(fake_imgs) / math.sqrt(fake_imgs.shape[2] * fake_imgs.shape[3])
    
    grad = torch.autograd.grad(
        outputs=(fake_imgs * noise).sum(),
        inputs=generator.parameters(),
        create_graph=True
    )[0]
    
    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
    return path_lengths
    
def find_dataset_path():
    """Find the CelebA-HQ dataset directory"""
    possible_paths = [
        'Data/celeba_hq',
        './data/celeba_hq',
        './datasets/celeba_hq',
        '../celeba_hq',
        '../data/celeba_hq',
        '../datasets/celeba_hq',
    ]
    
    for path in possible_paths:
        if os.path.exists(path):
            return path
    
    return None

def create_dataloader(config):
    """Create dataloader with proper path finding and error handling"""
    try:
        # Find dataset path
        dataset_path = find_dataset_path()
        if dataset_path is None:
            raise ValueError("Could not find CelebA-HQ dataset directory. Please specify the correct path.")
            
        print(f"Using dataset path: {dataset_path}")
        
        # Create dataset
        dataset = CelebAHQDataset(dataset_path, config.image_size)
        
        if len(dataset) == 0:
            raise ValueError("Dataset is empty!")
            
        # Create dataloader
        return DataLoader(
            dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=0,  # Set to 0 for debugging
            pin_memory=True if torch.backends.mps.is_available() else False,
            drop_last=True
        )
    except Exception as e:
        print(f"Error creating dataloader: {e}")
        return None


# Helper function to calculate log2
def log2(x):
    return int(torch.log2(torch.tensor(x)).item())
    
def compute_gradient_penalty(discriminator, real_imgs, fake_imgs, device):
    """Compute gradient penalty for improved WGAN training"""
    batch_size = real_imgs.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_imgs + ((1 - alpha) * fake_imgs)).requires_grad_(True)
    d_interpolates = discriminator(interpolates)
    
    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

class ModulatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, style_dim):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.style_dim = style_dim
        
        self.scale = 1 / math.sqrt(in_channels * kernel_size ** 2)
        self.padding = kernel_size // 2
        
        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
        )
        self.modulation = EqualLinear(style_dim, in_channels)
        self.demodulate = True

    def forward(self, x, style):
        batch, in_channels, height, width = x.shape
        
        # Style modulation
        style = self.modulation(style).view(batch, 1, -1, 1, 1)
        weight = self.scale * self.weight * style
        
        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
            weight = weight * demod.view(batch, self.out_channels, 1, 1, 1)
        
        weight = weight.view(
            batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size
        )
        
        x = x.view(1, batch * in_channels, height, width)
        out = F.conv2d(x, weight, padding=self.padding, groups=batch)
        out = out.view(batch, self.out_channels, height, width)
        
        return out


def main():
    # Set seeds for reproducibility
    set_seeds()
    
    # Create config
    config = StyleGANConfig()
    print(f"Using device: {config.device}")
    
    # Initialize models
    generator = Generator(config).to(config.device)
    discriminator = Discriminator(config).to(config.device)
    
    # Create dataloader
    dataloader = create_dataloader(config)
    if dataloader is None:
        return
        
    # Train model
    success = train_model(generator, discriminator, dataloader, config)
    
    if success:
        print("Training completed successfully")
        # Save final samples
        save_sample_images(generator, config, "final", num_samples=25, nrow=5)
        
        # Save final model
        torch.save({
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
        }, 'final_model.pt')
    else:
        print("Training failed")

In [77]:
if __name__ == "__main__":
    main()

Using device: mps
Using dataset path: Data/celeba_hq
Scanning Data/celeba_hq for images...
Found 28000 images
Validating images...


100%|███████████████████████████████████| 28000/28000 [00:03<00:00, 7769.77it/s]


Found 28000 valid images
Starting training...


Epoch 1/10:   0%|                                      | 0/3500 [00:00<?, ?it/s]


Saved generated images at samples/generated_epoch_1_batch_0.png


Epoch 1/10:  14%|▏| 500/3500 [04:43<28:07,  1.78it/s, d_loss=0.000242, g_loss=8.


Saved generated images at samples/generated_epoch_1_batch_500.png


Epoch 1/10:  29%|▎| 1000/3500 [09:32<23:29,  1.77it/s, d_loss=0.0762, g_loss=11.


Saved generated images at samples/generated_epoch_1_batch_1000.png


Epoch 1/10:  43%|▍| 1500/3500 [14:19<19:21,  1.72it/s, d_loss=0.106, g_loss=10.4


Saved generated images at samples/generated_epoch_1_batch_1500.png


Epoch 1/10:  43%|▍| 1516/3500 [14:28<18:56,  1.75it/s, d_loss=0.112, g_loss=12.3


KeyboardInterrupt: 

# Actual Code Section

In [81]:
import torch
import dnnlib
import legacy

# Download StyleGAN2-ADA FFHQ pretrained weights
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl'
with dnnlib.util.open_url(url) as f:
    pretrained = legacy.load_network_pkl(f)

Downloading https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl ... done


In [140]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import numpy as np
import math
import os

class PretrainedStyleGAN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.style_dim = config.style_dim
        
        # Mapping network
        self.mapping = MappingNetwork(
            latent_dim=config.style_dim,
            style_dim=config.style_dim,
            n_mlp=config.n_mlp
        )
        
        # Synthesis network
        self.synthesis = SynthesisNetwork(
            style_dim=config.style_dim,
            channels=config.channels,
            max_resolution=config.image_size
        )
        
    def forward(self, z):
        # Map latent to style space
        w = self.mapping(z)
        # Generate image
        return self.synthesis(w)

class EqualLinear(nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, lr_mul=1.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim))
        else:
            self.bias = None
            
        self.lr_mul = lr_mul
        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
        
    def forward(self, input):
        if self.bias is None:
            out = F.linear(input, self.weight * self.scale)
        else:
            out = F.linear(input, self.weight * self.scale, self.bias * self.lr_mul)
        return out

class SynthesisNetwork(nn.Module):
    def __init__(self, style_dim, channels, max_resolution):
        super().__init__()
        self.style_dim = style_dim
        self.max_resolution = max_resolution
        self.log_size = int(math.log2(max_resolution))
        
        # Initial learned constant input
        self.input = nn.Parameter(torch.randn(1, channels[4], 4, 4))
        
        # Style blocks (renamed from self.blocks)
        self.style_blocks = nn.ModuleList()
        in_channel = channels[4]
        
        # Build progressive blocks
        for i in range(3, self.log_size + 1):
            res = 2 ** i
            out_channel = channels[res]
            
            self.style_blocks.append(
                StyleBlock(
                    in_channel,
                    out_channel,
                    style_dim,
                    upsample=True
                )
            )
            in_channel = out_channel
            
    def forward(self, w):
        # Initial constant input
        x = self.input.repeat(w.size(0), 1, 1, 1)
        
        # Apply style blocks
        for block in self.style_blocks:  # Using style_blocks instead of blocks
            x = block(x, w)
            
        return x
        
class StyleBlock(nn.Module):
    def __init__(self, in_channel, out_channel, style_dim, upsample=True):
        super().__init__()
        self.upsample = upsample
        
        # First convolution
        self.conv1 = ModulatedConv2d(
            in_channel, 
            out_channel, 
            3, 
            style_dim
        )
        self.noise1 = NoiseInjection()
        self.activate1 = nn.LeakyReLU(0.2, inplace=True)
        
        # Second convolution
        self.conv2 = ModulatedConv2d(
            out_channel, 
            out_channel, 
            3, 
            style_dim
        )
        self.noise2 = NoiseInjection()
        self.activate2 = nn.LeakyReLU(0.2, inplace=True)
        
    def forward(self, x, style):
        # First conv block
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.conv1(x, style)
        x = self.noise1(x)
        x = self.activate1(x)
        
        # Second conv block
        x = self.conv2(x, style)
        x = self.noise2(x)
        x = self.activate2(x)
        
        return x


class MappingNetwork(nn.Module):
    def __init__(self, latent_dim, style_dim, n_mlp):
        super().__init__()
        layers = []
        dim = latent_dim
        for _ in range(n_mlp):
            layers.append(EqualLinear(dim, style_dim, lr_mul=0.01))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            dim = style_dim
            
        self.mapping = nn.Sequential(*layers)
        
    def forward(self, z):
        # Normalize latent vector
        z = F.normalize(z, dim=1)
        # Map to W space
        return self.mapping(z)

# Training helper functions
def train_with_pretrained(generator, discriminator, dataloader, config):
    """Training loop incorporating pretrained weights"""
    
    # Load pretrained state
    state_dict = torch.load('pretrained_stylegan2_ffhq.pt', weights_only=False)
    generator.load_state_dict(state_dict['g'], strict=False)
    discriminator.load_state_dict(state_dict['d'], strict=False)
    
    # Freeze certain layers
    for name, param in generator.named_parameters():
        if 'mapping' in name or 'style' in name:
            param.requires_grad = False
            
    # Modified training loop
    g_losses = []
    d_losses = []
    
    for epoch in range(config.num_epochs):
        for batch_idx, real_imgs in enumerate(dataloader):
            # Train discriminator
            d_loss = train_d(real_imgs, generator, discriminator, config)
            
            # Train generator (with style mixing)
            g_loss = train_g_mixed(generator, discriminator, config)
            
            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch}/{config.num_epochs}] Batch [{batch_idx}] '
                      f'd_loss: {d_loss:.4f} g_loss: {g_loss:.4f}')
                
            # Save samples periodically
            if batch_idx % 500 == 0:
                save_samples(generator, f'samples/epoch_{epoch}_batch_{batch_idx}.png')
                
    return g_losses, d_losses

def train_g_mixed(generator, discriminator, config):
    """Generator training with style mixing regularization"""
    batch_size = config.batch_size
    mixing_prob = 0.9
    
    # Generate two sets of latents
    z1 = torch.randn(batch_size, config.style_dim).to(config.device)
    z2 = torch.randn(batch_size, config.style_dim).to(config.device)
    
    # Mix styles
    if random.random() < mixing_prob:
        crossover_point = random.randint(1, generator.num_layers - 1)
        w1 = generator.mapping(z1)
        w2 = generator.mapping(z2)
        w = torch.cat([w1[:,:crossover_point], w2[:,crossover_point:]], dim=1)
    else:
        w = generator.mapping(z1)
    
    # Generate and optimize
    fake_imgs = generator.synthesis(w)
    fake_pred = discriminator(fake_imgs)
    g_loss = F.softplus(-fake_pred).mean()
    
    return g_loss.item()

In [141]:
class ModulatedConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, style_dim):
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.style_dim = style_dim
        
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
        self.padding = kernel_size // 2
        
        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )
        
        # Style modulation
        self.modulation = EqualLinear(style_dim, in_channel)
        
    def forward(self, x, style):
        batch, in_channel, height, width = x.shape
        
        # Style modulation
        style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
        weight = self.scale * self.weight * style
        
        # Demodulation
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
        weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
        
        # Reshape for grouped convolution
        weight = weight.view(
            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
        )
        x = x.view(1, batch * in_channel, height, width)
        
        # Convolution
        out = F.conv2d(x, weight, padding=self.padding, groups=batch)
        
        return out.view(batch, self.out_channel, height, width)


class NoiseInjection(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1))
        
    def forward(self, image, noise=None):
        if noise is None:
            batch, _, height, width = image.shape
            noise = image.new_empty(batch, 1, height, width).normal_()
        return image + self.weight * noise

In [142]:
import torch
import requests
import os
import pickle
from tqdm import tqdm

class StyleGANWeightLoader:
    def __init__(self):
        self.weights_dir = 'pretrained_weights'
        os.makedirs(self.weights_dir, exist_ok=True)
        
    def download_weights(self):
        """Download official StyleGAN2 weights"""
        # Official NVIDIA StyleGAN2-ADA FFHQ weights
        url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-256x256.pkl'
        filename = os.path.join(self.weights_dir, 'stylegan2-ffhq.pkl')
        
        if not os.path.exists(filename):
            print(f"Downloading weights from {url}")
            response = requests.get(url, stream=True)
            total_size = int(response.headers.get('content-length', 0))
            
            with open(filename, 'wb') as f, tqdm(
                desc="Downloading",
                total=total_size,
                unit='iB',
                unit_scale=True
            ) as pbar:
                for data in response.iter_content(chunk_size=1024):
                    size = f.write(data)
                    pbar.update(size)
                    
        return filename

    def load_weights(self):
        """Load weights safely"""
        try:
            weight_file = self.download_weights()
            
            print("Loading and converting weights...")
            with open(weight_file, 'rb') as f:
                data = pickle.load(f)
            
            # Convert weights to compatible format
            state_dict = self.convert_weights(data)
            
            # Save converted weights
            torch.save(state_dict, os.path.join(self.weights_dir, 'converted_weights.pt'))
            
            return state_dict
            
        except Exception as e:
            print(f"Error loading weights: {str(e)}")
            return None
    
    def convert_weights(self, data):
        """Convert original weights to our format"""
        g_state_dict = {}
        d_state_dict = {}
        
        # Handle different weight formats
        if isinstance(data, dict):
            g_weights = data.get('G_ema', data.get('G'))
            d_weights = data.get('D')
        else:
            g_weights = data
            d_weights = None
            
        # Convert generator weights
        if hasattr(g_weights, 'state_dict'):
            g_state = g_weights.state_dict()
        else:
            g_state = g_weights
            
        # Map weights to our architecture
        for name, param in g_state.items():
            if isinstance(param, torch.Tensor):
                if 'mapping' in name:
                    new_name = name.replace('mapping.', 'mapping.layers.')
                    g_state_dict[new_name] = param.clone()
                elif 'synthesis' in name:
                    if 'conv' in name:
                        g_state_dict[name] = param.clone()
                    elif 'noise' in name:
                        g_state_dict[name] = param.clone()
                    elif 'toRGB' in name:
                        g_state_dict[name] = param.clone()
                        
        return {
            'g': g_state_dict,
            'd': d_state_dict if d_weights is not None else None
        }

        
def load_pretrained_model(generator, discriminator, config):
    """Load pretrained weights into models"""
    try:
        loader = StyleGANWeightLoader()
        weights = loader.load_weights()
        
        # Load weights with error handling
        missing_g, unexpected_g = generator.load_state_dict(weights['g'], strict=False)
        print(f"\nGenerator loading info:")
        print(f"Missing keys: {len(missing_g)}")
        print(f"Unexpected keys: {len(unexpected_g)}")
        
        if weights['d']:
            missing_d, unexpected_d = discriminator.load_state_dict(weights['d'], strict=False)
            print(f"\nDiscriminator loading info:")
            print(f"Missing keys: {len(missing_d)}")
            print(f"Unexpected keys: {len(unexpected_d)}")
            
        return True
        
    except Exception as e:
        print(f"\nError loading pretrained weights: {str(e)}")
        print("Using random initialization instead.")
        return False

def save_samples(fake_imgs, path, nrow=4):
    """Save generated samples with proper format conversion"""
    try:
        # Ensure the images are in the correct format
        if fake_imgs.size(1) != 3:
            fake_imgs = fake_imgs.permute(0, 3, 1, 2)
            
        # Convert from [-1, 1] to [0, 1]
        fake_imgs = (fake_imgs + 1) / 2
        fake_imgs = torch.clamp(fake_imgs, 0, 1)
        
        # Save the images
        torchvision.utils.save_image(
            fake_imgs,
            path,
            nrow=nrow,
            normalize=False,
            range=(0, 1)
        )
    except Exception as e:
        print(f"Error saving samples: {str(e)}")
        traceback.print_exc()
        
def save_checkpoint(generator, discriminator, g_optim, d_optim, epoch, filename):
    """Save training checkpoint"""
    os.makedirs('checkpoints', exist_ok=True)
    
    torch.save({
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'g_optimizer_state_dict': g_optim.state_dict(),
        'd_optimizer_state_dict': d_optim.state_dict(),
        'epoch': epoch
    }, os.path.join('checkpoints', filename))

def train_stylegan(generator, discriminator, dataloader, config, num_epochs=10):
    """Main training loop with improved progress tracking"""
    
    # Setup optimizers
    g_optim = torch.optim.Adam(
        generator.parameters(),
        lr=config.lr,
        betas=(config.beta1, config.beta2)
    )
    d_optim = torch.optim.Adam(
        discriminator.parameters(),
        lr=config.lr,
        betas=(config.beta1, config.beta2)
    )
    
    # Create directories for saving results
    os.makedirs('samples', exist_ok=True)
    os.makedirs('checkpoints', exist_ok=True)
    
    print(f"Starting training for {num_epochs} epochs...")
    for epoch in range(num_epochs):
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        epoch_metrics = []
        
        for i, real_imgs in enumerate(progress_bar):
            try:
                # Training step
                metrics = train_step(real_imgs, generator, discriminator, g_optim, d_optim, config)
                
                if metrics is not None:
                    epoch_metrics.append(metrics)
                    
                    # Update progress bar
                    avg_metrics = {
                        k: sum(m[k] for m in epoch_metrics[-100:]) / len(epoch_metrics[-100:])
                        for k in metrics.keys()
                    }
                    progress_bar.set_postfix(avg_metrics)
                
                # Save samples and checkpoint
                if i % 500 == 0:
                    with torch.no_grad():
                        # Generate samples
                        sample_z = torch.randn(16, config.style_dim).to(config.device)
                        samples = generator(sample_z)
                        torchvision.utils.save_image(
                            samples,
                            f'samples/epoch_{epoch}_batch_{i}.png',
                            normalize=True,
                            value_range=(-1, 1),
                            nrow=4
                        )
                        
                        # Save checkpoint
                        torch.save({
                            'epoch': epoch,
                            'batch': i,
                            'generator_state_dict': generator.state_dict(),
                            'discriminator_state_dict': discriminator.state_dict(),
                            'g_optimizer_state_dict': g_optim.state_dict(),
                            'd_optimizer_state_dict': d_optim.state_dict(),
                            'metrics': avg_metrics
                        }, f'checkpoints/checkpoint_e{epoch}_b{i}.pt')
                
                # Memory management
                if i % 100 == 0 and torch.backends.mps.is_available():
                    torch.mps.empty_cache()
                    
            except Exception as e:
                print(f"\nError in batch {i}: {str(e)}")
                continue
                
        # Save epoch summary
        torch.save({
            'epoch': epoch,
            'metrics': avg_metrics
        }, f'checkpoints/epoch_{epoch}_summary.pt')
        
    return epoch_metrics

def load_and_train(generator, discriminator, dataloader, config):
    """Load weights and start training"""
    
    # Initialize weight loader
    weight_loader = StyleGANWeightLoader()
    
    try:
        # Load pretrained weights
        print("Loading pretrained StyleGAN2 weights...")
        weights = weight_loader.load_stylegan2_weights()
        
        # Load weights into models
        generator.load_state_dict(weights['g'], strict=False)
        if weights['d'] is not None:
            discriminator.load_state_dict(weights['d'], strict=False)
            
        print("Successfully loaded pretrained weights")
        
        # Start training
        return train_with_pretrained(generator, discriminator, dataloader, config)
        
    except Exception as e:
        print(f"Error loading weights: {str(e)}")
        print("Continuing with randomly initialized weights...")
        return train_with_pretrained(generator, discriminator, dataloader, config)

In [143]:
def load_and_initialize(generator, discriminator, config):
    """Load weights and initialize models"""
    loader = StyleGANWeightLoader()
    weights = loader.load_weights()
    
    if weights is not None:
        print("Loading pretrained weights...")
        missing_g, unexpected_g = generator.load_state_dict(weights['g'], strict=False)
        print(f"\nGenerator loading stats:")
        print(f"Missing keys: {len(missing_g)}")
        print(f"Unexpected keys: {len(unexpected_g)}")
        
        if weights['d'] is not None:
            missing_d, unexpected_d = discriminator.load_state_dict(weights['d'], strict=False)
            print(f"\nDiscriminator loading stats:")
            print(f"Missing keys: {len(missing_d)}")
            print(f"Unexpected keys: {len(unexpected_d)}")
    else:
        print("Using random initialization")


def create_dataloader(config):
    """Create dataloader with proper error handling"""
    try:
        # Verify dataset path exists
        if not os.path.exists(config.data_dir):
            raise ValueError(f"Dataset directory {config.data_dir} not found!")
            
        # Create dataset
        transform = transforms.Compose([
            transforms.Resize(config.image_size),
            transforms.CenterCrop(config.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        class CelebADataset(Dataset):
            def __init__(self, root_dir, transform=None):
                self.root_dir = Path(root_dir)
                self.transform = transform
                self.image_paths = list(self.root_dir.glob('*.jpg'))
                
                if len(self.image_paths) == 0:
                    raise ValueError(f"No images found in {root_dir}")
                    
                print(f"Found {len(self.image_paths)} images")
                
            def __len__(self):
                return len(self.image_paths)
                
            def __getitem__(self, idx):
                img_path = self.image_paths[idx]
                image = Image.open(img_path).convert('RGB')
                
                if self.transform:
                    image = self.transform(image)
                return image
                
        dataset = CelebADataset(config.data_dir, transform=transform)
        
        # Create dataloader with proper batch size
        dataloader = DataLoader(
            dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=config.num_workers,
            pin_memory=True if torch.backends.mps.is_available() else False,
            drop_last=True
        )
        
        return dataloader
        
    except Exception as e:
        print(f"Error creating dataloader: {str(e)}")
        return None

class StyleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, style_dim, upsample=True):
        super().__init__()
        self.conv1 = ModulatedConv2d(
            in_channels, out_channels, 3, style_dim, upsample=upsample
        )
        self.noise1 = NoiseInjection()
        self.activation1 = nn.LeakyReLU(0.2, inplace=True)
        
        self.conv2 = ModulatedConv2d(
            out_channels, out_channels, 3, style_dim
        )
        self.noise2 = NoiseInjection()
        self.activation2 = nn.LeakyReLU(0.2, inplace=True)
        
    def forward(self, x, style):
        x = self.conv1(x, style)
        x = self.noise1(x)
        x = self.activation1(x)
        
        x = self.conv2(x, style)
        x = self.noise2(x)
        x = self.activation2(x)
        
        return x

def save_image_grid(images, path, nrow=4):
    """Save a grid of images"""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torchvision.utils.save_image(
        images,
        path,
        nrow=nrow,
        normalize=True,
        range=(-1, 1)
    )

class StyleGANConfig:
    def __init__(self):
        # Model architecture
        self.image_size = 256
        self.style_dim = 512
        self.n_mlp = 8
        self.channels = {
            4: 512,    # 4x4
            8: 512,    # 8x8
            16: 512,   # 16x16
            32: 512,   # 32x32
            64: 256,   # 64x64
            128: 128,  # 128x128
            256: 64    # 256x256
        }
        
        # Training parameters
        self.batch_size = 16
        self.lr = 0.002
        self.beta1 = 0.0
        self.beta2 = 0.99
        
        # Dataset parameters
        self.data_dir = 'Data/celeba_hq'
        self.num_workers = 0
        
        # Device
        self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [144]:
class StyleGANTrainer:
    def __init__(self, config):
        self.config = config
        self.device = config.device
        
        # Initialize models
        self.generator = PretrainedStyleGAN(config).to(self.device)
        self.discriminator = Discriminator(config).to(self.device)
        
        # Initialize optimizers
        self.g_optim = torch.optim.Adam(
            self.generator.parameters(),
            lr=config.lr,
            betas=(config.beta1, config.beta2)
        )
        self.d_optim = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=config.lr,
            betas=(config.beta1, config.beta2)
        )
        
        # Create directories
        os.makedirs('samples', exist_ok=True)
        os.makedirs('checkpoints', exist_ok=True)
        
    def train(self, dataloader, num_epochs=10):
        print(f"Starting training on {self.device}")
        
        for epoch in range(num_epochs):
            progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
            epoch_metrics = []
            
            for i, real_imgs in enumerate(progress_bar):
                try:
                    metrics = train_step(
                        real_imgs, 
                        self.generator,
                        self.discriminator,
                        self.g_optim,
                        self.d_optim,
                        self.config,
                        self.device
                    )
                    
                    if metrics is not None:
                        epoch_metrics.append(metrics)
                        
                        # Update progress bar
                        if len(epoch_metrics) > 0:
                            avg_metrics = {
                                k: sum(m[k] for m in epoch_metrics[-100:] if k != 'fake_imgs') / min(len(epoch_metrics), 100)
                                for k in metrics.keys() if k != 'fake_imgs'
                            }
                            progress_bar.set_postfix(avg_metrics)
                        
                        # Save samples
                        if i % 500 == 0:
                            save_samples(
                                metrics['fake_imgs'],
                                f'samples/epoch_{epoch}_batch_{i}.png'
                            )
                            
                            # Save checkpoint
                            self.save_checkpoint(epoch, i, metrics)
                            
                    # Memory management
                    if i % 100 == 0 and torch.backends.mps.is_available():
                        torch.mps.empty_cache()
                        
                except Exception as e:
                    print(f"\nError in batch {i}: {str(e)}")
                    traceback.print_exc()
                    continue
                    
            # Save epoch summary
            self.save_epoch_summary(epoch, epoch_metrics)
            
    def save_checkpoint(self, epoch, batch, metrics):
        checkpoint = {
            'epoch': epoch,
            'batch': batch,
            'generator_state_dict': self.generator.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict(),
            'g_optimizer_state_dict': self.g_optim.state_dict(),
            'd_optimizer_state_dict': self.d_optim.state_dict(),
            'metrics': {k: v for k, v in metrics.items() if k != 'fake_imgs'}
        }
        torch.save(checkpoint, f'checkpoints/checkpoint_e{epoch}_b{batch}.pt')
        
    def save_epoch_summary(self, epoch, metrics):
        summary = {
            'epoch': epoch,
            'metrics': {
                k: sum(m[k] for m in metrics if k != 'fake_imgs') / len(metrics)
                for k in metrics[0].keys() if k != 'fake_imgs'
            }
        }
        torch.save(summary, f'checkpoints/epoch_{epoch}_summary.pt')

In [182]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pathlib import Path
from tqdm import tqdm
import traceback
import os

class StyleGANTrainer:
    def __init__(self, config):
        self.config = config
        self.device = config.device
        
        # Initialize models
        self.generator = Generator(config).to(self.device)
        self.discriminator = Discriminator(config).to(self.device)
        
        # Initialize optimizers
        self.g_optim = torch.optim.Adam(
            self.generator.parameters(),
            lr=config.lr,
            betas=(config.beta1, config.beta2)
        )
        self.d_optim = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=config.lr,
            betas=(config.beta1, config.beta2)
        )
        
        # Initialize scaler for mixed precision
        self.scaler = torch.amp.GradScaler(enabled=False)  # Disabled for MPS
        
        # Create directories
        os.makedirs('samples', exist_ok=True)
        os.makedirs('checkpoints', exist_ok=True)
        
        # Set up M1 optimizations
        if self.device.type == "mps":
            torch.mps.empty_cache()
            
    def train_step(self, real_imgs):
        """Single training step for both G and D"""
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(self.device)

        # Train Discriminator
        self.d_optim.zero_grad(set_to_none=True)
        
        # Real images
        real_pred = self.discriminator(real_imgs)
        d_real_loss = F.softplus(-real_pred).mean()
        
        # Generate fake images
        z = torch.randn(batch_size, self.config.style_dim).to(self.device)
        with torch.no_grad():
            fake_imgs = self.generator(z)
        fake_pred = self.discriminator(fake_imgs.detach())
        d_fake_loss = F.softplus(fake_pred).mean()
        
        # Combined D loss
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        self.d_optim.step()

        # Train Generator
        self.g_optim.zero_grad(set_to_none=True)
        
        # Generate new fake images
        z = torch.randn(batch_size, self.config.style_dim).to(self.device)
        fake_imgs = self.generator(z)
        fake_pred = self.discriminator(fake_imgs)
        
        # G loss
        g_loss = F.softplus(-fake_pred).mean()
        g_loss.backward()
        self.g_optim.step()

        # Memory management
        if self.device.type == "mps":
            torch.mps.empty_cache()

        return {
            'd_loss': d_loss.item(),
            'g_loss': g_loss.item(),
            'fake_images': fake_imgs.detach()
        }

    def save_samples(self, fake_imgs, epoch, batch_idx):
        """Save generated samples"""
        try:
            samples_dir = Path('samples')
            samples_dir.mkdir(exist_ok=True)
            
            # Convert to range [0, 1]
            fake_imgs = (fake_imgs + 1) / 2.0
            fake_imgs = torch.clamp(fake_imgs, 0, 1)
            
            # Create image grid
            grid = torchvision.utils.make_grid(
                fake_imgs[:16],  # Save top 16 images
                nrow=4,
                padding=2,
                normalize=False
            )
            
            # Save grid
            save_path = samples_dir / f'samples_epoch_{epoch}_batch_{batch_idx}.png'
            torchvision.utils.save_image(grid, save_path)
            print(f"\nSaved samples to {save_path}")
            
        except Exception as e:
            print(f"Error saving samples: {e}")
            traceback.print_exc()
    
    def save_checkpoint(self, epoch, batch_idx, losses):
        """Save model checkpoint"""
        try:
            checkpoint = {
                'epoch': epoch,
                'batch_idx': batch_idx,
                'generator_state': self.generator.state_dict(),
                'discriminator_state': self.discriminator.state_dict(),
                'g_optimizer': self.g_optim.state_dict(),
                'd_optimizer': self.d_optim.state_dict(),
                'losses': losses,
                'config': self.config
            }
            
            save_path = f'checkpoints/styleGAN_epoch{epoch}_batch{batch_idx}.pt'
            torch.save(checkpoint, save_path)
            print(f"Saved checkpoint: {save_path}")
            
        except Exception as e:
            print(f"Error saving checkpoint: {e}")
            traceback.print_exc()
            
    def train(self, dataloader, num_epochs=10):
        """Main training loop"""
        print(f"Starting training on {self.device}")
        
        try:
            for epoch in range(num_epochs):
                pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
                running_d_loss = 0
                running_g_loss = 0
                
                for batch_idx, real_imgs in enumerate(pbar):
                    # Training step
                    metrics = self.train_step(real_imgs)
                    
                    # Update running losses
                    running_d_loss = 0.9 * running_d_loss + 0.1 * metrics['d_loss']
                    running_g_loss = 0.9 * running_g_loss + 0.1 * metrics['g_loss']
                    
                    # Update progress bar
                    pbar.set_postfix({
                        'D_loss': f"{running_d_loss:.4f}",
                        'G_loss': f"{running_g_loss:.4f}"
                    })
                    
                    # Save samples and checkpoint periodically
                    if batch_idx % 100 == 0:
                        self.save_samples(metrics['fake_images'], epoch, batch_idx)
                        self.save_checkpoint(
                            epoch, 
                            batch_idx,
                            {
                                'g_loss': running_g_loss,
                                'd_loss': running_d_loss
                            }
                        )
                        
                    # Memory management
                    if batch_idx % 10 == 0 and self.device.type == "mps":
                        torch.mps.empty_cache()
                
                # Save epoch checkpoint
                self.save_checkpoint(
                    epoch,
                    len(dataloader),
                    {
                        'g_loss': running_g_loss,
                        'd_loss': running_d_loss
                    }
                )
                
        except Exception as e:
            print(f"\nTraining interrupted: {str(e)}")
            traceback.print_exc()
            return False
            
        return True

class CelebADataset(Dataset):
    """Made pickleable for multiprocessing"""
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        
        # Get all image files
        self.image_paths = []
        valid_extensions = {'.jpg', '.jpeg', '.png'}
        
        for ext in valid_extensions:
            self.image_paths.extend(
                [str(p) for p in Path(root_dir).glob(f'*{ext}')]
            )
            
        if not self.image_paths:
            raise RuntimeError(f"No images found in {root_dir}")
            
        print(f"Found {len(self.image_paths)} images")
        
    def __len__(self):
        return len(self.image_paths)
        
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return torch.zeros(3, self.transform.transforms[0].size,
                             self.transform.transforms[0].size)

def create_dataloader(config):
    """Create dataloader with proper multiprocessing settings"""
    try:
        transform = transforms.Compose([
            transforms.Resize(config.image_size),
            transforms.CenterCrop(config.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        dataset = CelebADataset(config.data_dir, transform)
        
        return DataLoader(
            dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=0,  # Set back to 0 for debugging
            pin_memory=True if torch.backends.mps.is_available() else False,
            drop_last=True
        )
        
    except Exception as e:
        print(f"Error creating dataloader: {str(e)}")
        traceback.print_exc()
        return None

In [183]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Generator(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.style_dim = config.style_dim
        self.num_layers = int(math.log2(config.image_size)) - 1
        self.gradient_checkpointing_enabled = False
        
        # Mapping Network
        mapping_layers = []
        for _ in range(config.n_mlp):
            mapping_layers.extend([
                EqualLinear(config.style_dim, config.style_dim),
                nn.LeakyReLU(0.2, inplace=True)
            ])
        self.mapping = nn.Sequential(*mapping_layers)
        
        # Initial learned constant input
        self.input = nn.Parameter(torch.randn(1, config.channels[4], 4, 4))
        
        # Style blocks for progressive generation
        self.style_blocks = nn.ModuleList()
        in_channel = config.channels[4]
        
        resolutions = [4, 8, 16, 32, 64, 128, 256]  # All possible resolutions
        for i in range(len(resolutions) - 1):
            curr_res = resolutions[i]
            next_res = resolutions[i + 1]
            
            self.style_blocks.append(
                StyleBlock(
                    in_channel, 
                    config.channels[next_res],
                    config.style_dim,
                    upsample=True
                )
            )
            in_channel = config.channels[next_res]
            
        # Final RGB conversion
        self.to_rgb = nn.Sequential(
            nn.Conv2d(in_channel, 3, 1),
            nn.Tanh()
        )
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def enable_gradient_checkpointing(self):
        self.gradient_checkpointing_enabled = True
        
    def disable_gradient_checkpointing(self):
        self.gradient_checkpointing_enabled = False
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.Conv2d):
            nn.init.kaiming_normal_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
                
    def forward(self, z, return_latents=False):
        batch_size = z.size(0)
        
        # Map to W space
        if self.gradient_checkpointing_enabled and self.training:
            w = torch.utils.checkpoint.checkpoint(self.mapping, z)
        else:
            w = self.mapping(z)
        
        # Replicate w for each style block if needed
        if w.ndim == 2:
            w = w.unsqueeze(1).repeat(1, len(self.style_blocks), 1)
            
        # Start from learned constant
        x = self.input.repeat(batch_size, 1, 1, 1)
        
        # Apply style blocks with optional checkpointing
        for i, block in enumerate(self.style_blocks):
            if self.gradient_checkpointing_enabled and self.training:
                x = torch.utils.checkpoint.checkpoint(block, x, w[:, i])
            else:
                x = block(x, w[:, i])
            
        # Convert to RGB
        out = self.to_rgb(x)
        
        if return_latents:
            return out, w
        return out


class Discriminator(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        # Initial RGB processing
        self.from_rgb = nn.Sequential(
            nn.Conv2d(3, config.channels[256], 1),  # Start with channels for highest resolution
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Progressive downsampling blocks
        self.blocks = nn.ModuleList()
        resolutions = [256, 128, 64, 32, 16, 8, 4]  # From highest to lowest
        
        for i in range(len(resolutions) - 1):
            curr_res = resolutions[i]
            next_res = resolutions[i + 1]
            
            self.blocks.append(
                DiscriminatorBlock(
                    config.channels[curr_res],
                    config.channels[next_res],
                    downsample=True
                )
            )
        
        # Final convolution layers
        final_channels = config.channels[4]
        self.final_conv = nn.Sequential(
            nn.Conv2d(final_channels, final_channels, 3, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(final_channels, final_channels, 4, padding=0),  # Changed to remove padding
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Calculate the flattened size
        self.final_size = final_channels  # This will be 512 based on your config
        
        # Final classification
        self.classifier = nn.Sequential(
            nn.Linear(self.final_size, 1),  # Changed from final_channels * 4 * 4
            nn.Sigmoid()
        )
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.Conv2d):
            nn.init.kaiming_normal_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
                
    def forward(self, x):
        # Initial RGB processing
        x = self.from_rgb(x)  # x shape: [batch_size, 64, 256, 256]
        
        # Progressive downsampling
        for block in self.blocks:
            x = block(x)  # Progressive downsampling through resolutions
            
        # Final convolutions to 1x1
        x = self.final_conv(x)  # Should result in [batch_size, 512, 1, 1]
        
        # Flatten properly
        x = x.view(x.size(0), -1)  # Reshape to [batch_size, 512]
        
        # Classification
        return self.classifier(x)

class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channel, out_channel, downsample=True):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, in_channel, 3, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        if downsample:
            self.conv2 = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, 4, stride=2, padding=1),  # Changed kernel and stride
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:
            self.conv2 = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, 3, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class ModulatedConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, style_dim):
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.style_dim = style_dim
        
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
        self.padding = kernel_size // 2
        
        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )
        
        self.modulation = EqualLinear(style_dim, in_channel)
        
    def forward(self, x, style):
        batch, in_channel, height, width = x.shape
        
        # Style modulation
        style = self.modulation(style).view(batch, 1, -1, 1, 1)
        weight = self.scale * self.weight * style
        
        # Demodulation
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
        weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
        
        # Reshape for grouped convolution
        weight = weight.view(
            batch * self.out_channel, in_channel, 
            self.kernel_size, self.kernel_size
        )
        x = x.view(1, batch * in_channel, height, width)
        
        # Convolution
        out = F.conv2d(x, weight, padding=self.padding, groups=batch)
        
        return out.view(batch, self.out_channel, height, width)


class NoiseInjection(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1))
        
    def forward(self, image, noise=None):
        if noise is None:
            batch, _, height, width = image.shape
            noise = image.new_empty(batch, 1, height, width).normal_()
        return image + self.weight * noise


class EqualLinear(nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, lr_mul=1.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim))
        else:
            self.bias = None
            
        self.lr_mul = lr_mul
        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
        
    def forward(self, input):
        if self.bias is None:
            out = F.linear(input, self.weight * self.scale)
        else:
            out = F.linear(input, self.weight * self.scale, 
                         self.bias * self.lr_mul)
        return out

In [184]:
class StyleGANConfig:
    def __init__(self):
        # Model architecture
        self.image_size = 256
        self.style_dim = 512
        self.n_mlp = 8
        self.channels = {
            4: 512,    # 4x4
            8: 512,    # 8x8 
            16: 512,   # 16x16
            32: 512,   # 32x32
            64: 256,   # 64x64
            128: 128,  # 128x128
            256: 64    # 256x256
        }
        
        # Training parameters
        self.batch_size = 32  # Increased for faster training
        self.lr = 0.0002
        self.beta1 = 0.0
        self.beta2 = 0.99
        
        # Optimization settings
        self.mixed_precision = True
        self.gradient_checkpointing = True
        self.accumulation_steps = 4  # Gradient accumulation steps
        
        # System settings
        self.num_workers = 2
        self.pin_memory = True
        self.prefetch_factor = 2
        
        # Device
        self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [185]:
config = StyleGANConfig()
generator = Generator(config).to(config.device)
discriminator = Discriminator(config).to(config.device)

In [186]:
if __name__ == "__main__":
    # Initialize config
    config = StyleGANConfig()
    config.batch_size = 32
    config.lr = 0.0002
    config.data_dir = "Data/celeba_hq"
    
    # Create trainer and dataloader
    trainer = StyleGANTrainer(config)
    dataloader = create_dataloader(config)
    
    if dataloader is not None:
        try:
            success = trainer.train(dataloader, num_epochs=10)
            if success:
                print("Training completed successfully!")
            else:
                print("Training completed with errors.")
        except Exception as e:
            print(f"Training failed: {str(e)}")
            traceback.print_exc()

Found 28000 images
Starting training on mps


Epoch 1/10:   0%|         | 0/875 [00:16<?, ?it/s, D_loss=0.1447, G_loss=0.0396]


Saved samples to samples/samples_epoch_0_batch_0.png


Epoch 1/10:   0%| | 1/875 [00:17<4:20:13, 17.86s/it, D_loss=0.1447, G_loss=0.039

Saved checkpoint: checkpoints/styleGAN_epoch0_batch0.pt


Epoch 1/10:  11%| | 100/875 [22:09<2:50:16, 13.18s/it, D_loss=1.3863, G_loss=0.6


Saved samples to samples/samples_epoch_0_batch_100.png


Epoch 1/10:  12%| | 101/875 [22:10<2:54:10, 13.50s/it, D_loss=1.3863, G_loss=0.6

Saved checkpoint: checkpoints/styleGAN_epoch0_batch100.pt


Epoch 1/10:  18%|▏| 160/875 [35:40<2:39:24, 13.38s/it, D_loss=1.3863, G_loss=0.6


KeyboardInterrupt: 

In [93]:
import dnnlib
import legacy

url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl'
with dnnlib.util.open_url(url) as f:
    pretrained = legacy.load_network_pkl(f)