In [25]:
import requests
import torch
import transformers

# import sys
# sys.path.append(".")
from PIL import Image

from flamingo_hf.configuration_flamingo import FlamingoConfig
from flamingo_hf.modeling_flamingo import FlamingoModel
from flamingo_hf import FlamingoForConditionalGeneration

config = FlamingoConfig.from_json_file("./flamingo_hf/config.json")
# model = FlamingoForConditionalGeneration(config)

# model.load_state_dict(torch.load("/home/v-boli7/azure_storage/models/openflamingo/open_flamingo_hf.pt", map_location="cpu"), strict=False)
model = FlamingoForConditionalGeneration.from_pretrained(
    pretrained_model_name_or_path="/home/v-boli7/projects/openflamingo-9b-hf",
    device_map="auto",
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.68s/it]
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
Using pad_token, but it is not set yet.
Loading checkpoint shards: 100%|██████████| 4/4 [00:09<00:00,  2.27s/it]


In [19]:
correct_model = FlamingoForConditionalGeneration(config)
correct_model.load_state_dict(torch.load("/home/v-boli7/azure_storage/models/openflamingo/open_flamingo_hf.pt", map_location="cpu"), strict=False)

Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.21s/it]
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
Using pad_token, but it is not set yet.


_IncompatibleKeys(missing_keys=['lang_encoder.model.layers.0.decoder_layer.self_attn.q_proj.weight', 'lang_encoder.model.layers.0.decoder_layer.self_attn.k_proj.weight', 'lang_encoder.model.layers.0.decoder_layer.self_attn.v_proj.weight', 'lang_encoder.model.layers.0.decoder_layer.self_attn.o_proj.weight', 'lang_encoder.model.layers.0.decoder_layer.mlp.gate_proj.weight', 'lang_encoder.model.layers.0.decoder_layer.mlp.down_proj.weight', 'lang_encoder.model.layers.0.decoder_layer.mlp.up_proj.weight', 'lang_encoder.model.layers.0.decoder_layer.input_layernorm.weight', 'lang_encoder.model.layers.0.decoder_layer.post_attention_layernorm.weight', 'lang_encoder.model.layers.1.decoder_layer.self_attn.q_proj.weight', 'lang_encoder.model.layers.1.decoder_layer.self_attn.k_proj.weight', 'lang_encoder.model.layers.1.decoder_layer.self_attn.v_proj.weight', 'lang_encoder.model.layers.1.decoder_layer.self_attn.o_proj.weight', 'lang_encoder.model.layers.1.decoder_layer.mlp.gate_proj.weight', 'lang_enc

In [20]:
import copy
model.lang_encoder.lm_head.weight = copy.deepcopy(correct_model.lang_encoder.lm_head.weight)

In [27]:
tokenizer = model.text_tokenizer
image_processor = transformers.CLIPImageProcessor()

demo_image_one = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
demo_image_two = Image.open(requests.get("http://images.cocodataset.org/test-stuff2017/000000028137.jpg", stream=True).raw)
query_image = Image.open(requests.get("http://images.cocodataset.org/test-stuff2017/000000028352.jpg", stream=True).raw)

vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
"""
Step 3: Preprocessing text
Details: In the text we expect an <image> special token to indicate where an image is.
We also expect an <|endofchunk|> special token to indicate the end of the text 
portion associated with an image.
"""
model.text_tokenizer.padding_side = "left"  # For generation padding tokens should be on the left
lang_x = model.text_tokenizer(
    ["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"],
    return_tensors="pt",
)

"""
Step 4: Generate text
"""
generated_text = model.generate(
    vision_x=vision_x.to(model.device),
    lang_x=lang_x["input_ids"].to(model.device),
    attention_mask=lang_x["attention_mask"].to(model.device),
    max_new_tokens=20,
    num_beams=3,
)

print("Generated text: ", model.text_tokenizer.decode(generated_text[0]))

Generated text:  <s> <image> An image of two cats. <|endofchunk|> <image> An image of a bathroom sink. <|endofchunk|> <image> An image of univers univers univers spectrum spectrum spectrum spectrum spectrum spectrum spectrum spectrum spectrum spectrum spectrum spectrum spectrum spectrum cele cele cele


: 

In [22]:
model = model.to("cpu")
models_named_dict = {}
for name, param in model.named_parameters():
    models_named_dict[name] = param

for name, param in correct_model.named_parameters():
    if name not in models_named_dict:
        print(f"Failed {name}")
    else:
        if torch.allclose(param, models_named_dict[name], atol=1e-7, rtol=1e-5):
            pass
            # print(f"Passed {name}")
        else:
            print(f"Failed {name}")

In [24]:
model.save_pretrained("/home/v-boli7/projects/openflamingo-9b-hf")