In [11]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from typing import Optional, Callable, Tuple
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from torch import nn, optim
from torch.nn import functional as F
from torchvision import models
from torch.utils.data.dataset import random_split

In [12]:
class plainVAE(nn.Module):
    def __init__(self, latent_dim: int = 128):
        super(plainVAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder: input 3x128x128 -> latent mean and log variance
        self.encoder = nn.Sequential(
            # Layer 1: 3x128x128 -> 32x64x64 # Should I use bigger images?
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),

            # Layer 2: 32x64x64 -> 64x32x32 # Should I use bigger images?
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # Layer 3: 64x32x32 -> 128x16x16
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # Layer 4: 128x16x16 -> 256x8x8
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # # Layer 5: 256x8x8 -> 512x4x4
            # nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(512),
            # nn.ReLU(inplace=True),
        )

        self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)

        self.decoder_fc = nn.Linear(latent_dim, 512 * 4 * 4)
        self.decoder = nn.Sequential(
            # Layer 1: 512*4*4 -> 256*8*8
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # Layer 2: 256*8*8 -> 128*16*16
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # Layer 3: 128*16*16 -> 64*32*32
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            # # Layer 4: 64*32*32 -> 32*64*64
            # nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(32),
            # nn.ReLU(inplace=True),

            # Layer 4: 32*64*64 -> 3*128*128
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        h = h.view(h.size(0), -1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterization(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = self.decoder_fc(z)
        h = h.view(h.size(0), 512, 4, 4)
        return self.decoder(h)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterization(mu, logvar)
        x_reconstructed = self.decode(z)
        return x_reconstructed, mu, logvar

    def loss(self, x, recon_x, mu, logvar, beta: float = 1.0):
        recon_loss = F.mse_loss(x, recon_x, reduction='sum') / x.size(0)
        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)

        return recon_loss + beta * kl, recon_loss, kl

In [13]:
class VocalPortraitDataset:
    def __init__(self, root_path: str, transform: Optional[Callable] = None):

        self.root_path = root_path
        self.transform = transform
        self.samples = self._build_samples()  # List of (image_path, label) tuples
        
    def _build_samples(self) -> list[Tuple[str, str]]:
        """
        Returns:
            List of tuples containing (image_path, label) where label is the
            relative path from person directory (e.g. "person001/English/source001/face001.jpg")
        """
        samples = []
        
        for root, _, files in os.walk(self.root_path):
            for file in files:

                if file.lower().endswith(('.jpg', '.jpeg', '.png')):

                    full_path = os.path.join(root, file)
                    rel_path = os.path.relpath(full_path, start=self.root_path)
                    samples.append((full_path, rel_path))

        return samples

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str]:
        """
        Returns:
            Tuple containing:
            - The transformed image tensor
            - The label string (relative path from person directory)
        """
        image_path, label = self.samples[idx]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)
            
        return image, label

In [14]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    ], p=0.5),
    transforms.ToTensor(),         # this ensures values ∈ [0,1]
])

In [15]:
# To handle the labels properly
def collate_fn(batch):
    images = torch.stack([item[0] for item in batch])  # Stack images
    labels = [item[1] for item in batch]               # Keep labels as a list
    return images, labels


dataset = VocalPortraitDataset(root_path="data/mavceleb_train/faces", transform=transform)
lengths = [int(len(dataset)*0.8), len(dataset) - int(len(dataset)*0.8)]
train_dataset, test_dataset = random_split(dataset, lengths)
batch_size = 64


In [16]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,  # Handles labels
    num_workers=4,
    pin_memory=True,  # Faster GPU transfer if using CUDA
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)

In [None]:
# ----------------------------
# 3) Instantiate model, optimizer, device
# ----------------------------
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
model = plainVAE(latent_dim=128).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(f'Using device: {device}')

In [None]:
# ----------------------------
# 4) Training loop
# ----------------------------
num_epochs = 1
log_interval = 100

# Checkpoint configuration
resume = True
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Initialize CSV storage for mu, logvar
train_latent_data = []

# Resume logic
start_epoch = 1
if resume:
    checkpoint_files = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith("vae_epoch")])
    if checkpoint_files:
        latest_checkpoint = os.path.join(checkpoint_dir, checkpoint_files[-1])
        model.load_state_dict(torch.load(latest_checkpoint))
        print(f"Resumed model from {latest_checkpoint}")
        start_epoch = int(latest_checkpoint.split("epoch")[1].split(".")[0]) + 1
        # To resume optimizer state (if saved):
        # optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, f"optimizer_epoch{start_epoch-1}.pth")))

for epoch in range(start_epoch, num_epochs + 1):
    model.train()
    running_recon_loss = 0.0
    running_kl_loss = 0.0
    running_total_loss = 0.0

    for batch_idx, (imgs, labels) in enumerate(train_dataloader, start=1):
        imgs = imgs.to(device)   # shape (B, 3, 128, 128)
        optimizer.zero_grad()

        recon_imgs, mu, logvar = model(imgs)
        loss, recon_loss, kl_loss = model.loss(imgs, recon_imgs, mu, logvar, beta=1.0)
        loss.backward()
        optimizer.step()

        # Store training latent variables for last epoch only
        if epoch == num_epochs:
            for label, m, lv in zip(labels, mu, logvar):
                train_latent_data.append({
                    "label": label,
                    "mu": m.cpu().detach().numpy(),
                    "logvar": lv.cpu().detach().numpy()
                })

        running_total_loss += loss.item()
        running_recon_loss += recon_loss.item()
        running_kl_loss += kl_loss.item()

        if batch_idx % log_interval == 0:
            avg_total = running_total_loss / log_interval
            avg_recon = running_recon_loss / log_interval
            avg_kl    = running_kl_loss / log_interval
            print(f"Epoch [{epoch}/{num_epochs}]  "
                  f"Batch [{batch_idx}/{len(train_dataloader)}]  "
                  f"Total Loss: {avg_total:.3f}  "
                  f"Reconstruction: {avg_recon:.3f}  "
                  f"KL: {avg_kl:.3f}    "
                  f"KL: {len(mu), len(logvar)}")
            running_total_loss = 0.0
            running_recon_loss = 0.0
            running_kl_loss    = 0.0
    
    # Save checkpoints
    '''model_checkpoint = os.path.join(checkpoint_dir, f"vae_epoch{epoch}.pth")
    optimizer_checkpoint = os.path.join(checkpoint_dir, f"optimizer_epoch{epoch}.pth")
    torch.save(model.state_dict(), model_checkpoint)
    torch.save(optimizer.state_dict(), optimizer_checkpoint)
    print(f"Saved checkpoint at epoch {epoch}")'''

print("Training finished.")

In [None]:
# Save mu, logvar to csv
def save_latent_data(data, filename):

    df = pd.DataFrame(data)
    
    # Convert numpy arrays to string representation for CSV
    df['mu'] = df['mu'].apply(lambda x: np.array2string(x, separator=',', threshold=np.inf))
    df['logvar'] = df['logvar'].apply(lambda x: np.array2string(x, separator=',', threshold=np.inf))
    
    df.to_csv(filename, index=False)
    print(f"Saved latent variables to {filename}")

save_latent_data(train_latent_data, "train_latent_variables.csv")


In [None]:
model.eval()
for batch_idx, (imgs, _) in enumerate(test_dataloader, start=1):
    with torch.no_grad():
        imgs = imgs.to(device)
        sample_imgs, _, _ = model(imgs)   # (64, 3, 128, 128)
        # Save a grid of 64 samples as a single image
        save_image(sample_imgs.cpu(), f"sample_epoch_test.png", nrow=8, normalize=True)
    break

In [None]:
'''class VocalPortraitDataset():
    # TODO: labels not implemented
    def __init__(self, root_path:str, transform:None):
        # root_path is the faces folder path
        self.root_path = root_path
        self.transform = transform
        self.image_paths = self.get_image_paths()

    def get_image_paths(self):
        image_paths = []
        
        person_folder = os.listdir(self.root_path)
        for person in person_folder:
            for nationality in ['English', 'Urdu']:
                if os.path.exists(f"{self.root_path}/{person}/{nationality}"):
                    video_ids = os.listdir(f"{self.root_path}/{person}/{nationality}")
                    for video_id in video_ids:
                        image_names = os.listdir(f"{self.root_path}/{person}/{nationality}/{video_id}")
                        for image_name in image_names:
                            image_paths.append(f"{self.root_path}/{person}/{nationality}/{video_id}/{image_name}")
        return image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        # label = self.labels[idx]
        return image, idx
    


data_root = "/kaggle/input/vocalportrait/mavceleb_v1_train_cropped/mavceleb_v1_train_cropped/faces"
dataset = VocalPortraitDataset(data_root, transform)

lengths = [int(len(dataset)*0.8), len(dataset) - int(len(dataset)*0.8)]
train_dataset, test_dataset = random_split(dataset, lengths)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    pin_memory=True,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=True,
    pin_memory=True,
)
# ----------------------------
# 3) Instantiate model, optimizer, device
# ----------------------------
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = plainVAE(latent_dim=128).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


# ----------------------------
# 4) Training loop
# ----------------------------
num_epochs = 5
log_interval = 100   # how many batches between printouts

resume = True
checkpoint_path = "/kaggle/input/vae8epoch/pytorch/default/1/vae_epoch8.pth"  # or whichever epoch you want

start_epoch = 1
if resume and os.path.exists(checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path))
    print(f"Resumed model from {checkpoint_path}")
    # If you also saved the optimizer state, load it too:
    # optimizer.load_state_dict(torch.load("optimizer_epoch5.pth"))
    # start_epoch = int(checkpoint_path.split("epoch")[1].split(".")[0]) + 1
    start_epoch = 9

for epoch in range(start_epoch, num_epochs + 1):
    model.train()
    running_recon_loss = 0.0
    running_kl_loss = 0.0
    running_total_loss = 0.0

    for batch_idx, (imgs, _) in enumerate(train_dataloader, start=1):
        imgs = imgs.to(device)   # shape (B, 3, 128, 128)
        optimizer.zero_grad()

        recon_imgs, mu, logvar = model(imgs)
        loss, recon_loss, kl_loss = model.loss(imgs, recon_imgs, mu, logvar, beta=1.0)
        loss.backward()
        optimizer.step()

        running_total_loss += loss.item()
        running_recon_loss += recon_loss.item()
        running_kl_loss += kl_loss.item()

        if batch_idx % log_interval == 0:
            avg_total = running_total_loss / log_interval
            avg_recon = running_recon_loss / log_interval
            avg_kl    = running_kl_loss / log_interval
            print(f"Epoch [{epoch}/{num_epochs}]  "
                  f"Batch [{batch_idx}/{len(train_dataloader)}]  "
                  f"Total Loss: {avg_total:.3f}  "
                  f"Reconstruction: {avg_recon:.3f}  "
                  f"KL: {avg_kl:.3f}")
            running_total_loss = 0.0
            running_recon_loss = 0.0
            running_kl_loss    = 0.0

    # At the end of each epoch, you can save a checkpoint:
    checkpoint_path = f"vae_epoch{epoch}.pth"
    torch.save(model.state_dict(), checkpoint_path.format({"i": epoch}))

    # (Optional) Also: generate a few samples from N(0,I) and save to disk
    model.eval()
    with torch.no_grad():
        sample_z = torch.randn(64, model.latent_dim).to(device)
        sample_imgs = model.decode(sample_z)   # (64, 3, 128, 128)
        # Save a grid of 64 samples as a single image
        save_image(sample_imgs.cpu(), f"sample_epoch{epoch}.png", nrow=8, normalize=True)

print("Training finished.")



model.eval()
for batch_idx, (imgs, _) in enumerate(test_dataloader, start=1):
    with torch.no_grad():
        imgs = imgs.to(device)
        sample_imgs, _, _ = model(imgs)   # (64, 3, 128, 128)
        # Save a grid of 64 samples as a single image
        save_image(sample_imgs.cpu(), f"sample_epoch_test.png", nrow=8, normalize=True)
    break'''