In [None]:
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

batch_size = 256
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Data preparation for the test dataset
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
import torch
from torchvision import datasets, transforms

def print_dataset_image_dimensions(dataset):
    """Prints the dimensions of an image in a PyTorch dataset."""
    try:
        
        sample = dataset[0]
        if isinstance(sample, tuple):
            image = sample[0]
        elif hasattr(sample, 'shape'):
            image = sample
        else:
            print("Error: Dataset sample does not have an image or shape attribute.")
            return
        
        print(f"Image shape: {image.shape}")
    except IndexError as e:
        print(f"Error accessing dataset element: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")


print_dataset_image_dimensions(train_dataset)


In [None]:
print(len(train_dataset))
print(len(test_dataset))


In [None]:

import matplotlib.pyplot as plt


# Get the first batch of images and labels
data_iter = iter(train_loader)
images, labels = next(data_iter)


image = images[13] 
label = labels[0]  

plt.figure(figsize=(1, 1))  
plt.imshow(image.squeeze(), cmap="gray")
plt.title(f"Label: {label}")
plt.axis("off")
plt.show()

**Encoder and decoder of VAE are decoded as follows.**
- Encoder (Gaussian distribution)：$$q_{\phi}({\bf z}|{\bf x}) = {\mathcal N}({\bf z}| \mu,\sigma^2{\bf I}), 　s.t.　\mu=g^{\mu}_{\phi}({\bf x}), \sigma=g^{\sigma}_{\phi}({\bf x}). $$
- Decoder (Bernoulli distribution)：$$p_{\theta}({\bf x}|{\bf z}) = Ber({\bf x}| \lambda), 　s.t.　\lambda=f_{\theta}({\bf z}).$$
<br>

**ELBO (Evidence Lower BOund) is represented like below.**  The first and the second term are corresponding to (minus) reconstruction loss and KL divergence, respectively. For the implementation of β-VAE, the coefficient β(>1) in the regularisation term of the objective function is needed in order to make the model more disentangled.


$$
 {\mathcal L}({\bf x};{\bf \theta},{\bf \phi}) = \mathbb{E}_{q_{\phi}({\bf z}|{\bf x})}[\log p_\theta({\bf x}|{\bf z})] -\beta D_{KL}[q_{\phi}({\bf z}|{\bf x})||p_{\theta}({\bf z})]
 $$  

#**MODEL**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BetaVAE(nn.Module):
    def __init__(self, latent_dim=10, beta=1.0):
        super(BetaVAE, self).__init__()
        self.latent_dim = latent_dim
        self.beta = beta

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 3 * 3, 128),
        )

        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)

        # Decoder
        self.decoder_fc = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 32 * 3 * 3),
            nn.ReLU(),
        )

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=0, output_padding=0),  # From 3x3 to 7x7
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),  # From 7x7 to 14x14
            nn.ReLU(),
            nn.ConvTranspose2d(8, 1, kernel_size=4, stride=2, padding=1, output_padding=0),  # From 14x14 to 28x28
            nn.Sigmoid(),
        )



    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # Encoding
        x = self.encoder(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)

        # Reparameterization
        z = self.reparameterize(mu, logvar)

        # Decoding
        x = self.decoder_fc(z)
        x = x.view(-1, 32, 3, 3)
        x = self.decoder_conv(x)
        return x, mu, logvar


    def decode(self, z):
        """Decode a latent vector `z` into an image."""
        x = self.decoder_fc(z)
        x = x.view(-1, 32, 3, 3)
        x = self.decoder_conv(x)
        return x

    
    def loss_function(self, recon_x, x, mu, logvar):
        recon_loss = F.mse_loss(recon_x, x, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + self.beta * kl_loss



# Loss Function

In [None]:
def loss_function(self, recon_x, x, mu, logvar):
        recon_loss = F.mse_loss(recon_x, x, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + self.beta * kl_loss

# **TRAINING**

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:

latent_dim = 10
beta = 0.5  # Adjust based on experiments for balance between reconstruction and disentanglement
epochs = 20
learning_rate = 1e-3


model = BetaVAE(latent_dim=latent_dim, beta=beta).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate , weight_decay=1)

In [None]:
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for batch in train_loader:  # Assuming train_loader is defined
        images, _ = batch
        images = images.to(device)

        optimizer.zero_grad()
        recon_images, mu, logvar = model(images)
        loss = model.loss_function(recon_images, images, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {train_loss / len(train_loader.dataset)}")


**Visualization of reconstructed image**

Original image

In [None]:
def visualize_reconstructions_and_generations(model, test_loader, num_examples=100):
    import matplotlib.pyplot as plt
    model.eval()
    
    images_collected = []
    reconstructions_collected = []
    
    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(device)
            recon_images, _, _ = model(images)
            images_collected.append(images)
            reconstructions_collected.append(recon_images)
            
            # Stop when we collect enough examples
            if len(images_collected) * batch_size >= num_examples:
                break
    
    images_collected = torch.cat(images_collected, dim=0)[:num_examples]
    reconstructions_collected = torch.cat(reconstructions_collected, dim=0)[:num_examples]
    
    fig, axes = plt.subplots(2, 100, figsize=(num_examples, 10))
    for i in range(100):  # Show 10 examples
        axes[0, i].imshow(images_collected[i].cpu().squeeze(), cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].imshow(reconstructions_collected[i].cpu().squeeze(), cmap="gray")
        axes[1, i].axis("off")
    axes[0, 0].set_title("Originals", fontsize=16)
    axes[1, 0].set_title("Reconstructions", fontsize=16)
    plt.show()

    noise = torch.randn(num_examples, 10).to(device)  # Latent dimension = 10
    with torch.no_grad():
        generated_images = model.decode(noise)
    
    fig, axes = plt.subplots(1, 100, figsize=(num_examples, 10))
    for i in range(100):  # Show 10 generated examples
        axes[i].imshow(generated_images[i].cpu().squeeze(), cmap="gray")
        axes[i].axis("off")
    plt.suptitle("Generated Examples", fontsize=16)
    plt.show()


In [None]:
visualize_reconstructions_and_generations(model, test_loader, num_examples=100)

Reconstructed image

Random sampling from latent variable

In [None]:
from torchvision.models import inception_v3
from scipy.linalg import sqrtm
import numpy as np



def calculate_fid(real_features, generated_features):
    mu_real = np.mean(real_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)
    mu_gen = np.mean(generated_features, axis=0)
    sigma_gen = np.cov(generated_features, rowvar=False)

    diff = mu_real - mu_gen
    covmean, _ = sqrtm(sigma_real.dot(sigma_gen), disp=False)

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff.dot(diff) + np.trace(sigma_real + sigma_gen - 2 * covmean)
    return fid

inception = inception_v3(pretrained=True, transform_input=False).to(device)
inception.fc = nn.Identity()
real_features = []
gen_features = []

with torch.no_grad():
    for images, _ in test_loader:
        images = images.to(device)

        # Resize images to 75x75 and convert to RGB
        images_resized = F.interpolate(images, size=(75, 75), mode='bilinear')
        images_rgb = images_resized.repeat(1, 3, 1, 1)

        real_features.append(inception(images_rgb).cpu().numpy())

    for gen_images in generated_images:
        gen_images_resized = F.interpolate(gen_images, size=(75, 75), mode='bilinear')
        gen_images_rgb = gen_images_resized.repeat(1, 3, 1, 1)

        gen_features.append(inception(gen_images_rgb).cpu().numpy())

real_features = np.concatenate(real_features, axis=0)
gen_features = np.concatenate(gen_features, axis=0)

fid_score = calculate_fid(real_features, gen_features)
print(f"FID Score: {fid_score}")




