In [None]:
## Moved to 512x512 as the final step. Batch size had to drop to 8 here, but employed gradient accumulation to emulate a bigger batch size. This however didn't improve the results. I can see that gradient accumulation doesn't work exactly the same way as increasing the batch size directly.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import time
from torchviz import make_dot
from torchinfo import summary

d_loss_values = []
g_loss_values = []
epoch_times = []  # List to store time taken for each epoch

# Define constants
IMG_SIZE = 512  # Set to 512x512 for the new output size
LATENT_DIM = 200  # Updated latent dimension
BATCH_SIZE = 8  # Updated batch size
EPOCHS = 100
accum_steps = 12  # Number of accumulation steps to simulate a larger batch size

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.init_size = IMG_SIZE // 32  # Adjusted for 512x512 images (512 / 32 = 16)
        self.l1 = nn.Sequential(nn.Linear(LATENT_DIM, 512 * self.init_size ** 2))
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.BatchNorm2d(32, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 16, 3, stride=1, padding=1),
            nn.BatchNorm2d(16, 0.8),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 512, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),   # 512x512 -> 256x256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 256x256 -> 128x128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1), # 128x128 -> 64x64
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1), # 64x64 -> 32x32
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, 4, stride=2, padding=1), # 32x32 -> 16x16
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 2048, 4, stride=2, padding=1), # 16x16 -> 8x8
            nn.BatchNorm2d(2048),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(2048 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])
dataloader = DataLoader(
    ConcatDataset([datasets.Flowers102(root='../../data/flowers', split='train', download=True, transform=transform),
                   datasets.Flowers102(root='../../data/flowers', split='val', download=True, transform=transform),
                   datasets.Flowers102(root='../../data/flowers', split='test', download=True, transform=transform)]),
    batch_size=BATCH_SIZE, shuffle=True
)

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Apply weights initialization to models
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers (consider lowering the learning rate slightly)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0003, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)

        # Labels
        valid = torch.ones(batch_size, 1, requires_grad=False).to(device)
        fake = torch.zeros(batch_size, 1, requires_grad=False).to(device)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad(set_to_none=True)

        z = torch.randn(batch_size, LATENT_DIM).to(device)
        gen_imgs = generator(z)

        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        # Accumulate gradients
        d_loss = d_loss / accum_steps
        d_loss.backward()

        if (i + 1) % accum_steps == 0 or (i + 1) == len(dataloader):
            optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad(set_to_none=True)

        # Measure generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss = g_loss / accum_steps
        g_loss.backward()

        if (i + 1) % accum_steps == 0 or (i + 1) == len(dataloader):
            optimizer_G.step()

        print(f"[Epoch {epoch}/{EPOCHS}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
        print(f"[GPU Memory Allocated: {allocated_memory:.2f} GB] [GPU Memory Reserved: {reserved_memory:.2f} GB]")

    # Save sample images and model checkpoints every few epochs
    if epoch % 2 == 0:
        save_image(gen_imgs.data[:25], f"images/{epoch}_DCGAN_Flowers_bigbatch_R1_512.png", nrow=5, normalize=True)
        torch.save(generator.state_dict(), f"saved_model/saved_model_dcgan_Flowers_bigbatch_R1_512_{epoch}.pth")
        d_loss_values.append(d_loss.item())
        g_loss_values.append(g_loss.item())

    epoch_end_time = time.time()
    epoch_times.append(epoch_end_time - epoch_start_time)

# Save model and plot training progress
torch.save(generator.state_dict(), f"saved_model/saved_model_dcgan_Flowers_bigbatch_R1_512_{EPOCHS}.pth")

average_time_per_epoch = sum(epoch_times) / len(epoch_times)
print(f"Average time per epoch: {average_time_per_epoch:.2f} seconds")

plt.plot(np.arange(0, EPOCHS, 2), d_loss_values, label='Discriminator loss')
plt.plot(np.arange(0, EPOCHS, 2), g_loss_values, label='Generator loss')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss values')
plt.savefig('loss_values_flowers_512.png')
plt.show()

plt.plot(np.arange(0, EPOCHS), epoch_times)
plt.title("Time taken per epoch")
plt.xlabel("Epoch")
plt.ylabel("Time (s)")
plt.savefig("time_per_epoch_flowers_512.png")
plt.show()


In [None]:
## Tried to emulate even bigger batch size in order to see some better results - the pictures are better, but not by much - the model's architecture might be the bottleneck here more than the batch size.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import time
from torchviz import make_dot
from torchinfo import summary

d_loss_values = []
g_loss_values = []
epoch_times = []  # List to store time taken for each epoch

# Define constants
IMG_SIZE = 512  # Set to 512x512 for the new output size
LATENT_DIM = 200  # Updated latent dimension
BATCH_SIZE = 8  # Updated batch size
EPOCHS = 100
accum_steps = 36  # Number of accumulation steps to simulate a larger batch size

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.init_size = IMG_SIZE // 32  # Adjusted for 512x512 images (512 / 32 = 16)
        self.l1 = nn.Sequential(nn.Linear(LATENT_DIM, 512 * self.init_size ** 2))
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.BatchNorm2d(32, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 16, 3, stride=1, padding=1),
            nn.BatchNorm2d(16, 0.8),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 512, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),   # 512x512 -> 256x256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 256x256 -> 128x128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1), # 128x128 -> 64x64
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1), # 64x64 -> 32x32
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, 4, stride=2, padding=1), # 32x32 -> 16x16
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 2048, 4, stride=2, padding=1), # 16x16 -> 8x8
            nn.BatchNorm2d(2048),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(2048 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])
dataloader = DataLoader(
    ConcatDataset([datasets.Flowers102(root='../../data/flowers', split='train', download=True, transform=transform),
                   datasets.Flowers102(root='../../data/flowers', split='val', download=True, transform=transform),
                   datasets.Flowers102(root='../../data/flowers', split='test', download=True, transform=transform)]),
    batch_size=BATCH_SIZE, shuffle=True
)

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Apply weights initialization to models
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers (consider lowering the learning rate slightly)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0003, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)

        # Labels
        valid = torch.ones(batch_size, 1, requires_grad=False).to(device)
        fake = torch.zeros(batch_size, 1, requires_grad=False).to(device)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad(set_to_none=True)

        z = torch.randn(batch_size, LATENT_DIM).to(device)
        gen_imgs = generator(z)

        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        # Accumulate gradients
        d_loss = d_loss / accum_steps
        d_loss.backward()

        if (i + 1) % accum_steps == 0 or (i + 1) == len(dataloader):
            optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad(set_to_none=True)

        # Measure generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss = g_loss / accum_steps
        g_loss.backward()

        if (i + 1) % accum_steps == 0 or (i + 1) == len(dataloader):
            optimizer_G.step()

        print(f"[Epoch {epoch}/{EPOCHS}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
        print(f"[GPU Memory Allocated: {allocated_memory:.2f} GB] [GPU Memory Reserved: {reserved_memory:.2f} GB]")

    # Save sample images and model checkpoints every few epochs
    if epoch % 2 == 0:
        save_image(gen_imgs.data[:25], f"images/{epoch}_DCGAN_Flowers_gigabatch_R1_512.png", nrow=5, normalize=True)
        torch.save(generator.state_dict(), f"saved_model/saved_model_dcgan_Flowers_gigabatch_R1_512_{epoch}.pth")
        d_loss_values.append(d_loss.item())
        g_loss_values.append(g_loss.item())

    epoch_end_time = time.time()
    epoch_times.append(epoch_end_time - epoch_start_time)

# Save model and plot training progress
torch.save(generator.state_dict(), f"saved_model/saved_model_dcgan_Flowers_gigabatch_R1_512_{EPOCHS}.pth")

average_time_per_epoch = sum(epoch_times) / len(epoch_times)
print(f"Average time per epoch: {average_time_per_epoch:.2f} seconds")

plt.plot(np.arange(0, EPOCHS, 2), d_loss_values, label='Discriminator loss')
plt.plot(np.arange(0, EPOCHS, 2), g_loss_values, label='Generator loss')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss values')
plt.savefig('loss_values_flowers_512_gigabatch.png')
plt.show()

plt.plot(np.arange(0, EPOCHS), epoch_times)
plt.title("Time taken per epoch")
plt.xlabel("Epoch")
plt.ylabel("Time (s)")
plt.savefig("time_per_epoch_flowers_512_gigabatch.png")
plt.show()
