In [27]:
! pip install -q kaggle
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
! pip install opendatasets
! pip install datasets

In [None]:
import opendatasets as od

In [None]:
! pip install lightning dataclasses accelerate diffusers p_tqdm

In [None]:
od.download(
	"https://www.kaggle.com/datasets/mahmoudnoor/high-resolution-catdogbird-image-dataset-13000")


od.download(
	"https://www.kaggle.com/datasets/joaopauloschuler/cifar10-64x64-resized-via-cai-super-resolution")

In [1]:
import os
import torch
from torch import nn
from tqdm import tqdm
from PIL import Image
from p_tqdm import p_map
import lightning.pytorch as pl
import torch.nn.functional as F
from dataclasses import dataclass
from accelerate import Accelerator
from torchvision import transforms, datasets
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, random_split
from lightning.pytorch.loggers import TensorBoardLogger
from diffusers.optimization import get_cosine_schedule_with_warmup
#from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from diffusers import DDPMScheduler

In [3]:
import random
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(project="Diffusion-cat-dog-bird", log_model=True, name='64-image-custom-ema')

import torchvision.transforms as T
transform = T.ToPILImage()

import copy

In [4]:
@dataclass
class TrainingConfig:
    image_size = 64  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = train_batch_size  # how many images to sample during evaluation
    num_epochs = 100
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    num_workers = 12
    save_model_epochs = 30
    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = 'cat-dogs-birds'  # the model namy locally and on the HF Hub
    accelerator = Accelerator(mixed_precision=mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)

    push_to_hub = False  # whether to upload the saved model to the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0

config = TrainingConfig()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [5]:
train_transform = transforms.Compose(
            [
                transforms.Resize((config.image_size, config.image_size)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(10),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

test_transform = transforms.Compose(
            [
                transforms.Resize((config.image_size, config.image_size)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

train_dataset = datasets.ImageFolder(root='high-resolution-catdogbird-image-dataset-13000', transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers)

In [6]:

class SelfAttention(nn.Module):
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        size = x.shape[-1]
        x = x.view(-1, self.channels, size * size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, size, size)


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class UNet(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256, remove_deep_conv=False):
        super().__init__()
        self.time_dim = time_dim
        self.remove_deep_conv = remove_deep_conv
        self.inc = DoubleConv(c_in, 64)
        self.down1 = Down(64, 128)
        self.sa1 = SelfAttention(128)
        self.down2 = Down(128, 256)
        self.sa2 = SelfAttention(256)
        self.down3 = Down(256, 256)
        self.sa3 = SelfAttention(256)


        if remove_deep_conv:
            self.bot1 = DoubleConv(256, 256)
            self.bot3 = DoubleConv(256, 256)
        else:
            self.bot1 = DoubleConv(256, 512)
            self.bot2 = DoubleConv(512, 512)
            self.bot3 = DoubleConv(512, 256)

        self.up1 = Up(512, 128)
        self.sa4 = SelfAttention(128)
        self.up2 = Up(256, 64)
        self.sa5 = SelfAttention(64)
        self.up3 = Up(128, 64)
        self.sa6 = SelfAttention(64)
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=one_param(self).device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc
        

    def unet_forwad(self, x, t):
        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4 = self.sa3(x4)

        x4 = self.bot1(x4)
        if not self.remove_deep_conv:
            x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)
        x = self.sa4(x)
        x = self.up2(x, x2, t)
        x = self.sa5(x)
        x = self.up3(x, x1, t)
        x = self.sa6(x)
        output = self.outc(x)
        return output

    def forward(self, x, t):
        t = t.unsqueeze(-1)
        t = self.pos_encoding(t, self.time_dim)
        return self.unet_forwad(x, t)


class UNet_conditional(UNet):
    def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, **kwargs):
        super().__init__(c_in, c_out, time_dim, **kwargs)
        if num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_dim)

    def forward(self, x, t, y=None):
        t = t.unsqueeze(-1)
        t = self.pos_encoding(t, self.time_dim)

        if y is not None:
            t += self.label_emb(y)

        return self.unet_forwad(x, t)


def one_param(m):
    "get model first parameter"
    return next(iter(m.parameters()))


unet = UNet_conditional(c_in=3, c_out=3, time_dim=256, num_classes=3, remove_deep_conv=False)

In [10]:
class DiffusionModel(pl.LightningModule):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.last_image = ""
        #self.label_projection = nn.Embedding(3, 64)

        # Initialize UNet model
        self.unet = unet

        # Initialize noise scheduler
        self.noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

        # Initialize EMA
        self.ema_decay = 0.995  # You should define this in your config
        self.ema_unet = copy.deepcopy(self.unet)
        for param in self.ema_unet.parameters():
            param.requires_grad_(False)

    def update_ema(self):
        with torch.no_grad():
            model_params = dict(self.unet.named_parameters())
            ema_params = dict(self.ema_unet.named_parameters())
            for name in model_params:
                model_param = model_params[name]
                ema_param = ema_params[name]
                ema_param.copy_(ema_param * self.ema_decay + (1.0 - self.ema_decay) * model_param)

    def forward(self, x, timesteps, hidden_embed):
        # Forward pass through UNet
        noise_pred = self.unet(x, timesteps, hidden_embed)
        #noise_pred = self.unet(x, timesteps, encoder_hidden_states=hidden_embed.unsqueeze(1))["sample"]
        return noise_pred

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.unet.parameters(), lr=self.config.learning_rate)
        scheduler = OneCycleLR(
                        optimizer,
                        max_lr=self.config.learning_rate,
                        pct_start=6/self.trainer.max_epochs,
                        epochs=self.trainer.max_epochs,
                        steps_per_epoch=len(train_loader),
                        anneal_strategy='cos',
                        div_factor=100, #data_module.dataset.num_rows
                        final_div_factor=10,
                        #three_phase=True
                )
        return {'optimizer': optimizer,
                'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
        #return(optimizer)

    def process_batch(self, batch):

        clean_images, targets = batch

        # Encode the clean images to obtain latent representations
        # with torch.no_grad():
        #     clean_latent = self.vae.encode(clean_images.to(torch.float16)).latent_dist.sample()
        #     clean_latent = ((clean_latent * 2) - 1) * 0.18215

        clean_latent = clean_images

        # Sample noise to add to the images
        noise = torch.randn(clean_latent.shape, device=self.device)

        # Sample a random timestep for each image
        bs = clean_images.size(0)
        timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bs,), device=self.device).long()

        # Add noise to the clean latent representations (forward diffusion)
        noisy_images = self.noise_scheduler.add_noise(clean_latent, noise, timesteps)

        # Label Projection
        #target_embed = self.label_projection(targets)

        # Predict the noise residual
        noise_pred = self(noisy_images, timesteps, targets)  # Calls the forward method

        # Compute the loss
        loss = F.mse_loss(noise_pred, noise)

        self.last_image = (clean_images[-1], targets[-1])

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.process_batch(batch)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.update_ema()
        return loss

    # def validation_step(self, batch, batch_idx):
    #     val_loss = self.process_batch(batch)
    #     self.log('val_loss', val_loss, on_epoch=True, prog_bar=True, logger=True)
    #     return val_loss
    
    def latents_to_pil(self, latents):
        # batch of latents -> list of images
        # latents = (1 / 0.18215) * latents
        # with torch.no_grad():
        #     image = self.vae.decode(latents.to(torch.float16)).sample

        image = latents
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (image * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images][0]
        return pil_images

    def generate_and_save_images(self, current_epoch):
        generator = torch.manual_seed(42)
        noise = torch.randn(1, 3, config.image_size, config.image_size, generator=generator).cuda()
        #label = self.label_projection(torch.tensor(random.randint(0, 2)).cuda())
        noise = noise * self.noise_scheduler.init_noise_sigma

        for i, t in tqdm(enumerate(self.noise_scheduler.timesteps), total=len(self.noise_scheduler.timesteps)):
            noise = self.noise_scheduler.scale_model_input(noise, t)

            with torch.no_grad():
                noise_pred = self.unet(noise, torch.tensor(t).to(noise.device), torch.tensor([random.randint(0, 2)]).to(noise.device))
                #noise_pred = self.unet(noise, t, encoder_hidden_states=label.unsqueeze(0).unsqueeze(0))["sample"]

            noise = self.noise_scheduler.step(noise_pred, t.long(), noise).prev_sample

        decoded_noise = self.latents_to_pil(noise)

        # with torch.no_grad():
        #     decoded_noise = diffusion_model.vae.decode(noise.to(torch.float16)).sample

        #decoded_noise = transform(decoded_noise[0])

        #decoded_noise = pil_images[0]
        
        # Save the generated images to a unique folder for the current epoch
        save_dir = f"generated_images/epoch_{current_epoch}"
        os.makedirs(save_dir, exist_ok=True)
        
        decoded_noise.save(os.path.join(save_dir, f"image.png"))
        wandb_logger.log_image(key=f"generated_epoch_{current_epoch}", images=[decoded_noise])
        #self.logger.log_image(f"generated_epoch_{current_epoch}", [decoded_noise,]) 

    def apply_transform(self, image):
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (image * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]
        return(pil_images)
    
    def generate_visualize(self, current_epoch):
        clean_images, targets = self.last_image
        clean_images = clean_images.unsqueeze(0)

        with torch.no_grad():
            clean_latent = self.vae.encode(clean_images.to(torch.float16)).latent_dist.sample()
            image = self.vae.decode(clean_latent.to(torch.float16)).sample
            
        save_dir = f"coded_images/epoch_{current_epoch}"
        save_dir_orig = f"original_images/epoch_{current_epoch}"

        clean_images = self.apply_transform(clean_images)[0]
        image = self.apply_transform(image)[0]

        os.makedirs(save_dir, exist_ok=True)
        os.makedirs(save_dir_orig, exist_ok=True)

        image.save(os.path.join(save_dir, f"image.png"))
        clean_images.save(os.path.join(save_dir_orig, f"image.png"))
            
    def on_train_epoch_end(self):
        self.generate_and_save_images(self.current_epoch)
        #self.generate_visualize(self.current_epoch)

diffusion_model = DiffusionModel(config)

In [11]:
trainer = pl.Trainer(
    precision='16-mixed' if config.mixed_precision=='fp16' else 32,  # Set precision
    accelerator='auto',
    devices='auto',
    strategy='auto',
    max_epochs=config.num_epochs,
    logger=[TensorBoardLogger("logs/", name="stable-diffusion"), wandb_logger],
    callbacks=[LearningRateMonitor(logging_interval="step"), ModelCheckpoint(monitor="train_loss_epoch", mode="min")],
    #limit_train_batches=0.3, 
    
)
torch.set_float32_matmul_precision('high')

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [12]:
trainer.fit(diffusion_model, train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params
----------------------------------------------
0 | unet     | UNet_conditional | 23.3 M
1 | ema_unet | UNet_conditional | 23.3 M
----------------------------------------------
23.3 M    Trainable params
23.3 M    Non-trainable params
46.7 M    Total params
186.668   Total estimated model params size (MB)


Epoch 0: 100%|██████████| 834/834 [01:50<00:00,  7.54it/s, v_num=8jcv, train_loss_step=0.100, train_loss_epoch=0.171]

  noise_pred = self.unet(noise, torch.tensor(t).to(noise.device), torch.tensor([random.randint(0, 2)]).to(noise.device))
100%|██████████| 1000/1000 [00:07<00:00, 125.49it/s]


Epoch 1:   0%|          | 1/834 [00:01<18:03,  0.77it/s, v_num=8jcv, train_loss_step=0.145, train_loss_epoch=0.171]  

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


Epoch 1: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0783, train_loss_epoch=0.104]

  noise_pred = self.unet(noise, torch.tensor(t).to(noise.device), torch.tensor([random.randint(0, 2)]).to(noise.device))
100%|██████████| 1000/1000 [00:07<00:00, 128.54it/s]


Epoch 2: 100%|██████████| 834/834 [01:51<00:00,  7.48it/s, v_num=8jcv, train_loss_step=0.0506, train_loss_epoch=0.0692]

100%|██████████| 1000/1000 [00:07<00:00, 128.06it/s]


Epoch 3: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.041, train_loss_epoch=0.0515] 

100%|██████████| 1000/1000 [00:07<00:00, 128.63it/s]


Epoch 4: 100%|██████████| 834/834 [01:52<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.034, train_loss_epoch=0.0436] 

100%|██████████| 1000/1000 [00:07<00:00, 128.95it/s]


Epoch 5: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0307, train_loss_epoch=0.0392] 

100%|██████████| 1000/1000 [00:07<00:00, 128.40it/s]


Epoch 6: 100%|██████████| 834/834 [01:51<00:00,  7.49it/s, v_num=8jcv, train_loss_step=0.0282, train_loss_epoch=0.0364] 

100%|██████████| 1000/1000 [00:07<00:00, 130.88it/s]


Epoch 7: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0269, train_loss_epoch=0.0346] 

100%|██████████| 1000/1000 [00:07<00:00, 129.38it/s]


Epoch 8: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.026, train_loss_epoch=0.0331]  

100%|██████████| 1000/1000 [00:07<00:00, 127.82it/s]


Epoch 9: 100%|██████████| 834/834 [01:51<00:00,  7.48it/s, v_num=8jcv, train_loss_step=0.0252, train_loss_epoch=0.032]  

100%|██████████| 1000/1000 [00:07<00:00, 130.26it/s]


Epoch 10: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0247, train_loss_epoch=0.0311]

100%|██████████| 1000/1000 [00:07<00:00, 130.69it/s]


Epoch 11: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0241, train_loss_epoch=0.0303] 

100%|██████████| 1000/1000 [00:07<00:00, 128.58it/s]


Epoch 12: 100%|██████████| 834/834 [01:51<00:00,  7.48it/s, v_num=8jcv, train_loss_step=0.0236, train_loss_epoch=0.0296] 

100%|██████████| 1000/1000 [00:07<00:00, 128.17it/s]


Epoch 13: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0232, train_loss_epoch=0.029]  

100%|██████████| 1000/1000 [00:07<00:00, 130.44it/s]


Epoch 14: 100%|██████████| 834/834 [01:51<00:00,  7.49it/s, v_num=8jcv, train_loss_step=0.0229, train_loss_epoch=0.0286]

100%|██████████| 1000/1000 [00:07<00:00, 128.81it/s]


Epoch 15: 100%|██████████| 834/834 [01:51<00:00,  7.49it/s, v_num=8jcv, train_loss_step=0.0226, train_loss_epoch=0.0281] 

100%|██████████| 1000/1000 [00:07<00:00, 130.41it/s]


Epoch 16: 100%|██████████| 834/834 [01:51<00:00,  7.48it/s, v_num=8jcv, train_loss_step=0.0223, train_loss_epoch=0.0277] 

100%|██████████| 1000/1000 [00:07<00:00, 128.66it/s]


Epoch 17: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0218, train_loss_epoch=0.0272] 

100%|██████████| 1000/1000 [00:07<00:00, 128.65it/s]


Epoch 18: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0215, train_loss_epoch=0.0269] 

100%|██████████| 1000/1000 [00:07<00:00, 130.64it/s]


Epoch 19: 100%|██████████| 834/834 [01:51<00:00,  7.49it/s, v_num=8jcv, train_loss_step=0.0213, train_loss_epoch=0.0267] 

100%|██████████| 1000/1000 [00:07<00:00, 130.45it/s]


Epoch 20: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0212, train_loss_epoch=0.0262] 

100%|██████████| 1000/1000 [00:07<00:00, 127.58it/s]


Epoch 21: 100%|██████████| 834/834 [01:52<00:00,  7.44it/s, v_num=8jcv, train_loss_step=0.0204, train_loss_epoch=0.0259] 

100%|██████████| 1000/1000 [00:07<00:00, 129.03it/s]


Epoch 22: 100%|██████████| 834/834 [01:51<00:00,  7.48it/s, v_num=8jcv, train_loss_step=0.0201, train_loss_epoch=0.0255] 

100%|██████████| 1000/1000 [00:07<00:00, 128.56it/s]


Epoch 23: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0195, train_loss_epoch=0.0251] 

100%|██████████| 1000/1000 [00:07<00:00, 130.41it/s]


Epoch 24: 100%|██████████| 834/834 [01:52<00:00,  7.44it/s, v_num=8jcv, train_loss_step=0.0192, train_loss_epoch=0.0249] 

100%|██████████| 1000/1000 [00:07<00:00, 130.43it/s]


Epoch 25: 100%|██████████| 834/834 [01:52<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0192, train_loss_epoch=0.0245] 

100%|██████████| 1000/1000 [00:07<00:00, 128.24it/s]


Epoch 26: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0187, train_loss_epoch=0.0241] 

100%|██████████| 1000/1000 [00:07<00:00, 128.84it/s]


Epoch 27: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0181, train_loss_epoch=0.0238] 

100%|██████████| 1000/1000 [00:07<00:00, 128.08it/s]


Epoch 28: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0182, train_loss_epoch=0.0234] 

100%|██████████| 1000/1000 [00:07<00:00, 128.73it/s]


Epoch 29: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0178, train_loss_epoch=0.0232] 

100%|██████████| 1000/1000 [00:07<00:00, 128.07it/s]


Epoch 30: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0173, train_loss_epoch=0.0227] 

100%|██████████| 1000/1000 [00:07<00:00, 130.17it/s]


Epoch 31: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.017, train_loss_epoch=0.0223]  

100%|██████████| 1000/1000 [00:07<00:00, 130.40it/s]


Epoch 32: 100%|██████████| 834/834 [01:52<00:00,  7.43it/s, v_num=8jcv, train_loss_step=0.017, train_loss_epoch=0.022]   

100%|██████████| 1000/1000 [00:07<00:00, 130.44it/s]


Epoch 33: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0166, train_loss_epoch=0.0217]

100%|██████████| 1000/1000 [00:07<00:00, 129.64it/s]


Epoch 34: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0163, train_loss_epoch=0.0215] 

100%|██████████| 1000/1000 [00:07<00:00, 128.08it/s]


Epoch 35: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0161, train_loss_epoch=0.0212] 

100%|██████████| 1000/1000 [00:07<00:00, 130.95it/s]


Epoch 36: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0166, train_loss_epoch=0.0209] 

100%|██████████| 1000/1000 [00:07<00:00, 130.81it/s]


Epoch 37: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0158, train_loss_epoch=0.0206] 

100%|██████████| 1000/1000 [00:07<00:00, 129.05it/s]


Epoch 38: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0155, train_loss_epoch=0.0203] 

100%|██████████| 1000/1000 [00:07<00:00, 128.01it/s]


Epoch 39: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0152, train_loss_epoch=0.020]  

100%|██████████| 1000/1000 [00:07<00:00, 130.28it/s]


Epoch 40: 100%|██████████| 834/834 [01:52<00:00,  7.44it/s, v_num=8jcv, train_loss_step=0.015, train_loss_epoch=0.0197] 

100%|██████████| 1000/1000 [00:07<00:00, 130.49it/s]


Epoch 41: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0147, train_loss_epoch=0.0194] 

100%|██████████| 1000/1000 [00:07<00:00, 128.34it/s]


Epoch 42: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0145, train_loss_epoch=0.0191] 

100%|██████████| 1000/1000 [00:07<00:00, 128.37it/s]


Epoch 43: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0142, train_loss_epoch=0.0189] 

100%|██████████| 1000/1000 [00:07<00:00, 130.55it/s]


Epoch 44: 100%|██████████| 834/834 [01:52<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0144, train_loss_epoch=0.0187] 

100%|██████████| 1000/1000 [00:07<00:00, 131.02it/s]


Epoch 45: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.014, train_loss_epoch=0.0186]  

100%|██████████| 1000/1000 [00:07<00:00, 129.37it/s]


Epoch 46: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0136, train_loss_epoch=0.0184] 

100%|██████████| 1000/1000 [00:07<00:00, 128.43it/s]


Epoch 47: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0136, train_loss_epoch=0.0181] 

100%|██████████| 1000/1000 [00:07<00:00, 128.91it/s]


Epoch 48: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0137, train_loss_epoch=0.0178] 

100%|██████████| 1000/1000 [00:07<00:00, 130.86it/s]


Epoch 49: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0133, train_loss_epoch=0.0176] 

100%|██████████| 1000/1000 [00:07<00:00, 128.29it/s]


Epoch 50: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0134, train_loss_epoch=0.0173] 

100%|██████████| 1000/1000 [00:07<00:00, 128.55it/s]


Epoch 51: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0135, train_loss_epoch=0.0173] 

100%|██████████| 1000/1000 [00:07<00:00, 128.50it/s]


Epoch 52: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0132, train_loss_epoch=0.0171] 

100%|██████████| 1000/1000 [00:07<00:00, 128.63it/s]


Epoch 53: 100%|██████████| 834/834 [01:52<00:00,  7.43it/s, v_num=8jcv, train_loss_step=0.013, train_loss_epoch=0.0169]  

100%|██████████| 1000/1000 [00:07<00:00, 129.71it/s]


Epoch 54: 100%|██████████| 834/834 [01:51<00:00,  7.48it/s, v_num=8jcv, train_loss_step=0.0132, train_loss_epoch=0.0167] 

100%|██████████| 1000/1000 [00:07<00:00, 131.00it/s]


Epoch 55: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0131, train_loss_epoch=0.0165] 

100%|██████████| 1000/1000 [00:07<00:00, 128.13it/s]


Epoch 56: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0128, train_loss_epoch=0.0163] 

100%|██████████| 1000/1000 [00:07<00:00, 130.74it/s]


Epoch 57: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0122, train_loss_epoch=0.0162] 

100%|██████████| 1000/1000 [00:07<00:00, 128.01it/s]


Epoch 58: 100%|██████████| 834/834 [01:52<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0121, train_loss_epoch=0.016]  

100%|██████████| 1000/1000 [00:07<00:00, 129.18it/s]


Epoch 59: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.012, train_loss_epoch=0.0157] 

100%|██████████| 1000/1000 [00:07<00:00, 129.29it/s]


Epoch 60: 100%|██████████| 834/834 [01:51<00:00,  7.49it/s, v_num=8jcv, train_loss_step=0.0121, train_loss_epoch=0.0156] 

100%|██████████| 1000/1000 [00:07<00:00, 129.83it/s]


Epoch 61: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0123, train_loss_epoch=0.0155] 

100%|██████████| 1000/1000 [00:07<00:00, 128.82it/s]


Epoch 62: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0121, train_loss_epoch=0.0155] 

100%|██████████| 1000/1000 [00:07<00:00, 128.51it/s]


Epoch 63: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0117, train_loss_epoch=0.0153] 

100%|██████████| 1000/1000 [00:07<00:00, 130.76it/s]


Epoch 64: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0114, train_loss_epoch=0.015]  

100%|██████████| 1000/1000 [00:07<00:00, 128.60it/s]


Epoch 65: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0113, train_loss_epoch=0.0148]

100%|██████████| 1000/1000 [00:07<00:00, 128.52it/s]


Epoch 66: 100%|██████████| 834/834 [01:52<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0113, train_loss_epoch=0.0147] 

100%|██████████| 1000/1000 [00:07<00:00, 128.74it/s]


Epoch 67: 100%|██████████| 834/834 [01:51<00:00,  7.49it/s, v_num=8jcv, train_loss_step=0.0112, train_loss_epoch=0.0146] 

100%|██████████| 1000/1000 [00:07<00:00, 131.43it/s]


Epoch 68: 100%|██████████| 834/834 [01:52<00:00,  7.44it/s, v_num=8jcv, train_loss_step=0.0113, train_loss_epoch=0.0145] 

100%|██████████| 1000/1000 [00:07<00:00, 129.38it/s]


Epoch 69: 100%|██████████| 834/834 [01:51<00:00,  7.48it/s, v_num=8jcv, train_loss_step=0.0111, train_loss_epoch=0.0145] 

100%|██████████| 1000/1000 [00:07<00:00, 130.43it/s]


Epoch 70: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0111, train_loss_epoch=0.0143] 

100%|██████████| 1000/1000 [00:07<00:00, 128.88it/s]


Epoch 71: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0112, train_loss_epoch=0.0142] 

100%|██████████| 1000/1000 [00:07<00:00, 130.96it/s]


Epoch 72: 100%|██████████| 834/834 [01:52<00:00,  7.44it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0141]  

100%|██████████| 1000/1000 [00:07<00:00, 128.01it/s]


Epoch 73: 100%|██████████| 834/834 [01:52<00:00,  7.44it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.014]   

100%|██████████| 1000/1000 [00:07<00:00, 128.08it/s]


Epoch 74: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0139]

100%|██████████| 1000/1000 [00:07<00:00, 131.09it/s]


Epoch 75: 100%|██████████| 834/834 [01:51<00:00,  7.49it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0138] 

100%|██████████| 1000/1000 [00:07<00:00, 128.38it/s]


Epoch 76: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0138] 

100%|██████████| 1000/1000 [00:07<00:00, 128.04it/s]


Epoch 77: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0137] 

100%|██████████| 1000/1000 [00:07<00:00, 128.68it/s]


Epoch 78: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0136] 

100%|██████████| 1000/1000 [00:07<00:00, 128.05it/s]


Epoch 79: 100%|██████████| 834/834 [01:51<00:00,  7.48it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0136] 

100%|██████████| 1000/1000 [00:07<00:00, 128.78it/s]


Epoch 80: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0135] 

100%|██████████| 1000/1000 [00:07<00:00, 128.64it/s]


Epoch 81: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0135] 

100%|██████████| 1000/1000 [00:07<00:00, 128.20it/s]


Epoch 82: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0135] 

100%|██████████| 1000/1000 [00:07<00:00, 128.35it/s]


Epoch 83: 100%|██████████| 834/834 [01:50<00:00,  7.52it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0135] 

100%|██████████| 1000/1000 [00:07<00:00, 130.42it/s]


Epoch 84: 100%|██████████| 834/834 [01:51<00:00,  7.48it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0135] 

100%|██████████| 1000/1000 [00:07<00:00, 128.64it/s]

Epoch 85:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0135]          




Epoch 85: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0135] 

100%|██████████| 1000/1000 [00:07<00:00, 129.12it/s]

Epoch 86:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0135]          




Epoch 86: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0135] 

100%|██████████| 1000/1000 [00:07<00:00, 128.73it/s]

Epoch 87:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.0108, train_loss_epoch=0.0135]          




Epoch 87: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0136] 

100%|██████████| 1000/1000 [00:07<00:00, 130.18it/s]

Epoch 88:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0136]          




Epoch 88: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0136] 

100%|██████████| 1000/1000 [00:07<00:00, 131.16it/s]

Epoch 89:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0136]          




Epoch 89: 100%|██████████| 834/834 [01:52<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0136]  

100%|██████████| 1000/1000 [00:07<00:00, 128.50it/s]

Epoch 90:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0136]          




Epoch 90: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0137]  

100%|██████████| 1000/1000 [00:07<00:00, 129.73it/s]

Epoch 91:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0137]          




Epoch 91: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0138]  

100%|██████████| 1000/1000 [00:07<00:00, 128.07it/s]

Epoch 92:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0138]          




Epoch 92: 100%|██████████| 834/834 [01:51<00:00,  7.51it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0139] 

100%|██████████| 1000/1000 [00:07<00:00, 130.86it/s]

Epoch 93:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0139]          




Epoch 93: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0139] 

100%|██████████| 1000/1000 [00:07<00:00, 127.60it/s]

Epoch 94:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.0139]          




Epoch 94: 100%|██████████| 834/834 [01:51<00:00,  7.47it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.014]  

100%|██████████| 1000/1000 [00:07<00:00, 130.46it/s]

Epoch 95:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.0109, train_loss_epoch=0.014]          




Epoch 95: 100%|██████████| 834/834 [01:52<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0141] 

100%|██████████| 1000/1000 [00:07<00:00, 128.21it/s]


Epoch 96: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.0111, train_loss_epoch=0.0142] 

100%|██████████| 1000/1000 [00:07<00:00, 127.62it/s]

Epoch 97:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.0111, train_loss_epoch=0.0142]          




Epoch 97: 100%|██████████| 834/834 [01:51<00:00,  7.45it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0143]  

100%|██████████| 1000/1000 [00:07<00:00, 128.04it/s]

Epoch 98:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0143]          




Epoch 98: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0143]  

100%|██████████| 1000/1000 [00:07<00:00, 130.99it/s]

Epoch 99:   0%|          | 0/834 [00:00<?, ?it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0143]          




Epoch 99: 100%|██████████| 834/834 [01:51<00:00,  7.46it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0143]  

100%|██████████| 1000/1000 [00:07<00:00, 130.35it/s]
`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 834/834 [01:59<00:00,  6.98it/s, v_num=8jcv, train_loss_step=0.011, train_loss_epoch=0.0143]


In [None]:
diffusion_model = DiffusionModel.load_from_checkpoint('logs/stable-diffusion/version_0/checkpoints/epoch=7-step=320.ckpt', config=config)

In [10]:
diffusion_model.device

device(type='cpu')

In [38]:
generator = torch.manual_seed(42)
diffusion_model.unet = diffusion_model.unet.cuda()
noise = torch.randn(1, 3, config.image_size*2, config.image_size*2, generator=generator).cuda()
noise = noise * diffusion_model.noise_scheduler.init_noise_sigma
#label = diffusion_model.label_projection(torch.tensor(2).cuda())
label = torch.tensor(1).cuda()


for i, t in tqdm(enumerate(diffusion_model.noise_scheduler.timesteps), total=len(diffusion_model.noise_scheduler.timesteps)):
    noise = diffusion_model.noise_scheduler.scale_model_input(noise, t)

    with torch.no_grad():
        noise_pred = diffusion_model.unet(noise, torch.tensor(t).cuda(), label.unsqueeze(0))

    noise = diffusion_model.noise_scheduler.step(noise_pred, t.long(), noise).prev_sample

  noise_pred = diffusion_model.unet(noise, torch.tensor(t).cuda(), label.unsqueeze(0))


100%|██████████| 1000/1000 [00:55<00:00, 17.96it/s]


In [20]:
noise.shape

torch.Size([1, 3, 64, 64])

In [None]:
# from matplotlib import pyplot as plt

# fig, axs = plt.subplots(1, 4, figsize=(16, 4))
# for c in range(4):
#     axs[c].imshow(noise[0][c].cpu(), cmap='Greys')

In [39]:
# noise = (1 / 0.18215) * noise
# with torch.no_grad():
#     image = diffusion_model.vae.decode(noise).sample

image = noise
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
# pil_images = [Image.fromarray(image) for image in images][0]
# pil_images

In [40]:
Image.fromarray(images[0]).save(os.path.join('./', f"image.png"))