In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')
import numpy as np
import matplotlib.pyplot as plt
device="cuda:0"
from imagenet_dict import imagenet_dict
from muse.pipeline_muse_toast import PipelineMuse

In [2]:
def display_intermediate_img_and_mask(i_batch, timestep, images, intermediate_images, mask_index, res=256):
    seq_len = 16 if res == 256 else 32
    vq_len = 256 if res == 256 else 1024
    inter_arr = [np.array(intermediate_images[i][i_batch]) for i in range(timestep)]
    mask_arr = [np.array(mask_index[i][i_batch].reshape(-1, seq_len, seq_len).squeeze().detach().cpu()) for i in range(timestep)]

    fig, axes = plt.subplots(timestep+1, 3, figsize=(10, (timestep+2)*2))

    mask_arr_np = np.concatenate([np.zeros((1,seq_len,seq_len)), 1-np.array(mask_arr)])
    for i in range(timestep):
        previous = mask_arr_np[i]
        now = mask_arr_np[i+1]
        approved = previous * now
        created = now - approved
        denied = previous - approved
        image_rgb = np.zeros((seq_len, seq_len, 3))
        image_rgb[created == 1] = [0, 1, 1]  # Sky Blue for created
        image_rgb[approved == 1] = [0, 1, 0]  # Green for approved
        image_rgb[denied == 1] = [1, 0, 0]  # Red for denied
        
        axes[i][0].imshow(inter_arr[i], interpolation='nearest', aspect='equal')
        axes[i][1].matshow(1-mask_arr[i], cmap='gray')
        axes[i][2].imshow(image_rgb, interpolation='nearest', aspect='equal')
        
        axes[i][0].axis("off")
        axes[i][1].axis("off")
        axes[i][2].axis("off")
        
        # Insert the count of approved pixels between the subplots
        unmasked = now.sum() / vq_len
        prev_unmasked = previous.sum() / vq_len
        approved = approved.sum() / vq_len
        denied = denied.sum() / vq_len
        created = created.sum() / vq_len
        txt = f'prev_unmasked: {prev_unmasked:.2f} || unmasked: {unmasked:.2f} || approved: {approved:.2f} || denied: {denied:.2f} || created: {created:.2f}'
        
        axes[i][1].text(0.5, 1.05, txt, 
                        horizontalalignment='center', 
                        verticalalignment='center', 
                        transform=axes[i][1].transAxes)

    axes[timestep][0].imshow(images[i_batch], interpolation='nearest', aspect='equal')
    axes[timestep][0].axis("off")
    axes[timestep][1].axis("off")
    axes[timestep][2].axis("off")
        
    fig.tight_layout() 
    fig.show()

In [None]:
pipe = PipelineMuse.from_pretrained(transformer_path="../results_corr/ft_256_toast_cls_b256_corr/checkpoint-50000/ema_model", 
                                    is_class_conditioned=True, use_toast=True, vae_path="../scripts/tokenizer_imagenet256_torch/").to(device)
# pipe = PipelineMuse.from_pretrained(transformer_path="../results_corr/ft_512_toast_cls_b256_corr/checkpoint-50000/ema_model", 
#                                     is_class_conditioned=True, use_toast=True, vae_path="../scripts/tokenizer_imagenet512_torch/").to(device)
pipe.transformer.eval()
pipe.vae.eval()
print("Loaded model")

In [None]:
# Scene class id of Imagenet

class_ids = 105 # 105: koala 454: bookstore

timesteps = 18
images, intermediate_images, intermediate, mask_index = pipe(class_ids=class_ids, num_images_per_prompt=4, 
                                   timesteps=timesteps, temperature=10, sampling_type='self_guidance', #maskgit or self_guidance
                                   return_intermediate=True, guidance_scale=1.0)
print(imagenet_dict[class_ids])
for i in range(4):
    display(images[i])

In [None]:
i_batch = 0
display_intermediate_img_and_mask(i_batch, timesteps, images, intermediate_images, mask_index, res=256)