In [71]:
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [72]:
class Discriminator(nn.Module):
    def __init__(self,channels,features_d):
        super(Discriminator,self).__init__()
        self.disc=nn.Sequential(
            nn.Conv2d(channels,features_d,4,2,1),
            nn.LeakyReLU(0.2),
            self._block(features_d,features_d*2,4,2,1),
            self._block(features_d*2,features_d*4,4,2,1),
            self._block(features_d*4,features_d*8,4,2,1),
            nn.Conv2d(features_d*8,1,4,2,0),
            nn.Sigmoid()
        )

    def _block(self,in_channels,out_channels,kernel_size,stride,padding):
        return nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    def forward(self,x):
        return self.disc(x)

In [73]:
class Generator(nn.Module):
    def _block(self,in_channels,out_channels,kernel_size,stride,padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    def __init__(self,z_dim,img_channels,features_g):
        super(Generator,self).__init__()
        self.gen=nn.Sequential(
            self._block(z_dim,features_g*8,4,1,0),
            self._block(features_g*8,features_g*4,4,2,1),
            self._block(features_g*4,features_g*2,4,2,1),
            self._block(features_g*2,features_g*1,4,2,1),
            nn.ConvTranspose2d(features_g,img_channels,4,2,1),
            nn.Tanh()
        )
    def forward(self,x):
        return self.gen(x)

In [74]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m,(nn.ConvTranspose2d,nn.Conv2d,nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data,0.0,0.02)

In [75]:
def test():
    N,in_channels,H,W=8,3,64,64
    z_dim=100
    x=torch.randn((N,in_channels,H,W))
    disc=Discriminator(in_channels,8)
    initialize_weights(disc)
    assert disc(x).shape==(N,1,1,1 )
    gen=Generator(z_dim,in_channels,8)
    initialize_weights(gen)
    z=torch.randn((N,z_dim,1,1))
    assert gen(z).shape==(N,in_channels,H,W)
    print("Success")
test()

Success


In [76]:
device="cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [77]:
lr=2e-4
lr

0.0002

In [78]:
img_size=64
batch_size=64
z_dim=256
img_channels=3
features_d=64
features_g=64

In [79]:
transform=transforms.Compose([
    transforms.Resize((img_size,img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(img_channels)],[0.5 for _ in range(img_channels)])
])

In [80]:
dataset=datasets.ImageFolder(root='data/celeba/',transform=transform)

In [81]:
loader=DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=4)

In [82]:
gen=Generator(z_dim=z_dim,img_channels=img_channels,features_g=features_g).to(device)
disc=Discriminator(channels=img_channels,features_d=features_d).to(device)

In [83]:
initialize_weights(gen)
initialize_weights(disc)

In [84]:
opt_gen=optim.Adam(gen.parameters(),lr=lr,betas=(0.5,0.999))
opt_disc=optim.Adam(disc.parameters(),lr=lr,betas=(0.5,0.99))

In [85]:
criterion=nn.BCELoss()

In [86]:
fixed_noise=torch.randn((32,z_dim,1,1)).to(device)

In [87]:
fixed_noise

tensor([[[[-0.8216]],

         [[ 1.4291]],

         [[-0.1752]],

         ...,

         [[ 1.3314]],

         [[-0.6916]],

         [[-0.4866]]],


        [[[-0.1484]],

         [[-0.0366]],

         [[-0.3422]],

         ...,

         [[-2.3145]],

         [[ 1.5720]],

         [[ 1.0168]]],


        [[[-0.1716]],

         [[-0.4857]],

         [[-0.4893]],

         ...,

         [[-0.7237]],

         [[-0.5232]],

         [[ 0.6144]]],


        ...,


        [[[-0.4139]],

         [[-0.2013]],

         [[ 1.1573]],

         ...,

         [[ 0.2161]],

         [[-0.8906]],

         [[-1.4038]]],


        [[[-0.3960]],

         [[ 0.6921]],

         [[-1.7239]],

         ...,

         [[-0.4450]],

         [[-1.1471]],

         [[ 0.9505]]],


        [[[ 1.2643]],

         [[ 0.7105]],

         [[-0.2253]],

         ...,

         [[-0.2549]],

         [[ 0.6981]],

         [[-0.0171]]]], device='cuda:0')

In [88]:
writer_real=SummaryWriter(f"runs/RDCGAN/real")
writer_fake=SummaryWriter(f"runs/RDCGAN/fake")
step=0

In [89]:
for epochs in range(20):
    for batch_idx,(real,_) in enumerate(loader):
        real=real.to(device)
        noise=torch.randn((batch_size,z_dim,1,1)).to(device)
        fake=gen(noise)

        disc_real=disc(real).reshape(-1)
        real_labels=torch.full_like(disc_real,0.9)
        lossD_real=criterion(disc_real,real_labels)
        disc_fake=disc(fake.detach()).reshape(-1)
        fake_labels=torch.full_like(disc_fake,0.1)
        lossD_fake=criterion(disc_fake,fake_labels)
        lossD=(lossD_fake+lossD_real)/2
        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        output=disc(fake).reshape(-1)
        lossG=criterion(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx%100==0:
            print(f"Epoch:{epochs}\nLoss Generator:{lossG}\nLoss Discriminator:{lossD}")
            with torch.no_grad():
                fake=gen(fixed_noise)
                img_grid_fake=torchvision.utils.make_grid(fake[:32],normalize=True)
                img_grid_real=torchvision.utils.make_grid(real[:32],normalize=True)
                writer_fake.add_image("Fake",img_grid_fake,global_step=step)
                writer_real.add_image("Real",img_grid_real,global_step=step)
        step+=1

Epoch:0
Loss Generator:0.8140408992767334
Loss Discriminator:0.6883387565612793
Epoch:0
Loss Generator:2.3096766471862793
Loss Discriminator:0.3255464732646942
Epoch:0
Loss Generator:2.101498603820801
Loss Discriminator:0.3811744749546051
Epoch:0
Loss Generator:1.3203067779541016
Loss Discriminator:0.5154972076416016
Epoch:0
Loss Generator:1.7277944087982178
Loss Discriminator:0.544561505317688
Epoch:0
Loss Generator:1.6224596500396729
Loss Discriminator:0.4926971197128296
Epoch:0
Loss Generator:1.4252588748931885
Loss Discriminator:0.5340104103088379
Epoch:0
Loss Generator:1.264326572418213
Loss Discriminator:0.5867581367492676
Epoch:0
Loss Generator:1.2286951541900635
Loss Discriminator:0.6103821396827698
Epoch:0
Loss Generator:1.3118054866790771
Loss Discriminator:0.6634307503700256
Epoch:0
Loss Generator:1.9377597570419312
Loss Discriminator:0.4748638868331909
Epoch:0
Loss Generator:1.59598970413208
Loss Discriminator:0.49841269850730896
Epoch:0
Loss Generator:1.9941774606704712
Lo

KeyboardInterrupt: 