In [27]:
import torch
from PIL import Image
from collections import defaultdict
import torch.nn.functional as TF
import torchvision.datasets as dsets
from torchvision import transforms
import numpy as np

import torch.optim as optim
import torchvision
# import matplotlib.pyplot as plt
import utils
from guided_diffusion.unet import UNetModel
import math
from tensorboardX import SummaryWriter
import os
import json
from collections import namedtuple
import argparse
from torchvision.utils import save_image
from tqdm import tqdm
from blur_diffusion import Deblurring, ForwardBlurIncreasing, gaussian_kernel_1d
from utils import normalize_np, clear
from EMA import EMA
from torch.nn import DataParallel
from fid import FID
from scipy.integrate import solve_ivp

In [36]:
parser = argparse.ArgumentParser(description='Configs')
parser.add_argument('--gpu', default='0',type=str, help='gpu num')
parser.add_argument('--dataset',default='cifar10', type=str, help='cifar10 / mnist')
parser.add_argument('--name', default='blur_diff',type=str, help='Saving directory name')
parser.add_argument('--ckpt', default='', type=str, help='UNet checkpoint')

parser.add_argument('--bsize', default=16, type=int, help='batchsize')
parser.add_argument('--N', default=500, type=int, help='Max diffusion timesteps')
parser.add_argument('--sig', default=0.4, type=float, help='sigma value for blur kernel')
parser.add_argument('--sig_min', default=0, type=float, help='sigma value for blur kernel')
parser.add_argument('--sig_max', default=0.1, type=float, help='sigma value for blur kernel')
parser.add_argument('--lr', default=0.00005, type=float, help='learning rate')
parser.add_argument('--noise_schedule', default='linear', type=str, help='Type of noise schedule to use')
parser.add_argument('--betamin', default=0.0001, type=float, help='beta (min). get_score(1) can diverge if this is too low.')
parser.add_argument('--betamax', default=0.02, type=float, help='beta (max)')
parser.add_argument('--fromprior', default=True, type=bool, help='start sampling from prior')
parser.add_argument('--gtscore', action='store_true', help='Use ground truth score for reverse diffusion')
parser.add_argument('--max_iter', default=15000, type=int, help='max iterations')
parser.add_argument('--eval_iter', default=1000, type=int, help='eval iterations')
parser.add_argument('--fid_iter', default=2000, type=int, help='eval iterations')
parser.add_argument('--fid_num_samples', default=100, type=int, help='eval iterations')
parser.add_argument('--fid_bsize', default=32, type=int, help='eval iterations')
parser.add_argument('--loss_type', type=str, default = 'eps_simple', choices=['sm_simple', 'eps_simple', 'sm_exact', 'std_matching'])
parser.add_argument('--f_type', type=str, default = 'linear', choices=['linear', 'log', 'quadratic', 'cubic', 'quartic', 'triangular'])
parser.add_argument('--dropout', default=0, type=float, help='dropout')

# EMA, save
parser.add_argument('--use_ema', action='store_true',
                    help='use EMA or not')
parser.add_argument('--inference', action='store_true')
parser.add_argument('--freq_feat', action='store_true', help = "concat Utx_i")
parser.add_argument('--ode', action='store_true', help = "ODE fast sampler")
parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
parser.add_argument('--save_every', type=int, default=50000, help='How often we wish to save ckpts')
opt = parser.parse_args("")

In [3]:
device = torch.device(f'cuda:{opt.gpu}')
device = torch.device('cuda')
print("N:", opt.N)
N = opt.N
bsize = opt.bsize
beta_min = opt.betamin
beta_max = opt.betamax
sig = opt.sig

N: 1000


In [4]:
train_transformer =  transforms.Compose([
    transforms.RandomHorizontalFlip(0.5), 
    transforms.ToTensor()
])


In [5]:
train_dataset = torchvision.datasets.CIFAR10(root='.', train=True, transform=train_transformer, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='.', train=False, transform=transforms.ToTensor(), download=True)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=opt.bsize,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=opt.bsize,shuffle=True)

In [22]:
train_dir = os.path.join('experiments','train')

In [24]:
fid_eval = FID(real_dir =train_dir, device = device,bsize=1)

path: experiments\train


100%|██████████| 3125/3125 [01:21<00:00, 38.50it/s]


NameError: name 'dataset_train' is not defined

In [28]:
resolution = train_dataset[0][0].shape[-1]
input_nc = train_dataset[0][0].shape[0]
ksize = resolution * 2 - 1
pad = 0

# define forward blur
kernel = gaussian_kernel_1d(ksize, sig)
blur = Deblurring(kernel, input_nc, resolution, device=device)
print("blur.U_small.shape:", blur.U_small.shape)
D_diag = blur.singulars()
fb = ForwardBlurIncreasing(N=N, beta_min=beta_min, beta_max=beta_max, sig=sig, sig_max = opt.sig_max, sig_min = opt.sig_min, D_diag=D_diag,
                    blur=blur, channel=input_nc, device=device, noise_schedule=opt.noise_schedule, resolution=resolution, pad=pad, f_type=opt.f_type)
dir = os.path.join('experiments', opt.name)
writer = SummaryWriter(dir)

contains zero? tensor(False, device='cuda:0')
blur.U_small.shape: torch.Size([32, 32])
fs:  tensor([-6.2563e-05,  0.0000e+00,  6.2563e-05,  ...,  6.2375e-02,
         6.2437e-02,  6.2500e-02], device='cuda:0')
p:  torch.Size([1001, 3072])
D:  torch.Size([1001, 3072])


In [39]:

model = UNetModel(resolution, input_nc, 128, input_nc, blur = blur, dropout=opt.dropout, freq_feat = opt.freq_feat)
if not opt.ckpt == '' and os.path.exists(opt.ckpt):
    model.load_state_dict(torch.load(opt.ckpt))
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = DataParallel(model)

model.to(device)
print("input_nc", input_nc, "resolution", resolution)

# data_loader = torch.utils.data.DataLoader(dataset=dataset_train,
#                                           batch_size=bsize,
#                                           shuffle=True,
#                                           drop_last=True)
# data_loader_test = torch.utils.data.DataLoader(dataset=dataset_test,
#                                                batch_size=bsize,
#                                                shuffle=False,
#                                                drop_last=True)
optimizer = optim.Adam(model.parameters(), lr=opt.lr)
if opt.use_ema:
    optimizer = EMA(optimizer, ema_decay=opt.ema_decay)

# forward process visualization
sample = train_dataset[1][0].unsqueeze(0)

x_0 = sample[:4]
x_0 = x_0.to(device)
i = np.array([500] * x_0.shape[0])
i = torch.from_numpy(i).to(device)
fb.sanity(x_0, i)

sample_list = []
for i in range(0, N+1, N//10):
    if i == 0:
        sample_list.append(x_0[0])
        continue
    i = np.array([i] * x_0.shape[0])
    i = torch.from_numpy(i).to(device)
    x_i = fb.get_x_i(x_0, i)
    sample_list.append(x_i[0])
    print(f"x_{i.item()}.std() = {x_i.std()}")
    print(f"x_{i.item()}.mean() = {x_i.mean()}")


grid_sample = torch.cat(sample_list, dim=2)
utils.tensor_imsave(grid_sample, "./" + dir, "forward_process.jpg")
with open(os.path.join(dir, "config.json"), "w") as json_file:
    json.dump(vars(opt), json_file)
import time
meta_iter = 0

input_blocks torch.Size([128, 3, 3, 3])
input_nc 3 resolution 32
MAE = 3.758759703487158e-06
x_100.std() = 0.49324896931648254
x_100.mean() = 0.4726855754852295
x_200.std() = 0.7742474675178528
x_200.mean() = 0.403974711894989
x_300.std() = 0.9143067598342896
x_300.mean() = 0.321166455745697
x_400.std() = 0.9757927656173706
x_400.mean() = 0.2228381335735321
x_500.std() = 0.9767712354660034
x_500.mean() = 0.1384417861700058
x_600.std() = 0.9953518509864807
x_600.mean() = 0.05848237872123718
x_700.std() = 0.9877786636352539
x_700.mean() = 0.01588168740272522
x_800.std() = 0.9873171448707581
x_800.mean() = 0.025635406374931335
x_900.std() = 1.009985089302063
x_900.mean() = 0.021165339276194572
x_1000.std() = 1.010701298713684
x_1000.mean() = 0.039750274270772934


In [40]:
for step in range(opt.max_iter):
    if not opt.inference:
        elips = time.time()
        try:
            x_0, _ = train_iter.next()
        except:
            train_iter = iter(train_loader)
            image, _ = next(train_iter)
        """
        training
        """
        assert x_0.shape[-1] == resolution, f"{x_0.shape}"
        i = np.random.uniform(1 / N, 1, size = (x_0.shape[0])) * N
        i = torch.from_numpy(i).to(device).type(torch.long)

        x_0 = x_0.to(device)
        x_i, eps = fb.get_x_i(x_0, i, return_eps = True)

        if opt.loss_type == "sm_simple":
            loss = fb.get_loss_i_simple(model, x_0, x_i, i)
        elif opt.loss_type == "eps_simple":
            loss = fb.get_loss_i_eps_simple(model, x_i, i, eps)
        elif opt.loss_type == "sm_exact":
            loss = fb.get_loss_i_exact(model, x_0, x_i, i)
        elif opt.loss_type == "std_matching":
            loss = fb.get_loss_i_std_matching(model, x_i, i, eps)
        writer.add_scalar('loss_train', loss, step)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print(step, loss)
        # print(f"time: {time.time() - elips}")
    # Calcuate FID
    if step > 240001:
        fid_iter = opt.fid_iter
    else:
        fid_iter = 240000
    if (step % fid_iter == 0 and step > 0):
        id = 0
        if not os.path.exists(os.path.join("./",dir, f"{step}")):
            os.mkdir(os.path.join("./",dir, f"{step}"))
        with torch.no_grad():
            if opt.use_ema:
                optimizer.swap_parameters_with_ema(store_params_in_ema=True)
            model = model.eval()
            for _ in range(opt.fid_num_samples // opt.fid_bsize):
                i = np.array([opt.N - 1] * opt.fid_bsize)
                i = torch.from_numpy(i).to(device)
                pred = fb.get_x_N([opt.fid_bsize, input_nc, resolution, resolution], i)
                for i in reversed(range(1, opt.N)):
                    i = np.array([i] * opt.fid_bsize)
                    i = torch.from_numpy(i).to(device)
                    if opt.loss_type == "sm_simple":
                        s = model(pred, i)
                    elif opt.loss_type == "eps_simple":
                        eps = model(pred, i)
                        s = fb.get_score_from_eps(eps, i)
                    elif opt.loss_type == "sm_exact":
                        s = model(pred, i)
                    elif opt.loss_type == "std_matching":
                        std = model(pred, i)
                        s = fb.get_score_from_std(std, i)
                    else:
                        raise NotImplementedError
                    s = fb.U_I_minus_B_Ut(s, i)
                    rms = lambda x: torch.sqrt(torch.mean(x ** 2))
                    # print(f"rms(s) * fb._beta_i(i) = {rms(s) * fb._beta_i(i)[0]}")
                    hf = pred - fb.W(pred, i)
                    # Anderson theorem
                    pred1 = pred + hf # unsharpening mask filtering
                    pred2 = pred1 + s  # # denoising
                    if i[0] > 2:
                        pred = pred2 + fb.U_I_minus_B_sqrt_Ut(torch.randn_like(pred), i) # inject noise
                    else:
                        pred = pred2
                    # print(f"i = {i[0]}, rmse = {torch.sqrt(torch.mean(pred**2))}, mean = {torch.mean(pred)} std = {torch.std(pred)}" )
                for sample in pred:
                    save_image(sample, os.path.join(dir, f"{step}", f"{id:05d}.png"))
                    id += 1
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)
            model = model.train()
        fid = fid_eval(os.path.join(dir, f"{step}"))
        writer.add_scalar('fid', fid, step)
        print(f"step {step}, fid = {fid}")
    if (step % opt.eval_iter == 0 and step > 0) or opt.inference:
        """
        sampling (eval)
        """
        cnt = 0
        loss = 0
        

        with torch.no_grad():
            if opt.use_ema:
                optimizer.swap_parameters_with_ema(store_params_in_ema=True)
            model = model.eval()

            if opt.ode:
                raise NotImplementedError
                def to_flattened_numpy(x):
                    """Flatten a torch tensor `x` and convert it to numpy."""
                    return x.detach().cpu().numpy().reshape((-1,))
                def from_flattened_numpy(x, shape):
                    """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
                    return torch.from_numpy(x.reshape(shape))
                def ode_func(i, y):
                    i = int(i*N)
                    print(f"i = {i}")
                    y = from_flattened_numpy(y, [bsize, input_nc, resolution, resolution]).to(device).type(torch.float32)
                    i = np.array([N - 1] * bsize)
                    i = torch.from_numpy(i).to(device)
                    if opt.loss_type == "sm_simple":
                            s = model(y, i)
                    elif opt.loss_type == "eps_simple":
                        eps = model(y, i)
                        s = fb.get_score_from_eps(eps, i)
                    elif opt.loss_type == "sm_exact":
                        s = model(y, i)
                    elif opt.loss_type == "std_matching":
                        std = model(y, i)
                        s = fb.get_score_from_std(std, i)
                    else:
                        raise NotImplementedError
                    s = fb.U_I_minus_B_Ut(s, i)
                    hf = y - fb.W(y, i)
                    dt = - 1.0 / N
                    drift = (s/2 + hf) / dt
                    drift = to_flattened_numpy(drift)
                    return drift
                x_N = fb.get_x_N([bsize, input_nc, resolution, resolution], N)
                solution = solve_ivp(ode_func, (1, 1e-3), to_flattened_numpy(x_N),
                                     rtol=1e-3, atol=1e-3, method="RK45")
                nfe = solution.nfev
                solution = torch.tensor(solution.y[:, -1]).reshape(x_N.shape).to(device).type(torch.float32)
                
                save_image(solution, "./solution.jpg")
                print(f"nfe = {nfe}")
                raise NotImplementedError
            for x_0, _ in test_loader:
                x_0 = x_0.to(device)
                # for v in range(0, 250, 20):
                #     x_0[:, :, v, :] = 0
                if opt.fromprior:
                    i = np.array([N - 1] * x_0.shape[0])
                    i = torch.from_numpy(i).to(device)
                    pred = fb.get_x_N(x_0.shape, i)
                    print(f"pred.std() = {pred.std()}")
                else:
                    i = np.array([N-1] * x_0.shape[0])
                    i = torch.from_numpy(i).to(device)
                    pred = fb.get_x_i(x_0, i)
                preds = [pred]

                for i in reversed(range(1, N)):
                    i = np.array([i] * x_0.shape[0])
                    i = torch.from_numpy(i).to(device)
                    if opt.gtscore:
                        s = fb.get_score_gt(pred, x_0, i)
                    else:
                        if opt.loss_type == "sm_simple":
                            s = model(pred, i)
                        elif opt.loss_type == "eps_simple":
                            eps = model(pred, i)
                            s = fb.get_score_from_eps(eps, i)
                        elif opt.loss_type == "sm_exact":
                            s = model(pred, i)
                        elif opt.loss_type == "std_matching":
                            std = model(pred, i)
                            s = fb.get_score_from_std(std, i)
                        else:
                            raise NotImplementedError
                    s = fb.U_I_minus_B_Ut(s, i)
                    rms = lambda x: torch.sqrt(torch.mean(x ** 2))
                    # print(f"rms(s) * fb._beta_i(i) = {rms(s) * fb._beta_i(i)[0]}")
                    hf = pred - fb.W(pred, i)
                    # Anderson theorem
                    pred1 = pred + hf # unsharpening mask filtering
                    pred2 = pred1 + s  # # denoising
                    if i[0] > 2:
                        pred = pred2 + fb.U_I_minus_B_sqrt_Ut(torch.randn_like(pred), i) # inject noise
                    else:
                        pred = pred2
                    # print(f"i = {i[0]}, rmse = {torch.sqrt(torch.mean(pred**2))}, mean = {torch.mean(pred)} std = {torch.std(pred)}")
                    # assert rms(pred) < 100
                    if (i[0]) % (N // 10) == 0:
                        img = pred[0]
                        preds.append(pred)

                preds.append(pred)
                assert x_0.shape == pred.shape
                # visualize
                grid = torch.cat(preds, dim=3) # grid_sample.shape: (bsize, channel, H, W * 12)
                # (batch_size, channel, H, W * 12) -> (channel, H * bsize, W * 12)
                grid = grid.permute(1, 0, 2, 3).contiguous().view(grid.shape[1], -1, grid.shape[3])
                # (bsize, channel, H, W) -> (channel, H, W * bsize)
                gt = x_0.permute(1, 2, 0, 3).contiguous().view(x_0.shape[1], -1, x_0.shape[3] * x_0.shape[0])
                if cnt <= 2:
                    utils.tensor_imsave(gt, "./" + dir, f"{step}_{cnt}_GT.jpg")
                    utils.tensor_imsave(grid, "./" + dir, f"{step}_{cnt}_pred.jpg")
              
                cnt += 1
                loss += TF.l1_loss(x_0, pred) / 2

                if cnt == 2:
                    break
        print(f"step: {step} loss: {loss}")
        writer.add_scalar('loss_val', loss, meta_iter)
        f = open('./' + str(dir) + '/log.txt', 'a')

        f.write(f"Step: {step} loss: {loss}" + '\n')

        f.close()
        model.train()
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)
    if step % opt.save_every == 1:
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)
        if torch.cuda.device_count() > 1:
            torch.save(model.module.state_dict(), os.path.join(dir, f"model_{step}.ckpt"))
        else:
            torch.save(model.state_dict(), os.path.join(dir, f"model_{step}.ckpt"))
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)

OutOfMemoryError: CUDA out of memory. Tried to allocate 18.00 MiB (GPU 0; 4.00 GiB total capacity; 3.37 GiB already allocated; 0 bytes free; 3.40 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF