In [73]:
import torch
from torch import nn
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize, Resize, Compose
from torchvision.utils import save_image
from torch.autograd.variable import Variable
from tqdm import tqdm

# Define device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 64
nz = 100
ngf = 64
ndf = 64
nc = 3
lr = 0.0002
beta1 = 0.5
num_epochs = 5

class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)



class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),  # Verkleinern auf eine 1x1 Ausgabe
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)


# Load the dataset
transform = Compose([Resize((64, 64)), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = ImageFolder("train64", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

# Initialize the networks
netG = Generator(nz, ngf, nc).to(device)
netD = Discriminator(nc, ndf).to(device)

# Initialize the criterion
criterion = nn.BCELoss()

# Setup Adam optimizers
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))


# Training loop
for epoch in range(num_epochs):
    for i, data in enumerate(tqdm(dataloader, 0)):
        netD.zero_grad()
        real = data[0].to(device)
        batch_size = real.size(0)
        label = torch.full((batch_size,), 1, device=device).float()
        output = netD(real)
        errD_real = criterion(output, label)
        errD_real.backward()
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(0).float()
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        errD = errD_real + errD_fake
        optimizerD.step()

        netG.zero_grad()
        label.fill_(1).float()
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        optimizerG.step()

        if i % 100 == 0:
            print(f"[{epoch+1}/{num_epochs}][{i}/{len(dataloader)}] Loss_D: {errD.item()} Loss_G: {errG.item()}")

    # save a generated image after each epoch
    with torch.no_grad():
        fake = netG(torch.randn(64, nz, 1, 1, device=device)).detach().cpu()
    save_image(fake, f"output_{epoch}.png")


  0%|          | 3/1563 [00:07<51:54,  2.00s/it]  

[1/5][0/1563] Loss_D: 1.4283931255340576 Loss_G: 2.030776023864746


  7%|▋         | 103/1563 [00:22<03:21,  7.25it/s]

[1/5][100/1563] Loss_D: 0.06905334442853928 Loss_G: 6.270732402801514


 13%|█▎        | 203/1563 [00:36<03:09,  7.19it/s]

[1/5][200/1563] Loss_D: 0.1378207504749298 Loss_G: 5.042850017547607


 19%|█▉        | 303/1563 [00:51<02:57,  7.12it/s]

[1/5][300/1563] Loss_D: 0.06461813300848007 Loss_G: 4.213501930236816


 26%|██▌       | 403/1563 [01:06<02:43,  7.08it/s]

[1/5][400/1563] Loss_D: 0.08496499061584473 Loss_G: 4.155423641204834


 32%|███▏      | 503/1563 [01:21<02:30,  7.06it/s]

[1/5][500/1563] Loss_D: 0.11909759789705276 Loss_G: 3.6849489212036133


 39%|███▊      | 603/1563 [01:36<02:16,  7.05it/s]

[1/5][600/1563] Loss_D: 0.44118955731391907 Loss_G: 7.55834436416626


 45%|████▍     | 703/1563 [01:51<02:02,  7.02it/s]

[1/5][700/1563] Loss_D: 0.22810614109039307 Loss_G: 2.9988789558410645


 51%|█████▏    | 803/1563 [02:07<01:49,  6.95it/s]

[1/5][800/1563] Loss_D: 0.06043301522731781 Loss_G: 4.670472621917725


 58%|█████▊    | 903/1563 [02:22<01:35,  6.94it/s]

[1/5][900/1563] Loss_D: 1.2846417427062988 Loss_G: 1.7856649160385132


 64%|██████▍   | 1003/1563 [02:37<01:19,  7.02it/s]

[1/5][1000/1563] Loss_D: 1.1507372856140137 Loss_G: 0.927391767501831


 71%|███████   | 1103/1563 [02:52<01:05,  6.99it/s]

[1/5][1100/1563] Loss_D: 0.05525706708431244 Loss_G: 5.159183502197266


 77%|███████▋  | 1203/1563 [03:07<00:51,  6.95it/s]

[1/5][1200/1563] Loss_D: 0.009449204429984093 Loss_G: 5.952235698699951


 83%|████████▎ | 1301/1563 [03:22<00:48,  5.39it/s]

[1/5][1300/1563] Loss_D: 0.0054879989475011826 Loss_G: 6.559638977050781


 90%|████████▉ | 1403/1563 [03:37<00:22,  7.02it/s]

[1/5][1400/1563] Loss_D: 0.0030386031139642 Loss_G: 7.006097316741943


 96%|█████████▌| 1503/1563 [03:52<00:08,  7.05it/s]

[1/5][1500/1563] Loss_D: 0.001808332628570497 Loss_G: 7.408713340759277


100%|██████████| 1563/1563 [04:02<00:00,  6.45it/s]
  0%|          | 3/1563 [00:08<1:00:03,  2.31s/it]

[2/5][0/1563] Loss_D: 0.0021047985646873713 Loss_G: 7.033045768737793


  7%|▋         | 103/1563 [00:23<03:25,  7.09it/s]

[2/5][100/1563] Loss_D: 0.00045665528159588575 Loss_G: 9.819746017456055


  8%|▊         | 132/1563 [00:28<05:11,  4.59it/s]


KeyboardInterrupt: 