In [None]:
import re
import os
import random
from tqdm import tqdm
from PIL import Image
from collections import defaultdict
from instructg2i import image_grid, InstructG2IMultiGuidePipeline, get_neighbor_transforms
import numpy as np
import pickle

In [2]:
random.seed(42)

### Read the Artist Database

In [3]:
device='cuda:0'
csv_path = 'data/Art/label_list.csv'
img_dir = 'data/Art/train'

In [4]:
AUTHOR = 0
PAINTING_NAME = 1
PAINTING_PATH = -1

In [5]:
def path_to_name(path):
    try:
        return re.findall("[^/]+$", path)[-1]
    except:
        return False

In [6]:
with open('data/Art/valid_img.pkl','rb') as f:
   valid_img = pickle.load(f)

In [None]:
author2img = defaultdict(set)
pics = {}
seen_titles = set()

with open(csv_path) as f:
    readin = f.readlines()
    random.shuffle(readin)
    for line in tqdm(readin):
        tmp = line.strip().split('|')
        
        file_name = path_to_name(tmp[PAINTING_PATH])
        if file_name not in valid_img or tmp[AUTHOR] == '' or tmp[AUTHOR] == 'NoAuthor':
            continue
                
        if '##' in tmp[PAINTING_NAME]:
            title = tmp[PAINTING_NAME].split('##')[0]
        else:
            title = tmp[PAINTING_NAME]
        
        if title not in seen_titles:
            seen_titles.add(title)
            author2img[tmp[AUTHOR]].add(title)
            pics[title] = {'title': title, 'author': tmp[AUTHOR], 'file_name': file_name}

print(f'Number of pictures: {len(pics)}')
print(f'Number of artists {len(author2img)}')

## Generation

In [8]:
model_dir = "PeterJinGo/VirtualArtist"

In [None]:
resolution=256
neighbor_num = 5
pipeline = InstructG2IMultiGuidePipeline.from_pretrained(model_dir, neighbor_num, device=device)

In [None]:
## You can use this cell to search if your interested artist exist in the database or not
## If it exist, this cell will print out the names exist in the database and please use the printed name as author_name for the cell below
search_name = 'Pablo Picasso'  # Vincent van Gogh, Pablo Picasso

for author_name in author2img:
    if author_name[:len(search_name)] == search_name:
        print(author_name)

In [11]:
## You can change the text and the artist (printed by the above cell) below

text_prompt = 'a man playing piano'  # a man playing soccer, a man playing piano
author_names = ['Pablo Picasso', 'Gustave Courbet']  # Vincent van Gogh, Pablo Picasso, Caravaggio, Gustave Courbet, Salvador Dali, Max Beckmann
neighbor_num = 5
scale_as = [0, 7.5, 15]
scale_bs = [0, 7.5, 15]

In [None]:
# multi image generation
all_neighbors = []
for author_name in author_names:
    tmp_neighbors = list(author2img[author_name])
    random.shuffle(tmp_neighbors)
    assert len(tmp_neighbors) > 0
    all_neighbors.append(tmp_neighbors[:neighbor_num])

assert len(all_neighbors) == len(author_names)

# graph2image pipeline
neighbor_transforms = get_neighbor_transforms(resolution)

neighbor_images = []
neighbor_masks = []
for neighbors in all_neighbors:
    image = Image.open(os.path.join(img_dir, pics[neighbors[0]]['file_name'])).convert("RGB")
    neighbor_image = [neighbor_transforms(Image.open(os.path.join(img_dir, pics[n_file]['file_name'])).convert("RGB")) for n_file in neighbors]
    neighbor_mask = [1] * len(neighbor_image)
    neighbor_image += [neighbor_transforms(Image.fromarray(np.uint8(np.zeros_like(np.array(image)))).convert('RGB'))] * (neighbor_num - len(neighbor_image))
    neighbor_mask += [0] * (neighbor_num - len(neighbor_mask))
    neighbor_images.append(neighbor_image)
    neighbor_masks.append(neighbor_mask)

image_gens = []
for scale_a in scale_as:
    for scale_b in scale_bs:
        graph_guidance_scales = [scale_a, scale_b]

        image_gen = pipeline(prompt=text_prompt, 
                                        neighbor_images=neighbor_images, 
                                        neighbor_masks=neighbor_masks, 
                                        graph_guidance_scales=graph_guidance_scales,
                                        num_inference_steps=100).images[0]
        image_gens.append(image_gen)

print(f'Generated Image {author_names}, {scale_as} & {scale_bs}, {text_prompt}')
res_grid = image_grid(image_gens, len(scale_as), len(scale_bs))
res_grid.show()

# show the samples neighbors
# for author_name, neighbors in zip(author_names, all_neighbors):
#     neighbor_grid = image_grid([Image.open(os.path.join(img_dir, pics[img_file]['file_name'])) for img_file in neighbors], 1, neighbor_num)
#     print(f'Neighbors {author_name}:')
#     neighbor_grid.show()