In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm #Configuration parameters
epochs = 100 # Number of training epochs
batch_size = 64 # Size of each training batch
sample_size = 100 # Size of the random noise vector input to generator
g_lr = 1.0e-4 # Learning rate for the generator
d_lr = 1.0e-4 # Learning rate for the discriminator
#DataLoader for the MNIST dataset

transform = transforms.ToTensor()
dataset = datasets.MNIST (root='./data', train=True, download=True, transform=transform) # fix: assign to dataset
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)
#Generator Network
class Generator(nn.Sequential):
    def __init__(self, sample_size: int): # fix: __init
    # The generator takes a random noise vector of size sample size and outputs a 784-dimensional image vec
        super().__init__(
            nn.Linear(sample_size, 128), # First linear layer
            nn.LeakyReLU(0.01), # Activation
            nn.Linear(128,784), # Output layer to match image size
            nn.Sigmoid()) # Normalize outputs between 0 and 1
        self.sample_size = sample_size
    def forward(self, batch_size: int):
        z = torch.randn(batch_size, self.sample_size) # Generate random noise input
        output = super().forward(z) # Forward pass
        generated_images = output.reshape(batch_size, 1, 28, 28) # Reshape to image format
        return generated_images
# Discriminator Network
class Discriminator(nn.Sequential):
    def __init__(self): # fix: __init
        # The discriminator takes a 784-dimensional image vector and outputs a single logit
        super().__init__(
            nn.Linear(784, 128), # Input layer
            nn.LeakyReLU(0.01), # Activation
            nn.Linear(128, 1)) # Output logit
    def forward(self, images: torch.Tensor, targets: torch.Tensor):
        prediction = super().forward(images.reshape(-1, 784))   #Flatten image and forward pass
        loss = F.binary_cross_entropy_with_logits(prediction, targets) # BCE loss with logits
        return loss
#Function to save generated images in a grid layout
def save_image_grid(epoch: int, images: torch.Tensor, ncol: int):
    image_grid = make_grid(images, ncol) #Arrange images into a grid
    image_grid = image_grid.permute(1, 2, 0) #Convert from CHW to HWC for matplotlib # fix: assign to image_grid
    image_grid = image_grid.cpu().numpy() # Convert to numpy for plotting # Moved this line inside the function
    plt.imshow(image_grid) # Display image grid
    plt.xticks([]); plt.yticks([]) # Remove axis ticks
    plt.savefig(f'generated_{epoch:03d}.jpg') # Save image
    plt.close() #close the plot
#Labels for training the discriminator
real_targets = torch.ones(batch_size, 1) # Label for real images
fake_targets = torch.zeros(batch_size, 1) # Label for fake images
# Initialize generator and discriminator
generator = Generator(sample_size)
discriminator = Discriminator()
# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr) # fix: corrected d_1r to d_lr
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)
# Training loop
for epoch in range(epochs):
    d_losses = [] # Track discriminator losses
    g_losses = [] # Track generator losses
    for images, labels in tqdm(dataloader):
        #
        #Train the Discriminator Network
        #
        discriminator.train() # Set discriminator to training mode
        d_loss = discriminator(images, real_targets) # Loss for real images # fix: assign to d_loss, call discriminator
        generator.eval() # Set generator to eval mode for generating fakes
        with torch.no_grad():
            generated_images = generator(batch_size) # Generate fake images
        d_loss += discriminator(generated_images, fake_targets) # fix: call discriminator
        d_optimizer.zero_grad() # Clear gradients
        d_loss.backward() # Backpropagate
        d_optimizer.step() # Update discriminator parameters
        #Train the Generator Network
        #
        generator.train() # Set generator to training mode
        generated_images = generator(batch_size) # Generate fake images
        discriminator.eval() # Discriminator in eval mode but still requires gradients
        g_loss = discriminator(generated_images, real_targets) # Try to fool discriminator # fix: call discriminator
        g_optimizer.zero_grad() # Clear gradients
        g_loss.backward() # Backpropagate
        g_optimizer.step() # Update generator parameters
        # Record losses
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())
    # Print losses at end of each epoch
    print(epoch, np.mean(d_losses), np.mean(g_losses))
    # Save generated samples
    save_image_grid(epoch, generator(batch_size), ncol=8)

100%|██████████| 9.91M/9.91M [00:00<00:00, 16.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 487kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.45MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 13.3MB/s]
100%|██████████| 937/937 [00:11<00:00, 81.03it/s]


0 0.42999258987295463 2.720147139680042


100%|██████████| 937/937 [00:10<00:00, 87.22it/s]


1 0.32672598528632996 2.2346837675813167


100%|██████████| 937/937 [00:11<00:00, 84.00it/s]


2 0.45584978517880437 1.659167510722719


100%|██████████| 937/937 [00:11<00:00, 83.71it/s]


3 0.5695850273779133 1.2280829789289924


100%|██████████| 937/937 [00:11<00:00, 83.36it/s]


4 0.39944403550986163 1.6468218414862357


100%|██████████| 937/937 [00:12<00:00, 77.06it/s]


5 0.3915709119564441 1.9938432157612151


100%|██████████| 937/937 [00:12<00:00, 76.85it/s]


6 0.48111864765400314 2.1386036045904984


100%|██████████| 937/937 [00:11<00:00, 83.43it/s]


7 0.49289584124864355 1.9981869574926452


100%|██████████| 937/937 [00:11<00:00, 82.88it/s]


8 0.54520026342337 1.9509909932870366


100%|██████████| 937/937 [00:11<00:00, 83.68it/s]


9 0.573164553945067 1.8400702000682103


100%|██████████| 937/937 [00:11<00:00, 82.60it/s]


10 0.5285396089930418 1.9345906602662044


100%|██████████| 937/937 [00:11<00:00, 82.48it/s]


11 0.5395311294014197 1.9058060214766317


100%|██████████| 937/937 [00:11<00:00, 83.38it/s]


12 0.44853301131610807 2.0135647214972985


100%|██████████| 937/937 [00:11<00:00, 84.21it/s]


13 0.4107577297419087 2.113236270693857


100%|██████████| 937/937 [00:10<00:00, 87.21it/s]


14 0.39084842952173066 2.267715590613374


100%|██████████| 937/937 [00:11<00:00, 83.27it/s]


15 0.36409466520444816 2.4890706147302812


100%|██████████| 937/937 [00:11<00:00, 81.36it/s]


16 0.47935299925196007 2.331653665580261


100%|██████████| 937/937 [00:11<00:00, 81.55it/s]


17 0.5131405402718384 2.2150047269422193


100%|██████████| 937/937 [00:11<00:00, 82.32it/s]


18 0.5353263239818551 2.1536224939270996


100%|██████████| 937/937 [00:11<00:00, 82.72it/s]


19 0.6014899383105234 2.0381876597918467


100%|██████████| 937/937 [00:11<00:00, 82.17it/s]


20 0.4842046870906681 2.1904944905349257


100%|██████████| 937/937 [00:11<00:00, 82.31it/s]


21 0.4200121670198033 2.3518323344889165


100%|██████████| 937/937 [00:11<00:00, 83.08it/s]


22 0.4757684991479302 2.2681170480610087


100%|██████████| 937/937 [00:11<00:00, 82.92it/s]


23 0.4753601503823839 2.3525594768045806


100%|██████████| 937/937 [00:11<00:00, 83.87it/s]


24 0.40973390071089266 2.508913420816623


100%|██████████| 937/937 [00:11<00:00, 84.66it/s]


25 0.4171508739254136 2.4932618311655053


100%|██████████| 937/937 [00:11<00:00, 78.26it/s]


26 0.48362979246750715 2.2885271759908505


100%|██████████| 937/937 [00:10<00:00, 86.33it/s]


27 0.46295852587024583 2.40779122955135


100%|██████████| 937/937 [00:11<00:00, 83.52it/s]


28 0.5013787017433468 2.2810973020602443


100%|██████████| 937/937 [00:11<00:00, 83.46it/s]


29 0.4796283252084014 2.2919661358809904


100%|██████████| 937/937 [00:11<00:00, 83.04it/s]


30 0.5257746634831932 2.206756964946124


100%|██████████| 937/937 [00:11<00:00, 82.99it/s]


31 0.5146352449628816 2.288617239436065


100%|██████████| 937/937 [00:11<00:00, 82.56it/s]


32 0.49473530723610964 2.338236888292123


100%|██████████| 937/937 [00:11<00:00, 83.14it/s]


33 0.49320973848711336 2.2866689399314155


100%|██████████| 937/937 [00:11<00:00, 83.82it/s]


34 0.47848884336721936 2.388307631969961


100%|██████████| 937/937 [00:11<00:00, 84.14it/s]


35 0.5337618968371012 2.2820279538440604


100%|██████████| 937/937 [00:11<00:00, 83.49it/s]


36 0.546805914272748 2.217830216744666


100%|██████████| 937/937 [00:11<00:00, 84.21it/s]


37 0.5333610639182836 2.2243019534531374


100%|██████████| 937/937 [00:10<00:00, 87.95it/s]


38 0.5540964436123949 2.208040486659413


100%|██████████| 937/937 [00:11<00:00, 85.16it/s]


39 0.5749412040636659 2.2512400937818287


100%|██████████| 937/937 [00:11<00:00, 84.15it/s]


40 0.5761362460214275 2.2334394126717028


100%|██████████| 937/937 [00:11<00:00, 84.29it/s]


41 0.5612152014305396 2.2288010031334746


100%|██████████| 937/937 [00:11<00:00, 82.74it/s]


42 0.5611362879660493 2.28173932207305


100%|██████████| 937/937 [00:11<00:00, 83.33it/s]


43 0.583233433923701 2.220935413270012


100%|██████████| 937/937 [00:11<00:00, 82.60it/s]


44 0.5701581358337097 2.230516234355141


100%|██████████| 937/937 [00:11<00:00, 83.25it/s]


45 0.5618445458508861 2.2742840979117087


100%|██████████| 937/937 [00:12<00:00, 76.84it/s]


46 0.5859550465705428 2.2620923338475833


100%|██████████| 937/937 [00:11<00:00, 82.55it/s]


47 0.7414434839878927 2.102050753133401


100%|██████████| 937/937 [00:11<00:00, 83.21it/s]


48 0.5170136654994022 2.5030511639288675


100%|██████████| 937/937 [00:11<00:00, 83.85it/s]


49 0.6131213586638297 2.1789899371095247


100%|██████████| 937/937 [00:10<00:00, 85.73it/s]


50 0.9806342998652728 1.8692734236396618


100%|██████████| 937/937 [00:11<00:00, 82.97it/s]


51 0.6393142652676989 2.2642966433802942


100%|██████████| 937/937 [00:11<00:00, 82.75it/s]


52 0.8371871004364152 2.277421808166463


100%|██████████| 937/937 [00:11<00:00, 82.90it/s]


53 0.5273352877307472 2.445417584833493


100%|██████████| 937/937 [00:11<00:00, 83.40it/s]


54 0.6837893612484032 2.1555080141558083


100%|██████████| 937/937 [00:11<00:00, 83.39it/s]


55 0.6535946543595834 2.1481542041487476


100%|██████████| 937/937 [00:11<00:00, 83.68it/s]


56 0.6274023706465165 2.155096963604588


100%|██████████| 937/937 [00:11<00:00, 82.13it/s]


57 0.6128473111697679 2.1542581822345452


100%|██████████| 937/937 [00:11<00:00, 82.01it/s]


58 0.5906650158740667 2.1703341906709377


100%|██████████| 937/937 [00:11<00:00, 81.87it/s]


59 0.5994723030508137 2.2208394663310993


100%|██████████| 937/937 [00:11<00:00, 81.52it/s]


60 0.6039250675616676 2.2343224196703737


100%|██████████| 937/937 [00:11<00:00, 82.53it/s]


61 0.6100876029489389 2.203342817636284


100%|██████████| 937/937 [00:10<00:00, 85.42it/s]


62 0.6231458670874161 2.1787919556700177


100%|██████████| 937/937 [00:11<00:00, 83.60it/s]


63 0.6445539115968865 2.16891201303252


100%|██████████| 937/937 [00:11<00:00, 82.71it/s]


64 1.0815393348068412 1.7628787117299556


100%|██████████| 937/937 [00:11<00:00, 82.71it/s]


65 0.6251439628377159 2.191102158170881


100%|██████████| 937/937 [00:11<00:00, 82.84it/s]


66 0.5802980799748778 2.3482080248911075


100%|██████████| 937/937 [00:11<00:00, 83.69it/s]


67 0.6713087811922951 2.049580693753832


100%|██████████| 937/937 [00:11<00:00, 82.68it/s]


68 0.6510930937280522 2.0968184548737274


100%|██████████| 937/937 [00:13<00:00, 69.43it/s]


69 0.6520984273010283 2.115903436946767


100%|██████████| 937/937 [00:12<00:00, 75.88it/s]


70 0.6629645610377272 2.12231949464743


100%|██████████| 937/937 [00:11<00:00, 78.31it/s]


71 0.6686700867168931 2.1222389996242623


100%|██████████| 937/937 [00:11<00:00, 82.32it/s]


72 0.6753242276648701 2.0884172133982117


100%|██████████| 937/937 [00:11<00:00, 80.93it/s]


73 0.6860674382083062 2.06071150073883


100%|██████████| 937/937 [00:11<00:00, 81.70it/s]


74 0.6794311887362978 2.0647627751371775


100%|██████████| 937/937 [00:12<00:00, 75.71it/s]


75 0.656901588021373 2.1022920946937487


100%|██████████| 937/937 [00:11<00:00, 81.30it/s]


76 0.6618309195929682 2.086244152347968


100%|██████████| 937/937 [00:11<00:00, 81.03it/s]


77 0.648408621454188 2.074344747857834


100%|██████████| 937/937 [00:11<00:00, 81.25it/s]


78 0.6289314115251013 2.1279479585818573


100%|██████████| 937/937 [00:11<00:00, 83.90it/s]


79 0.6330112041378734 2.1410073554630342


100%|██████████| 937/937 [00:10<00:00, 85.33it/s]


80 0.6382230377438992 2.1281854150898303


100%|██████████| 937/937 [00:11<00:00, 81.42it/s]


81 0.6339673621295737 2.1472464474186443


100%|██████████| 937/937 [00:11<00:00, 81.05it/s]


82 0.6468666507569296 2.1251809832380575


100%|██████████| 937/937 [00:11<00:00, 79.78it/s]


83 0.6565918524847977 2.0947860410653667


100%|██████████| 937/937 [00:11<00:00, 80.13it/s]


84 0.6404549392017763 2.1215393768430646


100%|██████████| 937/937 [00:11<00:00, 82.09it/s]


85 2.5300820431052813 1.0752392116707443


100%|██████████| 937/937 [00:11<00:00, 81.77it/s]


86 1.3018714866999501 1.518626635967731


100%|██████████| 937/937 [00:11<00:00, 81.60it/s]


87 0.7454551849891944 2.1515439611361655


100%|██████████| 937/937 [00:11<00:00, 81.72it/s]


88 1.0461109170537113 2.2869515148018316


100%|██████████| 937/937 [00:11<00:00, 82.31it/s]


89 2.0302375923353173 1.059371224017128


100%|██████████| 937/937 [00:11<00:00, 83.45it/s]


90 1.0226598069660182 1.7368146148028216


100%|██████████| 937/937 [00:11<00:00, 82.58it/s]


91 0.6109480921143274 2.2725139620973307


100%|██████████| 937/937 [00:11<00:00, 84.20it/s]


92 0.6930285288277405 2.1297720539658913


100%|██████████| 937/937 [00:10<00:00, 85.87it/s]


93 0.7285935715270272 2.012195831810779


100%|██████████| 937/937 [00:11<00:00, 83.76it/s]


94 0.7335557503150647 2.014696117526312


100%|██████████| 937/937 [00:11<00:00, 82.85it/s]


95 0.752946509781871 1.975272611085735


100%|██████████| 937/937 [00:11<00:00, 83.49it/s]


96 0.7907055010790505 1.9197103829623032


100%|██████████| 937/937 [00:11<00:00, 83.21it/s]


97 0.7867777892085251 1.912967690281578


100%|██████████| 937/937 [00:11<00:00, 83.70it/s]


98 1.9596027783421341 1.1431519975339082


100%|██████████| 937/937 [00:11<00:00, 83.06it/s]

99 3.164161084174092 0.6676997320374659



