In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from normalizing_flows.src.data_pipeline.dataset import CelebADataset, RandomSubsetDataset
from normalizing_flows.src.realnvp import realnvp_flow
from normalizing_flows.src.callbacks import EarlyStopping, ModelCheckpoint


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Load Data

In [None]:
# Define image transformations
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Load CelebA dataset
dataset = CelebADataset(
    root='../data',
    split='train',
    download=True,  # if you have trouble downloading the images, download them manually and move the zip file to ../data/celeba/
    transform=transform
)
dataset_subset = RandomSubsetDataset(dataset, subset_size=30000)

# Create dataloader
dataloader = DataLoader(
    dataset_subset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    collate_fn=dataset.collate_fn_skip_errors
)

# Print dataset information
print(f"Dataset size: {len(dataset)}")
print(f"Number of batches: {len(dataloader)}")
print(f"Batch shape: {next(iter(dataloader))[0].shape}")

In [None]:
def show_images(images, nrow=8):
    """Display a grid of images."""
    images = images.cpu()
    images = images * 0.5 + 0.5  # Denormalize
    grid = torchvision.utils.make_grid(images, nrow=nrow)
    plt.figure(figsize=(15, 15))
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis('off')
    plt.show()


sample_batch = next(iter(dataloader))[0][:32]
show_images(sample_batch)

## Create Model

In [None]:
from normalizing_flows.src.realnvp import realnvp_flow

In [None]:
temp = realnvp_flow.RealNVP(3, 32, n_hidden_layers=1)
# foo = temp(sample_batch)[0]
foo = temp.inverse(temp(sample_batch)[0])

In [None]:
foo.shape

In [None]:
show_images(sample_batch)

In [None]:
model = realnvp_flow.RealNVPFlow(
    in_channels=3,  # RGB images
    height=32,
    width=32,
    hidden_channels=64,
    n_hidden_layers=2,
    n_coupling_layers=4
)
model = model.to(device)
print(model)

In [None]:
opt = torch.optim.Adam(model.parameters(), lr=2e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt, factor=0.5, patience=4, threshold=0.001, threshold_mode='abs'
)
early_stopping = EarlyStopping(patience=8, threshold=0.001)
save_dir = 'checkpoints'
model_checkpoint = ModelCheckpoint(
    save_dir=save_dir,
    filename='realnvp_{epoch:03d}_{score:.3f}.pt',
    save_best_only=True
)
n_epochs = 10

In [None]:
model = model.train()
for ep in range(n_epochs):
    loss_sum = 0
    for i, (x, _) in tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {ep:02d}"):
        x = x.to(device)
        opt.zero_grad()
        loss = -model.log_prob(x).mean()
        loss.backward()
        opt.step()
        loss_sum += loss.detach().cpu().item()

    loss_avg = loss_sum / len(dataloader)
    lr = scheduler.get_last_lr()[0]
    scheduler.step(loss_avg)
    model_checkpoint.save(model, score=loss_avg, epoch=ep)

    print(f"Epoch {ep+1}/{n_epochs}, loss: {loss_avg:.4f}, lr: {lr}")

    if early_stopping(loss_avg):
        print(f'EarlyStopping activated. Ending training now.')
        break

best_path = os.path.join('checkpoints', os.listdir('checkpoints')[-1])
print(f"Loading best model from checkpoint: {best_path}.")
model_checkpoint.load(model, best_path)

In [None]:
model = model.eval()
with torch.no_grad():
    samples = model.sample(32)
show_images(samples)

In [None]:
def visualize_latent_space(model, dataloader, device, n_samples=1000):
    """Visualize the latent space of the model."""
    model = model.eval()
    zs = []
    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            if i * dataloader.batch_size >= n_samples:
                break
            x = x.to(device)
            z, _ = model.forward(x)
            zs.append(z.cpu())
    
    zs = torch.cat(zs, dim=0)[:n_samples]
    zs = zs.view(n_samples, -1)  # Flatten
    
    # Plot first two dimensions
    plt.figure(figsize=(10, 10))
    plt.scatter(zs[:, 0], zs[:, 1], alpha=0.5)
    plt.title('Latent Space Visualization')
    plt.xlabel('z1')
    plt.ylabel('z2')
    plt.show()

In [None]:
visualize_latent_space(model, dataloader, device)