In [10]:
from transformers import OPTForCausalLM, AutoTokenizer, AutoConfig
import torch
import torch.nn as nn

import numpy as np
from numpy.lib.format import open_memmap

import sys
from threading import Thread
from queue import Queue 

import functools 

In [11]:
checkpoint = 'facebook/opt-125m'

In [12]:
def find_module_list(module: nn.Module):
    def _find_module_list(module: nn.Module):
        if isinstance(module, nn.ModuleList):
            yield module
        else:
            for child in module.children():
                yield from _find_module_list(child)
    
    g = _find_module_list(module)
    try:
        return next(iter(g))
    except:
        raise ValueError(f'{module.__class__.__name__} does not have a nn.ModuleList structure')

import torch
from accelerate.utils import honor_type
from typing import Mapping

def get_info(obj):
    if isinstance(obj, (tuple, list)):
        if len(obj) and all(type(o) == type(obj[0]) and len(o) == len(obj[0]) for o in obj):
            return f"{len(obj)}x {get_info(obj[0])}"
        else:
            return honor_type(obj, (get_info(o) for o in obj))
    elif isinstance(obj, Mapping):
        return type(obj)({k: get_info(v) for k, v in obj.items()})
    elif isinstance(obj, (torch.Tensor)):
        return f"{obj.__class__.__name__}(shape={tuple(obj.size())}, dtype={obj.dtype}, device={obj.device}, mem/elem/dtype={sys.getsizeof(obj.storage()) / obj.numel() / obj.element_size():.3f})"
    # elif isinstance(obj, (MixTensor)):
    #     return f"{obj.__class__.__name__}(shape={tuple(obj.size())}, dtype={obj.dtype}, percents={obj.percents})"
    # elif isinstance(obj, (BatchListTensor)):
    #     return f"{obj.__class__.__name__}(shape={tuple(obj.size())}, dtype={obj.dtype}, has_mix={obj.has_mix})"
    elif isinstance(obj, (int, bool, type(None))):
        return f"{obj}"
    else:
        # logger.warning(f"inputs: {obj} of type '{type(obj)}' is not implemented.")
        return f"{obj.__class__.__name__}: {obj}"


class Model:
    """workload balance"""
    def __init__(self, hf_model, ) -> None:
        self.hf_model = hf_model 
        self.layers = self.get_layers()

    def get_layers(self):
        if isinstance(self.hf_model, (OPTForCausalLM, )):
            return self.hf_model.model.decoder.layers 
        else:
            return find_module_list(self.hf_model)
    
    def override_layer_forward(self, i: int):
        layer = self.layers[i]
        old_forward = layer.forward

        @functools.wraps(old_forward)
        def new_forward(*args, **kwargs):
            print(f'\t{i = }, {get_info(args) = }, \n\t{i = }, {get_info(kwargs) = }')

            if isinstance(self.hf_model, (OPTForCausalLM, )):
                actv_recomp = args[0] # b,1,h / bzh
                kv_cache = kwargs.get('past_key_value') # b,n_kv_heads,s_cache,h_kv    x2
                attn_mask = kwargs.get('attention_mask') # b,1,1,s_all  (bsz, 1, tgt_len, src_len)

            # wlb formatted args, kwargs to formats for the original forward of hf model
            args_for_old = args
            kwargs_for_old = kwargs
            output = old_forward(*args_for_old, **kwargs_for_old) # h'=(b,z,h), kv=(b,n,s_all,h) x2
            
            print(f'\t{i = }, {get_info(output) = }\n')
            # to wlb output
            
            return output
        
        layer.forward = new_forward
        return layer

    def override_hf_model_forward(self):
        old_forward = self.hf_model.forward
        @functools.wraps(old_forward)
        def new_forward(*args, **kwargs):
            print(f'hf_model {get_info(args) = }, \nhf_model {get_info(kwargs) = }\n')

            # wlb formatted args & kwargs to formats for the original forward of hf model
            args_for_old = args
            kwargs_for_old = kwargs
            output = old_forward(*args_for_old, **kwargs_for_old) 

            print(f'hf_model {get_info(output) = }\n')
            
            return output
        
        self.hf_model.forward = new_forward
        return self.hf_model

    def build(self):
        for i, _ in enumerate(self.layers):
            self.override_layer_forward(i)
        self.override_hf_model_forward()
        return self.hf_model 

hf_model= OPTForCausalLM.from_pretrained(checkpoint)
x_model = Model(hf_model).build()



In [13]:
num_prompts = 16
prompts = None
checkpoint = checkpoint
prompt_len = 50
compute_device = 0
gen_len = 20
model = x_model.to(compute_device) # 

# test
if True:
    if prompts is None:  # get default prompts
        prompts = [
            "for i in range(10): ",
            "Who are you? Are you conscious?",
            "Where is Deutschland?",
            "How is Huawei Mate 60 Pro?",
        ]
    prompts = (
        prompts * (num_prompts // len(prompts))
        + prompts[: (num_prompts % len(prompts))]
    )

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(checkpoint) # , padding_side="left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # eos padding

    # inputs
    inputs = tokenizer(
        prompts,
        padding="max_length",
        max_length=prompt_len,
        return_tensors="pt",
        # padding=True,
    ).to(compute_device)

    # generate
    generate_ids = model.generate(
        inputs.input_ids,
        max_new_tokens=gen_len,  # max_lengths
        
        num_beams=6, #
        num_beam_groups=2, #
        diversity_penalty=0.1, #
        # do_sample=True, #
    )

    # outputs
    output_texts = tokenizer.batch_decode(
        generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    print(output_texts)

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


GenerationMode.GROUP_BEAM_SEARCH
hf_model get_info(args) = (), 
hf_model get_info(kwargs) = {'input_ids': 'Tensor(shape=(96, 50), dtype=torch.int64, device=cuda:0, mem/elem/dtype=1.001)', 'past_key_values': 'None', 'use_cache': 'True', 'attention_mask': 'Tensor(shape=(96, 50), dtype=torch.int64, device=cuda:0, mem/elem/dtype=1.001)', 'return_dict': 'True', 'output_attentions': 'False', 'output_hidden_states': 'False'}

	i = 0, get_info(args) = '1x Tensor(shape=(96, 50, 768), dtype=torch.float32, device=cuda:0, mem/elem/dtype=1.000)', 
	i = 0, get_info(kwargs) = {'attention_mask': 'Tensor(shape=(96, 1, 50, 50), dtype=torch.float32, device=cuda:0, mem/elem/dtype=1.000)', 'layer_head_mask': 'None', 'past_key_value': 'None', 'output_attentions': 'False', 'use_cache': 'True'}
	i = 0, get_info(output) = ('Tensor(shape=(96, 50, 768), dtype=torch.float32, device=cuda:0, mem/elem/dtype=1.000)', '2x Tensor(shape=(96, 12, 50, 64), dtype=torch.float32, device=cuda:0, mem/elem/dtype=1.000)')

	i = 