Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

visualization #13

Open
buproof opened this issue Nov 14, 2022 · 7 comments
Open

visualization #13

buproof opened this issue Nov 14, 2022 · 7 comments

Comments

@buproof
Copy link

buproof commented Nov 14, 2022

Hi, author ! Thanks for this great code!

I want to reproduce the visualization results, but I cannot find the corresponding code in this repo.

I read the paper but I think it's not easy for me to reproduce the visualization results correctly.

May I have the code to produce the visualization results, or is there anything I missed?

Thank you very much!

@232525
Copy link
Owner

232525 commented Nov 14, 2022

I am sorry that we have not organized our visualization code. But I can provide you with a core demo code:

import cv2
import matplotlib.pyplot as plt
import skimage.transform
import numpy as np

def visualize_attention(image_path, alpha_weights):
    # load image
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # get Height, Width of image
    H, W = img.shape[:2]
    dH, dW = H//4, W//4
    
    alpha_weights = alpha_weights
    # keep the top-k weights
    k = 20
    _tmp = alpha_weights.reshape(-1)
    top_k = _tmp[_tmp.argsort()[-k]]
    alpha_weights = alpha_weights * (alpha_weights >= top_k)

    # resize the weights from (12, 12) to (H/4, W/4)
    alpha_weights = skimage.transform.resize(alpha_weights, (dH, dW))
    # expand the weights to the raw size of image
    alpha_weights = skimage.transform.pyramid_expand(alpha_weights, upscale=4, sigma=20)

    #  draw image and weights
    plt.plot()
    plt.imshow(img)
    plt.imshow(alpha_weights, alpha=0.75, cmap=plt.cm.gray)
    plt.axis('off')
    plt.show()

alpha_weights = np.array(
[5.0248e-03, 5.2091e-03, 5.0840e-03, 5.1059e-03, 3.2360e-02, 1.0273e-03,
        2.9019e-04, 3.2377e-02, 5.1033e-03, 5.3537e-04, 5.2795e-03, 5.2804e-03,
        5.0838e-03, 5.0338e-03, 5.4204e-03, 5.0996e-03, 3.2342e-02, 2.1245e-04,
        1.1865e-03, 3.2370e-02, 5.1377e-03, 2.7714e-04, 3.2365e-02, 5.4530e-03,
        2.1483e-03, 1.8291e-03, 1.3979e-04, 9.0597e-04, 5.1887e-03, 3.3162e-03,
        5.9515e-03, 5.2063e-03, 3.7140e-03, 5.7669e-03, 5.2450e-03, 5.0991e-03,
        2.2931e-03, 1.0192e-03, 1.2310e-03, 1.7673e-03, 3.2369e-02, 1.4196e-02,
        2.5353e-02, 3.2365e-02, 4.2024e-04, 5.2958e-04, 1.6338e-03, 2.3828e-03,
        5.4352e-03, 5.1889e-03, 2.1982e-03, 3.3123e-04, 3.2343e-02, 2.4629e-03,
        2.2377e-03, 1.5513e-04, 3.1852e-04, 2.2781e-04, 1.6502e-03, 9.5750e-04,
        4.8194e-04, 7.9026e-03, 9.6730e-04, 1.5098e-02, 1.7108e-03, 8.0923e-04,
        1.1966e-03, 8.3894e-04, 3.7549e-03, 5.2052e-03, 1.4130e-03, 1.9779e-03,
        1.5995e-03, 2.7751e-03, 5.5997e-03, 7.0124e-03, 2.1481e-03, 5.7834e-03,
        1.2972e-03, 1.7500e-04, 7.7323e-03, 1.8277e-03, 1.7876e-03, 1.3740e-03,
        5.2334e-03, 3.2342e-02, 8.6587e-03, 1.5491e-03, 3.2362e-02, 4.0188e-03,
        4.4041e-04, 6.4261e-04, 1.4355e-03, 8.3124e-03, 5.3338e-03, 4.9598e-03,
        5.4005e-03, 4.6577e-03, 2.1362e-02, 3.9373e-03, 3.2342e-02, 8.5086e-04,
        6.0412e-04, 1.3558e-04, 6.1554e-03, 5.5917e-03, 5.2004e-03, 1.7581e-03,
        5.1032e-03, 5.4421e-03, 3.2950e-03, 2.8823e-03, 5.1852e-03, 1.9310e-03,
        8.0221e-04, 3.6786e-04, 3.7763e-03, 6.1263e-04, 5.2769e-03, 5.0333e-03,
        5.1957e-03, 5.1147e-03, 1.5568e-03, 9.7842e-05, 3.2341e-02, 6.3536e-04,
        7.0632e-04, 5.4808e-04, 2.7613e-03, 1.5866e-03, 5.3861e-03, 3.2329e-02,
        5.5115e-03, 3.2350e-02, 1.3547e-03, 5.1975e-03, 3.2346e-02, 4.3569e-04,
        2.0293e-03, 3.2360e-02, 6.9499e-04, 2.4257e-03, 3.2363e-02, 4.9784e-03]
).reshape([12, 12])
visualize_attention('./COCO_val2014_000000483108.jpg', alpha_weights[:, :])

The result should be like this:
image

What you need to do is get the attention weight alpha_weights you want to visualize

@buproof
Copy link
Author

buproof commented Nov 16, 2022

Thank you for your prompt reply!
Now I have other questions and i hope you could help me !!!

  1. I want to know if the training is interrupted, can I start the training from where the interruption is?
  2. load_epoch、resume What do these two parameters do?
  3. What I want to visualize is When I use the trained model, I input any image (not necessarily from the data set), the output is corresponding text expression . Just like you showed in your paper
    捕获

@232525
Copy link
Owner

232525 commented Nov 16, 2022

Answer to 1 and 2: Yes, but not perfect. --load_epoch and --resume are all used to re-load weights. Specifically:

  • --resume is the epoch number of trained model that you want to re-load (for example, if you set it to 3, it will load the model of caption_model_3.pth). Note: --resume just load the weights of a model, but will not load the optimizer and scheduler (because I didn't store their state_dict).
  • --load_epoch is used to re-load the learning rate. For example, if the training was interrupted at 3rd epoch, and you want to continue training. The correct way is to re-load all detail of the interrupted checkpoint, but we just store the model weights, so --load_epoch is just a simple way to restore the learning_rate state of the interrupted checkpoint. It can not re-load the all details.

Answer to 3: Actually I save all attention weights of corresponding generated words and then generated the visualization image for each word, and finally I use Visio to manually compose the visualization images and words to fit the paper width. I have not scripted the full process into one stage.

@Lujinfu1999
Copy link

Hello, I would like to know which step of inference should be taken for attention weight, and which stage of attention weight should be taken when generating each word?Thanks

@Debolena7
Copy link

@Lujinfu1999 hello, did u find the answer to this?

Hello, I would like to know which step of inference should be taken for attention weight, and which stage of attention weight should be taken when generating each word?Thanks

@Lujinfu1999
Copy link

Lujinfu1999 commented Mar 14, 2024

I tried get attention weights from the last decoder's cross-attention's last head,maybe you can try it. @not-hermione

@Debolena7
Copy link

I tried get attention weights from the last decoder's cross-attention's last head,maybe you can try it. @not-hermione

Was the attention maps visualization convincing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants