In [1]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, OPTForCausalLM, MistralForCausalLM
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
from accelerate.utils import named_module_tensors, find_tied_parameters

import numpy as np
from numpy.lib.format import open_memmap

import sys
from threading import Thread
from queue import Queue 

import functools 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint = 'facebook/opt-125m'
config = AutoConfig.from_pretrained(checkpoint)
with init_empty_weights(): 
    e: OPTForCausalLM = AutoModelForCausalLM.from_config(config,)
e.tie_weights()
find_tied_parameters(e)



[['lm_head.weight', 'model.decoder.embed_tokens.weight']]

In [33]:
comp_device = 0
res = {}
for n, t in named_module_tensors(e, recurse=True):
    if isinstance(t, nn.Parameter) and 'model.decoder.layers' in n:
        res[n] = 'meta'
    else:
        res[n] = comp_device
res['lm_head'] = 0
res

{'model.decoder.embed_tokens.weight': 0,
 'model.decoder.embed_positions.weight': 0,
 'model.decoder.final_layer_norm.weight': 0,
 'model.decoder.final_layer_norm.bias': 0,
 'model.decoder.layers.0.self_attn.k_proj.weight': 'meta',
 'model.decoder.layers.0.self_attn.k_proj.bias': 'meta',
 'model.decoder.layers.0.self_attn.v_proj.weight': 'meta',
 'model.decoder.layers.0.self_attn.v_proj.bias': 'meta',
 'model.decoder.layers.0.self_attn.q_proj.weight': 'meta',
 'model.decoder.layers.0.self_attn.q_proj.bias': 'meta',
 'model.decoder.layers.0.self_attn.out_proj.weight': 'meta',
 'model.decoder.layers.0.self_attn.out_proj.bias': 'meta',
 'model.decoder.layers.0.self_attn_layer_norm.weight': 'meta',
 'model.decoder.layers.0.self_attn_layer_norm.bias': 'meta',
 'model.decoder.layers.0.fc1.weight': 'meta',
 'model.decoder.layers.0.fc1.bias': 'meta',
 'model.decoder.layers.0.fc2.weight': 'meta',
 'model.decoder.layers.0.fc2.bias': 'meta',
 'model.decoder.layers.0.final_layer_norm.weight': 'met

In [34]:
m = AutoModelForCausalLM.from_pretrained(checkpoint, device_map=res, offload_folder='./_offload')

In [35]:
m.lm_head.weight

Parameter containing:
tensor([[ 0.1150, -0.1438,  0.0555,  ...,  0.2146,  0.0833,  0.0669],
        [ 0.1149, -0.1438,  0.0547,  ...,  0.2145,  0.0833,  0.0669],
        [ 0.0010, -0.0922,  0.1025,  ..., -0.0402,  0.0060, -0.1078],
        ...,
        [ 0.1152, -0.1437,  0.0547,  ...,  0.2145,  0.0833,  0.0671],
        [ 0.1151, -0.1455,  0.0546,  ...,  0.2156,  0.0837,  0.0673],
        [ 0.1156, -0.1437,  0.0577,  ...,  0.2139,  0.0833,  0.0650]],
       device='cuda:0', requires_grad=True)

In [36]:
m.model.decoder.layers[0].fc1.weight

Parameter containing:
tensor(..., device='meta', size=(3072, 768), requires_grad=True)

In [None]:
checkpoint = 'mistralai/Mistral-7B-v0.1'
mis = AutoModelForCausalLM.from_pretrained(checkpoint, device_map={'': "meta"})
print(mis)
print(set(t.device for n, t in named_module_tensors(mis, recurse=True)))
print(list(n for n, t in named_module_tensors(mis, recurse=True)))

In [None]:
def find_module_list(module: nn.Module):
    def _find_module_list(module: nn.Module, prefix=''):
        if isinstance(module, nn.ModuleList):
            yield module, prefix
        else:
            for name, child in module.named_children():
                yield from _find_module_list(child, prefix=prefix+'.'+name if prefix else name)
    
    g = _find_module_list(module)
    try:
        return next(iter(g))
    except:
        raise ValueError(f'{module.__class__.__name__} does not have a nn.ModuleList structure')

find_module_list(mis), find_module_list(e)

In [None]:

import torch
from accelerate.utils import honor_type
from typing import Mapping

def get_info(obj, debug=False):
    if isinstance(obj, (tuple, list)):
        ret = honor_type(obj, (get_info(o) for o in obj))
        if len(set(ret)) == 1 and len(ret) > 1:
            return f"{len(ret)} * {ret[0]}"
        else:
            return ret 
    elif isinstance(obj, Mapping):
        return type(obj)({k: get_info(v) for k, v in obj.items()})
    elif isinstance(obj, (torch.Tensor)):
        if debug:
            return f"{obj.__class__.__name__}(shape={tuple(obj.size())}, dtype={obj.dtype}, device={obj.device}, mem/elem/dtype={sys.getsizeof(obj.storage()) / obj.numel() / obj.element_size():.3f})"
        else:
            return f"{obj.__class__.__name__}(shape={tuple(obj.size())}, mem/elem/dtype={sys.getsizeof(obj.storage()) / obj.numel() / obj.element_size():.3f})"
    elif isinstance(obj, (int, bool, type(None))):
        return f"{obj}"
    else:
        return f"{obj.__class__.__name__}: {obj}"

from data_movement import Engine, Task

class Model:
    """
    1. override forward functions
    """
    def __init__(self, hf_model, comp_device=0, **kwargs) -> None:
        self.checkpoint = kwargs.get('checkpoint')
        self.torch_dtype = kwargs.get('torch_dtype')
        self.config = AutoConfig.from_pretrained(self.checkpoint, torch_dtype=self.torch_dtype)
        with init_empty_weights(): # while buffers are not empty
            self.hf_model = AutoModelForCausalLM.from_config(self.config, torch_dtype=self.torch_dtype)
        self.layers, self.layers_name = self.get_layers()
        
        self.comp_device = comp_device

        self.dm_engine = Engine(self.comp_device)

        # init model 
        self.hf_model = hf_model.to(comp_device)


    def get_layers(self) -> tuple[nn.Module, str]:
        if isinstance(self.hf_model, (OPTForCausalLM, )):
            return self.hf_model.model.decoder.layers, 'model.decoder.layers'
        else:
            def find_module_list(module: nn.Module) -> tuple[nn.Module, str]:
                def _find_module_list(module: nn.Module, prefix=''):
                    if isinstance(module, nn.ModuleList):
                        yield module, prefix
                    else:
                        for name, child in module.named_children():
                            yield from _find_module_list(child, prefix=prefix+'.'+name if prefix else name)
                
                g = _find_module_list(module)
                try:
                    return next(iter(g))
                except:
                    raise ValueError(f'{module.__class__.__name__} does not have a nn.ModuleList structure')

            return find_module_list(self.hf_model)
    
    def override_layer_forward(self, i: int):
        layer = self.layers[i]
        old_forward = layer.forward

        @functools.wraps(old_forward)
        def new_forward(*args, **kwargs):
            print(f'\t{i = }, {get_info(args) = }, \n\t{i = }, {get_info(kwargs) = }')

            if isinstance(self.hf_model, (OPTForCausalLM, )):
                actv_recomp = args[0] # b,1,h / bzh
                kv_cache = kwargs.get('past_key_value') # b,n_kv_heads,s_cache,h_kv    x2
                attn_mask = kwargs.get('attention_mask') # b,1,1,s_all  (bsz, 1, tgt_len, src_len)

            # new to hf: args, kwargs
            args_for_old = args
            kwargs_for_old = kwargs

            # hf execution
            old_output = old_forward(*args_for_old, **kwargs_for_old) # h'=(b,z,h), kv=(b,n,s_all,h) x2
            
            # hf to new: output
            output = old_output
            print(f'\t{i = }, {get_info(output) = }\n')
            
            return output
        
        layer.forward = new_forward
        return layer

    def override_hf_model_forward(self):
        old_forward = self.hf_model.forward
        @functools.wraps(old_forward)
        def new_forward(*args, **kwargs):
            print(f'hf_model {get_info(args) = }, \nhf_model {get_info(kwargs) = }\n')

            # new to hf: args, kwargs
            args_for_old = args
            kwargs_for_old = kwargs

            # hf execution
            old_output = old_forward(*args_for_old, **kwargs_for_old) 

            # hf to new: output
            output = old_output 
            print(f'hf_model {get_info(output) = }\n')
            
            return output
        
        self.hf_model.forward = new_forward
        return self.hf_model

    def build(self):
        for i, _ in enumerate(self.layers):
            self.override_layer_forward(i)
        self.override_hf_model_forward()
        return self.hf_model 



In [None]:
num_prompts = 16
prompts = None
checkpoint = 'facebook/opt-125m'
prompt_len = 50
comp_device = 0
gen_len = 20


hf_model= OPTForCausalLM.from_pretrained(checkpoint)
model = Model(hf_model, comp_device=comp_device).build()
model = model.to(comp_device) # 

# test
if True:
    if prompts is None:  # get default prompts
        prompts = [
            "for i in range(10): ",
            "Who are you? Are you conscious?",
            "Where is Deutschland?",
            "How is Huawei Mate 60 Pro?",
        ]
    prompts = (
        prompts * (num_prompts // len(prompts))
        + prompts[: (num_prompts % len(prompts))]
    )

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(checkpoint) # , padding_side="left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # eos padding

    # inputs
    inputs = tokenizer(
        prompts,
        padding="max_length",
        max_length=prompt_len,
        return_tensors="pt",
        # padding=True,
    ).to(comp_device)

    # generate
    generate_ids = model.generate(
        inputs.input_ids,
        max_new_tokens=gen_len,  # max_lengths
        
        num_beams=6, #
        num_beam_groups=2, #
        diversity_penalty=0.1, #
        # do_sample=True, #
    )

    # outputs
    output_texts = tokenizer.batch_decode(
        generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    print(output_texts)