/
core.py
79 lines (65 loc) · 2.47 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from ... import imports as I
from ...torch_base import TorchBase
class ImageCaptioner(TorchBase):
"""
interface to Image Captioner
"""
def __init__(self, model_name="ydshieh/vit-gpt2-coco-en", device=None):
"""
```
ImageCaptioner constructor
Args:
model_name(str): name of image captioning model
device(str): device to use (e.g., 'cuda', 'cpu')
```
"""
if not I.PIL_INSTALLED:
raise Exception(
"PIL is not installed. Please install with: pip install pillow>=9.0.1"
)
super().__init__(
device=device, quantize=False, min_transformers_version="4.12.3"
)
self.model_name = model_name
from transformers import (
AutoTokenizer,
VisionEncoderDecoderModel,
ViTFeatureExtractor,
)
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name).to(
self.torch_device
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.extractor = ViTFeatureExtractor.from_pretrained(self.model_name)
def caption(self, images):
"""
```
Performs image captioning
This method supports a single image or a list of images. If the input is an image, the return
type is a string. If text is a list, a list of strings is returned
Args:
images: image|list
Returns:
list of captions
```
"""
# Convert single element to list
values = [images] if not isinstance(images, list) else images
# Open images if file strings
values = [
I.Image.open(image) if isinstance(image, str) else image for image in values
]
# Feature extraction
pixels = self.extractor(images=values, return_tensors="pt").pixel_values
pixels = pixels.to(self.torch_device)
# Run model
import torch
with torch.no_grad():
outputs = self.model.generate(
pixels, max_length=16, num_beams=4, return_dict_in_generate=True
).sequences
# Tokenize outputs into text results
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
captions = [caption.strip() for caption in captions]
# Return single element if single element passed in
return captions[0] if not isinstance(images, list) else captions