In [19]:
import torch
import torch.nn.functional as F
from torch import nn
import torchvision
from torchvision import transforms

from utils import *
from learner import *
# from vae import *

In [20]:
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        self.fc_mu = nn.Linear(256 * 32 * 32, latent_dim)
        self.fc_logvar = nn.Linear(256 * 32 * 32, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256 * 32 * 32),
            nn.ReLU(),
            nn.Unflatten(1, (256, 32, 32)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar
        
def vae_loss(X_new=None, X=None, mu=None, lv=None):
    return F.binary_cross_entropy(X_new, X, reduction="sum") + (-0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp()))


In [21]:
# # Define the VAE model
# class VAE(nn.Module):
#     def __init__(self, latent_dim=20):
#         super().__init__()
#         self.latent_dim = latent_dim
#         # Encoder
#         self.encoder = nn.Sequential(
#             nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(32),  # <-- Added batch normalization
#             nn.ReLU(),
#             nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(64),  # <-- Added batch normalization
#             nn.ReLU(),
#             nn.Flatten(),
#             nn.Linear(64 * 7 * 7, 2 * self.latent_dim)  # Output mean and log variance
#         )

#         # Decoder
#         self.decoder = nn.Sequential(
#             nn.Linear(self.latent_dim, 64 * 7 * 7),
#             nn.ReLU(),
#             nn.Unflatten(1, (64, 7, 7)),
#             nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(32),  # <-- Added batch normalization
#             nn.ReLU(),
#             nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
#             nn.Sigmoid()  # Output image in [0, 1] range
#         )

#     def reparameterize(self, mu, logvar):
#         std = torch.exp(0.5 * logvar)
#         eps = torch.randn_like(std)
#         return mu + eps * std

#     def forward(self, x):
#         z = self.encoder(x)
#         mu, logvar = z[:, :self.latent_dim], z[:, self.latent_dim:]
#         z = self.reparameterize(mu, logvar)
#         return self.decoder(z), mu, logvar

In [22]:
BATCH_SIZE = 2
device = "cuda" if torch.cuda.is_available() else "cpu"

In [23]:
# Write transform for image
data_transform = transforms.Compose([
    # Resize the images to 64x64
    transforms.Resize(size=(512, 512)),
    # Flip the images randomly on the horizontal
    transforms.RandomHorizontalFlip(p=0.5), # p = probability of flip, 0.5 = 50% chance
    # Turn the image into a torch.Tensor
    transforms.ToTensor() # this also converts all pixel values from 0 to 255 to be between 0.0 and 1.0 
])

In [24]:
import requests
import zipfile
from pathlib import Path
import random
from PIL import Image
data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"
image_path_list = list(image_path.glob("*/*/*.jpg"))


In [25]:
train_dir = image_path / "train"
test_dir = image_path / "test"

train_dir, test_dir

(WindowsPath('data/pizza_steak_sushi/train'),
 WindowsPath('data/pizza_steak_sushi/test'))

In [26]:
train_data = torchvision.datasets.ImageFolder(root=train_dir, # target folder of images
                                  transform=data_transform, # transforms to perform on data (images)
                                  target_transform=None) # transforms to perform on labels (if necessary)

test_data = torchvision.datasets.ImageFolder(root=test_dir, 
                                 transform=data_transform)

print(f"Train data:\n{train_data}\nTest data:\n{test_data}")

Train data:
Dataset ImageFolder
    Number of datapoints: 150
    Root location: data\pizza_steak_sushi\train
    StandardTransform
Transform: Compose(
               Resize(size=(512, 512), interpolation=bilinear, max_size=None, antialias=warn)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
           )
Test data:
Dataset ImageFolder
    Number of datapoints: 75
    Root location: data\pizza_steak_sushi\test
    StandardTransform
Transform: Compose(
               Resize(size=(512, 512), interpolation=bilinear, max_size=None, antialias=warn)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
           )


In [27]:
train_dataloader = DataLoader(dataset=train_data, 
                              batch_size=BATCH_SIZE, # how many samples per batch?
                              shuffle=True) # shuffle the data?

test_dataloader = DataLoader(dataset=test_data, 
                             batch_size=BATCH_SIZE, 
                             shuffle=False) # don't usually need to shuffle testing data

train_dataloader, test_dataloader

(<torch.utils.data.dataloader.DataLoader at 0x2c1bc784b80>,
 <torch.utils.data.dataloader.DataLoader at 0x2c1bc7841f0>)

In [None]:
def plot_transformed_images(image_paths, transform, n=3, seed=42):
    """Plots a series of random images from image_paths.

    Will open n image paths from image_paths, transform them
    with transform and plot them side by side.

    Args:
        image_paths (list): List of target image paths. 
        transform (PyTorch Transforms): Transforms to apply to images.
        n (int, optional): Number of images to plot. Defaults to 3.
        seed (int, optional): Random seed for the random generator. Defaults to 42.
    """
    random.seed(seed)
    random_image_paths = random.sample(image_paths, k=n)
    for image_path in random_image_paths:
        with Image.open(image_path) as f:
            fig, ax = plt.subplots(1, 2)
            ax[0].imshow(f) 
            ax[0].set_title(f"Original \nSize: {f.size}")
            ax[0].axis("off")

            # Transform and plot image
            # Note: permute() will change shape of image to suit matplotlib 
            # (PyTorch default is [C, H, W] but Matplotlib is [H, W, C])
            transformed_image = transform(f).permute(1, 2, 0) 
            ax[1].imshow(transformed_image) 
            ax[1].set_title(f"Transformed \nSize: {transformed_image.shape}")
            ax[1].axis("off")

            fig.suptitle(f"Class: {image_path.parent.stem}", fontsize=16)

plot_transformed_images(image_path_list, 
                        transform=data_transform, 
                        n=3)

In [29]:
import matplotlib.pyplot as plt
import numpy as np

def plot_reconstructions(model, data, num_samples=3):
    with torch.inference_mode():
        recon_batch, _, _ = model(data)
        recon_batch = recon_batch.cpu()

    # Plot original images
    plt.figure(figsize=(30, 12))
    for i in range(min(num_samples, len(data))):
        plt.subplot(2, num_samples, i + 1)
        plt.imshow(data[i].permute(1, 2, 0).cpu())
        plt.title('Original')
        plt.axis('off')

    # Plot reconstructed images
    for i in range(min(num_samples, len(data))):
        plt.subplot(2, num_samples, i + num_samples + 1)
        plt.imshow(recon_batch[i].permute(1, 2, 0))
        plt.title('New')
        plt.axis('off')

    plt.show()

In [31]:
ic, hc, lc = 784, 400, 20
lr = 1e-3
vae = VAE(latent_dim=64)
loss_fn = vae_loss
epochs = 10
optimizer = torch.optim.Adam(vae.parameters(), lr=lr)
schedular = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, total_steps=epochs*len(train_dataloader))

In [32]:
class VAE_Learner(Learner):
    def _update_results(self):
        self.loss_metric.update(self.loss)

    def _run_batch(self):
        self.X_new, self.mu, self.lv = self.model(self.X)
        self.loss = self.loss_fn(self.X_new, self.X, self.mu, self.lv)
        self._update_results()
        if self.training:
            self.optimizer.zero_grad()
            self.loss.backward()
            self.optimizer.step()

    def _run_epoch(self, train=True):
        self.training = train
        self.dl = self.train_dataloader if self.training else self.test_dataloader
        for self.X, _ in self.dl:
            self.X = self.X.to(self.device)
            self._run_batch()
   

In [33]:
vae_learner = VAE_Learner(model=vae, train_dataloader=train_dataloader, test_dataloader=test_dataloader, loss_fn=loss_fn, optimizer=optimizer, scheduler=schedular, device=device)

In [34]:
def train():
    vae_learner.fit(epochs=10, test=False)
    for batch, (data, _) in enumerate(test_dataloader):
        plot_reconstructions(vae, data.to(device))
        if batch == 5:
            break

In [None]:
for _ in range(5): train()

In [None]:
for _ in range(5): train()

In [37]:
torch.save(vae.state_dict(), './vae_model.pth') 

In [None]:
for _ in range(20): train()

In [39]:
torch.save(vae.state_dict(), 'vae_model.pth')