In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from accelerate import Accelerator
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
from diffusers.optimization import get_scheduler
from torchvision.utils import save_image
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Compose, CenterCrop

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


In [2]:
transform = Compose([
    ToTensor(),
    CenterCrop([128, 128])
])
dataset = ImageFolder("./data/pokemini", transform=transform)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0)

In [3]:
model = UNet2DModel(
sample_size=128,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D",
"AttnDownBlock2D",  "DownBlock2D",
),
up_block_types=(
"UpBlock2D", "AttnUpBlock2D",
"UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D",
),
)

In [4]:
accelerator = Accelerator(
    gradient_accumulation_steps=1,
    mixed_precision="no"
)

In [5]:
noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000
)

In [6]:
optimizer = torch.optim.AdamW(
    model.parameters(), lr=1E-4,
    betas=(0.95, 0.99),
    weight_decay=1E-6,
    eps=1E-8
)

In [7]:
lr_scheduler = get_scheduler(
    "cosine", optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=(
        len(dataloader) * 100
    ),
)

In [8]:
model, optimizer, dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, dataloader, lr_scheduler
)

In [9]:
for epoch in range(100):
    model.train()
    for step, batch in enumerate(dataloader):
        clean_images = batch[0].to("cuda")
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (clean_images.shape[0],), device=clean_images.device
        ).long()
        
        # Add noise to the clean images according to the noise magnitude
        # at each timestep (this is the forward diffusion process)
        noisy_images = noise_scheduler.add_noise(
            clean_images, noise, timesteps
        )

        with accelerator.accumulate(model):
            # Predict the noise residual
            noise_pred = model(noisy_images, timesteps).sample
            loss = F.mse_loss(noise_pred, noise)
            accelerator.backward(loss)

            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
    print(f"Epoch {epoch+1} / 100 - Loss: {loss.detach().item()}")

    # Using  Will stop the execution of the current process
    # until every other process has reached that point.
    # This does nothing when running on a single process.
    accelerator.wait_for_everyone()

    
    # Generate sample images for visual inspection
    if accelerator.is_main_process:
        if epoch % 10 == 0 or epoch == 100 - 1:
            pipeline = DDPMPipeline(
                unet=accelerator.unwrap_model(model),
                scheduler=noise_scheduler,
            )

            generator = torch.manual_seed(0)
            # run pipeline in inference (sample random noise and denoise)
            images = pipeline(
                generator=generator,
                batch_size=64,
                output_type="numpy"
            ).images

            # denormalize the images and save to wandb
            processed_images = torch.tensor(images).transpose(1, 3).transpose(2, 3)
            save_image(processed_images, f"./hf_epoch-{epoch+1}.png")
            pipeline.save_pretrained("hf_poke.pth")

    accelerator.wait_for_everyone()

accelerator.end_training()



Epoch 1 / 100 - Loss: 0.18759840726852417


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 2 / 100 - Loss: 0.205866739153862
Epoch 3 / 100 - Loss: 0.039883311837911606
Epoch 4 / 100 - Loss: 0.011146686971187592
Epoch 5 / 100 - Loss: 0.029633190482854843
Epoch 6 / 100 - Loss: 0.014035075902938843
Epoch 7 / 100 - Loss: 0.0035448407288640738
Epoch 8 / 100 - Loss: 0.004128352738916874
Epoch 9 / 100 - Loss: 0.0023402124643325806
Epoch 10 / 100 - Loss: 0.008877461776137352
Epoch 11 / 100 - Loss: 0.008535260334610939


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 12 / 100 - Loss: 0.05877518281340599
Epoch 13 / 100 - Loss: 0.012605486437678337
Epoch 14 / 100 - Loss: 0.002799785230308771
Epoch 15 / 100 - Loss: 0.012023687362670898
Epoch 16 / 100 - Loss: 0.00440285773947835
Epoch 17 / 100 - Loss: 0.0023765095975250006
Epoch 18 / 100 - Loss: 0.0056413523852825165
Epoch 19 / 100 - Loss: 0.0012780504766851664
Epoch 20 / 100 - Loss: 0.007354632019996643
Epoch 21 / 100 - Loss: 0.012356966733932495


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 22 / 100 - Loss: 0.02457582950592041
Epoch 23 / 100 - Loss: 0.024840451776981354
Epoch 24 / 100 - Loss: 0.003758967388421297
Epoch 25 / 100 - Loss: 0.0252738855779171
Epoch 26 / 100 - Loss: 0.010926840826869011
Epoch 27 / 100 - Loss: 0.002087017521262169
Epoch 28 / 100 - Loss: 0.004330662544816732
Epoch 29 / 100 - Loss: 0.0016544654499739408
Epoch 30 / 100 - Loss: 0.005153650883585215
Epoch 31 / 100 - Loss: 0.005730435252189636


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 32 / 100 - Loss: 0.01686091348528862
Epoch 33 / 100 - Loss: 0.01568448543548584
Epoch 34 / 100 - Loss: 0.001389938173815608
Epoch 35 / 100 - Loss: 0.008370171301066875
Epoch 36 / 100 - Loss: 0.006108378525823355
Epoch 37 / 100 - Loss: 0.002317538484930992
Epoch 38 / 100 - Loss: 0.0038189420010894537
Epoch 39 / 100 - Loss: 0.0013193362392485142
Epoch 40 / 100 - Loss: 0.010236789472401142
Epoch 41 / 100 - Loss: 0.006798309739679098


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 42 / 100 - Loss: 0.029594821855425835
Epoch 43 / 100 - Loss: 0.007751771714538336
Epoch 44 / 100 - Loss: 0.0009873182279989123
Epoch 45 / 100 - Loss: 0.006975287105888128
Epoch 46 / 100 - Loss: 0.0038861501961946487
Epoch 47 / 100 - Loss: 0.0014379543717950583
Epoch 48 / 100 - Loss: 0.0035689882934093475
Epoch 49 / 100 - Loss: 0.0004382262413855642
Epoch 50 / 100 - Loss: 0.010590891353785992
Epoch 51 / 100 - Loss: 0.006108391098678112


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 52 / 100 - Loss: 0.03222649171948433
Epoch 53 / 100 - Loss: 0.007497613783925772
Epoch 54 / 100 - Loss: 0.0013260047417134047
Epoch 55 / 100 - Loss: 0.011232979595661163
Epoch 56 / 100 - Loss: 0.004746669437736273
Epoch 57 / 100 - Loss: 0.0008758722105994821
Epoch 58 / 100 - Loss: 0.0021633291617035866
Epoch 59 / 100 - Loss: 0.0009512423421256244
Epoch 60 / 100 - Loss: 0.0030145791824907064
Epoch 61 / 100 - Loss: 0.005430744495242834


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 62 / 100 - Loss: 0.02029474824666977
Epoch 63 / 100 - Loss: 0.01170387677848339
Epoch 64 / 100 - Loss: 0.0008576257387176156
Epoch 65 / 100 - Loss: 0.0131174735724926
Epoch 66 / 100 - Loss: 0.007879412733018398
Epoch 67 / 100 - Loss: 0.0014958189567551017
Epoch 68 / 100 - Loss: 0.0020121994893997908
Epoch 69 / 100 - Loss: 0.0004982231184840202
Epoch 70 / 100 - Loss: 0.004348783753812313
Epoch 71 / 100 - Loss: 0.020378567278385162


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 72 / 100 - Loss: 0.019471127539873123
Epoch 73 / 100 - Loss: 0.010743064805865288
Epoch 74 / 100 - Loss: 0.0008399360813200474
Epoch 75 / 100 - Loss: 0.011767745018005371
Epoch 76 / 100 - Loss: 0.0034117549657821655
Epoch 77 / 100 - Loss: 0.0016075018793344498
Epoch 78 / 100 - Loss: 0.001964941853657365
Epoch 79 / 100 - Loss: 0.0007325410842895508
Epoch 80 / 100 - Loss: 0.003888264298439026
Epoch 81 / 100 - Loss: 0.0139863146468997


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 82 / 100 - Loss: 0.00900292582809925
Epoch 83 / 100 - Loss: 0.0054556829854846
Epoch 84 / 100 - Loss: 0.0007940351497381926
Epoch 85 / 100 - Loss: 0.006530681159347296
Epoch 86 / 100 - Loss: 0.0047417595051229
Epoch 87 / 100 - Loss: 0.0013857567682862282
Epoch 88 / 100 - Loss: 0.001885767444036901
Epoch 89 / 100 - Loss: 0.0005598432617262006
Epoch 90 / 100 - Loss: 0.008582176640629768
Epoch 91 / 100 - Loss: 0.005008901469409466


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 92 / 100 - Loss: 0.018885962665081024
Epoch 93 / 100 - Loss: 0.008572029881179333
Epoch 94 / 100 - Loss: 0.0008478917879983783
Epoch 95 / 100 - Loss: 0.008590107783675194
Epoch 96 / 100 - Loss: 0.005641000345349312
Epoch 97 / 100 - Loss: 0.0006619002670049667
Epoch 98 / 100 - Loss: 0.004685807507485151
Epoch 99 / 100 - Loss: 0.00048288164543919265
Epoch 100 / 100 - Loss: 0.0025606052950024605


  0%|          | 0/1000 [00:00<?, ?it/s]