In [79]:
import torch
import torchvision
from torch import nn

In [80]:
class Generator(nn.Module):
    def __init__(self,in_dim,n_channel,n_dim):
        super(Generator,self).__init__()
        self.blocks = nn.ModuleList()
        self.blocks.append(self.block(100,n_dim*8,4,1,0))
        for x in range(3):
            self.blocks.append(self.block(n_dim*8//(2**(x)),n_dim*8//(2**(x+1)),4,2,1))
        self.end = nn.Sequential(
            nn.ConvTranspose2d(n_dim,n_channel,4,2,1),
            nn.Tanh()
        )

    def block(self,in_chan,out_chan,kernel, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_chan,out_chan,kernel,stride,padding),
            nn.BatchNorm2d(out_chan),
            nn.ReLU()
        )

    def forward(self,x):
        for layer in self.blocks:
            x= layer(x)
        return self.end(x)

    # def block2(self,in_chan,out_chan,kernel, stride, padding=1):
    #     return nn.Sequential(
    #         nn.ConvTranspose2d(in_chan,in_chan,kernel,stride,padding),
    #         nn.ConvTranspose2d(in_chan,out_chan,kernel,stride,padding),
    #         nn.BatchNorm2d(out_chan),
    #         nn.ReLU()
    #     )

In [81]:
class Discriminator(nn.Module):
    def __init__(self,n_channel,n_dim):
        super(Discriminator,self).__init__()
        self.blocks = nn.ModuleList()
        self.blocks.append(nn.Sequential(nn.Conv2d(n_channel,n_dim,4,2,1),
                                         nn.LeakyReLU(0.2)),)
        for x in range(3):
            self.blocks.append(self.block(n_dim*(2**(x)),n_dim*(2**(x+1)),4,2,1))
        self.end=nn.Sequential(nn.Conv2d(n_dim*8,1,4,1,0),
                               nn.Sigmoid())


    def block(self,in_chan,out_chan,kernel, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_chan,out_chan,kernel,stride,padding),
            nn.BatchNorm2d(out_chan),
            nn.LeakyReLU(0.2)
        )
    def forward(self,x):
        for layer in self.blocks:
            x= layer(x)
        return self.end(x)

In [82]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [83]:
gen = Generator(100,3,128)
noise = torch.randn((2,100,1,1))
print(gen(noise).shape)

torch.Size([2, 3, 64, 64])


In [84]:
disc = Discriminator(3,128)
print(disc(gen(noise)).shape)

torch.Size([2, 1, 1, 1])


In [85]:
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [86]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3
NOISE_DIM = 100
NUM_EPOCHS = 15
FEATURES_NUM = 128
print(device)

cuda


In [87]:
dataset = datasets.ImageFolder(root='dataset/celeba',
                           transform=transforms.Compose([
                               transforms.Resize(IMAGE_SIZE),
                               transforms.CenterCrop(IMAGE_SIZE),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))


In [88]:
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")

In [89]:
def train(loader,generator,discriminator,opt_gen,opt_disc,loss_gen,loss_disc):
    loop= tqdm(loader,leave=False)
    for batch_idx,(real,_) in enumerate(loop):
        real=real.to(device)
        noise = torch.randn((BATCH_SIZE,NOISE_DIM,1,1)).to(device)
        fake = generator(noise)

        #Discriminator loss
        disc_real = discriminator(real).reshape(-1)
        disc_loss_true = loss_disc(disc_real,torch.ones_like(disc_real))
        disc_fake = discriminator(fake.detach()).reshape(-1)
        disc_loss_false = loss_disc(disc_fake,torch.zeros_like(disc_fake))
        disc_loss = (disc_loss_false+disc_loss_true)/2
        discriminator.zero_grad()
        disc_loss.backward()
        opt_disc.step()

        #Generator loss
        gen_res = discriminator(fake).reshape(-1)
        gen_loss = loss_gen(gen_res,torch.ones_like(gen_res))
        generator.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        #Check
        if batch_idx % 10 == 0:
            with torch.no_grad():
                fake = gen(fixed_noise)
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                writer_real.add_image("Real", img_grid_real, global_step=batch_idx)
                writer_fake.add_image("Fake", img_grid_fake, global_step=batch_idx)


In [90]:
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_NUM).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_NUM).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
loss_disc = nn.BCELoss()
loss_gen = nn.MSELoss()

fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)

In [None]:
for epoch in range(NUM_EPOCHS):
    train(dataloader,gen,disc,opt_gen,opt_disc,loss_gen,loss_disc)

  4%|▍         | 71/1583 [00:22<10:31,  2.39it/s]

In [None]:
%tensorboard --logdir logs