# End to End Model Preparation and Inference 🚀

This notebook demonstrates how to get your model ready for inference with `odewei` and perform efficient inference by loading weights only when they are needed. 🔥


## GPU Check


In [None]:
!nvidia-smi

## Setup

In [None]:
%pip install -q git+https://github.com/Kinyugo/odewei.git
%pip install -q joblib

## Imports

In [None]:
import os

import torch
from transformers import (
    AutoModelForSeq2SeqLM,
    T5Config,
    T5ForConditionalGeneration,
    T5Tokenizer,
)

from odewei.torch import ShardedWeightsLoader, init_on_demand_weights_model

## Prepare Checkpoints for On-Demand Loading

🥳 Get ready for limitless inference 

In this step, we'll load an existing model checkpoint and shard it, so that its weights can be loaded on-demand, instead of all at once. 

💡 Keep in mind, the sharding step has to be done on a device that can handle the full model size, as the full model state will be loaded. But, once the weights are sharded and saved, future inference tasks can run on any device that has at least the memory equal to the size of the largest shard.

For simplicity, we'll use HuggingFace's implementation, but note that the size of the shard determines the minimum device memory that can be supported. An alternative is to shard the model in a layer-wise fashion, allowing for more dynamic control over which layers are pre-loaded together. 


In [None]:
HF_REPO_ID = "google/flan-t5-small"
CKPT_DIR = "sharded-flan-t5-small"

# Load model checkpoints with the desired data type
model = AutoModelForSeq2SeqLM.from_pretrained(HF_REPO_ID, torch_dtype=torch.float16)

# Shard the model checkpoints, the size of the shard determines the mimimum GPU memory supported
# in this case since the model is tiny we set a very small shard of 100MB but in a more realistic
# use-case such as with the Flan-T5-XXL we would set something like 2000MB.
model.save_pretrained(CKPT_DIR, max_shard_size="100MB")


# View the saved checkpoints
print(os.listdir(CKPT_DIR))

## Run Inference With On Demand Weights Loading

Now that we have sharded the weights, we can then run inference with on demand weights loading. The maximum amount of memory that will be used at any one time is equal to the size of the largest shard plus the size of the batch. This allows inference on virtually any hardware, regardless of the model size.

All we need now are two key components:

1. `model_fn`: A function that returns an instance of our model. This function does not need to load weights, as the weights are initialized as empty. This means that our model will not consume any device memory, allowing us to initialize a model with billions of parameters even on an everyday laptop.

    ```python
    def model_fn() -> torch.nn.Module:
        ...
    ```

2. `weights_loader_fn`: A function that takes the module name and a list of weight names as inputs, and returns a mapping from weight name to weight.

    ```python
    def weights_loader_fn(module_name: str, weight_names: list[str]
    ) -> dict[str, torch.Tensor]:
        ...
    ```

In [None]:
# Prepare the `model_fn`. This is just a function that returns an instance of our model.
# It is used to create an instance of our model that is initialized with empty weights and
# thus doesn't consume device memory
tokenizer = T5Tokenizer.from_pretrained(HF_REPO_ID)
config = T5Config.from_pretrained(HF_REPO_ID)
model_fn = lambda: T5ForConditionalGeneration(config)

# Prepare the `weights_loader_fn`. This function performs the actual loading of weights
# as well as making sure the weights have the correct data type and are on the right device.
weights_loader_fn = ShardedWeightsLoader(
    index_file_path=os.path.join(CKPT_DIR, "pytorch_model.bin.index.json"),
    weights_dir=CKPT_DIR,
    weights_mapping_key="weight_map",
    device=torch.device("cuda"),
    dtype=torch.float16,
)

# Initialize our odewei model. `enable_preloading` determines whether any extra weights returned
# by the `weights_loader_fn` that are do not belong to the current module/sub-model will be loaded.
odewei_model = init_on_demand_weights_model(
    model_fn, weights_loader_fn, enable_preloading=True
)

In [None]:
PROMPT = "translate English to German: How old are you?"


def generate(tokenizer, model, prompt):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    outputs = model.generate(input_ids, max_new_tokens=20)

    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [None]:
# Sample from the regular PyTorch model
torch.manual_seed(0)
print("Regular PyTorch Model:", generate(tokenizer, model, PROMPT))

In [None]:
# Sample from odewei model
torch.manual_seed(0)
print("odewei Model:", generate(tokenizer, odewei_model, PROMPT))