In [0]:
!pip install jcopdl gdown
!gdown https://drive.google.com/uc?id=1KaiwyyYRGW8FbvSd4Feg1i1YW2k2s30u
!unzip /content/celebA_redux.zip

In [3]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback, set_config

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

# Dataset & Dataloader (Hanya Train set)

In [0]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [0]:
bs = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # menjadi (-1, 1)
])

train_set = datasets.ImageFolder("/content/celebA_redux", transform=transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=4)

# Arsitektur & Config

In [1]:
%%writefile model_wdcgan.py
import torch
from torch import nn
from jcopdl.layers import conv_block, tconv_block, linear_block

def conv(c_in, c_out, batch_norm=True, activation="lrelu"):
    return conv_block(c_in, c_out, kernel=4, stride=2, pad=1, bias=False, batch_norm=batch_norm, activation=activation, pool_type=None)

def tconv(c_in, c_out, batch_norm=True, activation="lrelu"):
    return tconv_block(c_in, c_out, kernel=4, stride=2, pad=1, bias=False, batch_norm=batch_norm, activation=activation, pool_type=None)  


class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            conv(3, 32, batch_norm=False),          
            conv(32, 64),
            conv(64, 128),
            conv(128, 256),
            conv_block(256, 1, kernel=4, stride=1, pad=0, bias=False, activation=None, pool_type=None),
            nn.Flatten()
        )

    def forward(self, x):
        x = self.conv(x)
        return x
    
    def clip_weights(self, vmin=-0.01, vmax=0.01):
        for p in self.parameters():
            p.data.clamp_(vmin, vmax)    


class Generator(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.z_dim = z_dim
        self.tconv = nn.Sequential(
            tconv_block(z_dim, 512, kernel=4, stride=2, pad=1, bias=False, activation="lrelu", pool_type=None),
            tconv(512, 256),
            tconv(256, 128),
            tconv(128, 64),
            tconv(64, 32),
            tconv(32, 3, activation="tanh", batch_norm=False)
        )
        
    def forward(self, x):
        return self.tconv(x)

    def generate(self, n, device):
        z = torch.randn((n, self.z_dim, 1, 1), device=device)
        return self.tconv(z)

Writing model_wdcgan.py


In [0]:
config = set_config({
    "z_dim": 100,
    "batch_size": bs
})

# Training Preparation -> MCOC

In [0]:
from model_wdcgan import Critic, Generator

In [0]:
def wasserstein_loss(output, target):
    return output.mean() * target.mean()

In [0]:
D = Critic().to(device)
G = Generator(config.z_dim).to(device)

criterion = wasserstein_loss

d_optimizer = optim.RMSprop(D.parameters(), lr=1e-4)
g_optimizer = optim.RMSprop(G.parameters(), lr=1e-4)

# Training

In [0]:
# !rm -rf /content/output

In [0]:
import os
from torchvision.utils import save_image
from tqdm.auto import tqdm

os.makedirs("output/WDCGAN/", exist_ok=True)
os.makedirs("model/WDCGAN/", exist_ok=True)

In [0]:
max_epochs = 1000
for epoch in range(max_epochs):
    D.train()
    G.train()
    for i, (real_img, _) in enumerate(trainloader):
        n_data = real_img.shape[0]
        ## Real and Fake Images
        real_img = real_img.to(device)
        fake_img = G.generate(n_data, device)

        ## Real and Fake Labels
        real = -torch.ones((n_data, 1), device=device)
        fake = torch.ones((n_data, 1), device=device)

        ## Training Discriminator ##
        d_optimizer.zero_grad()
        # Real image -> Discriminator -> label Real
        output = D(real_img)
        d_real_loss = criterion(output, real)
        
        # Fake image -> Discriminator -> label Fake
        output = D(fake_img.detach())
        d_fake_loss = criterion(output, fake)
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        # Weight clipping
        D.clip_weights()

        if i % 5 == 0:
            ## Training Generator ##
            g_optimizer.zero_grad()
            # Fake image -> Discriminator -> label Real
            output = D(fake_img)
            g_loss = criterion(output, real)        
            g_loss.backward()
            g_optimizer.step()

    
    if epoch % 5 == 0:
        print(f"Epoch: {epoch:5} | D_loss: {d_loss/2:.5f} | G_loss: {g_loss:.5f}")

    if epoch % 15 == 0:
        G.eval()
        epoch = str(epoch).zfill(4)
        fake_img = G.generate(64, device)
        save_image(fake_img, f"output/WDCGAN/{epoch}.jpg", nrow=8, normalize=True)
        
        torch.save(D, "model/WDCGAN/critic.pth")
        torch.save(G, "model/WDCGAN/generator.pth")

Epoch:     0 | D_loss: -0.18430 | G_loss: 0.18856
Epoch:     5 | D_loss: -0.19710 | G_loss: 0.19879
Epoch:    10 | D_loss: -0.19709 | G_loss: 0.19913
Epoch:    15 | D_loss: -0.19839 | G_loss: 0.20005
Epoch:    20 | D_loss: -0.19784 | G_loss: 0.19978
Epoch:    25 | D_loss: -0.19721 | G_loss: 0.19912
Epoch:    30 | D_loss: -0.19846 | G_loss: 0.20011
Epoch:    35 | D_loss: -0.04034 | G_loss: 0.16971
Epoch:    40 | D_loss: -0.19820 | G_loss: 0.20010
Epoch:    45 | D_loss: -0.19777 | G_loss: 0.19922
Epoch:    50 | D_loss: -0.19442 | G_loss: 0.19559
Epoch:    55 | D_loss: -0.19797 | G_loss: 0.19981
Epoch:    60 | D_loss: -0.19739 | G_loss: 0.19946
Epoch:    65 | D_loss: -0.19824 | G_loss: 0.19982
Epoch:    70 | D_loss: -0.19810 | G_loss: 0.19980
Epoch:    75 | D_loss: -0.19585 | G_loss: 0.19823
Epoch:    80 | D_loss: -0.18828 | G_loss: 0.18834
Epoch:    85 | D_loss: -0.19728 | G_loss: 0.19964
Epoch:    90 | D_loss: -0.01078 | G_loss: -0.19155
Epoch:    95 | D_loss: -0.19719 | G_loss: 0.19912