In [1]:
import torch
from torch import nn , optim
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
#from torch.utils.tensorboard import SummaryWriter


In [2]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.disc(x)

In [3]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim) -> None:
        super().__init__()
        
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, img_dim),
            #nn.Tanh()
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.gen(x)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device


'cuda'

In [5]:
#device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 16
image_dim = 28*28*1
batch_size = 4
num_epochs = 2

## creating the model
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device) #* for testing

## Creating the trasnforms

transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])

In [6]:
disc

Discriminator(
  (disc): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.1)
    (2): Linear(in_features=128, out_features=1, bias=True)
    (3): Sigmoid()
  )
)

In [7]:
gen

Generator(
  (gen): Sequential(
    (0): Linear(in_features=16, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.1)
    (2): Linear(in_features=128, out_features=784, bias=True)
    (3): Sigmoid()
  )
)

In [8]:
## Creating the dataset
dataset  = datasets.MNIST(
    root= "dataset/",
    transform=transforms,
    download=True
)

loader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True
)

In [9]:
## Loss and optimizer

opt_disc = optim.Adam(disc.parameters(), lr= lr)
opt_gen = optim.Adam(gen.parameters(), lr= lr)
criterion = nn.BCELoss()

In [10]:
## Tensorboard

# writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
# writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
# step = 0


In [11]:
noise = torch.randn((batch_size, z_dim)).to(device)
noise

tensor([[-0.2455,  0.0692, -1.3365, -0.2672, -0.1551, -0.2727,  0.3241,  0.1741,
          1.5673, -0.6120,  1.8356, -1.1575, -0.2227,  0.6297, -2.4234,  0.4870],
        [-0.2768, -0.6736, -2.2939, -0.5892, -1.1272,  0.0913, -1.7003,  0.2753,
         -0.5953,  1.2793, -1.2535, -1.1080, -0.4976,  0.8704,  1.5902,  1.2096],
        [-2.4683,  0.1226, -1.2422,  0.6586, -0.4295,  0.6612,  0.5049, -0.3468,
          0.8434, -0.2277, -0.8666,  1.5982, -0.8927, -1.8239,  0.3297,  0.6832],
        [-0.2576,  0.9524,  0.4714,  0.9698, -1.1119, -1.2335,  0.0649,  0.2261,
          0.1460, -1.2439, -1.4365,  1.2230, -0.2849, -0.1662, -0.7894,  0.4462]],
       device='cuda:0')

In [12]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        #
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]
        
        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        #* 
        
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        
        
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

In [None]:
for epoch in range(num_epochs):
    for batch_idx , (real, _) in enumerate(loader):
        #* view(-1) is used to flatten the image. instead of `torch.flatten`
        real = real.view(-1,784).to(device)
        batch_size = real.shape[0]
        
        ## training the Discriminator
        #* train disc = max [log(D(real)) + log(1-D(G(noise)))]
            
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)        
        disc_real = real.view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
            
        
        #disc_fake = disc(fake).view(-1)
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        ## training the generator
        #* train gen = min [log(1-D(G(noise)))] or max [log(D(G(noise))]
        output = disc(fake).view(-1)
        lossG  = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        
        # ## Tensorboard
        # if batch_idx == 0:
        #     print(f"Epoch [{epoch}/{num_epochs}]" \
        #         f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}")
            
        #     with torch.inference_mode():
        #         fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        #         data = real.resape(-1, 1, 28, 28)
        #         img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        #         img_grid_real = torchvision.utils.make_grid(data, normalize=True)
                
        #         writer_fake.add_image(
        #             "Mnist Fake Images", img_grid_fake, global_step=step)

        #         writer_real.add_image(
        #             "Mnist Real Images", img_grid_real, global_step=step)
        #         step += 1

tensor([[0.6114, 0.4754, 0.4116, 0.4589, 0.4259, 0.3793, 0.4241, 0.5815, 0.4967,
         0.5552, 0.4638, 0.5414, 0.5225, 0.5568, 0.4629, 0.4707, 0.5518, 0.4813,
         0.5954, 0.5024, 0.4019, 0.5248, 0.5120, 0.4481, 0.4732, 0.3514, 0.4941,
         0.5482, 0.4941, 0.4761, 0.4360, 0.4404, 0.6296, 0.4801, 0.5077, 0.4938,
         0.5698, 0.4963, 0.4856, 0.4555, 0.4971, 0.5313, 0.5889, 0.5020, 0.5505,
         0.5475, 0.3688, 0.5853, 0.4621, 0.4411, 0.5350, 0.4753, 0.5172, 0.5200,
         0.4688, 0.4287, 0.4640, 0.5593, 0.5872, 0.4705, 0.5208, 0.5932, 0.5931,
         0.4964, 0.5514, 0.5012, 0.5354, 0.3962, 0.5685, 0.4883, 0.4482, 0.5479,
         0.5057, 0.3928, 0.4911, 0.5538, 0.6254, 0.5168, 0.5749, 0.5659, 0.4511,
         0.5307, 0.4597, 0.5721, 0.5198, 0.6467, 0.5374, 0.4817, 0.5435, 0.6040,
         0.4249, 0.5923, 0.5161, 0.5288, 0.5027, 0.3562, 0.5244, 0.4063, 0.5166,
         0.6101, 0.5153, 0.5044, 0.5374, 0.5459, 0.4352, 0.3775, 0.4792, 0.5298,
         0.5532, 0.5520, 0.4

RuntimeError: all elements of input should be between 0 and 1