notebook used to get parameter count and model size from diffusion models

In [1]:

#? resources
#? vid 1: https://www.youtube.com/watch?v=HoKDTa5jHvg&t=177s
# explains how ddpm works
#? vid2: https://www.youtube.com/watch?v=TBCRlnwJtZU
# goes over implementation to ddpm

import random
import os
import copy
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
from helper_utils import *

from modules import UNet, UNet_conditional, EMA
import logging 
from torch.utils.tensorboard import SummaryWriter
#torch.set_default_tensor_type('torch.cuda.FloatTensor')

# library for quantization
try:
    import bitsandbytes as bnb
    print('imported bitsandbytes')
    
except:
    print('cant import bitsandbytes')
    bnb = None

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda", USE_GPU = True):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size #resoultion of image #: side note from video --> for higher resolutions, training seperate upsamplers instead of training on bigger resolution images
        self.device = device
        
        if USE_GPU and torch.cuda.is_available():
            print("CUDAAAAAAAAAAA")
            self.use_cuda = torch.device('cuda')
        else:
            self.use_cuda = torch.device('cpu')
        
        #? right now using simple beta schedule --> open AI using cosine scheduler        
        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)
        
        #! try implementing cosine scheduler
        
    def prepare_noise_schedule(self):
        #? Creates a one-dimensional tensor of size steps whose values are evenly spaced from start to end, inclusive
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)
    
    def noise_images(self, x, t):
        """Adds noise to image. You can iteratively add noise to image but vid 1 showed 
        a simplification that adds noise in 1 step. Which is this implementation
        Args:
            x (_type_): _description_
            t (_type_): _description_

        Returns:
            _type_: returns image with noise added on
        """
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        E = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * E, E
    
    def sample_timesteps(self, n):
        """_summary_

        Args:
            n (_type_): _description_

        Returns:
            _type_: _description_
        """
        #? needed for algorithm for training
        return torch.randint(low=1, high=self.noise_steps, size=(n,))
    

    def sample(self, model, n, labels, channels=3, cfg_scale=3):
        """implements algorithm 2 from the ddpm paper in vid 1

        Args:
            model (_type_): _description_
            n (int): number of images we want to sample 

        Returns:
            _type_: _description_
        """
        logging.info(f"Sampling {n} new images....")
        #? see here for why we set model.eval() https://stackoverflow.com/questions/60018578/what-does-model-eval-do-in-pytorch
        #? essentially disables some some parts of torch for specific steps
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            #? create initial images by sampling over normal dist (step 1)
            x = torch.randn((n, channels, self.img_size, self.img_size)).to(self.device)
            
            #? step 2, 3, 4
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device) #? tensor of timestep
                predicted_noise = model(x, t, labels) #? feed that into model w/ current images
                
                #? noise
                if cfg_scale > 0:
                    uncond_predicted_noise = model(x, t, None)
                    predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                
                #? only want noise for timestemps greater than 1. done so b/c in last iteration, would make final outcome worse due to adding noise to finalized pixels
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                    
                #? alter image by removed a little bit of noise
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        
        #? switch back to train    
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2 #? brings back value to 0-1 range 
        x = (x * 255).type(torch.uint8) #? bring back values to pixel range for viewing image
        return x

<module 'torch._C' from 'c:\\Users\\Efran\\anaconda3\\envs\\compsci682\\lib\\site-packages\\torch\\_C.cp310-win_amd64.pyd'>
False

The following directories listed in your path were found to be non-existent: {WindowsPath('C'), WindowsPath('/Users/Efran/anaconda3/envs/compsci682/lib')}
The following directories listed in your path were found to be non-existent: {WindowsPath('vs/workbench/api/node/extensionHostProcess')}
The following directories listed in your path were found to be non-existent: {WindowsPath('/matplotlib_inline.backend_inline'), WindowsPath('module')}
The following directories listed in your path were found to be non-existent: {WindowsPath('/usr/local/cuda/lib64')}
DEBUG: Possible options found for libcudart.so: set()
CUDA SETUP: PyTorch settings found: CUDA_VERSION=118, Highest Compute Capability: 8.9.
CUDA SETUP: To manually override the PyTorch CUDA version please see:https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md
CUDA SETUP: Load


python -m bitsandbytes


  warn(msg)
  warn(msg)



The following directories listed in your path were found to be non-existent: {WindowsPath('C'), WindowsPath('/Users/Efran/anaconda3/envs/compsci682/lib')}
The following directories listed in your path were found to be non-existent: {WindowsPath('vs/workbench/api/node/extensionHostProcess')}
The following directories listed in your path were found to be non-existent: {WindowsPath('/matplotlib_inline.backend_inline'), WindowsPath('module')}
The following directories listed in your path were found to be non-existent: {WindowsPath('/usr/local/cuda/lib64')}
DEBUG: Possible options found for libcudart.so: set()
CUDA SETUP: PyTorch settings found: CUDA_VERSION=118, Highest Compute Capability: 8.9.
CUDA SETUP: To manually override the PyTorch CUDA version please see:https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md
CUDA SETUP: Loading binary c:\Users\Efran\anaconda3\envs\compsci682\lib\site-packages\bitsandbytes\libbitsandbytes_cuda118.so...
argument of type 

### Get Params Count

In [2]:
def model_elems_params(model):
    total_params = 0
    total_nonzero_params = 0

    for param in model.parameters():
        total_params += param.numel()
        total_nonzero_params += torch.count_nonzero(param)
    
    return total_params, total_nonzero_params.item()



In [3]:
model_path = 'models'
model_list = os.listdir(model_path)

count = 0

for model_type in model_list:
    saved_models = os.listdir(os.path.join(model_path, model_type))
    
    if '49_ckpt.pt' in saved_models and '49_ema_ckpt.pt' in saved_models:
        cfg = os.path.join(model_path, model_type, '49_ckpt.pt')
        cfg_model = torch.load(cfg)
        total_params, total_nonzero_params = model_elems_params(cfg_model)
        print(os.path.join(model_path, model_type, '49_ckpt.pt'), total_params, total_nonzero_params)
        
        ema = os.path.join(model_path, model_type, '49_ema_ckpt.pt')
        ema_model = torch.load(ema)
        total_params, total_nonzero_params = model_elems_params(ema_model)
        print(os.path.join(model_path, model_type, '49_ema_ckpt.pt'), total_params, total_nonzero_params)
        #count += 2
        
        
    elif 'pruned_49_ckpt.pt' in saved_models and 'pruned_49_ema_ckpt.pt' in saved_models:
        cfg = os.path.join(model_path, model_type, 'pruned_49_ckpt.pt')
        cfg_model = torch.load(cfg)
        total_params, total_nonzero_params = model_elems_params(cfg_model)
        print(os.path.join(model_path, model_type, 'pruned_49_ckpt.pt'), total_params, total_nonzero_params)
        
        ema = os.path.join(model_path, model_type, 'pruned_49_ema_ckpt.pt')
        ema_model = torch.load(ema)
        total_params, total_nonzero_params = model_elems_params(ema_model)
        print(os.path.join(model_path, model_type, 'pruned_49_ema_ckpt.pt'), total_params, total_nonzero_params)
        #count += 2
    print('----------------------------------------------------------')
    


models\base_mnist_ddpm_conditional_ema\49_ckpt.pt 17868145 17868145
models\base_mnist_ddpm_conditional_ema\49_ema_ckpt.pt 17868145 17868145
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
models\d_2_mnist_ddpm_conditional_ema\49_ckpt.pt 14029681 14029681
models\d_2_mnist_ddpm_conditional_ema\49_ema_ckpt.pt 14029681 14029681
----------------------------------------------------------
models\d_4_mnist_ddpm_conditional_ema\49_ckpt.pt 12110449 12110449
models\d_4_mnist_ddpm_conditional_ema\49_ema_ckpt.pt 12110449 12110449
----------------------------------------------------------
models\p_l1_unstructured_10_base_mnist_ddpm_conditional_ema\pruned_49_ckpt.pt 17868145 16122904
models\p_l1_unstructured_10_base_mnist_ddpm_conditional_ema\pruned_49_ema_ckpt.pt 17868145 16122904
----------------------------------------------------------
models\p_l1_unstructured_10_d_2_mn

### Get Model Size

In [10]:
model_path = 'models'
new_model_path = 'model_state_dict'
model_list = os.listdir(model_path)

for model_type in model_list:
    saved_models = os.listdir(os.path.join(model_path, model_type))
    
    if '49_ckpt.pt' in saved_models and '49_ema_ckpt.pt' in saved_models:
        cfg = os.path.join(model_path, model_type, '49_ckpt.pt')
        cfg_model = torch.load(cfg)

        
        ema = os.path.join(model_path, model_type, '49_ema_ckpt.pt')
        ema_model = torch.load(ema)
        
        save_loc = os.path.join(new_model_path, model_type)
        if os.path.exists(save_loc) == False:
            os.makedirs(save_loc)
        
        torch.save(cfg_model.state_dict(), os.path.join(save_loc, '49_ckpt.pt'))
        torch.save(ema_model.state_dict(), os.path.join(save_loc, '49_ema_ckpt.pt'))
        
    elif 'pruned_49_ckpt.pt' in saved_models and 'pruned_49_ema_ckpt.pt' in saved_models:
        cfg = os.path.join(model_path, model_type, 'pruned_49_ckpt.pt')
        cfg_model = torch.load(cfg)
        
        
        ema = os.path.join(model_path, model_type, 'pruned_49_ema_ckpt.pt')
        ema_model = torch.load(ema)
        
        save_loc = os.path.join(new_model_path, model_type)
        if os.path.exists(save_loc) == False:
            os.makedirs(save_loc)
        torch.save(cfg_model.state_dict(), os.path.join(save_loc, 'pruned_49_ckpt.pt'))
        torch.save(ema_model.state_dict(), os.path.join(save_loc, 'pruned_49_ema_ckpt.pt'))
    
    print('----------------------------------------------------------')

----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
----------------------------------------------------------
--------------------------------------------------------

In [None]:
gzip -qfc '49_ckpt.pt' > '49_ckpt.pt.gz'
gzip -qfc '49_ema_ckpt.pt' > '49_ema_ckpt.pt.gz'

In [None]:
gzip -qfc 'pruned_49_ckpt.pt' > 'pruned_49_ckpt.pt.gz'
gzip -qfc 'pruned_49_ema_ckpt.pt' > 'pruned_49_ema_ckpt.pt.gz'

In [36]:

new_model_path = 'model_state_dict'
model_list = os.listdir(new_model_path)

for model_type in model_list:
    zipped_models = os.listdir(os.path.join(new_model_path, model_type))
    #print(zipped_models)
    if 'pruned_49_ckpt.pt.gz' in zipped_models and 'pruned_49_ema_ckpt.pt.gz' in zipped_models:
        cfg = os.path.join(new_model_path, model_type, 'pruned_49_ckpt.pt.gz')
        ema = os.path.join(new_model_path, model_type, 'pruned_49_ema_ckpt.pt.gz')
        
        print(cfg, os.path.getsize(cfg))
        print(ema, os.path.getsize(ema))
        
    if '49_ckpt.pt.gz' in zipped_models and '49_ema_ckpt.pt.gz' in zipped_models:
        cfg = os.path.join(new_model_path, model_type, '49_ckpt.pt.gz')
        ema = os.path.join(new_model_path, model_type, '49_ema_ckpt.pt.gz')
        print(cfg, os.path.getsize(cfg))
        print(ema, os.path.getsize(ema))
        
    print('----------------------------------------------------------')

model_state_dict\base_mnist_ddpm_conditional_ema\49_ckpt.pt.gz 66538096
model_state_dict\base_mnist_ddpm_conditional_ema\49_ema_ckpt.pt.gz 66538147
----------------------------------------------------------
model_state_dict\d_2_mnist_ddpm_conditional_ema\49_ckpt.pt.gz 52203120
model_state_dict\d_2_mnist_ddpm_conditional_ema\49_ema_ckpt.pt.gz 52204450
----------------------------------------------------------
model_state_dict\d_4_mnist_ddpm_conditional_ema\49_ckpt.pt.gz 45030932
model_state_dict\d_4_mnist_ddpm_conditional_ema\49_ema_ckpt.pt.gz 45031920
----------------------------------------------------------
model_state_dict\p_l1_unstructured_10_base_mnist_ddpm_conditional_ema\pruned_49_ckpt.pt.gz 62448687
model_state_dict\p_l1_unstructured_10_base_mnist_ddpm_conditional_ema\pruned_49_ema_ckpt.pt.gz 62448829
----------------------------------------------------------
model_state_dict\p_l1_unstructured_10_d_2_mnist_ddpm_conditional_ema\pruned_49_ckpt.pt.gz 49014554
model_state_dict\p_l1