In [None]:
flattened_size = 512 * 4 * 4  # this is correct for 128×128 input
latent_dim = 128


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torch.nn.functional as F
import time

# Dataset path
dataset_path = 'data/augmentedData'
if not any(os.path.isdir(os.path.join(dataset_path, d)) for d in os.listdir(dataset_path)):
    os.makedirs(os.path.join(dataset_path, 'unspecified'), exist_ok=True)
    for file in os.listdir(dataset_path):
        file_path = os.path.join(dataset_path, file)
        if os.path.isfile(file_path):
            os.rename(file_path, os.path.join(dataset_path, 'unspecified', file))

# Transforms with augmentations
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# --------------------------------------
# VGG based perpetual loss
# ---------------------------------------

class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = models.vgg16(pretrained=True).features[:16].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg

    def forward(self, x, y):
        x_vgg = self.vgg(x)
        y_vgg = self.vgg(y)
        return F.mse_loss(x_vgg, y_vgg)

# --------------------------------------------------
# Encoder
# -------------------------------------------------

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),    # 128 -> 64
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1),  # 64 -> 32
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1), # 32 -> 16
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, 4, 2, 1), # 16 -> 8
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 512, 4, 2, 1), # 8 -> 4
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        self.flattened_size = 512 * 4 * 4
        self.fc_mu = nn.Linear(self.flattened_size, latent_dim)
        self.fc_logvar = nn.Linear(self.flattened_size, latent_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


# ------------------------------------------
# Decoder
# ------------------------------------------

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 4, 2, 1),  # 4 -> 8
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),  # 8 -> 16
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 16 -> 32
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 32 -> 64
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),     # 64 -> 128
            nn.Tanh()
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, 512, 4, 4)
        return self.deconv(x)


# --------------------------------------
# Discriminator
# ------------------------------------

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),   # 128 -> 64
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1), # 64 -> 32
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),# 32 -> 16
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),# 16 -> 8
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 512, 4, 2, 1),# 8 -> 4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 1, 4, 1, 0),  # 4 -> 1
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.features(x)
        validity = self.classifier(features)
        return validity.view(-1, 1).squeeze(1), features.view(x.size(0), -1)



In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torch.nn.functional as F

# --------------------------
# Device & Hyperparameters
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 200
vae_epochs = 200
gan_epochs = 10000
batch_size = 32

# --------------------------
# Data Loading & Augmentation
# --------------------------
dataset_path = 'data/augmentedData'
if not any(os.path.isdir(os.path.join(dataset_path, d)) for d in os.listdir(dataset_path)):
    os.makedirs(os.path.join(dataset_path, 'unspecified'), exist_ok=True)
    for file in os.listdir(dataset_path):
        file_path = os.path.join(dataset_path, file)
        if os.path.isfile(file_path):
            os.rename(file_path, os.path.join(dataset_path, 'unspecified', file))

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# --------------------------
# Perceptual Loss
# --------------------------
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features[:16].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg

    def forward(self, x, y):
        return F.mse_loss(self.vgg(x), self.vgg(y))

# --------------------------
# Encoder
# --------------------------
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.Conv2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU(True)
        )
        self.flattened_size = 512 * 4 * 4
        self.fc_mu = nn.Linear(self.flattened_size, latent_dim)
        self.fc_logvar = nn.Linear(self.flattened_size, latent_dim)

    def forward(self, x):
        x = self.conv(x).view(x.size(0), -1)
        return self.fc_mu(x), self.fc_logvar(x)

# --------------------------
# Decoder
# --------------------------
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1), nn.Tanh()
        )

    def forward(self, z):
        x = self.fc(z).view(-1, 512, 4, 4)
        return self.deconv(x)

# --------------------------
# Discriminator
# --------------------------
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True)
        )
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 1, 4, 1, 0), nn.Sigmoid()
        )

    def forward(self, x):
        features = self.features(x)
        validity = self.classifier(features)
        return validity.view(-1, 1).squeeze(1), features.view(x.size(0), -1)

# --------------------------
# Initialize Models & Optimizers
# --------------------------
encoder = Encoder(latent_dim).to(device)
decoder = Decoder(latent_dim).to(device)
discriminator = Discriminator().to(device)

optimizer_E = optim.Adam(encoder.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_G = optim.Adam(decoder.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))



        

In [2]:
# --------------------------
# VAE Pretraining
# --------------------------
reconstruction_loss = VGGPerceptualLoss().to(device)

for epoch in range(vae_epochs):
    encoder.train()
    decoder.train()
    for imgs, _ in dataloader:
        imgs = imgs.to(device)
        optimizer_E.zero_grad()
        optimizer_G.zero_grad()

        mu, logvar = encoder(imgs)
        std = torch.exp(0.5 * logvar)
        z = torch.randn_like(std) * std + mu
        recon_imgs = decoder(z)

        recon_loss = reconstruction_loss(recon_imgs, imgs)
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / imgs.size(0)
        vae_loss = recon_loss + 0.1 * kl_div
        vae_loss.backward()
        optimizer_E.step()
        optimizer_G.step()

    print(f"[VAE Epoch {epoch+1}] Loss: {vae_loss.item():.4f}")


[VAE Epoch 1] Loss: 5.5156
[VAE Epoch 2] Loss: 5.8782
[VAE Epoch 3] Loss: 4.7253
[VAE Epoch 4] Loss: 5.1371
[VAE Epoch 5] Loss: 5.3023
[VAE Epoch 6] Loss: 5.2719
[VAE Epoch 7] Loss: 4.3793
[VAE Epoch 8] Loss: 4.7611
[VAE Epoch 9] Loss: 5.8771
[VAE Epoch 10] Loss: 5.0994
[VAE Epoch 11] Loss: 4.8063
[VAE Epoch 12] Loss: 4.4441
[VAE Epoch 13] Loss: 4.9007
[VAE Epoch 14] Loss: 4.7184
[VAE Epoch 15] Loss: 5.0375
[VAE Epoch 16] Loss: 4.5051
[VAE Epoch 17] Loss: 4.4805
[VAE Epoch 18] Loss: 4.3287
[VAE Epoch 19] Loss: 4.3823
[VAE Epoch 20] Loss: 4.3965
[VAE Epoch 21] Loss: 4.6671
[VAE Epoch 22] Loss: 5.2740
[VAE Epoch 23] Loss: 4.4042
[VAE Epoch 24] Loss: 3.9263
[VAE Epoch 25] Loss: 4.4032
[VAE Epoch 26] Loss: 4.5984
[VAE Epoch 27] Loss: 4.5535
[VAE Epoch 28] Loss: 5.0393
[VAE Epoch 29] Loss: 4.5053
[VAE Epoch 30] Loss: 3.7371
[VAE Epoch 31] Loss: 3.6139
[VAE Epoch 32] Loss: 4.8108
[VAE Epoch 33] Loss: 4.5201
[VAE Epoch 34] Loss: 4.7966
[VAE Epoch 35] Loss: 4.6210
[VAE Epoch 36] Loss: 3.9466
[

In [3]:
# --------------------------
# Full VAE-GAN Training
# --------------------------
os.makedirs("data/finetuneData", exist_ok=True)

for epoch in range(gan_epochs):
    for imgs, _ in dataloader:
        imgs = imgs.to(device)
        valid = torch.ones(imgs.size(0), device=device)
        fake = torch.zeros(imgs.size(0), device=device)

        # --- VAE ---
        optimizer_E.zero_grad()
        optimizer_G.zero_grad()

        mu, logvar = encoder(imgs)
        std = torch.exp(0.5 * logvar)
        z = torch.randn_like(std) * std + mu
        recon_imgs = decoder(z)

        recon_loss = reconstruction_loss(recon_imgs, imgs)
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / imgs.size(0)
        vae_loss = recon_loss + 0.1 * kl_div
        vae_loss.backward()
        optimizer_E.step()
        optimizer_G.step()

        # --- Discriminator ---
        optimizer_D.zero_grad()
        real_pred, _ = discriminator(imgs)
        fake_pred, _ = discriminator(decoder(torch.randn(imgs.size(0), latent_dim, device=device)).detach())
        d_loss = (F.binary_cross_entropy(real_pred, valid) + F.binary_cross_entropy(fake_pred, fake)) / 2
        d_loss.backward()
        optimizer_D.step()

        # --- Generator ---
        optimizer_G.zero_grad()
        gen_imgs = decoder(torch.randn(imgs.size(0), latent_dim, device=device))
        pred, feat_fake = discriminator(gen_imgs)
        _, feat_real = discriminator(imgs)
        feature_loss = F.mse_loss(feat_fake, feat_real.detach())
        g_loss = F.binary_cross_entropy(pred, valid) + 0.1 * feature_loss
        g_loss.backward()
        optimizer_G.step()

    print(f"[GAN Epoch {epoch+1}] [D: {d_loss.item():.4f}] [G: {g_loss.item():.4f}] [VAE: {vae_loss.item():.4f}]")

    with torch.no_grad():
        sample = decoder(torch.randn(64, latent_dim, device=device))
        vutils.save_image(sample, f"data/finetuneData/epoch_{epoch + vae_epochs}.png", normalize=True)

# --------------------------
# Save Final Models
# --------------------------
torch.save(encoder.state_dict(), "encoder.pth")
torch.save(decoder.state_dict(), "decoder.pth")
torch.save(discriminator.state_dict(), "discriminator.pth")


[GAN Epoch 1] [D: 0.3767] [G: 1.5995] [VAE: 5.0296]
[GAN Epoch 2] [D: 0.2637] [G: 2.0836] [VAE: 5.0909]
[GAN Epoch 3] [D: 0.1937] [G: 1.9506] [VAE: 5.4758]
[GAN Epoch 4] [D: 1.3252] [G: 2.9190] [VAE: 5.2462]
[GAN Epoch 5] [D: 0.4662] [G: 1.9951] [VAE: 5.1634]
[GAN Epoch 6] [D: 0.2744] [G: 2.6745] [VAE: 5.6323]
[GAN Epoch 7] [D: 0.0479] [G: 4.4751] [VAE: 5.0195]
[GAN Epoch 8] [D: 0.1724] [G: 2.0997] [VAE: 4.8039]
[GAN Epoch 9] [D: 0.0748] [G: 4.3415] [VAE: 5.4029]
[GAN Epoch 10] [D: 0.3923] [G: 1.3293] [VAE: 5.9453]
[GAN Epoch 11] [D: 0.4587] [G: 3.6408] [VAE: 6.2582]
[GAN Epoch 12] [D: 0.1050] [G: 3.5484] [VAE: 5.0492]
[GAN Epoch 13] [D: 0.1001] [G: 4.0124] [VAE: 5.1487]
[GAN Epoch 14] [D: 0.1637] [G: 5.5278] [VAE: 4.5926]
[GAN Epoch 15] [D: 0.0995] [G: 6.1698] [VAE: 4.2500]
[GAN Epoch 16] [D: 1.3120] [G: 4.5345] [VAE: 5.1596]
[GAN Epoch 17] [D: 0.0788] [G: 5.0645] [VAE: 5.5090]
[GAN Epoch 18] [D: 0.2651] [G: 6.7007] [VAE: 5.1245]
[GAN Epoch 19] [D: 0.3788] [G: 7.1634] [VAE: 4.7932]
[G

KeyboardInterrupt: 