In [None]:
from dataclasses import dataclass
from tqdm.auto import tqdm
from PIL import Image
import torch
import torchvision

# TODO adapt these parameters such that they work for your setup
@dataclass
class TrainingConfig:
    x_size = 28  # the generated x resolution
    num_channels = 1  # the number of channels in the generated x
    train_batch_size = 10
    eval_batch_size = 10  # how many xs to sample during evaluation
    num_epochs = 10
    learning_rate = 1e-4
    output_dir = "samples"

config = TrainingConfig()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torchvision.transforms as transforms
mnist_dataset = torchvision.datasets.MNIST(root='datasets/mnist', train=True, download=True, transform=transforms.ToTensor())
train_dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=1)
mnist_testdataset = torchvision.datasets.MNIST(root='datasets/mnist', train=False, download=True, transform=transforms.ToTensor())
test_dataloader = torch.utils.data.DataLoader(mnist_testdataset, batch_size=config.train_batch_size, shuffle=False, num_workers=1)

In [None]:
def get_alpha(t):
    return 1 - 0.9999 * t

In [None]:
import math
def get_alpha(t, start=0.2, end=1, tau=1, clip_min=1e-9):
    # A gamma function based on cosine function.
    v_start = math.cos(start * math.pi / 2) ** (2 * tau)
    v_end = math.cos(end * math.pi / 2) ** (2 * tau)
    output = torch.cos((t * (end - start) + start) * torch.pi / 2) ** (2 * tau)
    output = (v_end - output) / (v_end - v_start)
    return torch.clip(output, clip_min, 1.)

In [None]:
def forward_diffusion(clean_x, noise, t):
    # it takes the clean xs, the noise and the timesteps as input and returns the noisy xs
    alpha = get_alpha(t).to(clean_x.device)
    for _ in range(len(clean_x.shape) - 1):
        alpha = alpha.unsqueeze(-1)

    noisy_x = clean_x * torch.sqrt(alpha)  +  noise * torch.sqrt(1 - alpha)
    return noisy_x

In [None]:
def save_and_show(batch, name, nrow=1):
    x_grid = torchvision.utils.make_grid(batch, nrow)
    torchvision.utils.save_image(x_grid, name)
    display(Image.open(name))

sample_batch, sample_y = next(iter(train_dataloader))
noise = torch.randn_like(sample_batch)
noise_levels = []
for i in range(11):
    current_batch = forward_diffusion(sample_batch, noise, torch.tensor(i / 10))
    noise_levels.append(current_batch)

save_and_show(torch.cat(noise_levels), f'forward_diffusion.png', nrow=noise_levels[0].shape[0])

In [None]:
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=config.x_size,  # the target x resolution
    in_channels=1,  # the number of input channels, 3 for RGB xs
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",  # a regular ResNet downsampling block
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "UpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)

In [None]:
import torch.nn.functional as F
import torch

optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

def train_loop(config, model, forward_diffusion, optimizer, train_dataloader, device):
    model.to(device)
    global_step = 0
    sample_batch = next(iter(train_dataloader))[0].to(device)
    test_noise = torch.randn_like(sample_batch)

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader))
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            if epoch * len(train_dataloader) + step in [0, 100, 200, 1000, 5000, 10000, 20000]:
                noisy_xs_list = []
                sample_batch_reconstructed = []
                for noise_level in range(11):
                    t = torch.tensor(noise_level / 10)
                    alpha = get_alpha(t)
                    noisy_xs = forward_diffusion(sample_batch, test_noise, t)
                    noisy_xs_list.append(noisy_xs)
                    noise_pred = model(noisy_xs, t.to(device), return_dict=False)[0].detach()
                    sample_batch_reconstructed.append((noisy_xs - torch.sqrt(1 - alpha) * noise_pred) / torch.sqrt(alpha))
                    
                save_and_show(torch.cat(sample_batch_reconstructed, 0), f'reconstruction_{epoch}_{step}.png', nrow=sample_batch.shape[0])

            clean_xs = batch[0].to(device)
            # Sample noise to add to the xs
            noise = torch.randn(clean_xs.shape).to(device)
            bs = clean_xs.shape[0]

            # Sample a random timestep for each x
            t = torch.abs(1.0 - torch.rand(bs)).to(device)

            # Add noise to the clean xs according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_xs = forward_diffusion(clean_xs, noise, t)

            # Predict the noise residual
            noise_pred = model(noisy_xs, t, return_dict=False)[0]
            loss = F.mse_loss(noise_pred, noise)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "step": global_step}
            progress_bar.set_postfix(**logs)
            global_step += 1

In [None]:
train_loop(config, model, forward_diffusion, optimizer, train_dataloader, device)

In [None]:
torch.save(model, 'unet.pt')

In [None]:
model = torch.load('peal_runs/unet.pt', map_location=torch.device('cpu')).to(device)

In [None]:
def get_next_x(model, current_x, current_timestep, next_timestep, classifier_criterion=None, gradient_scale = 2.0):
    current_noise_pred = model(current_x, current_timestep, return_dict=False)[0].detach()

    #
    alpha_next = get_alpha(next_timestep).to(current_x.device)
    alpha_current = get_alpha(current_timestep).to(current_x.device)
    for _ in range(len(current_x.shape) - 1):
        alpha_next = alpha_next.unsqueeze(-1)
        alpha_current = alpha_current.unsqueeze(-1)

    #
    if not classifier_criterion is None:
        x_copy = torch.nn.Parameter(current_x)
        loss = classifier_criterion(x_copy)
        loss.backward()
        current_noise_pred -= gradient_scale * torch.sqrt(1 - alpha_current) * x_copy.grad.detach()

    #
    next_image = torch.sqrt(1 - alpha_next) * current_noise_pred
    next_image += torch.sqrt(alpha_next) * (current_x - torch.sqrt(1 - alpha_current) * current_noise_pred) / torch.sqrt(alpha_current)
        
    return next_image

In [None]:
def reverse_diffusion_ddim(model, noise, num_timesteps, classifier_criterion=None, gradient_scale = 2.0):
    # it should take noise, the model and the number of timesteps as input and return the generated images
    # Generate the initial image from the noise
    noisy_images_list = []
    current_x = torch.clone(noise)
    
    # Perform reverse diffusion for the specified number of timesteps
    progress_bar = tqdm(range(num_timesteps - 1))
    for t in range(1, num_timesteps + 1)[::-1]:
        # Generate the noise for the current timestep
        progress_bar.set_description(f"T {t}")
        timesteps = torch.ones([current_x.shape[0]], dtype=torch.float32).to(noise) * t
        current_timestep = timesteps / num_timesteps
        next_timestep = (timesteps - 1) / num_timesteps
        current_x = get_next_x(model, current_x, current_timestep, next_timestep, classifier_criterion=classifier_criterion, gradient_scale=gradient_scale)
        if t % (num_timesteps / 10) == 0:
            noisy_images_list.append(current_x)
    
    noisy_images_list.append(current_x)
    save_and_show(torch.cat(noisy_images_list, 0), f'z_to_x.png', nrow=current_x.shape[0])
    # Return the generated images
    return current_x

In [None]:
samples = reverse_diffusion_ddim(model, torch.randn_like(sample_batch).to(device), 1000)

In [None]:
def forward_diffusion_ddim(model, x, num_timesteps = 100):
    noisy_xs_list = []
    current_z = torch.clone(x)
    
    # Perform reverse diffusion for the specified number of timesteps
    progress_bar = tqdm(range(num_timesteps - 1))
    for t in range(num_timesteps):
        # Generate the noise for the current timestep
        progress_bar.set_description(f"T {t} / {num_timesteps}")
        timesteps = torch.ones([current_z.shape[0]], dtype=torch.int32).to(current_z) * t
        current_timestep = timesteps / num_timesteps
        next_timestep = (timesteps + 1) / num_timesteps
        current_z = get_next_x(model, current_z, current_timestep, next_timestep)
        if t % (num_timesteps / 10) == 0:
            noisy_xs_list.append(current_z)
    
    noisy_xs_list.append(current_z)
    save_and_show(torch.cat(noisy_xs_list, 0), f'x_to_z.png', nrow=current_z.shape[0])
    # Return the generated z
    return current_z

In [None]:
# sanity check inversion capabilities
reconstruction = reverse_diffusion_ddim(model, forward_diffusion_ddim(model, sample_batch.to(device), 1000), 1000)
print(torch.mean(torch.abs(reconstruction.cpu() - sample_batch)))

In [None]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x


def train(classifier, device, train_loader, optimizer, epoch, criterion):
    model.train()
    progress_bar = tqdm(range(len(train_loader)))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = classifier(data)
        loss = criterion(output, target)
        loss.backward()
        progress_bar.set_description(f"Epoch: {epoch}, Batch: {batch_idx} / {len(train_loader)}, Loss: {loss.item()}")
        optimizer.step()


def test(classifier, device, test_loader, epoch):
    model.eval()
    progress_bar = tqdm(range(len(test_loader)))
    accuracy = torch.tensor(0.0).to(device)
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = data.to(device), target.to(device)
        prediction = classifier(data).argmax(-1)
        accuracy += torch.sum((prediction == target).float())
        progress_bar.set_description(f"Epoch: {epoch}, Batch: {batch_idx} / {len(test_loader)}")
    
    print(f"Epoch: {epoch}, Accuracy: {accuracy / len(test_loader.dataset)}")


classifier = Net().to(device)
optimizer = optim.Adam(classifier.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(12):
    train(classifier, device, train_dataloader, optimizer, epoch, criterion)
    test(classifier, device, test_dataloader, epoch)

torch.save(classifier.state_dict(), "peal_runs/mnist_cnn.pt")

In [None]:
def make_classifier_criterion(classifier, target, criterion):
    def classifier_criterion(x):
        return criterion(classifier(x), target)
    
    return classifier_criterion

In [None]:
classifier_criterion = make_classifier_criterion(classifier, torch.arange(sample_batch.shape[0], dtype=torch.int64).to(device), criterion)

In [None]:
conditioned_samples = reverse_diffusion_ddim(model, torch.randn_like(sample_batch).to(device), 3000, classifier_criterion, 2.0)

In [None]:
classifier_criterion(conditioned_samples.to(device))

In [None]:
classifier_criterion2 = make_classifier_criterion(classifier, sample_y.to(device), criterion)
classifier_criterion2(conditioned_samples.to(device))