In [12]:
DATA_DIR = "images/256"     # Deine Trainingsbilder
EPOCHS = 500                         # Anzahl Trainingsdurchläufe
BATCH_SIZE = 32                       # Bilder pro Batch (bei Speicherproblemen reduzieren)
SAVE_EVERY = 10                      # Alle X Epochen speichern


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from PIL import Image
import os, glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')


# Ordner anlegen
os.makedirs('training/StyleGAN/bilder', exist_ok=True)
os.makedirs('training/StyleGAN/epochen', exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Verwende: {device}")

🚀 Verwende: cuda


In [14]:
class SimpleImageDataset(Dataset):
    def __init__(self, folder):
        self.files = []
        for ext in ['*.jpg', '*.png', '*.jpeg']:
            self.files.extend(glob.glob(os.path.join(folder, ext)))
        
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
        print(f"Gefunden: {len(self.files)} Bilder")
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert('RGB')
        return self.transform(img)

In [15]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Start: 512 -> 4x4
        self.start = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 4, 1, 0),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        
        # Upsampling Layers: 4x4 -> 256x256
        self.ups = nn.ModuleList([
            self._make_layer(512, 256),  # 4->8
            self._make_layer(256, 128),  # 8->16
            self._make_layer(128, 64),   # 16->32
            self._make_layer(64, 32),    # 32->64
            self._make_layer(32, 16),    # 64->128
            self._make_layer(16, 8),     # 128->256
        ])
        
        # Final RGB
        self.to_rgb = nn.Sequential(
            nn.Conv2d(8, 3, 3, 1, 1),
            nn.Tanh()
        )
    
    def _make_layer(self, in_ch, out_ch):
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
    
    def forward(self, z):
        x = z.view(z.size(0), 512, 1, 1)
        x = self.start(x)
        
        for up in self.ups:
            x = up(x)
        
        return self.to_rgb(x)

In [16]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.layers = nn.Sequential(
            # 256->128
            nn.Conv2d(3, 16, 4, 2, 1),
            nn.LeakyReLU(0.2),
            
            # 128->64
            nn.Conv2d(16, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            
            # 64->32
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            # 32->16
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            # 16->8
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            
            # 8->4
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            
            # 4->1
            nn.Conv2d(512, 1, 4, 1, 0),
        )
    
    def forward(self, x):
        return self.layers(x).view(-1)

In [17]:
def setup_training():
    # Models
    G = Generator().to(device)
    D = Discriminator().to(device)
    
    # Optimizers
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    # Loss
    criterion = nn.BCEWithLogitsLoss()
    
    # Dataset
    dataset = SimpleImageDataset(DATA_DIR)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=16, pin_memory=True)
    
    return G, D, opt_G, opt_D, criterion, dataloader


In [18]:
def save_samples(G, epoch, n_samples=16):
    """Beispielbilder speichern"""
    G.eval()
    with torch.no_grad():
        z = torch.randn(n_samples, 512, device=device)
        fake = G(z)
        
        # Grid erstellen
        grid = make_grid(fake, nrow=4, normalize=True, value_range=(-1, 1))
        
        # Als Bild speichern
        plt.figure(figsize=(10, 10))
        plt.imshow(grid.permute(1, 2, 0).cpu())
        plt.axis('off')
        plt.title(f'Epoch {epoch}')
        plt.savefig(f'training/StyleGAN/bilder/epoch_{epoch:04d}.png', bbox_inches='tight', dpi=150)
        plt.show()
    G.train()

def save_checkpoint(G, D, opt_G, opt_D, epoch, losses):
    """Model speichern"""
    torch.save({
        'epoch': epoch,
        'generator': G.state_dict(),
        'discriminator': D.state_dict(),
        'opt_G': opt_G.state_dict(),
        'opt_D': opt_D.state_dict(),
        'losses': losses
    }, f'training/StyleGAN/epochen/checkpoint_{epoch:04d}.pth')
    
def plot_losses(losses):
    """Loss-Verlauf anzeigen"""
    plt.figure(figsize=(10, 4))
    plt.plot(losses['G'], label='Generator', color='blue')
    plt.plot(losses['D'], label='Discriminator', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Losses')
    plt.grid(True, alpha=0.3)
    plt.show()

In [19]:
from torch.amp import autocast, GradScaler

def train():
    print("Training starten...")

    G, D, opt_G, opt_D, criterion, dataloader = setup_training()

    scaler_G = GradScaler(device=device)
    scaler_D = GradScaler(device=device)

    losses = {'G': [], 'D': []}
    real_label = 1.0
    fake_label = 0.0

    for epoch in range(EPOCHS):
        epoch_d_loss = 0
        epoch_g_loss = 0

        pbar = tqdm(dataloader, desc=f"Epoche {epoch+1}/{EPOCHS}", leave=False)

        for real_batch in pbar:
            real_batch = real_batch.to(device)
            batch_size = real_batch.size(0)

            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            # ----- Discriminator -----
            D.zero_grad()

            with autocast(device_type=device.type):
                output = D(real_batch)
                d_loss_real = criterion(output, real_labels)

                noise = torch.randn(batch_size, 512, device=device)
                fake = G(noise)

                output_fake = D(fake.detach())
                d_loss_fake = criterion(output_fake, fake_labels)

                d_loss = d_loss_real + d_loss_fake

            scaler_D.scale(d_loss).backward()
            scaler_D.step(opt_D)
            scaler_D.update()

            # ----- Generator -----
            G.zero_grad()

            with autocast(device_type=device.type):
                output = D(fake)
                g_loss = criterion(output, real_labels)

            scaler_G.scale(g_loss).backward()
            scaler_G.step(opt_G)
            scaler_G.update()

            batch_d_loss = d_loss.item()
            batch_g_loss = g_loss.item()
            epoch_d_loss += batch_d_loss
            epoch_g_loss += batch_g_loss

            pbar.set_postfix({
                "D_Loss": f"{batch_d_loss:.4f}",
                "G_Loss": f"{batch_g_loss:.4f}"
            })

        avg_d_loss = epoch_d_loss / len(dataloader)
        avg_g_loss = epoch_g_loss / len(dataloader)
        losses['D'].append(avg_d_loss)
        losses['G'].append(avg_g_loss)

        if (epoch + 1) % SAVE_EVERY == 0:
            save_samples(G, epoch + 1)
            save_checkpoint(G, D, opt_G, opt_D, epoch + 1, losses)
            plot_losses(losses)
        print(f"\nEpoch {epoch+1}: G_Loss={avg_g_loss:.4f}, D_Loss={avg_d_loss:.4f}")

    print("Training abgeschlossen!")
    return G, D, losses


### Bilder Generieren


In [20]:
def generate_images(G=None, n=16):
    #Neue Bilder generieren
    if G is None:
        if 'G' in globals():
            G = globals()['G']
        else:
            print("Kein Generator verfügbar! Führe zuerst das Training aus oder lade ein Checkpoint.")
            return
    
    G.eval()
    with torch.no_grad():
        z = torch.randn(n, 512, device=device)
        fake = G(z)
        
        grid = make_grid(fake, nrow=4, normalize=True, value_range=(-1, 1))
        
        plt.figure(figsize=(12, 12))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title('Generierte Bilder')
        plt.show()

In [21]:
def load_checkpoint(path):
    #Gespeichertes Model laden
    checkpoint = torch.load(path, map_location=device)
    
    G = Generator().to(device)
    G.load_state_dict(checkpoint['generator'])
    G.eval()
    
    print(f"Model geladen von Epoch {checkpoint['epoch']}")
    return G


In [None]:
import pandas as pd
df_losses = pd.DataFrame({
    'Epoch': list(range(1, len(losses['G']) + 1)),
    'Generator_Loss': losses['G'],
    'Discriminator_Loss': losses['D']
})
df_losses.to_csv('training/StyleGANv2/losses.csv', index=False)
print("Loss-Tabelle als CSV gespeichert: training/StyleGANv2/losses.csv")


In [None]:
# Hauptfunktion zum Ausführen
if __name__ == "__main__":
    DATA_DIR = "images/256"
    
    # Training starten
    G, D, losses = train()