In [1]:
import os
import torch
from torch import nn, optim
from jcopdl.callback import set_config, Callback
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from IPython.display import Image

os.makedirs("output", exist_ok=True)

In [2]:
device = ("cuda:0" if torch.cuda.is_available() else "cpu")
device

'cpu'

# Dataset dan Dataloader (hanya trainset)

In [4]:
bs = 64

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(), 
    transforms.Normalize([0.5], [0.5]) # supaya menjadi (-1, 1) supaya lebih stabil
])

train_set = datasets.ImageFolder("data/train", transform=transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=4)



In [5]:
train_set

Dataset ImageFolder
    Number of datapoints: 2000
    Root location: data/train
    StandardTransform
Transform: Compose(
               Grayscale(num_output_channels=1)
               ToTensor()
               Normalize(mean=[0.5], std=[0.5])
           )

In [16]:
%%writefile model_cgan.py

import torch
from torch import nn
from jcopdl.layers import linear_block


class Discriminator(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.flatten = nn.Flatten()
        self.embed_label = nn.Embedding(n_classes, n_classes)
        self.fc = nn.Sequential(
            linear_block(784 + n_classes, 512, activation="lrelu"),
            linear_block(512, 256, activation="lrelu"),
            linear_block(256, 128, activation="lrelu"),
            linear_block(128, 1, activation="sigmoid")
        )
    def forward(self, x, y):
        x = self.flatten(x)
        y = self.embed_label(y)
        x = torch.cat([x, y], dim=1)
        return self.fc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, n_classes):
        super().__init__()
        self.z_dim = z_dim
        self.embed_label = nn.Embedding(n_classes, n_classes)
        self.fc = nn.Sequential(
            linear_block(z_dim + n_classes, 128, activation="lrelu"),
            linear_block(128, 256, activation="lrelu", batch_norm=True),
            linear_block(256, 512, activation="lrelu", batch_norm=True),
            linear_block(512, 1024, activation="lrelu", batch_norm=True),
            linear_block(1024, 784, activation="tanh")
        )
        
    def forward(self, x, y):
        y = self.embed_label(y)
        x = torch.cat([x, y], dim=1)
        return self.fc(x)
    
    def generate(self, labels, device):
        z = torch.randn((len(labels), self.z_dim), device=device)
        return self.forward(z, labels)

Overwriting model_cgan.py


In [17]:
config = set_config({
    "z_dim": 100,
    "n_classes": len(train_set.classes),
    "batch_size": bs
})

In [18]:
config.n_classes

10

# Training preparation

- tidak perlu callback dan early stopping

In [27]:
from model_cgan import Discriminator, Generator

In [28]:
D = Discriminator(config.n_classes).to(device)
G = Generator(config.z_dim, config.n_classes).to(device)

criterion = nn.BCELoss()

d_optimizer = optim.AdamW(D.parameters(), lr=0.0002)
g_optimizer = optim.AdamW(G.parameters(), lr=0.0002)

# Training

In [16]:
os.makedirs("output/GAN/", exist_ok=True)
os.makedirs("model/GAN/", exist_ok=True)

In [None]:
max_epochs = 300
fix_labels = torch.randint(10, (64), device=device)

for epoch in range(max_epochs):
    D.train()
    G.train()
    for real_img, labels in trainloader:
        n_data = real_img.shape[0]
        
        ## Real and Fake Images
        real_img, labels = real_img.to(device), labels.to(device)
        fake_img = G.generate(labels, device)
        
        ## Real and Fake Labels
        real = torch.ones((n_data, 1), device=device)
        fake = torch.zeros((n_data, 1), device=device)
        
        ## Training Discriminator
        d_optimizer.zero_grad()
        # Real Image -> Discriminator -> label Real
        output = D(real_img, labels)
        d_real_loss = criterion(output, real)
        
        # Fake Image -> Diskriminator -> label fake
        output = D(fake_img.detach(), labels)
        d_fake_loss = criterion(output, fake)
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        # Training Generator
        g_optimizer.zero_grad()
        # Fake image -> Discriminator -> tp label real
        output = D(fake_img, labels)
        g_loss = criterion(output, real)
        g_loss.backward()
        g_optimizer.step()
        
    if epoch % 5 == 0:
        print(f"Epoch: {epoch:5} | D_loss: {d_loss/2:.5f} | G_loss: {g_loss:.5f}")
        
    if epoch % 15 == 0:
        G.eval()
        epoch = str(epoch).zfill(4)
        fake_img = G.generate(fix_labels, device)
        save_image(fake_img.view(-1, 1, 28, 28), f"output/CGAN/{epoch}.jpg", nrow=8, normalize=True)
        
        torch.save(D, "model/CGAN/discriminator.pth")
        torch.save(G, "model/CGAN/generator.pth")
        
        
        

Epoch:     0 | D_loss: 0.53385 | G_loss: 1.08915
Epoch:     5 | D_loss: 0.03304 | G_loss: 20.03904
Epoch:    10 | D_loss: 0.00174 | G_loss: 13.08765
Epoch:    15 | D_loss: 0.16183 | G_loss: 7.07853
Epoch:    20 | D_loss: 0.13427 | G_loss: 12.71363
Epoch:    25 | D_loss: 0.15090 | G_loss: 6.76249
Epoch:    30 | D_loss: 0.49544 | G_loss: 6.92281
Epoch:    35 | D_loss: 0.16435 | G_loss: 5.21741
Epoch:    40 | D_loss: 0.29345 | G_loss: 5.48549
Epoch:    45 | D_loss: 0.18758 | G_loss: 4.26791
Epoch:    50 | D_loss: 0.14180 | G_loss: 5.02771
Epoch:    55 | D_loss: 0.45637 | G_loss: 6.53112
Epoch:    60 | D_loss: 0.33287 | G_loss: 2.40704
Epoch:    65 | D_loss: 0.20383 | G_loss: 2.45407
Epoch:    70 | D_loss: 0.10645 | G_loss: 4.25715
Epoch:    75 | D_loss: 0.02405 | G_loss: 5.92743
Epoch:    80 | D_loss: 0.12346 | G_loss: 5.20580
Epoch:    85 | D_loss: 0.07463 | G_loss: 6.37552
Epoch:    90 | D_loss: 0.28937 | G_loss: 4.03237
Epoch:    95 | D_loss: 0.14460 | G_loss: 9.43227


# Generate image

In [None]:
model = "CGAN"
G = torch.load(f"model/{model}/generator.pth", map_location="cpu").eval()
labels = torch.ones(64, device=device).to(int) * 6
fake_img = G.generate(labels, device)
save_image(fake_img.view(-1, 1, 28, 28), f"output/{model}.jpg", nrow=8, normalize=True)
Image(f"output/{model}.jpg")