In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset

from tqdm.notebook import tqdm

from diffusers import UNet2DModel, DDPMScheduler, DDIMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import DDPMPipeline  
from accelerate import Accelerator

import os

In [3]:
def get_classes(dataset, category: str|list[str]): 
    classes = []
    if isinstance(category, str): 
        classes.append(dataset.class_to_idx[category])
    elif isinstance(category, list): 
        for c in category: 
            classes.append(dataset.class_to_idx[c])

    indexes = list(filter(lambda x: dataset[x][1]  in classes, range(len(dataset))))
    return Subset(dataset, indexes)


preprocess = transforms.Compose(
    [
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

train_dataset = ImageFolder(root='../data/imgs', transform=preprocess)
train_dataset = get_classes(train_dataset, ['lemonade', 'alcohol', 'cold_tea'])
train_loader = DataLoader(train_dataset, batch_size=6, shuffle=True)

In [4]:
model = UNet2DModel(
    sample_size=128,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=1,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D", 
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  
        "AttnUpBlock2D",  
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

In [10]:
def train_loop(model, noise_scheduler, optimizer, train_dataloader, lr_scheduler, accelerator, n_epochs, continue_training=False):
    
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    for epoch in range(n_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for  batch in train_dataloader:
            clean_images = batch[0]

            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape, device=clean_images.device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,
                dtype=torch.int64
            )

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)
    
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1
          

In [11]:
from accelerate import notebook_launcher
from diffusers.optimization import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=(len(train_loader) * 300),
)

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)            

accelerator = Accelerator(          
        mixed_precision="fp16" ,
        gradient_accumulation_steps=1,  
        log_with="tensorboard",
        project_dir=None,
    )    
    
args = (model, noise_scheduler, optimizer, train_loader, lr_scheduler, accelerator, 5, True)  

notebook_launcher(train_loop, args, num_processes=1)

Launching training on one GPU.


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


  0%|          | 0/47 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 48.00 MiB. GPU 0 has a total capacity of 3.71 GiB of which 20.44 MiB is free. Including non-PyTorch memory, this process has 3.68 GiB memory in use. Of the allocated memory 3.37 GiB is allocated by PyTorch, and 198.92 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [21]:
from diffusers.utils import make_image_grid

def evaluate(filename, pipeline):
    images = pipeline(
        batch_size=25,
        generator=torch.Generator(device='cpu').manual_seed(23534), # Use a separate torch generator to avoid rewinding the random state of the main training loop
    ).images

    # Make a grid out of the images
    image_grid = make_image_grid(images, rows=5, cols=5)

    # Save the images
    test_dir = os.path.join('generated' , "samples")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(f"{test_dir}/{filename}.png")  

In [22]:

pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
evaluate('300_epochs_2',  pipeline)

  0%|          | 0/1000 [00:00<?, ?it/s]