In [1]:
!pip install open_flamingo

Collecting open_flamingo
  Downloading open_flamingo-2.0.1-py3-none-any.whl.metadata (14 kB)
Collecting einops-exts (from open_flamingo)
  Downloading einops_exts-0.0.4-py3-none-any.whl.metadata (621 bytes)
INFO: pip is looking at multiple versions of open-flamingo to determine which version is compatible with other requirements. This could take a while.
Collecting open_flamingo
  Downloading open_flamingo-2.0.0-py3-none-any.whl.metadata (13 kB)
  Downloading open_flamingo-0.0.2-py3-none-any.whl.metadata (12 kB)
Collecting braceexpand (from open_flamingo)
  Downloading braceexpand-0.1.7-py2.py3-none-any.whl.metadata (3.0 kB)
Collecting webdataset (from open_flamingo)
  Downloading webdataset-1.0.2-py3-none-any.whl.metadata (12 kB)
Collecting inflection (from open_flamingo)
  Downloading inflection-0.5.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting open-clip-torch (from open_flamingo)
  Downloading open_clip_torch-3.2.0-py3-none-any.whl.metadata (32 kB)
Collecting ftfy (from open-c

In [None]:
from json import decoder
from open_flamingo import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="anas-awadalla/mpt-7b",   # another one (bigscience/bloom-560m)
    tokenizer_path="anas-awadalla/mpt-7b",       # another one (bigscience/bloom-560m)
    cross_attn_every_n_layers=4,
    decoder_layers_attr_name="transformer.blocks"
)

# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch

checkpoint_path = hf_hub_download("OpenFlamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)

# Manually alias "layers" for MptForCausalLLM compatibility with Flamingo's forward method
model.lang_encoder.layers = model.lang_encoder.transformer.blocks

open_clip_model.safetensors:   0%|          | 0.00/1.71G [00:00<?, ?B/s]



config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

pytorch_model.bin.index.json: 0.00B [00:00, ?B/s]

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Loading weights:   0%|          | 0/194 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Flamingo model initialized with 1384777744 trainable parameters


checkpoint.pt:   0%|          | 0.00/5.54G [00:00<?, ?B/s]

In [None]:
from PIL import Image
import requests

"""
Step 1: Load images
"""

demo_image_one = Image.open(
    requests.get(
        "http://images.cocodataset.org/val2017/000000039769.jpg",
        stream=True,
    ).raw
  )


demo_image_two = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
        stream=True,
    ).raw
)

query_image = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028352.jpg",
        stream=True,
    ).raw
)


"""
Step 2: Preprocessing images
Details: For OpenFlamingo we expect the image to be a torch tensor of shape
batch_size  x  num_media  x num_frames x channels x height x width,
In this case batch_size = 1, num_media = 3, num_frames = 1,
channels = 3, height = 224, width = 224.
"""

vision_x = [
    image_processor(demo_image_one).unsqueeze(0),
    image_processor(demo_image_two).unsqueeze(0),
    image_processor(query_image).unsqueeze(0),
]

vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)


"""
Step 3: Preprocessing text
Details: In the text we expect an <image> special token to indicate where an image is.
we also expect an <|endofchunk|> special token to indicate the end of the text
portion associated with an image.
"""

tokenizer.padding_side = "left" # For generation padding tokens should be on the left
lang_x = tokenizer(
    ["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|<image>An image of "],
    return_tensors="pt",
    # padding=True,
    # truncation=True,
)

"""
Step 4: Generate text
"""

generated_text = model.generate(
    vision_x=vision_x,
    lang_x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=8,
    num_beams=2,
)

print("Generated text: ", tokenizer.decode(generated_text[0]))
