In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt

In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image_size = 64
batch_size = 128
latent_dim = 100
num_epochs = 5
learning_rate = 1e-3

# Define transforms for both CelebA and MNIST
transform_celeba = transforms.Compose([
  transforms.Resize(image_size),
  transforms.CenterCrop(image_size), # CelebA images are often cropped
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize for RGB channels
])

transform_mnist = transforms.Compose([
  transforms.Resize(image_size),
  transforms.ToTensor(),
  transforms.Normalize((0.5,), (0.5,)) # Normalize for single channel - Corrected
])

# Define a placeholder CelebADatasetManual class if it's not defined elsewhere
class CelebADatasetManual(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        # Assuming the images are directly in the root_dir and are in a format
        # that torchvision.io.read_image can handle (e.g., .jpg, .png)
        # You might need to adjust this based on your actual CelebA file structure.
        self.image_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.image_files[idx]
        # Use torchvision.io to read the image
        # This requires installing torchvision with io capabilities (e.g., pip install torchvision)
        try:
            image = torchvision.io.read_image(img_path)
            # Convert grayscale to RGB if needed (e.g., for consistency with model input)
            if image.shape[0] == 1:
                image = image.repeat(3, 1, 1)
            # Convert from uint8 to float and scale to [0, 1]
            image = image.float() / 255.0
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a dummy image or handle the error as appropriate
            # Returning a tuple of dummy image and label
            return torch.zeros((3, image_size, image_size)), 0


        if self.transform:
            image = self.transform(image)

        # CelebA dataset typically doesn't have explicit class labels like MNIST,
        # but you might have attribute labels. For this VAE example, we can
        # just return a dummy label or modify the dataloader loop to only expect images.
        # Returning a tuple of image and dummy label
        label = 0

        return image, label


# Check if CelebA dataset is available and use it, otherwise use MNIST
celeba_root_dir = '/content/img_align_celeba' # Update this path if needed
if os.path.exists(celeba_root_dir):
    print("Using CelebA dataset")
    dataset = CelebADatasetManual(root_dir=celeba_root_dir, transform=transform_celeba)
else:
    print("CelebA dataset not found, using MNIST dataset")
    dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform_mnist)
    # dataset = Subset(dataset, range(500)) # Uncomment to use a smaller subset of MNIST for faster testing


dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

CelebA dataset not found, using MNIST dataset


100%|██████████| 9.91M/9.91M [00:00<00:00, 20.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 481kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.74MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.34MB/s]


In [None]:
class Encoder(nn.Module):
  def __init__(self, latent_dim):
    super(Encoder, self).__init__()
    # The first convolutional layer's input channels will be determined by the dataset loaded
    # We'll keep it at 3 for CelebA and adjust if MNIST is used.
    # However, since we are now switching the transform based on the dataset,
    # the input channels to the model should match the output channels of the transform.
    # The transforms are set up to output 3 channels for CelebA and 1 for MNIST.
    # We need to make the Encoder flexible or ensure the correct transform is applied.
    # Let's keep the Encoder for 3 channels as the primary goal is CelebA,
    # and ensure the MNIST transform outputs 3 channels by repeating the single channel.

    # *** Correction: It's better to modify the Encoder's first layer based on the dataset. ***
    # We will handle this by creating the model *after* the dataloader is set up,
    # and pass the number of input channels to the Encoder.

    self.conv = nn.Sequential(
      nn.Conv2d(3, 64, 4, 2, 1), nn.ReLU(), # Keep at 3 channels for CelebA
      nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
      nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(),
      nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU()
    )
    self.fc_mu = nn.Linear(512*4*4, latent_dim)
    self.fc_logvar = nn.Linear(512*4*4, latent_dim)

  def forward(self, x):
    x = self.conv(x)
    x = x.view(x.size(0), -1)
    return self.fc_mu(x), self.fc_logvar(x)

In [None]:
class Decoder(nn.Module):
  def __init__(self, latent_dim):
    super(Decoder, self).__init__()
    self.fc = nn.Linear(latent_dim, 512*4*4)
    self.deconv = nn.Sequential(
      nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(),
      nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
      nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(),
      nn.ConvTranspose2d(64, 3, 4, 2, 1), nn.Tanh() # Keep at 3 channels for CelebA
    )

  def forward(self, z):
    x = self.fc(z)
    x = x.view(x.size(0), 512, 4, 4)
    return self.deconv(x)

In [None]:
class VAE(nn.Module):
  def __init__(self, latent_dim):
    super(VAE, self).__init__()
    self.encoder = Encoder(latent_dim)
    self.decoder = Decoder(latent_dim)

  def reparameterize(self, mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

  def forward(self, x):
    mu, logvar = self.encoder(x)
    z = self.reparameterize(mu, logvar)
    return self.decoder(z), mu, logvar

In [None]:
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

In [None]:
# Determine input channels based on the dataset
input_channels = 3 if isinstance(dataset, CelebADatasetManual) else 1

class Encoder(nn.Module):
  def __init__(self, latent_dim, in_channels):
    super(Encoder, self).__init__()
    self.conv = nn.Sequential(
      nn.Conv2d(in_channels, 64, 4, 2, 1), nn.ReLU(),
      nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
      nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(),
      nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU()
    )
    self.fc_mu = nn.Linear(512*4*4, latent_dim)
    self.fc_logvar = nn.Linear(512*4*4, latent_dim)

  def forward(self, x):
    x = self.conv(x)
    x = x.view(x.size(0), -1)
    return self.fc_mu(x), self.fc_logvar(x)

class Decoder(nn.Module):
  def __init__(self, latent_dim, out_channels):
    super(Decoder, self).__init__()
    self.fc = nn.Linear(latent_dim, 512*4*4)
    self.deconv = nn.Sequential(
      nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(),
      nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
      nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(),
      nn.ConvTranspose2d(64, out_channels, 4, 2, 1), nn.Tanh()
    )

  def forward(self, z):
    x = self.fc(z)
    x = x.view(x.size(0), 512, 4, 4)
    return self.deconv(x)

class VAE(nn.Module):
  def __init__(self, latent_dim, in_channels, out_channels):
    super(VAE, self).__init__()
    self.encoder = Encoder(latent_dim, in_channels)
    self.decoder = Decoder(latent_dim, out_channels)

  def reparameterize(self, mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

  def forward(self, x):
    mu, logvar = self.encoder(x)
    z = self.reparameterize(mu, logvar)
    return self.decoder(z), mu, logvar

model = VAE(latent_dim, input_channels, input_channels).to(device) # Pass input/output channels
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
  model.train()
  total_loss = 0
  for images, _ in dataloader:
    images = images.to(device)
    recon, mu, logvar = model(images)
    loss = vae_loss(recon, images, mu, logvar)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
  print(f"Epoch [{epoch+1}/{num_epochs}], Loss:{total_loss/len(dataloader.dataset):.4f}")

  model.eval()
  with torch.no_grad():
    z = torch.randn(64, latent_dim).to(device)
    sample_images = model.decoder(z).cpu() * 0.5 + 0.5 # De-normalize
    grid = torchvision.utils.make_grid(sample_images, nrow=8, normalize=True)
    if input_channels == 1:
        plt.imshow(grid.permute(1, 2, 0), cmap='gray') # Use grayscale colormap for MNIST
        plt.title(f"Generated Digits - Epoch {epoch+1}")
    else:
        plt.imshow(grid.permute(1, 2, 0)) # Use default colormap for CelebA
        plt.title(f"Generated Faces - Epoch {epoch+1}")
    plt.axis('off')
    plt.show()

RuntimeError: output with shape [1, 64, 64] doesn't match the broadcast shape [3, 64, 64]