In [54]:
DATA_DIR = "images/256"     # Deine Trainingsbilder
EPOCHS = 5000                         # Anzahl Trainingsdurchläufe
BATCH_SIZE = 128                     # Bilder pro Batch (bei Speicherproblemen reduzieren)
SAVE_EVERY = 10                      # Alle X Epochen speichern


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from PIL import Image
import os, glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')


# Ordner anlegen
os.makedirs('training/StyleGANv2NEW/bilder', exist_ok=True)
os.makedirs('training/StyleGANv2NEW/epochen', exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Verwende: {device}")

🚀 Verwende: cuda


In [56]:
class SimpleImageDataset(Dataset):
    def __init__(self, folder):
        self.files = []
        # Erweiterte Dateierkennung
        extensions = ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG', '*.JPEG']
        for ext in extensions:
            self.files.extend(glob.glob(os.path.join(folder, ext)))
        
        if len(self.files) == 0:
            raise ValueError(f"Keine Bilder in {folder} gefunden! Überprüfe den Pfad.")
        
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
        print(f"Gefunden: {len(self.files)} Bilder")
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        try:
            img = Image.open(self.files[idx]).convert('RGB')
            return self.transform(img)
        except Exception as e:
            print(f"Fehler beim Laden von {self.files[idx]}: {e}")
            # Fallback: zufälliges anderes Bild
            return self.__getitem__((idx + 1) % len(self.files))

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            # Start: (512,) → (512, 4, 4)
            nn.ConvTranspose2d(512, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            # 4 → 8
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            # 8 → 16
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # 16 → 32
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 32 → 64
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            # 64 → 128
            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True),

            # 128 → 256
            nn.ConvTranspose2d(16, 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU(True),

            # Final RGB Output
            nn.Conv2d(8, 3, 3, 1, 1, bias=False),
            nn.Tanh()
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    def forward(self, z):
        # Erwartet: z.shape = (batch, 512)
        z = z.view(z.size(0), 512, 1, 1)
        return self.model(z)


In [58]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.layers = nn.Sequential(
            # 256->128
            nn.Conv2d(3, 16, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 128->64
            nn.Conv2d(16, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 64->32
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 32->16
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 16->8
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 8->4
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 4->1
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
        )
        
        # Gewichte initialisieren
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    
    def forward(self, x):
        return self.layers(x).view(-1)

In [None]:
def find_latest_checkpoint():
    """Findet das neueste Checkpoint automatisch"""
    checkpoint_dir = 'training/StyleGANv2NEW/epochen'
    if not os.path.exists(checkpoint_dir):
        return None
    
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_*.pth'))
    if not checkpoints:
        return None
    
    # Sortiere nach Epoch-Nummer
    checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
    latest = checkpoints[-1]
    print(f"Neuestes Checkpoint gefunden: {latest}")
    return latest

In [None]:
def setup_training(resume_from=None):
    # Models
    G = Generator().to(device)
    D = Discriminator().to(device)
    
    # Optimizers
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.999))
    
    # Loss
    criterion = nn.BCEWithLogitsLoss()
    
    workers = max(1, os.cpu_count() - 1)  # immer mindestens 1

    # Dataset
    dataset = SimpleImageDataset(DATA_DIR)
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=workers,
        pin_memory=True if device.type == 'cuda' else False
)
    
    # Resume Setup
    start_epoch = 0
    losses = {'G': [], 'D': []}
    
    if resume_from:
        print(f"Lade Checkpoint: {resume_from}")
        if os.path.exists(resume_from):
            checkpoint = torch.load(resume_from, map_location=device)
            
            # Models laden
            G.load_state_dict(checkpoint['generator'])
            D.load_state_dict(checkpoint['discriminator'])
            
            # Optimizer states laden
            opt_G.load_state_dict(checkpoint['opt_G'])
            opt_D.load_state_dict(checkpoint['opt_D'])
            
            # Training state laden
            start_epoch = checkpoint['epoch']
            losses = checkpoint.get('losses', {'G': [], 'D': []})
            
            print(f" Resume von Epoch {start_epoch}")
        else:
            print(f"Checkpoint {resume_from} nicht gefunden! Starte neues Training.")
    
    return G, D, opt_G, opt_D, criterion, dataloader, start_epoch, losses


In [None]:
def save_samples(G, epoch, n_samples=16):
    """Beispielbilder speichern"""
    G.eval()
    with torch.no_grad():
        z = torch.randn(n_samples, 512, device=device)
        fake = G(z)
        
        # Grid erstellen
        grid = make_grid(fake, nrow=4, normalize=True, value_range=(-1, 1))
        
        # Als Bild speichern
        plt.figure(figsize=(10, 10))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title(f'Epoch {epoch}')
        plt.savefig(f'training/StyleGANv2NEW/bilder/epoch_{epoch:04d}.png', bbox_inches='tight', dpi=150)
        plt.close()  # Memory cleanup
    G.train()

def save_checkpoint(G, D, opt_G, opt_D, epoch, losses):
    """Model speichern"""
    checkpoint_path = f'training/StyleGANv2NEW/epochen/checkpoint_{epoch:04d}.pth'
    torch.save({
        'epoch': epoch,
        'generator': G.state_dict(),
        'discriminator': D.state_dict(),
        'opt_G': opt_G.state_dict(),
        'opt_D': opt_D.state_dict(),
        'losses': losses
    }, checkpoint_path)
    
    # Auch als "latest" speichern für einfaches Resume
    latest_path = 'training/StyleGANv2NEW/epochen/latest_checkpoint.pth'
    torch.save({
        'epoch': epoch,
        'generator': G.state_dict(),
        'discriminator': D.state_dict(),
        'opt_G': opt_G.state_dict(),
        'opt_D': opt_D.state_dict(),
        'losses': losses
    }, latest_path)
    
    print(f" Checkpoint gespeichert: {checkpoint_path}")
    
def plot_losses(losses):
    """Loss-Verlauf anzeigen"""
    if not losses['G'] or not losses['D']:
        return
        
    plt.figure(figsize=(10, 4))
    plt.plot(losses['G'], label='Generator', color='blue')
    plt.plot(losses['D'], label='Discriminator', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Losses')
    plt.grid(True, alpha=0.3)
    plt.show()

In [None]:
from torch.amp import autocast, GradScaler

def train(resume_from=None):
    print(" Training starten...")

    # Setup mit Resume-Funktionalität
    G, D, opt_G, opt_D, criterion, dataloader, start_epoch, losses = setup_training(resume_from)

    # Mixed Precision Training
    scaler_G = GradScaler(device=device)
    scaler_D = GradScaler(device=device)

    # Fixed noise für konsistente Samples
    fixed_noise = torch.randn(16, 512, device=device)

    real_label = 0.9  # Label smoothing
    fake_label = 0.1

    print(f"Starte Training von Epoch {start_epoch + 1} bis {EPOCHS}")

    for epoch in range(start_epoch, EPOCHS):
        epoch_d_loss = 0
        epoch_g_loss = 0

        pbar = tqdm(dataloader, desc=f"Epoche {epoch+1}/{EPOCHS}", leave=False)

        for real_batch in pbar:
            real_batch = real_batch.to(device)
            batch_size = real_batch.size(0)

            # Label smoothing mit leichtem Noise
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            real_labels += 0.05 * torch.randn_like(real_labels)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
            fake_labels += 0.05 * torch.randn_like(fake_labels)

            # ----- Discriminator Update -----
            D.zero_grad()

            with autocast(device_type=device.type):
                # Real images
                output = D(real_batch)
                d_loss_real = criterion(output, real_labels)

                # Fake images
                noise = torch.randn(batch_size, 512, device=device)
                fake = G(noise)
                output_fake = D(fake.detach())
                d_loss_fake = criterion(output_fake, fake_labels)

                d_loss = d_loss_real + d_loss_fake

            scaler_D.scale(d_loss).backward()
            scaler_D.step(opt_D)
            scaler_D.update()

            # ----- Generator Update -----
            G.zero_grad()

            with autocast(device_type=device.type):
                output = D(fake)
                g_loss = criterion(output, torch.full((batch_size,), real_label, dtype=torch.float, device=device))

            scaler_G.scale(g_loss).backward()
            scaler_G.step(opt_G)
            scaler_G.update()

            batch_d_loss = d_loss.item()
            batch_g_loss = g_loss.item()
            epoch_d_loss += batch_d_loss
            epoch_g_loss += batch_g_loss

            pbar.set_postfix({
                "D_Loss": f"{batch_d_loss:.4f}",
                "G_Loss": f"{batch_g_loss:.4f}"
            })

        # Epoch Statistics
        avg_d_loss = epoch_d_loss / len(dataloader)
        avg_g_loss = epoch_g_loss / len(dataloader)
        losses['D'].append(avg_d_loss)
        losses['G'].append(avg_g_loss)

        print(f" Epoch {epoch+1}: G_Loss={avg_g_loss:.4f}, D_Loss={avg_d_loss:.4f}")

        # Speichern alle SAVE_EVERY Epochen
        if (epoch + 1) % SAVE_EVERY == 0:
            # Fixed noise samples für Konsistenz
            G.eval()
            with torch.no_grad():
                test_fake = G(fixed_noise)
                grid = make_grid(test_fake, nrow=4, normalize=True, value_range=(-1, 1))
                plt.figure(figsize=(10, 10))
                plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
                plt.axis('off')
                plt.title(f'Epoch {epoch + 1}')
                plt.savefig(f'training/StyleGANv2NEW/bilder/epoch_{epoch + 1:04d}.png', bbox_inches='tight', dpi=150)
                plt.show()
            G.train()
            
            save_checkpoint(G, D, opt_G, opt_D, epoch + 1, losses)
            plot_losses(losses)

        # Backup speichern (überschreibt das vorherige)
        if (epoch + 1) % 5 == 0:  # Häufiger speichern für Resume
            save_checkpoint(G, D, opt_G, opt_D, epoch + 1, losses)

    print(" Training abgeschlossen!")

    import pandas as pd
    df_losses = pd.DataFrame({
        'Epoch': list(range(1, len(losses['G']) + 1)),
        'Generator_Loss': losses['G'],
        'Discriminator_Loss': losses['D']
    })
    df_losses.to_csv('training/StyleGANv2NEW/losses.csv', index=False)
    print(" Loss-Tabelle als CSV gespeichert: training/StyleGANv2NEW/losses.csv")

    return G, D, losses

### Bilder Generieren


In [63]:
def generate_images(G=None, n=16):
    """Neue Bilder generieren"""
    if G is None:
        if 'G' in globals():
            G = globals()['G']
        else:
            print("Kein Generator verfügbar! Führe zuerst das Training aus oder lade ein Checkpoint.")
            return
    
    G.eval()
    with torch.no_grad():
        z = torch.randn(n, 512, device=device)
        fake = G(z)
        
        grid = make_grid(fake, nrow=4, normalize=True, value_range=(-1, 1))
        
        plt.figure(figsize=(12, 12))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title('Generierte Bilder')
        plt.show()

In [64]:
def load_checkpoint(path):
    """Gespeichertes Model laden"""
    if not os.path.exists(path):
        print(f"Checkpoint {path} nicht gefunden!")
        return None
        
    checkpoint = torch.load(path, map_location=device)
    
    G = Generator().to(device)
    G.load_state_dict(checkpoint['generator'])
    G.eval()
    
    print(f"Model geladen von Epoch {checkpoint['epoch']}")
    return G

In [None]:
def resume_latest():
    """Automatisch vom neuesten Checkpoint fortsetzen"""
    latest = find_latest_checkpoint()
    if latest:
        return train(resume_from=latest)
    else:
        print("Kein Checkpoint gefunden. Starte neues Training.")
        return train()

def resume_from_epoch(epoch):
    """Von einer bestimmten Epoch fortsetzen"""
    checkpoint_path = f'training/StyleGANv2NEW/epochen/checkpoint_{epoch:04d}.pth'
    if os.path.exists(checkpoint_path):
        return train(resume_from=checkpoint_path)
    else:
        print(f"Checkpoint für Epoch {epoch} nicht gefunden!")
        return None

In [66]:
if __name__ == "__main__":
    # Option 1: Neues Training starten
    # G, D, losses = train()
    
    # Option 2: Vom neuesten Checkpoint fortsetzen
    # G, D, losses = resume_latest()
    
    # Option 3: Von spezifischer Epoch fortsetzen
     G, D, losses = resume_from_epoch(570)
    
    # Option 4: Von spezifischem Checkpoint fortsetzen
    # G, D, losses = train(resume_from='training/StyleGANv2/epochen/checkpoint_0050.pth')

🚀 Training starten...
📁 Gefunden: 34846 Bilder
🔄 Lade Checkpoint: training/StyleGANv2/epochen/checkpoint_0570.pth


RuntimeError: Error(s) in loading state_dict for Generator:
	Missing key(s) in state_dict: "model.0.weight", "model.1.weight", "model.1.bias", "model.1.running_mean", "model.1.running_var", "model.3.weight", "model.4.weight", "model.4.bias", "model.4.running_mean", "model.4.running_var", "model.6.weight", "model.7.weight", "model.7.bias", "model.7.running_mean", "model.7.running_var", "model.9.weight", "model.10.weight", "model.10.bias", "model.10.running_mean", "model.10.running_var", "model.12.weight", "model.13.weight", "model.13.bias", "model.13.running_mean", "model.13.running_var", "model.15.weight", "model.16.weight", "model.16.bias", "model.16.running_mean", "model.16.running_var", "model.18.weight", "model.19.weight", "model.19.bias", "model.19.running_mean", "model.19.running_var", "model.21.weight". 
	Unexpected key(s) in state_dict: "start.0.weight", "start.1.weight", "start.1.bias", "start.1.running_mean", "start.1.running_var", "start.1.num_batches_tracked", "ups.0.0.weight", "ups.0.1.weight", "ups.0.1.bias", "ups.0.1.running_mean", "ups.0.1.running_var", "ups.0.1.num_batches_tracked", "ups.1.0.weight", "ups.1.1.weight", "ups.1.1.bias", "ups.1.1.running_mean", "ups.1.1.running_var", "ups.1.1.num_batches_tracked", "ups.2.0.weight", "ups.2.1.weight", "ups.2.1.bias", "ups.2.1.running_mean", "ups.2.1.running_var", "ups.2.1.num_batches_tracked", "ups.3.0.weight", "ups.3.1.weight", "ups.3.1.bias", "ups.3.1.running_mean", "ups.3.1.running_var", "ups.3.1.num_batches_tracked", "ups.4.0.weight", "ups.4.1.weight", "ups.4.1.bias", "ups.4.1.running_mean", "ups.4.1.running_var", "ups.4.1.num_batches_tracked", "ups.5.0.weight", "ups.5.1.weight", "ups.5.1.bias", "ups.5.1.running_mean", "ups.5.1.running_var", "ups.5.1.num_batches_tracked", "to_rgb.0.weight". 