In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt 
import os
from torchvision.utils import save_image

In [14]:
if torch.cuda.is_available():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("GPU name:", torch.cuda.get_device_name(0))
    print("GPU count:", torch.cuda.device_count())

GPU name: NVIDIA GeForce GTX 1660 Ti
GPU count: 1


# LOAD DATASET

In [15]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [16]:
dataset_path = "./train" 

full_dataset = datasets.ImageFolder(
    root=dataset_path,
    transform=transform
)

In [17]:
print("Classes:", full_dataset.classes)
print("Total images:", len(full_dataset))

Classes: ['sunflower']
Total images: 500


In [18]:
val_ratio = 0.2
num_total = len(full_dataset)
num_val = int(num_total * val_ratio)
num_train = num_total - num_val

train_dataset, val_dataset = random_split(full_dataset, [num_train, num_val])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# A. Construct and train a deep learning model as an image generator by using the  images in your training dataset to train the model.

In [19]:
class VAE(nn.Module):
    def __init__(self, latent_dim=64):
        super(VAE, self).__init__()
        
        # Encoder
        self.enc = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),  # 32x16x16
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1), # 64x8x8
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1), # 128x4x4
            nn.ReLU(True)
        )
        
        self.fc_mu = nn.Linear(128*4*4, latent_dim)
        self.fc_logvar = nn.Linear(128*4*4, latent_dim)

        # Decoder
        self.fc_dec = nn.Linear(latent_dim, 128*4*4)

        self.dec = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1), # 64x8x8
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), # 32x16x16
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),  # 3x32x32
            nn.Sigmoid()  # [0,1]
        )

    def encode(self, x):
        h = self.enc(x)
        h = h.view(h.size(0), -1)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.fc_dec(z)
        h = h.view(-1, 128, 4, 4)
        return self.dec(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar


In [20]:
def vae_loss(recon, x, mu, logvar):
    recon_loss = F.mse_loss(recon, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return (recon_loss + kl_loss) / x.size(0)

In [None]:
model = VAE(latent_dim=64).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)

num_epochs = 200
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0

    # Train
    for imgs, _ in train_loader:
        imgs = imgs.to(device)

        recon, mu, logvar = model(imgs)
        loss = vae_loss(recon, imgs, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for imgs, _ in val_loader:
            imgs = imgs.to(device)
            recon, mu, logvar = model(imgs)
            loss = vae_loss(recon, imgs, mu, logvar)
            val_loss += loss.item()

    train_losses.append(train_loss / len(train_loader))
    val_losses.append(val_loss / len(val_loader))

    print(f"Epoch {epoch+1}/{num_epochs}, Train={train_losses[-1]:.4f}, Val={val_losses[-1]:.4f}")


Epoch 1/100, Train=2068.2029, Val=1631.3683
Epoch 2/100, Train=1284.0907, Val=1150.7721
Epoch 3/100, Train=1109.9180, Val=1093.9966
Epoch 4/100, Train=1080.4228, Val=1077.2298
Epoch 5/100, Train=1068.8171, Val=1075.6365
Epoch 6/100, Train=1051.9652, Val=1057.2864
Epoch 7/100, Train=1052.1661, Val=1041.8577
Epoch 8/100, Train=1033.2260, Val=1024.2340
Epoch 9/100, Train=1016.6684, Val=1016.7541
Epoch 10/100, Train=999.3309, Val=1003.2425
Epoch 11/100, Train=993.2126, Val=996.5915
Epoch 12/100, Train=984.1025, Val=994.5771
Epoch 13/100, Train=984.5100, Val=980.7440
Epoch 14/100, Train=975.4386, Val=985.2407
Epoch 15/100, Train=969.1200, Val=979.5393
Epoch 16/100, Train=964.7430, Val=974.1036
Epoch 17/100, Train=967.7508, Val=982.2454
Epoch 18/100, Train=966.5055, Val=972.0754
Epoch 19/100, Train=960.4865, Val=971.8147
Epoch 20/100, Train=960.8339, Val=966.3830
Epoch 21/100, Train=958.9969, Val=966.5356
Epoch 22/100, Train=946.9071, Val=963.5494
Epoch 23/100, Train=952.9742, Val=959.5427
E