# Explainability-aided Image generation

Built upon FuseDream by Xingchao Liu, Chengyue Gong, Lemeng Wu, Shujian Zhang, Hao Su and Qiang Liu (https://github.com/gnobitab/FuseDream). 


## Setup

In [None]:
!nvidia-smi

In [None]:
!git clone https://github.com/apple/ml-no-token-left-behind.git
import os
os.chdir(f'ml-no-token-left-behind')
!pip install ftfy regex tqdm numpy scipy h5py lpips==0.1.4 flair sacremoses
!pip install gdown captum
!gdown 'https://drive.google.com/uc?id=1YqbbmUijKI85WZjTdRD2mMp4CDaDUWgC'
!gdown 'https://drive.google.com/uc?id=1dr196QReWq0UWF7pQSZcbCiw0ksbexpk'
!mkdir external/FuseDream/BigGAN_utils/weights/
!cp biggan-256.pth external/FuseDream/BigGAN_utils/weights/
!cp biggan-512.pth external/FuseDream/BigGAN_utils/weights/

In [None]:
import sys
sys.path.append('./external/FuseDream')
sys.path.append('./external/FuseDream/BigGAN_utils')
sys.path.append("./external/TransformerMMExplainability")
import external.TransformerMMExplainability.CLIP.clip as clip
import torch
from tqdm import tqdm
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import torchvision
from PIL import Image
import external.FuseDream.BigGAN_utils.utils as utils
import torch.nn.functional as F
from external.FuseDream.DiffAugment_pytorch import DiffAugment
import numpy as np
from external.FuseDream.fusedream_utils import FuseDreamBaseGenerator, get_G, save_image, interpret, show_heatmap_on_text

## Setting up parameters
1. SENTENCE: The query text for generating the image. Note: we find that putting a period '.' at the end of the sentence can boost the quality of the generated images, e.g., 'A photo of a blue dog.' generates better images than 'A photo of a blue dog'.
2. INIT_ITERS: Controls the number of images used for initialization (M in the paper, and M = INIT_ITERS*10). Use the default number 1000 should work well.
3. OPT_ITERS: Controls the number of iterations for optimizing the latent variables. Use the default number 1000 should work well.
4. NUM_BASIS: Controls the number of basis images used in optimization (k in the paper). Choose from 5, 10, 15 should work well.
5. MODEL: Currently please choose from 'biggan-256' and 'biggan-512'.
6. SEED: Random seed. Choose an arbitrary integer you like.
7. LAMBDA_EXPL - the weighting of the explainability-based loss
8. NEGLECT_THRESHOLD - the threshold of relevance under which a word is considered neglected in the generated image

In [None]:
#@title Parameters
SENTENCE = "A photo of a strawberry muffin" #@param {type:"string"}
INIT_ITERS =  1000#@param {type:"number"}
OPT_ITERS = 1000#@param {type:"number"}
NUM_BASIS = 10#@param {type:"number"}
MODEL = "biggan-512" #@param ["biggan-256","biggan-512"]
SEED = 0#@param {type:"number"}
LAMBDA_EXPL = 0.1#@param {type:"number"}
NEGLECT_THRESHOLD = 0.7#@param {type:"number"}

import sys
sys.argv = [''] ### workaround to deal with the argparse in Jupyter

## Run

In [None]:
#@title Original FuseDream Generation
from external.TransformerMMExplainability.CLIP.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()

utils.seed_rng(SEED) 

sentence = SENTENCE

print('Generating:', sentence)
if MODEL == "biggan-256":
    G, config = get_G(256) 
elif MODEL == "biggan-512":
    G, config = get_G(512) 
else:
    raise Exception('Model not supported')
generator = FuseDreamBaseGenerator(G, config, 10)
z_cllt, y_cllt = generator.generate_basis(sentence,
                                          init_iters=INIT_ITERS,
                                          num_basis=NUM_BASIS,
                                          expl_lambda=0)

z_cllt_save = torch.cat(z_cllt).cpu().numpy()
y_cllt_save = torch.cat(y_cllt).cpu().numpy()
img, z, y = generator.optimize_clip_score(z_cllt,
                                          y_cllt,
                                          sentence, 
                                          latent_noise=False, 
                                          augment=True, 
                                          opt_iters=OPT_ITERS, 
                                          optimize_y=True,
                                          expl_lambda=0)
score = generator.measureAugCLIP(z, y, sentence, augment=True, num_samples=20)
print('AugCLIP score for original FuseDream result:', score)

from IPython import display

print("resulting image")
display.display(torchvision.transforms.functional.to_pil_image(torchvision.utils.make_grid(img.detach().cpu(), nrow=1, normalize=True, scale_each=True, range=(-1, 1), padding=0)))

In [None]:
#@title Check if any object is neglected

from flair.models import MultiTagger
from flair.data import Sentence
tagger = MultiTagger.load(['pos'])

normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
img_res = F.interpolate(img, size=224, mode='bilinear')
text_relevance = interpret(normalize(img_res), sentence, model=generator.clip_model, device='cuda')
text_scores = [str(score) for score in text_relevance[0].detach().cpu().numpy()]

text_expl_score = text_relevance
desired_tokens = 0
pos = {}
tag_lst = []
desired_tokens = torch.zeros((len(_tokenizer.encode(sentence))))
words_expl_scores = []
token_id = 0
text_tokens_decoded=[_tokenizer.decode([a]) for a in _tokenizer.encode(sentence)]
entire_word = ''

sentence_obj = Sentence(sentence)
tagger.predict(sentence_obj)

for label in sentence_obj.get_labels('pos'):
    print(label)
    
    entire_word = entire_word + label.data_point.text

    # if is part of token
    if text_tokens_decoded[token_id] != entire_word.lower() and \
        text_tokens_decoded[token_id] != f'{entire_word} '.lower() and \
        text_tokens_decoded[token_id].startswith(entire_word.lower()):
        continue
    else:
        tag_lst.append({'word': entire_word, 'POS': label.value})

        num_of_tokens = len(_tokenizer.encode(entire_word))
        for t in range(num_of_tokens):
            token_id = token_id + 1
        entire_word = ''

needs_our_method = False
token_id = 0
for word_idx, pos_dict in enumerate(tag_lst):
    word, pos = pos_dict['word'], pos_dict['POS']

    num_of_tokens = len(_tokenizer.encode(word))

    expl = 0
    beg_token_id = token_id
    for t in range(num_of_tokens):
        if text_expl_score[0, token_id] > expl:
            expl = text_expl_score[0, token_id]
        token_id += 1

    tag_lst[word_idx]['expl'] = expl
    tag_lst[word_idx]['tokens'] = list(range(beg_token_id, token_id))

    tag_lst[word_idx]['need_emphasize'] = False
    if pos.startswith('NN'):
        
        if expl < NEGLECT_THRESHOLD:
                tag_lst[word_idx]['need_emphasize'] = True
                needs_our_method = True
    
        



In [None]:
#@title Explainability-aided generation
if needs_our_method:
    desired_tokens = [word['tokens'] for word in tag_lst if word['need_emphasize']]
    print(desired_tokens)

    utils.seed_rng(SEED) 

    sentence = SENTENCE

    print('Generating:', sentence)
    if MODEL == "biggan-256":
        G, config = get_G(256) 
    elif MODEL == "biggan-512":
        G, config = get_G(512) 
    else:
        raise Exception('Model not supported')
    generator = FuseDreamBaseGenerator(G, config, 10) 
    z_cllt, y_cllt = generator.generate_basis(sentence,
                                              init_iters=INIT_ITERS,
                                              num_basis=NUM_BASIS,
                                              desired_tokens=desired_tokens,
                                              expl_lambda=LAMBDA_EXPL)

    z_cllt_save = torch.cat(z_cllt).cpu().numpy()
    y_cllt_save = torch.cat(y_cllt).cpu().numpy()
    img, z, y = generator.optimize_clip_score(z_cllt,
                                              y_cllt,
                                              sentence, 
                                              latent_noise=False, 
                                              augment=True, 
                                              opt_iters=OPT_ITERS, 
                                              optimize_y=True,
                                              desired_words = desired_tokens,
                                              expl_lambda=0)
    score = generator.measureAugCLIP(z, y, sentence, augment=True, num_samples=20)
    print('AugCLIP score for explainability-aided FuseDream result:', score)

    from IPython import display

    print("resulting image")
    display.display(torchvision.transforms.functional.to_pil_image(torchvision.utils.make_grid(img.detach().cpu(), nrow=1, normalize=True, scale_each=True, range=(-1, 1), padding=0)))


else:
  print("No object is neglected, no explainability-assistance is needed")