In [1]:
import datetime
import os
import shutil
import time

import hydra
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm


import wandb
from algos import build_algo
from datasets import build_loader
from htransform.likelihoods import get_likelihood
from models import build_model
from models.classifier_guidance_model import ClassifierGuidanceModel, HTransformModel
from models.diffusion import Diffusion
from utils.degredations import get_degreadation_image
from utils.distributed import common_init, get_logger, init_processes
from utils.functions import get_timesteps, postprocess, preprocess, strfdt
from utils.save import save_result

torch.set_printoptions(sci_mode=False)

from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra

GlobalHydra.instance().clear()
initialize(version_base="1.2", config_path="_configs") 
cfg = compose(config_name="deft")


In [2]:
cfg.exp
cfg.exp.name = "sr4_deft"
cfg.exp.samples_root = "samples_trial"
cfg.exp.save_deg = True 
cfg.exp.save_evolution = False
cfg.exp.save_ori = True
cfg.exp.seed = 3
cfg.exp.smoke_test = -1
cfg.htransform_model.in_channels = 9
cfg.htransform_model.num_channels = 64
cfg.htransform_model.num_head_channels = 16
cfg.htransform_model.out_channels = 3
cfg.likelihood.forward_op.noise_std = 0.0
cfg.likelihood.forward_op.scale = 4.0
cfg.loader.batch_size = 10
cfg.dist.num_processes_per_node = 1
cfg.htransform_model.ckpt_path = "/home/sp2058/DEFT/outputs/model_ckpts/sr4_deft_8fi2wuj3/model.pt"

cfg.wandb_config.log = False
cfg.exp.name = "debug_sampling"
OmegaConf.set_struct(cfg, False)
cfg.cwd = "/home/sp2058/DEFT/outputs/2024-10-24/08-18-53"
OmegaConf.set_struct(cfg, True)

In [3]:
# Initialize distributed process group
try:
    dist.init_process_group(
        backend='nccl',
        init_method='tcp://127.0.0.1:23456',
        world_size=1,
        rank=0
    )
except:
    pass


print("cfg.exp.seed", cfg.exp.seed)
common_init(0, seed=cfg.exp.seed)
torch.cuda.set_device(0)

exp_name = cfg.exp.name
print(f"Experiment name is {exp_name}")
exp_root = cfg.exp.root
samples_root = cfg.exp.samples_root
samples_root = os.path.join(exp_root, samples_root, exp_name)

dataset_name = cfg.dataset.name

model, classifier, htransform_model = build_model(cfg)
model.eval()

if classifier is not None:
    classifier.eval()
loader = build_loader(cfg)

print(f"Dataset size is {len(loader.dataset)}")

diffusion = Diffusion(**cfg.diffusion)

if "deft" in cfg.algo.name:
    likelihood = get_likelihood(cfg, device=model.device)

    cg_model = HTransformModel(
        model, htransform_model, classifier, diffusion, likelihood, cfg
    )
else:
    cg_model = ClassifierGuidanceModel(model, classifier, diffusion, cfg)

algo = build_algo(cg_model, cfg)

if (
    "ddrm" in cfg.algo.name
    or "mcg" in cfg.algo.name
    or "dps" in cfg.algo.name
    or "pgdm" in cfg.algo.name
    or "reddiff" in cfg.algo.name
    or "deft" in cfg.algo.name
):
    H = algo.H

########################## DO FINETUNING IF NEEDED ##########
print(cfg.htransform_model.ckpt_path)
if cfg.algo.name == "deft" and cfg.htransform_model.ckpt_path is None:
    algo.train()

########################## DO EVAL ##########################
psnrs = []
start_time = time.time()

cfg.exp.seed 3
Experiment name is debug_sampling


[2024-11-27 20:27:30,101][model][INFO] - Loading model from pretrained_model/256x256_diffusion_uncond.pt..


  model.load_state_dict(torch.load(model_ckpt, map_location=map_location))
  htransform_model.load_state_dict(torch.load(cfg.htransform_model.ckpt_path))


Dataset size is 1000
Finetune dataset size is 1000
Validation dataset size is 100
{'name': 'deft', 'deg': 'sr4', 'finetune_args': {'batch_size': 100, 'shuffle': True, 'drop_last': True, 'epochs': 100, 'lr': 0.001, 'log_freq': 10, 'save_model_every_n_epoch': 10, 'lr_annealing': True}, 'val_args': {'batch_size': 10, 'psnr_batch_size': 100, 'num_steps': 100, 'eta': 0.0, 'use_ema': False, 'sample_freq': 5, 'psnr_sample_freq': -1}, 'ema': {'beta': 0.999, 'update_after_step': 400, 'update_every': 10}, 'use_x0hat': True, 'use_loggrad': True}
deg sr4
number of parameters in pretrained model:  552814086
number of parameters in finetuned model:  23701700
Fraction:  0.042874631092522486
/home/sp2058/DEFT/outputs/model_ckpts/sr4_deft_8fi2wuj3/model.pt


In [9]:
wandb.init()
for it, (x, y, info) in tqdm(enumerate(loader)):
    if cfg.exp.smoke_test > 0 and it >= cfg.exp.smoke_test:
        break
    # Images are in [0, 1]
    # y here is the label of imagenet that class_cond models occasionally need.
    x, y = x.cuda(), y.cuda()

    # Convert from [0, 1] to [-1, 1]
    x = preprocess(x)
    ts = get_timesteps(cfg)

    kwargs = info
    # TODO: Can we combine the likelihood forward pass for all algorithms?
    if (
        "ddrm" in cfg.algo.name
        or "mcg" in cfg.algo.name
        or "dps" in cfg.algo.name
        or "pgdm" in cfg.algo.name
        or "reddiff" in cfg.algo.name
    ):
        idx = info["index"]
        if "inp" in cfg.algo.deg or "in2" in cfg.algo.deg:  # what is in2?
            H.set_indices(idx)
        y_0 = H.H(x)

        # This is to account for scaling to [-1, 1]
        # y_0 is the degradation that we consider
        y_0 = (
            y_0 + torch.randn_like(y_0) * cfg.algo.sigma_y * 2
        )  # ?? what is it for???
        kwargs["y_0"] = y_0
    elif "deft" in cfg.algo.name:
        # TODO: Use algo.sigma_y instead of forward_op.noise_std
        # TODO: remove likelihood configs entirely, specify in algo.deg_args.
        if "inp" in cfg.algo.deg or "in2" in cfg.algo.deg:
            y_0, masks = algo.model.likelihood.sample(
                x,
                deterministic_idx=torch.arange(0, x.shape[0])
                .long()
                .to(algo.device),
            )
        else:
            y_0 = algo.model.likelihood.sample(x)
            masks = None
        kwargs["masks"] = masks
        kwargs["y_0"] = y_0
        kwargs["use_ema"] = cfg.algo.val_args.use_ema

    # pgdm
    if cfg.exp.save_evolution:
        if cfg.algo.name == "deft":
            raise NotImplementedError("DEFT does not support evolution saving")
        xt_s, _, xt_vis, _, mu_fft_abs_s, mu_fft_ang_s = algo.sample(
            x, y, ts, **kwargs
        )
    else:
        xt_s, _ = algo.sample(x, y, ts, **kwargs)
    
    # Save x, y, ts, kwargs
    # Save sampling inputs for debugging
    torch.save({
        'x': x,
        'y': y, 
        'ts': ts,
        'kwargs': kwargs
    }, 'sampling_inputs.pt')
    
    break

0it [00:00, ?it/s]

sampling w/ point estimate model




100%|██████████| 100/100 [00:33<00:00,  3.02it/s]
0it [00:33, ?it/s]


In [6]:
print(xt_s.shape)

torch.Size([10, 3, 256, 256])


In [7]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../adapt-diffusions")
from htransform.eval import calculate_total_psnr

In [8]:
import argparse

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

finetune_model_cfg = {
    "attention_resolutions": "32,16,8",
    "class_cond": False,
    "dropout": 0.0,
    "image_size": 256,
    "learn_sigma": True,
    "num_channels": 256,
    "num_head_channels": 64,
    "num_res_blocks": 2,
    "resblock_updown": True,
    "use_fp16": False,
    "use_scale_shift_norm": True,
    "num_heads": 4,
    "num_heads_upsample": -1,
    "conv_resample": True,
    "dims": 2,
    "num_classes": None,
    "use_checkpoint": False,
    "use_new_attention_order": False,
    "num_samples": 1
}

cfg_model = dict2namespace(finetune_model_cfg)


val_kwargs={
    "batch_size": 10,
    "num_steps": 100,
    "eta": 1.,
    "sample_freq": 1,
    "psnr_sample_freq": 1,
    "psnr_batch_size": 10,
    "rescale_image": True,
}

calculate_total_psnr(
    algo.model.model,
    algo.model.htransform_model,
    algo.model.likelihood,
    loader,
    algo.model.diffusion,
    algo.device,
    val_kwargs,
    cfg_model,
    use_target_score_matching=False,
    log_dir="trial"
    )

ts:  [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710, 720, 730, 740, 750, 760, 770, 780, 790, 800, 810, 820, 830, 840, 850, 860, 870, 880, 890, 900, 910, 920, 930, 940, 950, 960, 970, 980, 990]
{'batch_size': 10, 'num_steps': 100, 'eta': 1.0, 'sample_freq': 1, 'psnr_sample_freq': 1, 'psnr_batch_size': 10, 'rescale_image': True}


  0%|          | 0/100 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s]


TypeError: get_xi_condition() got an unexpected keyword argument 'cfg_model'

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

im1 = mpimg.imread('/home/sp2058/adapt-diffusions/test/n01440764/ILSVRC2012_val_00000293.png')
im2 = mpimg.imread('/home/sp2058/adapt-diffusions/save_log/sr_htransform_kdbbaeqk/n01440764/ILSVRC2012_val_00000293.png')

im3 = mpimg.imread('/home/sp2058/DEFT/outputs/samples/old_ckpt_new_sampling/n01440764/ILSVRC2012_val_00000293.png')
