In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import os

In [2]:
print(torch.cuda.is_available())
print(torch.version.cuda)
print(torch.__version__)

True
12.4
2.4.1+cu124


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
numberofEpochs = 100
batchSize = 32
learningRateGenerator = 0.0002
learningRateDiscriminator = 0.0002
imageSize = 64

In [5]:
transform = transforms.Compose([
    transforms.Resize(64), 
    transforms.ToTensor(), 
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = dsets.CIFAR10(root='./data', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size = batchSize, shuffle=True)

Files already downloaded and verified


In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False), 
            nn.BatchNorm2d(512), 
            nn.ReLU(True), 

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128), 
            nn.ReLU(True), 

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(64), 
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False), 
            nn.Tanh()
        )
    
    def forward(self, input):
        return self.main(input)

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias = False), 
            nn.LeakyReLU(0.2, inplace = True), 

            nn.Conv2d(64, 128, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True), 

            nn.Conv2d(128, 256, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(256), 
            nn.LeakyReLU(0.2, inplace = True),

            nn.Conv2d(256, 512, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(512), 
            nn.LeakyReLU(0.2, inplace = True), 

            nn.Conv2d(512, 1, 4, 1, 0, bias = False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)    

In [8]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)

In [9]:
criterion = nn.BCELoss()
optimizerDiscriminator = optim.Adam(Discriminator().parameters(), lr = learningRateDiscriminator, betas = (0.5, 0.999))
optimizerGenerator = optim.Adam(Generator().parameters(), lr = learningRateGenerator, betas = (0.5, 0.999))

In [None]:
def GetGradientNorms(model):
    totalNorm = 0
    for p in model.parameters():
        paramNorm = p.grad.data.norm(2)
        totalNorm += paramNorm.item() ** 2
    return totalNorm ** (1 / 2)

generatorLosses = []
discriminatorLosses = []
generatorGradientNorms = []
discriminatorGradientNorms = []

In [10]:
saveDir = os.path.join(os.getcwd(), "FakeImages")

if not os.path.exists(saveDir):
    os.makedirs(saveDir)

saveDir


'd:\\Finally, A study Folder\\Thapar Summer School on Machine Learning and Deep Learning\\GAN\\DC-GAN\\FakeImages'

In [None]:
fixedNoise = torch.randn(batchSize, 100, 1, 1, device = device)

for epochs in range(numberofEpochs):
    for i, data in enumerate(dataloader, 0):
        realImages, _ = data
        realImages = realImages.to(device)

        optimizerDiscriminator.zero_grad()
        
        output = discriminator(realImages).view(-1)
        realLabels = torch.ones(output.size()).to(device)
        realLoss = criterion(output, realLabels)
        realLoss.backward()

        noise = torch.randn(128, 100, 1, 1, device = device)
        fakeImages = generator(noise)
        output = discriminator(fakeImages.detach()).view(-1)
        fakeLabels = torch.zeros(output.size()).to(device)
        fakeLoss = criterion(output, fakeLabels)
        fakeLoss.backward()
        discriminatorGradientNorms.append(GetGradientNorms(discriminator))
        discriminatorGradientNorm = GetGradientNorms(discriminator)
        optimizerDiscriminator.step()

        optimizerGenerator.zero_grad()
        output = discriminator(fakeImages).view(-1)
        realLabels = torch.ones(output.size()).to(device)
        generatorLoss = criterion(output, realLabels)
        generatorLoss.backward()
        generatorGradientNorms.append(GetGradientNorms(generator))
        generatorGradientNorm = GetGradientNorms(generator)
        optimizerGenerator.step()

        generatorLosses.append(generatorLoss.item())
        discriminatorLosses.append(realLoss.item() + fakeLoss.item())
        generatorGradientNorms.append(generatorGradientNorm)
        discriminatorGradientNorms.append(discriminatorGradientNorm)

    with torch.no_grad():
        fake = generator(fixedNoise).detach().cpu()
    vutils.save_image(fake, f"{saveDir}/fake_{epochs+1}.png", normalize = True)
    print(f"Epochs: {epochs}, Generator Loss: {generatorLoss}, Discriminator Loss: {realLoss + fakeLoss}")


    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(generatorLosses, label='Generator')
    plt.plot(discriminatorLosses, label='Discriminator')
    plt.legend()
    plt.title(f"Epochs: {epochs}, G Loss: {generatorLosses[-1]:.4f}, D Loss: {discriminatorLosses[-1]:.4f}")

    plt.subplot(1, 2, 2)
    plt.plot(generatorGradientNorms, label='Generator')
    plt.plot(discriminatorGradientNorms, label='Discriminator')
    plt.legend()
    plt.title("Gradient Norms")
    plt.savefig(f"{saveDir}/epoch_{epochs}.png")
    plt.close()

Epochs: 0, Generator Loss: 0.6894029974937439, Discriminator Loss: 1.5193023681640625
Epochs: 1, Generator Loss: 0.6518171429634094, Discriminator Loss: 1.5039211511611938
Epochs: 2, Generator Loss: 0.652620792388916, Discriminator Loss: 1.4483940601348877
Epochs: 3, Generator Loss: 0.6577266454696655, Discriminator Loss: 1.4484572410583496
Epochs: 4, Generator Loss: 0.6859258413314819, Discriminator Loss: 1.4423766136169434
Epochs: 5, Generator Loss: 0.6812514662742615, Discriminator Loss: 1.5000290870666504
Epochs: 6, Generator Loss: 0.6983600854873657, Discriminator Loss: 1.528090000152588
Epochs: 7, Generator Loss: 0.6836071014404297, Discriminator Loss: 1.49661386013031
Epochs: 8, Generator Loss: 0.6891034245491028, Discriminator Loss: 1.5012075901031494
Epochs: 9, Generator Loss: 0.6575510501861572, Discriminator Loss: 1.4722204208374023
Epochs: 10, Generator Loss: 0.6856337785720825, Discriminator Loss: 1.4156017303466797
Epochs: 11, Generator Loss: 0.6526297330856323, Discrimin