In [64]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from pathlib import Path

## Dataset

In [65]:
class ColorizationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.gray_paths = list((self.root_dir / "gray").glob("*"))
        self.color_paths = list((self.root_dir / "color").glob("*"))
        self.transform = transform

    def __len__(self):
        return len(self.gray_paths)

    def __getitem__(self, idx):
        gray_path = self.gray_paths[idx]
        color_path = self.color_paths[idx]

        gray_image = Image.open(gray_path).convert("L")  # 1-Kanal
        color_image = Image.open(color_path).convert("RGB")  # 3-Kanal

        if self.transform:
            gray_image = self.transform(gray_image)
            color_image = self.transform(color_image)

        # Graubild zu 3 Kanälen duplizieren
        gray_image = gray_image.expand(3, -1, -1)

        # Zusätzlich den Dateinamen zurückgeben
        filename = gray_path.stem  # Nur der Name ohne ".jpg" oder ".png"

        return gray_image, color_image, filename

## Generater

In [66]:
class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()

        def down_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, 2, 1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2)
            )

        def up_block(in_channels, out_channels):
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )

        self.down1 = down_block(3, 64)
        self.down2 = down_block(64, 128)
        self.down3 = down_block(128, 256)
        self.down4 = down_block(256, 512)

        self.up1 = up_block(512, 256)
        self.up2 = up_block(512, 128)
        self.up3 = up_block(256, 64)
        self.up4 = nn.ConvTranspose2d(128, 3, 4, 2, 1)  # Output RGB

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        u1 = self.up1(d4)
        u1 = torch.cat([u1, d3], dim=1)
        u2 = self.up2(u1)
        u2 = torch.cat([u2, d2], dim=1)
        u3 = self.up3(u2)
        u3 = torch.cat([u3, d1], dim=1)
        output = torch.tanh(self.up4(u3))
        return output

## Discriminator

In [67]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),  # (gray + color)
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 1, 4, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, gray, color):
        x = torch.cat([gray, color], dim=1)
        return self.model(x)

In [68]:
# Hyperparameter
lr_generator = 2e-4
lr_discriminator = 1e-4
b1_generator = 0.5
b2_generator = 0.999
b1_discriminator = 0.5
b2_discriminator = 0.999
epochs = 100
batch_size = 8
datapath = Path.cwd() / "images"
output_folder = Path.cwd() / "images" / "generated"

In [69]:
# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Modelle
gen = UNetGenerator().to(device)
disc = Discriminator().to(device)

# Optimizer
opt_gen = optim.Adam(gen.parameters(), lr=lr_generator, betas=(b1_generator, b2_generator))
opt_disc = optim.Adam(disc.parameters(), lr=lr_discriminator, betas=(b1_discriminator, b2_discriminator))

# Loss Funktionen
criterion_GAN = nn.BCELoss()
criterion_L1 = nn.L1Loss()

# DataLoader
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

dataset = ColorizationDataset(datapath, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training
for epoch in range(epochs):
    for idx, (gray, color, filename) in enumerate(dataloader):
        gray = gray.to(device)
        color = color.to(device)

        fake_color = gen(gray)
        real_pair = disc(gray, color)
        fake_pair = disc(gray, fake_color.detach())

        real_labels = torch.ones_like(real_pair)
        fake_labels = torch.zeros_like(fake_pair)

        loss_real = criterion_GAN(real_pair, real_labels)
        loss_fake = criterion_GAN(fake_pair, fake_labels)
        loss_disc = (loss_real + loss_fake) / 2

        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # ==================== Generator ====================
        fake_pair = disc(gray, fake_color)
        loss_gan = criterion_GAN(fake_pair, real_labels)
        loss_l1 = criterion_L1(fake_color, color) * 100

        loss_gen = loss_gan + loss_l1

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if idx % 100 == 0:
            print(f"Epoch [{epoch}/{epochs}] Batch [{idx}/{len(dataloader)}] "
                  f"Loss D: {loss_disc.item():.4f}, loss G: {loss_gen.item():.4f}")

    # --> Am Ende jedes Epochs: speichere das aktuelle fake_color
    # Aber Achtung: Hier nehmen wir NUR das erste Bild im Batch [0]
    save_path = output_folder / f"{filename[0]}_generated.png"
    save_image(fake_color[0], save_path)  # [0] weil Batch

print("✅ Training abgeschlossen!")

Epoch [0/100] Batch [0/1] Loss D: 0.7137, loss G: 58.4400
Epoch [1/100] Batch [0/1] Loss D: 0.6836, loss G: 50.3474
Epoch [2/100] Batch [0/1] Loss D: 0.6570, loss G: 44.2196
Epoch [3/100] Batch [0/1] Loss D: 0.6336, loss G: 39.6873
Epoch [4/100] Batch [0/1] Loss D: 0.6113, loss G: 36.2992
Epoch [5/100] Batch [0/1] Loss D: 0.5932, loss G: 34.6737
Epoch [6/100] Batch [0/1] Loss D: 0.5702, loss G: 32.2293
Epoch [7/100] Batch [0/1] Loss D: 0.5522, loss G: 30.2861
Epoch [8/100] Batch [0/1] Loss D: 0.5314, loss G: 28.9496
Epoch [9/100] Batch [0/1] Loss D: 0.5097, loss G: 27.8018
Epoch [10/100] Batch [0/1] Loss D: 0.4928, loss G: 26.6776
Epoch [11/100] Batch [0/1] Loss D: 0.4777, loss G: 25.9010
Epoch [12/100] Batch [0/1] Loss D: 0.4534, loss G: 25.6020
Epoch [13/100] Batch [0/1] Loss D: 0.4346, loss G: 25.3500
Epoch [14/100] Batch [0/1] Loss D: 0.4127, loss G: 24.2883
Epoch [15/100] Batch [0/1] Loss D: 0.4020, loss G: 23.3772
Epoch [16/100] Batch [0/1] Loss D: 0.3789, loss G: 22.9699
Epoch [