# GILL Inference Examples

This is a notebook showcasing how to run GILL for image generation, image retrieval, and text generation, some of the tasks that GILL is capable of. It reproduces several examples in our paper, [Generating Images with Multimodal Language Models](https://arxiv.org/abs/2305.17216).

For reproducibility, all examples in this notebook use greedy (deterministic) decoding. However, it is possible to change to nucleus sampling for more diverse and higher quality outputs (used for some of the figures in the paper) by changing the `temperature` and `top_p` parameters in the `generate()` function.

At least 22GB of GPU memory is required to run this model, and it has only been tested on A6000, V100, and 3090 GPUs.

In [7]:
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import torch
from tqdm import notebook

from gill import models
from gill import utils

In [None]:
# !pip uninstall torch torchvision torchaudio

In [None]:
# !pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121

In [None]:
import torch.nn as nn

## Load the Model

Note that you will need to download the [CC3M image embeddings](https://drive.google.com/file/d/1e9Cimh2dpWN8Cbgx_mSR-954Dr-DS-ZO/view) and place them in the `checkpoints/gill_opt/` folder, in order to use GILL's image retrieval capabilities. If this embedding file does not exist, the model will still run, but it will exclusively generate images (as opposed to deciding when to retrieve or generate).

In [8]:
# Download the model checkpoint and embeddings to checkpoints/gill_opt/
model_dir = 'checkpoints/gill_opt/'
model = models.load_gill(model_dir)

cc3m.npy files do not exist in checkpoints/gill_opt/.
Running the model without retrieval.
Adding [IMG0] token to vocabulary.
Before adding new token, tokenizer("[IMG0]") = {'input_ids': [10975, 3755, 534, 288, 742], 'attention_mask': [1, 1, 1, 1, 1]}
After adding 1 new tokens, tokenizer("[IMG0]") = {'input_ids': [50266], 'attention_mask': [1]}
Adding [IMG1] token to vocabulary.
Before adding new token, tokenizer("[IMG1]") = {'input_ids': [10975, 3755, 534, 134, 742], 'attention_mask': [1, 1, 1, 1, 1]}
After adding 1 new tokens, tokenizer("[IMG1]") = {'input_ids': [50267], 'attention_mask': [1]}
Adding [IMG2] token to vocabulary.
Before adding new token, tokenizer("[IMG2]") = {'input_ids': [10975, 3755, 534, 176, 742], 'attention_mask': [1, 1, 1, 1, 1]}
After adding 1 new tokens, tokenizer("[IMG2]") = {'input_ids': [50268], 'attention_mask': [1]}
Adding [IMG3] token to vocabulary.
Before adding new token, tokenizer("[IMG3]") = {'input_ids': [10975, 3755, 534, 246, 742], 'attention_mask



Using HuggingFace AutoFeatureExtractor for laion/clap-htsat-fused.
Using facebook/opt-125m for the language model.
Using openai/clip-vit-large-patch14 for the visual model with 4 visual tokens.
Using laion/clap-htsat-fused as audio encoder
Freezing the audio model
Freezing the LM.
Restoring pretrained weights for the visual model.


Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encode

Freezing the VM.
---------- Audio embedding ----------
LM input embedding dimension :  768
number of visual tokens :  4
Embedding dim:  3072
Retrieval embedding dim:  256


`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


Loading decision model...


In [None]:
#caption
encoder_outputs = torch.rand((1,1024))
hidden_size = 1024

visual_embeddings = nn.Linear(hidden_size, 16384)
visual_embs = visual_embeddings(encoder_outputs)
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 4, -1))
visual_embs.size()

In [None]:
#retireval
visual_fc = nn.Linear(hidden_size, 256)
visual_embs = visual_fc(encoder_outputs)
print(visual_embs.size())
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 1, -1))
print(visual_embs.size())

# Image Generation

GILL can generate images conditioned on image and text inputs. Shown are several examples for various text prompts and image + text prompts.

In [None]:
sofa_img = utils.get_image_from_url('https://images.pexels.com/photos/1866149/pexels-photo-1866149.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=1')
blue = utils.get_image_from_url('https://unsplash.com/photos/0YQz7M2fcYY/download?ixid=M3wxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjkyMDEyMzA5fA&force=true')
red = utils.get_image_from_url('https://unsplash.com/photos/3TuIIkWlpvA/download?ixid=M3wxMjA3fDB8MXxzZWFyY2h8Mnx8cmVkJTIwY29sb3J8ZW58MHx8fHwxNjkxOTk1NzM0fDA&force=true')

In [None]:
# import requests

# response = requests.get('https://unsplash.com/photos/0YQz7M2fcYY/download?ixid=M3wxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjkyMDEyMzA5fA&force=true')
# response.content

In [None]:
# sofa_img = utils.get_image_from_url('https://images.pexels.com/photos/1866149/pexels-photo-1866149.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=1')
# ['an astronaut riding a horse on mars'],
    # [sofa_img, 'a picture of this but in red, color'],

# Generate for a few types of text prompts and image + text prompts.
for prompt in [
    ['a picture of a cat'],    
]:
    g_cuda = torch.Generator(device='cuda').manual_seed(1337)
    return_outputs = model.generate_for_images_and_texts(
        prompt, num_words=2, ret_scale_factor=100.0, generator=g_cuda)
    
    # Show either the generated or retrieved image, depending on the decision model outputs.
    if return_outputs[1]['decision'][0] == 'gen':
        plt.imshow(return_outputs[1]['gen'][0][0])
        plt.title('Generated')
    else:
        plt.imshow(return_outputs[1]['ret'][0][0].resize((512, 512)))
        plt.title(f"Retrieved")
    plt.show()

# Caption an audio

In [9]:
import librosa
audio_data, sampling_rate = librosa.load("sin_100.wav")

In [10]:
from typing import List
type(audio_data) in [np.ndarray,torch.Tensor,List[np.ndarray],List[torch.Tensor]]

True

In [14]:
prompts = [
    audio_data,
    'This is the sound of'
]

return_outputs = model.generate_for_images_and_texts(prompts, num_words=16, min_word_tokens=16)
print(return_outputs[0])

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


<class 'numpy.ndarray'>
Audio embs shape (captioning) : torch.Size([1, 4, 768])
<class 'str'>
 a man who has been in a coma for a year.


In [None]:
import gc
gc.collect()

# Multimodal Dialogue

GILL can also generate dialogue-like text. We define some helper functions for displaying outputs:

In [None]:
def generate_dialogue(prompts: list, system_message: str = None, num_words: int = 32,
                      sf: float = 1.0, temperature: float = 0.0, top_p: float = 1.0,
                      divider_count: int = 40):
    g_cuda = torch.Generator(device='cuda').manual_seed(1337)

    full_outputs = []
    if system_message:
        print("Adding system message")
        full_inputs = [system_message]
    else:
        full_inputs = []

    for prompt_idx, prompt in notebook.tqdm(enumerate(prompts), total=len(prompts)):
        formatted_prompt = []
        for p in prompt:
            if type(p) == Image.Image:
                full_inputs.append(p)
                formatted_prompt.append(p)
            elif type(p) == str:
                full_inputs.append(f'Q: {p}\nA:')
                formatted_prompt.append(f'User: {p}')
        formatted_prompt.append('=' * divider_count)  # Add divider

        return_outputs = model.generate_for_images_and_texts(
            full_inputs, num_words=num_words, ret_scale_factor=sf,
            generator=g_cuda, temperature=temperature, top_p=top_p)

        # Add outputs
        output_text = return_outputs[0].replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', '')
        full_inputs.append(output_text + '\n')

        formatted_return_outputs = []
        for p in return_outputs:
            if type(p) == str:
                p_formatted = p.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', '')
                formatted_return_outputs.append(f'GILL: {p_formatted}')
            else:
                formatted_return_outputs.append(p)
        formatted_return_outputs.append('=' * divider_count)  # Add divider

        full_outputs.extend(formatted_prompt + formatted_return_outputs)

    return full_outputs


def display_conversation(full_outputs):
    # Display conversation.
    for p in full_outputs:
        if type(p) == Image.Image:
            plt.figure(figsize=(4, 4))
            plt.imshow(p)
            plt.show()
        elif type(p) == str:
            print(p)
        elif type(p) == dict:
            # Decide whether to retrieve or generate
            decision_probs = [f'{s:.3f}' for s in p['decision'][1]]
            if p['decision'][0] == 'gen':
                gen_img = p['gen'][0][0].resize((512, 512))
                # Generate
                plt.figure(figsize=(4, 4))
                plt.imshow(gen_img)
                plt.title(f'GENERATED (p={decision_probs})')
            else:
                ret_img = p['ret'][0][0].resize((512, 512))
                # Retrieve
                plt.figure(figsize=(4, 4))
                plt.imshow(ret_img)
                plt.title(f'RETRIEVED (p={decision_probs})')
            plt.show()
        else:
            raise NotImplementedError(p)

The inputs to the model can be an interleaved image and text sequence. Shown is one example from our paper, with a cupcake image input and a question (note that the word "cupcake" is never explicitly mentioned, but GILL infers it from the image).

In [None]:
sf = 1.4  # Scaling factor: increase to increase the chance of returning an image
temperature = 0.0  # 0 means deterministic, try 0.6 for more randomness
top_p = 1.0  # If you set temperature to 0.6, set this to 0.95
num_words = 32

prompts = [
    [
        utils.get_image_from_url('https://www.allrecipes.com/thmb/riDYvmalWk8QgJDBT_pZRkpfpR0=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/17377-chocolate-cupcakes-DDMFS-4x3-622a7a66fcd84692947794ed385dc991.jpg'),
        'How should I publicise these at the market?'
    ],
]

full_outputs = generate_dialogue(prompts, num_words=num_words, sf=sf, temperature=temperature, top_p=top_p)
display_conversation(full_outputs)

# Image-to-Text Example

GILL can also generate text conditioned on image and text inputs. This is helpful for tasks such as image captioning or VQA.

In [None]:
pancakes_img = utils.get_image_from_url('https://images.pexels.com/photos/376464/pexels-photo-376464.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=1')
prompts = [
    pancakes_img,
    'A picture of'
]

plt.imshow(pancakes_img)
plt.show()

return_outputs = model.generate_for_images_and_texts(prompts, num_words=16, min_word_tokens=16)
print(return_outputs[0])

## Interleaving Image and Text input

In [None]:
blue = utils.get_image_from_url('https://unsplash.com/photos/0YQz7M2fcYY/download?ixid=M3wxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjkyMDEyMzA5fA&force=true')
pancake = utils.get_image_from_url('https://images.pexels.com/photos/376464/pexels-photo-376464.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=1')
whipped_cream = utils.get_image_from_url("https://media.istockphoto.com/id/510175498/photo/whipped-cream-on-white-background.jpg?s=2048x2048&w=is&k=20&c=Z_4UJPvUq2kmQXnWngY6apzKZjDgRIGfxaa82mPakcM=")

In [None]:
# prompts = [pancake, "This but with grey colored syrup"]
# g_cuda = torch.Generator(device='cuda').manual_seed(1337)
# return_outputs = model.generate_for_images_and_texts(prompts, num_words=2, ret_scale_factor=100.0, generator=g_cuda)

# # Show either the generated or retrieved image, depending on the decision model outputs.
# if return_outputs[1]['decision'][0] == 'gen':
#     plt.imshow(return_outputs[1]['gen'][0][0])
#     plt.title('Generated')
# else:
#     plt.imshow(return_outputs[1]['ret'][0][0].resize((512, 512)))
#     plt.title(f"Retrieved")
# plt.show()

In [None]:
sofa_img = utils.get_image_from_url('https://images.pexels.com/photos/1866149/pexels-photo-1866149.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=1')
blue = utils.get_image_from_url('https://unsplash.com/photos/0YQz7M2fcYY/download?ixid=M3wxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjkyMDEyMzA5fA&force=true')
red = utils.get_image_from_url('https://unsplash.com/photos/3TuIIkWlpvA/download?ixid=M3wxMjA3fDB8MXxzZWFyY2h8Mnx8cmVkJTIwY29sb3J8ZW58MHx8fHwxNjkxOTk1NzM0fDA&force=true')

sf = 1.4  # Scaling factor: increase to increase the chance of returning an image
temperature = 0.0  # 0 means deterministic, try 0.6 for more randomness
top_p = 1.0  # If you set temperature to 0.6, set this to 0.95
num_words = 32

prompts = [
    [
        sofa_img, " but in ", blue, " color."
    ],
]

full_outputs = generate_dialogue(prompts, num_words=num_words, sf=sf, temperature=temperature, top_p=top_p)
display_conversation(full_outputs)

In [None]:
sf = 1.4  # Scaling factor: increase to increase the chance of returning an image
temperature = 0.0  # 0 means deterministic, try 0.6 for more randomness
top_p = 1.0  # If you set temperature to 0.6, set this to 0.95
num_words = 32

prompts = [
    [
        "How would this ", pancake, " look with ", whipped_cream, " on it?"
    ],
]

full_outputs = generate_dialogue(prompts, num_words=num_words, sf=sf, temperature=temperature, top_p=top_p)
display_conversation(full_outputs)

## VIST eval

In [None]:
!wget https://visionandlanguage.net/VIST/json_files/story-in-sequence/SIS-with-labels.tar.gz

In [None]:
!tar -xzf SIS-with-labels.tar.gz -C sis

In [None]:
!rm SIS-with-labels.tar.gz

In [None]:
!python evals/download_vist_images.py

In [None]:
%%capture captured_output
%%writefile output.txt

## Audio Processing

In [None]:
import librosa
import librosa.display
import IPython.display as ipd
from transformers import ClapModel, ClapProcessor, ClapConfig

In [None]:
ipd.Audio("sin_100.wav")

In [None]:
# configuration = ClapConfig()
# model = ClapModel(configuration)

In [None]:
audio_data, sampling_rate = librosa.load("sin_100.wav")

In [None]:
model = ClapModel.from_pretrained("laion/clap-htsat-fused").to(0)
processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")

# inputs = processor(audios=audio_sample["audio"]["array"], return_tensors="pt").to(0)
# audio_embed = model.get_audio_features(**inputs)

In [None]:
inputs = processor(audios=audio_data, return_tensors="pt").to(0)
audio_embed = model.get_audio_features(**inputs)

In [None]:
audio_embed.shape

In [None]:
type(audio_data)

In [5]:
!pip install moviepy

Collecting moviepy
  Downloading moviepy-1.0.3.tar.gz (388 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m388.3/388.3 kB[0m [31m645.8 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting decorator<5.0,>=4.0.2 (from moviepy)
  Downloading decorator-4.4.2-py2.py3-none-any.whl (9.2 kB)
Collecting proglog<=1.0.0 (from moviepy)
  Downloading proglog-0.1.10-py3-none-any.whl (6.1 kB)
Collecting imageio<3.0,>=2.5 (from moviepy)
  Obtaining dependency information for imageio<3.0,>=2.5 from https://files.pythonhosted.org/packages/f6/37/e21e6f38b93878ba80302e95b8ccd4718d80f0c53055ccae343e606b1e2d/imageio-2.31.5-py3-none-any.whl.metadata
  Downloading imageio-2.31.5-py3-none-any.whl.metadata (4.6 kB)
Collecting imageio_ffmpeg>=0.2.0 (from moviepy)
  Obtaining dependency information for imageio_ffmpeg>=0.2.0 from https://files.pythonhosted.org/packages/1a/98/3df1d8dd8f2c121b6c588b1e0d604f36592d56df9c41fb155ed546c6a

In [17]:
import librosa
import os
from IPython.display import Audio
from moviepy.editor import AudioFileClip

In [25]:
Audio("datasets/AudioCaps/train/100011.wav")

In [20]:
def get_audio_duration(file_path):
    audio_clip = AudioFileClip(file_path)
    duration = audio_clip.duration
    audio_clip.close()
    return duration

In [15]:
# audio_data, sampling_rate = librosa.load("datasets/AudioCaps/train/59
get_audio_duration("datasets/AudioCaps/train/0.wav") == 10

True

In [24]:
count = 0
for index,f in enumerate(os.listdir("datasets/AudioCaps/train")):
    print(f)
    if get_audio_duration(os.path.join("datasets/AudioCaps/train",f)) != 10:
        count += 1
        print(count)


0.wav
100004.wav
1
100005.wav
100011.wav


OSError: MoviePy error: failed to read the duration of file datasets/AudioCaps/train/100011.wav.
Here are the file infos returned by ffmpeg:

ffmpeg version 4.2.2-static https://johnvansickle.com/ffmpeg/  Copyright (c) 2000-2019 the FFmpeg developers
  built with gcc 8 (Debian 8.3.0-6)
  configuration: --enable-gpl --enable-version3 --enable-static --disable-debug --disable-ffplay --disable-indev=sndio --disable-outdev=sndio --cc=gcc --enable-fontconfig --enable-frei0r --enable-gnutls --enable-gmp --enable-libgme --enable-gray --enable-libaom --enable-libfribidi --enable-libass --enable-libvmaf --enable-libfreetype --enable-libmp3lame --enable-libopencore-amrnb --enable-libopencore-amrwb --enable-libopenjpeg --enable-librubberband --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libvorbis --enable-libopus --enable-libtheora --enable-libvidstab --enable-libvo-amrwbenc --enable-libvpx --enable-libwebp --enable-libx264 --enable-libx265 --enable-libxml2 --enable-libdav1d --enable-libxvid --enable-libzvbi --enable-libzimg
  libavutil      56. 31.100 / 56. 31.100
  libavcodec     58. 54.100 / 58. 54.100
  libavformat    58. 29.100 / 58. 29.100
  libavdevice    58.  8.100 / 58.  8.100
  libavfilter     7. 57.100 /  7. 57.100
  libswscale      5.  5.100 /  5.  5.100
  libswresample   3.  5.100 /  3.  5.100
  libpostproc    55.  5.100 / 55.  5.100
[wav @ 0x69cc480] Cannot check for SPDIF
Guessed Channel Layout for Input Stream #0.0 : stereo
Input #0, wav, from 'datasets/AudioCaps/train/100011.wav':
  Metadata:
    encoder         : Lavf58.29.100
  Duration: N/A, bitrate: 1411 kb/s
    Stream #0:0: Audio: pcm_s16le ([1][0][0][0] / 0x0001), 44100 Hz, stereo, s16, 1411 kb/s
At least one output file must be specified
