In [None]:
import os
import json
from tqdm import tqdm
from PIL import Image
from instructg2i import image_grid, InstructG2IGuidePipeline, get_neighbor_transforms
import numpy as np

In [2]:
device='cuda:0'

In [None]:
# read metadata

meta = []
with open('data/Art/test/metadata.jsonl') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        meta.append(json.loads(line))

img_dir = f'/data/Art/test'

### Generation

In [4]:
model_dir = "PeterJinGo/VirtualArtist"

In [None]:
resolution=256
neighbor_num = 5
pipeline = InstructG2IGuidePipeline.from_pretrained(model_dir, neighbor_num, device=device)

In [None]:
# from utils import image_grid
rand_meta = {'center': 'Huis In Boomrijke Omgeving##TwEBS4xd_HohOw.jpg', 'text': 'Huis In Boomrijke Omgeving', 'neighbors': ['Landschap Met Bomen Aan Het Water##MAE0t4GUCt8feA.jpg', 'Bosweg Met Twee Huizen##ogG_bhcOYCmVbQ.jpg']}
print(rand_meta)
gt_img = Image.open(os.path.join(img_dir, rand_meta['center']))

# graph2image pipeline
neighbor_transforms = get_neighbor_transforms(resolution)

image = Image.open(os.path.join(img_dir, rand_meta['center'])).convert("RGB")
neighbor_image = [neighbor_transforms(Image.open(os.path.join(img_dir, n_file)).convert("RGB")) for n_file in rand_meta['neighbors'][:neighbor_num]]
neighbor_mask = [1] * len(neighbor_image)
neighbor_image += [neighbor_transforms(Image.fromarray(np.uint8(np.zeros_like(np.array(image)))).convert('RGB'))] * (neighbor_num - len(neighbor_image))
neighbor_mask += [0] * (neighbor_num - len(neighbor_mask))

gen_imgs = []
guidance_scales = [0, 2, 10] # [5, 7.5, 10]
graph_guidance_scales = [0, 2, 10] # [5, 7.5, 10]
for guidance_scale in guidance_scales:
    for graph_guidance_scale in graph_guidance_scales:
        gen_img = pipeline(prompt=rand_meta['text'], 
                                    neighbor_image=neighbor_image, 
                                    neighbor_mask=neighbor_mask, 
                                    num_inference_steps=100,
                                    guidance_scale=guidance_scale,
                                   graph_guidance_scale=graph_guidance_scale).images[0]
        gen_imgs.append(gen_img)

res_grid = image_grid(gen_imgs, len(guidance_scales), len(graph_guidance_scales))
neighbor_grid = image_grid([Image.open(os.path.join(img_dir, img_file)) for img_file in rand_meta['neighbors'][:6]], 1, len(rand_meta['neighbors'][:6]))

print('********   Ground Truth   ********')
gt_img.show()
print('********   Generated results (row: text_guidance_scales; col: graph_guidance_scales)   ********')
res_grid.show()
# print('********   Neighbors   ********')
# neighbor_grid.show()

# columns have different graph rate
# rows have different text rate