In [1]:
!nvidia-smi
!which python | grep DYY

Thu Sep 19 12:13:05 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   39C    P0              54W / 300W |   8716MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import torch
torch.manual_seed(3407)
torch.backends.cudnn.deterministic = True

# Model

In [3]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class SinousEmbedding(nn.Module):
#     def __init__(self, dim) -> None:
#         super().__init__()
#         assert dim%2==0,NotImplementedError()
#         self.angles = (1000.**(-2/dim))**torch.arange(1,dim//2+1,1,dtype=torch.float).cuda()
#         self.angles.requires_grad_(False)
#     def forward(self,x):
#         angles = torch.einsum('m,i->im',self.angles,x.float())
#         return torch.cat((torch.sin(angles),torch.cos(angles)),dim=1)

# class DDPM(nn.Module):
#     def __init__(self, *args, **kwargs) -> None:
#         super().__init__(*args, **kwargs)
#         self.in_size = 28 * 28
#         self.t_embedding_dim = 256
#         self.t_embedding = SinousEmbedding(dim=self.t_embedding_dim)
#         self.up = nn.ModuleList([
#             nn.Sequential(
#                 nn.Linear(784+self.t_embedding_dim,64),
#                 nn.ReLU(),
#             ),
#             nn.Sequential(
#                 nn.Linear(64,32),
#                 nn.ReLU(),
#             ),
#             # nn.Sequential(
#             #     nn.Linear(256,256),
#             #     # nn.LeakyReLU(0.1),
#             # ),
#         ])
#         self.middle = nn.ModuleList([
#             nn.Linear(32,32),
#             # nn.LeakyReLU(0.1),
#         ])
#         self.down= nn.ModuleList([
#             nn.Sequential(
#                 nn.Linear(32,32),
#                 nn.ReLU(),
#             ),
#             # nn.Sequential(
#             #     nn.Linear(256,256),
#             #     # nn.LeakyReLU(0.1),
#             # ),
#             nn.Sequential(
#                 nn.Linear(32,64),
#                 nn.ReLU(),
#             ),
#         ])
#         self.end_mlp = nn.Linear(64,784)
#         self.apply_init()

#     def apply_init(self):
#         for m in self.modules():
#             if isinstance(m, nn.Linear):
#                 nn.init.xavier_normal_(m.weight)
#                 nn.init.constant_(m.bias, 0)

#     def forward(self,x,t):
#         x = x.reshape(-1,784)
#         ttensor = self.t_embedding(t) # [batch, 256]
#         batch = x.shape[0]
#         xc = x.clone()
#         ups = []
#         x = torch.cat((x,ttensor),dim=-1)
#         for ly in self.up:
#             x = ly(x)
#             ups.append(x.clone())
#         for ly in self.middle:
#             x = ly(x)
#         for ly in self.down:
#             x = ly(x) + ups.pop()

#         x = self.end_mlp(x)
#         x = (x + xc)
#         return x

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SinousEmbedding(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        assert dim%2==0,NotImplementedError()
        self.angles = (10000.**(-2/dim))**torch.arange(1,dim//2+1,1,dtype=torch.float).cuda()
        self.angles.requires_grad_(False)
    def forward(self,x):
        angles = torch.einsum('m,i->im',self.angles,x.float())
        return torch.cat((torch.sin(angles),torch.cos(angles)),dim=1)
    
class ResidualBlock(nn.Module):

    def __init__(self,channels=128,kernel_size=3,t_dim=64) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(channels,channels,kernel_size=kernel_size,padding=kernel_size//2)
        self.t_net = nn.Linear(t_dim,channels)
        self.conv2 = nn.Conv2d(channels,channels,kernel_size=kernel_size,padding=kernel_size//2)
        self.conv1.weight.data.fill_(0)
        self.conv2.weight.data.fill_(0)
        self.t_net.weight.data.fill_(0)
        self.conv1.bias.data.fill_(0)
        self.conv2.bias.data.fill_(0)
        self.t_net.bias.data.fill_(0)
    
    def forward(self,x,t):
        xc = x.clone()
        x = self.conv1(x.relu())
        x = x + self.t_net(t).unsqueeze(-1).unsqueeze(-1).expand(t.shape[0],x.shape[1],x.shape[2],x.shape[3])
        x = F.relu(x)
        x = self.conv2(x)
        return x + xc


class F_x_t(nn.Module):

    def __init__(self,in_channels,out_channels,out_size,kernel_size=3,t_shape=64,attn=False,attn_dim=32,residual=True) -> None:
        super().__init__()
        # self.t_channels = out_channels // 2
        # self.conv_channels = out_channels - self.t_channels
        self.t_channels = out_channels
        self.conv_channels = out_channels
        self.conv = nn.Conv2d(in_channels, self.conv_channels, kernel_size=kernel_size, padding=kernel_size//2)
        self.out_size = out_size
        self.fc = nn.Linear(t_shape, self.t_channels)
        self.attn = attn
        self.residual = residual
        if attn:
            self.Q  = nn.Conv2d(out_channels, attn_dim, kernel_size=1)
            self.K  = nn.Conv2d(out_channels, attn_dim, kernel_size=1)
            self.V  = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        if residual:
            self.res = ResidualBlock(channels=out_channels,kernel_size=kernel_size,t_dim=t_shape)
        # self.fc = nn.Embedding(t_shape, self.t_num)

    def forward(self, x, t):
        if self.t_channels == 0:
            raise NotImplementedError()
            return self.conv(x)
        # return torch.cat([self.conv(x),self.fc(t).unsqueeze(-1).unsqueeze(-1).expand(t.shape[0], self.t_channels, self.out_size, self.out_size)],dim=1).relu()
        val = self.conv(x) + self.fc(t).unsqueeze(-1).unsqueeze(-1).expand(t.shape[0], self.t_channels, self.out_size, self.out_size)
        if self.residual:
            val = self.res(val,t)
        if self.attn:
            q = self.Q(val)
            k = self.K(val)
            v = self.V(val)
            attn_score = torch.einsum('bchw,bcxy->bhwxy',q,k).reshape(q.shape[0],*q.shape[-2:],-1)
            attn_score = attn_score.softmax(dim=-1).reshape(q.shape[0],*q.shape[-2:],*k.shape[-2:])
            return torch.einsum('bhwxy,bcxy->bchw',attn_score,v).relu()
        return val.relu()

class DDPM(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.t_embedding_dim = 32
        self.t_embedding = SinousEmbedding(dim=self.t_embedding_dim)
        self.up= nn.ModuleList([
            F_x_t(in_channels=1,out_channels=32,out_size=32,kernel_size=3,t_shape=self.t_embedding_dim),
            F_x_t(in_channels=32,out_channels=64,out_size=16,kernel_size=3,t_shape=self.t_embedding_dim),
            F_x_t(in_channels=64,out_channels=128,out_size=8,kernel_size=3,t_shape=self.t_embedding_dim,attn=False),
            # ResidualBlock(channels=128,kernel_size=3,t_dim=self.t_embedding_dim),
            # F_x_t(in_channels=128,out_channels=128,out_size=4,kernel_size=1,t_shape=self.t_embedding_dim),
        ])
        self.middle = nn.ModuleList([
            nn.Identity()
            # ResidualBlock(channels=128,kernel_size=3,t_dim=self.t_embedding_dim),
            # F_x_t(in_channels=128,out_channels=128,out_size=4,kernel_size=1,t_shape=self.t_embedding_dim,attn=False),
        ])
        self.down= nn.ModuleList([
            # F_x_t(in_channels=128,out_channels=128,out_size=2,kernel_size=1,t_shape=self.t_embedding_dim),
            F_x_t(in_channels=128,out_channels=64,out_size=8,kernel_size=3,t_shape=self.t_embedding_dim,attn=False),
            F_x_t(in_channels=64,out_channels=32,out_size=16,kernel_size=3,t_shape=self.t_embedding_dim),
            F_x_t(in_channels=32,out_channels=16,out_size=32,kernel_size=3,t_shape=self.t_embedding_dim),
        ])
        # self.end_mlp = nn.Conv2d(32,1,kernel_size=3,padding=1)
        self.end_mlp = nn.Conv2d(16,1,kernel_size=1)

    def forward(self,x,t):
        x = x.reshape(-1,1,28,28)
        x = F.pad(x,(2,2,2,2),mode='constant',value=0)
        ttensor = self.t_embedding(t) # [batch, 256]
        batch = x.shape[0]
        # xc = x.clone()            print(attn_score.shape)

        ups = []
        for ly in self.up:
            x = ly(x,ttensor)
            ups.append(x.clone()) # append: 28x28, 14x14
            x = nn.AvgPool2d(2)(x)
        for ly in self.middle:
            # x = ly(x,ttensor)
            x = ly(x)
        for ly in self.down:
            x = nn.Upsample(scale_factor=2)(x) + ups.pop() # 14x14, 28x28
            x = ly(x,ttensor)
            # x = nn.Upsample(scale_factor=2)(x) + ups.pop()
        x = self.end_mlp(x)
        x = x[:,:,2:30,2:30]
        return x.reshape(batch,28*28)

# Train

In [5]:
import sys
import os

# parent_dir = os.path.abspath('/root/DeepLearning')
parent_dir = os.path.abspath('/home/zhh24/DeepLearning')

sys.path.append(parent_dir)
print('appended',parent_dir)

import utils

from tqdm import tqdm
import torch
import torch.nn.functional as F
import torchvision.utils

import matplotlib.pyplot as plt
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.autograd.set_detect_anomaly(True)
mnist = utils.MNIST(batch_size=512,data_aug=True)
train_loader = mnist.train_dataloader
valid_loader = mnist.valid_dataloader
T=1000
beta1=1e-4 # variance of lowest temperature
betaT=2e-2 # variance of highest temperature

# step = torch.log(torch.tensor(betaT/beta1))/(T-1)
# betas = beta1 * torch.exp(step*torch.arange(T,dtype=torch.float).to(device))
step = (betaT-beta1)/(T-1)
betas = torch.arange(T,dtype=torch.float,device=device) * step + beta1


alphas = 1-betas
alpha_bars = alphas.clone()
for i in range(1,T):
    alpha_bars[i] *= alpha_bars[i-1]

                    # we re-define a way to generate hyperparameters
                    # alpha_bar_0 = .9
                    # # alpha_bar_mid = .3
                    # alpha_bar_T = 1e-3
                    # alpha_bars = torch.zeros(T,dtype=torch.float)
                    # # alpha_bars[:T//2] = alpha_bar_0 + (alpha_bar_mid-alpha_bar_0) * torch.arange(T//2,dtype=torch.float,device=device) / (T//2)
                    # # alpha_bars[T//2:] = alpha_bar_mid + (alpha_bar_T-alpha_bar_mid) * torch.arange(T//2,dtype=torch.float,device=device) / (T//2)
                    # alpha_bars = alpha_bar_0 + (alpha_bar_T-alpha_bar_0) * torch.arange(T,dtype=torch.float,device=device) / T
                    # alphas = alpha_bars.clone()
                    # for i in range(1,T):
                    #     alphas[i] = alpha_bars[i] / alpha_bars[i-1]
                    # betas = 1-alphas

print(alpha_bars)
print('range of bars',alpha_bars.min(),alpha_bars.max())
# print(alphas)

sqrt = torch.sqrt
sigmas = sqrt(betas * (1-alpha_bars / alphas)/(1-alpha_bars))
sigmas[0] = 1
print('range of sigmas,',sigmas.min(),sigmas.max())
alphas = alphas.to(device)
alpha_bars = alpha_bars.to(device)
betas = betas.to(device)
sigmas = sigmas.to(device)
weights = torch.ones(T,dtype=torch.float,device=device)

@torch.no_grad()
def sample(model:DDPM,save_dir):
    x = torch.randn([100,784]).to(device)
    for t in range(T-1,-1,-1):
        sigmaz = torch.randn_like(x)*sigmas[t]
        if t==0:
            sigmaz = 0
        x = (x-(1-alphas[t])/(sqrt(1-alpha_bars[t]))*model(x,t*torch.ones(x.shape[0],dtype=torch.long,device=device)))/(sqrt(alphas[t]))+sigmaz
        # x = torch.clamp(x,0,1)
    grid = torchvision.utils.make_grid(post_process(x).reshape(-1,1,28,28).cpu(), nrow=10)
    torchvision.utils.save_image(grid, save_dir)

@torch.no_grad()
def visualize(model,save_dir):
    interval = (T-1) // 20
    x = torch.randn([10,784]).to(device)
    x_history = []
    for t in range(T-1,-1,-1):
        sigmaz = torch.randn_like(x)*((betas[t])**0.5).to(device)
        if t==0:
            sigmaz = 0
        x = (x-(1-alphas[t])/(sqrt(1-alpha_bars[t]))*model(x,t*torch.ones(x.shape[0],dtype=torch.long,device=device)))/(sqrt(alphas[t]))+sigmaz
        # x = torch.clamp(x,0,1)
        x_history.append(x)
    # print('cat.shape',torch.cat(x_history,dim=0).shape)
    grid = torchvision.utils.make_grid(post_process(torch.stack(x_history,dim=0)[::interval,...]).reshape(-1,1,28,28).cpu(), nrow=10)
    torchvision.utils.save_image(grid, save_dir)
    print('Saved visualize to',os.path.abspath(save_dir))

@torch.no_grad()
def visualize_denoise(model,save_dir):
    # get 10 images from the dataset
    x,_ = next(iter(valid_loader))
    x = x[:20,...].reshape(20,784).to(device)
    x = pre_process(x)
    t = torch.tensor([i * T // 20 for i in range(20)],dtype=torch.long,device=device)
    noise = torch.randn_like(x).reshape(-1,784)
    v1 = (sqrt(alpha_bars[t]).reshape(-1,1)*x).reshape(-1,784)
    v2 = sqrt(1-alpha_bars[t]).reshape(-1,1)*noise
    x_corr = v1+v2
    est = model(x_corr,t)
    x_rec = (x_corr - sqrt(1-alpha_bars[t]).reshape(-1,1)*est)/(sqrt(alpha_bars[t])).reshape(-1,1)
    grid_orig = torchvision.utils.make_grid(post_process(x).reshape(-1,1,28,28).cpu(), nrow=10)
    grid_corr = torchvision.utils.make_grid(post_process(x_corr).reshape(-1,1,28,28).cpu(), nrow=10)
    grid_rec = torchvision.utils.make_grid(post_process(x_rec).reshape(-1,1,28,28).cpu(), nrow=10)
    # add noise level infomation to the image
    noise_level = (1-alpha_bars[t]).reshape(-1).tolist()
    ori_mse = noise.pow(2).mean(dim=1).reshape(-1).tolist()
    mse = ((est-noise)**2).mean(dim=1).reshape(-1).tolist()
    print(noise_level)
    print(ori_mse)
    print(mse)
    grid = torch.cat([grid_orig,grid_corr,grid_rec],dim=1)
    torchvision.utils.save_image(grid, save_dir)
    print('Saved denoise to',os.path.abspath(save_dir))

def plot_loss(losses,save_dir):
    losses_vals, t_vals = zip(*losses)
    losses_vals = torch.cat(losses_vals,dim=0)
    t_vals = torch.cat(t_vals,dim=0)
    # print('t_vals',t_vals)
    # print('losses_vals',losses_vals)

    results = []
    for t in range(T):
        this_t = abs(t_vals.float()-float(t))<0.5
        results.append(torch.sum(torch.where(this_t,losses_vals,torch.tensor(0.,device=device))).item() / (torch.sum(this_t.float())+1e-3).item())
    plt.plot(results)
    plt.ylim(0,max(results)* 1.2)
    plt.savefig(save_dir)
    plt.close()
    # weights = (torch.tensor(results,device=device)) # weights
    weights = torch.ones(T,dtype=torch.float,device=device)
    # weights[:10]=0
    # weights[10:80] /= 100
    return weights

def pre_process(x):
    # do the logit transform
    # return (torch.log(x+1e-3)-torch.log(1-x+1e-3))
    return x*2-1 #MODIFIED
    return (x+1)/2

def post_process(x):
    # return torch.sigmoid(x)
    return (x+1)/2 #MODIFIED
    return x*2-1

def train(epochs,model:DDPM,optimizer,eval_interval=5):
    global weights
    for epoch in range(epochs):
        # print('weights normalized:',weights/weights.sum())
        all_ts = torch.distributions.Categorical(weights).sample((50000,))
        cnt = 0
        model.train()
        with tqdm(train_loader) as bar:
            losses = []
            for x,_ in bar:
                cnt += x.shape[0]
                x = pre_process(x.to(device))
                epss = torch.randn_like(x).reshape(-1,784).to(device)
                # ts = torch.randint(0,T,(x.shape[0],),device=device,dtype=torch.long)
                ts = all_ts[cnt-x.shape[0]:cnt]
                alpha_tbars = alpha_bars[ts]
                value = (sqrt(alpha_tbars).reshape(-1,1,1,1)*x).reshape(-1,784)+sqrt(1-alpha_tbars).reshape(-1,1)*epss
                out = model(value,ts) # [batch,784]
                # loss = ((epss-out).pow(2).mean(dim=-1) * (betas[ts])/(2*alphas[ts]*(1-alpha_tbars))).sum(dim=0)
                loss = ((epss-out).pow(2).mean(dim=-1)).mean(dim=0)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.item())
                bar.set_description('epoch {}, loss {:.4f}'.format(epoch,sum(losses)/len(losses)))

        model.eval()
        with torch.no_grad():
            with tqdm(valid_loader) as bar:
                mses = []
                losses = []
                losses_for_t = []
                for x,_ in bar:
                    x = pre_process(x.to(device))
                    epss = torch.randn_like(x).reshape(-1,784).to(device)
                    ts = torch.randint(0,T,(x.shape[0],),device=device,dtype=torch.long)
                    # print(ts)
                    alpha_tbars = alpha_bars[ts]
                    value = (sqrt(alpha_tbars).reshape(-1,1,1,1)*x).reshape(-1,784)+sqrt(1-alpha_tbars).reshape(-1,1)*epss
                    out = model(value,ts)
                    mse = F.mse_loss(epss,out)
                    mses.append(mse.item())
                    loss = ((epss-out).pow(2).mean(dim=-1))
                    # loss = (epss-out).pow(2).mean(dim=-1)
                    losses_for_t.append((loss.clone().detach(),ts))
                    loss = (loss).mean(dim=0)
                    losses.append(loss.item())
                    bar.set_description('epoch {}, MSE {:.4f}, [Valid] {:.4f}'.format(epoch,sum(mses)/len(mses),sum(losses)/len(losses)))
                    
        if epoch % eval_interval == 0:
            visualize(model,save_dir=os.path.join('./samples',f'diffuse_epoch_{epoch}.png'))
            sample(model,save_dir=os.path.join('./samples',f'sample_epoch_{epoch}.png'))
            # visualize_denoise(model,save_dir=os.path.join('./samples',f'denoise_epoch_{epoch}.png'))
            weights = plot_loss(losses_for_t,save_dir=os.path.join('./samples',f'loss_epoch_{epoch}.png'))
            torch.save(model,os.path.join('./samples',f'epoch_{epoch}.pt'))

if __name__ == '__main__':
    model = DDPM().to(device)
    print('Number parameters of the model:', sum(p.numel() for p in model.parameters()))
    print('Model strcuture:',model)
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
    os.makedirs('./samples',exist_ok=True)
    sample(model,save_dir=os.path.join('./samples',f'init.png'))
    visualize(model,save_dir=os.path.join('./samples',f'init_visualize.png'))
    train(200,model,optimizer,eval_interval=5)

appended /home/zhh24/DeepLearning
tensor([9.9990e-01, 9.9978e-01, 9.9964e-01, 9.9948e-01, 9.9930e-01, 9.9910e-01,
        9.9888e-01, 9.9864e-01, 9.9838e-01, 9.9811e-01, 9.9781e-01, 9.9749e-01,
        9.9715e-01, 9.9679e-01, 9.9641e-01, 9.9602e-01, 9.9560e-01, 9.9516e-01,
        9.9471e-01, 9.9423e-01, 9.9374e-01, 9.9322e-01, 9.9269e-01, 9.9213e-01,
        9.9156e-01, 9.9096e-01, 9.9035e-01, 9.8972e-01, 9.8907e-01, 9.8840e-01,
        9.8771e-01, 9.8700e-01, 9.8627e-01, 9.8553e-01, 9.8476e-01, 9.8398e-01,
        9.8317e-01, 9.8235e-01, 9.8151e-01, 9.8065e-01, 9.7977e-01, 9.7887e-01,
        9.7795e-01, 9.7702e-01, 9.7606e-01, 9.7509e-01, 9.7410e-01, 9.7309e-01,
        9.7206e-01, 9.7102e-01, 9.6995e-01, 9.6887e-01, 9.6777e-01, 9.6665e-01,
        9.6551e-01, 9.6436e-01, 9.6319e-01, 9.6200e-01, 9.6079e-01, 9.5956e-01,
        9.5832e-01, 9.5706e-01, 9.5578e-01, 9.5449e-01, 9.5318e-01, 9.5185e-01,
        9.5050e-01, 9.4914e-01, 9.4776e-01, 9.4636e-01, 9.4494e-01, 9.4351e-01,
      

  0%|                                                                                                                                     | 0/94 [00:00<?, ?it/s]

Saved visualize to /home/zhh24/samples/init_visualize.png


epoch 0, loss 0.2792: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:30<00:00,  3.30it/s]
epoch 0, MSE 0.1279, [Valid] 0.1279: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:03<00:00,  6.12it/s]


Saved visualize to /home/zhh24/samples/diffuse_epoch_0.png


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
epoch 1, loss 0.1134: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:29<00:00,  3.29it/s]
epoch 1, MSE 0.0988, [Valid] 0.0988: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:03<00:00,  6.11it/s]
epoch 2, loss 0.0909: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:30<00:00,  3.28it/s]
epoch 2, MSE 0.0820, [Valid] 0.0820: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:03<00:00,  6.08it/s]
epoch 3, loss 0.0804: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:30<

Saved visualize to /home/zhh24/samples/diffuse_epoch_5.png


epoch 6, loss 0.0617: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:29<00:00,  3.30it/s]
epoch 6, MSE 0.0608, [Valid] 0.0608: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:03<00:00,  6.09it/s]
epoch 7, loss 0.0604: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:29<00:00,  3.28it/s]
epoch 7, MSE 0.0574, [Valid] 0.0574: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:03<00:00,  6.11it/s]
epoch 8, loss 0.0576: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:30<00:00,  3.29it/s]
epoch 8, MSE 0.0562, [Valid] 0.0562: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:03<00:00,  6.11it/s]
epoch 9, loss 0.0560: 100%|█

Saved visualize to /home/zhh24/samples/diffuse_epoch_10.png


epoch 11, loss 0.0525: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:30<00:00,  3.30it/s]
epoch 11, MSE 0.0542, [Valid] 0.0542: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:03<00:00,  6.11it/s]
epoch 12, loss 0.0520: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:30<00:00,  3.29it/s]
epoch 12, MSE 0.0513, [Valid] 0.0513: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:03<00:00,  6.06it/s]
epoch 13, loss 0.0505: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:30<00:00,  3.29it/s]
epoch 13, MSE 0.0520, [Valid] 0.0520: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:03<00:00,  6.13it/s]
epoch 14, loss 0.0497: 100%|

Saved visualize to /home/zhh24/samples/diffuse_epoch_15.png


epoch 16, loss 0.0475:  91%|████████████████████████████████████████████████████████████████████████████████████████████▍        | 86/94 [00:27<00:02,  3.14it/s]


KeyboardInterrupt: 