In [1]:
import cv2
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import os
from torchvision import transforms

In [2]:
source_folder_path = 'img_align_celeba'
destination_folder_path = 'train_images'

Resize images so that training and generation does not take as long

In [3]:
def resize_and_save_images(source_folder, destination_folder, target_size):
    if not os.path.exists(source_folder):
        print(f"Source folder '{source_folder}' does not exist.")
        return

    os.makedirs(destination_folder, exist_ok=True)

    for filename in os.listdir(source_folder):
        source_path = os.path.join(source_folder, filename)
        destination_path = os.path.join(destination_folder, filename)

        image = cv2.imread(source_path)

        if image is not None:
            resized_image = cv2.resize(image, target_size)
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            cv2.imwrite(destination_path, resized_image)

In [5]:
target_size = (64, 64)

resize_and_save_images(source_folder_path, destination_folder_path, target_size)

In [6]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
from diffusers.utils import make_image_grid, pt_to_pil
from dataclasses import dataclass
from transformers import get_cosine_schedule_with_warmup
import tqdm
from accelerate import Accelerator, notebook_launcher
from diffusers.utils import pt_to_pil

  torch.utils._pytree._register_pytree_node(


In [7]:
@dataclass
class TrainingConfig:
    image_size = 64
    train_batch_size = 16
    eval_batch_size = 16
    mixed_precision = "fp16"
    output_dir = "gen_model"
    gradient_accumulation_steps = 1
    start_epoch = 0
    total_epochs = 100
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 20
    overwrite_output_dir = True
    seed = 0

In [8]:
class ImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_list = [os.path.join(folder_path, filename) for filename in os.listdir(folder_path)]

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_path = self.image_list[idx]
        images = cv2.imread(img_path)
 
        if self.transform:
            images = self.transform(images)

        return {"images": images}

In [31]:
config = TrainingConfig()
    
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((config.image_size, config.image_size)),
    transforms.ToTensor(),
])

dataset = ImageDataset('train_images', transform=transform)
    
dataloader = DataLoader(dataset,
                        batch_size=config.train_batch_size,
                        shuffle=True)
device = ("cuda" if torch.cuda.is_available()
           else "mps" if torch.backends.mps.is_available()
           else "cpu")

In [21]:
model = UNet2DModel(
    in_channels = 3,
    out_channels = 3,
    sample_size = config.image_size,
    layers_per_block = 2,
    block_out_channels = (128,128,256,256,512,512),
    down_block_types = [
        "DownBlock2D","DownBlock2D",
        "DownBlock2D","DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D"
    ],
    up_block_types = [
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D","UpBlock2D","UpBlock2D","UpBlock2D"
    ]        
).to(device)
print(model)

UNet2DModel(
  (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): LoRACompatibleLinear(in_features=128, out_features=512, bias=True)
    (act): SiLU()
    (linear_2): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
  )
  (down_blocks): ModuleList(
    (0-1): 2 x DownBlock2D(
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=128, bias=True)
          (norm2): GroupNorm(32, 128, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsampl

In [22]:
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=config.learning_rate)
    
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(dataloader)*config.total_epochs))

In [23]:
checkpoint_filename = os.path.join(config.output_dir,
                                   "checkpoint.pt")
config.start_epoch = 0
if os.path.exists(checkpoint_filename):
    checkpoint = torch.load(checkpoint_filename)
    config.start_epoch = checkpoint["epoch"]+1
    model.load_state_dict(checkpoint["network"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
    print("Loading previous checkpoint...")

In [24]:
def evaluate(config, epoch, pipeline):
    images = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.manual_seed(config.seed)
    ).images
        
    image_grid = make_image_grid(images, 4, 4)
        
    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(os.path.join(test_dir,
                                 "Image_%04d.png" % epoch))

In [25]:
def train_loop(config, model, optimizer,
               noise_scheduler, dataloader,
               lr_scheduler):
        
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs")
    )
        
    (model,
     optimizer,
     dataloader,
     noise_scheduler,
     lr_scheduler) = accelerator.prepare(
                                        model,
                                        optimizer,
                                        dataloader,
                                        noise_scheduler,
                                        lr_scheduler
                                        )
         
    if accelerator.is_main_process:
        os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("training")  
        
    global_step = 0
        
    for epoch in range(config.start_epoch,
                       config.total_epochs):        
        progress_bar = tqdm.tqdm(total=len(dataloader),
                                 disable=not accelerator.is_main_process)
        progress_bar.set_description(f"Epoch {epoch}")
            
        for batch in dataloader:
            clean_images = batch["images"]   
                
                
            noise = torch.randn(clean_images.shape).to(
                           clean_images.device)
            bs = clean_images.shape[0]
            timesteps = torch.randint(0,
                            noise_scheduler.config.num_train_timesteps,
                            (bs,),
                            device=clean_images.device).long()
            noisy_images = noise_scheduler.add_noise(
                    clean_images,
                    noise,
                    timesteps
            )
                                             
                
            with accelerator.accumulate(model):
                noise_pred = model(noisy_images,
                                   timesteps,
                                   return_dict=False)[0]
                loss = torch.nn.functional.mse_loss(noise_pred, noise)
                    
                accelerator.backward(loss)
                accelerator.clip_grad_norm_(model.parameters(),1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {
                "loss": loss.detach().item(),
                "step": global_step,
                "lr": lr_scheduler.get_last_lr()[0]                     
            }
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1
                
        if accelerator.is_main_process:
            pipeline = DDPMPipeline(
                unet=accelerator.unwrap_model(model),
                scheduler=noise_scheduler)

            if(epoch+1)%config.save_image_epochs == 0:
                evaluate(config, epoch, pipeline)
                                                     
            if(epoch+1)%config.save_model_epochs == 0:
                pipeline.save_pretrained(config.output_dir)

                unwrapped = accelerator.unwrap_model(model)
                save_info = {
                    "epoch": epoch,
                    "network": unwrapped.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "lr_scheduler": lr_scheduler.state_dict()                
                }
                torch.save(save_info, checkpoint_filename)

In [26]:
args = (config, model, optimizer, noise_scheduler,
        dataloader, lr_scheduler)
notebook_launcher(train_loop, args, num_processes=1) 

Launching training on CPU.



  0%|                                                                              | 0/10822 [00:00<?, ?it/s][A
Epoch 0:   0%|                                                                     | 0/10822 [00:00<?, ?it/s][A
Epoch 0:   0%|                                                         | 1/10822 [00:54<162:56:21, 54.21s/it][A
Epoch 0:   0%|                             | 1/10822 [00:54<162:56:21, 54.21s/it, loss=1.05, lr=2e-7, step=0][A
Epoch 0:   0%|                             | 2/10822 [01:40<149:23:27, 49.70s/it, loss=1.05, lr=2e-7, step=0][A
Epoch 0:   0%|                             | 2/10822 [01:40<149:23:27, 49.70s/it, loss=1.04, lr=4e-7, step=1][A
Epoch 0:   0%|                             | 3/10822 [02:16<130:24:15, 43.39s/it, loss=1.04, lr=4e-7, step=1][A
Epoch 0:   0%|                             | 3/10822 [02:16<130:24:15, 43.39s/it, loss=1.05, lr=6e-7, step=2][A
Epoch 0:   0%|                             | 4/10822 [02:52<121:19:29, 40.37s/it, loss=1.05, lr

KeyboardInterrupt: 

In [27]:
@dataclass
class GenConfig:
    image_size = 64
    model_dir = "gen_model"
    output_dir = "gen_images"
    seed = 0
    eval_batch_size = 1

In [28]:
config = GenConfig()
    
modelLoc = os.path.join(config.model_dir, "finalModel.pt")
if (not os.path.exists(modelLoc)):
    modelLoc = os.path.join(config.model_dir, "checkpoint.pt") #use latest checkpoint of final model does not exist
if (not os.path.exists(modelLoc)):
    print("CANNOT FIND FINAL OR CHECKPOINT MODELS! EXITING!!!")
    exit(1)

CANNOT FIND FINAL OR CHECKPOINT MODELS! EXITING!!!


In [30]:
device = ("cuda" if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available()
            else "cpu")

model = UNet2DModel(
        in_channels = 3,
        out_channels = 3,
        sample_size = config.image_size,
        layers_per_block = 2,
        block_out_channels = (128,128,256,256,512,512),
        down_block_types = [
            "DownBlock2D","DownBlock2D",
            "DownBlock2D","DownBlock2D",
            "AttnDownBlock2D",
            "DownBlock2D"
        ],
        up_block_types = [
            "UpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D","UpBlock2D","UpBlock2D","UpBlock2D"
        ]        
).to(device)

noise_scheduler = DDPMScheduler()
   
os.makedirs(config.output_dir, exist_ok=True)
finalModelDict = torch.load(modelLoc)
   
model.load_state_dict(finalModelDict["network"])
   
pipeline = DDPMPipeline(unet=model,
                        scheduler=noise_scheduler)
   
for i in range(1000):
    evaluateEpoch(config, i, pipeline)

FileNotFoundError: [Errno 2] No such file or directory: 'gen_model\\checkpoint.pt'

In [None]:
def evaluateEpoch(config, epoch, pipeline):
    images = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.manual_seed(epoch)
    ).images
   
    image_grid = make_image_grid(images, 1, 1)
   
    test_dir = os.path.join(config.output_dir)
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(os.path.join(test_dir,
                                    "Image_%04d.png" % epoch))

In [None]:
import torch_fidelity as torf
from torch.utils.data import Dataset

In [None]:
def main():
    metrics = torf.calculate_metrics(
        input1="gen_images",
        input2="train_images",
        fid=True,
        kid=True,
        kid_subset_size=25, # Just doing this for expediency 200,
        cuda=True)
    print(metrics)