In [None]:
DATA_DIR = "images/256"     # Deine Trainingsbilder
EPOCHS = 5000                         # Anzahl Trainingsdurchläufe
BATCH_SIZE = 128                     # Bilder pro Batch (bei Speicherproblemen reduzieren)
SAVE_EVERY = 10                      # Alle X Epochen speichern


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from PIL import Image
import os, glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')


# Ordner anlegen
os.makedirs('training/StyleGANv3/bilder', exist_ok=True)
os.makedirs('training/StyleGANv3/epochen', exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Verwende: {device}")

In [None]:
class SimpleImageDataset(Dataset):
    def __init__(self, folder):
        self.files = []
        # Erweiterte Dateierkennung
        extensions = ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG', '*.JPEG']
        for ext in extensions:
            self.files.extend(glob.glob(os.path.join(folder, ext)))
        
        if len(self.files) == 0:
            raise ValueError(f"Keine Bilder in {folder} gefunden! Überprüfe den Pfad.")
        
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
        print(f" Gefunden: {len(self.files)} Bilder")
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        try:
            img = Image.open(self.files[idx]).convert('RGB')
            return self.transform(img)
        except Exception as e:
            print(f"Fehler beim Laden von {self.files[idx]}: {e}")
            # Fallback: zufälliges anderes Bild
            return self.__getitem__((idx + 1) % len(self.files))

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Start: 512 -> 4x4
        self.start = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        
        # Upsampling Layers: 4x4 -> 256x256
        self.ups = nn.ModuleList([
            self._make_layer(512, 256),  # 4->8
            self._make_layer(256, 128),  # 8->16
            self._make_layer(128, 64),   # 16->32
            self._make_layer(64, 32),    # 32->64
            self._make_layer(32, 16),    # 64->128
            self._make_layer(16, 8),     # 128->256
        ])
        
        # Final RGB
        self.to_rgb = nn.Sequential(
            nn.Conv2d(8, 3, 3, 1, 1, bias=False),
            nn.Tanh()
        )
        
        # Gewichte initialisieren
        self.apply(self._init_weights)
    
    def _make_layer(self, in_ch, out_ch):
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(True)
        )
    
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    
    def forward(self, z):
        x = z.view(z.size(0), 512, 1, 1)
        x = self.start(x)
        
        for up in self.ups:
            x = up(x)
        
        return self.to_rgb(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.layers = nn.Sequential(
            # 256->128
            nn.Conv2d(3, 16, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 128->64
            nn.Conv2d(16, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 64->32
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 32->16
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 16->8
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 8->4
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 4->1
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
        )
        
        # Gewichte initialisieren
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    
    def forward(self, x):
        return self.layers(x).view(-1)

In [None]:
def find_latest_checkpoint():
    """Findet das neueste Checkpoint automatisch"""
    checkpoint_dir = 'training/StyleGANv3/epochen'
    if not os.path.exists(checkpoint_dir):
        return None
    
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_*.pth'))
    if not checkpoints:
        return None
    
    # Sortiere nach Epoch-Nummer
    checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
    latest = checkpoints[-1]
    print(f" Neuestes Checkpoint gefunden: {latest}")
    return latest

In [None]:
def setup_training(resume_from=None):
    # Models
    G = Generator().to(device)
    D = Discriminator().to(device)
    
    # Optimizers
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.00002, betas=(0.5, 0.999))
    
    # Loss
    criterion = nn.BCEWithLogitsLoss()
    
    workers = max(1, os.cpu_count() - 1)  # immer mindestens 1

    # Dataset
    dataset = SimpleImageDataset(DATA_DIR)
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=workers,
        pin_memory=True if device.type == 'cuda' else False
)
    
    # Resume Setup
    start_epoch = 0
    losses = {'G': [], 'D': []}
    
    if resume_from:
        print(f"Lade Checkpoint: {resume_from}")
        if os.path.exists(resume_from):
            checkpoint = torch.load(resume_from, map_location=device)
            
            # Models laden
            G.load_state_dict(checkpoint['generator'])
            D.load_state_dict(checkpoint['discriminator'])
            
            # Optimizer states laden
            opt_G.load_state_dict(checkpoint['opt_G'])
            opt_D.load_state_dict(checkpoint['opt_D'])
            
            # Training state laden
            start_epoch = checkpoint['epoch']
            losses = checkpoint.get('losses', {'G': [], 'D': []})
            
            print(f"Resume von Epoch {start_epoch}")
        else:
            print(f"Checkpoint {resume_from} nicht gefunden! Starte neues Training.")
    
    return G, D, opt_G, opt_D, criterion, dataloader, start_epoch, losses


In [None]:
def save_samples(G, epoch, n_samples=16):
    """Beispielbilder speichern"""
    G.eval()
    with torch.no_grad():
        z = torch.randn(n_samples, 512, device=device)
        fake = G(z)
        
        # Grid erstellen
        grid = make_grid(fake, nrow=4, normalize=True, value_range=(-1, 1))
        
        # Als Bild speichern
        plt.figure(figsize=(10, 10))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title(f'Epoch {epoch}')
        plt.savefig(f'training/StyleGANv3/bilder/epoch_{epoch:04d}.png', bbox_inches='tight', dpi=150)
        plt.close()  # Memory cleanup
    G.train()

def save_checkpoint(G, D, opt_G, opt_D, epoch, losses):
    """Model speichern"""
    checkpoint_path = f'training/StyleGANv3/epochen/checkpoint_{epoch:04d}.pth'
    torch.save({
        'epoch': epoch,
        'generator': G.state_dict(),
        'discriminator': D.state_dict(),
        'opt_G': opt_G.state_dict(),
        'opt_D': opt_D.state_dict(),
        'losses': losses
    }, checkpoint_path)
    
    # Auch als "latest" speichern für einfaches Resume
    latest_path = 'training/StyleGANv3/epochen/latest_checkpoint.pth'
    torch.save({
        'epoch': epoch,
        'generator': G.state_dict(),
        'discriminator': D.state_dict(),
        'opt_G': opt_G.state_dict(),
        'opt_D': opt_D.state_dict(),
        'losses': losses
    }, latest_path)
    
    print(f"Checkpoint gespeichert: {checkpoint_path}")
    
def plot_losses(losses):
    """Loss-Verlauf anzeigen"""
    if not losses['G'] or not losses['D']:
        return
        
    plt.figure(figsize=(10, 4))
    plt.plot(losses['G'], label='Generator', color='blue')
    plt.plot(losses['D'], label='Discriminator', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Losses')
    plt.grid(True, alpha=0.3)
    plt.show()

In [None]:
def analyze_training_health(losses):
    """Analysiert die Trainingsgesundheit"""
    if len(losses['G']) < 20:
        print("Nicht genug Daten für Analyse")
        return
    
    # Rolling averages
    window = min(20, len(losses['G']) // 2)
    g_rolling = np.convolve(losses['G'], np.ones(window)/window, mode='valid')
    d_rolling = np.convolve(losses['D'], np.ones(window)/window, mode='valid')
    
    # Trends berechnen
    recent_g = np.mean(losses['G'][-window:])
    recent_d = np.mean(losses['D'][-window:])
    
    print("=== TRAINING GESUNDHEITS-ANALYSE ===")
    print(f"Aktuelle Durchschnitte (letzte {window} Epochen):")
    print(f"  Generator Loss: {recent_g:.4f}")
    print(f"  Discriminator Loss: {recent_d:.4f}")
    print(f"  Loss Ratio (G/D): {recent_g/recent_d:.2f}")
    
    if recent_d < 0.5:
        print(" PROBLEM: Discriminator zu stark (Loss < 0.5)")
        print("   -> Empfehlung: Learning Rate für D reduzieren oder mehr G-Updates")
    
    if recent_g > 3.0:
        print(" PROBLEM: Generator kämpft (Loss > 3.0)")
        print("   -> Empfehlung: Learning Rate für G reduzieren oder weniger D-Updates")
    
    if recent_g/recent_d > 4.0:
        print(" PROBLEM: Unbalanciertes Training (G/D Ratio > 4)")
        print("   -> Empfehlung: Training-Balance anpassen")
    
    # Stabilität check
    g_variance = np.var(losses['G'][-window:])
    if g_variance > 1.0:
        print(" PROBLEM: Hohe Varianz in Generator Loss")
        print("   -> Empfehlung: Learning Rate reduzieren, Gradient Clipping")
    
    print("=" * 40)


In [None]:
from torch.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
import numpy as np

def improved_train(resume_from=None):
    print(" Verbessertes Training mit adaptiver Lernrate starten...")
    
    # Setup mit Resume-Funktionalität
    G, D, opt_G, opt_D, criterion, dataloader, start_epoch, losses = setup_training(resume_from)
    
    # VERBESSERUNG 1: Learning Rate Scheduler hinzufügen
    scheduler_G = ExponentialLR(opt_G, gamma=0.995)  # LR reduziert sich alle Epochen um 0.5%
    scheduler_D = ExponentialLR(opt_D, gamma=0.995)
    
    # VERBESSERUNG 2: Adaptive Learning Rate basierend auf Loss-Verhältnis
    plateau_scheduler_G = ReduceLROnPlateau(opt_G, mode='min', factor=0.5, patience=20)
    plateau_scheduler_D = ReduceLROnPlateau(opt_D, mode='min', factor=0.5, patience=15)
    
    # Mixed Precision Training
    scaler_G = GradScaler(device=device)
    scaler_D = GradScaler(device=device)
    
    # Fixed noise für konsistente Samples
    fixed_noise = torch.randn(16, 512, device=device)
    
    # VERBESSERUNG 3: Adaptive Label Smoothing
    def get_labels(batch_size, label_type, epoch):
        if label_type == 'real':
            # Weniger aggressive Label Smoothing bei späteren Epochen
            smoothing = max(0.05, 0.1 - epoch * 0.0001)
            labels = torch.full((batch_size,), 0.9 - smoothing, dtype=torch.float, device=device)
            labels += smoothing * torch.randn_like(labels) * 0.1
        else:  # fake
            smoothing = max(0.05, 0.1 - epoch * 0.0001)
            labels = torch.full((batch_size,), 0.1 + smoothing, dtype=torch.float, device=device)
            labels += smoothing * torch.randn_like(labels) * 0.1
        return labels
    
    print(f"Starte Training von Epoch {start_epoch + 1} bis {EPOCHS}")
    
    for epoch in range(start_epoch, EPOCHS):
        epoch_d_loss = 0
        epoch_g_loss = 0
        
        # VERBESSERUNG 4: Adaptives Training-Verhältnis
        d_loss_avg = np.mean(losses['D'][-10:]) if len(losses['D']) > 10 else 1.0
        g_loss_avg = np.mean(losses['G'][-10:]) if len(losses['G']) > 10 else 1.0
        
        # Wenn D zu stark wird, weniger D-Updates
        d_steps = 1
        g_steps = 1
        
        if d_loss_avg < 0.5 and g_loss_avg > 2.0:  # D zu stark
            d_steps = 1
            g_steps = 2
            print(f" Epoch {epoch+1}: D zu stark - mehr G Updates")
        elif d_loss_avg > 1.2 and g_loss_avg < 1.5:  # G zu stark
            d_steps = 2
            g_steps = 1
            print(f" Epoch {epoch+1}: G zu stark - mehr D Updates")
        
        pbar = tqdm(dataloader, desc=f"Epoche {epoch+1}/{EPOCHS}", leave=False)
        
        for real_batch in pbar:
            real_batch = real_batch.to(device)
            batch_size = real_batch.size(0)
            
            # ----- Discriminator Updates -----
            for _ in range(d_steps):
                D.zero_grad()
                
                with autocast(device_type=device.type):
                    # Real images
                    real_labels = get_labels(batch_size, 'real', epoch)
                    output_real = D(real_batch)
                    d_loss_real = criterion(output_real, real_labels)
                    
                    # Fake images
                    noise = torch.randn(batch_size, 512, device=device)
                    fake = G(noise)
                    fake_labels = get_labels(batch_size, 'fake', epoch)
                    output_fake = D(fake.detach())
                    d_loss_fake = criterion(output_fake, fake_labels)
                    
                    d_loss = (d_loss_real + d_loss_fake) / 2
                
                scaler_D.scale(d_loss).backward()
                
                # VERBESSERUNG 5: Gradient Clipping
                scaler_D.unscale_(opt_D)
                torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=1.0)
                
                scaler_D.step(opt_D)
                scaler_D.update()
            
            # ----- Generator Updates -----
            for _ in range(g_steps):
                G.zero_grad()
                
                with autocast(device_type=device.type):
                    noise = torch.randn(batch_size, 512, device=device)
                    fake = G(noise)
                    output = D(fake)
                    
                    # Standard adversarial loss
                    real_labels = get_labels(batch_size, 'real', epoch)
                    g_loss = criterion(output, real_labels)
                
                scaler_G.scale(g_loss).backward()
                
                # Gradient Clipping für Generator
                scaler_G.unscale_(opt_G)
                torch.nn.utils.clip_grad_norm_(G.parameters(), max_norm=1.0)
                
                scaler_G.step(opt_G)
                scaler_G.update()
            
            batch_d_loss = d_loss.item()
            batch_g_loss = g_loss.item()
            epoch_d_loss += batch_d_loss
            epoch_g_loss += batch_g_loss
            
            pbar.set_postfix({
                "D_Loss": f"{batch_d_loss:.4f}",
                "G_Loss": f"{batch_g_loss:.4f}",
                "LR_G": f"{opt_G.param_groups[0]['lr']:.6f}",
                "LR_D": f"{opt_D.param_groups[0]['lr']:.6f}"
            })
        
        # Epoch Statistics
        avg_d_loss = epoch_d_loss / len(dataloader)
        avg_g_loss = epoch_g_loss / len(dataloader)
        losses['D'].append(avg_d_loss)
        losses['G'].append(avg_g_loss)
        
        # VERBESSERUNG 6: Learning Rate Updates
        scheduler_G.step()
        scheduler_D.step()
        plateau_scheduler_G.step(avg_g_loss)
        plateau_scheduler_D.step(avg_d_loss)
        
        print(f" Epoch {epoch+1}: G_Loss={avg_g_loss:.4f}, D_Loss={avg_d_loss:.4f}")
        print(f"    LR_G={opt_G.param_groups[0]['lr']:.6f}, LR_D={opt_D.param_groups[0]['lr']:.6f}")
        
        # VERBESSERUNG 7: Early Warning System
        if len(losses['D']) > 50:
            recent_d_trend = np.mean(losses['D'][-10:]) - np.mean(losses['D'][-50:-40])
            recent_g_trend = np.mean(losses['G'][-10:]) - np.mean(losses['G'][-50:-40])
            
            if recent_d_trend < -0.2 and recent_g_trend > 0.5:
                print("  WARNING: Potential training instability detected!")
                print("   Discriminator getting too strong, Generator struggling")
                
                # Automatische Anpassung der Learning Rates
                for param_group in opt_D.param_groups:
                    param_group['lr'] *= 0.8  # D Learning Rate reduzieren
                for param_group in opt_G.param_groups:
                    param_group['lr'] *= 1.1  # G Learning Rate leicht erhöhen
                
                print(f"   Automatische LR Anpassung: G={opt_G.param_groups[0]['lr']:.6f}, D={opt_D.param_groups[0]['lr']:.6f}")
        
        # Speichern und Samples generieren
        if (epoch + 1) % SAVE_EVERY == 0:
            G.eval()
            with torch.no_grad():
                test_fake = G(fixed_noise)
                grid = make_grid(test_fake, nrow=4, normalize=True, value_range=(-1, 1))
                plt.figure(figsize=(10, 10))
                plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
                plt.axis('off')
                plt.title(f'Epoch {epoch + 1} (Improved Training)')
                plt.savefig(f'training/StyleGANv3/bilder/improved_epoch_{epoch + 1:04d}.png', 
                           bbox_inches='tight', dpi=150)
                plt.show()
            G.train()
            
            save_checkpoint(G, D, opt_G, opt_D, epoch + 1, losses)
            plot_losses(losses)
            
            # Training Health Check
            analyze_training_health(losses)
        
        # Häufigeres Backup
        if (epoch + 1) % 5 == 0:
            save_checkpoint(G, D, opt_G, opt_D, epoch + 1, losses)
    
    print("Verbessertes Training abgeschlossen!")
    return G, D, losses

In [None]:
def resume_from_specific_epoch_improved(target_epoch):
    """Resume von einer spezifischen Epoch mit reduzierten Learning Rates"""
    checkpoint_path = f'training/StyleGANv3/epochen/checkpoint_{target_epoch:04d}.pth'
    
    if not os.path.exists(checkpoint_path):
        print(f"Checkpoint für Epoch {target_epoch} nicht gefunden!")
        return None
    
    print(f" Resume von Epoch {target_epoch} mit reduzierten Learning Rates")
    
    # Training setup
    G, D, opt_G, opt_D, criterion, dataloader, start_epoch, losses = setup_training(checkpoint_path)
    
    # KRITISCH: Learning Rates reduzieren für Stabilität
    for param_group in opt_G.param_groups:
        param_group['lr'] = 0.0001  # Halbiert von 0.0002
    for param_group in opt_D.param_groups:
        param_group['lr'] = 0.0001  # Halbiert von 0.0002
    
    print(f"Neue Learning Rates: G={opt_G.param_groups[0]['lr']}, D={opt_D.param_groups[0]['lr']}")
    
    return improved_train(resume_from=checkpoint_path)


In [None]:
def resume_latest_improved():
    """Automatisch vom neuesten Checkpoint mit verbessertem Training fortsetzen"""
    latest = find_latest_checkpoint()
    if latest:
        return improved_train(resume_from=latest)
    else:
        print("Kein Checkpoint gefunden. Starte neues verbessertes Training.")
        return improved_train()


### Bilder Generieren


In [None]:
def generate_images(G=None, n=16):
    """Neue Bilder generieren"""
    if G is None:
        if 'G' in globals():
            G = globals()['G']
        else:
            print("Kein Generator verfügbar! Führe zuerst das Training aus oder lade ein Checkpoint.")
            return
    
    G.eval()
    with torch.no_grad():
        z = torch.randn(n, 512, device=device)
        fake = G(z)
        
        grid = make_grid(fake, nrow=4, normalize=True, value_range=(-1, 1))
        
        plt.figure(figsize=(12, 12))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title('Generierte Bilder')
        plt.show()

In [None]:
def load_checkpoint(path):
    """Gespeichertes Model laden"""
    if not os.path.exists(path):
        print(f"Checkpoint {path} nicht gefunden!")
        return None
        
    checkpoint = torch.load(path, map_location=device)
    
    G = Generator().to(device)
    G.load_state_dict(checkpoint['generator'])
    G.eval()
    
    print(f"Model geladen von Epoch {checkpoint['epoch']}")
    return G

In [None]:
if __name__ == "__main__":
    # Option 1: Neues Training starten
    # G, D, losses = train()
    
    # Option 2: Vom neuesten Checkpoint fortsetzen
    G, D, losses = resume_latest_improved()

    # Option 3: Von spezifischer Epoch fortsetzen
    # G, D, losses = resume_latest_improved(50)
    
    # Option 4: Von spezifischem Checkpoint fortsetzen
    # G, D, losses = train(resume_from='training/StyleGANv3/epochen/checkpoint_0050.pth')