In [2]:
from PIL import Image
import torch
import nopdb
from tqdm.auto import tqdm

from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('creating base model...')
base_name = 'base40M' # use base300M or base1B for better results
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])

print('creating upsample model...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])

print('downloading base checkpoint...')
base_model.load_state_dict(load_checkpoint(base_name, device))

print('downloading upsampler checkpoint...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))

creating base model...
creating upsample model...
downloading base checkpoint...
downloading upsampler checkpoint...


<All keys matched successfully>

In [4]:
sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model],
    diffusions=[base_diffusion, upsampler_diffusion],
    num_points=[1024, 4096 - 1024],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 3.0],
)

In [5]:
# Load an image to condition on.
img = Image.open('example_data/cube_stack.jpg')

# Produce a sample from the model.
samples = None
with nopdb.capture_call(base_model.backbone.resblocks[-1].attn.attention.forward) as attn_call:
    for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):
        samples = x

0it [00:00, ?it/s]

In [6]:
print(attn_call.locals.keys())
print(attn_call.locals['attn_ch'])
# print(attn_call.locals['x'].shape)

dict_keys(['self', 'qkv', 'bs', 'n_ctx', 'width', 'attn_ch', 'scale', 'q', 'k', 'v', 'weight', 'wdtype'])
64


In [7]:
print(attn_call.locals['q'].shape)
print(attn_call.locals['k'].shape)
print(attn_call.locals['v'].shape)
print(attn_call.locals['weight'].shape)
batch, heads, target, source = attn_call.locals['weight'].shape

torch.Size([2, 1281, 8, 64])
torch.Size([2, 1281, 8, 64])
torch.Size([2, 1281, 8, 64])
torch.Size([2, 8, 1281, 1281])


In [8]:
import math
scale = 1 / math.sqrt(math.sqrt(attn_call.locals['attn_ch']))

def reshape_heads_to_batch_dim(tensor):
        batch_size, seq_len, heads, dim = tensor.shape
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * heads, seq_len, dim)
        return tensor
#Manually extract the query and key tensors and combine them as in transformer module, to obtain the attention map.
new_q = reshape_heads_to_batch_dim(attn_call.locals['q'])
new_k = reshape_heads_to_batch_dim(attn_call.locals['k'])
attention_scores = torch.einsum("b i d, b j d -> b i j", new_q, new_k) * scale

attention_probs = attention_scores.softmax(dim=-1)
print(attention_probs.shape)

torch.Size([16, 1281, 1281])


: 

In [9]:
print(attention_probs)
def compute_ca_loss(attn_map, bboxes, object_positions):
    loss = 0
    object_number = len(bboxes)
    if object_number == 0:
        return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
    for attn_map_integrated in attn_map:
        #This chunks the attention mid from [16,1281,1281] to [8,1281,1281]. We pick the second half. Why?
        attn_map = attn_map_integrated.chunk(2)[1]
        #Then we extract the dimensions. b= 8 i=1281 j=1281
        b, i, j = attn_map.shape
        '''
        For creating the dimensions of the mask we need to think what makes sense in 3D space...
        For 2D case we can just sqrt the latent dimension to create a square, because images are squares.
        Then represent the mask as that square, but now we have 3D space...
        Furthermore, we have that 257 dimensions are added because 256 from the image embedding and 1 from the timestep.
        So the embedding dimensions are not only containing information about the point cloud. 
        '''
        H = W = D = int(math.sqrt(i))
        #Loop for the amount of objects (basically for how many bounding boxes we created)
        for obj_idx in range(object_number):
            obj_loss = 0 #per object "loss"
            #We create a mask of all zeros using the sqrt i dimension, in this case 
            mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W))
            for obj_box in bboxes[obj_idx]:
                #Extract the corners of the bounding boxes and set the mask matrix to 1 at locations inside the bounding box.
                x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
                    int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
                mask[y_min: y_max, x_min: x_max] = 1
            #Object_position in example case looks like: [[2,3],[10]]
            for obj_position in object_positions[obj_idx]: #(1) 2 3 (2) 10
                #Third dimension corresponds then to attention map of words. Picking the location of the word that has a bounding box means getting that specific attention for that word.
                #Originally attn_map[:,:,obj_position] has shape [4,64,1]
                ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) #Reshape the specific attention map of that word into [4,8,8]
                '''
                obj_position is an integer which corresponds to the index in the query of that specific word. So here we assume that selecting the index will give us the attention
                map of each word in the query. But why is it length 77?
                '''
                #Multiply this specific attention map with the mask, merge the final two dimensions and sum along them. We divide by the specific attention map of this particular word.
                activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1)
                #Intuitively: We amplify attention values of this specific object within the bounding box (indicated by mask), to then guide to model to attend more to this area for this object?
                obj_loss += torch.mean((1 - activation_value) ** 2) #if for example an object has multiple words associated with it like "hello kitty", we sum the "losses" of both words.
            loss += (obj_loss/len(object_positions[obj_idx]))

tensor([[[1.9516e-02, 6.8748e-05, 9.5237e-04,  ..., 1.1808e-04,
          1.1010e-06, 1.8025e-05],
         [3.0395e-04, 8.0177e-03, 1.9105e-04,  ..., 1.7302e-06,
          3.3840e-06, 1.1612e-06],
         [9.2774e-04, 1.0865e-04, 8.8315e-04,  ..., 9.0119e-05,
          6.5337e-06, 8.9934e-06],
         ...,
         [6.3521e-06, 4.5698e-08, 4.6456e-08,  ..., 3.5422e-02,
          4.3947e-05, 2.9925e-05],
         [1.6877e-09, 4.0746e-09, 3.5010e-10,  ..., 2.7414e-06,
          4.8975e-02, 1.6963e-05],
         [2.6988e-07, 6.0006e-08, 1.0788e-08,  ..., 1.5673e-05,
          2.8754e-04, 1.1290e-02]],

        [[9.8665e-02, 2.7145e-04, 2.1516e-03,  ..., 1.3662e-05,
          1.1435e-06, 9.4908e-06],
         [3.6014e-04, 6.7240e-03, 2.2407e-03,  ..., 2.2885e-05,
          1.0653e-04, 2.2251e-05],
         [2.3039e-04, 2.9884e-04, 2.9164e-03,  ..., 5.8499e-07,
          9.3270e-06, 4.3015e-06],
         ...,
         [6.4138e-09, 1.8667e-09, 8.8785e-10,  ..., 1.1489e-01,
          1.239

: 

In [1]:
pc = sampler.output_to_point_clouds(samples)[0]
fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75)))

NameError: name 'sampler' is not defined