[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Kinyugo/odewel/blob/main/example_notebooks/large_language_model_inference.ipynb)

# 🏋️ Large Language Model Inference 💪

This notebook illustrates how to perform inference using `odewel`. It showcases the power of `odewel` that enables one to run models on hardware with limited memory capacity.

We demonstrate the usage of `odewel` with the Flan-T5-XXL model. Specifically, we utilize the sharded version created by Phil Schmid to avoid the requirement of a device with 80GB of memory for model preparation 🥶.

## GPU Check

In [None]:
!nvidia-smi

## Setup

In [None]:
%pip install -q git+https://github.com/Kinyugo/odewel.git
%pip install -q transformers accelerate sentencepiece
%pip install -q joblib

## Imports

In [None]:
import os

import torch
from transformers import T5Config, T5ForConditionalGeneration, T5Tokenizer

from odewel.torch import ShardedWeightsLoader, init_on_demand_weights_model

## Fetch Checkpoints 

We will download the checkpoints from Hugging Face (HF).

> 💡 Although we are downloading the checkpoints here, the implementation of `odewel` is flexible enough to allow for downloading checkpoints during inference time. This could be useful in cases where storage is also limited, as you can download the required checkpoint just in time for inference, rather than having to store it on disk.

In [None]:
!git lfs install
!git clone https://huggingface.co/philschmid/flan-t5-xxl-sharded-fp16

## Run Inference With On Demand Weights Loading

Now that we have the sharded weights, we can then run inference with on demand weights loading. The maximum amount of memory that will be used at any one time is equal to the size of the largest shard plus the size of the batch. This allows inference on virtually any hardware, regardless of the model size.

All we need now are two key components:

1. `model_fn`: A function that returns an instance of our model. This function does not need to load weights, as the weights are initialized as empty. This means that our model will not consume any device memory, allowing us to initialize a model with billions of parameters even on an everyday laptop.

    ```python
    def model_fn() -> torch.nn.Module:
        ...
    ```

2. `weights_loader_fn`: A function that takes the module name and a list of weight names as inputs, and returns a mapping from weight name to weight.

    ```python
    def weights_loader_fn(module_name: str, weight_names: list[str]
    ) -> dict[str, torch.Tensor]:
        ...
    ```

In [None]:
HF_REPO_ID = "philschmid/flan-t5-xxl-sharded-fp16"
CKPT_DIR = "flan-t5-xxl-sharded-fp16"

# Prepare the `model_fn`. This is just a function that returns an instance of our model.
# It is used to create an instance of our model that is initialized with empty weights and
# thus doesn't consume device memory
tokenizer = T5Tokenizer.from_pretrained(HF_REPO_ID)
config = T5Config.from_pretrained(HF_REPO_ID)
model_fn = lambda: T5ForConditionalGeneration(config)

# Prepare the `weights_loader_fn`. This function performs the actual loading of weights
# as well as making sure the weights have the correct data type and are on the right device.
weights_loader_fn = ShardedWeightsLoader(
    index_file_path=os.path.join(CKPT_DIR, "pytorch_model.bin.index.json"),
    weights_dir=CKPT_DIR,
    weights_mapping_key="weight_map",
    device=torch.device("cuda"),
    dtype=torch.float16,
)

# Initialize our odewel model. `enable_preloading` determines whether any extra weights returned
# by the `weights_loader_fn` that are do not belong to the current module/sub-model will be loaded.
odewel_model = init_on_demand_weights_model(
    model_fn, weights_loader_fn, enable_preloading=True
)

In [None]:
def generate(tokenizer, model, prompt):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    outputs = model.generate(input_ids, max_new_tokens=20)

    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [None]:
# Sample from odewel model
torch.manual_seed(0)
PROMPT = "translate English to German: How old are you?"
print("odewel Model:", generate(tokenizer, odewel_model, PROMPT))