In [None]:
import torch
from DiffusionFreeGuidence.DiffusionCondition import GaussianDiffusionSampler
from DiffusionFreeGuidence.ModelCondition import UNet
from torchvision.utils import save_image
import torchvision.utils as vutils
import os
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
plt.rcParams['animation.html'] = 'jshtml'
plt.rcParams["figure.figsize"] = (7.0, 7.0) 
plt.rcParams['animation.embed_limit'] = 300


In [None]:
modelConfig = {
        "state": "eval", # or eval
        "epoch": 70,
        "batch_size": 100,
        "T": 500,
        "channel": 128,
        "channel_mult": [1, 2, 2, 2],
        "num_res_blocks": 2,
        "dropout": 0.15,
        "lr": 1e-4,
        "multiplier": 2.5,
        "beta_1": 1e-4,
        "beta_T": 0.028,
        "img_size": 32,
        "grad_clip": 1.,
        "device": "cuda:3",
        "w": 1.8,
        "save_dir": "./CheckpointsCondition/",
        "training_load_weight": None,
        "test_load_weight": "ckpt_69_.pt",
        "sampled_dir": "./SampledImgs/",
        "sampledNoisyImgName": "NoisyGuidenceImgs.png",
        "sampledImgName": "SampledGuidenceImgs.png",
        "nrow": 10,
    }

device = torch.device(modelConfig["device"])

# load model and evaluate
with torch.no_grad():
    step = int(modelConfig["batch_size"] // 10)
    label_list = []
    for i in range(10):
        label_list += [torch.ones(size=[1]).long() * i] * 10
    labels = torch.cat(label_list, dim=0).long().to(device) + 1
    model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
    ckpt = torch.load(os.path.join(
        modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device)
    model.load_state_dict(ckpt)
    sampler = GaussianDiffusionSampler(
        model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)

    noisyImage = torch.randn(
        size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)
    saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
    sampledImg, sampledImg_list= sampler(noisyImage, labels, diffusion_process = True)
    sampledImg = sampledImg * 0.5 + 0.5  # [0 ~ 1]   

In [None]:
img = vutils.make_grid(sampledImg, nrow=10, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
img = img.permute(1, 2, 0).cpu().numpy()
plt.imshow(img)
plt.axis('off')
plt.show()

In [None]:
import matplotlib.animation as animation
img_list = []
for img in sampledImg_list:
    img = img * 0.5 + 0.5
    img = vutils.make_grid(img, nrow=10, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
    img = img.permute(1, 2, 0).cpu().numpy()
    img_list.append(img)

img_list = [img_list[1]]+[img_list[3]]+img_list[6:]
img_list = img_list+img_list[::-1][1:-1]

fig, ax = plt.subplots()

def init():
    ax.set_axis_off()
    
def update(i):
    ax.imshow(img_list[i])
    return ax

ani = FuncAnimation(fig, init_func = init, func = update,frames=len(img_list), interval=200, repeat_delay=100)
plt.close()
ani

In [None]:
# Create a writer instance
writer = animation.PillowWriter(fps=5)

# Save the animation as a GIF file
ani.save('animation.gif', writer=writer)