In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch

from configs.ve.cifar10_ncsnpp_deep_continuous import get_config
from score_inverse.models.utils import create_model
from score_inverse.models.ema import ExponentialMovingAverage

from score_inverse.sde import get_sde
from score_inverse.datasets.scalers import get_data_inverse_scaler, get_data_scaler
from score_inverse.sampling import get_corrector, get_predictor
from score_inverse.sampling.inverse import get_pc_inverse_solver

from score_inverse.tasks.denoise import DenoiseTask
from score_inverse.datasets.cifar10 import CIFAR10

from PIL import Image
import numpy as np

import os

os.chdir("..")
print("New Working Directory ", os.getcwd())

In [None]:
dataset = CIFAR10(train=False)

In [None]:
config = get_config()
config.model.num_scales = 100  # Number of discretisation steps
config.eval.batch_size = 4  # Number of samples per generation

inverse_task = DenoiseTask((3, 32, 32), noise_type='shot', severity=1).to(config.device)

ckpt_path = "checkpoints/ve/cifar10_ncsnpp_deep_continuous/checkpoint_12.pth"
loaded_state = torch.load(ckpt_path, map_location=config.device)

score_model = create_model(config)
# Still need to load the base model state since non-trainable params aren't covered by EMA
score_model.load_state_dict(loaded_state["model"], strict=False)

# Replace trainable model params with EMA params
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
ema.load_state_dict(loaded_state["ema"])
ema.copy_to(score_model.parameters())

In [None]:
def display_img(im, scale=10):
    w, h = im.size
    display(im.resize((scale * w, scale * h), Image.NEAREST))


def display_sample(sample, scale=10, shape=None):
    samples = np.clip(sample.permute(0, 2, 3, 1).cpu().numpy() * 255.0, 0, 255).astype(
        np.uint8
    )
    if shape is None:
        shape = (
            -1,
            config.data.image_size,
            config.data.image_size,
            config.data.num_channels,
        )
    samples = samples.reshape(shape)
    im = Image.fromarray(samples[0])
    display_img(im, scale)

In [None]:
x = dataset[13]
x = x[None, :].to(config.device)

y = inverse_task.forward(x)

display_sample(x)
display_sample(y)

In [None]:
scaler = get_data_scaler(config)

inverse_scaler = get_data_inverse_scaler(config)
sde, sampling_eps = get_sde(config)

sampling_shape = (
    config.eval.batch_size,
    config.data.num_channels,
    config.data.image_size,
    config.data.image_size,
)
predictor = get_predictor(config.sampling.predictor.lower())
corrector = get_corrector(config.sampling.corrector.lower())
sampling_fn = get_pc_inverse_solver(
    sde=sde,
    shape=sampling_shape,
    predictor=predictor,
    corrector=corrector,
    inverse_scaler=inverse_scaler,
    snr=config.sampling.snr,
    n_steps=config.sampling.n_steps_each,
    probability_flow=config.sampling.probability_flow,
    continuous=config.training.continuous,
    denoise=config.sampling.noise_removal,
    eps=sampling_eps,
    device=config.device,
    inverse_task=inverse_task,
    lambda_=0.05,
)

In [None]:
sample, n = sampling_fn(score_model, y.to(config.device))

In [None]:
display_sample(y)
display_sample(sample)
display_sample(x)

In [None]:
class InverseSolverSampler:
    def __init__(self, inverse_task, score_model, config, batch_size=100, lambda_=0.01):
        self.inverse_task = inverse_task
        self.score_model = score_model
        self.config = config
        self.batch_size = batch_size
        self.lambda_ = lambda_
        
        self.predictor = get_predictor(config.sampling.predictor.lower())
        self.corrector = get_corrector(config.sampling.corrector.lower())
        
        self.inverse_scaler = get_data_inverse_scaler(config)

        self.sde, self.sampling_eps = get_sde(config)

        self.sampling_shape = (
            batch_size,
            config.data.num_channels,
            config.data.image_size,
            config.data.image_size,
        )

    def fit(self, *_):
        return self

    def predict(self, X):
        sampling_fn = get_pc_inverse_solver(
            sde=self.sde,
            shape=self.sampling_shape,
            predictor=self.predictor,
            corrector=self.corrector,
            inverse_scaler=self.inverse_scaler,
            snr=self.config.sampling.snr,
            n_steps=self.config.sampling.n_steps_each,
            probability_flow=self.config.sampling.probability_flow,
            continuous=self.config.training.continuous,
            denoise=self.config.sampling.noise_removal,
            eps=self.sampling_eps,
            device=self.config.device,
            inverse_task=self.inverse_task,
            lambda_=self.lambda_,
        )
        
        samples = []
        
        for i in range(len(X)//self.batch_size):
            y = inverse_task.forward(X[i*self.batch_size:(i+1)*self.batch_size].to(self.config.device))
            sample, _ = sampling_fn(self.score_model, y)
            samples.append(sample.detach().cpu())
        
        samples = torch.cat(samples)
        
        torch.cuda.empty_cache()
        return samples

    def get_params(self, deep=False):
        return dict(inverse_task = self.inverse_task, 
                    score_model = self.score_model, 
                    config = self.config,
                    batch_size = self.batch_size,
                    lambda_ = self.lambda_)

    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

In [None]:
from sklearn.model_selection import GridSearchCV
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.image import PeakSignalNoiseRatio
from sklearn.metrics import make_scorer
import dill
from pathlib import Path

ssim = StructuralSimilarityIndexMeasure(data_range=(0,1))
ssim.__name__ = 'StructuralSimilarityIndexMeasure'
psnr = PeakSignalNoiseRatio(data_range=(0,1))
psnr.__name__ = 'PeakSignalNoiseRatio'

batch_size = 10
cv = 5
total_samples = 100
lambda_range = np.arange(0.01, 0.055, 0.005)

search_data = torch.stack([dataset[ind] for ind in range(total_samples)]) # note: this is on CPU to save memory

results = {}

print('Attempting to load gridsearch data...')
for noise_type in ['gaussian', 'shot']:
    save_path = f'logs/gridsearch/{noise_type}' 
    results[noise_type] = {}
    Path(save_path).mkdir(parents=True, exist_ok=True)
    for severity in range(1,6):
        try: 
            severity_results = dill.load(open(f'{save_path}/{severity}.pkl', 'rb'))
            print(f'Successfully loaded data for {noise_type} noise with severity={severity}')
            results[noise_type][severity] = severity_results
        except FileNotFoundError:
            print(f'Data not found for {noise_type} noise with severity={severity}. Running gridsearch...')
            
            inverse_task = DenoiseTask((3, 32, 32), noise_type=noise_type, severity=severity).to(config.device)
            
            sampler = InverseSolverSampler(inverse_task, score_model, config, batch_size=batch_size)
            
            gscv = GridSearchCV(sampler, dict(lambda_=lambda_range), scoring={'ssim': make_scorer(ssim), 'psnr': make_scorer(psnr)}, error_score='raise', verbose=4, cv=cv, refit=False, n_jobs=2)

            results[noise_type][severity] = gscv.fit(X=search_data, y=search_data).cv_results_
    
            print(f'Saving results for {noise_type} noise with severity={severity}\n')
            dill.dump(results[noise_type][severity], open(f'{save_path}/{severity}.pkl', 'wb'))
            
            torch.cuda.empty_cache()    

In [None]:
for noise_type, severity_results in results.items():
    print(noise_type, 'noise')
    for severity, gscv_results in severity_results.items():
        print(f'severity = {severity}, best lambda = {gscv_results["param_lambda_"][gscv_results["mean_test_ssim"].argmax()]}')
    print()

In [None]:
# Shut down the kernel to release resources
# Note: depending on jupyter version, this may need to change to `from jupyter_server import serverapp
from notebook import notebookapp
port_list = [note["port"] for note in notebookapp.list_running_servers()]

print(port_list)

for port in port_list:
    !jupyter notebook stop {port}