# Explainability-aided text-guided image editing

An adaptation of StyleCLIP (https://github.com/orpatashnik/StyleCLIP) optimization notebook, that uses explainability-guided loss

## Setup

In [None]:
#@title Setup

!git clone https://github.com/apple/ml-no-token-left-behind.git
import os
os.chdir(f'ml-no-token-left-behind')

!pip install ftfy regex tqdm captum torchvision opencv-python matplotlib

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# downloads StyleGAN's weights and facial recognition network weights
ids = ['1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT', '1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL']
for file_id in ids:
  downloaded = drive.CreateFile({'id':file_id})
  downloaded.FetchMetadata(fetch_all=True)
  downloaded.GetContentFile(downloaded.metadata['title'])

In [None]:
#@title Explainabilitiy utils
import torch
import external.TransformerMMExplainability
import external.TransformerMMExplainability.CLIP.clip as clip
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt
from captum.attr import visualization
from torchvision import transforms


# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

def get_desired_tokens_from_words(desired_words, text):
    if desired_words is None:
        num_of_tokens = len(_tokenizer.encode(text))
        desired_tokens = torch.full((num_of_tokens,), 1 / num_of_tokens)
    
    else:

        target_words = text.split(" ")
        desired_words = np.array(desired_words.split(' ')).astype(float)
        desired_tokens = torch.zeros((len(_tokenizer.encode(text))))

        token_id = 0
        for word_idx, word in enumerate(target_words):
            num_of_tokens = len(_tokenizer.encode(word))

            for t in range(num_of_tokens):
                desired_tokens[token_id] = desired_words[word_idx]
                token_id = token_id + 1
                
        if desired_tokens.min() != desired_tokens.max():
            desired_tokens /= desired_tokens.max()
        else:
            desired_tokens = desired_tokens / len(desired_tokens)
    
    return desired_tokens

def interpret(image, text, model, device, desired_words, index=None, softmax_temp=10.):
    model.zero_grad()
    desired_words = get_desired_tokens_from_words(desired_words, text)
    
    text = clip.tokenize([text]).to(device)
    CLS_idx = text.argmax(dim=-1)
    

    with torch.enable_grad():
        image = image.detach().clone().requires_grad_() # TODO: what should require grad? the model? I need the image
        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
        if index is None:
            index = np.argmax(logits_per_image.cpu().data.numpy(), axis=-1)
        one_hot = np.zeros((1, logits_per_image.size()[-1]), dtype=np.float32)
        one_hot[0, index] = 1
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)
        one_hot =  torch.sum(one_hot.to(logits_per_image.device) * logits_per_image)

        image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
        num_tokens = image_attn_blocks[0].attn.attn_output_weights.shape[-1]
        R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn.attn_output_weights.dtype).to(logits_per_image.device)
        for blk_idx, blk in enumerate(image_attn_blocks):
            grad = torch.autograd.grad(one_hot, [blk.attn.attn_output_weights], retain_graph=True)[0].detach()
            cam = blk.attn.attn_output_weights.detach()
            cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
            grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
            cam = grad * cam
            cam = cam.clamp(min=0).mean(dim=0)
            R = R + torch.matmul(cam, R)
        R[0, 0] = 0
        image_relevance = R[0, 1:]

        text_attn_blocks = list(dict(model.transformer.resblocks.named_children()).values())
        num_tokens = text_attn_blocks[0].attn.attn_output_weights.shape[-1]
        R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attn.attn_output_weights.dtype).to(logits_per_image.device)
        for blk_idx, blk in enumerate(text_attn_blocks):
            grad = torch.autograd.grad(one_hot, [blk.attn.attn_output_weights], retain_graph=True)[0].detach()
            cam = blk.attn.attn_output_weights.detach()
            cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
            grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
            cam = grad * cam
            cam = cam.clamp(min=0).mean(dim=0)
            R_text = R_text + torch.matmul(cam, R_text)
        text_relevance = R_text[CLS_idx, 1:CLS_idx]
        text_relevance = text_relevance / text_relevance.sum()
        text_relevance = text_relevance / text_relevance.max()
        target_word_expl_score = (text_relevance * desired_words.to(logits_per_image.device))

    image_relevance = image_relevance.reshape(1, 1, 7, 7)
    image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
    image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy()
    image_relevance = image_relevance / image_relevance.sum()
    image_relevance = image_relevance / image_relevance.max()
    image = image[0].permute(1, 2, 0).data.cpu().numpy()
    image = (image - image.min()) / (image.max() - image.min())
    vis = show_cam_on_image(image, image_relevance)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)

    plt.imshow(vis)
    model.zero_grad()
    return text_relevance.detach(), logits_per_image.detach(), target_word_expl_score.detach()

class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

from external.TransformerMMExplainability.CLIP.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()

def show_heatmap_on_text(text, text_encoding, R_text):
  text_scores = R_text.flatten()
#   print(text_scores)
  text_tokens=_tokenizer.encode(text)
  text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
  vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
  visualization.visualize_text(vis_data_records)
  return text_scores

def visualize_explainability(image, text, desired_words, model, width):
    source_image = torch.nn.AvgPool2d(kernel_size=width // 32)(
        torch.nn.Upsample(scale_factor=7)(image.to(clip_device)))

    R_text, _, _ = interpret(image=source_image, 
                             text=text,
                             model=model,
                             desired_words=desired_words,
                             device=clip_device, 
                             index=0)
        
    show_heatmap_on_text(text, clip.tokenize([text]), R_text)
    plt.show()



def show_heatmap_on_text(text, text_encoding, R_text, softmax_temp=2.):
  CLS_idx = text_encoding.argmax(dim=-1)
  R_text = R_text[CLS_idx, 1:CLS_idx]
  text_scores = R_text / R_text.sum()
  text_scores = text_scores.flatten()
  text_tokens=_tokenizer.encode(text)
  text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
  vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
  visualization.visualize_text(vis_data_records)

## Run configuration
Please enter your manipulation text in th description, desired words should be 1 for words that are part of the semnatic change in the image

In [None]:
description = 'A person with purple hair' #@param {type:"string"}

desired_words = "0 0 0 1 1"#@param {type:"string"}

optimization_steps = 200 #@param {type:"number"}

expl_lambda = 1.0 #@param {type:"number"}

l2_lambda = 0.008 #@param {type:"number"}

id_lambda = 0.005 #@param {type:"number"}

stylespace = False #@param {type:"boolean"}


In [None]:
use_seed = True #@param {type:"boolean"}

seed =  0#@param {type: "number"}

## Run

In [None]:
#@title Additional Arguments
args = {
    "description": description,
    "desired_words": desired_words,
    "expl_lambda": expl_lambda,
    "ckpt": "stylegan2-ffhq-config-f.pt",
    "stylegan_size": 1024,
    "lr_rampup": 0.05,
    "lr": 0.1,
    "step": optimization_steps,
    "mode": "edit",
    "l2_lambda": l2_lambda,
    "id_lambda": id_lambda,
    'work_in_stylespace': stylespace,
    "latent_path": None,
    "truncation": 0.7,
    "save_intermediate_image_every": 20,
    "results_dir": "results",
    "ir_se50_weights": "model_ir_se50.pth",
    "loss_type": "EGCLIPLoss"
}

In [None]:
if use_seed:
  import torch
  torch.manual_seed(seed)
from external.StyleCLIP.optimization.run_optimization import main
from argparse import Namespace
result = main(Namespace(**args))

In [None]:
#@title Visualize Result
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
result_image = ToPILImage()(make_grid(result[0].detach().cpu(), normalize=True, scale_each=True, range=(-1, 1), padding=0))
h, w = result_image.size
result_image.resize((h // 2, w // 2))