# This is a notebook for testing a variance reduction technique using langevin dynamics to sample from a posterior distribution

## We have: 

### a set of observations $y = Ax^* (y \in \mathbb{R}^m, A \in \mathbb{R}^{m \times n} x \in \mathbb{R}^n)$ 
### ncsn $s(\theta, x) \simeq \nabla \log p(x)$ 
### region of interest $ROI \subseteq \{1,2,\dots,n\}$ 

## We want: recovered $\hat{x}$ where $x[ROI] = x^*[ROI]$ 

### we propose to do this by getting a minimum variance estimate for ROI, possibly at the expense of bias in ROI and increased variance in [N] / ROI  

## Preliminaries: define the paths to useful files and import needed stuff

In [None]:
%cd /home/sriram/Projects/ncsnv2

In [None]:
ckpt_path = "/home/sriram/Projects/ncsnv2/exp/logs/celeba/checkpoint_210000.pth"
config_path = "/home/sriram/Projects/ncsnv2/configs/celeba.yml"

In [None]:
from main import dict2namespace
import yaml
import torch

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)
new_config = dict2namespace(config)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
new_config.device = device

print(device)

In [None]:
import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--seed', type=int, default=1234, help='Random seed')
parser.add_argument('--exp', type=str, default='exp', help='Path for saving running related data.')

args = parser.parse_args(["--seed", "2240", "--exp", "/home/sriram/Projects/ncsnv2/exp"])

## Grab the data and visualise

In [None]:
from datasets import get_dataset, data_transform, inverse_data_transform

_, test_dataset = get_dataset(args, new_config)

In [None]:
from torch.utils.data import DataLoader

test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True,
                          num_workers=8, drop_last=True)

test_iter = iter(test_loader)

In [None]:
test_sample = next(test_iter)[0]
test_sample = test_sample.to(new_config.device)
test_sample = data_transform(new_config, test_sample)
test_sample = test_sample.cpu()

print("SHAPE: ", test_sample.shape)
print("MIN: ", torch.min(test_sample))
print("MAX: ", torch.max(test_sample))
print("MEAN: ", torch.mean(test_sample))
print("STD: ", torch.std(test_sample))

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torchvision

plt.figure(figsize=(8, 8))
grid_img = torchvision.utils.make_grid(test_sample, nrow=4)
plt.imshow(grid_img.permute(1, 2, 0))

## Grab the appropriate model

In [None]:
from models.ncsnv2 import NCSNv2
from models.ema import EMAHelper
from models import get_sigmas

new_config.input_dim = new_config.data.image_size ** 2 * new_config.data.channels

score = NCSNv2(new_config).to(new_config.device)
score = torch.nn.DataParallel(score)

#Set up the exponential moving average
if new_config.model.ema:
    ema_helper = EMAHelper(mu=new_config.model.ema_rate)
    ema_helper.register(score)

states = torch.load(ckpt_path)
score.load_state_dict(states[0])
### Make sure we can resume with different eps
states[1]['param_groups'][0]['eps'] = new_config.optim.eps

if new_config.model.ema:
    ema_helper.load_state_dict(states[4])

#grab all L noise levels
sigmas = get_sigmas(new_config)

In [None]:
test_score = ema_helper.ema_copy(score)

test_score.eval();

In [None]:
print("NUMBER OF NOISE LEVELS: ", sigmas.shape.numel())
print("FIRST: ", sigmas[0].item(), " LAST: ", sigmas[-1].item())

## Create some measurements of an image and visualise

In [None]:
x = next(test_iter)[0][0]

plt.figure()
plt.imshow(x.permute(1, 2, 0))
plt.show()

In [None]:
def getRectMask(h_offset=0, w_offset=0, height=10, width=35, tensor_like=None, \
                img_height=64, img_width=64, num_channels=3):
    
    if tensor_like is not None:
        mask_tensor = torch.ones_like(tensor_like)
    else:
        mask_tensor = torch.ones(num_channels, img_height, img_width)
    
    mask_tensor[:, h_offset:h_offset+height, w_offset:w_offset+width] = 0
    
    return mask_tensor

In [None]:
A = getRectMask(h_offset=27, w_offset=15, tensor_like = x)

y = A * x

plt.figure()
plt.imshow(y.permute(1, 2, 0))
plt.show()

## First, run Langevin dynamics on the single image to sample multiple times

In [None]:
N = 16

x = x.unsqueeze(0).repeat(N, 1, 1, 1)
y = y.unsqueeze(0).repeat(N, 1, 1, 1)

print("X shape: ", x.shape)
print("Y shape: ", y.shape)
print("A shape: ", A.shape)

In [None]:
x = x.to(new_config.device)
A = A.to(new_config.device)
y = y.to(new_config.device)

In [None]:
def SGLD_inpainting(x_mod, x, A, scorenet, sigmas, T=5, step_lr=3.3e-6, \
                   final_only=False, verbose=False, denoise=True, decimate=False):
    
    images = []
    
    #create a negative mask from A
    A_trans = -A + 1  
    
    with torch.no_grad():
        for c, sigma in enumerate(sigmas):
            #if we choose to decimate, only update once every decimate steps
            if decimate is not False:
                if c % decimate != 0:
                    continue 
            #construct the noise level labels to give to scorenet for scaling 
            labels = torch.ones(x_mod.shape[0], device=x_mod.device) * c
            labels = labels.long()
            
            step_size = step_lr * (sigma / sigmas[-1]) ** 2
            
            y = A * x + torch.randn_like(x_mod) * sigma
            
            for s in range(T):
                #prior
                grad = scorenet(x_mod, labels)
                
                #draw noise
                noise = torch.randn_like(x_mod)
                
                #prior step
                x_mod = x_mod + step_size * grad + noise * torch.sqrt(step_size * 2)
                
                #likelihood step
                x_mod = x_mod * A_trans + y * A
                
                #logging
                grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=-1).mean()
                noise_norm = torch.norm(noise.view(noise.shape[0], -1), dim=-1).mean()
                image_norm = torch.norm(x_mod.view(x_mod.shape[0], -1), dim=-1).mean()
                snr = torch.sqrt(step_size / 2.) * grad_norm / noise_norm
                grad_mean_norm = torch.norm(grad.mean(dim=0).view(-1)) ** 2 * sigma ** 2
                
                if not final_only:
                    images.append(x_mod.to('cpu'))
                if verbose:
                    print("level: {}, step_size: {}, grad_norm: {}, image_norm: {}, snr: {}, grad_mean_norm: {}".format(
                        c, step_size, grad_norm.item(), image_norm.item(), snr.item(), grad_mean_norm.item()))
                
        if denoise:
            last_noise = (len(sigmas) - 1) * torch.ones(x_mod.shape[0], device=x_mod.device)
            last_noise = last_noise.long()
            x_mod = x_mod + sigmas[-1] ** 2 * scorenet(x_mod, last_noise)
            images.append(x_mod.to('cpu'))

        if final_only:
            return [x_mod.to('cpu')]
        else:
            return images

In [None]:
x_mod = torch.rand(N, 3, 64, 64, device=new_config.device)

images = SGLD_inpainting(x_mod=x, x=x, A=A, scorenet=test_score, sigmas=sigmas, \
                         T=5, step_lr=3.3e-6, final_only=True, verbose=True, denoise=True,
                         decimate = 2)

In [None]:
print(len(images))
print(images[0].shape)

In [None]:
results = images[0]

plt.figure(figsize=(9, 18))
grid_img = torchvision.utils.make_grid(results, nrow=4)
plt.imshow(grid_img.permute(1, 2, 0), interpolation='nearest')

## Calculate the variance in the ROI

In [None]:
ROI = -A + 1

x_hat_ROI = results.to(new_config.device) * ROI
x_hat_ROI_mean = torch.mean(x_hat_ROI, dim=0)

print("ROI Img Shape: ", x_hat_ROI.shape)
print("Mean Shape: ", x_hat_ROI_mean.shape)

In [None]:
ROI_var = torch.norm(x_hat_ROI - x_hat_ROI_mean, p=2)**2 / N

print(ROI_var.item())

## Repeat the process for non-pixel-space forward operators

In [None]:
x = next(test_iter)[0][0]

plt.imshow(x.permute(1, 2, 0))
plt.show()

In [None]:
N = 16

C, H, W = list(x.size())

x = x.unsqueeze(0).repeat(N, 1, 1, 1)

print(x.shape)

In [None]:
m = int(0.1 * C * H * W)

print(m)

In [None]:
A = (1 / np.sqrt(m)) * torch.randn(m, C*H*W)

y = torch.matmul(A, torch.flatten(x, start_dim=1).T).T

print("A shape: ", A.shape)
print("y shape: ", y.shape)
print("x shape: ", x.shape)

In [None]:
x = x.to(new_config.device)
A = A.to(new_config.device)
y = y.to(new_config.device)

In [None]:
def calc_likelihood_grad(A, y, x_hat, c_list):
    """
    Returns a likelihood gradient given a system and weighting.
    p(y|x_hat) = (1/2)||C(y - Ax_hat)||_2^2 where C is a square matrix with c_i on the diagonal.
    gradient = (A^T)(C^T)[CAx_hat - Cy]
             = (A^T)(C^T)(CAx_hat) - (A^T)(C^T)Cy
             = (A^T)(C^2)Ax_hat - (A^T)(C^2)y
    
    Arguments:
        A: measurement operator [m, n=H*C*W]
        y: measurement operator [N, m]
        x: data [N, C, H, W]
        c: weights for each row of reconstruction loss [m]
    """
    
    Ax_hat = torch.matmul(A, torch.flatten(x_hat, start_dim=1).T).T #[N, m]
    
    C_squared = torch.diag(torch.flatten(c_list)).to(new_config.device) #[m, m]
    
    hat_term = torch.matmul(C_squared, Ax_hat.T).T #[N, m]
    hat_term = torch.matmul(A.T, hat_term.T).T #[N, n]
    
    meas_term = torch.matmul(C_squared, y.T).T #[N, m]
    meas_term = torch.matmul(A.T, meas_term.T).T #[N, n]
    
    return (hat_term - meas_term).view(list(x_hat.shape)) #[N, n]

In [None]:
def calc_ROI_loss(x_hat, ROI):
    N = list(x_hat.shape)[0]
    
    x_hat_ROI = x_hat * ROI
    x_hat_ROI_mean = torch.mean(x_hat_ROI, dim=0)
    
    ROI_var = torch.norm(x_hat_ROI - x_hat_ROI_mean, p=2)**2 / N

    return ROI_var

In [None]:
def SGLD_inverse(x_mod, y, A, c_list, scorenet, sigmas, x=None, T=5, step_lr=3.3e-6, \
                 verbose=False, denoise=True, decimate=False):
    
    if x is not None:
        mse = torch.nn.MSELoss()
    
    for c, sigma in enumerate(sigmas):
        #if we choose to decimate, only update once every decimate steps
        if decimate is not False:
            if c % decimate != 0 or c == 0: #the second part after and is optional lol
                continue 

        with torch.no_grad():
        #construct the noise level labels to give to scorenet for scaling 
            labels = torch.ones(x_mod.shape[0], device=x_mod.device) * c
            labels = labels.long()

            step_size = step_lr * (sigma / sigmas[-1]) ** 2

        for s in range(T):
            #prior
            with torch.no_grad():
                grad = scorenet(x_mod, labels)
                grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=-1).mean()

            #likelihood
            mle_grad = calc_likelihood_grad(A=A, y=y, x_hat=x_mod, c_list=c_list)
            with torch.no_grad():
                mle_grad_norm = torch.norm(mle_grad.view(mle_grad.shape[0], -1), dim=-1).mean()

            grad = grad - (mle_grad / sigma**2)
            #grad = grad - mle_grad

            #draw noise
            noise = torch.randn_like(x_mod)

            #prior step
            x_mod = x_mod + step_size * grad + noise * torch.sqrt(step_size * 2)

            #logging
            with torch.no_grad():
                if x is not None:
                    true_mse = mse(x_mod, x)

                    if verbose:
                        print("level: {}, step_size: {:.3f}, grad_norm: {:.3f}, mle_grad_norm: {:.3f}, true mse: {:.3f}".format(
                            c, step_size, grad_norm.item(), mle_grad_norm.item(), true_mse.item()))

    if denoise:
        last_noise = (len(sigmas) - 1) * torch.ones(x_mod.shape[0], device=x_mod.device)
        last_noise = last_noise.long()
        x_mod = x_mod + sigmas[-1] ** 2 * scorenet(x_mod, last_noise)

    return x_mod

In [None]:
def getRectMask(h_offset=0, w_offset=0, height=10, width=35, tensor_like=None, \
                img_height=64, img_width=64, num_channels=3):
    
    if tensor_like is not None:
        mask_tensor = torch.zeros_like(tensor_like)
    else:
        mask_tensor = torch.zeros(num_channels, img_height, img_width)
    
    mask_tensor[:, h_offset:h_offset+height, w_offset:w_offset+width] = 1
    
    return mask_tensor

### create the hypterparameters

In [None]:
c_list = torch.ones(m, device=new_config.device, requires_grad=False)
#c_list = c_list * 1
c_list = c_list.requires_grad_()

print(c_list.shape)

In [None]:
ROI = getRectMask(h_offset=27, w_offset=15).to(new_config.device)

print(ROI.shape)

In [None]:
grid_img = torchvision.utils.make_grid(ROI.cpu(), nrow=4)
plt.imshow(grid_img.permute(1, 2, 0))

In [None]:
import torch.optim as optim

opt = optim.Adam([{'params': c_list}], lr=1e-3) 

In [None]:
from tqdm import tqdm
import time

x_mod = torch.rand(N, C, H, W, device=new_config.device, requires_grad=False)

langevin_vars = []

num_iters = 5

for epoch in tqdm(range(num_iters)):
    opt.zero_grad()
    
    x_hat = SGLD_inverse(x_mod, y, A, c_list, test_score, sigmas, x=x, \
                         verbose=True, denoise=False, decimate=False)
    
    plt.figure(figsize=(8, 8))
    grid_img = torchvision.utils.make_grid(x_hat.cpu(), nrow=4)
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.show()
    
    loss = calc_ROI_loss(x_hat, ROI)
    
    langevin_vars.append(loss.item())
    print("ROI VARIANCE: ", loss.item())
    
    loss.backward()
    opt.step()
    
    print(c_list)
    
    time.sleep(3)