In [1]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, OPTForCausalLM, MistralForCausalLM
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
from accelerate.hooks import remove_hook_from_module
from accelerate.utils import named_module_tensors, find_tied_parameters


import numpy as np
from numpy.lib.format import open_memmap

import os
import sys
import json
from copy import deepcopy

from threading import Thread
from queue import Queue 

import functools 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint = 'facebook/opt-125m'
# checkpoint = 'facebook/opt-13B'
# checkpoint = 'mistralai/Mistral-7B-v0.1'

comp_device = 0
torch_dtype = torch.float16
weights_offload_dir = f'_weights_offload/{checkpoint}/{torch_dtype}'

In [3]:

def find_module_list(module: nn.Module) -> tuple[nn.Module, str]:
    def _find_module_list(module: nn.Module, prefix=''):
        if isinstance(module, nn.ModuleList):
            yield module, prefix
        else:
            for name, child in module.named_children():
                yield from _find_module_list(child, prefix=prefix+'.'+name if prefix else name)
    
    g = _find_module_list(module)
    try:
        return next(iter(g))
    except:
        raise ValueError(f'{module.__class__.__name__} does not have a nn.ModuleList structure')


In [4]:
"""
1. get model parameter & buffer names
2. find the transformer block module
3. get a device map
4. get offloaded weights np.memmap files
"""
class ModelPrepare:
    def __init__(self, **kwargs) -> None:
        self.checkpoint = kwargs.get('checkpoint')
        self.torch_dtype = kwargs.get('torch_dtype')
        self.comp_device = kwargs.get('comp_device')
        self.weights_offload_dir = kwargs.get('weights_offload_dir') 

        self.empty_model = self.get_empty_model()
        self.layers, self.layers_name = self.parse_model_architecture()
        self.device_map = self.get_device_map()
        self.prepare_weights_memmap()

        self.model = self.init_model_weights()

    def get_empty_model(self):
        config = AutoConfig.from_pretrained(checkpoint)
        with init_empty_weights(): 
            e = AutoModelForCausalLM.from_config(config,)
        # don't run e.tie_weights() or the tied weights will not be in the device map
        # e.tie_weights()            
        return e

    def parse_model_architecture(self):
        layers_module, layers_name = find_module_list(self.empty_model)
        return layers_module, layers_name

    def get_device_map(self):
        """
        give the found transformer block list, set it to the `meta` or `disk` device; 
        send the device map to AutoModelForCausalLM.from_pretrained() and set the weights_offload_dir, the code from huggingface will automatically prepare the np.memmap files in the offload folder
        """
        res = {}
        for n, t in named_module_tensors(self.empty_model, recurse=True):
            if isinstance(t, nn.Parameter) and self.layers_name in n:
                res[n] = 'disk'
            else:
                res[n] = self.comp_device
        return res

    def prepare_weights_memmap(self):
        """init all nn.Parameter in model's transformer blocks to meta device , and others to compute device. (based on the device map)"""
        # all parameters of the model will be offloaded as memory-mapped array in a given folder.
        if not os.path.exists(self.weights_offload_dir):
            try:
                AutoModelForCausalLM.from_pretrained(
                    self.checkpoint, 
                    device_map={'':'disk'},  
                    torch_dtype=self.torch_dtype, 
                    offload_folder=self.weights_offload_dir, 
                    use_safetensors=False # use pytorch *.bin, as accelerate disk_offload have some bugs for safetensors
                )
            except:
                pass 
        
    def init_model_weights(self):
        model = AutoModelForCausalLM.from_pretrained(
            self.checkpoint, 
            device_map={k:v if v != 'disk' else 'meta' for k, v in self.device_map.items()}, # use 'meta' for no behavior 
            torch_dtype=self.torch_dtype, 
            offload_folder=None, 
            use_safetensors=False 
        )

        # remove accelerate disk_offload hooks (if has)
        model = remove_hook_from_module(model, recurse=True) 
        return model

mp = ModelPrepare(
    checkpoint=checkpoint,
    comp_device=comp_device,
    torch_dtype=torch_dtype, 
    weights_offload_dir=weights_offload_dir
)
model = mp.model




In [5]:
class DiskWeightsLoader:
    def __init__(self, weights_offload_dir) -> None:
        self.weights_offload_folder = weights_offload_dir

        with open(os.path.join(weights_offload_dir, "index.json"), "r") as f: 
            self.index = json.load(f)  

    def open_memmap(self, key: str) -> np.memmap:
        metadata = self.index[key]

        f_name = os.path.join(weights_offload_dir, key + '.dat')

        shape = tuple(metadata["shape"])
        if shape == ():
            # NumPy memory-mapped arrays can't have 0 dims so it was saved as 1d tensor
            shape = (1,)

        dtype = metadata["dtype"]
        if dtype == "bfloat16":
            # NumPy does not support bfloat16 so this was saved as a int16
            dtype = "int16"

        weight = np.memmap(f_name, dtype=dtype, shape=shape, mode="r") # no data movement

        if len(metadata["shape"]) == 0:
            weight = weight[0]

        # weight = torch.from_numpy(weight) # no data movement

        if metadata["dtype"] == "bfloat16":
            weight = weight.view(torch.bfloat16)

        return weight
    
dl = DiskWeightsLoader(weights_offload_dir)
mmap = dl.open_memmap(key="model.decoder.layers.0.fc1.bias")
d_tensor = torch.from_numpy(mmap)

g_tensor = torch.zeros(*mmap.shape, device = 0, dtype = d_tensor.dtype, pin_memory=False)

mmap, g_tensor.copy_(d_tensor) # d2g

  d_tensor = torch.from_numpy(mmap)


(memmap([-0.01394 , -0.002508, -0.01517 , ..., -0.005646, -0.01177 ,
         -0.003656], dtype=float16),
 tensor([-0.0139, -0.0025, -0.0152,  ..., -0.0056, -0.0118, -0.0037],
        device='cuda:0', dtype=torch.float16))

In [6]:
class Policy:
    def __init__(self, **kwargs):
        self.kwargs = kwargs 
        self.x, self.y, self.z = self.get_vars(['x', 'y', 'z'])
        self.g, self.c, self.d = self.get_vars(['g', 'c', 'd'])

    def get_vars(self, vars: list[str]):
        values = [self.kwargs.get(var) for var in vars]
        assert all(val is None or 0 <= val <= 1 for val in values) 
        assert len([val for val in values if val is None]) <= 1 or (1 in values)
        assert sum([val for val in values if val is not None]) <= 1
        
        for i, val in enumerate(values):
            if val is None:
                values[i] = 1 - sum([_val for _val in values if _val is not None])
        return values 
    
    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(x, y, z, g, c, d) = {self.x, self.y, self.z, self.g, self.c, self.d}'
    
policy = Policy(g=0.5, c=0.5, x=1, )
policy

Policy(x, y, z, g, c, d) = (1, 0, 0, 0.5, 0.5, 0.0)

In [14]:
d = {1: 'a'}
d.get(2, 'b')

'b'

In [None]:
class Home:
    """ 
    home of w/x/y (weights & caches) on g/c/d as list[Vector]
    
    memory view:

    n_layers * 
    + - - - - - - - - - +   + - - - - - - - - - +   + - - - - - - - - - +
    |     W(GPU) arrs   |   |    X(GPU) vecs    |   |    Y(GPU) vecs    |
    + - - - - - - - - - +   + - - - - - - - - - +   + - - - - - - - - - +
    |     W(CPU) arrs   |   |    X(CPU) vecs    |   |    Y(CPU) vecs    |
    + - - - - - - - - - +   + - - - - - - - - - +   + - - - - - - - - - +
    |     W(Disk) arrs  |   |    X(Disk) vecs   |   |    Y(Disk) vecs   |
    + - - - - - - - - - +   + - - - - - - - - - +   + - - - - - - - - - +
    
    where an arr(array) is a vec(vector) with a fixed length in the dim of `s'.
    More importantly, each arr | vec is with a fixed chunk_size in the dim of `h', 
    a bunch of arrs | vecs comprises the whole memory view.
    We aggregate n layers' w/x/y in same g/c/d (arr | vec) storages.

    the gpu arr | vec is based on cuda tensor
    the cpu arr | vec is based on pinned memory
    the disk arr | vec is based on np.memmap
    """
    def __init__(self, **kwargs):
        self.policy: Policy = kwargs.get('policy') 
        self.layers: nn.ModuleList = kwargs.get('layers')

        # chunksize in h dim
        self.chunk_size = kwargs.get("chunk_size", 64) # default to 64
    
    def f(self):
        ...


In [7]:
class RuntimeBuffers:
    """
    home buffers and layer running buffers
    for weights and caches
    and policy: x, y, z, g, c ,d
    """
    def __init__(self, **kwargs):
        self.layers: nn.ModuleList = kwargs.get('layers') 
        self.policy: Policy = kwargs.get('policy') 
        self.comp_device = kwargs.get('comp_device') 

        self.home = {
            'weights': ...,
            'kv_cache': None,
            'actv_cache': None,
        } 
        self.running = {
            'weights': self.init_layer_weights_buffer(),
            'kv_cache': self.init_layer_kv_cache_buffer(),
            'actv_cache': self.init_layer_actv_cache_buffer(),
        }

    def init_home_weights_buffer(self):
        ... 
    def init_home_kv_cache_buffer(self):
        ... 
    def init_home_actv_cache_buffer(self):
        ... 

    def init_layer_weights_buffer(self):
        one_layer = self.layers[0]
        buff = {
            n: torch.zeros(*t.shape, dtype=t.dtype, device=self.comp_device) 
            for n, t in named_module_tensors(one_layer, recurse=True) 
            if isinstance(t, nn.Parameter)
        }
        return buff

    def init_layer_kv_cache_buffer(self):
        return 
    
    def init_layer_actv_cache_buffer(self):
        return 


In [8]:

import torch
from accelerate.utils import honor_type
from typing import Mapping

def get_info(obj, debug=False):
    if isinstance(obj, (tuple, list)):
        ret = honor_type(obj, (get_info(o) for o in obj))
        if len(set(ret)) == 1 and len(ret) > 1:
            return f"{len(ret)} * {ret[0]}"
        else:
            return ret 
    elif isinstance(obj, Mapping):
        return type(obj)({k: get_info(v) for k, v in obj.items()})
    elif isinstance(obj, (torch.Tensor)):
        if debug:
            return f"{obj.__class__.__name__}(shape={tuple(obj.size())}, dtype={obj.dtype}, device={obj.device}, mem/elem/dtype={sys.getsizeof(obj.storage()) / obj.numel() / obj.element_size():.3f})"
        else:
            return f"{obj.__class__.__name__}(shape={tuple(obj.size())}, mem/elem/dtype={sys.getsizeof(obj.storage()) / obj.numel() / obj.element_size():.3f})"
    elif isinstance(obj, (int, bool, type(None))):
        return f"{obj}"
    else:
        return f"{obj.__class__.__name__}: {obj}"

from data_movement import Engine, Task
from vector import Vector

def kv_cache_kwarg_name(hf_model):
    if isinstance(hf_model, OPTForCausalLM | MistralForCausalLM):
        return 'past_key_value'
    else:
        raise NotImplementedError() 

class Buffer:
    def __init__(self, buff, loaded_flag=False):
        self.buff = buff 
        self.loaded_flag = loaded_flag

    def __getitem__(self, key):
        return self.buff[key]
    
    def __setitem__(self, key, value):
        self.buff[key] = value

class Model:
    def __init__(self, **kwargs) -> None:
        # ModelPrepare 
        self.checkpoint = kwargs.get('checkpoint')
        self.torch_dtype = kwargs.get('torch_dtype')
        self.comp_device = kwargs.get('comp_device')
        self.weights_offload_dir = kwargs.get('weights_offload_dir') 
        self.mp = ModelPrepare(**kwargs)

        self.device_map = self.mp.device_map 
        self.hf_model = self.mp.model  
        self.layers, self.layers_name = find_module_list(self.hf_model)  
        self.kv_cache_kwarg_name = kv_cache_kwarg_name(self.hf_model)
        self.weight_keys = self.device_map.keys() 

        # Data Movement 
        self.disk_weight_loader = DiskWeightsLoader(self.weights_offload_dir)
        self.dm_engine = Engine(self.comp_device)
        self.dm_engine.start()

        # flexgen
        self.policy = kwargs.get('policy')
        self.m = kwargs.get('m') # minibatches
        self.max_gmem = kwargs.get('max_gmem') 
        self.max_cmem = kwargs.get('max_cmem') 
        self.max_dmem = kwargs.get('max_dmem') 

        # w/x/y/z home & layer running buffers
        self.w_home = ... # g/c/d
        self.x_home = ... # g/c/d
        self.y_home = ... # g/c/d

        # weights
        self.w_buff_curr = Buffer(
            {
                n: torch.zeros(*t.shape, dtype=t.dtype, device=self.comp_device) 
                for n, t in named_module_tensors(self.layers[0], recurse=True) 
                if isinstance(t, nn.Parameter)
            }, 
            loaded_flag=False
        )
        self.w_buff_next = deepcopy(self.w_buff_curr)

        # kv cache: 2x (b, s, h_kv)
        self.x_buff_curr = Buffer(None, False) # Vector
        self.x_buff_next = Buffer(None, False)

        # actv: 1x (b, s, h_a)
        self.y_buff_curr = Buffer(None, False)
        self.y_buff_next = Buffer(None, False)

    def init_w_home(self):
        policy = self.policy 
        g, c, d = policy.g, policy.c, policy.d 
        
        names = [n for n, t in self.device_map.item() if t.device in ['meta', torch.device('meta')]]
        for name in names:
            mmap = self.disk_weight_loader.open_memmap(name)

            # dm_engine: vector push | pop


    @property
    def layer_weight_names(self):
        return [
            n for n, t in named_module_tensors(self.layers[0], recurse=True) 
            if isinstance(t, nn.Parameter)
        ]

    def override_layer_forward(self, i: int):
        layer = self.layers[i]
        old_forward = layer.forward

        # reference layer weights to layer running buffer
        def set_reference(module: nn.Module, name: str, reference: torch.Tensor | nn.Parameter):
            splits = name.split(".")
            for s in splits[:-1]:
                module = getattr(module, s)
                
            if isinstance(reference, torch.Tensor):
                reference = nn.Parameter(reference)
            
            setattr(module, splits[-1], reference)    
            # print(name, getattr(module, splits[-1]).device)

        for name in self.layer_weight_names:
            set_reference(layer, name, self.w_buff_curr[name])

        @functools.wraps(old_forward)
        def new_forward(*args, **kwargs):
            print(f'\t{i = }, {get_info(args) = }, \n\t{i = }, {get_info(kwargs) = }')

            # load 1 / ngb of next layer
            #   copy from home buffers at g/c/d to layer running buffer

            if self.kv_cache_kwarg_name not in kwargs:
                ## PREFILL PHASE 
                # 1. offload kv & actv cache of prev batch
                ...
            else:
                ## DECODING PHASE
                # 1. load kv & actv caches of next batch 

                # 2. offload kv | actv caches of prev batch 

                # 3. compute curr batch
                ...

            if isinstance(self.hf_model, (OPTForCausalLM, )):
                actv_recomp = args[0] # b,1,h / bzh
                kv_cache = kwargs.get('past_key_value') # b,n_kv_heads,s_cache,h_kv    x2
                attn_mask = kwargs.get('attention_mask') # b,1,1,s_all  (bsz, 1, tgt_len, src_len)

            # prepare args, kwargs for hf api's
            args_for_old = args
            kwargs_for_old = kwargs

            # hf execution
            old_output = old_forward(*args_for_old, **kwargs_for_old) # h'=(b,z,h), kv=(b,n,s_all,h) x2
            
            # prepare our output from hf's output
            output = old_output
            print(f'\t{i = }, {get_info(output) = }\n')

            # swap: curr buff & next buff (for batch & layer levels)
            
            return output
        
        layer.forward = new_forward
        return layer

    def override_hf_model_forward(self):
        old_forward = self.hf_model.forward
        @functools.wraps(old_forward)
        def new_forward(*args, **kwargs):
            print(f'hf_model {get_info(args) = }, \nhf_model {get_info(kwargs) = }\n')

            # new to hf: args, kwargs
            args_for_old = args
            kwargs_for_old = kwargs

            # hf execution
            old_output = old_forward(*args_for_old, **kwargs_for_old) 

            # hf to new: output
            output = old_output 
            print(f'hf_model {get_info(output) = }\n')
            
            return output
        
        self.hf_model.forward = new_forward
        return self.hf_model

    def override_forward_functions(self):
        for i, _ in enumerate(self.layers):
            self.override_layer_forward(i)
        self.override_hf_model_forward()
        return self.hf_model 



In [9]:
num_prompts = 16
prompts = None
prompt_len = 50
comp_device = 0
gen_len = 20


# hf_model= OPTForCausalLM.from_pretrained(checkpoint)
model = Model(
    checkpoint=checkpoint,
    comp_device=comp_device,
    torch_dtype=torch_dtype, 
    weights_offload_dir=weights_offload_dir
).override_forward_functions()




In [10]:
# [(n, t.device) for n, t in named_module_tensors(model, recurse=True)]

In [11]:

# test
if True:
    if prompts is None:  # get default prompts
        prompts = [
            "for i in range(10): ",
            "Who are you? Are you conscious?",
            "Where is Deutschland?",
            "How is Huawei Mate 60 Pro?",
        ]
    prompts = (
        prompts * (num_prompts // len(prompts))
        + prompts[: (num_prompts % len(prompts))]
    )

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(checkpoint) # , padding_side="left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # eos padding

    # inputs
    inputs = tokenizer(
        prompts,
        padding="max_length",
        max_length=prompt_len,
        return_tensors="pt",
        # padding=True,
    ).to(comp_device)

    # generate
    generate_ids = model.generate(
        inputs.input_ids,
        max_new_tokens=gen_len,  # max_lengths
        
        # num_beams=6, 

        # num_beam_groups=2, 
        # diversity_penalty=0.1, 
        
        # do_sample=True, 
    )

    # outputs
    output_texts = tokenizer.batch_decode(
        generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    print(output_texts)

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  return f"{obj.__class__.__name__}(shape={tuple(obj.size())}, mem/elem/dtype={sys.getsizeof(obj.storage()) / obj.numel() / obj.element_size():.3f})"


hf_model get_info(args) = (), 
hf_model get_info(kwargs) = {'input_ids': 'Tensor(shape=(16, 50), mem/elem/dtype=1.008)', 'past_key_values': 'None', 'use_cache': 'True', 'attention_mask': 'Tensor(shape=(16, 50), mem/elem/dtype=1.008)', 'return_dict': 'True', 'output_attentions': 'False', 'output_hidden_states': 'False'}

	i = 0, get_info(args) = ('Tensor(shape=(16, 50, 768), mem/elem/dtype=1.000)',), 
	i = 0, get_info(kwargs) = {'attention_mask': 'Tensor(shape=(16, 1, 50, 50), mem/elem/dtype=1.001)', 'layer_head_mask': 'None', 'past_key_value': 'None', 'output_attentions': 'False', 'use_cache': 'True'}
	i = 0, get_info(output) = ('Tensor(shape=(16, 50, 768), mem/elem/dtype=1.000)', '2 * Tensor(shape=(16, 12, 50, 64), mem/elem/dtype=1.000)')

	i = 1, get_info(args) = ('Tensor(shape=(16, 50, 768), mem/elem/dtype=1.000)',), 
	i = 1, get_info(kwargs) = {'attention_mask': 'Tensor(shape=(16, 1, 50, 50), mem/elem/dtype=1.001)', 'layer_head_mask': 'None', 'past_key_value': 'None', 'output_atten