In [1]:
import os
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vmy_utils

import my_utils


CUDA = True     # Change to False for CPU training
DATA_PATH = '~/Data/mnist'
#DATA_PATH = '/media/john/FastData/CelebA'
# DATA_PATH = '/media/john/FastData/lsun'
OUT_PATH = 'output'
LOG_FILE = os.path.join(OUT_PATH, 'log.txt')
BATCH_SIZE = 512        # Adjust this value according to your GPU memory
IMAGE_CHANNEL = 1
# IMAGE_CHANNEL = 3
Z_DIM = 100
G_HIDDEN = 64
X_DIM = 64
D_HIDDEN = 64
EPOCH_NUM = 12
REAL_LABEL = 1
FAKE_LABEL = 0
lr = 2e-4
seed = 1            # Change to None to get different results at each run

my_utils.clear_folder(OUT_PATH)
print("Logging to {}\n".format(LOG_FILE))
sys.stdout = my_utils.StdOut(LOG_FILE)
CUDA = CUDA and torch.cuda.is_available()
print("PyTorch version: {}".format(torch.__version__))
if CUDA:
    print("CUDA version: {}\n".format(torch.version.cuda))

if seed is None:
    seed = np.random.randint(1, 10000)
print("Random Seed: ", seed)
np.random.seed(seed)
torch.manual_seed(seed)
if CUDA:
    torch.cuda.manual_seed(seed)
cudnn.benchmark = True      # May train faster but cost more memory

dataset = dset.MNIST(root=DATA_PATH, download=True,
                      transform=transforms.Compose([
                      transforms.Resize(X_DIM),
                      transforms.ToTensor(),
                      transforms.Normalize((0.5,), (0.5,))
                      ]))
#dataset = dset.ImageFolder(root=DATA_PATH,
#                           transform=transforms.Compose([
#                           transforms.Resize(X_DIM),
#                           transforms.CenterCrop(X_DIM),
#                           transforms.ToTensor(),
#                           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#                           ]))
# dataset = dset.LSUN(root=DATA_PATH, classes=['bedroom_train'],
#                     transform=transforms.Compose([
#                     transforms.Resize(X_DIM),
#                     transforms.CenterCrop(X_DIM),
#                     transforms.ToTensor(),
#                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#                     ]))

assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=2)

device = torch.device("cuda:0" if CUDA else "cpu")


def weights_init(m):
    """custom weights initialization
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 1st layer
            nn.ConvTranspose2d(Z_DIM, G_HIDDEN * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 8),
            nn.ReLU(True),
            # 2nd layer
            nn.ConvTranspose2d(G_HIDDEN * 8, G_HIDDEN * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 4),
            nn.ReLU(True),
            # 3rd layer
            nn.ConvTranspose2d(G_HIDDEN * 4, G_HIDDEN * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 2),
            nn.ReLU(True),
            # 4th layer
            nn.ConvTranspose2d(G_HIDDEN * 2, G_HIDDEN, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN),
            nn.ReLU(True),
            # output layer
            nn.ConvTranspose2d(G_HIDDEN, IMAGE_CHANNEL, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 1st layer
            nn.Conv2d(IMAGE_CHANNEL, D_HIDDEN, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 2nd layer
            nn.Conv2d(D_HIDDEN, D_HIDDEN * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 3rd layer
            nn.Conv2d(D_HIDDEN * 2, D_HIDDEN * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 4th layer
            nn.Conv2d(D_HIDDEN * 4, D_HIDDEN * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # output layer
            nn.Conv2d(D_HIDDEN * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)


netG = Generator().to(device)
netG.apply(weights_init)
print(netG)

netD = Discriminator().to(device)
netD.apply(weights_init)
print(netD)

criterion = nn.BCELoss()

viz_noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1, device=device)

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

for epoch in range(EPOCH_NUM):
    for i, data in enumerate(dataloader):
        x_real = data[0].to(device)

        real_label = torch.full((x_real.size(0),), REAL_LABEL, device=device,dtype=torch.float32)
        fake_label = torch.full((x_real.size(0),), FAKE_LABEL, device=device,dtype=torch.float32)

        # Update D with real data
        netD.zero_grad()
        output_real = netD(x_real)
        loss_D_real = criterion(output_real, real_label)
        loss_D_real.backward()
        D_x = output_real.mean().item()



        # Update D with fake data
        z_noise = torch.randn(x_real.size(0), Z_DIM, 1, 1, device=device)
        x_fake = netG(z_noise)
        output_fake = netD(x_fake.detach())  # Detach to prevent gradients from flowing into G
        loss_D_fake = criterion(output_fake, fake_label)
        loss_D_fake.backward()
        D_G_z1 = output_fake.mean().item()
        # optimizerD.step()

        # Update G with fake data
        netG.zero_grad()
        output_fake = netD(x_fake)
        loss_G = criterion(output_fake, real_label)
        loss_G.backward()
        D_G_z2 = output_fake.mean().item()

        # Update the discriminator and generator
        optimizerD.step()
        optimizerG.step()

        if i % 100 == 0:
            print('Epoch {} [{}/{}] loss_D_real: {:.4f} loss_D_fake: {:.4f} loss_G: {:.4f}'.format(
                epoch, i, len(dataloader),
                loss_D_real.mean().item(),
                loss_D_fake.mean().item(),
                loss_G.mean().item()
            ))
            vmy_utils.save_image(x_real, os.path.join(OUT_PATH, 'real_samples{}.png'.format(epoch)), normalize=True)
            with torch.no_grad():
                viz_sample = netG(viz_noise)
                vmy_utils.save_image(viz_sample, os.path.join(OUT_PATH, 'fake_samples_{}.png'.format(epoch)), normalize=True)
    torch.save(netG.state_dict(), os.path.join(OUT_PATH, 'netG_{}.pth'.format(epoch)))
    torch.save(netD.state_dict(), os.path.join(OUT_PATH, 'netD_{}.pth'.format(epoch)))

Logging to output/log.txt

PyTorch version: 2.2.1+cu121
CUDA version: 12.1

Random Seed:  1
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /root/Data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 4403733.67it/s]


Extracting /root/Data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /root/Data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /root/Data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 66238.49it/s]

Extracting /root/Data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /root/Data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /root/Data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1247870.89it/s]

Extracting /root/Data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/Data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /root/Data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4426238.10it/s]

Extracting /root/Data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/Data/mnist/MNIST/raw






Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)


  self.pid = os.fork()


Epoch 0 [0/118] loss_D_real: 1.1193 loss_D_fake: 1.0719 loss_G: 0.6338
Epoch 0 [100/118] loss_D_real: 0.0419 loss_D_fake: 2.5201 loss_G: 0.0941
Epoch 1 [0/118] loss_D_real: 0.0375 loss_D_fake: 2.4327 loss_G: 0.1002
Epoch 1 [100/118] loss_D_real: 0.0759 loss_D_fake: 1.2287 loss_G: 0.3565
Epoch 2 [0/118] loss_D_real: 0.0961 loss_D_fake: 0.8784 loss_G: 0.5415
Epoch 2 [100/118] loss_D_real: 0.2251 loss_D_fake: 0.2097 loss_G: 1.7283
Epoch 3 [0/118] loss_D_real: 0.1145 loss_D_fake: 0.9495 loss_G: 0.5048
Epoch 3 [100/118] loss_D_real: 0.2387 loss_D_fake: 0.9187 loss_G: 0.5259
Epoch 4 [0/118] loss_D_real: 0.1300 loss_D_fake: 1.0635 loss_G: 0.4441
Epoch 4 [100/118] loss_D_real: 0.1349 loss_D_fake: 1.4154 loss_G: 0.2956
Epoch 5 [0/118] loss_D_real: 0.1523 loss_D_fake: 0.8614 loss_G: 0.5650
Epoch 5 [100/118] loss_D_real: 0.2526 loss_D_fake: 0.6003 loss_G: 0.8260
Epoch 6 [0/118] loss_D_real: 0.2247 loss_D_fake: 0.6952 loss_G: 0.7192
Epoch 6 [100/118] loss_D_real: 0.1086 loss_D_fake: 1.6503 loss_G:

In [2]:
data[0].shape

torch.Size([96, 1, 64, 64])