In [1]:
#Import the necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
#Make the discriminator
#Here img_data = 28*28 = 784
#Here LeakyReLU is used, SeLU can also be used instead

class Discriminator(nn.Module):
    def __init__(self, img_data):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_data, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)


In [3]:
#Make the generator
#Here the inputs are the random noise, denoted by z_dim and img_dim = 28*28 = 784
#Here LeakyReLU is used, SeLU can also be used instead

#The last layer is Tanh, instead of sigmoid because Tanh has a even distribution across zero and has
#better gradients while optimizing than sigmoid. The pixel values are dealt later in the subsequent code.

class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)

In [4]:
#hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((25, z_dim)).to(device)  #at every epoch we generate 25 images

transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

In [5]:
#MNIST Dataset
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 11503214.41it/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 269113.88it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3192182.17it/s]


Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2558835.29it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw






In [6]:
#function to generate and save images at every epoch

def generate_and_save_images(gen, epoch):
    with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        fake = fake * 0.5 + 0.5  # rescale from [-1, 1] to [0, 1]
        plt.figure(figsize=(5, 5))
        for i in range(25):
            plt.subplot(5, 5, i + 1)
            plt.imshow(fake[i][0].cpu(), cmap='gray')
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(f'generated_images_epoch_{epoch+1}.png')
        plt.close()

In [7]:
#Training loop

for epoch in tqdm(range(num_epochs), desc="Epochs"):
    running_lossD = 0.0
    running_lossG = 0.0
    total_batches = len(loader)
    
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        running_lossD += lossD.item()
        running_lossG += lossG.item()

    avg_lossD = running_lossD / total_batches
    avg_lossG = running_lossG / total_batches
    tqdm.write(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {avg_lossD:.4f}, Loss G: {avg_lossG:.4f}")
    
    generate_and_save_images(gen, epoch)

print("training finished!")


Epochs:   0%|          | 0/50 [00:20<?, ?it/s]

Epoch [1/50] Loss D: 0.4090, Loss G: 1.3100


Epochs:   2%|▏         | 1/50 [00:40<17:05, 20.93s/it]

Epoch [2/50] Loss D: 0.5662, Loss G: 1.0510


Epochs:   4%|▍         | 2/50 [01:00<16:19, 20.41s/it]

Epoch [3/50] Loss D: 0.5538, Loss G: 1.1854


Epochs:   6%|▌         | 3/50 [01:21<16:04, 20.52s/it]

Epoch [4/50] Loss D: 0.5555, Loss G: 1.1937


Epochs:   8%|▊         | 4/50 [01:41<15:47, 20.59s/it]

Epoch [5/50] Loss D: 0.6567, Loss G: 1.0008


Epochs:  10%|█         | 5/50 [02:02<15:26, 20.59s/it]

Epoch [6/50] Loss D: 0.6732, Loss G: 0.9749


Epochs:  12%|█▏        | 6/50 [02:22<15:01, 20.49s/it]

Epoch [7/50] Loss D: 0.6661, Loss G: 0.9550


Epochs:  14%|█▍        | 7/50 [02:42<14:35, 20.35s/it]

Epoch [8/50] Loss D: 0.7091, Loss G: 0.9071


Epochs:  16%|█▌        | 8/50 [03:03<14:16, 20.40s/it]

Epoch [9/50] Loss D: 0.6522, Loss G: 1.0448


Epochs:  18%|█▊        | 9/50 [03:23<13:58, 20.44s/it]

Epoch [10/50] Loss D: 0.6508, Loss G: 1.0525


Epochs:  20%|██        | 10/50 [03:44<13:34, 20.35s/it]

Epoch [11/50] Loss D: 0.6242, Loss G: 1.0672


Epochs:  22%|██▏       | 11/50 [04:04<13:18, 20.48s/it]

Epoch [12/50] Loss D: 0.6043, Loss G: 1.0928


Epochs:  24%|██▍       | 12/50 [04:24<12:56, 20.44s/it]

Epoch [13/50] Loss D: 0.5689, Loss G: 1.2507


Epochs:  26%|██▌       | 13/50 [04:45<12:32, 20.35s/it]

Epoch [14/50] Loss D: 0.5351, Loss G: 1.2470


Epochs:  28%|██▊       | 14/50 [05:05<12:12, 20.36s/it]

Epoch [15/50] Loss D: 0.5944, Loss G: 1.1978


Epochs:  30%|███       | 15/50 [05:27<11:56, 20.48s/it]

Epoch [16/50] Loss D: 0.5344, Loss G: 1.3633


Epochs:  32%|███▏      | 16/50 [05:49<11:49, 20.86s/it]

Epoch [17/50] Loss D: 0.5154, Loss G: 1.4401


Epochs:  34%|███▍      | 17/50 [06:09<11:36, 21.10s/it]

Epoch [18/50] Loss D: 0.5417, Loss G: 1.4410


Epochs:  36%|███▌      | 18/50 [06:30<11:08, 20.90s/it]

Epoch [19/50] Loss D: 0.5566, Loss G: 1.3988


Epochs:  38%|███▊      | 19/50 [06:50<10:42, 20.73s/it]

Epoch [20/50] Loss D: 0.5816, Loss G: 1.3223


Epochs:  40%|████      | 20/50 [07:10<10:15, 20.51s/it]

Epoch [21/50] Loss D: 0.6200, Loss G: 1.2261


Epochs:  42%|████▏     | 21/50 [07:30<09:53, 20.47s/it]

Epoch [22/50] Loss D: 0.5980, Loss G: 1.2952


Epochs:  44%|████▍     | 22/50 [07:50<09:27, 20.25s/it]

Epoch [23/50] Loss D: 0.5703, Loss G: 1.3292


Epochs:  46%|████▌     | 23/50 [08:10<09:03, 20.11s/it]

Epoch [24/50] Loss D: 0.5003, Loss G: 1.6231


Epochs:  48%|████▊     | 24/50 [08:30<08:40, 20.02s/it]

Epoch [25/50] Loss D: 0.5759, Loss G: 1.4241


Epochs:  50%|█████     | 25/50 [08:50<08:21, 20.06s/it]

Epoch [26/50] Loss D: 0.5738, Loss G: 1.4406


Epochs:  52%|█████▏    | 26/50 [09:10<08:01, 20.05s/it]

Epoch [27/50] Loss D: 0.5646, Loss G: 1.4460


Epochs:  54%|█████▍    | 27/50 [09:30<07:45, 20.22s/it]

Epoch [28/50] Loss D: 0.5583, Loss G: 1.4497


Epochs:  56%|█████▌    | 28/50 [09:50<07:23, 20.16s/it]

Epoch [29/50] Loss D: 0.5998, Loss G: 1.3557


Epochs:  58%|█████▊    | 29/50 [10:10<07:02, 20.10s/it]

Epoch [30/50] Loss D: 0.5685, Loss G: 1.3874


Epochs:  60%|██████    | 30/50 [10:31<06:42, 20.13s/it]

Epoch [31/50] Loss D: 0.5774, Loss G: 1.3844


Epochs:  62%|██████▏   | 31/50 [10:51<06:22, 20.15s/it]

Epoch [32/50] Loss D: 0.5526, Loss G: 1.4961


Epochs:  64%|██████▍   | 32/50 [11:11<06:01, 20.11s/it]

Epoch [33/50] Loss D: 0.5754, Loss G: 1.4160


Epochs:  66%|██████▌   | 33/50 [11:31<05:41, 20.10s/it]

Epoch [34/50] Loss D: 0.5869, Loss G: 1.3802


Epochs:  68%|██████▊   | 34/50 [11:51<05:21, 20.07s/it]

Epoch [35/50] Loss D: 0.6074, Loss G: 1.3506


Epochs:  70%|███████   | 35/50 [12:11<05:00, 20.03s/it]

Epoch [36/50] Loss D: 0.6144, Loss G: 1.2871


Epochs:  72%|███████▏  | 36/50 [12:32<04:43, 20.28s/it]

Epoch [37/50] Loss D: 0.5987, Loss G: 1.3016


Epochs:  74%|███████▍  | 37/50 [12:51<04:22, 20.20s/it]

Epoch [38/50] Loss D: 0.5784, Loss G: 1.3843


Epochs:  76%|███████▌  | 38/50 [13:11<04:01, 20.09s/it]

Epoch [39/50] Loss D: 0.6293, Loss G: 1.2431


Epochs:  78%|███████▊  | 39/50 [13:31<03:40, 20.06s/it]

Epoch [40/50] Loss D: 0.6176, Loss G: 1.2265


Epochs:  80%|████████  | 40/50 [13:51<03:19, 20.00s/it]

Epoch [41/50] Loss D: 0.6230, Loss G: 1.2006


Epochs:  82%|████████▏ | 41/50 [14:11<02:58, 19.88s/it]

Epoch [42/50] Loss D: 0.6077, Loss G: 1.2130


Epochs:  84%|████████▍ | 42/50 [14:30<02:38, 19.84s/it]

Epoch [43/50] Loss D: 0.6415, Loss G: 1.1065


Epochs:  86%|████████▌ | 43/50 [14:50<02:18, 19.79s/it]

Epoch [44/50] Loss D: 0.6217, Loss G: 1.1349


Epochs:  88%|████████▊ | 44/50 [15:10<01:58, 19.80s/it]

Epoch [45/50] Loss D: 0.5912, Loss G: 1.1798


Epochs:  90%|█████████ | 45/50 [15:30<01:39, 19.85s/it]

Epoch [46/50] Loss D: 0.5950, Loss G: 1.1773


Epochs:  92%|█████████▏| 46/50 [15:51<01:20, 20.11s/it]

Epoch [47/50] Loss D: 0.6464, Loss G: 1.0660


Epochs:  94%|█████████▍| 47/50 [16:11<01:00, 20.09s/it]

Epoch [48/50] Loss D: 0.6397, Loss G: 1.0533


Epochs:  96%|█████████▌| 48/50 [16:31<00:40, 20.10s/it]

Epoch [49/50] Loss D: 0.6304, Loss G: 1.0843


Epochs:  98%|█████████▊| 49/50 [16:51<00:20, 20.13s/it]

Epoch [50/50] Loss D: 0.5975, Loss G: 1.1142


Epochs: 100%|██████████| 50/50 [16:52<00:00, 20.25s/it]

training finished!





In [8]:
torch.save(gen.state_dict(), "generator.pth")
torch.save(disc.state_dict(), "discriminator.pth")