# Attention Map Editing

If you want to mask of focus on specific elements of the attention map you have to go to models/transformer.py and edit either the mask (line 86) or the focus (line 87) variable. The options are: None, uniform, pc_to_pc, pc_to_pc_diag, img_to_pc, pc_to_img, img_to_img, cross_attention. 

In [None]:
import os, sys
import seaborn as sns
import imageio
import importlib
from PIL import Image
import matplotlib.pyplot as plt
import math
import torch
import nopdb
from tqdm.auto import tqdm
import numpy as np

# Detect local paths
local_path = !pwd
local_path = local_path[0][:-5]
sys.path.append(local_path + 'src/')

from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud

torch.cuda.empty_cache()

%load_ext autoreload
%autoreload 2

# Path
path = '../src/point_e/examples'


#### Pre-settings:
Choose either an option for mask or focus
|        |  |
|----------------|:----------------------------------------------------------------------------------------------------------------------:|
| `mask`:        |  Set the mask here, options: None, uniform, pc_to_pc, pc_to_pc_diag, img_to_pc, pc_to_img, img_to_img, cross_attention |
| `focus`:       | Set the focus here, options: None, uniform, pc_to_pc, pc_to_pc_diag, img_to_pc, pc_to_img, img_to_img, cross_attention |
| `base_name`:   | Set name of base model here, use base300M or base1B for better results                                                 |
| `img`:         | Load the image to condition on                                                                                         |
| `breakpoints`: | Define the breakpoints for attention sampling                                                                          |

In [None]:
mask = "uniform"
focus = "None"
base_name = 'base40M' # use base300M or base1B for better results
img = Image.open(path +'/example_data/cube_stack.jpg')
breakpoints = [0, 10, 20, 30, 40, 50, 60, -1]

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('creating base model...')

base_config = MODEL_CONFIGS[base_name].copy()
base_config['mask'] = mask
base_config['focus'] = focus
base_model = model_from_config(base_config, device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])

print('creating upsample model...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])

print('downloading base checkpoint...')
base_model.load_state_dict(load_checkpoint(base_name, device))

print('downloading upsampler checkpoint...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))

In [None]:
sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model],
    diffusions=[base_diffusion, upsampler_diffusion],
    num_points=[1024, 4096 - 1024],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 3.0],
)

In [None]:
def sample_from_model(breakpoint):
    samples = None
    k = 0
    with nopdb.capture_call(base_model.backbone.resblocks[-1].attn.attention.forward) as attn_call:
        for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):
            if x.shape[2] == 1024:
                samples = x
                if k == breakpoint:
                    break
                k += 1
            else:
                break
            
    attention_probs = attn_call.locals['weight'][0]

    # Average across all heads
    avg_attn = torch.mean(attention_probs, dim = 0)
    
    # Est. self attention
    pc_self_attn = avg_attn[257:, 257:]
    
    # Est. cross attention
    img_self_attn = avg_attn[:257, :257]

    # Est. cross attention
    img_to_pc_cross_attn = avg_attn[:257, 257:]

    # Est. cross attention
    pc_to_img_cross_attn = avg_attn[257:, :257]

    pc_self_attn = pc_self_attn.cpu()
    img_self_attn = img_self_attn.cpu()
    img_to_pc_cross_attn = img_to_pc_cross_attn.cpu()
    pc_to_img_cross_attn = pc_to_img_cross_attn.cpu()
    avg_attn = avg_attn.cpu()
    
    return pc_self_attn, img_self_attn, img_to_pc_cross_attn, pc_to_img_cross_attn, avg_attn, samples

Run Model and save visualizations of the different attention maps

In [None]:
# Set path to save figures to
if mask != "None":
    fig_path = path +'/Figures/Attention_Edit/Mask_' + mask
else:
    fig_path = path +'/Figures/Attention_Edit/Focus_' + focus

# Initialize directories if not exists
vis = ['pc2pc', 'img2img', 'img2pc', 'pc2img', 'full', 'pointcloud']
for i in vis:
    os.makedirs(fig_path + '/' + i + '/frames', exist_ok=True)


In [None]:
for k in breakpoints:
    samples = sample_from_model(k)
    for idx, i in enumerate(vis):
        if i == 'pointcloud':
            pc = sampler.output_to_point_clouds(samples[idx])[0]
            ax = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75)))
            plt.title('iterations = ' + str(k))
            plt.savefig(fig_path + '/pointcloud/frames/' + str(k) + '.png', bbox_inches='tight')
            plt.close()
        else:
            ax = sns.heatmap(samples[idx], cmap = 'rocket_r', cbar=False)
            plt.title('iterations = ' + str(k))
            plt.savefig(fig_path + '/' + i + '/frames/' + str(k) + '.png', bbox_inches='tight')
            plt.close()

Create GIFs

In [None]:
time = range(len(breakpoints))

for idx, i in enumerate(vis):
    if idx > 3:
        images = [fig_path + '/' + i + '/frames/' + str(k) + '.png' for k in breakpoints]
        frames = [imageio.imread(image) for image in images]
    else:
        images = [fig_path + '/' + i + '/frames/' + str(k) + '.png' for k in breakpoints]
        frames = [imageio.v2.imread(image) for image in images]
        
    imageio.mimsave(fig_path + '/' + i + '/' + i + '_attention.gif', frames, duration=2000)