In [1]:
from functools import cached_property

import numpy as np
import torch
from torch import Tensor
from tqdm.auto import tqdm

%run DDPM_for_pretrained.ipynb 

In [2]:
class LensingDDPM:
    @classmethod
    def from_pretrained_astroddpm(cls, fname, *args, map_location=None, **kwargs):
        diffusion = GaussianDiffusion(
            Unet(dim=64, dim_mults=(1, 2, 4, 8)),
            timesteps=1000, loss_type='l1'
        ).eval()
        
        
        diffusion.load_state_dict({
            key[len('module.'):]: val for key, val in torch.load(
                fname, map_location=map_location
            )['model'].items()
        })
        return cls(diffusion.to(map_location), *args, **kwargs)

    def __init__(self, diffusion: GaussianDiffusion, multi_channel=True, *, show_progress=True):
        self.diffusion = diffusion
        self.multi_channel = multi_channel
        self.show_progress = show_progress

    @property
    def device(self):
        return next(self.diffusion.parameters()).device

    def with_progress(self, iterable):
        return tqdm(iterable) if self.show_progress else iterable

    @cached_property
    def nchannels(self):
        return self.diffusion.channels

    @cached_property
    def src_shape(self):
        return (() if self.multi_channel else (self.nchannels,)) + 2*(self.diffusion.image_size,)

    @cached_property
    def ts(self):
        return tuple(reversed(range(self.diffusion.num_timesteps)))

    def prior(self, n=1, clip=False):
        img = torch.randn((n, self.nchannels, *self.src_shape[-2:]), device=self.device)
        for i in self.with_progress(self.ts):
            img = self.diffusion.p_sample(
                img, img.new_full((n,), i, dtype=torch.long),
                clip_denoised=clip
            )
        return img



class LensingDDRM(LensingDDPM):
    def __init__(self, diffusion: GaussianDiffusion, Hb: Tensor,
                 H: Tensor = None,
                 U: Tensor = None, S: Tensor = None, Vh: Tensor = None,
                 multi_channel=True, show_progress=True):
        super().__init__(diffusion, multi_channel=multi_channel, show_progress=show_progress)

        self.Hb = Hb
        self.H = Hb.flatten(-2).flatten(-4,-3).flatten(0,1).flatten(1,2) if H is None else H
        print(self.H.shape)
        
        if not self.multi_channel:
            print('something happens here')
            self.H = torch.cat(diffusion.channels * (self.H/diffusion.channels,), dim=-1)

        if U is not None:
            self.U, self.S, self.Vh = U, S, Vh
        
#         print(Hb.shape)
#         print(self.Hb.shape[-4:-2])
        
    @cached_property
    def img_mask(self):
        return self.Hb.sum((-2, -1)) == 0

    @cached_property
    def img_shape(self):
        return self.Hb.shape[-4:-2]

    def _svd(self):
        print('SVD-ing')
        
        self.H = self.H.float()
        self.U, self.S, self.Vh = torch.linalg.svd(self.H, full_matrices=False)
#         self.U = self.U.to(dtype=torch.float16)
#         self.S = self.S.to(dtype=torch.float16)
#         self.Vh = self.Vh.to(dtype=torch.float16)
#         print('H shape',self.H.shape)
#         print('U shape',self.U.shape)
#         print('S shape',self.S.shape)
#         print('V shape',self.Vh.shape)
        
    @cached_property
    def U(self):
        self._svd()
        return self.U


    @cached_property
    def S(self):
        self._svd()
        return self.S

    @cached_property
    def Vh(self):
        self._svd()
        return self.Vh

    @cached_property
    def n(self) -> int:
        return self.U.shape[-2]

    @cached_property
    def m(self) -> int:
        return self.Vh.shape[-1]

    @cached_property
    def k(self) -> int:
        return self.S.shape[-1]

    def lens(self, src: Tensor):
        return (self.Hb.flatten(-2).flatten(-3, -2) @ src.flatten(-2).unsqueeze(-1)).squeeze(-1).unflatten(-1, self.img_shape)

    def mask(self, img: Tensor) -> np.ma.MaskedArray:
        return np.ma.array(img, mask=self.img_mask.expand(img.shape))

    @cached_property
    def sqrt_alphas(self):
        return self.diffusion.sqrt_alphas_cumprod

    @cached_property
    def sigmas(self):
        return (self.sqrt_alphas**(-2) - 1)**0.5
    
    def steps(self, skip:int = 1, skip_type:str = "linear"):
        
        if skip_type == "linear":
            steps = np.arange(0, self.diffusion.num_timesteps, skip)
        elif skip_type == "quad":
            steps = (
                np.linspace(
                    0, np.sqrt(self.diffusion.num_timesteps)*0.9999, self.diffusion.num_timesteps//skip
                )
                ** 2
            )
            steps = steps.astype(int)
            steps = np.unique(steps) # Remove duplicates when skip is small

        return steps

    def project_y(self, y: Tensor):
#         print(self.U.mT.shape)
#         print(y.unsqueeze(-1).shape)
        #y = y.half()

#         self.U = self.U.half()
#         self.S = self.S.half()
        return (self.U.mT @ y.unsqueeze(-1)).squeeze(-1) / self.S

    def project_x(self, x: Tensor):
        return (self.Vh @ x.unsqueeze(-1)).squeeze(-1)

    def deproject_x(self, xbar: Tensor):
        return (self.Vh.mT @ xbar.unsqueeze(-1)).squeeze(-1)

    def denoise(self, x: Tensor, t, clip=False):
        
        #print(self.src_shape)
        x = x.to(device=DEVICE)
        #print(x.shape)
        src_shape = (3,256,256)
        ximg = self.sqrt_alphas[t] * x.unflatten(-1, src_shape)
        #print(ximg.shape)
        if ximg.ndim < 4:
            ximg = ximg.unsqueeze(-4)
        t = torch.as_tensor(t).expand(ximg.shape[:-3])
        t = t.to(device=DEVICE)

        x_mean = self.diffusion.predict_start_from_noise(
            ximg, t=t, noise=self.diffusion.denoise_fn(ximg, t))
        if clip:
            x_mean = x_mean.clip_(-1., 1.)
        return x_mean.flatten(-len(self.src_shape))

    def sample(self, y, sigma_y, eta=0.03, eta_b=0.85, skip=1, skip_type="linear", clip=False):
        steps = self.steps(skip=skip, skip_type=skip_type)
        
        sigma_y_scaled = sigma_y / self.S
        r_sigma_y_scaled = self.S / sigma_y
        var_y_scaled = sigma_y_scaled**2

        y_bar = self.project_y(y).nan_to_num_(nan=None, posinf=1, neginf=1)

        x = self.sigmas[steps[-1]] * torch.randn(
            y.shape[:-1] + (self.m,),
            device=y.device, dtype=y.dtype)

        for t, tn in self.with_progress(zip(reversed(steps[:-1]), reversed(steps[1:]))):

            sigma_t = self.sigmas[t]
            eta_b_t = eta_b if eta_b is not None else 2 * self.S**2 / (self.S**2 + sigma_y**2 / sigma_t**2)

            xnew = self.denoise(x, tn, clip=clip)
            xnew = xnew.flatten(-2)

            sdiff_pre = sigma_t * (1 - eta**2)**0.5

            # Dealing with null-space
#             print(x.shape)
#             print(xnew.shape)
            x_aux = xnew + sdiff_pre * (x - xnew) / self.sigmas[tn] + sigma_t * eta * torch.randn_like(x)

            x0 = x_aux - ((self.Vh @ x_aux.unsqueeze(-1)).mT @ self.Vh).squeeze(-2)

            x_bar = self.project_x(xnew)

            # Dealing with non-singular-space
            mask_constr = sigma_t > sigma_y_scaled
            mean = torch.where(
                mask_constr, torch.lerp(x_bar, y_bar, eta_b_t),
                x_bar + sdiff_pre * (y_bar - x_bar) * r_sigma_y_scaled
            )
            var = torch.where(
                mask_constr, sigma_t**2 - eta_b_t**2 * var_y_scaled,
                             eta**2 * sigma_t**2
            ).clip_(0)

            x12 = mean + var**0.5 * torch.randn_like(y_bar)

            x = x0 + (self.Vh.T @ x12.unsqueeze(-1)).squeeze(-1)
            
            #print(x.shape)

        return self.denoise(x, 0, clip=clip)