In [43]:
import torch
import yaml
import os
import numpy as np
from normalizing_flows import models
import tempfile
import matplotlib.pyplot as plt
import imageio

In [3]:
# load model parameters
params = yaml.safe_load(open("./model/params.json"))

In [26]:
torch.manual_seed(42)
glow = models.Glow(K=params["model_params"]["num_step_of_flow"],
                   L=params["model_params"]["num_scales"],
                   base_input_shape=[3, 64, 64]).to(torch.device("cuda:0"))
# manually download the model from 
# 
ckpt = torch.load("./model/best_model_99_20230701_133810.pt")
glow.load_state_dict(ckpt["model_state_dict"])

<All keys matched successfully>

In [58]:
# number of samples in the grid
n_row, n_col = 3, 3
temperatures = np.linspace(start=0., stop=1., num=100)
seed = 42

In [61]:
with tempfile.TemporaryDirectory() as temp_directory:
    for T in temperatures:
        generated_image = glow.sample(T=T, fixed_sample=True,
                                      seed=seed, num_samples=n_row*n_col,
                                      device=torch.device("cuda:0"))
        generated_image = generated_image.clamp(-.5, +.5) + .5
        generated_image = generated_image.view(n_row, n_col, 3, 64, 64)
        generated_image = generated_image.detach().cpu().numpy()
        
        fig, ax = plt.subplots(n_row, n_col, squeeze=True,
                                subplot_kw=dict(aspect=1),
                                gridspec_kw=dict(hspace=0, wspace=0))
        fig.subplots_adjust(wspace=0.02, hspace=0.02)
        
        for i in range(n_row):
            for j in range(n_col):
                ax[i, j].axis("off")
                ax[i, j].imshow(generated_image[i, j].transpose(1, 2, 0))
        plt.suptitle(f"temperature {T}")
        plt.savefig(os.path.join(temp_directory, f"image_{T}.png"))
        plt.close()
        
        
    # generate a gif out of samples
    filenames =\
    map(lambda image_path: os.path.join(temp_directory, image_path),
    [img_path for img_path in sorted(os.listdir(temp_directory)) if img_path.endswith("png")])

    with imageio.get_writer('./temperatures.gif', mode='I') as writer:
        for filename in filenames:
            image = imageio.imread(filename)
            writer.append_data(image)



![SegmentLocal](temperatures.gif "segment")