source:
- [link](https://medium.com/codex/building-a-vanilla-gan-with-pytorch-ffdf26275b70)
- [link](https://www.kaggle.com/code/rafat97/pytorch-vanilla-gan)

In [None]:
import os
import time
import torch
from torch import nn, optim
!pip install torchsummary
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from PIL import Image
from typing import Any, Callable, Optional
from tqdm import tqdm

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DATASET_PATH = '/kaggle/input/batik-dataset-for-gan/Dataset Final'
BATCH_SIZE = 32
NUM_WORKERS = 4
SHUFFLE = True
PIN_MEMORY = False

RESOLUTION = 128
LATENT_DIM = 512

LEARNING_RATE = 1e-4
NUM_EPOCHS = 200

In [None]:
class BatikGANDataset(Dataset):
    '''
    BatikGAN Dataset Implementation with lazy loading.

    Args:
        path (str): Path to image directory.
        transform (callable, optional): Image transforms that takes a PIL.Image as input. Default value is None.
    '''

    def __init__(
        self,
        path: str,
        transform: Optional[Callable[Image.Image, Any]] = None
    ) -> None:
        super(BatikGANDataset, self).__init__()
        self.path = path
        self.transform = transform
        self.files = [ f for f in os.listdir(self.path) if f.endswith(('.png', '.jpg', '.jpeg')) ]

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, index: int) -> torch.Tensor:
        img_path = os.path.join(self.path, self.files[index])
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

In [None]:
class Generator(nn.Module):
    def __init__(self, resolution, latent_dim, hidden_dim=512, channels=3):
        super(Generator, self).__init__()
        output_dim = resolution * resolution * channels

        self.layers = nn.Sequential(
            self.gen_block(latent_dim, hidden_dim),
            self.gen_block(hidden_dim, hidden_dim*2),
             self.gen_block(hidden_dim*2, hidden_dim*2),
            self.gen_block(hidden_dim*2, hidden_dim),
            self.gen_block(hidden_dim, hidden_dim),
            self.gen_block(hidden_dim, hidden_dim//2),
            
            nn.Linear(hidden_dim//2, output_dim),
            nn.Tanh()
        )

    def gen_block(self, input_dim, output_dim):
        return nn.Sequential(
            nn.Linear(input_dim, output_dim, bias=False),
            nn.BatchNorm1d(output_dim, 0.8),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
gen = Generator(RESOLUTION, LATENT_DIM, channels=1).to(DEVICE)
summary(gen, input_size=(LATENT_DIM,), batch_size=BATCH_SIZE)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, resolution, hidden_dim=512, channels=3):
        super(Discriminator, self).__init__()
        input_dim = resolution * resolution * channels
        
        self.layers = nn.Sequential(
            self.disc_block(input_dim, hidden_dim*4, dropout=0.3),
            self.disc_block(hidden_dim*4, hidden_dim*2, dropout=0.3),
            self.disc_block(hidden_dim*2, hidden_dim, dropout=0.3),
            self.disc_block(hidden_dim, hidden_dim//2, dropout=0),  
            nn.Linear(hidden_dim//2, 1),
        )
        
    def disc_block(self, input_dim, output_dim, dropout=0):
        layers = [
            nn.Linear(input_dim, output_dim),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        if dropout > 0:
            layers.append(nn.Dropout(dropout))
        return nn.Sequential(*layers)
        
    def forward(self, x):
        return self.layers(x)

In [None]:
disc = Discriminator(RESOLUTION, channels=1).to(DEVICE)
summary(disc, input_size=(RESOLUTION*RESOLUTION*1,), batch_size=BATCH_SIZE)

In [None]:
def noise(batch_size, latent_dim, device):
    return torch.randn(batch_size, latent_dim, device=DEVICE)

In [None]:
def get_loader(resolution):
    transform = transforms.Compose([
        transforms.Resize((resolution, resolution)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5], inplace=True)
    ])

    dataset = BatikGANDataset(DATASET_PATH, transform=transform)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    return loader

In [None]:
criterion = nn.BCEWithLogitsLoss()

optim_g = optim.Adam(gen.parameters(), lr=LEARNING_RATE)
optim_d = optim.Adam(disc.parameters(), lr=LEARNING_RATE)

In [None]:
def save_checkpoint(generator, discriminator, optim_g, optim_d, epoch, folder="checkpoints"):
    os.makedirs(folder, exist_ok=True)

    checkpoint = {
        'generator': generator.state_dict(),
        'discriminator': discriminator.state_dict(),
        'optim_g': optim_g.state_dict(),
        'optim_d': optim_d.state_dict(),
        'epoch': epoch,
    }

    filename = f"{folder}/checkpoint_epoch{epoch}.pt"
    torch.save(checkpoint, filename)
    print(f"Saved checkpoint: {filename}")

In [None]:
def save_result(generator):
    torch.save(generator.state_dict(), "generator_final.pt")
    print("Final generator model saved as generator_final.pt")

In [None]:
G_losses = []
D_losses = []
def train(channels=3):
    fixed_noise = torch.randn(16, LATENT_DIM, device=DEVICE)
    dataloader = get_loader(RESOLUTION)

    gen.train()
    disc.train()

    print("Starting Training...\n")

    for epoch in range(NUM_EPOCHS):
        gen_loss_epoch = 0.0
        disc_loss_epoch = 0.0
        steps = 0
        start = time.time()

        pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
        for batch_idx, real in pbar:
            real = real.to(DEVICE)
            batch_size = real.size(0)
            real = real.view(batch_size, -1)

            ### Train Discriminator ###
            optim_d.zero_grad()

            real_labels = torch.ones((batch_size, 1), device=DEVICE)
            real_preds = disc(real)
            real_loss = criterion(real_preds, real_labels)

            noise = torch.randn(batch_size, LATENT_DIM, device=DEVICE)
            fake = gen(noise)
            fake_labels = torch.zeros((batch_size, 1), device=DEVICE)
            fake_preds = disc(fake.detach())
            fake_loss = criterion(fake_preds, fake_labels)

            disc_loss = (real_loss + fake_loss) / 2
            disc_loss.backward()
            optim_d.step()

            ### Train Generator ###
            optim_g.zero_grad()

            gen_preds = disc(fake)
            gen_labels = torch.ones((batch_size, 1), device=DEVICE)
            gen_loss = criterion(gen_preds, gen_labels)

            gen_loss.backward()
            optim_g.step()

            # Logging
            gen_loss_epoch += gen_loss.item()
            disc_loss_epoch += disc_loss.item()
            steps += 1

            pbar.set_postfix({
                "D_loss": f"{disc_loss.item():.4f}",
                "G_loss": f"{gen_loss.item():.4f}"
            })

        print(f"[Epoch {epoch+1}/{NUM_EPOCHS}] Time: {time.time()-start:.2f}s "
              f"| G Loss: {gen_loss_epoch/steps:.4f} | D Loss: {disc_loss_epoch/steps:.4f}")
        
        G_losses.append(gen_loss_epoch / steps)
        D_losses.append(disc_loss_epoch / steps)

        # Save model and sample generator images every 20 epochs
        if (epoch+1) % 20 == 0:
            with torch.no_grad():
                samples = gen(fixed_noise).reshape(-1, channels, RESOLUTION, RESOLUTION)
                samples = (samples + 1) / 2
                os.makedirs("generated_images", exist_ok=True)
                save_image(samples, f"generated_images/epoch_{epoch+1}.png", nrow=4)
                print(f"Saved generated image: epoch_{epoch+1}.png")
                
            save_checkpoint(gen, disc, optim_g, optim_d, epoch + 1)
        if epoch == (NUM_EPOCHS-1):
            save_result (gen)

In [None]:
# train(channels=1)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(G_losses, label="Generator Loss")
plt.plot(D_losses, label="Discriminator Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Generator and Discriminator Loss During Training")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("loss_plot.png")
plt.show()

In [None]:
! pip install -q onnx torchinfo torchmetrics[image]

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchmetrics.image.perceptual_path_length import PerceptualPathLength
from torchvision.utils import save_image
import math

In [None]:
REAL_DIR = 'fid_images/real'
FAKE_DIR = 'fid_images/fake'
PT_PATH = '/kaggle/input/vanillagan-bnw/pytorch/default/1/generator_final.pt'

generator = Generator(RESOLUTION, LATENT_DIM, channels=1).to(DEVICE)
generator.load_state_dict(torch.load(PT_PATH, map_location=DEVICE))

os.makedirs(REAL_DIR, exist_ok=True)
os.makedirs(FAKE_DIR, exist_ok=True)

test_loader = get_loader(RESOLUTION)

counter = 0
for real_img in tqdm(test_loader, desc='Saving images'):
    real_img = real_img.to(DEVICE)
    b = real_img.size(0)

    z = torch.randn(b, LATENT_DIM, device=DEVICE)
    fake_img = generator(z)
    fake_img = fake_img.view(b, 1, RESOLUTION, RESOLUTION)

    # Denormalize dari [-1, 1] ke [0, 1]
    real_img = (real_img * 0.5 + 0.5).clamp(0, 1)
    fake_img = (fake_img * 0.5 + 0.5).clamp(0, 1)

    for i in range(b):
        save_image(real_img[i], f"{REAL_DIR}/real_{counter + i}.png")
        save_image(fake_img[i], f"{FAKE_DIR}/fake_{counter + i}.png")

    counter += b
    del z, real_img, fake_img

In [None]:
from PIL import Image
from torchvision import transforms
import os

fid = FrechetInceptionDistance(feature=2048).to(DEVICE)

transform = transforms.Compose([
    transforms.Resize((RESOLUTION, RESOLUTION)),
    transforms.ToTensor(),
])

# Load real images
for file in os.listdir(REAL_DIR):
    img_path = os.path.join(REAL_DIR, file)
    img = Image.open(img_path).convert("L") 
    img = img.convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(DEVICE)
    img_uint8 = (img_tensor * 255).clamp(0, 255).to(torch.uint8)
    fid.update(img_uint8, real=True)

# Load fake images
for file in os.listdir(FAKE_DIR):
    img = Image.open(os.path.join(FAKE_DIR, file)).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(DEVICE)
    img_uint8 = (img_tensor * 255).clamp(0, 255).to(torch.uint8)
    fid.update(img_uint8, real=False)

# Compute FID
fid_score = fid.compute().item()
print(f'FID from folders: {fid_score:.4f}')

In [None]:
from torchmetrics.image.inception import InceptionScore

is_metric = InceptionScore().to(DEVICE)
transform = transforms.Compose([
    transforms.Resize((RESOLUTION, RESOLUTION)),
    transforms.ToTensor(),
])

for file in os.listdir(FAKE_DIR):
    img = Image.open(os.path.join(FAKE_DIR, file)).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(DEVICE)  # [1, 3, H, W]
    img_uint8 = (img_tensor * 255).clamp(0, 255).to(torch.uint8)
    is_metric.update(img_uint8)

is_score = is_metric.compute()
print(f"Inception Score: {is_score[0].item():.4f} ± {is_score[1].item():.4f}")


In [None]:
from torchmetrics.image.perceptual_path_length import PerceptualPathLength
class PPLWrapper(nn.Module):
    def __init__(self):
        super(PPLWrapper, self).__init__()

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        return generator(z)

    def sample(self, num_samples: int) -> torch.Tensor:
        return torch.randn(num_samples, LATENT_DIM, device=DEVICE)

generator.eval()
ppl = PerceptualPathLength().to(DEVICE)

ppl_mean, ppl_std, ppl_raw = ppl(PPLWrapper())
print(f'PPL: {ppl_mean.item()} +/- {ppl_std.item()}')

In [None]:
generator.eval()

# Generate 64 gambar
z = torch.randn(64, LATENT_DIM, device=DEVICE)
with torch.no_grad():
    fake_img = generator(z)
    fake_img = fake_img.view(64, 1, RESOLUTION, RESOLUTION)
    fake_img = (fake_img * 0.5 + 0.5).clamp(0, 1)  # denormalize dari [-1, 1] ke [0, 1]

# Simpan dalam satu grid 8x8
save_image(fake_img, "batik_grid.png", nrow=8)
print("Saved batik grid image as 'batik_grid.png'")

In [None]:
dummy_input = torch.randn(1, LATENT_DIM, device=DEVICE)

torch.onnx.export(
    generator,                         
    dummy_input,                      
    "generator.onnx",                 
    input_names=["latent_vector"],     
    output_names=["generated_image"],  
    dynamic_axes={
        "latent_vector": {0: "batch_size"},
        "generated_image": {0: "batch_size"},
    },
)