In [None]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import opendatasets as od
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
od.download('https://www.kaggle.com/datasets/yaswanthgali/dog-images/')

In [None]:
def get_image_urls(root_folder="dog-images/images/images"):
    image_urls = []
    image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp')

    for folder_path, _, files in os.walk(root_folder):
        for file in files:
            if file.lower().endswith(image_extensions):
                relative_path = os.path.join(folder_path, file)
                url = f"/{relative_path.replace(os.sep, '/')}"
                image_urls.append(url)

    return image_urls

In [None]:
def pad_to_square(img):
    width, height = img.size
    if width == height:
        return img
    elif width > height:
        result = Image.new(img.mode, (width, width), (0, 0, 0))
        result.paste(img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(img.mode, (height, height), (0, 0, 0))
        result.paste(img, ((height - width) // 2, 0))
        return result

class VAEDataset(Dataset):
    def __init__(self, image_urls, transform=None):
        self.image_urls = image_urls
        self.transform = transform or transforms.Compose([
            transforms.Lambda(pad_to_square),
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_urls)

    def __getitem__(self, idx):
        img_path = self.image_urls[idx].lstrip('/')  # Remove leading slash
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)
        return img
    
def get_vae_dataloader(image_urls, batch_size=32, shuffle=True, num_workers=4):
    dataset = VAEDataset(image_urls)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    return dataloader


In [None]:
image_urls = get_image_urls()
dataset = VAEDataset(image_urls)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)


In [None]:
class VAE(pl.LightningModule):
    def __init__(self, latent_dim=128):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        
        self.fc_mu = nn.Linear(256 * 32 * 32, latent_dim)
        self.fc_var = nn.Linear(256 * 32 * 32, latent_dim)
        
        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 256 * 32 * 32)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(-1, 256, 32, 32)
        x = self.decoder(x)
        return x
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

    def training_step(self, batch, batch_idx):
        x = batch
        recon_x, mu, log_var = self(x)
        recon_loss = F.mse_loss(recon_x, x, reduction='sum') * 512 * 512 * 3
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        loss = recon_loss + kl_div
        self.log('train_loss', loss, prog_bar=True)
        return loss
    def on_train_epoch_end(self):
        # Test the model at the end of each epoch
        self.test_model()

    def test_model(self):
        # Get a batch of images from the training data
        batch = next(iter(self.trainer.train_dataloader))
        self.visualize_results(batch, f"Epoch {self.current_epoch}")

    def visualize_results(self, batch, title_prefix):
        self.eval()
        with torch.no_grad():
            # Test with images from the dataset
            img = batch[:2].to(self.device)  # Take up to 2 images
            recon, _, _ = self(img)
            for i in range(len(img)):
                original = img[i].cpu().permute(1, 2, 0).numpy()
                reconstructed = recon[i].cpu().permute(1, 2, 0).numpy()
                self.plot_results(original, reconstructed, f"{title_prefix} - Sample {i+1} from Dataset")
            
            # Test with random noise
            noise = torch.randn(2, self.latent_dim).to(self.device)
            generated = self.decode(noise)
            for i in range(2):
                generated_img = generated[i].cpu().permute(1, 2, 0).numpy()
                self.plot_results(np.zeros_like(generated_img), generated_img, f"{title_prefix} - Generated from Random Noise {i+1}")
            
            # Interpolation in latent space
            img1, img2 = batch[:2]  # Get two images
            img1, img2 = img1.to(self.device), img2.to(self.device)
            mu1, _ = self.encode(img1.unsqueeze(0))
            mu2, _ = self.encode(img2.unsqueeze(0))
            interpolations = torch.zeros(7, self.latent_dim).to(self.device)
            for i, alpha in enumerate(np.linspace(0, 1, 7)):
                interpolations[i] = alpha * mu1 + (1 - alpha) * mu2
            generated = self.decode(interpolations)
            self.plot_interpolation(generated, f"{title_prefix} - Interpolation in Latent Space")
        
        plt.show()

    def plot_results(self, original, reconstructed, title):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(original)
        ax1.set_title('Original')
        ax1.axis('off')
        ax2.imshow(reconstructed)
        ax2.set_title('Reconstructed')
        ax2.axis('off')
        plt.suptitle(title)

    def plot_interpolation(self, generated, title):
        fig, axes = plt.subplots(1, 7, figsize=(20, 4))
        for i, ax in enumerate(axes):
            ax.imshow(generated[i].cpu().permute(1, 2, 0).numpy())
            ax.axis('off')
        plt.suptitle(title)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [None]:

model = VAE()
trainer = pl.Trainer(max_epochs=100, accelerator="auto", devices="auto", strategy="auto", accumulate_grad_batches=4, precision="bf16-mixed")
trainer.fit(model, dataloader)

torch.save(model.state_dict(), 'vae_model.pth')


In [None]:
# generated images are not great likely due to the extreme variation in the dataset