# This notebook is used as an actual notebook, for testing Meta's DiT model with distillation

# Imports

In [None]:
def self_distillation_CIN(student, sampler_student, original, sampler_original, optimizer, scheduler,
            session=None, steps=20, gradient_updates=200, run_name="test",step_scheduler="naive", x0=False):
    """
    Params: student, sampler_student, original, sampler_original, optimizer, scheduler, session=None, steps=20, generations=200, run_name="test", decrease_steps=False, step_scheduler="deterministic"

    Task:Distill a model into itself. This is done by having a (teacher) model distill knowledge into itself. Copies of the original model and sampler 
    are passed in to compare the original untrained version with the distilled model at scheduled intervals.
    """
    NUM_CLASSES = 1000
    ddim_steps_student = steps # Setting the number of steps for the student model
    ddim_eta = 0.0 # Setting the eta value to 0.0 means a deterministic output given the original noise, essential
    # For both the student and the original model, the number of steps is set to the same value. 
    # Technically the original model does not need to be trained, but it is kept for comparison purposes.
    sampler_student.make_schedule(ddim_num_steps=ddim_steps_student, ddim_eta=ddim_eta, verbose=False)
    sampler_original.make_schedule(ddim_num_steps=ddim_steps_student, ddim_eta=ddim_eta, verbose=False)
    ddim_eta = 0.0 # Setting the eta value to 0.0 means a deterministic output given the original noise, essential
    scale = 3.0 # This is $w$ in the paper, the CFG scale. Can be left static or varied as is done occasionally.
    criterion = nn.MSELoss() 

    instance = 0 # Actual instance of student gradient updates
    generation = 0 # The amount of final-step images generated
    averaged_losses = []
    all_losses = []
    
    if step_scheduler == "iterative": # Halve the number of steps from start to 1 with even allocation of gradient updates
        halvings = math.floor(math.log(ddim_steps_student)/math.log(2))
        updates_per_halving = int(gradient_updates / halvings)
        step_sizes = []
        for i in range(halvings):
            step_sizes.append(int((steps) / (2**i)))
        update_list = []
        for i in step_sizes:
            update_list.append(int(updates_per_halving / int(i/ 2))) # /2 because of 2 steps per update
    elif step_scheduler == "naive": # Naive approach, evenly distribute gradient updates over all steps
        step_sizes=[ddim_steps_student]
        update_list=[gradient_updates // int(ddim_steps_student / 2)] # /2 because of 2 steps per update
    elif step_scheduler == "gradual_linear": # Gradually decrease the number of steps to 1, with even allocation of gradient updates
        step_sizes = np.arange(steps, 0, -2)
        update_list = ((1/len(np.append(step_sizes[1:], 1)) * gradient_updates / np.append(step_sizes[1:], 1))).astype(int) * 2 # *2 because of 2 steps per update
    elif step_scheduler == "gradual_exp": # TEMPORARY VERSION, to test if focus on higher steps is better, reverse of the one below
        step_sizes = np.arange(64, 0, -2)
        update_list = np.exp((1 / np.append(step_sizes[1:],1))[::-1]) / np.sum(np.exp((1 / np.append(step_sizes[1:],1))[::-1]))
        update_list = (update_list * gradient_updates /  np.append(step_sizes[1:],1)).astype(int) * 2 # *2 because of 2 steps per update

    with torch.no_grad():
        student.use_ema = False
        with student.ema_scope(): 
                if x0:
                    sc=None
                else:
                    sc = student.get_learned_conditioning({student.cond_stage_key: torch.tensor(1*[1000]).to(student.device)}) # Get the learned conditioning
                for i, step in enumerate(step_sizes): # For each step size
                    if instance != 0 and "gradual" not in step_scheduler:   # Save the model after every step size. Given the large model size, 
                                                                            # the gradual versions are not saved each time (steps * 2 * 4.7gb is a lot!)
                        util.save_model(sampler_student, optimizer, scheduler, name=step_scheduler, steps=updates, run_name=run_name)
                    updates = int(step / 2) # We take updates as half the step size, because we do 2 steps per update
                    generations = update_list[i] # The number of generations has been determined earlier
                    print("Distilling to:", step)
                    
                    with tqdm.tqdm(torch.randint(0, NUM_CLASSES, (generations,))) as tepoch: # Take a random class for each generation

                        for i, class_prompt in enumerate(tepoch):
                            generation += 1
                            losses = []       
                            
                            scale = np.random.uniform(1.0, 4.0) # Randomly sample a scale for each generation, optional
                            c_student = student.get_learned_conditioning({student.cond_stage_key: torch.tensor([class_prompt]).to(student.device)}) # Set to 0 for unconditional, requires pretraining
                            
                            samples_ddim= None # Setting to None will create a new noise vector for each generation
                            predictions_temp = []
                            
                            for steps in range(updates):
                                # with autocast() and torch.enable_grad(): # For mixed precision training, should not be used for final results
                                    with torch.enable_grad():
                                            instance += 1
                                            
                                            optimizer.zero_grad()
                                            samples_ddim, pred_x0_student, _, at= sampler_student.sample_student(S=1,
                                                                                conditioning=c_student,
                                                                                batch_size=1,
                                                                                shape=[3, 64, 64],
                                                                                verbose=False,
                                                                                x_T=samples_ddim, # start noise or teacher output
                                                                                unconditional_guidance_scale=scale,
                                                                                unconditional_conditioning=sc, 
                                                                                eta=ddim_eta,
                                                                                keep_intermediates=False,
                                                                                intermediate_step = steps*2,
                                                                                steps_per_sampling = 1,
                                                                                total_steps = ddim_steps_student)
                                            
                                            # Code below first decodes the latent image and then reconstructs it. This is not necessary, but can be used to check if the latent image is correct
                                            # decode_student = student.differentiable_decode_first_stage(pred_x0_student)
                                            # reconstruct_student = torch.clamp((decode_student+1.0)/2.0, min=0.0, max=1.0)
                               

                                            with torch.no_grad():
                                                samples_ddim.detach()
                                                samples_ddim, _, _, pred_x0_teacher, _ = sampler_student.sample(S=1,
                                                                            conditioning=c_student,
                                                                            batch_size=1,
                                                                            shape=[3, 64, 64],
                                                                            verbose=False,
                                                                            x_T=samples_ddim, # output of student
                                                                            unconditional_guidance_scale=scale,
                                                                            unconditional_conditioning=sc, 
                                                                            eta=ddim_eta,
                                                                            keep_intermediates=False,
                                                                            intermediate_step = steps*2+1,
                                                                            steps_per_sampling = 1,
                                                                            total_steps = ddim_steps_student)     

                                                # decode_teacher = student.decode_first_stage(pred_x0_teacher)
                                                # reconstruct_teacher = torch.clamp((decode_teacher+1.0)/2.0, min=0.0, max=1.0)
                                        
                                            
                                            # # NO AUTOCAST:
                                            signal = at
                                            noise = 1 - at
                                            log_snr = torch.log(signal / noise)
                                            weight = max(log_snr, 1)
                                            loss = weight * criterion(pred_x0_student, pred_x0_teacher.detach())                     
                                            loss.backward()
                                            optimizer.step()
                                            # scheduler.step()
                                            # torch.nn.utils.clip_grad_norm_(sampler_student.model.parameters(), 1)
                                            losses.append(loss.item())


                                            if session != None and instance % 400 == 0: # or instance==1:

                                                with torch.no_grad():
                                                    # the x0 version keeps max denoising steps to 64
                                                    images, _ = util.compare_teacher_student_x0(original, sampler_original, student, sampler_student, steps=[16, 8,  4, 2, 1], prompt=992, x0=x0)
                                                    images = wandb.Image(_, caption="left: Teacher, right: Student")
                                                    wandb.log({"pred_x0": images})

                                                    # Important: Reset the schedule, as compare_teacher_student changes max steps. 
                                                    sampler_student.make_schedule(ddim_num_steps=ddim_steps_student, ddim_eta=ddim_eta, verbose=False)
                                                    sampler_original.make_schedule(ddim_num_steps=ddim_steps_student, ddim_eta=ddim_eta, verbose=False)

                            all_losses.extend(losses)
                            averaged_losses.append(sum(losses) / len(losses))
                            if session != None:
                                session.log({"generation_loss":averaged_losses[-1]})
                            tepoch.set_postfix(epoch_loss=averaged_losses[-1])

                if step_scheduler == "naive" or "gradual" in step_scheduler: # Save the final model, since we skipped all the intermediate steps
                    util.save_model(sampler_student, optimizer, scheduler, name=step_scheduler, steps=updates, run_name=run_name)


In [None]:
!pip install diffusers timm --upgrade
!pip install filelock

In [26]:
from util_DiT import *
import os
%reload_ext autoreload
%autoreload 2
import torch
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
# from models import DiT_XL_2
from models import DiT_S_2
from PIL import Image
from IPython.display import display
# torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

cwd = os.getcwd()

# Creating Model

In [28]:
# Setting up image sizes
image_size = 256 #@param [256, 512]
vae_model = "stabilityai/sd-vae-ft-ema" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
latent_size = int(image_size) // 8

# Load model:
model = DiT_XL_2(input_size=latent_size).to(device)
# model = DiT_XL_2(input_size=latent_size).to(device)
# model = DiT_S_2(input_size=latent_size).to(device)

state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
# state_dict = find_model(f"DiT-S-2-{image_size}x{image_size}.pt")
model.load_state_dict(state_dict)
model.eval() # important!
vae = AutoencoderKL.from_pretrained(vae_model).to(device)

# Doing a single denoising step

In [4]:
n = 1
num_sampling_steps = 4 #@param {type:"slider", min:0, max:1000, step:1}
cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = [207]
samples_per_row = 4 #@param {type:"number"}

# Create diffusion object:
diffusion = create_diffusion(str(num_sampling_steps))
# Sample inputs:
z = torch.randn(1, model.in_channels, latent_size, latent_size, device=device)
y = torch.randint(0, 1, (n,), device=device)

# Setup classifier-free guidance:

z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=4)
sample_fn = model.forward_with_cfg


# Sample images:
samples = diffusion.ddim_sample_loop_progressive(
    sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
)

samples, _ = samples.chunk(2, dim=0)  # Remove null class samples

samples = vae.decode(samples / 0.18215).sample
save_image(samples, "sample.png", nrow=int(samples_per_row), 
           normalize=True, value_range=(-1, 1))
samples = Image.open("sample.png")
display(samples)


In [5]:
samples, pred_xstart = sample_step_grad(model.forward_with_cfg_grad, diffusion, 4, model_kwargs, timesteps, samples)

# Distillation Loop

In [2]:
from util_DiT import *
import os
%reload_ext autoreload
%autoreload 2
import torch
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
from models import DiT_XL_2
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(True)
from torch.cuda.amp import GradScaler, autocast
import tqdm
# device = "cuda" if torch.cuda.is_available() else "cpu"
%env PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32
%env CUDA_LAUNCH_BLOCKING=1
device = "cuda"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

cwd = os.getcwd()
# Setting up image sizes
image_size = 256 #@param [256, 512]
vae_model = "stabilityai/sd-vae-ft-ema" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
latent_size = int(image_size) // 8

with autocast():
    # Load model:
    model = DiT_XL_2(input_size=latent_size).to(device)
    # original = DiT_XL_2(input_size=latent_size).to(device)
    state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
    model.load_state_dict(state_dict)
    # original.load_state_dict(state_dict)
    # original.eval()
    model.eval() # important!
    vae = AutoencoderKL.from_pretrained(vae_model).to("cpu")
del state_dict, vae_model
torch.cuda.empty_cache()
steps = 20
generations = 10
decrease_steps = False



  from .autonotebook import tqdm as notebook_tqdm


env: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32
env: CUDA_LAUNCH_BLOCKING=1


In [3]:
for name, param in model.named_parameters():
    print(name)

pos_embed
x_embedder.proj.weight
x_embedder.proj.bias
t_embedder.mlp.0.weight
t_embedder.mlp.0.bias
t_embedder.mlp.2.weight
t_embedder.mlp.2.bias
y_embedder.embedding_table.weight
blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight
blocks.0.attn.proj.bias
blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias
blocks.0.adaLN_modulation.1.weight
blocks.0.adaLN_modulation.1.bias
blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight
blocks.1.attn.proj.bias
blocks.1.mlp.fc1.weight
blocks.1.mlp.fc1.bias
blocks.1.mlp.fc2.weight
blocks.1.mlp.fc2.bias
blocks.1.adaLN_modulation.1.weight
blocks.1.adaLN_modulation.1.bias
blocks.2.attn.qkv.weight
blocks.2.attn.qkv.bias
blocks.2.attn.proj.weight
blocks.2.attn.proj.bias
blocks.2.mlp.fc1.weight
blocks.2.mlp.fc1.bias
blocks.2.mlp.fc2.weight
blocks.2.mlp.fc2.bias
blocks.2.adaLN_modulation.1.weight
blocks.2.adaLN_modulation.1.bias
blocks.3.attn.qkv.weight
blocks.3.attn.qkv.bia

In [4]:
# Just checking to see whether updating is even possible, it is, but only with the linear layers as they dont take up much memory

params_to_update = []
for name, param in model.named_parameters():
    if 'linear' in name:
        params_to_update.append(param)


optimizer = torch.optim.Adam([
    {'params': params_to_update, 'lr': 0.001},  # Parameters to update
   # Other parameters (not updated)
])

In [2]:
# # # Set user inputs:
# # seed = 0 #@param {type:"number"}
# # torch.manual_seed(seed)
# num_sampling_steps = 4 #@param {type:"slider", min:0, max:1000, step:1}
# cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
# class_labels = [992] #@param {type:"raw"}
# samples_per_row = 1 #@param {type:"number"}

# # Create diffusion object:
# diffusion = create_diffusion(str(num_sampling_steps))

# # Create sampling noise:
# n = 1
# z = torch.randn(n, 4, latent_size, latent_size, device=device)
# y = torch.tensor(class_labels, device=device)

# # Setup classifier-free guidance:
# z = torch.cat([z, z], 0)
# y_null = torch.tensor([1000] * n, device=device)
# y = torch.cat([y, y_null], 0)
# model_kwargs = dict(y=y, cfg_scale=cfg_scale)

# # Sample images:
# samples = diffusion.ddim_sample_loop_progressive(
#     model.forward_with_cfg, z.shape, z, clip_denoised=False, 
#     model_kwargs=model_kwargs, progress=True, device=device
# )
# samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
# samples = vae.decode(samples / 0.18215).sample

# # Save and display images:
# save_image(samples, "sample.png", nrow=int(samples_per_row), 
#            normalize=True, value_range=(-1, 1))
# samples = Image.open("sample.png")
# display(samples)

# With intermediate steps!

In [3]:
# # Set user inputs:
# seed = 0 #@param {type:"number"}
# torch.manual_seed(seed)
# import tqdm
# num_sampling_steps =2 #@param {type:"slider", min:0, max:1000, step:1}
# cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
# class_labels = [992] #@param {type:"raw"}
# samples_per_row = 1 #@param {type:"number"}

# # Create diffusion object:
# diffusion = create_diffusion(str(num_sampling_steps))
# diffusion_original = create_diffusion(str(num_sampling_steps))

# # Create sampling noise:
# n = 1
# z = torch.randn(n, 4, latent_size, latent_size, device=device)
# y = torch.tensor(class_labels, device=device)

# # Setup classifier-free guidance:
# z = torch.cat([z, z], 0)
# y_null = torch.tensor([1000] * n, device=device)
# y = torch.cat([y, y_null], 0)
# model_kwargs = dict(y=y, cfg_scale=cfg_scale)


# samples = torch.randn(z.shape, device=device)
# indices = list(range(diffusion.num_timesteps))[::-1]
# for i in tqdm.tqdm(indices):
#     print(i)
#     samples = diffusion.ddim_sample_loop_progressive_intermediate(
#             model.forward_with_cfg, z.shape, noise=samples, clip_denoised=False, 
#             model_kwargs=model_kwargs, progress=False, device=device, step=i)
#     img, _ = samples.chunk(2, dim=0)  # Remove null class samples
#     img = vae.decode(img / 0.18215).sample
#     if i == 1:
#         last = img
#     else:
#         final = img
#     # Save and display images:
#     save_image(img, "sample.png", nrow=int(samples_per_row), 
#             normalize=True, value_range=(-1, 1))
#     img = Image.open("sample.png")
#     display(img)



In [5]:

import torch.nn as nn



def self_distillation_dit(model, optimizer,
            session=None, steps=20, generations=200, early_stop=True, run_name="test", decrease_steps=False,
            step_scheduler="deterministic", type="snellius"):
    """
    Distill a model into itself. This is done by having a (teacher) model distill knowledge into itself. Copies of the original model and sampler 
    are passed in to compare the original untrained version with the distilled model at scheduled intervals.
    """
    NUM_CLASSES = 1000
    gradient_updates = generations
    ddim_steps_student = steps
    TEACHER_STEPS = 2
    ddim_eta = 0.0
    scale = 3.0
    optimizer=optimizer
    averaged_losses = []
    criterion = nn.MSELoss()
    instance = 0
    generation = 0
    all_losses = []
    num_sampling_steps = 64

    if step_scheduler == "iterative":
        halvings = math.floor(math.log(64)/math.log(2))
        updates_per_halving = int(gradient_updates / halvings)
        step_sizes = []
        for i in range(halvings):
            step_sizes.append(int((steps) / (2**i)))
        update_list = []
        for i in step_sizes:
            update_list.append(int(updates_per_halving / int(i/ 2)))
    elif step_scheduler == "naive":
        step_sizes=[ddim_steps_student]
        update_list=[gradient_updates // int(ddim_steps_student / 2)]
    elif step_scheduler == "gradual_linear":
        step_sizes = np.arange(steps, 0, -2)
        update_list = (1/len(np.append(step_sizes[1:], 1)) * gradient_updates / np.append(step_sizes[1:], 1)).astype(int)
    elif step_scheduler == "gradual_exp":
        step_sizes = np.arange(64, 0, -2)
        update_list = np.exp(1 / np.append(step_sizes[1:],1)) / np.sum(np.exp(1 / np.append(step_sizes[1:],1)))
        update_list = (update_list * gradient_updates /  np.append(step_sizes[1:],1)).astype(int)

    scaler = GradScaler()
    diffusion = create_diffusion(str(num_sampling_steps))
    
    n = 1
    num_sampling_steps =4 #@param {type:"slider", min:0, max:1000, step:1}
    cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
    samples_per_row = 1 #@param {type:"number"}
    indices = list(range(diffusion.num_timesteps))[::-1]
    with torch.no_grad():
        # with student.ema_scope():              
              

                for i, step in enumerate(step_sizes):
                    # if instance != 0 and "gradual" not in step_scheduler:
                    #     util.save_model(sampler_student, optimizer, scheduler, name=step_scheduler, steps=updates, run_name=run_name)
                    updates = int(step / 2)
                    generations = update_list[i]
                    print("Distilling to:", updates)
                         
                    
                    # sc = student.get_learned_conditioning({student.cond_stage_key: torch.tensor(1*[1000]).to(student.device)})
                    
                    
                    with tqdm.tqdm(torch.randint(0, NUM_CLASSES, (generations,))) as tepoch:

                        for i, class_prompt in enumerate(tepoch):
                            generation += 1
                            losses = []        
                            class_labels = torch.tensor([class_prompt])
                            z = torch.randn(n, 4, latent_size, latent_size, device=device)
                            y = torch.tensor(class_labels, device=device)
                            z = torch.cat([z, z], 0)
                            y_null = torch.tensor([1000] * n, device=device)
                            y = torch.cat([y, y_null], 0)
                            model_kwargs = dict(y=y, cfg_scale=cfg_scale)
                            # teacher_kwargs = model_kwargs.copy()
                            samples_ddim= None
                            predictions_temp = []
                            
                            samples_teacher = torch.randn(z.shape, device=device)
                                
                            for steps in range(updates):  
                                    optimizer.zero_grad()
                                    with autocast():

                                        with torch.enable_grad():
                                            
                                            instance += 1
                                        
                                        
                                            samples_student = diffusion.ddim_sample_loop_progressive_intermediate(
                                                    model.forward_with_cfg, z.shape, noise=samples_teacher, clip_denoised=False, 
                                                    model_kwargs=model_kwargs, progress=False, device=device, step=steps*2, student=True)
                                            
                                            samples_student_x0, _ = samples_student.chunk(2, dim=0)
                                            # samples_student_x0, _ = samples.chunk(2, dim=0)
                                            samples_student.detach()

                                        with torch.no_grad():
                                            
                                            samples_teacher = diffusion.ddim_sample_loop_progressive_intermediate(
                                                    model.forward_with_cfg, z.shape, noise=samples_student, clip_denoised=False, 
                                                    model_kwargs=model_kwargs, progress=False, device=device, step=steps*2+1)   
                                        
                                            samples_teacher_x0, _ = samples_teacher.chunk(2, dim=0)
                                            samples_teacher.detach()
                                            
                                        with torch.enable_grad():    

                                            # # AUTOCAST:
                                            # signal = at
                                            # noise = 1 - at
                                            # log_snr = torch.log(signal / noise)
                                            # weight = max(log_snr, 1)
                                            # loss = weight * criterion(pred_x0_student, pred_x0_teacher.detach())
                                            loss = criterion(samples_student_x0, samples_teacher_x0.detach())
                                            # loss = criterion(samples_student[0], samples_teacher[0].detach())
                                            scaler.scale(loss).backward()
                                            scaler.step(optimizer)
                                            scaler.update()
                                            # torch.nn.utils.clip_grad_norm_(sampler_student.model.parameters(), 1)
                                            
                                            # scheduler.step()
                                            losses.append(loss.item())

                                            torch.cuda.empty_cache()
                                            # # NO AUTOCAST:
                                            # # signal = at
                                            # # noise = 1 - at
                                            # # log_snr = torch.log(signal / noise)
                                            # # weight = max(log_snr, 1)
                                            # # loss = criterion(samples_student, samples_teacher.detach())
                                            # loss = criterion(samples_student_x0, samples_teacher_x0.detach())
                                            # # loss = criterion(pred_x0_student, pred_x0_teacher.detach())
                                            # loss.backward()
                                            # optimizer.step()
                                            # # scheduler.step()
                                            # # torch.nn.utils.clip_grad_norm_(sampler_student.model.parameters(), 1)
                                            
                                            # losses.append(loss.item())
                                            
                                            
                                        # if session != None and generation % 200 == 0 and generation > 0:
                                                
                                        #     x_T_teacher_decode = sampler_student.model.decode_first_stage(pred_x0_teacher)
                                        #     teacher_target = torch.clamp((x_T_teacher_decode+1.0)/2.0, min=0.0, max=1.0)
                                        #     x_T_student_decode = sampler_student.model.decode_first_stage(pred_x0_student.detach())
                                        #     student_target  = torch.clamp((x_T_student_decode +1.0)/2.0, min=0.0, max=1.0)
                                        #     predictions_temp.append(teacher_target)
                                        #     predictions_temp.append(student_target)
                                            
                                        
                                    

                                        # if session != None and instance % 10000 == 0 and generation > 0:
                                        #     fids = util.get_fid(student, sampler_student, num_imgs=100, name=run_name, instance = instance+1, steps=[64, 32, 16, 8, 4, 2, 1])
                                        #     session.log({"fid_64":fids[0]})
                                        #     session.log({"fid_32":fids[1]})
                                        #     session.log({"fid_16":fids[2]})
                                        #     session.log({"fid_8":fids[3]})
                                        #     session.log({"fid_4":fids[4]})
                                        #     session.log({"fid_2":fids[5]})
                                        #     session.log({"fid_1":fids[6]})
                                        
                #                         if session != None and instance % 2000 == 0:
                                            
                #                             with torch.no_grad():
                #                                 images, _ = util.compare_teacher_student(original, sampler_original, student, sampler_student, steps=[64, 32, 16, 8,  4, 2, 1], prompt=992)
                #                                 images = wandb.Image(_, caption="left: Teacher, right: Student")
                #                                 wandb.log({"pred_x0": images})
                #                                 # images, _ = util.compare_teacher_student_with_schedule(original, sampler_original, student, sampler_student, steps=[64, 32, 16, 8,  4, 2, 1], prompt=992)
                #                                 # images = wandb.Image(_, caption="left: Teacher, right: Student")
                #                                 # wandb.log({"schedule": images})
                #                                 sampler_student.make_schedule(ddim_num_steps=ddim_steps_student, ddim_eta=ddim_eta, verbose=False)
                #                                 sampler_original.make_schedule(ddim_num_steps=ddim_steps_student, ddim_eta=ddim_eta, verbose=False)

                #             if generation > 0 and generation % 20 == 0 and ddim_steps_student != 1 and step_scheduler=="FID":
                #                 fid = util.get_fid(student, sampler_student, num_imgs=100, name=run_name, 
                #                             instance = instance, steps=[ddim_steps_student])
                #                 if fid[0] <= current_fid[0] * 0.9 and decrease_steps==True:
                #                     print(fid[0], current_fid[0])
                #                     if ddim_steps_student in [16, 8, 4, 2, 1]:
                #                         name = "intermediate"
                #                         saving_loading.save_model(sampler_student, optimizer, scheduler, name, steps * 2, run_name)
                #                     if ddim_steps_student != 2:
                #                         ddim_steps_student -= 2
                #                         updates -= 1
                #                     else:
                #                         ddim_steps_student = 1
                #                         updates = 1    
                #                     current_fid = fid
                #                     print("steps decreased:", ddim_steps_student)    

                #             if session != None:
                #                 with torch.no_grad():
                #                     if session != None and generation % 200 == 0 and generation > 0:
                #                         img, grid = util.compare_latents(predictions_temp)
                #                         images = wandb.Image(grid, caption="left: Teacher, right: Student")
                #                         wandb.log({"Inter_Comp": images})
                #                         del img, grid, predictions_temp, x_T_student_decode, x_T_teacher_decode, student_target, teacher_target
                #                         torch.cuda.empty_cache()
                            
                            all_losses.extend(losses)
                            averaged_losses.append(sum(losses) / len(losses))
                            print(averaged_losses[-1])
                            tepoch.set_postfix(epoch_loss=averaged_losses[-1])

                # if step_scheduler == "naive" or "gradual" in step_scheduler:
                #     util.save_model(sampler_student, optimizer, scheduler, name=step_scheduler, steps=updates, run_name=run_name)

num_sampling_steps = 4
lr = 0.001
# optimizer = torch.optim.Adam(model.parameters(), lr=lr)#, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# optimizer
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
# self_distillation_dit(model, original, optimizer, step_scheduler="naive")
self_distillation_dit(model, optimizer, step_scheduler="naive")

Distilling to: 10


  y = torch.tensor(class_labels, device=device)
  5%|▌         | 1/20 [00:06<02:05,  6.63s/it, epoch_loss=0.000634]

0.00063366120448336


 10%|█         | 2/20 [00:10<01:32,  5.11s/it, epoch_loss=0.000708]

0.0007082553289365024


 15%|█▌        | 3/20 [00:14<01:18,  4.61s/it, epoch_loss=0.000913]

0.0009127709694439545


 20%|██        | 4/20 [00:18<01:09,  4.37s/it, epoch_loss=0.00118] 

0.001177606068085879


 25%|██▌       | 5/20 [00:22<01:03,  4.25s/it, epoch_loss=0.00136]

0.00136482025263831


 30%|███       | 6/20 [00:26<00:58,  4.15s/it, epoch_loss=0.00143]

0.0014326108328532427


 35%|███▌      | 7/20 [00:30<00:57,  4.39s/it, epoch_loss=0.00165]

0.0016477228025905788





In [5]:
# from torch.cuda.amp import GradScaler, autocast
# import torch.nn as nn
# scaler = GradScaler()


# def self_distillation_dit(model, original, optimizer,
#             session=None, steps=20, generations=200, early_stop=True, run_name="test", decrease_steps=False,
#             step_scheduler="deterministic", type="snellius"):
#     """
#     Distill a model into itself. This is done by having a (teacher) model distill knowledge into itself. Copies of the original model and sampler 
#     are passed in to compare the original untrained version with the distilled model at scheduled intervals.
#     """
#     NUM_CLASSES = 1000
#     gradient_updates = generations
#     ddim_steps_student = steps
#     TEACHER_STEPS = 2
#     ddim_eta = 0.0
#     scale = 3.0
#     optimizer=optimizer
#     averaged_losses = []
#     criterion = nn.MSELoss()
#     instance = 0
#     generation = 0
#     all_losses = []
#     num_sampling_steps = 64

#     if step_scheduler == "iterative":
#         halvings = math.floor(math.log(64)/math.log(2))
#         updates_per_halving = int(gradient_updates / halvings)
#         step_sizes = []
#         for i in range(halvings):
#             step_sizes.append(int((steps) / (2**i)))
#         update_list = []
#         for i in step_sizes:
#             update_list.append(int(updates_per_halving / int(i/ 2)))
#     elif step_scheduler == "naive":
#         step_sizes=[ddim_steps_student]
#         update_list=[gradient_updates // int(ddim_steps_student / 2)]
#     elif step_scheduler == "gradual_linear":
#         step_sizes = np.arange(steps, 0, -2)
#         update_list = (1/len(np.append(step_sizes[1:], 1)) * gradient_updates / np.append(step_sizes[1:], 1)).astype(int)
#     elif step_scheduler == "gradual_exp":
#         step_sizes = np.arange(64, 0, -2)
#         update_list = np.exp(1 / np.append(step_sizes[1:],1)) / np.sum(np.exp(1 / np.append(step_sizes[1:],1)))
#         update_list = (update_list * gradient_updates /  np.append(step_sizes[1:],1)).astype(int)


#     diffusion = create_diffusion(str(num_sampling_steps))
#     diffusion_original = create_diffusion(str(num_sampling_steps)) 
#     n = 1
#     num_sampling_steps =4 #@param {type:"slider", min:0, max:1000, step:1}
#     cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
#     samples_per_row = 1 #@param {type:"number"}
#     indices = list(range(diffusion.num_timesteps))[::-1]
#     with torch.no_grad():
#         # with student.ema_scope():              
              

#                 for i, step in enumerate(step_sizes):
#                     # if instance != 0 and "gradual" not in step_scheduler:
#                     #     util.save_model(sampler_student, optimizer, scheduler, name=step_scheduler, steps=updates, run_name=run_name)
#                     updates = int(step / 2)
#                     generations = update_list[i]
#                     print("Distilling to:", updates)
                         
                    
#                     # sc = student.get_learned_conditioning({student.cond_stage_key: torch.tensor(1*[1000]).to(student.device)})
                    
                    
#                     with tqdm.tqdm(torch.randint(0, NUM_CLASSES, (generations,))) as tepoch:

#                         for i, class_prompt in enumerate(tepoch):
#                             generation += 1
#                             losses = []        
#                             class_labels = torch.tensor([class_prompt])
#                             z = torch.randn(n, 4, latent_size, latent_size, device=device)
#                             y = torch.tensor(class_labels, device=device)
#                             z = torch.cat([z, z], 0)
#                             y_null = torch.tensor([1000] * n, device=device)
#                             y = torch.cat([y, y_null], 0)
#                             model_kwargs = dict(y=y, cfg_scale=cfg_scale)
#                             # teacher_kwargs = model_kwargs.copy()
#                             samples_ddim= None
#                             predictions_temp = []
                            
#                             samples_teacher = torch.randn(z.shape, device=device)
                                
#                             for steps in range(updates):  
#                                         optimizer.zero_grad()
#                                     # with autocast():

#                                         with torch.enable_grad():
                                            
#                                             instance += 1
                                        
                                        
#                                             samples_student = diffusion.ddim_sample_loop_progressive_intermediate(
#                                                     model.forward_with_cfg, z.shape, noise=samples_teacher, clip_denoised=False, 
#                                                     model_kwargs=model_kwargs, progress=False, device=device, step=steps*2, student=True)
#                                             samples_student_x0, _ = samples_student.chunk(2, dim=0)
#                                             # samples_student_x0, _ = samples.chunk(2, dim=0)

#                                         with torch.no_grad():
                                            
#                                             samples_teacher = diffusion_original.ddim_sample_loop_progressive_intermediate(
#                                                     original.forward_with_cfg, z.shape, noise=samples_student, clip_denoised=False, 
#                                                     model_kwargs=model_kwargs, progress=False, device=device, step=steps*2+1)   
                                        
#                                             samples_teacher_x0, _ = samples_teacher.chunk(2, dim=0)
                                    
#                                         with torch.enable_grad():    
#                                             print(samples_student.requires_grad)
#                                             # # AUTOCAST:
#                                             # signal = at
#                                             # noise = 1 - at
#                                             # log_snr = torch.log(signal / noise)
#                                             # weight = max(log_snr, 1)
#                                             # loss = weight * criterion(pred_x0_student, pred_x0_teacher.detach())
#                                             # scaler.scale(loss).backward()
#                                             # scaler.step(optimizer)
#                                             # scaler.update()
#                                             # # torch.nn.utils.clip_grad_norm_(sampler_student.model.parameters(), 1)
                                            
#                                             # scheduler.step()
#                                             # losses.append(loss.item())

                                            
#                                             # NO AUTOCAST:
#                                             # signal = at
#                                             # noise = 1 - at
#                                             # log_snr = torch.log(signal / noise)
#                                             # weight = max(log_snr, 1)
#                                             # loss = criterion(samples_student, samples_teacher.detach())
#                                             loss = criterion(samples_student_x0, samples_teacher_x0.detach())
#                                             # loss = criterion(pred_x0_student, pred_x0_teacher.detach())
#                                             loss.backward()
#                                             optimizer.step()
#                                             # scheduler.step()
#                                             # torch.nn.utils.clip_grad_norm_(sampler_student.model.parameters(), 1)
                                            
#                                             losses.append(loss.item())
                                            
#                                         # if session != None and generation % 200 == 0 and generation > 0:
                                                
#                                         #     x_T_teacher_decode = sampler_student.model.decode_first_stage(pred_x0_teacher)
#                                         #     teacher_target = torch.clamp((x_T_teacher_decode+1.0)/2.0, min=0.0, max=1.0)
#                                         #     x_T_student_decode = sampler_student.model.decode_first_stage(pred_x0_student.detach())
#                                         #     student_target  = torch.clamp((x_T_student_decode +1.0)/2.0, min=0.0, max=1.0)
#                                         #     predictions_temp.append(teacher_target)
#                                         #     predictions_temp.append(student_target)
                                            
                                        
                                    

#                                         # if session != None and instance % 10000 == 0 and generation > 0:
#                                         #     fids = util.get_fid(student, sampler_student, num_imgs=100, name=run_name, instance = instance+1, steps=[64, 32, 16, 8, 4, 2, 1])
#                                         #     session.log({"fid_64":fids[0]})
#                                         #     session.log({"fid_32":fids[1]})
#                                         #     session.log({"fid_16":fids[2]})
#                                         #     session.log({"fid_8":fids[3]})
#                                         #     session.log({"fid_4":fids[4]})
#                                         #     session.log({"fid_2":fids[5]})
#                                         #     session.log({"fid_1":fids[6]})
                                        
#                 #                         if session != None and instance % 2000 == 0:
                                            
#                 #                             with torch.no_grad():
#                 #                                 images, _ = util.compare_teacher_student(original, sampler_original, student, sampler_student, steps=[64, 32, 16, 8,  4, 2, 1], prompt=992)
#                 #                                 images = wandb.Image(_, caption="left: Teacher, right: Student")
#                 #                                 wandb.log({"pred_x0": images})
#                 #                                 # images, _ = util.compare_teacher_student_with_schedule(original, sampler_original, student, sampler_student, steps=[64, 32, 16, 8,  4, 2, 1], prompt=992)
#                 #                                 # images = wandb.Image(_, caption="left: Teacher, right: Student")
#                 #                                 # wandb.log({"schedule": images})
#                 #                                 sampler_student.make_schedule(ddim_num_steps=ddim_steps_student, ddim_eta=ddim_eta, verbose=False)
#                 #                                 sampler_original.make_schedule(ddim_num_steps=ddim_steps_student, ddim_eta=ddim_eta, verbose=False)

#                 #             if generation > 0 and generation % 20 == 0 and ddim_steps_student != 1 and step_scheduler=="FID":
#                 #                 fid = util.get_fid(student, sampler_student, num_imgs=100, name=run_name, 
#                 #                             instance = instance, steps=[ddim_steps_student])
#                 #                 if fid[0] <= current_fid[0] * 0.9 and decrease_steps==True:
#                 #                     print(fid[0], current_fid[0])
#                 #                     if ddim_steps_student in [16, 8, 4, 2, 1]:
#                 #                         name = "intermediate"
#                 #                         saving_loading.save_model(sampler_student, optimizer, scheduler, name, steps * 2, run_name)
#                 #                     if ddim_steps_student != 2:
#                 #                         ddim_steps_student -= 2
#                 #                         updates -= 1
#                 #                     else:
#                 #                         ddim_steps_student = 1
#                 #                         updates = 1    
#                 #                     current_fid = fid
#                 #                     print("steps decreased:", ddim_steps_student)    

#                 #             if session != None:
#                 #                 with torch.no_grad():
#                 #                     if session != None and generation % 200 == 0 and generation > 0:
#                 #                         img, grid = util.compare_latents(predictions_temp)
#                 #                         images = wandb.Image(grid, caption="left: Teacher, right: Student")
#                 #                         wandb.log({"Inter_Comp": images})
#                 #                         del img, grid, predictions_temp, x_T_student_decode, x_T_teacher_decode, student_target, teacher_target
#                 #                         torch.cuda.empty_cache()
                            
#                             all_losses.extend(losses)
#                             averaged_losses.append(sum(losses) / len(losses))
#                             print(averaged_losses[-1])
#                             tepoch.set_postfix(epoch_loss=averaged_losses[-1])

#                 # if step_scheduler == "naive" or "gradual" in step_scheduler:
#                 #     util.save_model(sampler_student, optimizer, scheduler, name=step_scheduler, steps=updates, run_name=run_name)

# num_sampling_steps = 4
# lr = 0.001
# optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
# self_distillation_dit(model, original, optimizer, step_scheduler="naive")

In [6]:
# n = 1
# num_sampling_steps =4 #@param {type:"slider", min:0, max:1000, step:1}
# cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
# samples_per_row = 1 #@param {type:"number"}
# indices = list(range(diffusion.num_timesteps))[::-1]
# class_labels = torch.tensor([992])
# z = torch.randn(n, 4, latent_size, latent_size, device=device)
# y = torch.tensor(class_labels, device=device)
# z = torch.cat([z, z], 0)
# y_null = torch.tensor([1000] * n, device=device)
# y = torch.cat([y, y_null], 0)
# model_kwargs = dict(y=y, cfg_scale=cfg_scale)

# # Sample images:
# samples = diffusion.ddim_sample_loop_progressive(
#     model.forward_with_cfg, z.shape, z, clip_denoised=False, 
#     model_kwargs=model_kwargs, progress=True, device=device
# )
# samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
# samples = vae.decode(samples / 0.18215).sample

# # Save and display images:
# save_image(samples, "sample.png", nrow=int(samples_per_row), 
#            normalize=True, value_range=(-1, 1))
# samples = Image.open("sample.png")
# display(samples)

In [7]:
# # Sample images:
# samples = diffusion_original.ddim_sample_loop_progressive(
#     original.forward_with_cfg, z.shape, z, clip_denoised=False, 
#     model_kwargs=model_kwargs, progress=True, device=device
# )
# samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
# samples = vae.decode(samples / 0.18215).sample

# # Save and display images:
# save_image(samples, "sample.png", nrow=int(samples_per_row), 
#            normalize=True, value_range=(-1, 1))
# samples_orig = Image.open("sample.png")
# display(samples_orig)

In [8]:
# samples == samples_orig