In [1]:
# from data_process import *

# train_loader, test_loader = load_mnist()


In [2]:
from model import *
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F

device = 'cuda'
model_diff = SimpleUnet().to(device)
print("Num params diffusion: ", sum(p.numel() for p in model_diff.parameters()))
model_path = '/home/longvv/generative_model/models/diffusion_mnist_fixed.pth'
model_diff.load_state_dict(torch.load(model_path, weights_only=True))

from torchvision import models
model_cls = models.resnet18()
model_cls.conv1 = nn.Conv2d(
    in_channels=2,
    out_channels=64,
    kernel_size=3,
    stride=1,
    padding=1,
    bias=False
)
model_cls.maxpool = nn.Identity()
model_cls.fc = nn.Linear(in_features=512, out_features=10, bias=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_cls = model_cls.to(device)
print("Num params classifier: ", sum(p.numel() for p in model_cls.parameters()))


Num params diffusion:  62248577
Num params classifier:  11173386


In [3]:
T = 500
start = 0.0001
end = 0.05
device = 'cuda'
IMG_SIZE = 32
BATCH_SIZE = 512

betas = torch.linspace(start, end, T)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def get_index_from_list(vals, t, x_shape):
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def classifier_gradient(model, x_t, t, y, scale):
    x_t = x_t.detach().clone().requires_grad_(True)
    B, C, H, W = x_t.shape
    timestep = (t.float() / T)
    timestep = timestep.view(B, 1, 1, 1)
    timestep = timestep.expand(B, 1, H, W).to(x_t.device)
    classifier_input = torch.cat([x_t, timestep], dim=1)

    model.eval()
    logits = model(classifier_input)
    log_probs = F.log_softmax(logits, dim=1)
    log_prob_of_y = log_probs.gather(1, y.view(-1,1)).squeeze(1)
    loss_for_grad = -log_prob_of_y.sum()  
    model.zero_grad()
    loss_for_grad.backward(retain_graph=True)

    grad_log = -x_t.grad.detach()  

    grad_scaled = scale * grad_log
    return grad_scaled

@torch.no_grad()
def sample_timestep(model_diff, model_cls, x, t, y, scale):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model_diff(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        # As pointed out by Luis Pereira (see YouTube comment)
        # The t's are offset from the t's in the paper
        return model_mean
    else:
        noise = torch.randn_like(x)
        cls_grad = classifier_gradient(model_cls, x, t, y, scale)
        model_mean = model_mean + torch.sqrt(posterior_variance_t) * noise  * scale * cls_grad
        return model_mean + torch.sqrt(posterior_variance_t) * noise

def show_images(img_name, dataset, num_samples=10, cols=5, save = True):
    """ Plots some samples from the dataset """
    plt.figure(figsize=(15,15)) 
    for i, img in enumerate(dataset):
        img = img.detach().to('cpu')
        if i == num_samples:
            break
        plt.subplot(int(num_samples/cols) + 1, cols, i + 1)
        
        #for mnist image only, if u need colored stuff, stick with 'convert_to_image' function in source code 
        plt.imshow(img.squeeze(dim=0), cmap='gray')
    
    if (save == True):
        name = 'output_' + str(img_name) + '.png'
        plt.savefig('/home/longvv/generative_model/diffusion-real/result/' + name)
    plt.show()

IMG_SIZE = 32
@torch.no_grad()
def classifier_guidance(model_diff, model_cls, y, scale):
    img_name = 'test' + y + '.png'
    img_size = IMG_SIZE
    img = torch.randn((1, 1, img_size, img_size), device=device)
    img_set = []
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(model_diff, model_cls, img, t, y, scale)
        # Edit: This is to maintain the natural range of the distribution
        #img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0:
            img_set.append(img.squeeze(dim = 0))
            
    show_images(img_name, img_set, save = True)
