In [1]:
import os
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from typing import List
from abc import abstractmethod
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader


seed = 1130
random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)

setup_name = "DDPM1130"
data_path = "Cloud"
output_path = "./outcome/{}".format(setup_name)
os.makedirs(output_path, exist_ok=True)
path_to_saved_models = os.path.join(output_path, "saved_models")
os.makedirs(path_to_saved_models, exist_ok=True)
path_to_saved_images = os.path.join(output_path, "saved_images")
os.makedirs(path_to_saved_images, exist_ok=True)

NC = 3
IMG_SIZE = 128
EPOCHS = 100
BATCH_SIZE = 32  
LR = 5e-5
TIMESTEPS = 1000
SAVE_FREQ = 20  
SHOW_FREQ = 10 
VAR_SCHEDULER = "cosine" 
NFAKE = 500  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class CloudDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if
                            f.endswith('.png') or f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image


def timestep_embedding(timesteps, dim, max_period=10000):
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

class TimestepBlock(nn.Module):
    @abstractmethod
    def forward(self, x, emb):
        pass

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    def forward(self, x, emb):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            else:
                x = layer(x)
        return x

def norm_layer(channels):
    return nn.GroupNorm(32, channels)

class ResidualBlock(TimestepBlock):
    def __init__(self, in_channels, out_channels, time_channels, dropout=0.0):
        super().__init__()
        self.conv1 = nn.Sequential(
            norm_layer(in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        )
        self.time_emb = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_channels, out_channels)
        )
        self.conv2 = nn.Sequential(
            norm_layer(out_channels),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        )
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x, t):
        h = self.conv1(x)
        h += self.time_emb(t)[:, :, None, None]
        h = self.conv2(h)
        return h + self.shortcut(x)

class AttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=1):
        super().__init__()
        self.num_heads = num_heads
        assert channels % num_heads == 0
        self.norm = norm_layer(channels)
        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)
        self.proj = nn.Conv2d(channels, channels, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        qkv = self.qkv(self.norm(x))
        q, k, v = qkv.reshape(B * self.num_heads, -1, H * W).chunk(3, dim=1)
        scale = 1. / math.sqrt(math.sqrt(C // self.num_heads))
        attn = torch.einsum("bct,bcs->bts", q * scale, k * scale)
        attn = attn.softmax(dim=-1)
        h = torch.einsum("bts,bcs->bct", attn, v)
        h = h.reshape(B, -1, H, W)
        h = self.proj(h)
        return h + x

class Upsample(nn.Module):
    def __init__(self, channels, use_conv=True):
        super().__init__()
        self.use_conv = use_conv
        if use_conv:
            self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if self.use_conv:
            x = self.conv(x)
        return x

class Downsample(nn.Module):
    def __init__(self, channels, use_conv=True):
        super().__init__()
        self.use_conv = use_conv
        if use_conv:
            self.op = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)
        else:
            self.op = nn.AvgPool2d(stride=2)

    def forward(self, x):
        return self.op(x)


class UNetModel(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        model_channels: int = 128,
        out_channels: int = 3,
        num_res_blocks: int = 2,
        attention_resolutions: List[int] = [16, 32],
        dropout: float = 0.0,
        channel_mult: List[int] = [1, 2, 4, 8],
        conv_resample: bool = True,
        num_heads: int = 4,
        time_embed_dim: int = 512
    ):
        super().__init__()

        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_heads = num_heads
        self.time_embed_dim = time_embed_dim

        self.time_embed = nn.Sequential(
            nn.Linear(model_channels, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )

        self.input_blocks = nn.ModuleList([
            TimestepEmbedSequential(nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=1))
        ])

        current_channel = model_channels
        input_block_chans = [current_channel]
        downsample_layers = len(channel_mult)
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResidualBlock(current_channel, mult * model_channels, time_embed_dim, dropout)
                ]
                current_channel = mult * model_channels
                if level < downsample_layers and (current_channel in attention_resolutions or current_channel // 2 in attention_resolutions):
                    layers.append(AttentionBlock(current_channel, num_heads=num_heads))
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                input_block_chans.append(current_channel)
            if level != downsample_layers - 1:
                self.input_blocks.append(TimestepEmbedSequential(Downsample(current_channel, conv_resample)))
                input_block_chans.append(current_channel)

        self.middle_block = TimestepEmbedSequential(
            ResidualBlock(current_channel, current_channel, time_embed_dim, dropout),
            AttentionBlock(current_channel, num_heads=num_heads),
            ResidualBlock(current_channel, current_channel, time_embed_dim, dropout)
        )

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                layers = [
                    ResidualBlock(
                        current_channel + input_block_chans.pop(),
                        model_channels * mult,
                        time_embed_dim,
                        dropout
                    )
                ]
                current_channel = model_channels * mult
                if level and i == num_res_blocks:
                    layers.append(Upsample(current_channel, conv_resample))
                if level < downsample_layers and (current_channel in attention_resolutions or current_channel // 2 in attention_resolutions):
                    layers.append(AttentionBlock(current_channel, num_heads=num_heads))
                self.output_blocks.append(TimestepEmbedSequential(*layers))

        self.out = nn.Sequential(
            norm_layer(current_channel),
            nn.SiLU(),
            nn.Conv2d(current_channel, out_channels, kernel_size=3, padding=1),
        )

    def forward(self, x, timesteps):
        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        h = x
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)

        h = self.middle_block(h, emb)

        for module in self.output_blocks:
            cat_in = torch.cat([h, hs.pop()], dim=1)
            h = module(cat_in, emb)

        return self.out(h)


def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)


class GaussianDiffusion:
    def __init__(
        self,
        timesteps=1000,
        beta_schedule='linear'
    ):
        self.timesteps = timesteps

        if beta_schedule == 'linear':
            betas = linear_beta_schedule(timesteps)
        elif beta_schedule == 'cosine':
            betas = cosine_beta_schedule(timesteps)
        else:
            raise ValueError(f'unknown beta schedule {beta_schedule}')
        self.betas = betas
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min =1e-20))
        self.posterior_mean_coef1 = (
            self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * torch.sqrt(self.alphas)
            / (1.0 - self.alphas_cumprod)
        )

    def _extract(self, a, t, x_shape):
        batch_size = t.shape[0]
        out = a.to(t.device).gather(0, t).float()
        out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
        return out

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def q_mean_variance(self, x_start, t):
        mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
        return mean, variance, log_variance

    def q_posterior_mean_variance(self, x_start, x_t, t):
        posterior_mean = (
            self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def predict_start_from_noise(self, x_t, t, noise):
        return (
            self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def p_mean_variance(self, model, x_t, t, clip_denoised=True):
        pred_noise = model(x_t, t)
        x_recon = self.predict_start_from_noise(x_t, t, pred_noise)
        if clip_denoised:
            x_recon = torch.clamp(x_recon, min=-1., max=1.)
        model_mean, posterior_variance, posterior_log_variance = \
                    self.q_posterior_mean_variance(x_recon, x_t, t)
        return model_mean, posterior_variance, posterior_log_variance

    @torch.no_grad()
    def p_sample(self, model, x_t, t, clip_denoised=True):
        model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t,
                                                    clip_denoised=clip_denoised)
        noise = torch.randn_like(x_t)
        nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))))
        pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
        return pred_img

    @torch.no_grad()
    def p_sample_loop(self, model, shape):
        batch_size = shape[0]
        device = next(model.parameters()).device
        img = torch.randn(shape, device=device)
        for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps):
            img = self.p_sample(model, img, torch.full((batch_size,), i, device=device, dtype=torch.long))
        return img.cpu()

    @torch.no_grad()
    def sample(self, model, nfake, image_size, batch_size=100, channels=3, to_numpy=False, unnorm_to_zero2one=True):
        if batch_size > nfake:
            batch_size = nfake
        assert nfake % batch_size == 0

        fake_images = []
        ngot = 0
        while ngot < nfake:
            batch_fake_images = self.p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
            fake_images.append(batch_fake_images)
            ngot += len(batch_fake_images)
            print("\r Got {}/{} fake images.".format(ngot, nfake))
        fake_images = torch.cat(fake_images, dim=0)

        if unnorm_to_zero2one:
            fake_images = (fake_images + 1) * 0.5
        if to_numpy:
            fake_images = fake_images.numpy()

        return fake_images

    def train_losses(self, model, x_start, t):
        noise = torch.randn_like(x_start)
        x_noisy = self.q_sample(x_start, t, noise=noise)
        predicted_noise = model(x_noisy, t)
        loss = F.mse_loss(noise, predicted_noise)
        return loss


transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cloud_dataset = CloudDataset(root_dir=data_path, transform=transform)
train_loader = DataLoader(cloud_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

model = UNetModel(
    in_channels=NC,
    model_channels=64,
    out_channels=NC,
    channel_mult=(1, 2, 4),
    attention_resolutions=[8]
).to(device)

gaussian_diffusion = GaussianDiffusion(timesteps=TIMESTEPS, beta_schedule=VAR_SCHEDULER)
optimizer = optim.Adam(model.parameters(), lr=LR)


start_time = time.time()
for epoch in range(EPOCHS):
    model.train()
    for step, images in enumerate(train_loader):
        optimizer.zero_grad()
        images = images.to(device)
        t = torch.randint(0, TIMESTEPS, (images.shape[0],), device=device).long()
        loss = gaussian_diffusion.train_losses(model, images, t)
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f}, Time: {time.time() - start_time:.4f}s")

    if (epoch + 1) % SAVE_FREQ == 0:
        save_file = os.path.join(path_to_saved_models, f"ckpt_epoch_{epoch+1}.pth")
        torch.save({
            'model': model.state_dict(),
        }, save_file)

    if (epoch + 1) % SHOW_FREQ == 0:
        model.eval()
        with torch.no_grad():
            gen_imgs = gaussian_diffusion.sample(model, nfake=9, image_size=IMG_SIZE, batch_size=9, channels=NC, to_numpy=False, unnorm_to_zero2one=True)
            save_file = os.path.join(path_to_saved_images, f'imgs_in_train_epoch_{epoch+1}.jpg')
            utils.save_image(gen_imgs, save_file, nrow=3, normalize=True)  
            

test_output_path = os.path.join(output_path, "result")
os.makedirs(test_output_path, exist_ok=True)
test_images = gaussian_diffusion.sample(model, NFAKE, image_size=IMG_SIZE, batch_size=100, channels=NC, to_numpy=False, unnorm_to_zero2one=True)
for i, img in enumerate(test_images):
    save_file = os.path.join(test_output_path, f'gen_img_{i+1:03d}.jpg')
    utils.save_image(img, save_file, normalize=True)

Epoch [1/100], Loss: 0.4304, Time: 19.4753s
Epoch [2/100], Loss: 0.3557, Time: 38.5720s
Epoch [3/100], Loss: 0.0601, Time: 57.6735s
Epoch [4/100], Loss: 0.0606, Time: 76.7273s
Epoch [5/100], Loss: 0.0411, Time: 95.8073s
Epoch [6/100], Loss: 0.0308, Time: 114.8980s
Epoch [7/100], Loss: 0.0427, Time: 134.0023s
Epoch [8/100], Loss: 0.0268, Time: 153.1045s
Epoch [9/100], Loss: 0.0523, Time: 172.2829s
Epoch [10/100], Loss: 0.0226, Time: 191.4206s


sampling loop time step: 100%|██████████| 1000/1000 [00:43<00:00, 23.03it/s]

 Got 9/9 fake images.





Epoch [11/100], Loss: 0.1452, Time: 254.0299s
Epoch [12/100], Loss: 0.0173, Time: 273.1156s
Epoch [13/100], Loss: 0.0248, Time: 292.2285s
Epoch [14/100], Loss: 0.0897, Time: 311.2967s
Epoch [15/100], Loss: 0.0186, Time: 330.4343s
Epoch [16/100], Loss: 0.2454, Time: 349.4972s
Epoch [17/100], Loss: 0.0141, Time: 368.5627s
Epoch [18/100], Loss: 0.0455, Time: 387.6396s
Epoch [19/100], Loss: 0.0263, Time: 406.7111s
Epoch [20/100], Loss: 0.0206, Time: 425.7963s


sampling loop time step: 100%|██████████| 1000/1000 [00:43<00:00, 23.02it/s]


 Got 9/9 fake images.
Epoch [21/100], Loss: 0.0233, Time: 488.5989s
Epoch [22/100], Loss: 0.1440, Time: 507.6994s
Epoch [23/100], Loss: 0.0420, Time: 526.7880s
Epoch [24/100], Loss: 0.0263, Time: 545.8663s
Epoch [25/100], Loss: 0.0915, Time: 564.9412s
Epoch [26/100], Loss: 0.0151, Time: 584.0369s
Epoch [27/100], Loss: 0.0314, Time: 603.1167s
Epoch [28/100], Loss: 0.0387, Time: 622.1889s
Epoch [29/100], Loss: 0.0286, Time: 641.2782s
Epoch [30/100], Loss: 0.0938, Time: 660.3648s


sampling loop time step: 100%|██████████| 1000/1000 [00:43<00:00, 23.03it/s]


 Got 9/9 fake images.
Epoch [31/100], Loss: 0.1155, Time: 722.9131s
Epoch [32/100], Loss: 0.0257, Time: 741.9967s
Epoch [33/100], Loss: 0.0488, Time: 761.0916s
Epoch [34/100], Loss: 0.0125, Time: 780.2044s
Epoch [35/100], Loss: 0.1350, Time: 799.2690s
Epoch [36/100], Loss: 0.0306, Time: 818.3507s
Epoch [37/100], Loss: 0.0208, Time: 837.4302s
Epoch [38/100], Loss: 0.0717, Time: 856.4821s
Epoch [39/100], Loss: 0.1085, Time: 875.5561s
Epoch [40/100], Loss: 0.0089, Time: 894.6443s


sampling loop time step: 100%|██████████| 1000/1000 [00:43<00:00, 23.03it/s]


 Got 9/9 fake images.
Epoch [41/100], Loss: 0.0367, Time: 957.3359s
Epoch [42/100], Loss: 0.0098, Time: 976.4214s
Epoch [43/100], Loss: 0.0102, Time: 995.4991s
Epoch [44/100], Loss: 0.0106, Time: 1014.5790s
Epoch [45/100], Loss: 0.0114, Time: 1033.6656s
Epoch [46/100], Loss: 0.0832, Time: 1052.7524s
Epoch [47/100], Loss: 0.0085, Time: 1071.8646s
Epoch [48/100], Loss: 0.0225, Time: 1090.9447s
Epoch [49/100], Loss: 0.0152, Time: 1110.0363s
Epoch [50/100], Loss: 0.0136, Time: 1129.1101s


sampling loop time step: 100%|██████████| 1000/1000 [00:43<00:00, 23.04it/s]


 Got 9/9 fake images.
Epoch [51/100], Loss: 0.0070, Time: 1191.6908s
Epoch [52/100], Loss: 0.0063, Time: 1210.7739s
Epoch [53/100], Loss: 0.0179, Time: 1229.8639s
Epoch [54/100], Loss: 0.0072, Time: 1248.9771s
Epoch [55/100], Loss: 0.0622, Time: 1268.0371s
Epoch [56/100], Loss: 0.1754, Time: 1287.1550s
Epoch [57/100], Loss: 0.0345, Time: 1306.3165s
Epoch [58/100], Loss: 0.0398, Time: 1325.4009s
Epoch [59/100], Loss: 0.0315, Time: 1344.5258s
Epoch [60/100], Loss: 0.0045, Time: 1363.6449s


sampling loop time step: 100%|██████████| 1000/1000 [00:43<00:00, 23.07it/s]


 Got 9/9 fake images.
Epoch [61/100], Loss: 0.1019, Time: 1426.2492s
Epoch [62/100], Loss: 0.0664, Time: 1445.3292s
Epoch [63/100], Loss: 0.0114, Time: 1464.4120s
Epoch [64/100], Loss: 0.0203, Time: 1483.5283s
Epoch [65/100], Loss: 0.0173, Time: 1502.6239s
Epoch [66/100], Loss: 0.0132, Time: 1521.7185s
Epoch [67/100], Loss: 0.0120, Time: 1540.8190s
Epoch [68/100], Loss: 0.0163, Time: 1559.9384s
Epoch [69/100], Loss: 0.0115, Time: 1579.0400s
Epoch [70/100], Loss: 0.0389, Time: 1598.2064s


sampling loop time step: 100%|██████████| 1000/1000 [00:43<00:00, 23.05it/s]


 Got 9/9 fake images.
Epoch [71/100], Loss: 0.0421, Time: 1660.7501s
Epoch [72/100], Loss: 0.0275, Time: 1679.8868s
Epoch [73/100], Loss: 0.0158, Time: 1698.9933s
Epoch [74/100], Loss: 0.0144, Time: 1718.1420s
Epoch [75/100], Loss: 0.0112, Time: 1737.2430s
Epoch [76/100], Loss: 0.0196, Time: 1756.3547s
Epoch [77/100], Loss: 0.0193, Time: 1775.4470s
Epoch [78/100], Loss: 0.0064, Time: 1794.5394s
Epoch [79/100], Loss: 0.0218, Time: 1813.6415s
Epoch [80/100], Loss: 0.0194, Time: 1832.7402s


sampling loop time step: 100%|██████████| 1000/1000 [00:43<00:00, 23.04it/s]


 Got 9/9 fake images.
Epoch [81/100], Loss: 0.0066, Time: 1895.3923s
Epoch [82/100], Loss: 0.0074, Time: 1914.4966s
Epoch [83/100], Loss: 0.0309, Time: 1933.5918s
Epoch [84/100], Loss: 0.0504, Time: 1952.6754s
Epoch [85/100], Loss: 0.0173, Time: 1971.7595s
Epoch [86/100], Loss: 0.0181, Time: 1990.9254s
Epoch [87/100], Loss: 0.0334, Time: 2010.0862s
Epoch [88/100], Loss: 0.0197, Time: 2029.2485s
Epoch [89/100], Loss: 0.0154, Time: 2048.3640s
Epoch [90/100], Loss: 0.0165, Time: 2067.5241s


sampling loop time step: 100%|██████████| 1000/1000 [00:43<00:00, 23.05it/s]


 Got 9/9 fake images.
Epoch [91/100], Loss: 0.0140, Time: 2130.0870s
Epoch [92/100], Loss: 0.0106, Time: 2149.1900s
Epoch [93/100], Loss: 0.0048, Time: 2168.2835s
Epoch [94/100], Loss: 0.0539, Time: 2187.3724s
Epoch [95/100], Loss: 0.0221, Time: 2206.5455s
Epoch [96/100], Loss: 0.0701, Time: 2225.6264s
Epoch [97/100], Loss: 0.0119, Time: 2244.7693s
Epoch [98/100], Loss: 0.0306, Time: 2263.8735s
Epoch [99/100], Loss: 0.0061, Time: 2282.9778s
Epoch [100/100], Loss: 0.0075, Time: 2302.0613s


sampling loop time step: 100%|██████████| 1000/1000 [00:43<00:00, 23.04it/s]


 Got 9/9 fake images.


sampling loop time step: 100%|██████████| 1000/1000 [07:23<00:00,  2.26it/s]


 Got 100/500 fake images.


sampling loop time step: 100%|██████████| 1000/1000 [07:23<00:00,  2.26it/s]


 Got 200/500 fake images.


sampling loop time step: 100%|██████████| 1000/1000 [07:23<00:00,  2.26it/s]


 Got 300/500 fake images.


sampling loop time step: 100%|██████████| 1000/1000 [07:23<00:00,  2.26it/s]


 Got 400/500 fake images.


sampling loop time step: 100%|██████████| 1000/1000 [07:23<00:00,  2.26it/s]


 Got 500/500 fake images.
