In [None]:
使用Flash Attention 2.0针对注意力头进行加速推理

In [2]:
import torch
import requests
from PIL import Image
from transformers import SiglipProcessor, SiglipModel
device = "cuda" # the device to load the model onto

model = SiglipModel.from_pretrained(
    "/root/autodl-tmp/siglip-so400m-patch14-384",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.float16,
    device_map=device,
)
processor = SiglipProcessor.from_pretrained("/root/autodl-tmp/siglip-so400m-patch14-384")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

candidate_labels = ["2 cats", "2 dogs"]
texts = [f'This is a photo of {label}.' for label in candidate_labels]
inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
inputs.to(device)

with torch.no_grad():
    with torch.autocast(device, dtype=torch.float16):
        outputs = model(**inputs)

logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image)
print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")

51.3% that image 0 is '2 cats'


In [None]:
PyTorch包含一个原生缩放点积注意力 （SDPA） 运算符，作为 torch.nn.functional 的一部分。此功能 包含多种实现，这些实现可以根据 inputs 和使用的硬件进行应用，这里需要torch>=2.1.1才可以使用。

In [3]:
import torch
import requests
from PIL import Image
from transformers import SiglipProcessor, SiglipModel
device = "cuda"

model = SiglipModel.from_pretrained(
    "/root/autodl-tmp/siglip-so400m-patch14-384",
    attn_implementation="sdpa",
    torch_dtype=torch.float16,
    device_map=device,
)
processor = SiglipProcessor.from_pretrained("/root/autodl-tmp/siglip-so400m-patch14-384")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

candidate_labels = ["2 cats", "2 dogs"]
texts = [f'This is a photo of {label}.' for label in candidate_labels]
inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
inputs.to(device)

with torch.no_grad():
    with torch.autocast(device, dtype=torch.float16):
        outputs = model(**inputs)

logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image)
print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")

51.3% that image 0 is '2 cats'
