In [None]:
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
print(np.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

print(torch.__version__)
print(torch.version.cuda)

In [None]:
class PostcardDataset(Dataset):
    def __init__(self, csv_path, image_dir, transform=None):
        self.df = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.transform = transform

        # Filter
        self.df = self.df.dropna(subset=["akon_id", "country_id", "latitude", "longitude"])

        self.countries = sorted(self.df["country_id"].unique())
        self.country_to_idx = {c: i for i, c in enumerate(self.countries)}
        self.idx_to_country = {i: c for c, i in self.country_to_idx.items()}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = os.path.join(self.image_dir, row["akon_id"] + ".jpg")
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        country_idx = self.country_to_idx[row["country_id"]]
        lat = float(row["latitude"])
        lon = float(row["longitude"])

        # Normierte Koordinaten
        coords = torch.tensor([lat / 90.0, lon / 180.0], dtype=torch.float32)

        return image, country_idx, coords

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

csv_path = "akon_postcards_public_domain.csv"
image_dir = "images/256"
dataset = PostcardDataset(csv_path, image_dir, transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

# Beispielbild anzeigen
img, label_idx, coords = dataset[0]
plt.imshow(img.permute(1, 2, 0) * 0.5 + 0.5)
plt.title(f"Label-Index: {label_idx}")
plt.show()

In [None]:
embedding_dim = 48
z_dim = 100
coord_dim = 2
num_countries = len(dataset.country_to_idx)

country_embedding = nn.Embedding(num_countries, embedding_dim).to(device)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(z_dim + embedding_dim + coord_dim, 1024 * 4 * 4),
            nn.BatchNorm1d(1024 * 4 * 4),  # BatchNorm hinzugefügt
            nn.ReLU(True)
        )

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.ConvTranspose2d(32, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z, country_embed, coords):
        x = torch.cat([z, country_embed, coords], dim=1)
        x = self.fc(x).view(-1, 1024, 4, 4)
        return self.conv(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # KLEINERE Label-Projektion - nur auf 32x32!
        self.label_proj = nn.Sequential(
            nn.Linear(embedding_dim + coord_dim, 32 * 32),
            nn.LeakyReLU(0.2)
        )

        self.model = nn.Sequential(
            # Input: 3 (image) + 1 (label map) = 4 channels
            nn.Conv2d(4, 32, 4, 2, 1),  # 256->128
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.25),  # Dropout hinzugefügt

            nn.Conv2d(32, 64, 4, 2, 1),  # 128->64
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.25),

            nn.Conv2d(64, 128, 4, 2, 1),  # 64->32
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.25),

            nn.Conv2d(128, 256, 4, 2, 1),  # 32->16
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, 4, 2, 1),  # 16->8
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 1, 4, 2, 1),  # 8->4
            nn.AdaptiveAvgPool2d(1),  # Global Average Pooling
            nn.Flatten(),
            nn.Sigmoid()
        )

    def forward(self, img, country_embed, coords):
        # Label-Map erst klein erstellen (32x32) - VIEL effizienter!
        label = torch.cat([country_embed, coords], dim=1)
        label_map = self.label_proj(label).view(-1, 1, 32, 32)
        
        # Dann auf Bildgröße hochskalieren (256x256) - GLEICHE QUALITÄT!
        label_map = F.interpolate(label_map, size=(256, 256), mode='bilinear', align_corners=False)
        
        # Input: 256x256 Bild + 256x256 Label-Map = 4 Kanäle
        x = torch.cat([img, label_map], dim=1)  # Shape: [batch, 4, 256, 256]
        return self.model(x)  # Output: [batch, 1] (Real/Fake Prediction)

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# VERSCHIEDENE LEARNING RATES - Generator höher, Diskriminator niedriger
g_opt = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
d_opt = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))  # Halbe LR!

criterion = nn.BCELoss()

# Learning Rate Scheduler
g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_opt, gamma=0.995)
d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_opt, gamma=0.995)

In [None]:
def train_step(imgs, country_idxs, coords):
    batch_size = imgs.size(0)
    
    real_labels = torch.ones(batch_size, 1).to(device) * 0.9
    fake_labels = torch.zeros(batch_size, 1).to(device) + 0.1
    
    country_embed = country_embedding(country_idxs).detach()
    coords = coords.detach()    
    
    d_loss = 0
    real_pred_mean = 0
    fake_pred_mean = 0
    
    if np.random.random() > 0.2:
        z = torch.randn(batch_size, z_dim).to(device)
        fake_imgs = generator(z, country_embed, coords)
        
        real_preds = discriminator(imgs, country_embed, coords)
        fake_preds = discriminator(fake_imgs.detach(), country_embed, coords)
        
        d_loss = criterion(real_preds, real_labels) + criterion(fake_preds, fake_labels)
        
        d_opt.zero_grad()
        d_loss.backward()
        d_opt.step()
        
        real_pred_mean = real_preds.mean().item()
        fake_pred_mean = fake_preds.mean().item()
        d_loss = d_loss.item()
    
    g_losses = []
    for i in range(2):
        z = torch.randn(batch_size, z_dim).to(device)
        fake_imgs = generator(z, country_embed, coords)
        fake_preds = discriminator(fake_imgs, country_embed, coords)
        g_loss = criterion(fake_preds, real_labels)

        g_opt.zero_grad()
        if i == 0:
            g_loss.backward(retain_graph=True)  # <-- hier
        else:
            g_loss.backward()
        g_opt.step()
        g_losses.append(g_loss.item())

        
    return d_loss, np.mean(g_losses), real_pred_mean, fake_pred_mean


In [None]:
import csv

log_file = "training/DCGANv3/training_metrics.csv"
os.makedirs("training/DCGANv3", exist_ok=True)

# Falls Datei noch nicht existiert, Header schreiben
if not os.path.exists(log_file):
    with open(log_file, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["Epoch", "D_Loss", "G_Loss", "Real_Pred", "Fake_Pred", "G_LR", "D_LR", "Duration"])


In [None]:
import time
from torchvision.utils import save_image

# Ordner vorbereiten
os.makedirs("training/DCGANv3/epoche_bilder", exist_ok=True)
os.makedirs("training/DCGANv3/epoche_schritte", exist_ok=True)

# Feste Inputs für gleichbleibende Samples
fixed_noise = torch.randn(16, z_dim).to(device)
fixed_country = torch.randint(0, num_countries, (16,)).to(device)
fixed_coords = torch.rand(16, coord_dim).to(device) * 2 - 1  # [-1, 1] range

# Checkpoint Setup
start_epoch = 0
checkpoint_path = "training/DCGANv3/epoche_schritte/checkpoint_epoch_latest.pth"

if os.path.exists(checkpoint_path):
    print(" Checkpoint gefunden, lade...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint["generator_state_dict"])
    discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
    g_opt.load_state_dict(checkpoint["g_opt_state_dict"])
    d_opt.load_state_dict(checkpoint["d_opt_state_dict"])
    start_epoch = checkpoint["epoch"]
    print(f"⏮ Weiter bei Epoche {start_epoch}")
else:
    print(" Kein Checkpoint gefunden, starte bei Epoche 0.")

num_epochs = 10000

# Training Loop
for epoch in range(start_epoch, num_epochs):
    start_time = time.time()
    print(f"\n Starte Epoche {epoch+1}/{num_epochs}")
    
    epoch_d_losses = []
    epoch_g_losses = []
    epoch_real_preds = []
    epoch_fake_preds = []

    for batch_idx, (imgs, country_idxs, coords) in enumerate(dataloader):
        imgs = imgs.to(device)
        country_idxs = country_idxs.to(device)
        coords = coords.to(device)

        # Verbessertes Training
        d_loss, g_loss, real_pred, fake_pred = train_step(imgs, country_idxs, coords)
        
        epoch_d_losses.append(d_loss)
        epoch_g_losses.append(g_loss)
        if real_pred > 0:  # Nur wenn Diskriminator trainiert wurde
            epoch_real_preds.append(real_pred)
            epoch_fake_preds.append(fake_pred)

    # Learning Rate Update
    g_scheduler.step()
    d_scheduler.step()

    #  Epoch Zusammenfassung
    duration = time.time() - start_time
    avg_d_loss = np.mean(epoch_d_losses)
    avg_g_loss = np.mean(epoch_g_losses)
    avg_real_pred = np.mean(epoch_real_preds) if epoch_real_preds else 0
    avg_fake_pred = np.mean(epoch_fake_preds) if epoch_fake_preds else 0
    
    print(f" Epoche {epoch+1} | D Loss: {avg_d_loss:.4f} | G Loss: {avg_g_loss:.4f}")
    print(f" Real Pred: {avg_real_pred:.3f} | Fake Pred: {avg_fake_pred:.3f} | {duration:.2f}s")
    print(f" G LR: {g_scheduler.get_last_lr()[0]:.6f} | D LR: {d_scheduler.get_last_lr()[0]:.6f}")

    #  In CSV loggen
    with open(log_file, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            epoch + 1,
            avg_d_loss,
            avg_g_loss,
            avg_real_pred,
            avg_fake_pred,
            g_scheduler.get_last_lr()[0],
            d_scheduler.get_last_lr()[0],
            duration
        ])


    #  Alle 10 Epochen Beispielbilder speichern
    if (epoch + 1) % 10 == 0:
        generator.eval()
        with torch.no_grad():
            embed = country_embedding(fixed_country)
            samples = generator(fixed_noise, embed, fixed_coords)
            save_image(samples * 0.5 + 0.5,
                       f"training/DCGANv3/epoche_bilder/gen_samples_epoch_{epoch+1}.png",
                       nrow=4)
        generator.train()

    # Alle 50 Epochen Checkpoint speichern
    if (epoch + 1) % 50 == 0:
        checkpoint_data = {
            'epoch': epoch + 1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'g_opt_state_dict': g_opt.state_dict(),
            'd_opt_state_dict': d_opt.state_dict(),
            'g_scheduler_state_dict': g_scheduler.state_dict(),
            'd_scheduler_state_dict': d_scheduler.state_dict(),
        }

        torch.save(checkpoint_data, f"training/DCGANv3/epoche_schritte/checkpoint_epoch_{epoch+1}.pth")
        torch.save(checkpoint_data, checkpoint_path)  # aktuellsten Stand überschreiben


In [None]:
# Am Ende des Trainings (nach der Schleife)
torch.save(generator.state_dict(), "generator_improved.pth")
torch.save(discriminator.state_dict(), "discriminator_improved.pth")

In [None]:
# Beispiel: Generiere ein Bild für ein bestimmtes Land und Koordinaten (z.B. Köln)
z = torch.randn(1, z_dim).to(device)

# Beispiel: normierte Koordinaten von Köln (Latitude / 90, Longitude / 180)
# Köln ca. 50.94°N, 6.96°E
lat_norm = 50.94 / 90.0
lon_norm = 6.96 / 180.0
coords = torch.tensor([[lat_norm, lon_norm]], dtype=torch.float32).to(device)

# Country-ID (z.B. Deutschland "DE")
city_country = "DE"
country_idx = torch.tensor([dataset.country_to_idx[city_country]]).to(device)

# Hole das Country-Embedding
country_embed = country_embedding(country_idx)

# Modell in Eval-Modus
generator.eval()
with torch.no_grad():
    gen_img = generator(z, country_embed, coords).squeeze().cpu()
    plt.imshow(gen_img.permute(1, 2, 0) * 0.5 + 0.5)
    plt.title(f"Generierte Postkarte: {city_country}")
    plt.axis("off")
    plt.show()

In [None]:
# Verbesserte Visualisierung der Embeddings
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import pandas as pd

# Alle Länder aus dem Dataset holen
all_countries = sorted(dataset.df["country_id"].unique())
all_idxs = torch.tensor([dataset.country_to_idx[c] for c in all_countries]).to(device)

# Embeddings holen (die trainierten!)
embeddings = country_embedding(all_idxs).detach().cpu().numpy()

# t-SNE
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(all_countries)-1))
emb_2d = tsne.fit_transform(embeddings)

# DataFrame zum Plotten bauen
df_tsne = pd.DataFrame({
    "country": all_countries,
    "x": emb_2d[:, 0],
    "y": emb_2d[:, 1]
})

# Plot t-SNE
plt.figure(figsize=(12, 10))
plt.scatter(df_tsne["x"], df_tsne["y"], alpha=0.7)

for i, country in enumerate(df_tsne["country"]):
    plt.annotate(country, (df_tsne.loc[i, "x"], df_tsne.loc[i, "y"]), fontsize=8)

plt.title("Trainierte Country Embeddings via t-SNE")
plt.show()

# PCA
pca = PCA(n_components=2)
emb_2d_pca = pca.fit_transform(embeddings)

df_pca = pd.DataFrame({
    "country": all_countries,
    "x": emb_2d_pca[:, 0],
    "y": emb_2d_pca[:, 1]
})

plt.figure(figsize=(12, 10))
plt.scatter(df_pca["x"], df_pca["y"], alpha=0.7)

for i, country in enumerate(df_pca["country"]):
    plt.annotate(country, (df_pca.loc[i, "x"], df_pca.loc[i, "y"]), fontsize=8)

plt.title("Trainierte Country Embeddings via PCA")
plt.show()

print(f"Explained variance ratio: {pca.explained_variance_ratio_}")