In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


ok i understand channels are how many channel the images has 3 or 1
features d are how many features you want to create

stride 2 causes the images to downsample

In [2]:
class Discriminator(nn.Module):
    def __init__(self,channels,features_d):
        super(Discriminator,self).__init__()
        self.disc=nn.Sequential(#3,64,64 the image
            nn.Conv2d(channels,features_d,kernel_size=4,stride=2,padding=1),#64,32,32, the 64 is the number of features map
            nn.LeakyReLU(0.2),#the image size gets halved bcoz of stride 2 and padding 1
            self._block(features_d,features_d*2,4,2,1),#128,16,16
            self._block(features_d*2,features_d*4,4,2,1),#256,8,8
            self._block(features_d*4,features_d*8,4,2,1),#512,4,4
            nn.Conv2d(features_d*8,1,4,2,0),#1,1,1
            nn.Sigmoid()
        )

        #FEATURES d specifies how many kernels there are, 

    def _block(self,in_channels,out_channels,kernel_size,stride,padding):
        return nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),#bias false for batch norm
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    def forward(self,x):
        return self.disc(x)

In [3]:
class Generator(nn.Module):
    def __init__(self,z_dim,channels_img,features_g):
        super(Generator,self).__init__()
        self.gen=nn.Sequential(#z_dim is like 100,1,1
            self._block(z_dim,features_g*8,4,1,0),#img size get doubled, the zdim number gets converted to features*8->100 to 512
            self._block(features_g*8,features_g*4,4,2,1),#512->256
            self._block(features_g*4,features_g*2,4,2,1),
            self._block(features_g*2,features_g*1,4,2,1),
            nn.ConvTranspose2d(features_g*1,channels_img,4,2,1),
            nn.Tanh()
        )

    def _block(self,in_channels,out_channels,kernel_size,stride,padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    def forward(self,x):
        return self.gen(x)

In [4]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data,0.0,0.02)

In [5]:
def test():
    N,in_channels,H,W=8,3,64,64
    z_dim=100
    x=torch.randn((N,in_channels,H,W))
    disc=Discriminator(in_channels,8)
    initialize_weights(disc)
    assert disc(x).shape==(N,1,1,1 )
    gen=Generator(z_dim,in_channels,8)
    initialize_weights(gen)
    z=torch.randn((N,z_dim,1,1))
    assert gen(z).shape==(N,in_channels,H,W)
    print("Success")

In [6]:
test()

Success


In [7]:
device="cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [8]:
lr=2e-4
lr

0.0002

In [9]:
batch_size=128
img_size=64
img_channels=1
z_dim=100
features_d=64
features_g=64

In [10]:
transform=transforms.Compose([
    transforms.Resize((img_size,img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(img_channels)],[0.5 for _ in range(img_channels)])
])

In [11]:
dataset=datasets.FashionMNIST(root='data/',transform=transform,download=True,train=True)

In [12]:
loader=DataLoader(dataset,batch_size,True)

In [13]:
gen=Generator(z_dim=z_dim,channels_img=img_channels,features_g=features_g).to(device)
disc=Discriminator(channels=img_channels,features_d=features_d).to(device)

In [14]:
initialize_weights(gen)
initialize_weights(disc)

In [15]:
opt_gen=optim.Adam(gen.parameters(),lr=lr,betas=(0.5,0.999))
opt_disc=optim.Adam(disc.parameters(),lr=lr,betas=(0.5,0.999))

In [16]:
criterion=nn.BCELoss()

In [17]:
torch.manual_seed(42)
fixed_noise=torch.randn(32,z_dim,1,1).to(device)#1 number in each dimension

In [18]:
fixed_noise

tensor([[[[ 1.9269]],

         [[ 1.4873]],

         [[ 0.9007]],

         ...,

         [[ 0.4880]],

         [[ 0.7846]],

         [[ 0.0286]]],


        [[[ 0.6408]],

         [[ 0.5832]],

         [[ 1.0669]],

         ...,

         [[ 0.3581]],

         [[ 0.4788]],

         [[ 1.3537]]],


        [[[ 0.5261]],

         [[ 2.1120]],

         [[-0.5208]],

         ...,

         [[ 0.2539]],

         [[ 0.9364]],

         [[ 0.7122]]],


        ...,


        [[[ 0.1463]],

         [[ 1.1357]],

         [[-0.2689]],

         ...,

         [[-0.8678]],

         [[-0.1043]],

         [[ 0.9756]]],


        [[[-0.8829]],

         [[-0.7063]],

         [[-1.2800]],

         ...,

         [[-1.4737]],

         [[ 0.9128]],

         [[-0.8139]]],


        [[[-0.3281]],

         [[-1.6034]],

         [[ 0.1566]],

         ...,

         [[ 0.8407]],

         [[-0.1939]],

         [[-0.1404]]]], device='cuda:0')

In [19]:
writer_real=SummaryWriter(f"runs/DCGAN/real")
writer_fake=SummaryWriter(f"runs/DCGAN/fake")
step=0

In [21]:
gen.train()
disc.train()
for epoch in range(50):
    for batch_idx,(real,_) in enumerate(loader):
        real=real.to(device)
        noise=torch.randn((batch_size,z_dim,1,1)).to(device)
        fake=gen(noise)

        #train disc log(d(x)) + log(1-d(g(z)))
        disc_real=disc(real).reshape(-1)#from 32 dimnesions of 1 to 1d 32 values. it produces batch size32 1s
        loss_disc_real=criterion(disc_real,torch.ones_like(disc_real))
        disc_fake=disc(fake).reshape(-1)
        loss_disc_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
        loss_disc=(loss_disc_fake+loss_disc_real)/2
        disc.zero_grad()
        loss_disc.backward(retain_graph=True)
        opt_disc.step()

        #maximize log(d(g(z)))
        output=disc(fake).reshape(-1)
        loss_gen=criterion(output,torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if batch_idx==0:
            print(f"Epoch:{epoch}\nLoss Generator:{loss_gen}\nLoss Discriminator{loss_disc}")

            with torch.no_grad():
                fake=gen(fixed_noise)
                img_grid_real=torchvision.utils.make_grid(real[:32],normalize=True)
                img_grid_fake=torchvision.utils.make_grid(fake[:32],normalize=True)
                writer_fake.add_image("Fake",img_grid_fake,global_step=step)
                writer_real.add_image("Real",img_grid_real,global_step=step)
            step+=1

Epoch:0
Loss Generator:2.7144083976745605
Loss Discriminator0.15206554532051086
Epoch:1
Loss Generator:5.386463642120361
Loss Discriminator0.058543860912323
Epoch:2
Loss Generator:3.821481943130493
Loss Discriminator0.025911428034305573
Epoch:3
Loss Generator:4.866625785827637
Loss Discriminator0.02715264819562435
Epoch:4
Loss Generator:2.2753100395202637
Loss Discriminator0.17347268760204315
Epoch:5
Loss Generator:3.885794162750244
Loss Discriminator0.020542679354548454
Epoch:6
Loss Generator:6.381095886230469
Loss Discriminator0.0034781978465616703
Epoch:7
Loss Generator:7.132964134216309
Loss Discriminator0.0019698443356901407
Epoch:8
Loss Generator:7.538364410400391
Loss Discriminator0.0010883444920182228
Epoch:9
Loss Generator:6.461429595947266
Loss Discriminator0.0034901979379355907


KeyboardInterrupt: 