notebook used to generate imgs from diffusion models

In [7]:

#? resources
#? vid 1: https://www.youtube.com/watch?v=HoKDTa5jHvg&t=177s
# explains how ddpm works
#? vid2: https://www.youtube.com/watch?v=TBCRlnwJtZU
# goes over implementation to ddpm

import random
import os
import copy
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
from helper_utils import *

from modules import UNet, UNet_conditional, EMA
import logging 
from torch.utils.tensorboard import SummaryWriter
#torch.set_default_tensor_type('torch.cuda.FloatTensor')

# library for quantization
try:
    import bitsandbytes as bnb
    print('imported bitsandbytes')
    
except:
    print('cant import bitsandbytes')
    bnb = None

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda", USE_GPU = True):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size #resoultion of image #: side note from video --> for higher resolutions, training seperate upsamplers instead of training on bigger resolution images
        self.device = device
        
        if USE_GPU and torch.cuda.is_available():
            print("CUDAAAAAAAAAAA")
            self.use_cuda = torch.device('cuda')
        else:
            self.use_cuda = torch.device('cpu')
        
        #? right now using simple beta schedule --> open AI using cosine scheduler        
        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)
        
        #! try implementing cosine scheduler
        
    def prepare_noise_schedule(self):
        #? Creates a one-dimensional tensor of size steps whose values are evenly spaced from start to end, inclusive
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)
    
    def noise_images(self, x, t):
        """Adds noise to image. You can iteratively add noise to image but vid 1 showed 
        a simplification that adds noise in 1 step. Which is this implementation
        Args:
            x (_type_): _description_
            t (_type_): _description_

        Returns:
            _type_: returns image with noise added on
        """
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        E = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * E, E
    
    def sample_timesteps(self, n):
        """_summary_

        Args:
            n (_type_): _description_

        Returns:
            _type_: _description_
        """
        #? needed for algorithm for training
        return torch.randint(low=1, high=self.noise_steps, size=(n,))
    

    def sample(self, model, n, labels, channels=3, cfg_scale=3):
        """implements algorithm 2 from the ddpm paper in vid 1

        Args:
            model (_type_): _description_
            n (int): number of images we want to sample 

        Returns:
            _type_: _description_
        """
        logging.info(f"Sampling {n} new images....")
        #? see here for why we set model.eval() https://stackoverflow.com/questions/60018578/what-does-model-eval-do-in-pytorch
        #? essentially disables some some parts of torch for specific steps
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            #? create initial images by sampling over normal dist (step 1)
            x = torch.randn((n, channels, self.img_size, self.img_size)).to(self.device)
            
            #? step 2, 3, 4
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device) #? tensor of timestep
                predicted_noise = model(x, t, labels) #? feed that into model w/ current images
                
                #? noise
                if cfg_scale > 0:
                    uncond_predicted_noise = model(x, t, None)
                    predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                
                #? only want noise for timestemps greater than 1. done so b/c in last iteration, would make final outcome worse due to adding noise to finalized pixels
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                    
                #? alter image by removed a little bit of noise
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        
        #? switch back to train    
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2 #? brings back value to 0-1 range 
        x = (x * 255).type(torch.uint8) #? bring back values to pixel range for viewing image
        return x


The following directories listed in your path were found to be non-existent: {WindowsPath('C'), WindowsPath('/Users/Efran/anaconda3/envs/compsci682/lib')}
The following directories listed in your path were found to be non-existent: {WindowsPath('vs/workbench/api/node/extensionHostProcess')}
The following directories listed in your path were found to be non-existent: {WindowsPath('module'), WindowsPath('/matplotlib_inline.backend_inline')}
The following directories listed in your path were found to be non-existent: {WindowsPath('/usr/local/cuda/lib64')}
DEBUG: Possible options found for libcudart.so: set()
CUDA SETUP: PyTorch settings found: CUDA_VERSION=118, Highest Compute Capability: 8.9.
CUDA SETUP: To manually override the PyTorch CUDA version please see:https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md
CUDA SETUP: Loading binary c:\Users\Efran\anaconda3\envs\compsci682\lib\site-packages\bitsandbytes\libbitsandbytes_cuda118.so...
argument of type 


python -m bitsandbytes


  warn(msg)
  warn(msg)


### generate img

In [28]:
def save_images_nb(images, path, sample_image_name, epoch, **kwargs):
    """saving images

    Args:
        images (_type_): _description_
        path (_type_): _description_
    """
    epoch = str(epoch)
    
    if os.path.exists(os.path.join(path)) == False:
        os.makedirs(os.path.join(path))
        
    count = 0
    for i in images:
        image_name = os.path.join(path, f'{count}_img.jpg')
        img = torchvision.utils.make_grid(i.float(), **kwargs)
        torchvision.utils.save_image(img, image_name, normalize=True)
        count += 1
    print('saved images to epoch folder')
    
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(os.path.join(path, sample_image_name))
    print('saved images to epoch grid image')

In [29]:
model_path = 'models'
model_list = os.listdir(model_path)
diffusion = Diffusion(img_size=56 , device='cuda')

count = 0
epoch = 49

num_img = 10

for model_type in model_list:
    saved_models = os.listdir(os.path.join(model_path, model_type))
    
    if '49_ckpt.pt' in saved_models and '49_ema_ckpt.pt' in saved_models:
        cfg = os.path.join(model_path, model_type, '49_ckpt.pt')
        cfg_model = torch.load(cfg)
        cfg_model.eval()
        numbers = [random.randint(0,9) for x in range(num_img)]
        labels = torch.FloatTensor(numbers).long().to('cuda')
        
        ema = os.path.join(model_path, model_type, '49_ema_ckpt.pt')
        ema_model = torch.load(ema)
        ema_model.eval()
        numbers = [random.randint(0,9) for x in range(num_img)]
        labels = torch.FloatTensor(numbers).long().to('cuda')
        
        sampled_images = diffusion.sample(cfg_model, n=len(labels), channels=1, labels=labels)
        ema_sampled_images = diffusion.sample(ema_model, n=len(labels), channels=1, labels=labels)
        
        
        
        sample_images_path = os.path.join("FINAL", model_type, 'ddpm_conditional')
        sample_image_name = f"grid.jpg"
        print(sample_images_path)
        
        save_images_nb(sampled_images, sample_images_path, sample_image_name, epoch=epoch)
        sample_images_path = os.path.join("FINAL", model_type, 'ddpm_conditional_ema')
        save_images_nb(ema_sampled_images, sample_images_path, 'ema' + sample_image_name, epoch=epoch)
        
        
    elif 'pruned_49_ckpt.pt' in saved_models and 'pruned_49_ema_ckpt.pt' in saved_models:
        cfg = os.path.join(model_path, model_type, 'pruned_49_ckpt.pt')
        cfg_model = torch.load(cfg)
        cfg_model.eval()
        numbers = [random.randint(0,9) for x in range(num_img)]
        labels = torch.FloatTensor(numbers).long().to('cuda')
        
        ema = os.path.join(model_path, model_type, 'pruned_49_ema_ckpt.pt')
        ema_model = torch.load(ema)
        ema_model.eval()
        numbers = [random.randint(0,9) for x in range(num_img)]
        labels = torch.FloatTensor(numbers).long().to('cuda')
        
        sampled_images = diffusion.sample(cfg_model, n=len(labels), channels=1, labels=labels)
        ema_sampled_images = diffusion.sample(ema_model, n=len(labels), channels=1, labels=labels)
        
        sample_images_path = os.path.join("FINAL", model_type, 'ddpm_conditional')
        sample_image_name = f"grid.jpg"
        print(sample_images_path)
        
        save_images_nb(sampled_images, sample_images_path, sample_image_name, epoch=epoch)
        sample_images_path = os.path.join("FINAL", model_type, 'ddpm_conditional_ema')
        save_images_nb(ema_sampled_images, sample_images_path, 'ema' + sample_image_name, epoch=epoch)
    print('----------------------------------------------------------')

CUDAAAAAAAAAAA


10:58:46 - INFO: Sampling 10 new images....
10:58:46 - INFO: Sampling 10 new images....
999it [00:54, 18.16it/s]
10:59:41 - INFO: Sampling 10 new images....
10:59:41 - INFO: Sampling 10 new images....
999it [00:54, 18.21it/s]
11:00:36 - INFO: Sampling 10 new images....
11:00:36 - INFO: Sampling 10 new images....


FINAL\base_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------


999it [00:55, 17.99it/s]
11:01:31 - INFO: Sampling 10 new images....
11:01:31 - INFO: Sampling 10 new images....
999it [00:55, 17.91it/s]
11:02:27 - INFO: Sampling 10 new images....
11:02:27 - INFO: Sampling 10 new images....


FINAL\d_2_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:53, 18.61it/s]
11:03:21 - INFO: Sampling 10 new images....
11:03:21 - INFO: Sampling 10 new images....
999it [00:53, 18.73it/s]
11:04:14 - INFO: Sampling 10 new images....
11:04:14 - INFO: Sampling 10 new images....


FINAL\d_4_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:52, 18.89it/s]
11:05:07 - INFO: Sampling 10 new images....
11:05:07 - INFO: Sampling 10 new images....
999it [00:54, 18.46it/s]


FINAL\p_l1_unstructured_10_base_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


11:06:02 - INFO: Sampling 10 new images....
11:06:02 - INFO: Sampling 10 new images....
999it [00:55, 18.05it/s]
11:06:57 - INFO: Sampling 10 new images....
11:06:57 - INFO: Sampling 10 new images....
999it [00:54, 18.27it/s]
11:07:52 - INFO: Sampling 10 new images....
11:07:52 - INFO: Sampling 10 new images....


FINAL\p_l1_unstructured_10_d_2_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:53, 18.53it/s]
11:08:46 - INFO: Sampling 10 new images....
11:08:46 - INFO: Sampling 10 new images....
999it [00:53, 18.52it/s]
11:09:40 - INFO: Sampling 10 new images....
11:09:40 - INFO: Sampling 10 new images....


FINAL\p_l1_unstructured_10_d_4_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:53, 18.67it/s]
11:10:33 - INFO: Sampling 10 new images....
11:10:33 - INFO: Sampling 10 new images....
999it [00:52, 19.05it/s]
11:11:26 - INFO: Sampling 10 new images....
11:11:26 - INFO: Sampling 10 new images....


FINAL\p_l1_unstructured_30_base_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:53, 18.61it/s]
11:12:20 - INFO: Sampling 10 new images....
11:12:20 - INFO: Sampling 10 new images....
999it [00:53, 18.50it/s]
11:13:14 - INFO: Sampling 10 new images....
11:13:14 - INFO: Sampling 10 new images....


FINAL\p_l1_unstructured_30_d_2_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:53, 18.58it/s]
11:14:08 - INFO: Sampling 10 new images....
11:14:08 - INFO: Sampling 10 new images....
999it [00:53, 18.67it/s]
11:15:01 - INFO: Sampling 10 new images....
11:15:01 - INFO: Sampling 10 new images....


FINAL\p_l1_unstructured_30_d_4_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:53, 18.77it/s]
11:15:55 - INFO: Sampling 10 new images....
11:15:55 - INFO: Sampling 10 new images....
999it [00:52, 18.87it/s]
11:16:48 - INFO: Sampling 10 new images....
11:16:48 - INFO: Sampling 10 new images....


FINAL\p_random_10_base_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:52, 18.92it/s]
11:17:40 - INFO: Sampling 10 new images....
11:17:40 - INFO: Sampling 10 new images....
999it [00:53, 18.77it/s]
11:18:34 - INFO: Sampling 10 new images....
11:18:34 - INFO: Sampling 10 new images....


FINAL\p_random_10_d_2_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:54, 18.22it/s]
11:19:29 - INFO: Sampling 10 new images....
11:19:29 - INFO: Sampling 10 new images....
999it [00:55, 18.15it/s]
11:20:24 - INFO: Sampling 10 new images....
11:20:24 - INFO: Sampling 10 new images....


FINAL\p_random_10_d_4_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:52, 19.02it/s]
11:21:16 - INFO: Sampling 10 new images....
11:21:16 - INFO: Sampling 10 new images....
999it [00:52, 18.98it/s]
11:22:09 - INFO: Sampling 10 new images....
11:22:09 - INFO: Sampling 10 new images....


FINAL\p_random_30_base_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:55, 18.14it/s]
11:23:04 - INFO: Sampling 10 new images....
11:23:04 - INFO: Sampling 10 new images....
999it [00:55, 18.11it/s]
11:24:00 - INFO: Sampling 10 new images....
11:24:00 - INFO: Sampling 10 new images....


FINAL\p_random_30_d_2_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------


999it [00:54, 18.17it/s]
11:24:55 - INFO: Sampling 10 new images....
11:24:55 - INFO: Sampling 10 new images....
999it [00:54, 18.25it/s]

FINAL\p_random_30_d_4_mnist_ddpm_conditional_ema\ddpm_conditional
saved images to epoch folder
saved images to epoch grid image
saved images to epoch folder
saved images to epoch grid image
----------------------------------------------------------



