# Magma Inference Demo

Copyright (c) 2023 Graphcore Ltd.

This notebook provides a basic interactive MAGMA inference application, based on the freely available checkpoint delivered by Aleph Alpha. Note that such checkpoint is only a demo meant to help users understand how the model works.

To run this notebook, make sure you have configured the environment as explained in the repository README.

The main inputs are the `image url` and the `text` prompt.
The image can be chosen between the provided examples or you can use an image url.
The text prompt can be the one that you prefer. In this interactive application the maximum allowed sequence length is by default 500. To use longer text prompts, you can change the key used to initialise the session from  `magma_v1_500` to `magma_v1_1024`.
Note that the output is sensitive to small changes to the prompt.

This interactive application allows you to control generation playing around with the `temperature`, `top-p` and `top-k` parameters. Moreover, the random `seed` is always explicitly set in order to have reproducible results even in presence of randomness.

- **top-p**: probability is redistributed among the first x tokens such that the cumulative probability is greater than the specified threshold p. Then, next token is sampled from such distribution (categorical sampling, non deterministic).
- **top-k**: probability is redistributed among the K most-likely tokens. Then, next token is sampled from such distribution (categorical sampling, non deterministic).
- **temperature**: logits are scaled by a factor 1/T (T between 0 and 1 ) before applying the softmax. This makes the distribution more peaked for low temperature, and broader for high temperatures. A zero temperature corresponds to a deterministic choice (argmax), while sampling output becomes more random as we increase the temperature.

If you are not familiar with these concepts, this [great hugging face article](https://huggingface.co/blog/how-to-generate) can help you visualise them.

You can also change the number of generated tokens by varying the `max_out_tokens` parameter.


In [1]:
from run_inference import run_inference, init_inference_session

  warn(f"Failed to load image Python extension: {e}")


In [2]:
import sys, os, os.path

os.environ["POPART_LOG_LEVEL"] = "ERROR"
print(os.environ["POPLAR_SDK_ENABLED"])

/nethome/sofial/workspace/sdks/poplar_sdk-ubuntu_20_04-3.2.0-EA.1+1213-ec6c27ac64/poplar-ubuntu_20_04-3.2.0+6970-37744fc347


## Compile and load the model
Compilation takes around 3 minutes.

In [3]:
session, config, tokenizer = init_inference_session(
    "magma_v1_500"
)  #  magma_v1_500 magma_v1_1024

2023-01-10 14:05:19 INFO: Starting. Process id: 60859
2023-01-10 14:05:19 INFO: Config: MagmaConfig(seed=0, visual=ResNetConfig(layers=(6, 8, 18, 8), width=96, image_resolution=384, execution=ResnetExecution(micro_batch_size=1, available_memory_proportion=(1.0, 1.0, 1.0, 1.0)), precision=<Precision.float16: popxl.dtypes.float16>), transformer=GPTJConfig(layers=28, hidden_size=4096, sequence_length=500, precision=<Precision.float16: popxl.dtypes.float16>, embedding=GPTJConfig.Embedding(vocab_size=50400, real_vocab_size=50258), attention=GPTJConfig.Attention(heads=16, rotary_positional_embeddings_base=10000, rotary_dim=64, use_cache=False), execution=GPTJExecution(micro_batch_size=1, available_memory_proportion=(0.45,), tensor_parallel=4, attention_serialisation=1), att_adapter=GPTJConfig.Adapter(layer_norm=False, downsample_factor=8, mode=None), ff_adapter=GPTJConfig.Adapter(layer_norm=False, downsample_factor=4, mode='normal')))
2023-01-10 14:05:19 INFO: Starting PopXL IR construction




Loading GPTJ language model...
loading magma checkpoint from: ./mp_rank_00_model_states.pt
magma successfully loaded
2023-01-10 14:11:16 INFO: Loading magma weights to host duration: 1.50 mins
2023-01-10 14:11:16 INFO: Starting Loading magma pretrained model to IPU
2023-01-10 14:14:51 INFO: Loading magma pretrained model to IPU duration: 3.58 mins


## Run demo

In [4]:
from PIL import Image
import requests
from io import BytesIO
import ipywidgets as ipw

In [5]:
def answer_int(image_url, text, seed, top_p, top_k, temperature, max_out_tokens):

    if image_url.startswith("http"):
        response = requests.get(image_url)
        image = BytesIO(response.content)
    else:
        image = open(image_url, "rb")
    img = ipw.Image(value=image.read(), width=384, height=384)
    prompt = ipw.Label(value=f"Prompt: {text}", style={"font_size": "16px"})
    answer = ipw.Label(
        value=f"Answer: `{run_inference(session, config, tokenizer, image_url, text, seed, top_p, top_k, temperature, max_out_tokens)}`",
        style={"font_size": "16px"},
    )
    return ipw.VBox(
        [img, prompt, answer], layout=ipw.Layout(display="flex", align_items="center")
    )


il = ipw.Layout(width="600px")

image_choices = [
    "https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg",
    "demo_example_images/cantaloupe_popsicle.jpg",
    "demo_example_images/circles.jpg",
    "demo_example_images/circles_square.jpg",
    "demo_example_images/korea.jpg",
    "demo_example_images/matterhorn.jpg",
    "demo_example_images/mushroom.jpg",
    "demo_example_images/people.jpg",
    "demo_example_images/playarea.jpg",
    "demo_example_images/popsicle.png",
    "demo_example_images/rainbow_popsicle.jpeg",
    "demo_example_images/table_tennis.jpg",
]
ipw.interact(
    answer_int,
    image_url=ipw.Dropdown(
        options=image_choices,
        value="https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg",
        layout=il,
        continuous_update=False,
    ),
    text=ipw.Text(value="A painting of ", layout=il, continuous_update=False),
    seed=ipw.IntSlider(0, 0, 300, layout=il, continuous_update=False),
    top_p=(
        ipw.FloatSlider(
            value=0.9, min=0.0, max=1.0, step=0.01, layout=il, continuous_update=False
        )
    ),
    top_k=ipw.IntSlider(0, 0, 10, layout=il, continuous_update=False),
    temperature=ipw.FloatSlider(
        value=0.7, min=0.0, max=1.0, step=0.01, layout=il, continuous_update=False
    ),
    max_out_tokens=ipw.IntSlider(6, 1, 356, layout=il, continuous_update=False),
);

interactive(children=(Dropdown(description='image_url', layout=Layout(width='600px'), options=('https://www.ar…