In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../src")
import model.sdes as sdes
import model.branch_definer as branch_definer
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import tqdm.notebook

In [None]:
# Define device
if torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

### Create the data loader

In [None]:
samples_per_digit = 100

dataset = torchvision.datasets.CIFAR100(
    "/gstore/scratch/u/tsenga5/datasets/CIFAR-100", train=True, transform=(lambda img: (np.transpose(np.asarray(img), (2, 0, 1)) / 256 * 2) - 1)
)

dataset.targets = np.array(dataset.targets)
classes = np.unique(dataset.targets)

images_by_class = {}
for c in classes:
    mask = dataset.targets == c
    images = (dataset.data[mask] / 256 * 2) - 1
    images = torch.tensor(np.transpose(images, (0, 3, 1, 2)))
    sample_inds = torch.tensor(np.random.choice(len(images), size=samples_per_digit, replace=False))
    images_by_class[c] = images[sample_inds].to(DEVICE)
input_shape = (3, 32, 32)

In [None]:
# Create the diffuser
diffuser = sdes.VariancePreservingSDE(0.1, 20, input_shape)
t_limit = 1

### Plotting functions

In [None]:
def plot_cifar(
    images, grid_size=(10, 5), scale=1, clip=False, title=None
):
    """
    Plots CIFAR objects. No normalization will be done.
    Arguments:
        `images`: a B x 3 x 28 x 28 NumPy array of numbers between
            0 and 1
        `grid_size`: a pair of integers denoting the number of images
            to plot horizontally and vertically (in that order); if
            more digits are provided than spaces in the grid, leftover
            digits will not be plotted; if fewer images are provided
            than spaces in the grid, there will be at most one
            unfinished row
        `scale`: amount to scale figure size by
        `clip`: if True, clip values to between 0 and 1
        `title`: if given, title for the plot
    """
    images = np.transpose(images, (0, 2, 3, 1))
    if clip:
        images = np.clip(images, 0, 1)

    width, height = grid_size
    num_images = len(images)
    height = min(height, num_images // width)

    figsize = (width * scale, (height * scale) + 0.5)

    fig, ax = plt.subplots(
        ncols=width, nrows=height,
        figsize=figsize, gridspec_kw={"wspace": 0, "hspace": 0}
    )
    if height == 1:
        ax = [ax]
    if width == 1:
        ax = [[a] for a in ax]

    for j in range(height):
        for i in range(width):
            index = i + (width * j)
            ax[j][i].imshow(images[index], cmap="gray", aspect="auto", interpolation=None)
            ax[j][i].axis("off")
    if title:
        ax[0][0].set_title(title)
    plt.subplots_adjust(bottom=0.25)
    return fig

In [None]:
def plot_forward_diffusion(diffused_digits_by_class, times):
    """
    Plots example digits and the trajectory of the forward diffusion
    process.
    Arguments:
        `diffused_digits_by_class`: the output of
            `branch_definer.run_forward_diffusion`
        `times`: array of times that diffusion was performed at
    """
    # Plot some results over time
    inds_to_show = np.arange(0, len(times), len(times) // 15)
    inds_to_show = np.concatenate([inds_to_show, [len(times) - 1]])
    num_classes = len(diffused_digits_by_class)
    num_show_per_class = 3
    example_value = next(iter(diffused_digits_by_class.values()))
    batch_size, input_shape = example_value.shape[1], example_value.shape[2:]
    for t_i in inds_to_show:
        t = times[t_i]
        digits_to_show = np.empty((num_classes * num_show_per_class,) + input_shape)
        
        for c_i, c in enumerate(sorted(diffused_digits_by_class.keys())):
            # Sample digits
            inds = np.random.choice(
                diffused_digits_by_class[c].shape[1], size=num_show_per_class, replace=False
            )
            digits_to_show[c_i * num_show_per_class : (c_i + 1) * num_show_per_class] = \
                diffused_digits_by_class[c][t_i][inds].cpu().numpy()

        # Reshape:
        digits_to_show = digits_to_show.reshape(
            (num_classes, num_show_per_class) + digits_to_show.shape[1:]
        )
        digits_to_show = np.swapaxes(digits_to_show, 0, 1).reshape((-1,) + digits_to_show.shape[2:])
            
        plot_cifar(digits_to_show, grid_size=(num_classes, num_show_per_class), clip=True, title=("t = %.2f" % t))

    # Show distribution over time
    fig, ax = plt.subplots(figsize=(20, 8))
    bins = np.linspace(-4, 4, 1000)
    cmap = plt.get_cmap("magma")
    for t_i in tqdm.notebook.trange(len(times)):
        all_vals = np.concatenate([
            np.random.choice(np.ravel(digits[t_i].cpu().numpy()), size=100, replace=False)
            for digits in diffused_digits_by_class.values()
        ])
        ax.hist(all_vals, bins=bins, histtype="step", color=cmap(t_i / len(times)), alpha=0.5, density=True)
    prior = diffuser.sample_prior(batch_size, torch.ones(batch_size).to(DEVICE) * times[-1]).cpu().numpy()
    ax.hist(np.ravel(prior), bins=bins, histtype="step", color="blue", linewidth=2, density=True, label="Sampled prior")
    ax.set_xlabel("x")
    ax.set_ylabel("p(x)")
    ax.set_title("Evolution of p(x) over forward SDE")
    ax.set_ylim((0, 3))
    plt.legend()
    plt.show()

In [None]:
def plot_similarities(sim_matrix, classes):
    """
    Plots the similarities between classes.
    Arguments:
        `sim_matrix`: a T x C x C similarity matrix between classes
            at various time points, output by
            `branch_definer.compute_time_similarities`
        `classes`: list of classes matching the order in `sim_matrix`
    """
    # Plot initial pairwise similarities
    labels = np.array(["%s-%s" % (classes[i], classes[j]) for i in range(len(classes)) for j in range(i + 1)])
    sims = np.array([sim_matrix[0, i, j] for i in range(len(classes)) for j in range(i + 1)])
    inds = np.flip(np.argsort(sims))
    labels, sims = labels[inds], sims[inds]
    fig, ax = plt.subplots(figsize=(20, 4))
    ax.bar(labels, sims)
    ax.set_ylabel("Average similarity")
    ax.set_title("Average similarity between pairs of classes (t = 0)")
    plt.show()
                
    # Plot average similarity over time
    mean_sims = []
    triu_inds = np.triu_indices(len(classes))
    for t_i in range(len(sim_matrix)):
        mean_sims.append(np.mean(sim_matrix[t_i][triu_inds]))
    fig, ax = plt.subplots(figsize=(20, 4))
    ax.plot(mean_sims)
    ax.set_xlabel("t")
    ax.set_ylabel("Average similarity")
    ax.set_title("Average similarity over all classes during forward diffusion")
    plt.show()

### Compute branch definitions

In [None]:
times = np.linspace(0, t_limit, 1000)
diffused_images_by_class = branch_definer.run_forward_diffusion(
    images_by_class, diffuser, times
)
plot_forward_diffusion(diffused_images_by_class, times)
plt.show()
sim_matrix, sim_matrix_classes = branch_definer.compute_time_similarities(diffused_images_by_class, times)
plot_similarities(sim_matrix, sim_matrix_classes)
plt.show()
branch_points = branch_definer.compute_branch_points(
    sim_matrix, times, sim_matrix_classes, min_branch_time=0.4, min_branch_length=0, epsilon=1e-5
)
branch_defs = branch_definer.branch_points_to_branch_defs(branch_points, t_limit)
branch_defs = [bd for bd in branch_defs if bd[2] - bd[1]]  # Remove 0-length branches
print(branch_points)
print()
print(branch_defs)