In [None]:
import torch
from modules import UNet_conditional
from diffusion import *

In [None]:
def create_sections_list(length, total_sum, schedule_function):
    assert total_sum >= length, "Total sum must be at least equal to the length of the list"

    sigmoid_values = torch.tensor([schedule_function(torch.tensor(i), 
                                                     max_steps=length, 
                                                     k=0.9) for i in range(1, length + 1)])
    normalized_sigmoid_values = sigmoid_values / torch.sum(sigmoid_values)
    available_sum = total_sum - length
    scaled_sigmoid_values = normalized_sigmoid_values * available_sum
    integer_list = (torch.ones(length) + scaled_sigmoid_values).int().tolist()

    current_sum = sum(integer_list)
    difference = total_sum - current_sum
    
    if difference != 0:
        sign = int(difference / abs(difference))
        indices = list(range(length))
        torch.randperm(len(indices)).tolist()

        for i in indices:
            if difference == 0:
                break
            integer_list[i] += sign
            difference -= sign

    return sorted(integer_list)

def cosine_step_schedule(step, max_steps=1000, k=0.9):
    t = step / max_steps
    return 0.5 * (1 + torch.cos(t * torch.pi))



In [None]:
path = "models/nophys/ema_ckpt.pt"
print("Loading ", path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

model = UNet_conditional(length=40,
                        #img_width=128, 
                         #img_height=64, 
                         #feat_num=3, 
                         device=device).to(device)
ckpt = torch.load(path, map_location=device)
model.load_state_dict(ckpt)
sampler = SpacedDiffusion(beta_start=1e-4, 
                          beta_end=0.02, 
                          noise_steps=1000, 
                          section_counts=create_sections_list(10, 
                                                              25, 
                                                              cosine_step_schedule), 
                          #img_height=64, 
                          #img_width=128,
                          length=40, 
                          device=device, 
                          rescale_timesteps=False)

y = torch.Tensor([0]).to(device).float().unsqueeze(0) # parameter vector

n = 4
x = sampler.ddim_sample_loop(model=model, 
                             y=y, 
                             cfg_scale=1, 
                             device=device, 
                             eta=1, 
                             n=n #number of samples
                             )