In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

# =====================================================
# CONFIG
# =====================================================
IMG_SIZE = 32
CHANNELS = 3
BATCH = 128
EPOCHS = 100
LATENT_DIM = 128
DATA_DIR = r"D:\pk\dataset\real"
MODEL_PATH = "vae.pth"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# =====================================================
# DATASET
# =====================================================
class ImageDataset(Dataset):
    def __init__(self, folder):
        self.files = [
            os.path.join(folder, f)
            for f in os.listdir(folder)
            if f.lower().endswith(("jpg", "jpeg", "png"))
        ]
        self.transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        return self.transform(img)

dataset = ImageDataset(DATA_DIR)
loader = DataLoader(dataset, batch_size=BATCH, shuffle=True)

# =====================================================
# VAE MODEL
# =====================================================
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),   # 16x16
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), # 8x8
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),# 4x4
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(256 * 4 * 4, LATENT_DIM)
        self.fc_logvar = nn.Linear(256 * 4 * 4, LATENT_DIM)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(LATENT_DIM, 256 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),    # 32x32
            nn.Tanh()
        )

    def forward(self, z):
        z = self.fc(z)
        z = z.view(z.size(0), 256, 4, 4)
        return self.deconv(z)

class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar

# =====================================================
# INIT MODEL
# =====================================================
model = VAE().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# =====================================================
# MODEL SUMMARY
# =====================================================
print("\nVAE MODEL SUMMARY\n")
print(model)
print(f"\nTotal Parameters: {sum(p.numel() for p in model.parameters()):,}\n")

# =====================================================
# LOSS FUNCTION
# =====================================================
def vae_loss(recon, x, mu, logvar):
    recon_loss = F.mse_loss(recon, x, reduction="sum")
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss

# =====================================================
# TRAINING
# =====================================================
for epoch in range(EPOCHS):
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    total_loss = 0

    for x in pbar:
        x = x.to(DEVICE)

        recon, mu, logvar = model(x)
        loss = vae_loss(recon, x, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix(loss=loss.item() / x.size(0))

    print(f"Epoch {epoch+1} Avg Loss: {total_loss / len(dataset):.4f}")

print("✔ Training Complete")

# =====================================================
# SAVE MODEL
# =====================================================
torch.save(model.state_dict(), MODEL_PATH)
print(f"✔ Model saved to {MODEL_PATH}")

# =====================================================
# LOAD MODEL (INFERENCE)
# =====================================================
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =====================================================
# INFERENCE – GENERATE NEW IMAGES
# =====================================================
@torch.no_grad()
def generate_images(n=16):
    z = torch.randn(n, LATENT_DIM, device=DEVICE)
    imgs = model.decoder(z)
    imgs = (imgs + 1) / 2
    return imgs.cpu()

# =====================================================
# INFERENCE – RECONSTRUCT INPUT
# =====================================================
@torch.no_grad()
def reconstruct_images(x):
    recon, _, _ = model(x.to(DEVICE))
    recon = (recon + 1) / 2
    return recon.cpu()

# =====================================================
# SHOW GENERATED IMAGES
# =====================================================
samples = generate_images(16)

plt.figure(figsize=(5,5))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(samples[i].permute(1,2,0))
    plt.axis("off")
plt.show()



VAE MODEL SUMMARY

VAE(
  (encoder): Encoder(
    (conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU()
      (4): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): ReLU()
    )
    (fc_mu): Linear(in_features=4096, out_features=128, bias=True)
    (fc_logvar): Linear(in_features=4096, out_features=128, bias=True)
  )
  (decoder): Decoder(
    (fc): Linear(in_features=128, out_features=4096, bias=True)
    (deconv): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU()
      (4): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): Tanh()
    )
  )
)

Total Parameters: 2,894,723



Epoch 1/100: 100%|██████████| 79/79 [02:09<00:00,  1.64s/it, loss=218]


Epoch 1 Avg Loss: 363.7231


Epoch 2/100: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s, loss=179]


Epoch 2 Avg Loss: 221.8702


Epoch 3/100:  90%|████████▉ | 71/79 [00:43<00:04,  1.63it/s, loss=187]