In [2]:
import glob
import os

import numpy as np
from skimage import io
from skimage.color import rgb2gray
from skimage.transform import resize
import torch
from torch.utils.data import Dataset

from torch import Tensor
from numbers import Number
from typing import Union

from torch.nn import Sequential
import torchvision.transforms as tvts
from torch.utils.data import IterableDataset, IterDataPipe
from torchdata.datapipes.iter import Mapper

%run DDRM_sampling.ipynb


class GalaxyImageDataset(Dataset):
    """
    Image dataset. Automatically detects file extensions.
    """

    def __init__(
        self,
        root_dir="/project/undark/galaxy_zoo_jpgs/",
        to_gray=True,
        n_pix=None,
        dtype=torch.float32,
        n_max=None,
    ):
        self.root_dir = root_dir

        exts = ["jpg", "jpeg", "png", "npy"]
        self.fnames = []
        for ext in exts:
            self.fnames.extend(glob.glob(os.path.join(root_dir, f"*.{ext}")))
        self.fnames = list(sorted(self.fnames))

        self.n_images = len(self.fnames)
        if n_max is not None:
            self.n_images = min(self.n_images, n_max)

        self.to_gray = to_gray
        self.n_pix = n_pix
        self.dtype = dtype

    def __len__(self):
        return self.n_images

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        ext = os.path.splitext(fname)[1][1:]
        if ext in ["jpg", "jpeg", "png"]:
            img = io.imread(fname)
        elif ext == "npy":
            img = np.load(fname)
        else:
            raise NotImplementedError(f"cannot load files with the extension {ext}")

        if idx >= self.n_images or idx < 0:
            raise IndexError()

        # Put channel last
        if img.shape[0] == 3:
            img = np.moveaxis(img, 0, -1)

        if self.n_pix is not None:
            img = resize(img, (self.n_pix, self.n_pix))

        if self.to_gray:
            assert img.shape[-1] == 3
            img = rgb2gray(img)
            return torch.as_tensor(img, dtype=self.dtype)
        else:
            return torch.as_tensor(img, dtype=self.dtype).movedim(-1, -3)


class Normalize(torch.nn.Module):
    def __init__(self, maximum: Union[Tensor, Number], minimum: Union[Tensor, Number] = 0, clip=True):
        super().__init__()
        self.max = maximum
        self.min = minimum
        self.clip = clip

    def forward(self, t: Tensor):
        res = (t - self.min) / self.max
        return res.clip_(0, 1) if self.clip else res


class ProbesDataset(IterableDataset[Tensor]):
    def __init__(
        self, files: IterDataPipe, size: int, loadfunc=torch.load,
        max_rotation=180, scale: Union[Number, tuple[Number, Number]]=1.2, max_translation=0.1,
        interp=tvts.InterpolationMode.BILINEAR, norm=5.5, clip=False
    ):
        """
        A dataset that loads galaxy images from a folder.
        Applies data augmentation using random rotation, scale, translation,
        and flipping, then crops the image to a desired size and normalises it.
        Parameters
        ----------
        folder
            location of the data: passed to `FileLister`
        size
            size of output images
        loadfunc
            function used to load file into `Tensor`. You may wish to pass
            ``partial(torch.load, map_device=DEVICE)`` to load directly on
            ``DEVICE``
        max_rotation
            maximum rotation (degrees) applied during augmentation
        scale
            scale passed to `RandomAffine` for augmentation. If `int`,
            interpreted as ``(1/scale, scale)``.
        max_translation
            maximum translation (as fraction of *original* image size) applied
            during augmentation (passed to `RandomAffine`)
        interp
            interpolation type passed to `RandomAffine`
        norm
            normalise images by this value
        clip
            clip values to the range ``(0, norm)`` (remapped to ``(0, 1)``)
        kwargs
            extra keyword arguments passed to `FileLister`
        """
        self.loader = Mapper(files, loadfunc)
        self.augmentation = Sequential(
            tvts.RandomAffine(
                degrees=max_rotation, scale=(1/scale, scale) if isinstance(scale, Number) else scale,
                translate=2*(max_translation,),
                interpolation=interp),
            tvts.CenterCrop(size),
            tvts.RandomHorizontalFlip(), tvts.RandomVerticalFlip(),
            Normalize(norm, clip=clip)
        )

    def __iter__(self):
        return map(self.augmentation, self.loader)
    
    
class DDPMDataset(LensingDDPM, IterableDataset[Tensor]):
    def __init__(self, diffusion: GaussianDiffusion, nbatch: int=None, clip=True, multi_channel=True, *, show_progress=True):
        """
        A dataset that samples galaxy images from a pre-trained DDPM model.
        Parameters
        ----------
        diffusion
            The pre-trained DDPM model.
        nbatch
            The number of images to sample in each iteration. Passing `None`
            disables batching and yields a single image with shape (C, H, W).
        clip
            Whether to clip to [-1, 1] while sampling. Note that the output is
            always re-normalised to [0, 1] (but not further clipped).
        multi_channel
            Whether to sample multi-channel images. Setting this to `False` is
            undefined for the moment.
        show_progress
            Whether to show a progress bar for the denoising steps.
        """
        super().__init__(diffusion=diffusion, multi_channel=multi_channel, show_progress=show_progress)
        self.nbatch = nbatch
        self.clip = clip

    def __iter__(self):
        while True:
            res = self.prior(1 if self.nbatch is None else self.nbatch, self.clip).flip(-3).add_(1).mul_(0.5)
            yield res.squeeze(-4) if self.nbatch is None else res


In [3]:
            
            
class DDRMDataset(LensingDDRM, IterableDataset[Tensor]):
    def __init__(self, diffusion: GaussianDiffusion,
                 Hb: Tensor, H: Tensor = None, U: Tensor = None, S: Tensor = None, Vh: Tensor = None,
                 nbatch: int=1, clip=True, multi_channel=True, *, show_progress=True,
                 y: Tensor = None, sigma_y=1, eta=0.03, eta_b=0.85, skip=1, skip_type="linear"):
        """
        A dataset that samples galaxy images from a DDRM model.
        Parameters
        ----------
        diffusion
            The pre-trained DDPM model.
        Hb, H, U, S, Vh
            SVD matrices.
        nbatch
            The number of images to sample in each iteration. Passing `None`
            disables batching and yields a single image with shape (C, H, W).
        clip
            Whether to clip to [-1, 1] while sampling. Note that the output is
            always re-normalised to [0, 1] (but not further clipped).
        multi_channel
            Whether to sample multi-channel images. Setting this to `False` is
            undefined for the moment.
        show_progress
            Whether to show a progress bar for the denoising steps.
        y
            Noisy observation on which to condition.
        sigma_y
            Noise level.
        eta
            eta hyperparmater for DDRM.
        eta_b
            eta_b hyperparmater for DDRM.
        skip
            How many steps to skip in the chain while sampling to accelerate DDRM.
        skip_type
            How to skip steps in the chain while sampling to accelerate DDRM. Options: "linear", "quad"
        """
        super().__init__(diffusion=diffusion, Hb=Hb, H=H, U=U, S=S, Vh=Vh, multi_channel=multi_channel, show_progress=show_progress)
        self.nbatch = nbatch
        self.clip = clip
        self.y = y
        self.sigma_y = sigma_y
        self.eta = eta
        self.eta_b = eta_b
        self.skip = skip
        self.skip_type = skip_type        

    def __iter__(self):
        
        while True:
            
#             print(self.y.shape)
#             print(self.H.shape)
            #print(self.H.sum(1).sum(1).shape)
            Y = 2*self.y - self.H.sum(1) # normalization
            
            res = self.sample(
                y=Y.expand(self.nbatch, *Y.shape[-3 if self.multi_channel else -2:]),
                sigma_y=2*self.sigma_y, skip=self.skip, skip_type=self.skip_type, clip=self.clip, eta=self.eta, eta_b=self.eta_b
            ).add_(1).mul_(0.5).unflatten(-1, self.src_shape)
            yield res.squeeze(-4) if self.nbatch==1 else res