#### Importimg Library

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import os
from sklearn.manifold import TSNE
import numpy as np
import torchvision.utils as vutils

# 1. Hyperparameters & Device

In [5]:
batch_size = 64
latent_dim = 100
lr_initial = 1e-3
num_epochs = 2       # silakan ganti sesuai kebutuhan
image_size = 64
save_sample_every = 5 # save recon & sample images tiap beberapa epoch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device :", device)
print("Device       :",torch.cuda.get_device_name(0))
print("Cude Version :", torch.version.cuda)

Using device : cuda
Device       : NVIDIA GeForce GTX 1060 6GB
Cude Version : 11.8


# 2. Dataset & DataLoader

In [None]:
data_dir = r"C:\Users\dawwi\Downloads\Dataset"  # Pastikan punya subfolder di dalamnya
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

# 3. Definisi Encoder & Decoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),  # (B, 32, 32, 32)
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1), # (B, 64, 16, 16)
            nn.ReLU(True),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(64 * 16 * 16, latent_dim)
        self.fc_logvar = nn.Linear(64 * 16 * 16, latent_dim)

    def forward(self, x):
        x = self.conv(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 64 * 16 * 16)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, 2, 1),  # (B, 32, 32, 32)
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),   # (B, 3, 64, 64)
            nn.Sigmoid()
        )

    def forward(self, z):
        x = self.fc(z).view(-1, 64, 16, 16)
        x = self.deconv(x)
        return x

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

# 4. Inisialisasi Model, Optimizer, Scheduler

In [None]:
encoder = Encoder(latent_dim).to(device)
decoder = Decoder(latent_dim).to(device)

params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=lr_initial)

# Contoh StepLR: setiap 5 epoch, LR turun dikali 0.5
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# 5. Loss Function (menggunakan reduction='mean')

In [None]:
def loss_function(recon_x, x, mu, logvar):
    """
    BCE + KL Divergence, dengan BCE direduksi mean per pixel.
    """
    BCE = F.binary_cross_entropy(recon_x, x, reduction='mean')
    # KLD = 0.5 * sum( var + mu^2 - logvar - 1 )
    # tapi kita pakai formula VAE standard:
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# 6. Fungsi Helper (Simpan Recon & Sampling)

In [None]:
def save_reconstructions(encoder, decoder, images, epoch, save_dir="samples"):
    """
    Menyimpan hasil rekonstruksi beberapa gambar (images) ke file.
    """
    os.makedirs(save_dir, exist_ok=True)
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        mu, logvar = encoder(images)
        z = reparameterize(mu, logvar)
        recon = decoder(z)
    # Buat grid (asli + recon) untuk perbandingan
    # misal: baris pertama = real images, baris kedua = recon images
    images_concat = torch.cat([images, recon], dim=0)
    grid = vutils.make_grid(images_concat.cpu(), nrow=images.size(0), normalize=True)
    filename = os.path.join(save_dir, f"recon_epoch_{epoch}.png")
    vutils.save_image(grid, filename)
    print(f"Saved reconstruction: {filename}")

In [None]:
def save_sampling(decoder, epoch, num_samples=8, save_dir="samples"):
    """
    Sampling random z dari N(0, I), generate via decoder, simpan sebagai grid.
    """
    os.makedirs(save_dir, exist_ok=True)
    decoder.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim, device=device)
        samples = decoder(z)
    grid = vutils.make_grid(samples.cpu(), nrow=num_samples, normalize=True)
    filename = os.path.join(save_dir, f"sample_epoch_{epoch}.png")
    vutils.save_image(grid, filename)
    print(f"Saved sampling: {filename}")

# 7. Variabel Logging

In [None]:
loss_history_iter = []    # loss di setiap iterasi
lr_history_iter = []      # learning rate di setiap iterasi
loss_history_epoch = []   # average loss per epoch

total_iters = 0

# 8. Training Loop

In [None]:
print("Start Training VAE...")
for epoch in range(1, num_epochs+1):
    epoch_loss = 0.0
    encoder.train()
    decoder.train()

    for i, (imgs, _) in enumerate(dataloader):
        imgs = imgs.to(device)

        # Forward
        mu, logvar = encoder(imgs)
        z = reparameterize(mu, logvar)
        recon = decoder(z)

        # Hitung loss
        loss = loss_function(recon, imgs, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Logging
        batch_loss = loss.item()
        loss_history_iter.append(batch_loss)
        lr_history_iter.append(optimizer.param_groups[0]['lr'])
        epoch_loss += batch_loss
        total_iters += 1

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(dataloader)}], "
                  f"Loss: {batch_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")

    # Rata-rata loss per epoch
    avg_epoch_loss = epoch_loss / len(dataloader)
    loss_history_epoch.append(avg_epoch_loss)

    # Scheduler step (menurunkan LR)
    scheduler.step()

    # Simpan recon & sample setiap 'save_sample_every' epoch
    if epoch % save_sample_every == 0:
        # Ambil batch pertama dari data loader (buat recon)
        example_imgs, _ = next(iter(dataloader))
        example_imgs = example_imgs[:8].to(device)  # ambil 8 gambar
        save_reconstructions(encoder, decoder, example_imgs, epoch)
        save_sampling(decoder, epoch, num_samples=8)

print("Training Finished.")

# 9. Simpan Model

In [None]:
os.makedirs("saved_models", exist_ok=True)
torch.save(encoder.state_dict(), "saved_models/vae_encoder.pth")
torch.save(decoder.state_dict(), "saved_models/vae_decoder.pth")
print("Model saved to 'saved_models' folder.")

# 10. Plot Loss & Learning Rate

In [None]:
plt.figure(figsize=(10,4))
plt.title("VAE Loss per Iteration")
plt.plot(loss_history_iter, label="Loss (iter)")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()
plt.show()

plt.figure(figsize=(10,4))
plt.title("VAE Loss per Epoch")
plt.plot(loss_history_epoch, label="Loss (epoch)")
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.legend()
plt.show()

plt.figure(figsize=(10,4))
plt.title("Learning Rate over Iterations")
plt.plot(lr_history_iter, label="LR")
plt.xlabel("Iteration")
plt.ylabel("Learning Rate")
plt.legend()
plt.show()

# 11. Visualisasi Latent Space (t-SNE pada mu)

In [None]:
def visualize_latent_tsne(encoder, device, dataloader, n_samples=1000):
    """
    Mengumpulkan mu dari batch (total n_samples), lalu t-SNE 2D.
    """
    encoder.eval()
    all_mu = []
    total_collected = 0

    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            mu, logvar = encoder(images)
            all_mu.append(mu.cpu().numpy())
            total_collected += images.size(0)
            if total_collected >= n_samples:
                break

    all_mu = np.concatenate(all_mu, axis=0)
    all_mu = all_mu[:n_samples]  # trim jika kebanyakan

    print("Running t-SNE on latent (mu), might take a while ...")
    tsne = TSNE(n_components=2, perplexity=30, n_iter=1000)
    mu_2d = tsne.fit_transform(all_mu)

    plt.figure(figsize=(8,6))
    plt.scatter(mu_2d[:,0], mu_2d[:,1], s=10, alpha=0.7, c='blue')
    plt.title("Latent Space (mu) Visualization via t-SNE")
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.show()

# Jalankan t-SNE (opsional):
visualize_latent_tsne(encoder, device, dataloader, n_samples=1000)