In [None]:
DATA_DIR = "images/256"     # Deine Trainingsbilder
EPOCHS = 570                         # Anzahl Trainingsdurchläufe
BATCH_SIZE = 128                     # Bilder pro Batch (bei Speicherproblemen reduzieren)
SAVE_EVERY = 10                      # Alle X Epochen speichern


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from PIL import Image
import os, glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')


# Ordner anlegen
os.makedirs('training/StyleGANv2.3/bilder', exist_ok=True)
os.makedirs('training/StyleGANv2.3/epochen', exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Verwende: {device}")

In [None]:
class SimpleImageDataset(Dataset):
    def __init__(self, folder):
        self.files = []
        # Erweiterte Dateierkennung
        extensions = ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG', '*.JPEG']
        for ext in extensions:
            self.files.extend(glob.glob(os.path.join(folder, ext)))
        
        if len(self.files) == 0:
            raise ValueError(f"Keine Bilder in {folder} gefunden! Überprüfe den Pfad.")
        
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
        print(f"Gefunden: {len(self.files)} Bilder")
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        try:
            img = Image.open(self.files[idx]).convert('RGB')
            return self.transform(img)
        except Exception as e:
            print(f"Fehler beim Laden von {self.files[idx]}: {e}")
            # Fallback: zufälliges anderes Bild
            return self.__getitem__((idx + 1) % len(self.files))

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Start: 512 -> 4x4
        self.start = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        
        # Upsampling Layers: 4x4 -> 256x256
        self.ups = nn.ModuleList([
            self._make_layer(512, 256),  # 4->8
            self._make_layer(256, 128),  # 8->16
            self._make_layer(128, 64),   # 16->32
            self._make_layer(64, 32),    # 32->64
            self._make_layer(32, 16),    # 64->128
            self._make_layer(16, 8),     # 128->256
        ])
        
        # Final RGB
        self.to_rgb = nn.Sequential(
            nn.Conv2d(8, 3, 3, 1, 1, bias=False),
            nn.Tanh()
        )
        
        # Gewichte initialisieren
        self.apply(self._init_weights)
    
    def _make_layer(self, in_ch, out_ch):
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(True)
        )
    
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    
    def forward(self, z):
        x = z.view(z.size(0), 512, 1, 1)
        x = self.start(x)
        
        for up in self.ups:
            x = up(x)
        
        return self.to_rgb(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.layers = nn.Sequential(
            # 256->128
            nn.Conv2d(3, 16, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 128->64
            nn.Conv2d(16, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 64->32
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 32->16
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 16->8
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 8->4
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 4->1
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
        )
        
        # Gewichte initialisieren
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    
    def forward(self, x):
        return self.layers(x).view(-1)

In [None]:
def find_latest_checkpoint():
    """Findet das neueste Checkpoint automatisch"""
    checkpoint_dir = 'training/StyleGANv2.3/epochen'
    if not os.path.exists(checkpoint_dir):
        return None
    
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_*.pth'))
    if not checkpoints:
        return None
    
    # Sortiere nach Epoch-Nummer
    checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
    latest = checkpoints[-1]
    print(f"Neuestes Checkpoint gefunden: {latest}")
    return latest

In [None]:
def setup_training(resume_from=None):
    # Models
    G = Generator().to(device)
    D = Discriminator().to(device)
    
    # Optimizers
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    # Loss
    criterion = nn.BCEWithLogitsLoss()
    
    workers = max(1, os.cpu_count() - 1)  # immer mindestens 1

    # Dataset
    dataset = SimpleImageDataset(DATA_DIR)
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=workers,
        pin_memory=True if device.type == 'cuda' else False
)
    
    # Resume Setup
    start_epoch = 0
    losses = {'G': [], 'D': []}
    
    if resume_from:
        print(f"Lade Checkpoint: {resume_from}")
        if os.path.exists(resume_from):
            checkpoint = torch.load(resume_from, map_location=device)
            
            # Models laden
            G.load_state_dict(checkpoint['generator'])
            D.load_state_dict(checkpoint['discriminator'])
            
            # Optimizer states laden
            opt_G.load_state_dict(checkpoint['opt_G'])
            opt_D.load_state_dict(checkpoint['opt_D'])
            
            # Training state laden
            start_epoch = checkpoint['epoch']
            losses = checkpoint.get('losses', {'G': [], 'D': []})
            
            print(f" Resume von Epoch {start_epoch}")
        else:
            print(f" Checkpoint {resume_from} nicht gefunden! Starte neues Training.")
    
    return G, D, opt_G, opt_D, criterion, dataloader, start_epoch, losses


In [None]:
def save_samples(G, epoch, n_samples=16):
    """Beispielbilder speichern"""
    G.eval()
    with torch.no_grad():
        z = torch.randn(n_samples, 512, device=device)
        fake = G(z)
        
        # Grid erstellen
        grid = make_grid(fake, nrow=4, normalize=True, value_range=(-1, 1))
        
        # Als Bild speichern
        plt.figure(figsize=(10, 10))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title(f'Epoch {epoch}')
        plt.savefig(f'training/StyleGANv2.3/bilder/epoch_{epoch:04d}.png', bbox_inches='tight', dpi=150)
        plt.close()  # Memory cleanup
    G.train()

def save_checkpoint(G, D, opt_G, opt_D, epoch, losses):
    """Model speichern"""
    checkpoint_path = f'training/StyleGANv2.3/epochen/checkpoint_{epoch:04d}.pth'
    torch.save({
        'epoch': epoch,
        'generator': G.state_dict(),
        'discriminator': D.state_dict(),
        'opt_G': opt_G.state_dict(),
        'opt_D': opt_D.state_dict(),
        'losses': losses
    }, checkpoint_path)
    
    # Auch als "latest" speichern für einfaches Resume
    latest_path = 'training/StyleGANv2.3/epochen/latest_checkpoint.pth'
    torch.save({
        'epoch': epoch,
        'generator': G.state_dict(),
        'discriminator': D.state_dict(),
        'opt_G': opt_G.state_dict(),
        'opt_D': opt_D.state_dict(),
        'losses': losses
    }, latest_path)
    
    print(f" Checkpoint gespeichert: {checkpoint_path}")
    
def plot_losses(losses):
    """Loss-Verlauf anzeigen"""
    if not losses['G'] or not losses['D']:
        return
        
    plt.figure(figsize=(10, 4))
    plt.plot(losses['G'], label='Generator', color='blue')
    plt.plot(losses['D'], label='Discriminator', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Losses')
    plt.grid(True, alpha=0.3)
    plt.show()

In [None]:
def plot_losses_with_balance(losses, g_history, d_history):
    """Erweiterte Loss-Visualisierung mit Balance-Info"""
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
    
    # Epoch Losses
    ax1.plot(losses['G'], label='Generator', color='blue', linewidth=2)
    ax1.plot(losses['D'], label='Discriminator', color='red', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.set_title('Training Losses (Epoch Average)')
    ax1.grid(True, alpha=0.3)
    
    # Balance Ratio
    if len(losses['G']) > 0 and len(losses['D']) > 0:
        balance_ratios = [g/max(d, 0.001) for g, d in zip(losses['G'], losses['D'])]
        ax2.plot(balance_ratios, color='green', linewidth=2)
        ax2.axhline(y=1.0, color='black', linestyle='--', alpha=0.7, label='Perfect Balance')
        ax2.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5, label='D Dominiert')
        ax2.axhline(y=2.0, color='purple', linestyle='--', alpha=0.5, label='G Dominiert')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('G_Loss / D_Loss Ratio')
        ax2.set_title('Loss Balance Ratio (ideal: ~1.0)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()



In [None]:

def monitor_training_health(G, D, dataloader, device):
    """Überwacht die Gesundheit des Trainings"""
    G.eval()
    D.eval()
    
    total_samples = 0
    d_accuracy_real = 0
    d_accuracy_fake = 0
    
    with torch.no_grad():
        for i, real_batch in enumerate(dataloader):
            if i >= 5:  # Nur erste 5 Batches testen
                break
                
            real_batch = real_batch.to(device)
            batch_size = real_batch.size(0)
            
            # Discriminator Accuracy auf echten Bildern
            real_pred = torch.sigmoid(D(real_batch))
            d_accuracy_real += (real_pred > 0.5).sum().item()
            
            # Discriminator Accuracy auf fake Bildern
            noise = torch.randn(batch_size, 512, device=device)
            fake = G(noise)
            fake_pred = torch.sigmoid(D(fake))
            d_accuracy_fake += (fake_pred < 0.5).sum().item()
            
            total_samples += batch_size
    
    real_acc = d_accuracy_real / total_samples
    fake_acc = d_accuracy_fake / total_samples
    overall_acc = (d_accuracy_real + d_accuracy_fake) / (total_samples * 2)
    
    print(f"\nTraining Health Check:")
    print(f"   Discriminator Accuracy (Real): {real_acc:.3f}")
    print(f"   Discriminator Accuracy (Fake): {fake_acc:.3f}")
    print(f"   Overall Accuracy: {overall_acc:.3f}")
    
    if overall_acc > 0.8:
        print(" Discriminator zu stark!")
        return "d_strong"
    elif overall_acc < 0.4:
        print(" Discriminator zu schwach!")
        return "d_weak"
    else:
        print(" Discriminator Balance OK")
        return "balanced"
    
    G.train()
    D.train()

In [None]:
from torch.amp import autocast, GradScaler

def train(resume_from=None):
    print("Training starten...")

    # Setup mit Resume-Funktionalität
    G, D, opt_G, opt_D, criterion, dataloader, start_epoch, losses = setup_training(resume_from)

    # WICHTIG: Gleiche Learning Rates verwenden!
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Erhöht von 0.000002!

    # Mixed Precision Training
    scaler_G = GradScaler(device=device)
    scaler_D = GradScaler(device=device)

    # Fixed noise für konsistente Samples
    fixed_noise = torch.randn(16, 512, device=device)

    # Weniger aggressives Label Smoothing
    real_label = 0.95  # Statt 0.9
    fake_label = 0.05  # Statt 0.1

    # Tracking für adaptive Ratio
    g_loss_history = []
    d_loss_history = []

    print(f"Starte Training von Epoch {start_epoch + 1} bis {EPOCHS}")

    for epoch in range(start_epoch, EPOCHS):
        epoch_d_loss = 0
        epoch_g_loss = 0
        d_update_count = 0  # Zählt D Updates

        pbar = tqdm(dataloader, desc=f"Epoche {epoch+1}/{EPOCHS}", leave=False)

        for batch_idx, real_batch in enumerate(pbar):
            real_batch = real_batch.to(device)
            batch_size = real_batch.size(0)

            # Weniger aggressives Label smoothing
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            real_labels += 0.02 * torch.randn_like(real_labels)  # Reduziert von 0.05
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
            fake_labels += 0.02 * torch.randn_like(fake_labels)

            # ===== GENERATOR UPDATE (immer) =====
            G.zero_grad()

            with autocast(device_type=device.type):
                noise = torch.randn(batch_size, 512, device=device)
                fake = G(noise)
                output = D(fake)
                g_loss = criterion(output, torch.full((batch_size,), real_label, dtype=torch.float, device=device))

            scaler_G.scale(g_loss).backward()
            # Gradient Clipping für Stabilität
            scaler_G.unscale_(opt_G)
            torch.nn.utils.clip_grad_norm_(G.parameters(), 1.0)
            scaler_G.step(opt_G)
            scaler_G.update()

            current_g_loss = g_loss.item()
            epoch_g_loss += current_g_loss

            # ===== ADAPTIVE DISCRIMINATOR UPDATES =====
            # Entscheide wie oft D updaten basierend auf G Performance
            if current_g_loss > 2.0:
                d_updates = 2  # G kämpft sehr → D öfter updaten
            elif current_g_loss > 1.2:
                d_updates = 1  # G kämpft → Normal updaten
            elif current_g_loss > 0.8:
                d_updates = 1  # Ausgewogen → Normal updaten
            else:
                d_updates = 0  # G dominiert → D pausieren lassen

            total_d_loss = 0
            for d_step in range(d_updates):
                D.zero_grad()

                with autocast(device_type=device.type):
                    # Real images
                    output_real = D(real_batch)
                    d_loss_real = criterion(output_real, real_labels)

                    # Fake images (wichtig: detach!)
                    fake_detached = fake.detach()
                    output_fake = D(fake_detached)
                    d_loss_fake = criterion(output_fake, fake_labels)

                    d_loss = d_loss_real + d_loss_fake

                scaler_D.scale(d_loss).backward()
                # Gradient Clipping für Stabilität
                scaler_D.unscale_(opt_D)
                torch.nn.utils.clip_grad_norm_(D.parameters(), 1.0)
                scaler_D.step(opt_D)
                scaler_D.update()

                total_d_loss += d_loss.item()
                d_update_count += 1

            # Durchschnittlicher D Loss für diesen Batch
            avg_d_loss = total_d_loss / max(d_updates, 1)
            epoch_d_loss += avg_d_loss

            # Progress Bar Update
            pbar.set_postfix({
                "G_Loss": f"{current_g_loss:.4f}",
                "D_Loss": f"{avg_d_loss:.4f}",
                "D_Updates": d_updates
            })

            # Alle 100 Batches: Zeige Balance-Info
            if batch_idx % 100 == 0:
                recent_g = sum(g_loss_history[-10:]) / max(len(g_loss_history[-10:]), 1)
                recent_d = sum(d_loss_history[-10:]) / max(len(d_loss_history[-10:]), 1)
                balance_ratio = recent_g / max(recent_d, 0.001)
                
                print(f"\nBatch {batch_idx}: G={current_g_loss:.3f}, D={avg_d_loss:.3f}, "
                      f"Balance={balance_ratio:.2f}, D_Updates={d_updates}")

            # Loss History für Tracking
            g_loss_history.append(current_g_loss)
            d_loss_history.append(avg_d_loss)

        # Epoch Statistics
        avg_d_loss = epoch_d_loss / len(dataloader)
        avg_g_loss = epoch_g_loss / len(dataloader)
        avg_d_updates = d_update_count / len(dataloader)
        
        losses['D'].append(avg_d_loss)
        losses['G'].append(avg_g_loss)

        # Balance Ratio für Epoch
        balance_ratio = avg_g_loss / max(avg_d_loss, 0.001)

        print(f"Epoch {epoch+1}: G_Loss={avg_g_loss:.4f}, D_Loss={avg_d_loss:.4f}, "
              f"Balance={balance_ratio:.2f}, Avg_D_Updates={avg_d_updates:.1f}")

        # Warnung bei extremen Unbalancen
        if balance_ratio > 5.0:
            print(" Generator dominiert stark! Discriminator wird öfter geupdated.")
        elif balance_ratio < 0.2:
            print(" Discriminator dominiert stark! Generator braucht Hilfe.")
        elif 0.8 <= balance_ratio <= 1.5:
            print("Gute Balance erreicht!")

        # Speichern alle SAVE_EVERY Epochen
        if (epoch + 1) % SAVE_EVERY == 0:
            # Fixed noise samples für Konsistenz
            G.eval()
            with torch.no_grad():
                test_fake = G(fixed_noise)
                grid = make_grid(test_fake, nrow=4, normalize=True, value_range=(-1, 1))
                plt.figure(figsize=(10, 10))
                plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
                plt.axis('off')
                plt.title(f'Epoch {epoch + 1} - Balance: {balance_ratio:.2f}')
                plt.savefig(f'training/StyleGANv2.3/bilder/epoch_{epoch + 1:04d}.png', bbox_inches='tight', dpi=150)
                plt.show()
            G.train()
            
            save_checkpoint(G, D, opt_G, opt_D, epoch + 1, losses)
            plot_losses_with_balance(losses, g_loss_history, d_loss_history)

        # Backup speichern (überschreibt das vorherige)
        if (epoch + 1) % 5 == 0:  # Häufiger speichern für Resume
            save_checkpoint(G, D, opt_G, opt_D, epoch + 1, losses)

    print(" Training abgeschlossen!")

    # Speichere erweiterte Statistiken
    import pandas as pd
    df_losses = pd.DataFrame({
        'Epoch': list(range(1, len(losses['G']) + 1)),
        'Generator_Loss': losses['G'],
        'Discriminator_Loss': losses['D'],
        'Balance_Ratio': [g/max(d, 0.001) for g, d in zip(losses['G'], losses['D'])]
    })
    df_losses.to_csv('training/StyleGANv2.3/losses.csv', index=False)
    print(" Loss-Tabelle als CSV gespeichert: training/StyleGANv2.3/losses.csv")

    return G, D, losses


### Bilder Generieren


In [None]:
def generate_images(G=None, n=16):
    """Neue Bilder generieren"""
    if G is None:
        if 'G' in globals():
            G = globals()['G']
        else:
            print("Kein Generator verfügbar! Führe zuerst das Training aus oder lade ein Checkpoint.")
            return
    
    G.eval()
    with torch.no_grad():
        z = torch.randn(n, 512, device=device)
        fake = G(z)
        
        grid = make_grid(fake, nrow=4, normalize=True, value_range=(-1, 1))
        
        plt.figure(figsize=(12, 12))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title('Generierte Bilder')
        plt.show()

In [None]:
def load_checkpoint(path):
    """Gespeichertes Model laden"""
    if not os.path.exists(path):
        print(f"Checkpoint {path} nicht gefunden!")
        return None
        
    checkpoint = torch.load(path, map_location=device)
    
    G = Generator().to(device)
    G.load_state_dict(checkpoint['generator'])
    G.eval()
    
    print(f"Model geladen von Epoch {checkpoint['epoch']}")
    return G

In [None]:
def resume_latest():
    """Automatisch vom neuesten Checkpoint fortsetzen"""
    latest = find_latest_checkpoint()
    if latest:
        return train(resume_from=latest)
    else:
        print("Kein Checkpoint gefunden. Starte neues Training.")
        return train()

def resume_from_epoch(epoch):
    """Von einer bestimmten Epoch fortsetzen"""
    checkpoint_path = f'training/StyleGANv2/epochen/checkpoint_{epoch:04d}.pth'
    if os.path.exists(checkpoint_path):
        return train(resume_from=checkpoint_path)
    else:
        print(f"Checkpoint für Epoch {epoch} nicht gefunden!")
        return None

In [None]:
if __name__ == "__main__":
    # Option 1: Neues Training starten
    # G, D, losses = train()
    
    # Option 2: Vom neuesten Checkpoint fortsetzen
    G, D, losses = resume_latest()
    
    # Option 3: Von spezifischer Epoch fortsetzen
    # G, D, losses = resume_from_epoch(50)
    
    # Option 4: Von spezifischem Checkpoint fortsetzen
    # G, D, losses = train(resume_from='training/StyleGANv2/epochen/checkpoint_0050.pth')