In [12]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, OPTForCausalLM, MistralForCausalLM
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
from accelerate.utils import named_module_tensors, find_tied_parameters

import numpy as np
from numpy.lib.format import open_memmap

import sys
from threading import Thread
from queue import Queue 

import functools 

In [16]:
checkpoint = 'facebook/opt-125m'
config = AutoConfig.from_pretrained(checkpoint)
with init_empty_weights(): 
    e = AutoModelForCausalLM.from_config(config,)
e.tie_weights()
find_tied_parameters(e)

[['lm_head.weight', 'model.decoder.embed_tokens.weight']]

In [21]:
checkpoint = 'mistralai/Mistral-7B-v0.1'
mis = AutoModelForCausalLM.from_pretrained(checkpoint, device_map={'': "meta"})
print(mis)
print(set(t.device for n, t in named_module_tensors(mis, recurse=True)))
print(list(n for n, t in named_module_tensors(mis, recurse=True)))

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.33it/s]


MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm(

In [24]:
def find_module_list(module: nn.Module):
    def _find_module_list(module: nn.Module, prefix=''):
        if isinstance(module, nn.ModuleList):
            yield module, prefix
        else:
            for name, child in module.named_children():
                yield from _find_module_list(child, prefix=prefix+'.'+name if prefix else name)
    
    g = _find_module_list(module)
    try:
        return next(iter(g))
    except:
        raise ValueError(f'{module.__class__.__name__} does not have a nn.ModuleList structure')

find_module_list(mis), find_module_list(e)

((ModuleList(
    (0-31): 32 x MistralDecoderLayer(
      (self_attn): MistralSdpaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): MistralRotaryEmbedding()
      )
      (mlp): MistralMLP(
        (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): MistralRMSNorm()
      (post_attention_layernorm): MistralRMSNorm()
    )
  ),
  'model.layers'),
 (ModuleList(
    (0-11): 12 x OPTDecoderLayer(
      (self_attn): OPTAttention(
        (k_proj): Linear(in_features=768, out_features=768, bi

In [3]:

import torch
from accelerate.utils import honor_type
from typing import Mapping

def get_info(obj, debug=False):
    if isinstance(obj, (tuple, list)):
        ret = honor_type(obj, (get_info(o) for o in obj))
        if len(set(ret)) == 1 and len(ret) > 1:
            return f"{len(ret)} * {ret[0]}"
        else:
            return ret 
    elif isinstance(obj, Mapping):
        return type(obj)({k: get_info(v) for k, v in obj.items()})
    elif isinstance(obj, (torch.Tensor)):
        if debug:
            return f"{obj.__class__.__name__}(shape={tuple(obj.size())}, dtype={obj.dtype}, device={obj.device}, mem/elem/dtype={sys.getsizeof(obj.storage()) / obj.numel() / obj.element_size():.3f})"
        else:
            return f"{obj.__class__.__name__}(shape={tuple(obj.size())}, mem/elem/dtype={sys.getsizeof(obj.storage()) / obj.numel() / obj.element_size():.3f})"
    elif isinstance(obj, (int, bool, type(None))):
        return f"{obj}"
    else:
        return f"{obj.__class__.__name__}: {obj}"

from data_movement import Engine, Task

class Model:
    """
    1. override forward functions
    """
    def __init__(self, hf_model, comp_device=0, **kwargs) -> None:
        self.checkpoint = kwargs.get('checkpoint')
        self.torch_dtype = kwargs.get('torch_dtype')
        self.config = AutoConfig.from_pretrained(self.checkpoint, torch_dtype=self.torch_dtype)
        with init_empty_weights(): # while buffers are not empty
            self.hf_model = AutoModelForCausalLM.from_config(self.config, torch_dtype=self.torch_dtype)
        self.layers, self.layers_name = self.get_layers()
        
        self.comp_device = comp_device

        self.dm_engine = Engine(self.comp_device)

        # init model 
        self.hf_model = hf_model.to(comp_device)


    def get_layers(self) -> tuple[nn.Module, str]:
        if isinstance(self.hf_model, (OPTForCausalLM, )):
            return self.hf_model.model.decoder.layers, 'model.decoder.layers'
        else:
            def find_module_list(module: nn.Module) -> tuple[nn.Module, str]:
                def _find_module_list(module: nn.Module, prefix=''):
                    if isinstance(module, nn.ModuleList):
                        yield module, prefix
                    else:
                        for name, child in module.named_children():
                            yield from _find_module_list(child, prefix=prefix+'.'+name if prefix else name)
                
                g = _find_module_list(module)
                try:
                    return next(iter(g))
                except:
                    raise ValueError(f'{module.__class__.__name__} does not have a nn.ModuleList structure')

            return find_module_list(self.hf_model)
    
    def override_layer_forward(self, i: int):
        layer = self.layers[i]
        old_forward = layer.forward

        @functools.wraps(old_forward)
        def new_forward(*args, **kwargs):
            print(f'\t{i = }, {get_info(args) = }, \n\t{i = }, {get_info(kwargs) = }')

            if isinstance(self.hf_model, (OPTForCausalLM, )):
                actv_recomp = args[0] # b,1,h / bzh
                kv_cache = kwargs.get('past_key_value') # b,n_kv_heads,s_cache,h_kv    x2
                attn_mask = kwargs.get('attention_mask') # b,1,1,s_all  (bsz, 1, tgt_len, src_len)

            # new to hf: args, kwargs
            args_for_old = args
            kwargs_for_old = kwargs

            # hf execution
            old_output = old_forward(*args_for_old, **kwargs_for_old) # h'=(b,z,h), kv=(b,n,s_all,h) x2
            
            # hf to new: output
            output = old_output
            print(f'\t{i = }, {get_info(output) = }\n')
            
            return output
        
        layer.forward = new_forward
        return layer

    def override_hf_model_forward(self):
        old_forward = self.hf_model.forward
        @functools.wraps(old_forward)
        def new_forward(*args, **kwargs):
            print(f'hf_model {get_info(args) = }, \nhf_model {get_info(kwargs) = }\n')

            # new to hf: args, kwargs
            args_for_old = args
            kwargs_for_old = kwargs

            # hf execution
            old_output = old_forward(*args_for_old, **kwargs_for_old) 

            # hf to new: output
            output = old_output 
            print(f'hf_model {get_info(output) = }\n')
            
            return output
        
        self.hf_model.forward = new_forward
        return self.hf_model

    def build(self):
        for i, _ in enumerate(self.layers):
            self.override_layer_forward(i)
        self.override_hf_model_forward()
        return self.hf_model 



In [4]:
num_prompts = 16
prompts = None
checkpoint = 'facebook/opt-125m'
prompt_len = 50
comp_device = 0
gen_len = 20


hf_model= OPTForCausalLM.from_pretrained(checkpoint)
model = Model(hf_model, comp_device=comp_device).build()
model = model.to(comp_device) # 

# test
if True:
    if prompts is None:  # get default prompts
        prompts = [
            "for i in range(10): ",
            "Who are you? Are you conscious?",
            "Where is Deutschland?",
            "How is Huawei Mate 60 Pro?",
        ]
    prompts = (
        prompts * (num_prompts // len(prompts))
        + prompts[: (num_prompts % len(prompts))]
    )

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(checkpoint) # , padding_side="left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # eos padding

    # inputs
    inputs = tokenizer(
        prompts,
        padding="max_length",
        max_length=prompt_len,
        return_tensors="pt",
        # padding=True,
    ).to(comp_device)

    # generate
    generate_ids = model.generate(
        inputs.input_ids,
        max_new_tokens=gen_len,  # max_lengths
        
        num_beams=6, #
        num_beam_groups=2, #
        diversity_penalty=0.1, #
        # do_sample=True, #
    )

    # outputs
    output_texts = tokenizer.batch_decode(
        generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    print(output_texts)

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  return f"{obj.__class__.__name__}(shape={tuple(obj.size())}, mem/elem/dtype={sys.getsizeof(obj.storage()) / obj.numel() / obj.element_size():.3f})"


hf_model get_info(args) = (), 
hf_model get_info(kwargs) = {'input_ids': 'Tensor(shape=(96, 50), mem/elem/dtype=1.001)', 'past_key_values': 'None', 'use_cache': 'True', 'attention_mask': 'Tensor(shape=(96, 50), mem/elem/dtype=1.001)', 'return_dict': 'True', 'output_attentions': 'False', 'output_hidden_states': 'False'}

	i = 0, get_info(args) = ('Tensor(shape=(96, 50, 768), mem/elem/dtype=1.000)',), 
	i = 0, get_info(kwargs) = {'attention_mask': 'Tensor(shape=(96, 1, 50, 50), mem/elem/dtype=1.000)', 'layer_head_mask': 'None', 'past_key_value': 'None', 'output_attentions': 'False', 'use_cache': 'True'}
	i = 0, get_info(output) = ('Tensor(shape=(96, 50, 768), mem/elem/dtype=1.000)', '2 * Tensor(shape=(96, 12, 50, 64), mem/elem/dtype=1.000)')

	i = 1, get_info(args) = ('Tensor(shape=(96, 50, 768), mem/elem/dtype=1.000)',), 
	i = 1, get_info(kwargs) = {'attention_mask': 'Tensor(shape=(96, 1, 50, 50), mem/elem/dtype=1.000)', 'layer_head_mask': 'None', 'past_key_value': 'None', 'output_atten