In [1]:
import torch
from PIL import Image
from collections import defaultdict
import torch.nn.functional as TF
import torchvision.datasets as dsets
from torchvision import transforms
import numpy as np

import torch.optim as optim
import torchvision
# import matplotlib.pyplot as plt
import utils
from guided_diffusion.unet import UNetModel
import math
from tensorboardX import SummaryWriter
import os
import json
from collections import namedtuple
import argparse
from torchvision.utils import save_image
from tqdm import tqdm
from blur_diffusion import Deblurring, ForwardBlurIncreasing, gaussian_kernel_1d
from utils import normalize_np, clear
from EMA import EMA
from torch.nn import DataParallel
from fid import FID
from scipy.integrate import solve_ivp

In [40]:
parser = argparse.ArgumentParser(description='Configs')
parser.add_argument('--gpu', default='0',type=str, help='gpu num')
parser.add_argument('--dataset',default='cifar10', type=str, help='cifar10 / mnist')
parser.add_argument('--name', default='blur_diff',type=str, help='Saving directory name')
parser.add_argument('--ckpt', default='', type=str, help='UNet checkpoint')

parser.add_argument('--bsize', default=8, type=int, help='batchsize')
parser.add_argument('--N', default=500, type=int, help='Max diffusion timesteps')
parser.add_argument('--sig', default=0.4, type=float, help='sigma value for blur kernel')
parser.add_argument('--sig_min', default=0, type=float, help='sigma value for blur kernel')
parser.add_argument('--sig_max', default=0.1, type=float, help='sigma value for blur kernel')
parser.add_argument('--lr', default=0.00005, type=float, help='learning rate')
parser.add_argument('--noise_schedule', default='linear', type=str, help='Type of noise schedule to use')
parser.add_argument('--betamin', default=0.0001, type=float, help='beta (min). get_score(1) can diverge if this is too low.')
parser.add_argument('--betamax', default=0.02, type=float, help='beta (max)')
parser.add_argument('--fromprior', default=True, type=bool, help='start sampling from prior')
parser.add_argument('--gtscore', action='store_true', help='Use ground truth score for reverse diffusion')
parser.add_argument('--max_iter', default=15000, type=int, help='max iterations')
parser.add_argument('--eval_iter', default=1000, type=int, help='eval iterations')
parser.add_argument('--fid_iter', default=2000, type=int, help='eval iterations')
parser.add_argument('--fid_num_samples', default=100, type=int, help='eval iterations')
parser.add_argument('--fid_bsize', default=8, type=int, help='eval iterations')
parser.add_argument('--loss_type', type=str, default = 'eps_simple', choices=['sm_simple', 'eps_simple', 'sm_exact', 'std_matching'])
parser.add_argument('--f_type', type=str, default = 'linear', choices=['linear', 'log', 'quadratic', 'cubic', 'quartic', 'triangular'])
parser.add_argument('--dropout', default=0, type=float, help='dropout')

# EMA, save
parser.add_argument('--use_ema', action='store_true',
                    help='use EMA or not')
parser.add_argument('--inference', action='store_true')
parser.add_argument('--freq_feat', action='store_true', help = "concat Utx_i")
parser.add_argument('--ode', action='store_true', help = "ODE fast sampler")
parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
parser.add_argument('--save_every', type=int, default=5000, help='How often we wish to save ckpts')
opt = parser.parse_args("")

In [41]:
device = torch.device(f'cuda:{opt.gpu}')
device = torch.device('cuda')
print("N:", opt.N)
N = opt.N
bsize = opt.bsize
beta_min = opt.betamin
beta_max = opt.betamax
sig = opt.sig

N: 500


In [42]:
train_transformer =  transforms.Compose([
    transforms.RandomHorizontalFlip(0.5), 
    transforms.ToTensor()
])


In [43]:
train_dataset = torchvision.datasets.CIFAR10(root='.', train=True, transform=train_transformer, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='.', train=False, transform=transforms.ToTensor(), download=True)

Files already downloaded and verified
Files already downloaded and verified


In [44]:
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=opt.bsize,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=opt.bsize,shuffle=True)

In [45]:
train_dir = os.path.join('experiments','train')

In [46]:
fid_eval = FID(real_dir =train_dir, device = device,bsize=1)

In [47]:
resolution = train_dataset[0][0].shape[-1]
input_nc = train_dataset[0][0].shape[0]
ksize = resolution * 2 - 1
pad = 0

# define forward blur
kernel = gaussian_kernel_1d(ksize, sig)
blur = Deblurring(kernel, input_nc, resolution, device=device)
print("blur.U_small.shape:", blur.U_small.shape)
D_diag = blur.singulars()
fb = ForwardBlurIncreasing(N=N, beta_min=beta_min, beta_max=beta_max, sig=sig, sig_max = opt.sig_max, sig_min = opt.sig_min, D_diag=D_diag,
                    blur=blur, channel=input_nc, device=device, noise_schedule=opt.noise_schedule, resolution=resolution, pad=pad, f_type=opt.f_type)
dir = os.path.join('experiments', opt.name)
writer = SummaryWriter(dir)

contains zero? tensor(False, device='cuda:0')
blur.U_small.shape: torch.Size([32, 32])
fs:  tensor([-0.0001,  0.0000,  0.0001,  0.0003,  0.0004,  0.0005,  0.0006,  0.0008,
         0.0009,  0.0010,  0.0011,  0.0013,  0.0014,  0.0015,  0.0016,  0.0018,
         0.0019,  0.0020,  0.0021,  0.0023,  0.0024,  0.0025,  0.0026,  0.0028,
         0.0029,  0.0030,  0.0031,  0.0033,  0.0034,  0.0035,  0.0036,  0.0038,
         0.0039,  0.0040,  0.0041,  0.0043,  0.0044,  0.0045,  0.0046,  0.0048,
         0.0049,  0.0050,  0.0051,  0.0053,  0.0054,  0.0055,  0.0056,  0.0058,
         0.0059,  0.0060,  0.0061,  0.0063,  0.0064,  0.0065,  0.0066,  0.0068,
         0.0069,  0.0070,  0.0071,  0.0073,  0.0074,  0.0075,  0.0076,  0.0078,
         0.0079,  0.0080,  0.0081,  0.0083,  0.0084,  0.0085,  0.0086,  0.0088,
         0.0089,  0.0090,  0.0091,  0.0093,  0.0094,  0.0095,  0.0096,  0.0098,
         0.0099,  0.0100,  0.0101,  0.0103,  0.0104,  0.0105,  0.0106,  0.0108,
         0.0109,  0.0110,  0

In [48]:
model = UNetModel(resolution, input_nc, 128, input_nc, blur = blur, dropout=opt.dropout, freq_feat = opt.freq_feat)
if not opt.ckpt == '' and os.path.exists(opt.ckpt):
    model.load_state_dict(torch.load(opt.ckpt))
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = DataParallel(model,device_ids=[0,1,2])

model.to(device)
print("input_nc", input_nc, "resolution", resolution)

# data_loader = torch.utils.data.DataLoader(dataset=dataset_train,
#                                           batch_size=bsize,
#                                           shuffle=True,
#                                           drop_last=True)
# data_loader_test = torch.utils.data.DataLoader(dataset=dataset_test,
#                                                batch_size=bsize,
#                                                shuffle=False,
#                                                drop_last=True)
optimizer = optim.Adam(model.parameters(), lr=opt.lr)
if opt.use_ema:
    optimizer = EMA(optimizer, ema_decay=opt.ema_decay)

# forward process visualization
sample = train_dataset[1][0].unsqueeze(0)

x_0 = sample[:4]
x_0 = x_0.to(device)
i = np.array([500] * x_0.shape[0])
i = torch.from_numpy(i).to(device)
fb.sanity(x_0, i)

sample_list = []
for i in range(0, N+1, N//10):
    if i == 0:
        sample_list.append(x_0[0])
        continue
    i = np.array([i] * x_0.shape[0])
    i = torch.from_numpy(i).to(device)
    x_i = fb.get_x_i(x_0, i)
    sample_list.append(x_i[0])
    print(f"x_{i.item()}.std() = {x_i.std()}")
    print(f"x_{i.item()}.mean() = {x_i.mean()}")


grid_sample = torch.cat(sample_list, dim=2)
utils.tensor_imsave(grid_sample, "./" + dir, "forward_process.jpg")
with open(os.path.join(dir, "config.json"), "w") as json_file:
    json.dump(vars(opt), json_file)
import time
meta_iter = 0

input_blocks torch.Size([128, 3, 3, 3])
Let's use 3 GPUs!
input_nc 3 resolution 32
MAE = 3.856495368381729e-06
x_50.std() = 0.3895820379257202
x_50.mean() = 0.4998919367790222
x_100.std() = 0.6231688261032104
x_100.mean() = 0.4730055332183838
x_150.std() = 0.7919426560401917
x_150.mean() = 0.3916463553905487
x_200.std() = 0.9058626294136047
x_200.mean() = 0.32252490520477295
x_250.std() = 0.9738093018531799
x_250.mean() = 0.2834331691265106
x_300.std() = 1.0011403560638428
x_300.mean() = 0.1843012273311615
x_350.std() = 0.9988614320755005
x_350.mean() = 0.1509888619184494
x_400.std() = 0.9909386038780212
x_400.mean() = 0.12187229096889496
x_450.std() = 0.9773322343826294
x_450.mean() = 0.05976749584078789
x_500.std() = 1.0135698318481445
x_500.mean() = 0.045845091342926025


In [49]:
for step in range(opt.max_iter):
    if not opt.inference:
        elips = time.time()
        try:
            x_0, _ = train_iter.next()
        except:
            train_iter = iter(train_loader)
            image, _ = next(train_iter)
        """
        training
        """
        assert x_0.shape[-1] == resolution, f"{x_0.shape}"
        i = np.random.uniform(1 / N, 1, size = (x_0.shape[0])) * N
        i = torch.from_numpy(i).to(device).type(torch.long)
#         i = i.to(device).type(torch.long)

        x_0 = x_0.to(device)
        x_i, eps = fb.get_x_i(x_0, i, return_eps = True)
        
        if opt.loss_type == "sm_simple":
            loss = fb.get_loss_i_simple(model, x_0, x_i, i)
        elif opt.loss_type == "eps_simple":
            loss = fb.get_loss_i_eps_simple(model, x_i, i, eps)
        elif opt.loss_type == "sm_exact":
            loss = fb.get_loss_i_exact(model, x_0, x_i, i)
        elif opt.loss_type == "std_matching":
            loss = fb.get_loss_i_std_matching(model, x_i, i, eps)
        ## Contrastive Loss
        s = model(x_i, i)
        s = fb.U_I_minus_B_Ut(s, i)
        hf = x_i - fb.W(x_i, i)
        # Anderson theorem
        pred1 = x_i + hf # unsharpening mask filtering
        pred2 = pred1 + s 
        sim_pos = 1/np.linalg.norm(pred2.cpu().detach().numpy()-x_0.cpu().detach().numpy())
        sim_neg = 1/np.linalg.norm(pred2.cpu().detach().numpy()-x_i.cpu().detach().numpy())
        contrastive_loss = (sim_pos - sim_neg)**2
        writer.add_scalar('loss_train', loss, step)

        writer.add_scalar('contrastive_loss', contrastive_loss, step)
        loss += 0.5*contrastive_loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print(step, loss,loss-contrastive_loss,contrastive_loss)
        # print(f"time: {time.time() - elips}")
    # Calcuate FID
    if step > 2401:
        fid_iter = opt.fid_iter
    else:
        fid_iter = 2400
    if (step % fid_iter == 0 and step > 0):
        id = 0
        if not os.path.exists(os.path.join("./",dir, f"{step}")):
            os.mkdir(os.path.join("./",dir, f"{step}"))
        with torch.no_grad():
            if opt.use_ema:
                optimizer.swap_parameters_with_ema(store_params_in_ema=True)
            model = model.eval()
            for _ in range(opt.fid_num_samples // opt.fid_bsize):
                i = np.array([opt.N - 1] * opt.fid_bsize)
                i = torch.from_numpy(i).to(device)
                pred = fb.get_x_N([opt.fid_bsize, input_nc, resolution, resolution], i)
                for i in reversed(range(1, opt.N)):
                    i = np.array([i] * opt.fid_bsize)
                    i = torch.from_numpy(i).to(device)
                    if opt.loss_type == "sm_simple":
                        s = model(pred, i)
                    elif opt.loss_type == "eps_simple":
                        eps = model(pred, i)
                        s = fb.get_score_from_eps(eps, i)
                    elif opt.loss_type == "sm_exact":
                        s = model(pred, i)
                    elif opt.loss_type == "std_matching":
                        std = model(pred, i)
                        s = fb.get_score_from_std(std, i)
                    else:
                        raise NotImplementedError
                    s = fb.U_I_minus_B_Ut(s, i)
                    rms = lambda x: torch.sqrt(torch.mean(x ** 2))
                    # print(f"rms(s) * fb._beta_i(i) = {rms(s) * fb._beta_i(i)[0]}")
                    hf = pred - fb.W(pred, i)
                    # Anderson theorem
                    pred1 = pred + hf # unsharpening mask filtering
                    pred2 = pred1 + s  # # denoising
                    if i[0] > 2:
                        pred = pred2 + fb.U_I_minus_B_sqrt_Ut(torch.randn_like(pred), i) # inject noise
                    else:
                        pred = pred2
                    # print(f"i = {i[0]}, rmse = {torch.sqrt(torch.mean(pred**2))}, mean = {torch.mean(pred)} std = {torch.std(pred)}" )
                for sample in pred:
                    save_image(sample, os.path.join(dir, f"{step}", f"{id:05d}.png"))
                    id += 1
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)
            model = model.train()
        fid = fid_eval(os.path.join(dir, f"{step}"))
        writer.add_scalar('fid', fid, step)
        print(f"step {step}, fid = {fid}")
    if (step % opt.eval_iter == 0 and step > 0) or opt.inference:
        """
        sampling (eval)
        """
        cnt = 0
        loss = 0
        

        with torch.no_grad():
            if opt.use_ema:
                optimizer.swap_parameters_with_ema(store_params_in_ema=True)
            model = model.eval()

            if opt.ode:
                raise NotImplementedError
                def to_flattened_numpy(x):
                    """Flatten a torch tensor `x` and convert it to numpy."""
                    return x.detach().cpu().numpy().reshape((-1,))
                def from_flattened_numpy(x, shape):
                    """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
                    return torch.from_numpy(x.reshape(shape))
                def ode_func(i, y):
                    i = int(i*N)
                    print(f"i = {i}")
                    y = from_flattened_numpy(y, [bsize, input_nc, resolution, resolution]).to(device).type(torch.float32)
                    i = np.array([N - 1] * bsize)
                    i = torch.from_numpy(i).to(device)
                    if opt.loss_type == "sm_simple":
                            s = model(y, i)
                    elif opt.loss_type == "eps_simple":
                        eps = model(y, i)
                        s = fb.get_score_from_eps(eps, i)
                    elif opt.loss_type == "sm_exact":
                        s = model(y, i)
                    elif opt.loss_type == "std_matching":
                        std = model(y, i)
                        s = fb.get_score_from_std(std, i)
                    else:
                        raise NotImplementedError
                    s = fb.U_I_minus_B_Ut(s, i)
                    hf = y - fb.W(y, i)
                    dt = - 1.0 / N
                    drift = (s/2 + hf) / dt
                    drift = to_flattened_numpy(drift)
                    return drift
                x_N = fb.get_x_N([bsize, input_nc, resolution, resolution], N)
                solution = solve_ivp(ode_func, (1, 1e-3), to_flattened_numpy(x_N),
                                     rtol=1e-3, atol=1e-3, method="RK45")
                nfe = solution.nfev
                solution = torch.tensor(solution.y[:, -1]).reshape(x_N.shape).to(device).type(torch.float32)
                
                save_image(solution, "./solution.jpg")
                print(f"nfe = {nfe}")
                raise NotImplementedError
            for x_0, _ in test_loader:
                x_0 = x_0.to(device)
                # for v in range(0, 250, 20):
                #     x_0[:, :, v, :] = 0
                if opt.fromprior:
                    i = np.array([N - 1] * x_0.shape[0])
                    i = torch.from_numpy(i).to(device)
                    pred = fb.get_x_N(x_0.shape, i)
                    print(f"pred.std() = {pred.std()}")
                else:
                    i = np.array([N-1] * x_0.shape[0])
                    i = torch.from_numpy(i).to(device)
                    pred = fb.get_x_i(x_0, i)
                preds = [pred]

                for i in reversed(range(1, N)):
                    i = np.array([i] * x_0.shape[0])
                    i = torch.from_numpy(i).to(device)
                    if opt.gtscore:
                        s = fb.get_score_gt(pred, x_0, i)
                    else:
                        if opt.loss_type == "sm_simple":
                            s = model(pred, i)
                        elif opt.loss_type == "eps_simple":
                            eps = model(pred, i)
                            s = fb.get_score_from_eps(eps, i)
                        elif opt.loss_type == "sm_exact":
                            s = model(pred, i)
                        elif opt.loss_type == "std_matching":
                            std = model(pred, i)
                            s = fb.get_score_from_std(std, i)
                        else:
                            raise NotImplementedError
                    s = fb.U_I_minus_B_Ut(s, i)
                    rms = lambda x: torch.sqrt(torch.mean(x ** 2))
                    # print(f"rms(s) * fb._beta_i(i) = {rms(s) * fb._beta_i(i)[0]}")
                    hf = pred - fb.W(pred, i)
                    # Anderson theorem
                    pred1 = pred + hf # unsharpening mask filtering
                    pred2 = pred1 + s  # # denoising
                    if i[0] > 2:
                        pred = pred2 + fb.U_I_minus_B_sqrt_Ut(torch.randn_like(pred), i) # inject noise
                    else:
                        pred = pred2
                    # print(f"i = {i[0]}, rmse = {torch.sqrt(torch.mean(pred**2))}, mean = {torch.mean(pred)} std = {torch.std(pred)}")
                    # assert rms(pred) < 100
                    if (i[0]) % (N // 10) == 0:
                        img = pred[0]
                        preds.append(pred)

                preds.append(pred)
                assert x_0.shape == pred.shape
                # visualize
                grid = torch.cat(preds, dim=3) # grid_sample.shape: (bsize, channel, H, W * 12)
                # (batch_size, channel, H, W * 12) -> (channel, H * bsize, W * 12)
                grid = grid.permute(1, 0, 2, 3).contiguous().view(grid.shape[1], -1, grid.shape[3])
                # (bsize, channel, H, W) -> (channel, H, W * bsize)
                gt = x_0.permute(1, 2, 0, 3).contiguous().view(x_0.shape[1], -1, x_0.shape[3] * x_0.shape[0])
                if cnt <= 2:
                    utils.tensor_imsave(gt, "./" + dir, f"{step}_{cnt}_GT.jpg")
                    utils.tensor_imsave(grid, "./" + dir, f"{step}_{cnt}_pred.jpg")
              
                cnt += 1
                loss += TF.l1_loss(x_0, pred) / 2

                if cnt == 2:
                    break
        print(f"step: {step} loss: {loss}")
        writer.add_scalar('loss_val', loss, meta_iter)
        f = open('./' + str(dir) + '/log.txt', 'a')

        f.write(f"Step: {step} loss: {loss}" + '\n')

        f.close()
        model.train()
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)
    if step % opt.save_every == 1:
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)
        if torch.cuda.device_count() > 1:
            torch.save(model.module.state_dict(), os.path.join(dir, f"model_{step}.ckpt"))
        else:
            torch.save(model.state_dict(), os.path.join(dir, f"model_{step}.ckpt"))
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)

0 tensor(1.3302, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.6000, device='cuda:0', grad_fn=<SubBackward0>) 0.7302102880077691
100 tensor(0.5038, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0093, device='cuda:0', grad_fn=<SubBackward0>) 0.4944978975924776
200 tensor(0.4481, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.3375, device='cuda:0', grad_fn=<SubBackward0>) 0.7856205155013396
300 tensor(0.1486, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.1006, device='cuda:0', grad_fn=<SubBackward0>) 0.2492219620010407
400 tensor(0.4680, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.4041, device='cuda:0', grad_fn=<SubBackward0>) 0.8721190609295526
500 tensor(0.1884, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.1516, device='cuda:0', grad_fn=<SubBackward0>) 0.33998056227822926
600 tensor(0.0478, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0337, device='cuda:0', grad_fn=<SubBackward0>) 0.08156196135102849
700 tensor(0.0566, device='cuda:0', grad_fn=<AddBa

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 49.69it/s]


step 2400, fid = 273.0727353966689
2500 tensor(0.1169, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0660, device='cuda:0', grad_fn=<SubBackward0>) 0.05088943350815026
2600 tensor(0.0319, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0057, device='cuda:0', grad_fn=<SubBackward0>) 0.026263388750915412
2700 tensor(0.0282, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0149, device='cuda:0', grad_fn=<SubBackward0>) 0.04311267051034763
2800 tensor(0.0244, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0072, device='cuda:0', grad_fn=<SubBackward0>) 0.0316743702120024
2900 tensor(0.0128, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0071, device='cuda:0', grad_fn=<SubBackward0>) 0.01985418374952834
3000 tensor(0.0529, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0241, device='cuda:0', grad_fn=<SubBackward0>) 0.02884554607317023
pred.std() = 1.0023400783538818
pred.std() = 0.996441125869751
step: 3000 loss: 0.3501817584037781
3100 tensor(0.0325, device='cuda:0', grad

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 49.35it/s]


step 4000, fid = 264.3244701823735
pred.std() = 0.9967623949050903
pred.std() = 0.9949722290039062
step: 4000 loss: 0.336889386177063
4100 tensor(0.0342, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0012, device='cuda:0', grad_fn=<SubBackward0>) 0.033057343636145875
4200 tensor(0.0436, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0242, device='cuda:0', grad_fn=<SubBackward0>) 0.019475105539200954
4300 tensor(0.0734, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0369, device='cuda:0', grad_fn=<SubBackward0>) 0.03649837748462682
4400 tensor(0.0365, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0070, device='cuda:0', grad_fn=<SubBackward0>) 0.04351182119715914
4500 tensor(0.0171, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0088, device='cuda:0', grad_fn=<SubBackward0>) 0.025852617796296246
4600 tensor(0.0318, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0096, device='cuda:0', grad_fn=<SubBackward0>) 0.04138920393091185
4700 tensor(0.0414, device='cuda:0', g

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 50.48it/s]


step 6000, fid = 229.99878325874857
pred.std() = 1.0082104206085205
pred.std() = 1.0039129257202148
step: 6000 loss: 0.30954107642173767
6100 tensor(0.0592, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0052, device='cuda:0', grad_fn=<SubBackward0>) 0.06442813928751577
6200 tensor(0.0262, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0126, device='cuda:0', grad_fn=<SubBackward0>) 0.038785531684942774
6300 tensor(0.0247, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0101, device='cuda:0', grad_fn=<SubBackward0>) 0.03482856174872956
6400 tensor(0.0208, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0134, device='cuda:0', grad_fn=<SubBackward0>) 0.03420890456691399
6500 tensor(0.0481, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0103, device='cuda:0', grad_fn=<SubBackward0>) 0.03777685412345786
6600 tensor(0.0219, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0124, device='cuda:0', grad_fn=<SubBackward0>) 0.03429055834933888
6700 tensor(0.0189, device='cuda:0'

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 51.65it/s]


step 8000, fid = 224.55480968818586
pred.std() = 0.9976666569709778
pred.std() = 0.9968157410621643
step: 8000 loss: 0.3710719048976898
8100 tensor(0.0132, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0028, device='cuda:0', grad_fn=<SubBackward0>) 0.01600756486471465
8200 tensor(0.0127, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0063, device='cuda:0', grad_fn=<SubBackward0>) 0.019000161943111857
8300 tensor(0.0147, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0037, device='cuda:0', grad_fn=<SubBackward0>) 0.018375803907713796
8400 tensor(0.0241, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0032, device='cuda:0', grad_fn=<SubBackward0>) 0.027298409468686672
8500 tensor(0.0181, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0104, device='cuda:0', grad_fn=<SubBackward0>) 0.028488340998373943
8600 tensor(0.0790, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0573, device='cuda:0', grad_fn=<SubBackward0>) 0.021629861197617655
8700 tensor(0.0269, device='cuda

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 49.16it/s]


step 10000, fid = 233.82995905866238
pred.std() = 0.9999263286590576
pred.std() = 1.0010802745819092
step: 10000 loss: 0.2695009708404541
10100 tensor(0.0448, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0057, device='cuda:0', grad_fn=<SubBackward0>) 0.05055498313183825
10200 tensor(0.0289, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0019, device='cuda:0', grad_fn=<SubBackward0>) 0.03080775985831232
10300 tensor(0.0131, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0044, device='cuda:0', grad_fn=<SubBackward0>) 0.017448374603820035
10400 tensor(0.0185, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0077, device='cuda:0', grad_fn=<SubBackward0>) 0.026215929714746105
10500 tensor(0.0162, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0066, device='cuda:0', grad_fn=<SubBackward0>) 0.022792802852890835
10600 tensor(0.0179, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0065, device='cuda:0', grad_fn=<SubBackward0>) 0.024410212601361492
10700 tensor(0.0255, dev

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 48.71it/s]


step 12000, fid = 247.27835399465997
pred.std() = 1.0032517910003662
pred.std() = 1.0130748748779297
step: 12000 loss: 0.26516446471214294
12100 tensor(0.0266, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0051, device='cuda:0', grad_fn=<SubBackward0>) 0.021491657314463695
12200 tensor(0.0281, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0010, device='cuda:0', grad_fn=<SubBackward0>) 0.027027214011467873
12300 tensor(0.0151, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0069, device='cuda:0', grad_fn=<SubBackward0>) 0.022082012649977927
12400 tensor(0.0216, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0124, device='cuda:0', grad_fn=<SubBackward0>) 0.034065395749385996
12500 tensor(0.0758, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0304, device='cuda:0', grad_fn=<SubBackward0>) 0.10619041101278824
12600 tensor(0.0374, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0166, device='cuda:0', grad_fn=<SubBackward0>) 0.05399135966719743
12700 tensor(0.0142, devi

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 50.48it/s]


step 14000, fid = 215.9210782593476
pred.std() = 1.0019195079803467
pred.std() = 0.9920446276664734
step: 14000 loss: 0.3014170527458191
14100 tensor(0.0140, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0077, device='cuda:0', grad_fn=<SubBackward0>) 0.021671475170353647
14200 tensor(0.0177, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0036, device='cuda:0', grad_fn=<SubBackward0>) 0.021213090419841537
14300 tensor(0.0191, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0089, device='cuda:0', grad_fn=<SubBackward0>) 0.02805821828983966
14400 tensor(0.0521, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0105, device='cuda:0', grad_fn=<SubBackward0>) 0.04158707966083814
14500 tensor(0.0144, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0063, device='cuda:0', grad_fn=<SubBackward0>) 0.020659562716549874
14600 tensor(0.0153, device='cuda:0', grad_fn=<AddBackward0>) tensor(-0.0086, device='cuda:0', grad_fn=<SubBackward0>) 0.023925027799188554
14700 tensor(0.0302, devic

In [50]:
for step in range(opt.max_iter):
    if not opt.inference:
        elips = time.time()
        try:
            x_0, _ = train_iter.next()
        except:
            train_iter = iter(train_loader)
            image, _ = next(train_iter)
        """
        training
        """
        assert x_0.shape[-1] == resolution, f"{x_0.shape}"
        i = np.random.uniform(1 / N, 1, size = (x_0.shape[0])) * N
        i = torch.from_numpy(i).to(device).type(torch.long)
#         i = i.to(device).type(torch.long)

        x_0 = x_0.to(device)
        x_i, eps = fb.get_x_i(x_0, i, return_eps = True)
        
        if opt.loss_type == "sm_simple":
            loss = fb.get_loss_i_simple(model, x_0, x_i, i)
        elif opt.loss_type == "eps_simple":
            loss = fb.get_loss_i_eps_simple(model, x_i, i, eps)
        elif opt.loss_type == "sm_exact":
            loss = fb.get_loss_i_exact(model, x_0, x_i, i)
        elif opt.loss_type == "std_matching":
            loss = fb.get_loss_i_std_matching(model, x_i, i, eps)
        ## Contrastive Loss
#         s = model(x_i, i)
#         s = fb.U_I_minus_B_Ut(s, i)
#         hf = x_i - fb.W(x_i, i)
#         # Anderson theorem
#         pred1 = x_i + hf # unsharpening mask filtering
#         pred2 = pred1 + s 
#         sim_pos = 1/np.linalg.norm(pred2.cpu().detach().numpy()-x_0.cpu().detach().numpy())
#         sim_neg = 1/np.linalg.norm(pred2.cpu().detach().numpy()-x_i.cpu().detach().numpy())
#         contrastive_loss = (sim_pos - sim_neg)**2
        writer.add_scalar('loss_train', loss, step)

#         writer.add_scalar('contrastive_loss', contrastive_loss, step)
#         loss += 0.5*contrastive_loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if step % 100 == 0:
#             print(step, loss,loss-contrastive_loss,contrastive_loss)
            print(step,loss)
        # print(f"time: {time.time() - elips}")
    # Calcuate FID
    if step > 2401:
        fid_iter = opt.fid_iter
    else:
        fid_iter = 2400
    if (step % fid_iter == 0 and step > 0):
        id = 0
        if not os.path.exists(os.path.join("./",dir, f"{step}")):
            os.mkdir(os.path.join("./",dir, f"{step}"))
        with torch.no_grad():
            if opt.use_ema:
                optimizer.swap_parameters_with_ema(store_params_in_ema=True)
            model = model.eval()
            for _ in range(opt.fid_num_samples // opt.fid_bsize):
                i = np.array([opt.N - 1] * opt.fid_bsize)
                i = torch.from_numpy(i).to(device)
                pred = fb.get_x_N([opt.fid_bsize, input_nc, resolution, resolution], i)
                for i in reversed(range(1, opt.N)):
                    i = np.array([i] * opt.fid_bsize)
                    i = torch.from_numpy(i).to(device)
                    if opt.loss_type == "sm_simple":
                        s = model(pred, i)
                    elif opt.loss_type == "eps_simple":
                        eps = model(pred, i)
                        s = fb.get_score_from_eps(eps, i)
                    elif opt.loss_type == "sm_exact":
                        s = model(pred, i)
                    elif opt.loss_type == "std_matching":
                        std = model(pred, i)
                        s = fb.get_score_from_std(std, i)
                    else:
                        raise NotImplementedError
                    s = fb.U_I_minus_B_Ut(s, i)
                    rms = lambda x: torch.sqrt(torch.mean(x ** 2))
                    # print(f"rms(s) * fb._beta_i(i) = {rms(s) * fb._beta_i(i)[0]}")
                    hf = pred - fb.W(pred, i)
                    # Anderson theorem
                    pred1 = pred + hf # unsharpening mask filtering
                    pred2 = pred1 + s  # # denoising
                    if i[0] > 2:
                        pred = pred2 + fb.U_I_minus_B_sqrt_Ut(torch.randn_like(pred), i) # inject noise
                    else:
                        pred = pred2
                    # print(f"i = {i[0]}, rmse = {torch.sqrt(torch.mean(pred**2))}, mean = {torch.mean(pred)} std = {torch.std(pred)}" )
                for sample in pred:
                    save_image(sample, os.path.join(dir, f"{step}", f"{id:05d}.png"))
                    id += 1
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)
            model = model.train()
        fid = fid_eval(os.path.join(dir, f"{step}"))
        writer.add_scalar('fid', fid, step)
        print(f"step {step}, fid = {fid}")
    if (step % opt.eval_iter == 0 and step > 0) or opt.inference:
        """
        sampling (eval)
        """
        cnt = 0
        loss = 0
        

        with torch.no_grad():
            if opt.use_ema:
                optimizer.swap_parameters_with_ema(store_params_in_ema=True)
            model = model.eval()

            if opt.ode:
                raise NotImplementedError
                def to_flattened_numpy(x):
                    """Flatten a torch tensor `x` and convert it to numpy."""
                    return x.detach().cpu().numpy().reshape((-1,))
                def from_flattened_numpy(x, shape):
                    """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
                    return torch.from_numpy(x.reshape(shape))
                def ode_func(i, y):
                    i = int(i*N)
                    print(f"i = {i}")
                    y = from_flattened_numpy(y, [bsize, input_nc, resolution, resolution]).to(device).type(torch.float32)
                    i = np.array([N - 1] * bsize)
                    i = torch.from_numpy(i).to(device)
                    if opt.loss_type == "sm_simple":
                            s = model(y, i)
                    elif opt.loss_type == "eps_simple":
                        eps = model(y, i)
                        s = fb.get_score_from_eps(eps, i)
                    elif opt.loss_type == "sm_exact":
                        s = model(y, i)
                    elif opt.loss_type == "std_matching":
                        std = model(y, i)
                        s = fb.get_score_from_std(std, i)
                    else:
                        raise NotImplementedError
                    s = fb.U_I_minus_B_Ut(s, i)
                    hf = y - fb.W(y, i)
                    dt = - 1.0 / N
                    drift = (s/2 + hf) / dt
                    drift = to_flattened_numpy(drift)
                    return drift
                x_N = fb.get_x_N([bsize, input_nc, resolution, resolution], N)
                solution = solve_ivp(ode_func, (1, 1e-3), to_flattened_numpy(x_N),
                                     rtol=1e-3, atol=1e-3, method="RK45")
                nfe = solution.nfev
                solution = torch.tensor(solution.y[:, -1]).reshape(x_N.shape).to(device).type(torch.float32)
                
                save_image(solution, "./solution.jpg")
                print(f"nfe = {nfe}")
                raise NotImplementedError
            for x_0, _ in test_loader:
                x_0 = x_0.to(device)
                # for v in range(0, 250, 20):
                #     x_0[:, :, v, :] = 0
                if opt.fromprior:
                    i = np.array([N - 1] * x_0.shape[0])
                    i = torch.from_numpy(i).to(device)
                    pred = fb.get_x_N(x_0.shape, i)
                    print(f"pred.std() = {pred.std()}")
                else:
                    i = np.array([N-1] * x_0.shape[0])
                    i = torch.from_numpy(i).to(device)
                    pred = fb.get_x_i(x_0, i)
                preds = [pred]

                for i in reversed(range(1, N)):
                    i = np.array([i] * x_0.shape[0])
                    i = torch.from_numpy(i).to(device)
                    if opt.gtscore:
                        s = fb.get_score_gt(pred, x_0, i)
                    else:
                        if opt.loss_type == "sm_simple":
                            s = model(pred, i)
                        elif opt.loss_type == "eps_simple":
                            eps = model(pred, i)
                            s = fb.get_score_from_eps(eps, i)
                        elif opt.loss_type == "sm_exact":
                            s = model(pred, i)
                        elif opt.loss_type == "std_matching":
                            std = model(pred, i)
                            s = fb.get_score_from_std(std, i)
                        else:
                            raise NotImplementedError
                    s = fb.U_I_minus_B_Ut(s, i)
                    rms = lambda x: torch.sqrt(torch.mean(x ** 2))
                    # print(f"rms(s) * fb._beta_i(i) = {rms(s) * fb._beta_i(i)[0]}")
                    hf = pred - fb.W(pred, i)
                    # Anderson theorem
                    pred1 = pred + hf # unsharpening mask filtering
                    pred2 = pred1 + s  # # denoising
                    if i[0] > 2:
                        pred = pred2 + fb.U_I_minus_B_sqrt_Ut(torch.randn_like(pred), i) # inject noise
                    else:
                        pred = pred2
                    # print(f"i = {i[0]}, rmse = {torch.sqrt(torch.mean(pred**2))}, mean = {torch.mean(pred)} std = {torch.std(pred)}")
                    # assert rms(pred) < 100
                    if (i[0]) % (N // 10) == 0:
                        img = pred[0]
                        preds.append(pred)

                preds.append(pred)
                assert x_0.shape == pred.shape
                # visualize
                grid = torch.cat(preds, dim=3) # grid_sample.shape: (bsize, channel, H, W * 12)
                # (batch_size, channel, H, W * 12) -> (channel, H * bsize, W * 12)
                grid = grid.permute(1, 0, 2, 3).contiguous().view(grid.shape[1], -1, grid.shape[3])
                # (bsize, channel, H, W) -> (channel, H, W * bsize)
                gt = x_0.permute(1, 2, 0, 3).contiguous().view(x_0.shape[1], -1, x_0.shape[3] * x_0.shape[0])
                if cnt <= 2:
                    utils.tensor_imsave(gt, "./" + dir, f"{step}_{cnt}_GT.jpg")
                    utils.tensor_imsave(grid, "./" + dir, f"{step}_{cnt}_pred.jpg")
              
                cnt += 1
                loss += TF.l1_loss(x_0, pred) / 2

                if cnt == 2:
                    break
        print(f"step: {step} loss: {loss}")
        writer.add_scalar('loss_val', loss, meta_iter)
        f = open('./' + str(dir) + '/log.txt', 'a')

        f.write(f"Step: {step} loss: {loss}" + '\n')

        f.close()
        model.train()
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)
    if step % opt.save_every == 1:
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)
        if torch.cuda.device_count() > 1:
            torch.save(model.module.state_dict(), os.path.join(dir, f"model_{step}.ckpt"))
        else:
            torch.save(model.state_dict(), os.path.join(dir, f"model_{step}.ckpt"))
        if opt.use_ema:
            optimizer.swap_parameters_with_ema(store_params_in_ema=True)

0 tensor(0.0074, device='cuda:0', grad_fn=<MeanBackward0>)
100 tensor(0.0020, device='cuda:0', grad_fn=<MeanBackward0>)
200 tensor(0.0173, device='cuda:0', grad_fn=<MeanBackward0>)
300 tensor(0.0083, device='cuda:0', grad_fn=<MeanBackward0>)
400 tensor(0.0038, device='cuda:0', grad_fn=<MeanBackward0>)
500 tensor(0.0025, device='cuda:0', grad_fn=<MeanBackward0>)
600 tensor(0.0029, device='cuda:0', grad_fn=<MeanBackward0>)
700 tensor(0.0053, device='cuda:0', grad_fn=<MeanBackward0>)
800 tensor(0.0018, device='cuda:0', grad_fn=<MeanBackward0>)
900 tensor(0.0045, device='cuda:0', grad_fn=<MeanBackward0>)
1000 tensor(0.0059, device='cuda:0', grad_fn=<MeanBackward0>)
pred.std() = 1.006337285041809
pred.std() = 0.9932403564453125
step: 1000 loss: 0.26761966943740845
1100 tensor(0.0019, device='cuda:0', grad_fn=<MeanBackward0>)
1200 tensor(0.0025, device='cuda:0', grad_fn=<MeanBackward0>)
1300 tensor(0.0075, device='cuda:0', grad_fn=<MeanBackward0>)
1400 tensor(0.0274, device='cuda:0', grad_fn

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 50.22it/s]


step 2400, fid = 263.15244523137915
2500 tensor(0.0021, device='cuda:0', grad_fn=<MeanBackward0>)
2600 tensor(0.0217, device='cuda:0', grad_fn=<MeanBackward0>)
2700 tensor(0.0059, device='cuda:0', grad_fn=<MeanBackward0>)
2800 tensor(0.0025, device='cuda:0', grad_fn=<MeanBackward0>)
2900 tensor(0.0233, device='cuda:0', grad_fn=<MeanBackward0>)
3000 tensor(0.0075, device='cuda:0', grad_fn=<MeanBackward0>)
pred.std() = 0.9993970394134521
pred.std() = 1.0008867979049683
step: 3000 loss: 0.2980382442474365
3100 tensor(0.0126, device='cuda:0', grad_fn=<MeanBackward0>)
3200 tensor(0.0075, device='cuda:0', grad_fn=<MeanBackward0>)
3300 tensor(0.0105, device='cuda:0', grad_fn=<MeanBackward0>)
3400 tensor(0.0046, device='cuda:0', grad_fn=<MeanBackward0>)
3500 tensor(0.0043, device='cuda:0', grad_fn=<MeanBackward0>)
3600 tensor(0.0041, device='cuda:0', grad_fn=<MeanBackward0>)
3700 tensor(0.0056, device='cuda:0', grad_fn=<MeanBackward0>)
3800 tensor(0.0030, device='cuda:0', grad_fn=<MeanBackward

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 49.90it/s]


step 4000, fid = 240.19906476128102
pred.std() = 1.001698613166809
pred.std() = 1.0025451183319092
step: 4000 loss: 0.2672337591648102
4100 tensor(0.0011, device='cuda:0', grad_fn=<MeanBackward0>)
4200 tensor(0.0049, device='cuda:0', grad_fn=<MeanBackward0>)
4300 tensor(0.0062, device='cuda:0', grad_fn=<MeanBackward0>)
4400 tensor(0.0064, device='cuda:0', grad_fn=<MeanBackward0>)
4500 tensor(0.0022, device='cuda:0', grad_fn=<MeanBackward0>)
4600 tensor(0.0042, device='cuda:0', grad_fn=<MeanBackward0>)
4700 tensor(0.0070, device='cuda:0', grad_fn=<MeanBackward0>)
4800 tensor(0.0192, device='cuda:0', grad_fn=<MeanBackward0>)
4900 tensor(0.0084, device='cuda:0', grad_fn=<MeanBackward0>)
5000 tensor(0.0098, device='cuda:0', grad_fn=<MeanBackward0>)
pred.std() = 0.9998305439949036
pred.std() = 1.0077474117279053
step: 5000 loss: 0.24289584159851074
5100 tensor(0.0142, device='cuda:0', grad_fn=<MeanBackward0>)
5200 tensor(0.0068, device='cuda:0', grad_fn=<MeanBackward0>)
5300 tensor(0.0021, 

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 49.54it/s]


step 6000, fid = 251.79785457737836
pred.std() = 1.0009417533874512
pred.std() = 0.9997055530548096
step: 6000 loss: 0.23946842551231384
6100 tensor(0.0061, device='cuda:0', grad_fn=<MeanBackward0>)
6200 tensor(0.0239, device='cuda:0', grad_fn=<MeanBackward0>)
6300 tensor(0.0027, device='cuda:0', grad_fn=<MeanBackward0>)
6400 tensor(0.0246, device='cuda:0', grad_fn=<MeanBackward0>)
6500 tensor(0.0040, device='cuda:0', grad_fn=<MeanBackward0>)
6600 tensor(0.0033, device='cuda:0', grad_fn=<MeanBackward0>)
6700 tensor(0.0109, device='cuda:0', grad_fn=<MeanBackward0>)
6800 tensor(0.0043, device='cuda:0', grad_fn=<MeanBackward0>)
6900 tensor(0.0040, device='cuda:0', grad_fn=<MeanBackward0>)
7000 tensor(0.0034, device='cuda:0', grad_fn=<MeanBackward0>)
pred.std() = 0.9974673390388489
pred.std() = 1.0030509233474731
step: 7000 loss: 0.2558271586894989
7100 tensor(0.0122, device='cuda:0', grad_fn=<MeanBackward0>)
7200 tensor(0.0112, device='cuda:0', grad_fn=<MeanBackward0>)
7300 tensor(0.0028,

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 52.26it/s]


step 8000, fid = 238.24954213123374
pred.std() = 1.0023409128189087
pred.std() = 1.0020124912261963
step: 8000 loss: 0.25425976514816284
8100 tensor(0.0063, device='cuda:0', grad_fn=<MeanBackward0>)
8200 tensor(0.0026, device='cuda:0', grad_fn=<MeanBackward0>)
8300 tensor(0.0035, device='cuda:0', grad_fn=<MeanBackward0>)
8400 tensor(0.0087, device='cuda:0', grad_fn=<MeanBackward0>)
8500 tensor(0.0053, device='cuda:0', grad_fn=<MeanBackward0>)
8600 tensor(0.0037, device='cuda:0', grad_fn=<MeanBackward0>)
8700 tensor(0.0008, device='cuda:0', grad_fn=<MeanBackward0>)
8800 tensor(0.0051, device='cuda:0', grad_fn=<MeanBackward0>)
8900 tensor(0.0044, device='cuda:0', grad_fn=<MeanBackward0>)
9000 tensor(0.0043, device='cuda:0', grad_fn=<MeanBackward0>)
pred.std() = 1.0003019571304321
pred.std() = 1.0041351318359375
step: 9000 loss: 0.2618172764778137
9100 tensor(0.0190, device='cuda:0', grad_fn=<MeanBackward0>)
9200 tensor(0.0056, device='cuda:0', grad_fn=<MeanBackward0>)
9300 tensor(0.0073,

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 50.03it/s]


step 10000, fid = 242.08618148383653
pred.std() = 0.9932202696800232
pred.std() = 1.0001039505004883
step: 10000 loss: 0.32482099533081055
10100 tensor(0.0086, device='cuda:0', grad_fn=<MeanBackward0>)
10200 tensor(0.0084, device='cuda:0', grad_fn=<MeanBackward0>)
10300 tensor(0.0311, device='cuda:0', grad_fn=<MeanBackward0>)
10400 tensor(0.0083, device='cuda:0', grad_fn=<MeanBackward0>)
10500 tensor(0.0040, device='cuda:0', grad_fn=<MeanBackward0>)
10600 tensor(0.0035, device='cuda:0', grad_fn=<MeanBackward0>)
10700 tensor(0.0017, device='cuda:0', grad_fn=<MeanBackward0>)
10800 tensor(0.0024, device='cuda:0', grad_fn=<MeanBackward0>)
10900 tensor(0.0027, device='cuda:0', grad_fn=<MeanBackward0>)
11000 tensor(0.0414, device='cuda:0', grad_fn=<MeanBackward0>)
pred.std() = 1.0065643787384033
pred.std() = 0.9918767809867859
step: 11000 loss: 0.3034140467643738
11100 tensor(0.0251, device='cuda:0', grad_fn=<MeanBackward0>)
11200 tensor(0.0029, device='cuda:0', grad_fn=<MeanBackward0>)
1130

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 52.68it/s]


step 12000, fid = 254.29508877499453
pred.std() = 0.9956225156784058
pred.std() = 1.0104796886444092
step: 12000 loss: 0.22754083573818207
12100 tensor(0.0300, device='cuda:0', grad_fn=<MeanBackward0>)
12200 tensor(0.0220, device='cuda:0', grad_fn=<MeanBackward0>)
12300 tensor(0.0049, device='cuda:0', grad_fn=<MeanBackward0>)
12400 tensor(0.0062, device='cuda:0', grad_fn=<MeanBackward0>)
12500 tensor(0.0041, device='cuda:0', grad_fn=<MeanBackward0>)
12600 tensor(0.0032, device='cuda:0', grad_fn=<MeanBackward0>)
12700 tensor(0.0007, device='cuda:0', grad_fn=<MeanBackward0>)
12800 tensor(0.0020, device='cuda:0', grad_fn=<MeanBackward0>)
12900 tensor(0.0030, device='cuda:0', grad_fn=<MeanBackward0>)
13000 tensor(0.0016, device='cuda:0', grad_fn=<MeanBackward0>)
pred.std() = 0.9944551587104797
pred.std() = 1.0035085678100586
step: 13000 loss: 0.26888251304626465
13100 tensor(0.0023, device='cuda:0', grad_fn=<MeanBackward0>)
13200 tensor(0.0275, device='cuda:0', grad_fn=<MeanBackward0>)
133

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [00:01<00:00, 50.24it/s]


step 14000, fid = 275.4149450715353
pred.std() = 0.9996069669723511
pred.std() = 0.9931371808052063
step: 14000 loss: 0.3260786831378937
14100 tensor(0.0213, device='cuda:0', grad_fn=<MeanBackward0>)
14200 tensor(0.0053, device='cuda:0', grad_fn=<MeanBackward0>)
14300 tensor(0.0026, device='cuda:0', grad_fn=<MeanBackward0>)
14400 tensor(0.0056, device='cuda:0', grad_fn=<MeanBackward0>)
14500 tensor(0.0045, device='cuda:0', grad_fn=<MeanBackward0>)
14600 tensor(0.0046, device='cuda:0', grad_fn=<MeanBackward0>)
14700 tensor(0.0040, device='cuda:0', grad_fn=<MeanBackward0>)
14800 tensor(0.0136, device='cuda:0', grad_fn=<MeanBackward0>)
14900 tensor(0.0416, device='cuda:0', grad_fn=<MeanBackward0>)
