In [None]:
import torch
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
from accelerate import Accelerator
import os
import random
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import numpy as np

In [None]:
# Dataset configuration
RANDOM_SEED = 42
IMG_SIZE = 64 
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
NUM_EPOCHS = 500
NUM_GENERATE_IMAGES = 9
NUM_TIMESTEPS = 1000
MIXED_PRECISION = "fp16"
GRADIENT_ACCUMULATION_STEPS = 1
CLASSES = 10

# Torch configs
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()

In [None]:
def model_init():
    # Model initialization
    model = UNet2DModel(
        sample_size=IMG_SIZE,  
        in_channels=3,
        out_channels=3,
        layers_per_block=2,
        block_out_channels=(64, 64, 128, 128, 256, 256),
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",
            "DownBlock2D"
        ),
        up_block_types=(
            "UpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D"
        )
    )
    return model

In [None]:
# Sample image generation function
def sample_image_generation(model, noise_scheduler, num_generate_images, random_seed, num_timesteps):
    pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
    images = pipeline(
        batch_size=num_generate_images,
        generator=torch.manual_seed(random_seed),
        num_inference_steps=num_timesteps
    ).images
    fig = plt.figure()
    for i in range(1, num_generate_images + 1):
        fig.add_subplot(3, 3, i)
        plt.imshow(images[i-1])
    plt.show()

In [None]:
model_path = 'saved_models'
os.makedirs(model_path, exist_ok=True)
models = {}
accelerator = Accelerator(
        mixed_precision=MIXED_PRECISION,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS
    )
noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TIMESTEPS)

for class_idx in range(CLASSES):
    models[class_idx] = model_init().to(device)
    model_path_ = os.path.join(model_path, f"model_DDPM_{class_idx}.pth")
    models[class_idx].load_state_dict(torch.load(model_path_))
    
print('Loaded pre-trained models\n')

In [None]:
#for class_idx in range(CLASSES):
#    sample_image_generation(models[class_idx], noise_scheduler, NUM_GENERATE_IMAGES, RANDOM_SEED, NUM_TIMESTEPS)

In [None]:
def generate(models, num_images_per_class, num_classes, random_seed, num_timesteps, device, save=False):
    torch.cuda.empty_cache()
    
    output_dir = 'generated_dataset/'
    os.makedirs(output_dir, exist_ok=True)
    
    real_class_names = [12, 13, 24, 38, 39, 44, 46, 49, 50, 6]
    all_generated_images = []
    all_labels = []
    
    for class_label in range(num_classes):
        model = models[class_label].to(device)
        
        with torch.no_grad():
            pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
    
            gen_images = pipeline(
                batch_size=num_images_per_class,
                generator=torch.manual_seed(random_seed),
                num_inference_steps=num_timesteps
            ).images
    
        for i in range(len(gen_images)):
            
            generated_image = gen_images[i]
            generated_image = transforms.ToTensor()(generated_image).to(device)
            
            if save:
                # Denormalize the image from [-1, 1] to [0, 1]
                denormalized_image = (generated_image + 1) / 2
                
                class_path = os.path.join(output_dir, f'{real_class_names[class_label]}')
                os.makedirs(class_path, exist_ok=True)
                image_path = os.path.join(class_path, f'{real_class_names[class_label]}_{i}.png')
                
                # Save the image
                save_image(denormalized_image, image_path)
                    
            all_generated_images.append(generated_image)
            all_labels.append(class_label)
                
    return all_generated_images, all_labels

In [None]:
size = 5
save = True
all_generated_images, all_labels = generate(models, size, CLASSES, RANDOM_SEED, NUM_TIMESTEPS, device, save)