In [51]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter 
import matplotlib.pyplot as plt
from torchvision.utils import save_image

In [52]:
lr = 3e-4
batch_size = 32
epochs = 50
z_dim = 64
loss = nn.BCELoss()
sample_size = 5

In [53]:
class Generator(nn.Module):
    def __init__(self,z_dim):
        super().__init__()
        self.first = nn.Linear(z_dim,128*7*7)
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(128,64,4,2,1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64,1,4,2,1),
            nn.Tanh()

        )
    def forward(self,x):
        x=self.first(x)
        x = x.view(-1,128,7,7)
        return self.gen(x)

In [54]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(1,64,4,2,1),
            nn.LeakyReLU(.2),
            nn.Conv2d(64,128,4,2,1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(.2),
            nn.Flatten(),
            nn.Linear(128*7*7,1),
            nn.Sigmoid(),
        )
    def forward(self,x):
        return self.disc(x)
        
        


In [55]:
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,),(0.5,))
    ]
)
dataset = datasets.FashionMNIST(root="dataset/",transform=transforms,download=True)
loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [56]:
disc = Discriminator()
gen = Generator(z_dim)
disc_optim = optim.Adam(disc.parameters(),lr=lr)
gen_optim = optim.Adam(gen.parameters(),lr=lr)
fixed_noise = torch.randn((batch_size,z_dim))

In [57]:
for epoch in range(epochs):
    for batch_idx , (real,_) in enumerate(loader):
        real = real.view(-1,1,28,28)
        noise = torch.randn(batch_size,z_dim)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        disc_fake = disc(fake.detach()).view(-1)
        disc_loss_real = loss(disc_real,torch.ones_like(disc_real))
        disc_loss_fake = loss(disc_fake,torch.zeros_like(disc_fake))
        loss_D = (disc_loss_real+disc_loss_fake)/2
        disc.zero_grad()
        loss_D.backward()
        disc_optim.step()

        output = disc(fake).view(-1)
        loss_G = loss(output,torch.ones_like(output))
        gen.zero_grad()
        loss_G.backward()
        gen_optim.step()
        
        if epoch%sample_size==0:
            with torch.no_grad():
                fake = gen(fixed_noise).view(-1,1,28,28)
                grid = torchvision.utils.make_grid(fake,normalize=True)
                plt.imshow(grid.permute(1,2,0))
                plt.axis("off")
                torchvision.utils.save_image(
                    grid,
                    f"samplesDCGAN/grid_epoch_{epoch}.png"
                )
                plt.close()