In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchdiffeq import odeint
import matplotlib.pyplot as plt
import numpy as np
import time

In [None]:
# Function to generate a mixture of Gaussians
def generate_mixture_of_gaussians(n_samples, n_components, dim, std=0.1):
    np.random.seed(42)
    centers = np.random.uniform(-1, 1, (n_components, dim))
    samples = []
    for _ in range(n_samples):
        component = np.random.randint(0, n_components)
        sample = np.random.normal(centers[component], std, dim)
        samples.append(sample)
    return np.array(samples), centers

# Plotting function
def plot_data(data, title):
    plt.figure(figsize=(8, 6))
    plt.scatter(data[:, 0], data[:, 1], s=5, alpha=0.7)
    plt.title(title)
    plt.xlabel('x1')
    plt.ylabel('x2')
    plt.grid(True)
    plt.show()

# Plot images in a grid
def plot_images_grid(images, title, n=10):
    plt.figure(figsize=(15, 5))
    for i in range(n):
        plt.subplot(2, n, i + 1)
        plt.imshow((images[i].transpose(1, 2, 0) * 0.5 + 0.5).clip(0, 1))
        plt.axis('off')
    plt.suptitle(title)
    plt.show()


In [None]:
### Real NVP Implementation for 2D Data ###
class AffineCouplingLayer2D(nn.Module):
    def __init__(self, input_dim, hidden_dim, mask):
        super(AffineCouplingLayer2D, self).__init__()
        self.mask = mask
        self.scale_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Tanh()
        )
        self.translate_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        mask = self.mask.to(x.device)  # Ensure the mask is on the same device as x
        x1 = x * mask
        s = self.scale_net(x1) * (1 - mask)
        t = self.translate_net(x1) * (1 - mask)
        z = x1 + (1 - mask) * (x * torch.exp(s) + t)
        log_det_jacobian = torch.sum((1 - mask) * s, dim=1)
        return z, log_det_jacobian

    def inverse(self, z):
        mask = self.mask.to(z.device)  # Ensure the mask is on the same device as z
        z1 = z * mask
        s = self.scale_net(z1) * (1 - mask)
        t = self.translate_net(z1) * (1 - mask)
        x = z1 + (1 - mask) * (z - t) * torch.exp(-s)
        return x

# Define the Real NVP Model for 2D data
class RealNVP2D(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_coupling_layers):
        super(RealNVP2D, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_coupling_layers = n_coupling_layers

        self.masks = []
        for i in range(n_coupling_layers):
            mask = torch.tensor([1 if j % 2 == 0 else 0 for j in range(input_dim)], dtype=torch.float32)
            if i % 2 == 0:
                mask = 1 - mask
            self.masks.append(mask)

        self.coupling_layers = nn.ModuleList([AffineCouplingLayer2D(input_dim, hidden_dim, self.masks[i])
                                              for i in range(n_coupling_layers)])

    def forward(self, x):
        log_det_jacobian = 0
        for layer in self.coupling_layers:
            x, layer_log_det_jacobian = layer(x)
            log_det_jacobian += layer_log_det_jacobian
        return x, log_det_jacobian

    def inverse(self, z):
        for layer in reversed(self.coupling_layers):
            z = layer.inverse(z)
        return z

# Define a function to train the Real NVP model for 2D data
def train_real_nvp_2d(model, data, n_epochs=100, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        z, log_det_jacobian = model(data)
        # Negative log likelihood loss
        loss = -torch.mean(torch.sum(torch.distributions.Normal(0, 1).log_prob(z), dim=1) + log_det_jacobian)
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 20 == 0:
            print(f"Epoch [{epoch + 1}/{n_epochs}], Loss: {loss.item():.4f}")


In [None]:
### Real NVP Implementation for Image Data ###
class AffineCouplingLayer(nn.Module):
    def __init__(self, num_channels, hidden_channels, mask):
        super(AffineCouplingLayer, self).__init__()
        self.mask = mask # The mask indicates which vector values are translated
        self.scale_net = nn.Sequential(
            nn.Conv2d(num_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, num_channels, kernel_size=3, padding=1),
            nn.Tanh()
        )
        self.translate_net = nn.Sequential(
            nn.Conv2d(num_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, num_channels, kernel_size=3, padding=1)
        )

    def forward(self, x):
        mask = self.mask.to(x.device)  # Ensure the mask is on the same device as x
        x1 = x * mask
        s = self.scale_net(x1) * (1 - mask) #sigma Network in the slides
        t = self.translate_net(x1) * (1 - mask) #mu Network in the slides
        z = x1 + (1 - mask) * (x * torch.exp(s) + t) 
        log_det_jacobian = torch.sum((1 - mask) * s, dim=[1, 2, 3])
        return z, log_det_jacobian

    def inverse(self, z):
        mask = self.mask.to(z.device)  # Ensure the mask is on the same device as z
        z1 = z * mask
        s = self.scale_net(z1) * (1 - mask)
        t = self.translate_net(z1) * (1 - mask)
        x = z1 + (1 - mask) * (z - t) * torch.exp(-s)
        return x

# Define the Real NVP Model for images with the updated AffineCouplingLayer
class RealNVP(nn.Module):
    def __init__(self, num_channels, hidden_channels, num_coupling_layers):
        super(RealNVP, self).__init__()
        self.num_channels = num_channels
        self.hidden_channels = hidden_channels
        self.num_coupling_layers = num_coupling_layers

        self.masks = []
        for i in range(num_coupling_layers):
            mask = torch.zeros(1, num_channels, 32, 32)
            mask[:, :, i % 2::2, :] = 1  # We are going to use Checkerboard mask for images
            self.masks.append(mask)

        self.coupling_layers = nn.ModuleList([AffineCouplingLayer(num_channels, hidden_channels, self.masks[i])
                                              for i in range(num_coupling_layers)])

    def forward(self, x):
        log_det_jacobian = 0
        for layer in self.coupling_layers:
            x, layer_log_det_jacobian = layer(x)
            log_det_jacobian += layer_log_det_jacobian
        return x, log_det_jacobian

    def inverse(self, z):
        for layer in reversed(self.coupling_layers):
            z = layer.inverse(z)
        return z

# Define a function to train the Real NVP model on CIFAR-10
def train_real_nvp(model, train_loader, n_epochs=10, lr=1e-4):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        for x, _ in train_loader:
            x = x.to(device)
            optimizer.zero_grad()
            z, log_det_jacobian = model(x)
            # Compute the loss as the negative log likelihood
            log_prob_z = torch.sum(torch.distributions.Normal(0, 1).log_prob(z), dim=[1, 2, 3])
            loss = -torch.mean(log_prob_z + log_det_jacobian)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch + 1}/{n_epochs}], Loss: {avg_loss:.4f}")

In [None]:
# Function to train FFJORD on 2D data
def train_ffjord_2d(model, data, n_epochs=100, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    integration_times = torch.tensor([0.0, 1.0]).to(data.device)

    for epoch in range(n_epochs):
        optimizer.zero_grad()
        z = model(data, integration_times)[1]
        # Negative log likelihood loss
        log_prob_z = torch.sum(torch.distributions.Normal(0, 1).log_prob(z), dim=1)
        loss = -torch.mean(log_prob_z)
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 20 == 0:
            print(f"Epoch [{epoch + 1}/{n_epochs}], Loss: {loss.item():.4f}")

# Train FFJORD on Gaussian mixture
class ODEFunc2D(nn.Module):
    def __init__(self, input_dim):
        super(ODEFunc2D, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, input_dim)
        )

    def forward(self, t, y):
        return self.net(y)

class FFJORD2D(nn.Module):
    def __init__(self, input_dim):
        super(FFJORD2D, self).__init__()
        self.ode_func = ODEFunc2D(input_dim)

    def forward(self, x, integration_times):
        return odeint(self.ode_func, x, integration_times)



### FFJORD Implementation for images ###
class ODEFuncImage(nn.Module):
    def __init__(self, input_dim):
        super(ODEFuncImage, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(input_dim, 64, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Conv2d(64, input_dim, kernel_size=3, padding=1)
        )

    def forward(self, t, y):
        return self.net(y)

class FFJORDImage(nn.Module):
    def __init__(self, num_channels):
        super(FFJORDImage, self).__init__()
        self.ode_func = ODEFuncImage(num_channels)

    def forward(self, x, integration_times):
        return odeint(self.ode_func, x, integration_times)

# Define a function to train FFJORD on CIFAR-10
def train_ffjord_images(model, train_loader, n_epochs=10, lr=1e-4):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    integration_times = torch.tensor([0.0, 1.0]).to(device)  # ODE integration from 0 to 1

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        for x, _ in train_loader:
            x = x.to(device)
            optimizer.zero_grad()
            z = model(x, integration_times)[1]
            # Compute the loss as the negative log likelihood
            log_prob_z = torch.sum(torch.distributions.Normal(0, 1).log_prob(z), dim=[1, 2, 3])
            loss = -torch.mean(log_prob_z)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch + 1}/{n_epochs}], Loss: {avg_loss:.4f}")



In [None]:

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
# Generate and visualize Gaussian mixture data
n_samples = 5000
n_components = 5
dim = 2
data, _ = generate_mixture_of_gaussians(n_samples, n_components, dim)
data = torch.tensor(data, dtype=torch.float32).to(device)

plot_data(data.cpu().numpy(), 'Original Mixture of Gaussians')

In [None]:
# Train Real NVP on Gaussian mixture
print("Training Real NVP on Gaussian Mixture Data...")
real_nvp_2d = RealNVP2D(dim, 128, 6).to(device)
start_time = time.time()
train_real_nvp_2d(real_nvp_2d, data, n_epochs=100, lr=1e-5)
end_time = time.time()
print(f"Real NVP Training Time: {end_time - start_time:.2f} seconds")

z_nvp_2d, _ = real_nvp_2d(data)
plot_data(z_nvp_2d.cpu().detach().numpy(), 'Transformed Data (Real NVP)')

In [None]:
print("Training FFJORD on Gaussian Mixture Data...")
ffjord_2d = FFJORD2D(dim).to(device)
start_time = time.time()
train_ffjord_2d(ffjord_2d, data, n_epochs=100, lr=1e-3)
end_time = time.time()
print(f"FFJORD Training Time: {end_time - start_time:.2f} seconds")

integration_times = torch.tensor([0.0, 1.0]).to(device)
z_ffjord_2d = ffjord_2d(data, integration_times)[1]
plot_data(z_ffjord_2d.cpu().detach().numpy(), 'Transformed Data (FFJORD)')

In [None]:
### Real NVP and FFJORD for CIFAR-10 Images ###

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
num_channels = 3  # CIFAR-10 has 3 color channels


In [None]:
# Define and train Real NVP for images
print("Training Real NVP on CIFAR-10 Data...")
real_nvp_model_image = RealNVP(num_channels, 64, 8).to(device)
start_time = time.time()
train_real_nvp(real_nvp_model_image, train_loader, n_epochs=10, lr=1e-4)
end_time = time.time()
print(f"Real NVP (Image) Training Time: {end_time - start_time:.2f} seconds")

In [None]:
# Test and visualize transformations
real_nvp_model_image.eval()
test_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
with torch.no_grad():
    x, _ = next(iter(test_loader))
    x = x.to(device)
    z_nvp_image, _ = real_nvp_model_image(x)
    x_reconstructed_nvp = real_nvp_model_image.inverse(z_nvp_image).cpu().numpy()

plot_images_grid(x.cpu().numpy(), 'Original CIFAR-10 Images (Real NVP)')
plot_images_grid(x_reconstructed_nvp, 'Reconstructed CIFAR-10 Images (Real NVP)')
# Free memory
del real_nvp_model_image
torch.cuda.empty_cache()

In [None]:
# Define and train FFJORD for images
print("Training FFJORD on CIFAR-10 Data...")
ffjord_model_image = FFJORDImage(num_channels).to(device)
start_time = time.time()
train_ffjord_images(ffjord_model_image, train_loader, n_epochs=10, lr=1e-4)
end_time = time.time()
print(f"FFJORD (Image) Training Time: {end_time - start_time:.2f} seconds")

In [None]:
# Test and visualize transformations
ffjord_model_image.eval()
with torch.no_grad():
    z_ffjord_image = ffjord_model_image(x, integration_times)[1]

plot_images_grid(x.cpu().numpy(), 'Original CIFAR-10 Images (FFJORD)')
plot_images_grid(z_ffjord_image.cpu().numpy(), 'Transformed CIFAR-10 Images (FFJORD)')
