In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as f
import torchvision
from torchvision.transforms import Compose, ToTensor, RandomResizedCrop
import torch.optim as optim
from torchvision.datasets import CIFAR10, MNIST, Flowers102
from torch.utils.data import DataLoader

from functools import partial

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
def get_alpha_t(variance_schedule):
    return np.cumprod(1-variance_schedule)


class AugmentedDataset(torch.utils.data.Dataset):
    """Custom Dataset for generating noisy images."""
    def __init__(self, data, variance_schedule, timestep_distribution=None, transform=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data=data
        self.simestep_distribution = timestep_distribution
        self.alpha_t = np.cumprod(1-variance_schedule)
        #self.alpha_t = get_alpha_t(variance_schedule)
        self.T=len(variance_schedule)
        self.transform = transform

    def __getitem__(self, idx):
        img, _ = self.data[idx]

        if self.transform:
            img = self.transform(img)

        t=np.random.randint(self.T)
        noise=np.random.normal(size=img.shape)

        l1=np.sqrt(self.alpha_t[t])
        l2=np.sqrt(1-self.alpha_t[t])

        return (l1*img+l2*noise, t) , noise

    def __len__(self):
        return len(self.data)


class AugmentedDataLoader(torch.utils.data.DataLoader):
    """Custom DataLoader for also sampling the timestep.
    Maybe not needed, let's see.
    """
    def __init__(self, timestep_batching_function, timestep_batch_size, variance_schedule, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.timestep_batching_function = timestep_batching_function
        self.timestep_batch_size = timestep_batch_size
        self.variance_schedule = variance_schedule
        self.alpha_t = np.cumprod(1-variance_schedule)
        print()

    def __forward_noise(self, image_batch, timestep_batch):
        extended_image_batch = np.concatenate([image_batch for i in range(self.timestep_batch_size)])
        print(extended_image_batch.shape)
        noisy_image_batch = np.zeros(extended_image_batch.shape)

        for i, t in enumerate(timestep_batch):        # add noise to image batch
            alpha_t = self.alpha_t[t]
            loc = np.sqrt(alpha_t)*image_batch
            scale = np.sqrt(1-alpha_t)

            start_index = i*self.timestep_batch_size
            end_index = start_index + self.batch_size
            noisy_image_batch[start_index:end_index] = np.random.normal(loc=loc, scale=scale, size=image_batch.shape)

        extended_image_batch = torch.tensor(extended_image_batch)
        noisy_image_batch = torch.tensor(noisy_image_batch)

        return extended_image_batch, noisy_image_batch

    def __iter__(self):
        def augment_iter(old_iter):
            for image_batch, _ in old_iter:
                print('gggggggggggggggggggg')

                timestep_batch = self.timestep_batching_function(self.timestep_batch_size)

                extended_image_batch, noisy_image_batch = self.__forward_noise(image_batch, timestep_batch)
                yield noisy_image_batch, extended_image_batch, timestep_batch

        return augment_iter(super().__iter__())


In [None]:
image_size = (256, 256)
batch_size = 6
variance_schedule = np.ones(20)*0.0011

dataset_flowers = Flowers102(root='datasets',
                  download=True)

dataset = AugmentedDataset(dataset_flowers,
                           variance_schedule,
                           transform=Compose([ToTensor(),
                                            RandomResizedCrop(image_size)]))

data_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

In [None]:
(im_n, t), noise = next(iter(data_loader))
im_n = im_n[0]
t = t[0]
noise = noise[0]

In [None]:
print(t)

alpha_t = get_alpha_t(variance_schedule)[t]
im = (im_n - noise*np.sqrt(1-alpha_t))/np.sqrt(alpha_t)

w, h, dpi = 1500, 500, 100

fig, ax = plt.subplots(ncols=3, figsize=(w/dpi, h/dpi), dpi=dpi)
ax[0].imshow(im_n.permute(1, 2, 0))
ax[1].imshow(noise.permute(1, 2, 0))
ax[2].imshow(im.permute(1, 2, 0))