In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np

d_loss_values = []
g_loss_values = []
# Define constants
IMG_SIZE = 512  # Set to 512x512 for the new output size
LATENT_DIM = 200  # Updated latent dimension
BATCH_SIZE = 512  # Updated batch size
EPOCHS = 100

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.init_size = IMG_SIZE // 32  # Adjusted for 512x512 images (512 / 32 = 16)
        self.l1 = nn.Sequential(nn.Linear(LATENT_DIM, 512 * self.init_size ** 2))
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.BatchNorm2d(32, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 16, 3, stride=1, padding=1),
            nn.BatchNorm2d(16, 0.8),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 512, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),   # 512x512 -> 256x256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 256x256 -> 128x128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1), # 128x128 -> 64x64
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1), # 64x64 -> 32x32
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, 4, stride=2, padding=1), # 32x32 -> 16x16
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 2048, 4, stride=2, padding=1), # 16x16 -> 8x8
            nn.BatchNorm2d(2048),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(2048 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        return validity


transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataloader = DataLoader(
    ConcatDataset([datasets.Flowers102(root='../../data/flowers', split='train', download=True, transform=transform),
                   datasets.Flowers102(root='../../data/flowers', split='val', download=True, transform=transform),
                   datasets.Flowers102(root='../../data/flowers', split='test', download=True, transform=transform)]),
    batch_size=BATCH_SIZE, shuffle=True
)

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training
for epoch in range(EPOCHS):
    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)

        # Adversarial ground truths
        valid = torch.ones(batch_size, 1, requires_grad=False).to(device)
        fake = torch.zeros(batch_size, 1, requires_grad=False).to(device)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = torch.randn(batch_size, LATENT_DIM).to(device)

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss for real images
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        # Loss for fake images
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        # Total discriminator loss
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Loss for fake images with flipped labels
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        print(f"[Epoch {epoch}/{EPOCHS}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

    # Save sample images
    if epoch % 2 == 0:
        save_image(gen_imgs.data[:25], f"images/{epoch}_DCGAN_flowers_bigbatch_512.png", nrow=5, normalize=True)
        # Save the model
        torch.save(generator.state_dict(), f"saved_model/saved_model_dcgan_flowers_bigbatch_512_{epoch}.pth")
        # store loss values for plotting at the end in the vectors
        d_loss_values.append(d_loss.item())
        g_loss_values.append(g_loss.item())
# Save model after last iter
save_image(gen_imgs.data[:25], f"images/{epoch}_DCGAN_flowers_bigbatch_512.png", nrow=5, normalize=True)
# Save the model
torch.save(generator.state_dict(), f"saved_model/saved_model_dcgan_flowers_bigbatch_512_{EPOCHS}.pth")

plt.plot(np.arange(0, EPOCHS,2), d_loss_values, label='Discriminator loss')
plt.plot(np.arange(0, EPOCHS,2), g_loss_values, label='Generator loss')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss values')
plt.savefig('loss_values_flowers_bigbatch.png')
plt.show()



[Epoch 0/100] [Batch 0/1024] [D loss: 0.6753358840942383] [G loss: 27.268253326416016]
[Epoch 0/100] [Batch 1/1024] [D loss: 0.0015320637030526996] [G loss: 17.794052124023438]
[Epoch 0/100] [Batch 2/1024] [D loss: 6.60241930745542e-05] [G loss: 17.31537437438965]
[Epoch 0/100] [Batch 3/1024] [D loss: 3.076045686611906e-05] [G loss: 15.874110221862793]
[Epoch 0/100] [Batch 4/1024] [D loss: 7.497778278775513e-05] [G loss: 12.796379089355469]
[Epoch 0/100] [Batch 5/1024] [D loss: 0.0016901842318475246] [G loss: 14.667989730834961]
[Epoch 0/100] [Batch 6/1024] [D loss: 6.008359559928067e-05] [G loss: 13.27330207824707]
[Epoch 0/100] [Batch 7/1024] [D loss: 0.0034707896411418915] [G loss: 19.637313842773438]
[Epoch 0/100] [Batch 8/1024] [D loss: 1.450812578696059e-05] [G loss: 18.61610221862793]
[Epoch 0/100] [Batch 9/1024] [D loss: 0.0009220248321071267] [G loss: 16.98117446899414]
[Epoch 0/100] [Batch 10/1024] [D loss: 0.00036728131817653775] [G loss: 14.502508163452148]
[Epoch 0/100] [B

KeyboardInterrupt: 