In [None]:
!ls

In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
OUTPUT_DIR = "/content/drive/MyDrive/flower_results"

vae_img_dir = OUTPUT_DIR + "/vae_images"
vaegan_img_dir = OUTPUT_DIR + "/vaegan_images"
model_dir = OUTPUT_DIR + "/models"
plot_dir = OUTPUT_DIR + "/plots"

for d in [vae_img_dir, vaegan_img_dir, model_dir, plot_dir]:
    os.makedirs(d, exist_ok=True)

print("Folders created in Drive")


In [None]:
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.Flowers102(
    root="/content/data",
    split="train",
    download=True,
    transform=transform
)

loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)

print("Total images:", len(dataset))
print("Batches per epoch:", len(loader))


In [None]:
images, _ = next(iter(loader))
grid = utils.make_grid(images[:16], normalize=True)
plt.figure(figsize=(6,6))
plt.imshow(grid.permute(1,2,0))
plt.axis("off")
plt.show()


In [None]:
latent_dim = 128

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3,64,4,2,1), nn.ReLU(),
            nn.Conv2d(64,128,4,2,1), nn.ReLU(),
            nn.Conv2d(128,256,4,2,1), nn.ReLU(),
            nn.Conv2d(256,512,4,2,1), nn.ReLU()
        )
        self.fc_mu = nn.Linear(512*4*4, latent_dim)
        self.fc_logvar = nn.Linear(512*4*4, latent_dim)

    def forward(self,x):
        x = self.conv(x).view(x.size(0), -1)
        return self.fc_mu(x), self.fc_logvar(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 512*4*4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(512,256,4,2,1), nn.ReLU(),
            nn.ConvTranspose2d(256,128,4,2,1), nn.ReLU(),
            nn.ConvTranspose2d(128,64,4,2,1), nn.ReLU(),
            nn.ConvTranspose2d(64,3,4,2,1), nn.Tanh()
        )

    def forward(self,z):
        x = self.fc(z).view(-1,512,4,4)
        return self.deconv(x)

encoder = Encoder().to(device)
decoder = Decoder().to(device)


In [None]:
import torch.nn.functional as F

def vae_loss(recon, x, mu, logvar):
    recon_loss = F.mse_loss(recon, x)
    kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss, kl_loss


In [None]:
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)

epochs = 50

vae_recon_losses = []
vae_kl_losses = []

for epoch in range(epochs):
    recon_epoch = 0
    kl_epoch = 0

    for imgs,_ in tqdm(loader):
        imgs = imgs.to(device)

        mu, logvar = encoder(imgs)
        std = torch.exp(0.5*logvar)
        z = mu + std*torch.randn_like(std)

        recon = decoder(z)

        recon_loss, kl_loss = vae_loss(recon, imgs, mu, logvar)
        loss = recon_loss + kl_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        recon_epoch += recon_loss.item()
        kl_epoch += kl_loss.item()

    recon_epoch /= len(loader)
    kl_epoch /= len(loader)

    vae_recon_losses.append(recon_epoch)
    vae_kl_losses.append(kl_epoch)

    utils.save_image(recon[:16], f"{vae_img_dir}/recon_epoch_{epoch}.png", normalize=True)
    utils.save_image(decoder(torch.randn(16,latent_dim).to(device)),
                     f"{vae_img_dir}/sample_epoch_{epoch}.png", normalize=True)

    torch.save({"encoder":encoder.state_dict(),"decoder":decoder.state_dict()},
               f"{model_dir}/vae_epoch_{epoch}.pt")

    print(f"Epoch {epoch}: Recon={recon_epoch:.4f}, KL={kl_epoch:.4f}")


In [None]:
pd.DataFrame({
    "Reconstruction Loss": vae_recon_losses,
    "KL Loss": vae_kl_losses
}).to_csv(OUTPUT_DIR+"/vae_losses.csv", index=False)


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3,64,4,2,1), nn.LeakyReLU(0.2),
            nn.Conv2d(64,128,4,2,1), nn.LeakyReLU(0.2),
            nn.Conv2d(128,256,4,2,1), nn.LeakyReLU(0.2),
            nn.Conv2d(256,1,4,1,0), nn.Sigmoid()
        )

    def forward(self,x):
        return self.net(x).view(-1)

disc = Discriminator().to(device)


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


In [None]:
import os

GAN_OUTPUT_DIR = "/content/drive/MyDrive/flower_gan_results"

gan_img_dir = os.path.join(GAN_OUTPUT_DIR, "images")
gan_model_dir = os.path.join(GAN_OUTPUT_DIR, "models")

os.makedirs(gan_img_dir, exist_ok=True)
os.makedirs(gan_model_dir, exist_ok=True)

print("New GAN folders created:")
print(gan_img_dir)
print(gan_model_dir)

# verify
os.listdir(GAN_OUTPUT_DIR)


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision import utils

# ===== NEW GAN OUTPUT FOLDER =====
GAN_OUTPUT_DIR = "/content/drive/MyDrive/flower_gan_results"
gan_img_dir = os.path.join(GAN_OUTPUT_DIR, "images")
gan_model_dir = os.path.join(GAN_OUTPUT_DIR, "models")

os.makedirs(gan_img_dir, exist_ok=True)
os.makedirs(gan_model_dir, exist_ok=True)

print("GAN folders ready:", gan_img_dir, gan_model_dir)

# ===== Hyperparameters =====
epochs = 100

lambda_recon = 10.0
lambda_kl = 0.1
lambda_adv = 1.0

# ===== Optimizers =====
opt_G = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
opt_D = optim.Adam(disc.parameters(), lr=5e-5)

bce = nn.BCELoss()

gan_losses = []
recon_losses2 = []
kl_losses2 = []

# ===== Training Loop =====
for epoch in range(epochs):
    g_epoch = 0.0
    d_epoch = 0.0
    recon_epoch = 0.0
    kl_epoch = 0.0

    for imgs, _ in tqdm(loader):
        imgs = imgs.to(device)

        # ---- Encode & Decode ----
        mu, logvar = encoder(imgs)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        fake = decoder(z)

        # ---- Train Discriminator ----
        d_real = disc(imgs)
        d_fake = disc(fake.detach())

        # label smoothing
        real_labels = 0.9 * torch.ones_like(d_real)
        fake_labels = 0.1 * torch.zeros_like(d_fake)

        d_loss = bce(d_real, real_labels) + bce(d_fake, fake_labels)

        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        # ---- Train Generator (VAE-GAN) ----
        d_fake = disc(fake)
        adv_loss = bce(d_fake, torch.ones_like(d_fake))

        recon_loss, kl_loss = vae_loss(fake, imgs, mu, logvar)

        g_loss = (
            lambda_recon * recon_loss +
            lambda_kl * kl_loss +
            lambda_adv * adv_loss
        )

        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()

        # ---- Accumulate losses ----
        g_epoch += g_loss.item()
        d_epoch += d_loss.item()
        recon_epoch += recon_loss.item()
        kl_epoch += kl_loss.item()

    # ---- Average losses ----
    g_epoch /= len(loader)
    d_epoch /= len(loader)
    recon_epoch /= len(loader)
    kl_epoch /= len(loader)

    gan_losses.append(g_epoch)
    recon_losses2.append(recon_epoch)
    kl_losses2.append(kl_epoch)

    # ---- Save image safely ----
    img_path = os.path.join(gan_img_dir, f"epoch_{epoch+1}.png")
    utils.save_image(fake[:16], img_path, normalize=True)

    # ---- Save model safely ----
    model_path = os.path.join(gan_model_dir, f"vaegan_epoch_{epoch+1}.pt")
    torch.save({
        "encoder": encoder.state_dict(),
        "decoder": decoder.state_dict(),
        "discriminator": disc.state_dict()
    }, model_path)

    # ---- Print log ----
    print(f"Epoch [{epoch+1}/{epochs}] | "
          f"G: {g_epoch:.4f} | "
          f"D: {d_epoch:.4f} | "
          f"Recon: {recon_epoch:.4f} | "
          f"KL: {kl_epoch:.4f}")


In [None]:
import pandas as pd

df_gan = pd.DataFrame({
    "GAN Loss": gan_losses,
    "Reconstruction Loss": recon_losses2,
    "KL Loss": kl_losses2
})

csv_path = "/content/drive/MyDrive/flower_gan_results/gan_losses.csv"
df_gan.to_csv(csv_path, index=False)

print("Losses saved to:", csv_path)


In [None]:
import matplotlib.pyplot as plt

epochs_range = range(1, len(gan_losses)+1)

plt.figure(figsize=(8,5))
plt.plot(epochs_range, gan_losses, label="GAN (Generator) Loss")
plt.plot(epochs_range, recon_losses2, label="Reconstruction Loss")
plt.plot(epochs_range, kl_losses2, label="KL Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("VAE-GAN Training Losses")
plt.legend()
plt.grid(True)

plot_path = "/content/drive/MyDrive/flower_gan_results/loss_plot.png"
plt.savefig(plot_path)
plt.show()

print("Plot saved to:", plot_path)


In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import os

vae_img_path = "/content/drive/MyDrive/flower_results/vae_images/recon_epoch_10.png"   # change epoch number
gan_img_path = "/content/drive/MyDrive/flower_gan_results/images/epoch_99.png"        # change epoch number

vae_img = Image.open(vae_img_path)
gan_img = Image.open(gan_img_path)

plt.figure(figsize=(10,5))

plt.subplot(1,2,1)
plt.imshow(vae_img)
plt.title("VAE Output (Blurry)")
plt.axis("off")

plt.subplot(1,2,2)
plt.imshow(gan_img)
plt.title("VAE-GAN Output (Sharper)")
plt.axis("off")

plt.show()


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision import utils

# ===== Paths =====
GAN_OUTPUT_DIR = "/content/drive/MyDrive/flower_gan_results"
gan_img_dir = os.path.join(GAN_OUTPUT_DIR, "images")
gan_model_dir = os.path.join(GAN_OUTPUT_DIR, "models")

os.makedirs(gan_img_dir, exist_ok=True)
os.makedirs(gan_model_dir, exist_ok=True)

# ===== Load last checkpoint =====
last_epoch = 100   # change this to your last trained epoch number
checkpoint_path = f"{gan_model_dir}/vaegan_epoch_{last_epoch}.pt"

ckpt = torch.load(checkpoint_path, map_location=device)

encoder.load_state_dict(ckpt["encoder"])
decoder.load_state_dict(ckpt["decoder"])
disc.load_state_dict(ckpt["discriminator"])

print("Checkpoint loaded from epoch", last_epoch)

# ===== Hyperparameters =====
more_epochs = 30
lambda_recon = 10.0
lambda_kl = 0.1
lambda_adv = 1.0

# ===== Optimizers =====
opt_G = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
opt_D = optim.Adam(disc.parameters(), lr=5e-5)

bce = nn.BCELoss()

# ===== Resume Training =====
for epoch in range(last_epoch, last_epoch + more_epochs):
    g_epoch = 0.0
    d_epoch = 0.0
    recon_epoch = 0.0
    kl_epoch = 0.0

    for imgs, _ in tqdm(loader):
        imgs = imgs.to(device)

        # ---- Encode & Decode ----
        mu, logvar = encoder(imgs)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        fake = decoder(z)

        # ---- Train Discriminator ----
        d_real = disc(imgs)
        d_fake = disc(fake.detach())

        real_labels = 0.9 * torch.ones_like(d_real)
        fake_labels = 0.1 * torch.zeros_like(d_fake)

        d_loss = bce(d_real, real_labels) + bce(d_fake, fake_labels)

        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        # ---- Train Generator ----
        d_fake = disc(fake)
        adv_loss = bce(d_fake, torch.ones_like(d_fake))

        recon_loss, kl_loss = vae_loss(fake, imgs, mu, logvar)

        g_loss = (
            lambda_recon * recon_loss +
            lambda_kl * kl_loss +
            lambda_adv * adv_loss
        )

        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()

        g_epoch += g_loss.item()
        d_epoch += d_loss.item()
        recon_epoch += recon_loss.item()
        kl_epoch += kl_loss.item()

    # ---- Average losses ----
    g_epoch /= len(loader)
    d_epoch /= len(loader)
    recon_epoch /= len(loader)
    kl_epoch /= len(loader)

    # ---- Save image ----
    img_path = os.path.join(gan_img_dir, f"epoch_{epoch+1}.png")
    utils.save_image(fake[:16], img_path, normalize=True)

    # ---- Save checkpoint ----
    model_path = os.path.join(gan_model_dir, f"vaegan_epoch_{epoch+1}.pt")
    torch.save({
        "encoder": encoder.state_dict(),
        "decoder": decoder.state_dict(),
        "discriminator": disc.state_dict()
    }, model_path)

    print(f"Epoch [{epoch+1}] | "
          f"G: {g_epoch:.4f} | "
          f"D: {d_epoch:.4f} | "
          f"Recon: {recon_epoch:.4f} | "
          f"KL: {kl_epoch:.4f}")


In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import os

vae_img_path = "/content/drive/MyDrive/flower_results/vae_images/recon_epoch_10.png"   # change epoch number
gan_img_path = "/content/drive/MyDrive/flower_gan_results/images/epoch_130.png"        # change epoch number

vae_img = Image.open(vae_img_path)
gan_img = Image.open(gan_img_path)

plt.figure(figsize=(10,5))

plt.subplot(1,2,1)
plt.imshow(vae_img)
plt.title("VAE Output (Blurry)")
plt.axis("off")

plt.subplot(1,2,2)
plt.imshow(gan_img)
plt.title("VAE-GAN Output (Sharper)")
plt.axis("off")

plt.show()
