In [1]:
!nvidia-smi
!which python | grep DYY

Mon Sep 16 18:56:27 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   45C    P0              41W / 300W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import torch
torch.manual_seed(3407)
torch.backends.cudnn.deterministic = True

# Model

In [3]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class SinousEmbedding(nn.Module):
#     def __init__(self, dim) -> None:
#         super().__init__()
#         assert dim%2==0,NotImplementedError()
#         self.angles = (1000.**(-2/dim))**torch.arange(1,dim//2+1,1,dtype=torch.float).cuda()
#         self.angles.requires_grad_(False)
#     def forward(self,x):
#         angles = torch.einsum('m,i->im',self.angles,x.float())
#         return torch.cat((torch.sin(angles),torch.cos(angles)),dim=1)

# class DDPM(nn.Module):
#     def __init__(self, *args, **kwargs) -> None:
#         super().__init__(*args, **kwargs)
#         self.in_size = 28 * 28
#         self.t_embedding_dim = 256
#         self.t_embedding = SinousEmbedding(dim=self.t_embedding_dim)
#         self.up = nn.ModuleList([
#             nn.Sequential(
#                 nn.Linear(784+self.t_embedding_dim,64),
#                 nn.ReLU(),
#             ),
#             nn.Sequential(
#                 nn.Linear(64,32),
#                 nn.ReLU(),
#             ),
#             # nn.Sequential(
#             #     nn.Linear(256,256),
#             #     # nn.LeakyReLU(0.1),
#             # ),
#         ])
#         self.middle = nn.ModuleList([
#             nn.Linear(32,32),
#             # nn.LeakyReLU(0.1),
#         ])
#         self.down= nn.ModuleList([
#             nn.Sequential(
#                 nn.Linear(32,32),
#                 nn.ReLU(),
#             ),
#             # nn.Sequential(
#             #     nn.Linear(256,256),
#             #     # nn.LeakyReLU(0.1),
#             # ),
#             nn.Sequential(
#                 nn.Linear(32,64),
#                 nn.ReLU(),
#             ),
#         ])
#         self.end_mlp = nn.Linear(64,784)
#         self.apply_init()

#     def apply_init(self):
#         for m in self.modules():
#             if isinstance(m, nn.Linear):
#                 nn.init.xavier_normal_(m.weight)
#                 nn.init.constant_(m.bias, 0)

#     def forward(self,x,t):
#         x = x.reshape(-1,784)
#         ttensor = self.t_embedding(t) # [batch, 256]
#         batch = x.shape[0]
#         xc = x.clone()
#         ups = []
#         x = torch.cat((x,ttensor),dim=-1)
#         for ly in self.up:
#             x = ly(x)
#             ups.append(x.clone())
#         for ly in self.middle:
#             x = ly(x)
#         for ly in self.down:
#             x = ly(x) + ups.pop()

#         x = self.end_mlp(x)
#         x = (x + xc)
#         return x

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SinousEmbedding(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        assert dim%2==0,NotImplementedError()
        self.angles = (1000.**(-2/dim))**torch.arange(1,dim//2+1,1,dtype=torch.float).cuda()
        self.angles.requires_grad_(False)
    def forward(self,x):
        angles = torch.einsum('m,i->im',self.angles,x.float())
        return torch.cat((torch.sin(angles),torch.cos(angles)),dim=1)

class F_x_t(nn.Module):

    def __init__(self,in_channels,out_channels,out_size,kernel_size=3,t_shape=64,attn=False,attn_dim=32) -> None:
        super().__init__()
        # self.t_channels = out_channels // 2
        # self.conv_channels = out_channels - self.t_channels
        self.t_channels = out_channels
        self.conv_channels = out_channels
        self.conv = nn.Conv2d(in_channels, self.conv_channels, kernel_size=kernel_size, padding=kernel_size//2)
        self.out_size = out_size
        self.fc = nn.Linear(t_shape, self.t_channels)
        self.attn = attn
        if attn:
            self.Q  = nn.Conv2d(out_channels, attn_dim, kernel_size=1)
            self.K  = nn.Conv2d(out_channels, attn_dim, kernel_size=1)
            self.V  = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        # self.fc = nn.Embedding(t_shape, self.t_num)

    def forward(self, x, t):
        if self.t_channels == 0:
            raise NotImplementedError()
            return self.conv(x)
        # return torch.cat([self.conv(x),self.fc(t).unsqueeze(-1).unsqueeze(-1).expand(t.shape[0], self.t_channels, self.out_size, self.out_size)],dim=1).relu()
        val = (self.conv(x) + self.fc(t).unsqueeze(-1).unsqueeze(-1).expand(t.shape[0], self.t_channels, self.out_size, self.out_size))
        if self.attn:
            q = self.Q(val)
            k = self.K(val)
            v = self.V(val)
            attn_score = torch.einsum('bchw,bcxy->bhwxy',q,k).reshape(q.shape[0],*q.shape[-2:],-1)
            attn_score = attn_score.softmax(dim=-1).reshape(q.shape[0],*q.shape[-2:],*k.shape[-2:])
            return torch.einsum('bhwxy,bcxy->bchw',attn_score,v).relu()
        return val.relu()

class DDPM(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.t_embedding_dim = 32
        self.t_embedding = SinousEmbedding(dim=self.t_embedding_dim)
        self.up= nn.ModuleList([
            F_x_t(in_channels=1,out_channels=32,out_size=32,kernel_size=3,t_shape=self.t_embedding_dim),
            F_x_t(in_channels=32,out_channels=64,out_size=16,kernel_size=3,t_shape=self.t_embedding_dim),
            F_x_t(in_channels=64,out_channels=128,out_size=8,kernel_size=3,t_shape=self.t_embedding_dim,attn=False),
            # F_x_t(in_channels=128,out_channels=128,out_size=4,kernel_size=1,t_shape=self.t_embedding_dim),
        ])
        self.middle = nn.ModuleList([
            # nn.Identity()
            F_x_t(in_channels=128,out_channels=128,out_size=4,kernel_size=1,t_shape=self.t_embedding_dim,attn=False),
        ])
        self.down= nn.ModuleList([
            # F_x_t(in_channels=128,out_channels=128,out_size=2,kernel_size=1,t_shape=self.t_embedding_dim),
            F_x_t(in_channels=128,out_channels=64,out_size=8,kernel_size=3,t_shape=self.t_embedding_dim,attn=False),
            F_x_t(in_channels=64,out_channels=32,out_size=16,kernel_size=3,t_shape=self.t_embedding_dim),
            F_x_t(in_channels=32,out_channels=16,out_size=32,kernel_size=3,t_shape=self.t_embedding_dim),
        ])
        # self.end_mlp = nn.Conv2d(32,1,kernel_size=3,padding=1)
        self.end_mlp = nn.Conv2d(16,1,kernel_size=1)
        self.apply_init()
    
    def apply_init(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self,x,t):
        x = x.reshape(-1,1,28,28)
        x = F.pad(x,(2,2,2,2),mode='constant',value=0)
        ttensor = self.t_embedding(t) # [batch, 256]
        batch = x.shape[0]
        # xc = x.clone()            print(attn_score.shape)

        ups = []
        for ly in self.up:
            x = ly(x,ttensor)
            ups.append(x.clone()) # append: 28x28, 14x14
            x = nn.AvgPool2d(2)(x)
        for ly in self.middle:
            x = ly(x,ttensor)
        for ly in self.down:
            x = nn.Upsample(scale_factor=2)(x) + ups.pop() # 14x14, 28x28
            x = ly(x,ttensor)
            # x = nn.Upsample(scale_factor=2)(x) + ups.pop()
        x = self.end_mlp(x)
        x = x[:,:,2:30,2:30]
        return x.reshape(batch,28*28)

# Train

In [34]:
import sys
import os

parent_dir = os.path.abspath('/home/zhh24/DeepLearning')

sys.path.append(parent_dir)
print('appended',parent_dir)

import utils

from tqdm import tqdm
import torch
import torch.nn.functional as F
import torchvision.utils

device = 'cuda' if torch.cuda.is_available() else 'cpu'

mnist = utils.MNIST(batch_size=512)
train_loader = mnist.train_dataloader
valid_loader = mnist.valid_dataloader
T=500
beta1=3e-4 # variance of lowest temperature
betaT=4e-2 # variance of highest temperature
# MODIFIED

# step = torch.log(torch.tensor(betaT/beta1))/(T-1)
# betas = beta1 * torch.exp(step*torch.arange(T,dtype=torch.float).to(device))
step = (betaT-beta1)/(T-1)
betas = torch.arange(T,dtype=torch.float,device=device) * step + beta1


alphas = 1-betas
alpha_bars = alphas.clone()
for i in range(1,T):
    alpha_bars[i] *= alpha_bars[i-1]
print(alpha_bars)
# print(alphas)

sqrt = torch.sqrt
sigmas = sqrt(betas * (1-alpha_bars / alphas)/(1-alpha_bars))

@torch.no_grad()
def sample(model:DDPM,save_dir):
    x = torch.randn([100,784]).to(device)
    for t in range(T-1,-1,-1):
        sigmaz = torch.randn_like(x)*sigmas[t]
        if t==0:
            sigmaz = 0
        x = (x-(1-alphas[t])/(sqrt(1-alpha_bars[t]))*model(x,t*torch.ones(x.shape[0],dtype=torch.long,device=device)))/(sqrt(alphas[t]))+sigmaz
        # x = torch.clamp(x,0,1)
    grid = torchvision.utils.make_grid(post_process(x).reshape(-1,1,28,28).cpu(), nrow=10)
    torchvision.utils.save_image(grid, save_dir)

@torch.no_grad()
def visualize(model,save_dir):
    x = torch.randn([10,784]).to(device)
    x_history = []
    for t in range(T-1,-1,-1):
        sigmaz = torch.randn_like(x)*((betas[t])**0.5).to(device)
        if t==0:
            sigmaz = 0
        x = (x-(1-alphas[t])/(sqrt(1-alpha_bars[t]))*model(x,t*torch.ones(x.shape[0],dtype=torch.long,device=device)))/(sqrt(alphas[t]))+sigmaz
        # x = torch.clamp(x,0,1)
        x_history.append(x)
    # print('cat.shape',torch.cat(x_history,dim=0).shape)
    grid = torchvision.utils.make_grid(post_process(torch.cat(x_history,dim=0)[3::4,...]).reshape(-1,1,28,28).cpu(), nrow=10)
    torchvision.utils.save_image(grid, save_dir)
    print('Saved visualize to',os.path.abspath(save_dir))

@torch.no_grad()
def visualize_denoise(model,save_dir):
    # get 10 images from the dataset
    x,_ = next(iter(valid_loader))
    x = x[:20,...].reshape(20,784).to(device)
    x = pre_process(x)
    t = torch.tensor([i * T // 20 for i in range(20)],dtype=torch.long,device=device)
    noise = torch.randn_like(x).reshape(-1,784)
    v1 = (sqrt(alpha_bars[t]).reshape(-1,1)*x).reshape(-1,784)
    v2 = sqrt(1-alpha_bars[t]).reshape(-1,1)*noise
    x_corr = v1+v2
    est = model(x_corr,t)
    x_rec = (x_corr - sqrt(1-alpha_bars[t]).reshape(-1,1)*est)/(sqrt(alpha_bars[t])).reshape(-1,1)
    grid_orig = torchvision.utils.make_grid(post_process(x).reshape(-1,1,28,28).cpu(), nrow=10)
    grid_corr = torchvision.utils.make_grid(post_process(x_corr).reshape(-1,1,28,28).cpu(), nrow=10)
    grid_rec = torchvision.utils.make_grid(post_process(x_rec).reshape(-1,1,28,28).cpu(), nrow=10)
    # add noise level infomation to the image
    noise_level = (1-alpha_bars[t]).reshape(-1).tolist()
    ori_mse = noise.pow(2).mean(dim=1).reshape(-1).tolist()
    mse = ((est-noise)**2).mean(dim=1).reshape(-1).tolist()
    print(noise_level)
    print(ori_mse)
    print(mse)
    grid = torch.cat([grid_orig,grid_corr,grid_rec],dim=1)
    torchvision.utils.save_image(grid, save_dir)
    print('Saved denoise to',os.path.abspath(save_dir))

def pre_process(x):
    # do the logit transform
    # return (torch.log(x+1e-3)-torch.log(1-x+1e-3))
    return x*2-1 #MODIFIED
    return (x+1)/2

def post_process(x):
    # return torch.sigmoid(x)
    return (x+1)/2 #MODIFIED
    return x*2-1

def train(epochs,model:DDPM,optimizer,eval_interval=5):
    for epoch in range(epochs):
        model.train()
        with tqdm(train_loader) as bar:
            losses = []
            for x,_ in bar:
                x = pre_process(x.to(device))
                epss = torch.randn_like(x).reshape(-1,784).to(device)
                ts = torch.randint(0,T,(x.shape[0],),device=device,dtype=torch.long)
                alpha_tbars = alpha_bars[ts]
                value = (sqrt(alpha_tbars).reshape(-1,1,1,1)*x).reshape(-1,784)+sqrt(1-alpha_tbars).reshape(-1,1)*epss
                out = model(value,ts) # [batch,784]
                # loss = ((epss-out).pow(2).mean(dim=-1) * (betas[ts])/(2*alphas[ts]*(1-alpha_tbars))).sum(dim=0)
                loss = (epss-out).pow(2).mean(dim=-1).mean(dim=0)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.item())
                bar.set_description('epoch {}, loss {:.4f}'.format(epoch,sum(losses)/len(losses)))
        model.eval()
        with torch.no_grad():
            with tqdm(valid_loader) as bar:
                mses = []
                losses = []
                for x,_ in bar:
                    x = pre_process(x.to(device))
                    epss = torch.randn_like(x).reshape(-1,784).to(device)
                    ts = torch.randint(0,T,(x.shape[0],),device=device,dtype=torch.long)
                    # print(ts)
                    alpha_tbars = alpha_bars[ts]
                    value = (sqrt(alpha_tbars).reshape(-1,1,1,1)*x).reshape(-1,784)+sqrt(1-alpha_tbars).reshape(-1,1)*epss
                    out = model(value,ts)
                    mse = F.mse_loss(epss,out)
                    mses.append(mse.item())
                    # loss = ((epss-out).pow(2).mean(dim=-1) * (betas[ts])/(2*alphas[ts]*(1-alpha_tbars))).sum(dim=0)
                    loss = (epss-out).pow(2).mean(dim=-1).mean(dim=0)
                    losses.append(loss.item())
                    bar.set_description('epoch {}, MSE {:.4f}, [Valid] {:.4f}'.format(epoch,sum(mses)/len(mses),sum(losses)/len(losses)))
                    
        if epoch % eval_interval == 0:
            visualize(model,save_dir=os.path.join('./samples',f'diffuse_epoch_{epoch}.png'))
            sample(model,save_dir=os.path.join('./samples',f'sample_epoch_{epoch}.png'))
            visualize_denoise(model,save_dir=os.path.join('./samples',f'denoise_epoch_{epoch}.png'))
            torch.save(model,os.path.join('./samples',f'epoch_{epoch}.pt'))

if __name__ == '__main__':
    model = DDPM().to(device)
    print('Number parameters of the model:', sum(p.numel() for p in model.parameters()))
    print('Model strcuture:',model)
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
    os.makedirs('./samples',exist_ok=True)
    sample(model,save_dir=os.path.join('./samples',f'init.png'))
    visualize(model,save_dir=os.path.join('./samples',f'init_visualize.png'))
    train(200,model,optimizer,eval_interval=5)

appended /home/zhh24/DeepLearning
tensor([9.9970e-01, 9.9932e-01, 9.9886e-01, 9.9832e-01, 9.9771e-01, 9.9701e-01,
        9.9624e-01, 9.9538e-01, 9.9445e-01, 9.9344e-01, 9.9235e-01, 9.9118e-01,
        9.8994e-01, 9.8862e-01, 9.8722e-01, 9.8575e-01, 9.8420e-01, 9.8257e-01,
        9.8087e-01, 9.7909e-01, 9.7724e-01, 9.7531e-01, 9.7331e-01, 9.7124e-01,
        9.6910e-01, 9.6688e-01, 9.6459e-01, 9.6223e-01, 9.5979e-01, 9.5729e-01,
        9.5472e-01, 9.5208e-01, 9.4937e-01, 9.4659e-01, 9.4375e-01, 9.4084e-01,
        9.3786e-01, 9.3482e-01, 9.3171e-01, 9.2854e-01, 9.2531e-01, 9.2201e-01,
        9.1865e-01, 9.1523e-01, 9.1176e-01, 9.0822e-01, 9.0462e-01, 9.0097e-01,
        8.9726e-01, 8.9349e-01, 8.8967e-01, 8.8579e-01, 8.8186e-01, 8.7788e-01,
        8.7384e-01, 8.6976e-01, 8.6562e-01, 8.6144e-01, 8.5720e-01, 8.5292e-01,
        8.4859e-01, 8.4422e-01, 8.3980e-01, 8.3534e-01, 8.3084e-01, 8.2629e-01,
        8.2171e-01, 8.1708e-01, 8.1241e-01, 8.0771e-01, 8.0297e-01, 7.9819e-01,
      

epoch 0, loss 15.6916:   1%|█                                                                                                     | 1/94 [00:00<00:11,  7.83it/s]

Saved visualize to /home/zhh24/samples/init_visualize.png


epoch 0, loss 1.3062: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.98it/s]
epoch 0, MSE 0.2242, [Valid] 0.2242: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.53it/s]


Saved visualize to /home/zhh24/samples/diffuse_epoch_0.png


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
epoch 1, loss 0.2487:   1%|█                                                                                                      | 1/94 [00:00<00:11,  7.80it/s]

[0.000299990177154541, 0.033122241497039795, 0.11033189296722412, 0.22125422954559326, 0.35161757469177246, 0.4865582585334778, 0.6133379936218262, 0.7231091856956482, 0.8114710450172424, 0.8779618144035339, 0.924903929233551, 0.956076443195343, 0.975583016872406, 0.9871010184288025, 0.993524968624115, 0.9969117641448975, 0.9986007213592529, 0.9993976950645447, 0.9997537732124329, 0.9999043941497803]
[0.9956701993942261, 0.9959710240364075, 1.0530003309249878, 0.9213144779205322, 1.057141661643982, 0.9751720428466797, 0.9680765867233276, 1.0060406923294067, 0.9209141731262207, 1.0358518362045288, 0.9515738487243652, 0.9217539429664612, 0.9214352965354919, 0.895706832408905, 0.9493274092674255, 1.0191142559051514, 1.0628652572631836, 0.9125326871871948, 0.9925356507301331, 1.0155960321426392]
[1.1694530248641968, 0.7616725564002991, 0.6008954048156738, 0.4027009904384613, 0.2939302921295166, 0.19253948330879211, 0.1724560707807541, 0.14846542477607727, 0.14206650853157043, 0.10852748155

epoch 1, loss 0.2006: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  8.20it/s]
epoch 1, MSE 0.1834, [Valid] 0.1834: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.62it/s]
epoch 2, loss 0.1743: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  8.22it/s]
epoch 2, MSE 0.1695, [Valid] 0.1695: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.57it/s]
epoch 3, loss 0.1563: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.97it/s]
epoch 3, MSE 0.1491, [Valid] 0.1491: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.53it/s]
epoch 4, loss 0.1425: 100%|█

Saved visualize to /home/zhh24/samples/diffuse_epoch_5.png


epoch 6, loss 0.1371:   1%|█                                                                                                      | 1/94 [00:00<00:11,  7.80it/s]

[0.000299990177154541, 0.033122241497039795, 0.11033189296722412, 0.22125422954559326, 0.35161757469177246, 0.4865582585334778, 0.6133379936218262, 0.7231091856956482, 0.8114710450172424, 0.8779618144035339, 0.924903929233551, 0.956076443195343, 0.975583016872406, 0.9871010184288025, 0.993524968624115, 0.9969117641448975, 0.9986007213592529, 0.9993976950645447, 0.9997537732124329, 0.9999043941497803]
[0.9825389385223389, 0.9264845252037048, 1.004364252090454, 1.0104691982269287, 1.0628550052642822, 1.0675299167633057, 0.9943296313285828, 1.1101744174957275, 1.0010329484939575, 0.988094687461853, 1.0674560070037842, 0.865959644317627, 0.9674904942512512, 0.9638899564743042, 0.9732338786125183, 1.0268019437789917, 0.9478619694709778, 1.0055363178253174, 1.0176571607589722, 0.9955370426177979]
[1.0310436487197876, 0.6248946785926819, 0.38352417945861816, 0.25659647583961487, 0.1409003734588623, 0.11699353158473969, 0.07434860616922379, 0.10285764932632446, 0.08510708063840866, 0.055729895

epoch 6, loss 0.1265: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  8.07it/s]
epoch 6, MSE 0.1231, [Valid] 0.1231: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.53it/s]
epoch 7, loss 0.1188: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  8.02it/s]
epoch 7, MSE 0.1178, [Valid] 0.1178: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.51it/s]
epoch 8, loss 0.1162: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.49it/s]
epoch 8, MSE 0.1144, [Valid] 0.1144: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.55it/s]
epoch 9, loss 0.1114: 100%|█

Saved visualize to /home/zhh24/samples/diffuse_epoch_10.png


epoch 11, loss 0.1144:   1%|█                                                                                                     | 1/94 [00:00<00:11,  7.79it/s]

[0.000299990177154541, 0.033122241497039795, 0.11033189296722412, 0.22125422954559326, 0.35161757469177246, 0.4865582585334778, 0.6133379936218262, 0.7231091856956482, 0.8114710450172424, 0.8779618144035339, 0.924903929233551, 0.956076443195343, 0.975583016872406, 0.9871010184288025, 0.993524968624115, 0.9969117641448975, 0.9986007213592529, 0.9993976950645447, 0.9997537732124329, 0.9999043941497803]
[0.9978609681129456, 1.026345133781433, 1.025437355041504, 0.9519157409667969, 1.0062947273254395, 0.9963005185127258, 0.9604605436325073, 1.0790499448776245, 0.9461386203765869, 0.9904152750968933, 0.9842007756233215, 0.9609602689743042, 0.9428297877311707, 1.0503206253051758, 1.006227731704712, 0.9509789347648621, 1.0208739042282104, 0.9768456816673279, 1.0350576639175415, 0.9944168329238892]
[1.0131313800811768, 0.5742236971855164, 0.27272656559944153, 0.1465596854686737, 0.09657861292362213, 0.06521821022033691, 0.06345497071743011, 0.07035114616155624, 0.059618908911943436, 0.04926869

epoch 11, loss 0.1054: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  8.26it/s]
epoch 11, MSE 0.1026, [Valid] 0.1026: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.56it/s]
epoch 12, loss 0.1019: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.97it/s]
epoch 12, MSE 0.1004, [Valid] 0.1004: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.53it/s]
epoch 13, loss 0.1025: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.99it/s]
epoch 13, MSE 0.0996, [Valid] 0.0996: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.55it/s]
epoch 14, loss 0.0999: 100%|

Saved visualize to /home/zhh24/samples/diffuse_epoch_15.png


epoch 16, loss 0.1100:   1%|█                                                                                                     | 1/94 [00:00<00:11,  7.79it/s]

[0.000299990177154541, 0.033122241497039795, 0.11033189296722412, 0.22125422954559326, 0.35161757469177246, 0.4865582585334778, 0.6133379936218262, 0.7231091856956482, 0.8114710450172424, 0.8779618144035339, 0.924903929233551, 0.956076443195343, 0.975583016872406, 0.9871010184288025, 0.993524968624115, 0.9969117641448975, 0.9986007213592529, 0.9993976950645447, 0.9997537732124329, 0.9999043941497803]
[0.9277873635292053, 0.9883667230606079, 1.0120646953582764, 1.0077815055847168, 1.0270980596542358, 1.0182219743728638, 1.0070602893829346, 1.0246878862380981, 0.9759447574615479, 1.031385898590088, 1.0254193544387817, 1.0323357582092285, 1.1017566919326782, 1.0779573917388916, 0.953021764755249, 1.0141795873641968, 0.9808624982833862, 1.0504343509674072, 0.9704107046127319, 0.9613288640975952]
[1.00008225440979, 0.5223508477210999, 0.2271416336297989, 0.12941516935825348, 0.07243644446134567, 0.06238614767789841, 0.05434912070631981, 0.0736839696764946, 0.06080454960465431, 0.04770956188

epoch 16, loss 0.0948: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.71it/s]
epoch 16, MSE 0.0966, [Valid] 0.0966: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.61it/s]
epoch 17, loss 0.0924: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.98it/s]
epoch 17, MSE 0.0948, [Valid] 0.0948: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.58it/s]
epoch 18, loss 0.0934: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.98it/s]
epoch 18, MSE 0.0928, [Valid] 0.0928: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.54it/s]
epoch 19, loss 0.0921: 100%|

Saved visualize to /home/zhh24/samples/diffuse_epoch_20.png


epoch 21, loss 0.0883:   1%|█                                                                                                     | 1/94 [00:00<00:11,  7.79it/s]

[0.000299990177154541, 0.033122241497039795, 0.11033189296722412, 0.22125422954559326, 0.35161757469177246, 0.4865582585334778, 0.6133379936218262, 0.7231091856956482, 0.8114710450172424, 0.8779618144035339, 0.924903929233551, 0.956076443195343, 0.975583016872406, 0.9871010184288025, 0.993524968624115, 0.9969117641448975, 0.9986007213592529, 0.9993976950645447, 0.9997537732124329, 0.9999043941497803]
[1.0054829120635986, 0.9600141644477844, 0.9670820832252502, 0.9969930648803711, 1.105242133140564, 0.9098739624023438, 1.0457936525344849, 0.9826421141624451, 1.1149725914001465, 1.007585883140564, 1.0218641757965088, 1.0435781478881836, 1.0076148509979248, 0.9794421792030334, 0.9783692955970764, 0.9049535989761353, 0.9323753714561462, 0.9648674130439758, 1.0208112001419067, 1.0206615924835205]
[1.0215849876403809, 0.4916878342628479, 0.19255805015563965, 0.12136484682559967, 0.06572921574115753, 0.050229642540216446, 0.052371297031641006, 0.06919031590223312, 0.07220511883497238, 0.04457

epoch 21, loss 0.0899: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  8.18it/s]
epoch 21, MSE 0.0888, [Valid] 0.0888: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.59it/s]
epoch 22, loss 0.0886: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.99it/s]
epoch 22, MSE 0.0833, [Valid] 0.0833: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.51it/s]
epoch 23, loss 0.0874: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  8.00it/s]
epoch 23, MSE 0.0868, [Valid] 0.0868: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.55it/s]
epoch 24, loss 0.0861: 100%|

Saved visualize to /home/zhh24/samples/diffuse_epoch_25.png


epoch 26, loss 0.0822:   1%|█                                                                                                     | 1/94 [00:00<00:11,  7.79it/s]

[0.000299990177154541, 0.033122241497039795, 0.11033189296722412, 0.22125422954559326, 0.35161757469177246, 0.4865582585334778, 0.6133379936218262, 0.7231091856956482, 0.8114710450172424, 0.8779618144035339, 0.924903929233551, 0.956076443195343, 0.975583016872406, 0.9871010184288025, 0.993524968624115, 0.9969117641448975, 0.9986007213592529, 0.9993976950645447, 0.9997537732124329, 0.9999043941497803]
[0.9959306120872498, 0.9339210987091064, 0.9340528845787048, 0.9093735218048096, 0.977450966835022, 0.9783138036727905, 0.9290589094161987, 1.0871658325195312, 0.9871878623962402, 0.9656217694282532, 1.004308819770813, 1.0306081771850586, 1.0518910884857178, 0.9508032202720642, 0.9677475690841675, 0.9302875399589539, 0.877405047416687, 0.9941999316215515, 0.9616425037384033, 1.0169779062271118]
[0.995722234249115, 0.40486934781074524, 0.15271854400634766, 0.10176672041416168, 0.07851235568523407, 0.04892120137810707, 0.04458837956190109, 0.05189305171370506, 0.07471700012683868, 0.04285229

epoch 26, loss 0.0844: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  8.23it/s]
epoch 26, MSE 0.0844, [Valid] 0.0844: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.58it/s]
epoch 27, loss 0.0835: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  8.00it/s]
epoch 27, MSE 0.0852, [Valid] 0.0852: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.56it/s]
epoch 28, loss 0.0821: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.99it/s]
epoch 28, MSE 0.0814, [Valid] 0.0814: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.53it/s]
epoch 29, loss 0.0815: 100%|

Saved visualize to /home/zhh24/samples/diffuse_epoch_30.png


epoch 31, loss 0.0934:   1%|█                                                                                                     | 1/94 [00:00<00:11,  7.80it/s]

[0.000299990177154541, 0.033122241497039795, 0.11033189296722412, 0.22125422954559326, 0.35161757469177246, 0.4865582585334778, 0.6133379936218262, 0.7231091856956482, 0.8114710450172424, 0.8779618144035339, 0.924903929233551, 0.956076443195343, 0.975583016872406, 0.9871010184288025, 0.993524968624115, 0.9969117641448975, 0.9986007213592529, 0.9993976950645447, 0.9997537732124329, 0.9999043941497803]
[1.0340079069137573, 0.953088641166687, 0.9939531683921814, 0.992101788520813, 0.9892802238464355, 1.0356539487838745, 1.0303881168365479, 0.9871862530708313, 0.9763270020484924, 1.0519763231277466, 1.053636908531189, 0.9414492249488831, 1.0695143938064575, 0.9721750617027283, 1.030295729637146, 0.9317366480827332, 1.0036790370941162, 1.0327701568603516, 0.9880707263946533, 0.9726787805557251]
[1.0466264486312866, 0.3561263084411621, 0.13949432969093323, 0.11477505415678024, 0.06897366791963577, 0.08928248286247253, 0.03117264434695244, 0.06624994426965714, 0.06319433450698853, 0.041680566

epoch 31, loss 0.0785: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.97it/s]
epoch 31, MSE 0.0791, [Valid] 0.0791: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.52it/s]
epoch 32, loss 0.0785: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  8.01it/s]
epoch 32, MSE 0.0783, [Valid] 0.0783: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.53it/s]
epoch 33, loss 0.0766: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:12<00:00,  7.97it/s]
epoch 33, MSE 0.0758, [Valid] 0.0758:  54%|██████████████████████████████████████████████▌                                       | 13/24 [00:01<00:01,  8.33it/s]

: 