In [None]:
import gc

gc.collect()
import torch

torch.cuda.empty_cache()


In [None]:
import os
import time
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import numpy
print(numpy.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

print(torch.__version__)
print(torch.version.cuda)


In [None]:
class PostcardDataset(Dataset):
    def __init__(self, csv_path, image_dir, transform=None):
        self.df = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.transform = transform


        # Nur gültige Labels
        self.df = self.df.dropna(subset=["akon_id", "city", "country_id"])
        self.df = self.df[(self.df["city"].astype(str).str.strip() != "") & (self.df["country_id"].astype(str).str.strip() != "")]

        # Kombination aus Stadt und Land
        self.df["combo_label"] = self.df["city"] + " | " + self.df["country_id"]

        self.label_to_idx = {label: idx for idx, label in enumerate(self.df["combo_label"].unique())}
        self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = os.path.join(self.image_dir, row["akon_id"] + ".jpg")
        image = Image.open(image_path).convert("RGB")
        
        label = row["combo_label"]
        label_idx = self.label_to_idx[label]

        if self.transform:
            image = self.transform(image)

        return image, label_idx


In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])


dataset = PostcardDataset("akon_postcards_public_domain.csv", "images/256", transform)
# dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

# Beispielbild anzeigen
img, label_idx = dataset[0]
plt.imshow(img.permute(1, 2, 0) * 0.5 + 0.5)
plt.title(f"Label-Index: {label_idx}")
plt.show()


In [None]:
embedding_dim = 32
z_dim = 100
num_labels = len(dataset.label_to_idx)

# Textembedding (z.B. Stadt)
label_embedding = nn.Embedding(num_labels, embedding_dim).to(device)


In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, embedding_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(z_dim + embedding_dim, 1024 * 4 * 4),
            nn.ReLU(True)
        )
        
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),  # 8x8
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1),  # 16x16
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 32x32
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 64x64
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1),   # 128x128
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.ConvTranspose2d(32, 3, 4, 2, 1),    # 256x256
            nn.Tanh()
        )

    def forward(self, z, label_embed):
        x = torch.cat([z, label_embed], dim=1)
        x = self.fc(x).view(-1, 1024, 4, 4)
        return self.conv(x)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.label_proj = nn.Linear(embedding_dim, 256 * 256)

        self.model = nn.Sequential(
            nn.Conv2d(4, 32, 4, 2, 1),  # 128x128
            nn.LeakyReLU(0.2),

            nn.Conv2d(32, 64, 4, 2, 1),  # 64x64
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, 4, 2, 1),  # 32x32
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, 4, 2, 1),  # 16x16
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, 4, 2, 1),  # 8x8
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 1024, 4, 2, 1),  # 4x4
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2),

            nn.Flatten(),
            nn.Linear(1024 * 4 * 4, 1),
        )

    def forward(self, img, label_embed):
        label_map = self.label_proj(label_embed).view(-1, 1, 256, 256)
        x = torch.cat([img, label_map], dim=1)  # [B, 4, 256, 256]
        return self.model(x).view(-1)


In [None]:
generator = Generator(z_dim, embedding_dim).to(device)
discriminator = Discriminator(embedding_dim).to(device)

g_opt = torch.optim.Adam(generator.parameters(), lr=5e-5, betas=(0.5, 0.9))
d_opt = torch.optim.Adam(discriminator.parameters(), lr=5e-5, betas=(0.5, 0.9))


In [None]:
def d_loss_fn(real_scores, fake_scores, gp, lambda_gp=10):
    return fake_scores.mean() - real_scores.mean() + lambda_gp * gp

def g_loss_fn(fake_scores):
    return -fake_scores.mean()


In [None]:
def gradient_penalty(discriminator, real_imgs, fake_imgs, label_embed):
    batch_size = real_imgs.size(0)
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)

    interpolated = epsilon * real_imgs + (1 - epsilon) * fake_imgs
    interpolated.requires_grad_(True)

    d_interpolated = discriminator(interpolated, label_embed)

    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(batch_size, -1)
    gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gp


In [None]:
os.makedirs("training/wasserstein/epoche_bilder", exist_ok=True)
os.makedirs("training/wasserstein/epoche_schritte", exist_ok=True)

fixed_noise = torch.randn(16, z_dim).to(device)
fixed_labels = torch.randint(0, num_labels, (16,), device=device)

n_critic = 2  # Discriminator öfter trainieren
num_epochs = 10000
start_epoch = 0

checkpoint_path = "training/wasserstein/epoche_schritte/checkpoint_epoch_latest.pth"
if os.path.exists(checkpoint_path):
    print(" Lade Checkpoint...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint["generator_state_dict"])
    discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
    g_opt.load_state_dict(checkpoint["g_opt_state_dict"])
    d_opt.load_state_dict(checkpoint["d_opt_state_dict"])
    start_epoch = checkpoint["epoch"]
    print(f" Starte ab Epoche {start_epoch}")
else:
    print(" Kein Checkpoint – frischer Start")

for epoch in range(start_epoch, num_epochs):
    start_time = time.time()

    for i, (imgs, label_idxs) in enumerate(dataloader):
        batch_size = imgs.size(0)
        imgs = imgs.to(device)
        label_idxs = label_idxs.to(device)

        # Keine Pre-Berechnung von real_scores außerhalb der Schleife!

        for _ in range(n_critic):
            label_embed = label_embedding(label_idxs)  # neu in jeder Iteration (oder zumindest detach/clonen)
            real_imgs = imgs  # du kannst imgs einfach nochmal nehmen, sollte kein Problem sein

            z = torch.randn(batch_size, z_dim, device=device)
            fake_imgs = generator(z, label_embed).detach()

            real_scores = discriminator(real_imgs, label_embed)
            fake_scores = discriminator(fake_imgs, label_embed)
            gp = gradient_penalty(discriminator, real_imgs, fake_imgs, label_embed)
            d_loss = d_loss_fn(real_scores, fake_scores, gp)

            d_opt.zero_grad()
            d_loss.backward()
            d_opt.step()

        # Generator Training (bleibt unverändert)
        label_embed = label_embedding(label_idxs)
        z = torch.randn(batch_size, z_dim, device=device)
        fake_imgs = generator(z, label_embed)
        fake_scores = discriminator(fake_imgs, label_embed)
        g_loss = g_loss_fn(fake_scores)

        g_opt.zero_grad()
        g_loss.backward()
        g_opt.step()

    duration = time.time() - start_time
    print(f" Epoche {epoch+1} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f} | ⏱️ {duration:.2f}s")



    # Bilder speichern
    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            embed = label_embedding(fixed_labels)
            samples = generator(fixed_noise, embed)
            save_image(samples * 0.5 + 0.5,
                       f"training/wasserstein/epoche_bilder/gen_samples_epoch_{epoch+1}.png",
                       nrow=4)

    # Checkpoint speichern
    if (epoch + 1) % 10 == 0:
        checkpoint_data = {
            'epoch': epoch + 1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'g_opt_state_dict': g_opt.state_dict(),
            'd_opt_state_dict': d_opt.state_dict(),
        }
        torch.save(checkpoint_data, f"training/wasserstein/epoche_schritte/checkpoint_epoch_{epoch+1}.pth")
        torch.save(checkpoint_data, checkpoint_path)


In [None]:
# Am Ende des Trainings (nach der Schleife)
torch.save(generator.state_dict(), "generator.pth")
torch.save(discriminator.state_dict(), "discriminator.pth")


In [None]:
generator.eval()
with torch.no_grad():
    z = torch.randn(1, z_dim).to(device)
    city_name = "Kiel"
    city_idx = torch.tensor([dataset.label_to_idx[city_name]]).to(device)
    label_embed = label_embedding(city_idx)
    
    gen_img = generator(z, label_embed).squeeze().detach().cpu()
    plt.imshow(gen_img.permute(1, 2, 0) * 0.5 + 0.5)
    plt.title(f"Generierte Postkarte: {city_name}")
    plt.axis("off")
    plt.show()
