In [1]:
import os

# Set the environment variables before using the transformers library
os.environ["HF_HOME"] = "/serenity/scratch/hkolisetty6/.cache/huggingface"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import sys
sys.path.insert(0, "./transformers/src")
sys.path.insert(0, "./peft/src")


In [2]:
from os.path import isdir, join
import numpy as np
import json
import logging

import torch
import transformers
import argparse
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaTokenizer,
    BitsAndBytesConfig,
    Seq2SeqTrainer,
    set_seed,
)
import evaluate

from peft import (
    PeftModel
)
from peft.tuners.lora import LoraLayer

from main import (
    ModelArguments, 
    DataArguments, 
    TrainingArguments, 
    GenerationArguments, 
    smart_tokenizer_and_embedding_resize,
    DEFAULT_PAD_TOKEN,
    make_data_module,
    load_dataset,
    TokenTimingStoppingCriteria,
    eval_all,
)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def get_compute_dtype(args):
    compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
    if torch.cuda.is_bf16_supported() and compute_dtype == torch.float16:
        print('GPU supports bfloat16, so switching the compute_dtype to torch.bfloat16')
        compute_dtype = torch.bfloat16
    return compute_dtype

def parse_args(args_list=None):
    '''
    Parse the command line arguments for the Hugging Face model, data, training, and generation arguments
    '''
    # Create an argument parser for the Hugging Face model, data, training, and generation arguments
    hfparser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, GenerationArguments))

    # Parse the command line arguments into the respective dataclass instances
    if args_list is None:
        model_args, data_args, training_args, generation_args, _ = \
            hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
    else:
        model_args, data_args, training_args, generation_args, _ = \
            hfparser.parse_args_into_dataclasses(args_list, return_remaining_strings=True)
    training_args.generation_config = transformers.GenerationConfig(
        **vars(generation_args)
    )

    # Combine the parsed arguments into a single argparse.Namespace object
    args = argparse.Namespace(
        **vars(model_args), **vars(data_args), **vars(training_args)
    )

    # Set additional parameters for later use
    args.compute_dtype = get_compute_dtype(args)

    return args, training_args

def get_last_checkpoint(output_dir):
    '''
    Given the output directory, return the last checkpoint directory
    '''
    assert output_dir is not None, 'Output directory must be specified'
    assert isdir(output_dir), 'Output directory does not exist'

    max_steps = 0
    for filename in os.listdir(output_dir):
        if isdir(join(output_dir, filename)) and filename.startswith('checkpoint-'):
            step = int(filename.split('-')[-1])
            if step > max_steps:
                max_steps = step
    assert max_steps > 0, 'No checkpoints found in output directory'
    last_checkpoint_dir = join(output_dir, f'checkpoint-{max_steps}')
    return last_checkpoint_dir

def get_tokenizer(args, model):
    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path=args.model_name_or_path,
        token=args.use_auth_token,
        padding_side='right',
        tokenizer_type='llama' if 'llama' in args.model_name_or_path else None, # Needed for HF name change
    )
    if tokenizer._pad_token is None:
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
            tokenizer=tokenizer,
            model=model,
        )
    if 'llama' in args.model_name_or_path or isinstance(tokenizer, LlamaTokenizer):
        # LLaMA tokenizer may not have correct special tokens set.
        # Check and add them if missing to prevent them from being parsed into different tokens.
        # Note that these are present in the vocabulary.
        # Note also that `model.config.pad_token_id` is 0 which corresponds to `<unk>` token.
        print('Adding correct special tokens to the llama tokenizer')
        tokenizer.add_special_tokens({
                "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
                "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
        })
    return tokenizer

# Assumes the pruning width method is flap
def set_width_mask_and_bias(model, args):
    '''
    Set the width_mask and bias for the model

    hkolisetty6:
    My understanding is that each module will hold the width_mask and bias for the corresponding layer.
    The width_mask and bias are dicts with keys as width_ratio and values as the mask/bias tensor.
    During inference, the width_mask and bias are applied to the output of the module.
    In essense, full matrix multiplication is done and then the width_mask and bias are applied 
    to the output before passing it to the next layer.
    '''
    shrink_file = np.load(args.shrinking_file, allow_pickle=True).item()
    assert 'width_mask' in shrink_file, 'Width mask not found in shrinking file'
    assert 'bias' in shrink_file, 'Bias not found in shrinking file'

    width_mask = shrink_file['width_mask']
    bias = shrink_file['bias']

    for name, module in model.named_modules():
        if name in width_mask:
            mask_dtype = args.compute_dtype # TODO hkolisetty6: in original code, this is set to torch.float32 when args.fp16 is True
            if 'mlp.down_proj' in name or 'self_attn.o_proj' in name:
                assert width_mask[name] is None
                for key in bias[name].keys():
                    bias[name][key] = torch.from_numpy(bias[name][key]).to(mask_dtype)
                module.set_width_mask(width_mask=None, output_bias=bias[name])
            else:
                assert bias[name] is None
                for key in width_mask[name].keys():
                    width_mask[name][key] = torch.from_numpy(width_mask[name][key]).to(mask_dtype)
                module.set_width_mask(width_mask=width_mask[name], output_bias=None)


# Assumes model is loaded on a single GPU
def load_model(args, checkpoint_dir):
    '''
    Given the command line arguments and the checkpoint directory, return the model
    '''
    shrink_config = {
        'enable_shrinking': args.enable_shrinking,
        'shrinkable_width': args.shrinkable_width,
        'shrinking_method': args.shrinking_method,
        'shrinking_file': args.shrinking_file,
        'mask_dtype': str(args.compute_dtype),
    }
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=args.bits == 4,
        load_in_8bit=args.bits == 8,
        llm_int8_threshold=6.0,
        llm_int8_has_fp16_weight=False,
        bnb_4bit_compute_dtype=args.compute_dtype,
        bnb_4bit_use_double_quant=args.double_quant,
        bnb_4bit_quant_type=args.quant_type,
    )

    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=args.model_name_or_path,
        device_map='auto',
        torch_dtype=args.compute_dtype, # TODO hkolisetty6: in original code, this is set to torch.float32
        token=args.use_auth_token,
        shrink_config=shrink_config,
        quantization_config=quantization_config,
    )
    model.config.torch_dtype = args.compute_dtype

    # !IMPORTANT 
    # Load the tokenizer before loading the adapters to ensure that the special tokens, if not available, are added to embeddings
    # This is important for the adapters to be loaded correctly
    tokenizer = get_tokenizer(args, model)

    # Load adapters from checkpoint
    print('Loading adapters from checkpoint')
    model = PeftModel.from_pretrained(
        model=model,
        model_id=join(checkpoint_dir, 'adapter_model'),
    )
    
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer) or 'lm_head' in name or 'embed_tokens' in name:
            if args.compute_dtype == torch.bfloat16:
                module = module.to(torch.float32)
        if 'norm' in name:
            module = module.to(torch.float32)
    
    model.config.use_cache = False # Check LlamaConfig for this attribute

    # Set the width mask and bias for the model
    set_width_mask_and_bias(model, args)

    return model, tokenizer

def get_mmlu_dataset_for_evaluation(args, tokenizer):
    if args.do_mmlu_eval:
        if args.mmlu_dataset == 'mmlu-zs':
            mmlu_dataset = load_dataset("json", data_files={
                'eval': 'data/mmlu/zero_shot_mmlu_val.json',
                'test': 'data/mmlu/zero_shot_mmlu_test.json',
            })
            mmlu_dataset = mmlu_dataset.remove_columns('subject')
        # MMLU Five-shot (Eval/Test only)
        elif args.mmlu_dataset == 'mmlu' or args.mmlu_dataset == 'mmlu-fs':
            mmlu_dataset = load_dataset("json", data_files={
                'eval': 'data/mmlu/five_shot_mmlu_val.json',
                'test': 'data/mmlu/five_shot_mmlu_test.json',
            })
            # mmlu_dataset = mmlu_dataset.remove_columns('subject')
        mmlu_dataset = mmlu_dataset[args.mmlu_split]
        if args.max_mmlu_samples is not None:
            mmlu_dataset = mmlu_dataset.select(range(args.max_mmlu_samples))
        abcd_idx = [
            tokenizer("A", add_special_tokens=False).input_ids[0],
            tokenizer("B", add_special_tokens=False).input_ids[0],
            tokenizer("C", add_special_tokens=False).input_ids[0],
            tokenizer("D", add_special_tokens=False).input_ids[0],
        ]
        accuracy = evaluate.load("accuracy")
        return mmlu_dataset, abcd_idx, accuracy
    return None, None, None

def eval_model(trainer, args, logger, all_metrics):
    if args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate(metric_key_prefix="eval")
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
        all_metrics.update(metrics)

# Assumes shrinking_method is calib_dp and shrinking is enabled
def setup_model_for_inference(model, args):
    '''
    Setup model for inference by activating the layers and setting the width ratio
    '''
    strategy = np.load(args.shrinking_file, allow_pickle=True).item()["strategy"]
    if 0 not in list(strategy.keys()):
        strategy[0] = np.ones(model.config.num_hidden_layers)

    active_layers_attn = active_layers_mlp = strategy[
        model.config.num_hidden_layers - args.eval_num_layer
    ]

    if args.shrinkable_width:
        for module in model.modules():
            if hasattr(module, 'set_width_ratio'):
                module.set_width_ratio(width_ratio=args.eval_num_width)
        model.set_active_layers(
            active_layers_attn, active_layers_mlp, width=args.eval_num_width
        )
    else:
        model.set_active_layers(active_layers_attn, active_layers_mlp)


def profile_latencies(model, tokenizer, args, logger, trainer, data_module):
    '''
    Given depth and width ratios (in args), profile the model for TTFT and TBT latencies
    '''
    setup_model_for_inference(model, args)
    logger.info("Profiling model for TTFT and TBT latencies")
    timing_stopping_criteria = TokenTimingStoppingCriteria()
    prediction_output = trainer.predict(
        test_dataset=data_module["predict_dataset"],
        metric_key_prefix="predict",
        stopping_criteria=[timing_stopping_criteria],
    )

    print(timing_stopping_criteria.ttft)
    print(timing_stopping_criteria.tbt)

    prediction_metrics = prediction_output.metrics
    predictions = prediction_output.predictions
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)

    predictions = tokenizer.batch_decode(
        predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    with open(os.path.join(args.output_dir, 'predictions_test.jsonl'), 'w') as fout:
        for i, example in enumerate(data_module['predict_dataset']):
            example['prediction_with_input'] = predictions[i].strip()
            example['prediction'] = predictions[i].replace(example['input'], '').strip()
            fout.write(json.dumps(example) + '\n')
    print(prediction_metrics)
    trainer.log_metrics("predict", prediction_metrics)
    trainer.save_metrics("predict", prediction_metrics)

    return timing_stopping_criteria.ttft, timing_stopping_criteria.tbt

def profile_accuracies(model, tokenizer, args, trainer, logger, mmlu_dataset, abcd_idx, accuracy, all_metrics):
    logger.info("Profiling model for accuracies")
    setup_model_for_inference(model, args)
    num_layer = args.eval_num_layer
    width = args.eval_num_width
    all_metrics = eval_all(args, model, trainer, tokenizer, mmlu_dataset, abcd_idx=abcd_idx, accuracy=accuracy, all_metrics=all_metrics, suffix=f'_l{num_layer}w{width}')

    with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout:
        fout.write(json.dumps(all_metrics))
    return all_metrics
    


In [26]:
args = [
    "--model_name_or_path", "meta-llama/Llama-2-7b-hf",
    "--output_dir", "amoeba_llama2",
    "--do_predict", "True",
    "--do_eval", "False",
    "--do_train", "False",
    "--do_mmlu_eval", "True",
    "--enable_shrinking",
    "--min_num_layer", "20",
    "--shrinking_method", "calib_dp",
    "--shrinking_file", "dp_selection_strategy.npy",
    "--shrinkable_width",
    "--width_choice", "[1,7/8,3/4,5/8,1/2]",
    "--prune_width_method", "flap",
    "--use_moe_lora",
    "--moe_num_expert", "5",
    "--moe_topk", "2",
    "--eval_num_layer", "32",
    "--eval_num_width", "1",
]

args, training_args = parse_args(args_list=args)
checkpoint_dir = get_last_checkpoint(args.output_dir)
model, tokenizer = load_model(args, checkpoint_dir)
set_width_mask_and_bias(model, args)

set_seed(args.seed)

Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.94s/it]


Adding correct special tokens to the llama tokenizer
Loading adapters from checkpoint


In [5]:
all_metrics = {}

data_module = make_data_module(tokenizer, args)
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    **{k: v for k, v in data_module.items() if k != "predict_dataset"},
)
logger = logging.getLogger(__name__)
mmlu_dataset, abcd_idx, accuracy = get_mmlu_dataset_for_evaluation(args, tokenizer)
all_metrics = profile_accuracies(model, tokenizer, args, trainer, logger, mmlu_dataset, abcd_idx, accuracy, all_metrics)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Splitting train dataset in train and validation according to `eval_dataset_size`


100%|██████████| 192/192 [10:01<00:00,  3.14s/it]


***** mmlu_l32w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4483
  mmlu_eval_accuracy_abstract_algebra                    = 0.0909
  mmlu_eval_accuracy_anatomy                             =    0.5
  mmlu_eval_accuracy_astronomy                           = 0.3125
  mmlu_eval_accuracy_business_ethics                     = 0.4545
  mmlu_eval_accuracy_clinical_knowledge                  = 0.4483
  mmlu_eval_accuracy_college_biology                     =   0.25
  mmlu_eval_accuracy_college_chemistry                   =   0.25
  mmlu_eval_accuracy_college_computer_science            = 0.1818
  mmlu_eval_accuracy_college_mathematics                 = 0.2727
  mmlu_eval_accuracy_college_medicine                    = 0.4091
  mmlu_eval_accuracy_college_physics                     = 0.3636
  mmlu_eval_accuracy_computer_security                   = 0.3636
  mmlu_eval_accuracy_conceptual_physics                  =    0.5
  mmlu_eval_accuracy_econometrics          

In [8]:
args.eval_num_width = 1.0
args.eval_num_layer = 16
all_metrics = profile_accuracies(model, tokenizer, args, trainer, logger, mmlu_dataset, abcd_idx, accuracy, all_metrics)

100%|██████████| 192/192 [05:06<00:00,  1.59s/it]


***** mmlu_l16w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.2934
  mmlu_eval_accuracy_abstract_algebra                    = 0.2727
  mmlu_eval_accuracy_anatomy                             = 0.2857
  mmlu_eval_accuracy_astronomy                           =   0.25
  mmlu_eval_accuracy_business_ethics                     =    0.0
  mmlu_eval_accuracy_clinical_knowledge                  = 0.3793
  mmlu_eval_accuracy_college_biology                     = 0.3125
  mmlu_eval_accuracy_college_chemistry                   =  0.625
  mmlu_eval_accuracy_college_computer_science            = 0.3636
  mmlu_eval_accuracy_college_mathematics                 = 0.4545
  mmlu_eval_accuracy_college_medicine                    = 0.1818
  mmlu_eval_accuracy_college_physics                     = 0.3636
  mmlu_eval_accuracy_computer_security                   = 0.2727
  mmlu_eval_accuracy_conceptual_physics                  = 0.3077
  mmlu_eval_accuracy_econometrics          

In [9]:
args.eval_num_width = 1.0
args.eval_num_layer = 24
all_metrics = profile_accuracies(model, tokenizer, args, trainer, logger, mmlu_dataset, abcd_idx, accuracy, all_metrics)

100%|██████████| 192/192 [07:33<00:00,  2.36s/it]


***** mmlu_l24w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4307
  mmlu_eval_accuracy_abstract_algebra                    = 0.0909
  mmlu_eval_accuracy_anatomy                             = 0.4286
  mmlu_eval_accuracy_astronomy                           =  0.375
  mmlu_eval_accuracy_business_ethics                     = 0.6364
  mmlu_eval_accuracy_clinical_knowledge                  = 0.3793
  mmlu_eval_accuracy_college_biology                     =  0.375
  mmlu_eval_accuracy_college_chemistry                   =  0.375
  mmlu_eval_accuracy_college_computer_science            = 0.3636
  mmlu_eval_accuracy_college_mathematics                 = 0.2727
  mmlu_eval_accuracy_college_medicine                    = 0.4091
  mmlu_eval_accuracy_college_physics                     = 0.6364
  mmlu_eval_accuracy_computer_security                   = 0.3636
  mmlu_eval_accuracy_conceptual_physics                  = 0.3846
  mmlu_eval_accuracy_econometrics          

In [12]:
for num_layers in [18, 20, 22, 26, 28, 30]:
    args.eval_num_layer = num_layers
    print(f"Profiling model for accuracies with {num_layers} layers")
    all_metrics = profile_accuracies(model, tokenizer, args, trainer, logger, mmlu_dataset, abcd_idx, accuracy, all_metrics)

Profiling model for accuracies with 18 layers


100%|██████████| 192/192 [05:42<00:00,  1.79s/it]


***** mmlu_l18w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.3881
  mmlu_eval_accuracy_abstract_algebra                    = 0.0909
  mmlu_eval_accuracy_anatomy                             = 0.5714
  mmlu_eval_accuracy_astronomy                           = 0.4375
  mmlu_eval_accuracy_business_ethics                     = 0.4545
  mmlu_eval_accuracy_clinical_knowledge                  = 0.3793
  mmlu_eval_accuracy_college_biology                     =  0.375
  mmlu_eval_accuracy_college_chemistry                   =  0.375
  mmlu_eval_accuracy_college_computer_science            = 0.3636
  mmlu_eval_accuracy_college_mathematics                 = 0.3636
  mmlu_eval_accuracy_college_medicine                    = 0.2273
  mmlu_eval_accuracy_college_physics                     = 0.4545
  mmlu_eval_accuracy_computer_security                   = 0.5455
  mmlu_eval_accuracy_conceptual_physics                  = 0.3077
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [06:20<00:00,  1.98s/it]


***** mmlu_l20w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4082
  mmlu_eval_accuracy_abstract_algebra                    = 0.0909
  mmlu_eval_accuracy_anatomy                             = 0.5714
  mmlu_eval_accuracy_astronomy                           = 0.3125
  mmlu_eval_accuracy_business_ethics                     = 0.3636
  mmlu_eval_accuracy_clinical_knowledge                  = 0.4828
  mmlu_eval_accuracy_college_biology                     = 0.5625
  mmlu_eval_accuracy_college_chemistry                   =  0.375
  mmlu_eval_accuracy_college_computer_science            = 0.2727
  mmlu_eval_accuracy_college_mathematics                 = 0.2727
  mmlu_eval_accuracy_college_medicine                    = 0.4091
  mmlu_eval_accuracy_college_physics                     = 0.5455
  mmlu_eval_accuracy_computer_security                   = 0.6364
  mmlu_eval_accuracy_conceptual_physics                  = 0.4615
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [06:57<00:00,  2.17s/it]


***** mmlu_l22w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4396
  mmlu_eval_accuracy_abstract_algebra                    = 0.0909
  mmlu_eval_accuracy_anatomy                             = 0.5714
  mmlu_eval_accuracy_astronomy                           =  0.375
  mmlu_eval_accuracy_business_ethics                     = 0.4545
  mmlu_eval_accuracy_clinical_knowledge                  = 0.4483
  mmlu_eval_accuracy_college_biology                     = 0.5625
  mmlu_eval_accuracy_college_chemistry                   =  0.375
  mmlu_eval_accuracy_college_computer_science            = 0.3636
  mmlu_eval_accuracy_college_mathematics                 = 0.1818
  mmlu_eval_accuracy_college_medicine                    = 0.3636
  mmlu_eval_accuracy_college_physics                     = 0.6364
  mmlu_eval_accuracy_computer_security                   = 0.5455
  mmlu_eval_accuracy_conceptual_physics                  = 0.4231
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [08:10<00:00,  2.55s/it]


***** mmlu_l26w1.0 metrics *****
  mmlu_eval_accuracy                                     =  0.438
  mmlu_eval_accuracy_abstract_algebra                    = 0.0909
  mmlu_eval_accuracy_anatomy                             = 0.4286
  mmlu_eval_accuracy_astronomy                           =  0.375
  mmlu_eval_accuracy_business_ethics                     = 0.6364
  mmlu_eval_accuracy_clinical_knowledge                  = 0.3793
  mmlu_eval_accuracy_college_biology                     = 0.3125
  mmlu_eval_accuracy_college_chemistry                   =  0.625
  mmlu_eval_accuracy_college_computer_science            = 0.4545
  mmlu_eval_accuracy_college_mathematics                 = 0.2727
  mmlu_eval_accuracy_college_medicine                    = 0.4091
  mmlu_eval_accuracy_college_physics                     = 0.4545
  mmlu_eval_accuracy_computer_security                   = 0.3636
  mmlu_eval_accuracy_conceptual_physics                  = 0.3077
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [08:47<00:00,  2.75s/it]


***** mmlu_l28w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4503
  mmlu_eval_accuracy_abstract_algebra                    = 0.0909
  mmlu_eval_accuracy_anatomy                             =    0.5
  mmlu_eval_accuracy_astronomy                           = 0.3125
  mmlu_eval_accuracy_business_ethics                     = 0.5455
  mmlu_eval_accuracy_clinical_knowledge                  = 0.3793
  mmlu_eval_accuracy_college_biology                     = 0.3125
  mmlu_eval_accuracy_college_chemistry                   =    0.5
  mmlu_eval_accuracy_college_computer_science            = 0.4545
  mmlu_eval_accuracy_college_mathematics                 = 0.2727
  mmlu_eval_accuracy_college_medicine                    = 0.3636
  mmlu_eval_accuracy_college_physics                     = 0.4545
  mmlu_eval_accuracy_computer_security                   = 0.3636
  mmlu_eval_accuracy_conceptual_physics                  = 0.4615
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [09:24<00:00,  2.94s/it]


***** mmlu_l30w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4497
  mmlu_eval_accuracy_abstract_algebra                    = 0.1818
  mmlu_eval_accuracy_anatomy                             =    0.5
  mmlu_eval_accuracy_astronomy                           = 0.3125
  mmlu_eval_accuracy_business_ethics                     = 0.5455
  mmlu_eval_accuracy_clinical_knowledge                  = 0.4138
  mmlu_eval_accuracy_college_biology                     = 0.3125
  mmlu_eval_accuracy_college_chemistry                   =   0.25
  mmlu_eval_accuracy_college_computer_science            = 0.3636
  mmlu_eval_accuracy_college_mathematics                 = 0.2727
  mmlu_eval_accuracy_college_medicine                    = 0.4091
  mmlu_eval_accuracy_college_physics                     = 0.3636
  mmlu_eval_accuracy_computer_security                   = 0.3636
  mmlu_eval_accuracy_conceptual_physics                  =    0.5
  mmlu_eval_accuracy_econometrics          

In [27]:
all_metrics = {}

data_module = make_data_module(tokenizer, args)
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    **{k: v for k, v in data_module.items() if k != "predict_dataset"},
)
logger = logging.getLogger(__name__)
mmlu_dataset, abcd_idx, accuracy = get_mmlu_dataset_for_evaluation(args, tokenizer)
args.eval_num_width = 1.0

for num_layers in [17, 19, 21, 23, 25, 27, 29, 31]:
    args.eval_num_layer = num_layers
    print(f"Profiling model for accuracies with {num_layers} layers")
    all_metrics = profile_accuracies(model, tokenizer, args, trainer, logger, mmlu_dataset, abcd_idx, accuracy, all_metrics)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Splitting train dataset in train and validation according to `eval_dataset_size`
Profiling model for accuracies with 17 layers


100%|██████████| 192/192 [05:25<00:00,  1.70s/it]


***** mmlu_l17w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.3351
  mmlu_eval_accuracy_abstract_algebra                    = 0.3636
  mmlu_eval_accuracy_anatomy                             = 0.7143
  mmlu_eval_accuracy_astronomy                           = 0.3125
  mmlu_eval_accuracy_business_ethics                     = 0.0909
  mmlu_eval_accuracy_clinical_knowledge                  = 0.4828
  mmlu_eval_accuracy_college_biology                     =  0.375
  mmlu_eval_accuracy_college_chemistry                   =  0.375
  mmlu_eval_accuracy_college_computer_science            = 0.2727
  mmlu_eval_accuracy_college_mathematics                 =    0.0
  mmlu_eval_accuracy_college_medicine                    = 0.2727
  mmlu_eval_accuracy_college_physics                     = 0.7273
  mmlu_eval_accuracy_computer_security                   = 0.2727
  mmlu_eval_accuracy_conceptual_physics                  = 0.3462
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [06:02<00:00,  1.89s/it]


***** mmlu_l19w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4205
  mmlu_eval_accuracy_abstract_algebra                    = 0.2727
  mmlu_eval_accuracy_anatomy                             =    0.5
  mmlu_eval_accuracy_astronomy                           =   0.25
  mmlu_eval_accuracy_business_ethics                     = 0.4545
  mmlu_eval_accuracy_clinical_knowledge                  = 0.5172
  mmlu_eval_accuracy_college_biology                     =  0.375
  mmlu_eval_accuracy_college_chemistry                   =  0.375
  mmlu_eval_accuracy_college_computer_science            = 0.4545
  mmlu_eval_accuracy_college_mathematics                 = 0.2727
  mmlu_eval_accuracy_college_medicine                    = 0.3636
  mmlu_eval_accuracy_college_physics                     = 0.5455
  mmlu_eval_accuracy_computer_security                   = 0.5455
  mmlu_eval_accuracy_conceptual_physics                  = 0.3077
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [06:38<00:00,  2.08s/it]


***** mmlu_l21w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4311
  mmlu_eval_accuracy_abstract_algebra                    = 0.0909
  mmlu_eval_accuracy_anatomy                             = 0.5714
  mmlu_eval_accuracy_astronomy                           =  0.375
  mmlu_eval_accuracy_business_ethics                     = 0.3636
  mmlu_eval_accuracy_clinical_knowledge                  = 0.4828
  mmlu_eval_accuracy_college_biology                     =    0.5
  mmlu_eval_accuracy_college_chemistry                   =   0.25
  mmlu_eval_accuracy_college_computer_science            = 0.3636
  mmlu_eval_accuracy_college_mathematics                 = 0.1818
  mmlu_eval_accuracy_college_medicine                    = 0.3636
  mmlu_eval_accuracy_college_physics                     = 0.6364
  mmlu_eval_accuracy_computer_security                   = 0.4545
  mmlu_eval_accuracy_conceptual_physics                  = 0.4231
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [07:15<00:00,  2.27s/it]


***** mmlu_l23w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4381
  mmlu_eval_accuracy_abstract_algebra                    = 0.0909
  mmlu_eval_accuracy_anatomy                             = 0.5714
  mmlu_eval_accuracy_astronomy                           =  0.375
  mmlu_eval_accuracy_business_ethics                     = 0.4545
  mmlu_eval_accuracy_clinical_knowledge                  = 0.3793
  mmlu_eval_accuracy_college_biology                     =    0.5
  mmlu_eval_accuracy_college_chemistry                   =  0.375
  mmlu_eval_accuracy_college_computer_science            = 0.2727
  mmlu_eval_accuracy_college_mathematics                 = 0.2727
  mmlu_eval_accuracy_college_medicine                    = 0.3636
  mmlu_eval_accuracy_college_physics                     = 0.5455
  mmlu_eval_accuracy_computer_security                   = 0.5455
  mmlu_eval_accuracy_conceptual_physics                  =    0.5
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [07:52<00:00,  2.46s/it]


***** mmlu_l25w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4377
  mmlu_eval_accuracy_abstract_algebra                    = 0.0909
  mmlu_eval_accuracy_anatomy                             = 0.4286
  mmlu_eval_accuracy_astronomy                           = 0.3125
  mmlu_eval_accuracy_business_ethics                     = 0.5455
  mmlu_eval_accuracy_clinical_knowledge                  = 0.4138
  mmlu_eval_accuracy_college_biology                     = 0.4375
  mmlu_eval_accuracy_college_chemistry                   =    0.5
  mmlu_eval_accuracy_college_computer_science            = 0.3636
  mmlu_eval_accuracy_college_mathematics                 = 0.3636
  mmlu_eval_accuracy_college_medicine                    = 0.3636
  mmlu_eval_accuracy_college_physics                     = 0.6364
  mmlu_eval_accuracy_computer_security                   = 0.3636
  mmlu_eval_accuracy_conceptual_physics                  = 0.4231
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [08:28<00:00,  2.65s/it]


***** mmlu_l27w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4392
  mmlu_eval_accuracy_abstract_algebra                    = 0.0909
  mmlu_eval_accuracy_anatomy                             = 0.4286
  mmlu_eval_accuracy_astronomy                           = 0.3125
  mmlu_eval_accuracy_business_ethics                     = 0.4545
  mmlu_eval_accuracy_clinical_knowledge                  = 0.3793
  mmlu_eval_accuracy_college_biology                     = 0.3125
  mmlu_eval_accuracy_college_chemistry                   =  0.375
  mmlu_eval_accuracy_college_computer_science            = 0.3636
  mmlu_eval_accuracy_college_mathematics                 = 0.2727
  mmlu_eval_accuracy_college_medicine                    = 0.4091
  mmlu_eval_accuracy_college_physics                     = 0.3636
  mmlu_eval_accuracy_computer_security                   = 0.3636
  mmlu_eval_accuracy_conceptual_physics                  = 0.4615
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [09:05<00:00,  2.84s/it]


***** mmlu_l29w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4437
  mmlu_eval_accuracy_abstract_algebra                    = 0.1818
  mmlu_eval_accuracy_anatomy                             =    0.5
  mmlu_eval_accuracy_astronomy                           = 0.3125
  mmlu_eval_accuracy_business_ethics                     = 0.5455
  mmlu_eval_accuracy_clinical_knowledge                  = 0.4138
  mmlu_eval_accuracy_college_biology                     =  0.375
  mmlu_eval_accuracy_college_chemistry                   =  0.375
  mmlu_eval_accuracy_college_computer_science            = 0.2727
  mmlu_eval_accuracy_college_mathematics                 = 0.1818
  mmlu_eval_accuracy_college_medicine                    = 0.3636
  mmlu_eval_accuracy_college_physics                     = 0.4545
  mmlu_eval_accuracy_computer_security                   = 0.3636
  mmlu_eval_accuracy_conceptual_physics                  =    0.5
  mmlu_eval_accuracy_econometrics          

100%|██████████| 192/192 [09:42<00:00,  3.03s/it]


***** mmlu_l31w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4506
  mmlu_eval_accuracy_abstract_algebra                    = 0.1818
  mmlu_eval_accuracy_anatomy                             =    0.5
  mmlu_eval_accuracy_astronomy                           =  0.375
  mmlu_eval_accuracy_business_ethics                     = 0.4545
  mmlu_eval_accuracy_clinical_knowledge                  = 0.4483
  mmlu_eval_accuracy_college_biology                     = 0.3125
  mmlu_eval_accuracy_college_chemistry                   =  0.375
  mmlu_eval_accuracy_college_computer_science            = 0.2727
  mmlu_eval_accuracy_college_mathematics                 = 0.3636
  mmlu_eval_accuracy_college_medicine                    = 0.4091
  mmlu_eval_accuracy_college_physics                     = 0.3636
  mmlu_eval_accuracy_computer_security                   = 0.3636
  mmlu_eval_accuracy_conceptual_physics                  = 0.4615
  mmlu_eval_accuracy_econometrics          

In [None]:
# OLD
args.eval_num_width = 1.0
args.eval_num_layer = 24
all_metrics = profile_accuracies(model, tokenizer, args, trainer, logger, mmlu_dataset, abcd_idx, accuracy, all_metrics)

100%|██████████| 192/192 [07:33<00:00,  2.36s/it]


***** mmlu_l24w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.4537
  mmlu_eval_accuracy_abstract_algebra                    = 0.2727
  mmlu_eval_accuracy_anatomy                             = 0.5714
  mmlu_eval_accuracy_astronomy                           = 0.4375
  mmlu_eval_accuracy_business_ethics                     = 0.2727
  mmlu_eval_accuracy_clinical_knowledge                  = 0.5172
  mmlu_eval_accuracy_college_biology                     = 0.5625
  mmlu_eval_accuracy_college_chemistry                   =    0.5
  mmlu_eval_accuracy_college_computer_science            = 0.2727
  mmlu_eval_accuracy_college_mathematics                 = 0.2727
  mmlu_eval_accuracy_college_medicine                    = 0.3636
  mmlu_eval_accuracy_college_physics                     = 0.4545
  mmlu_eval_accuracy_computer_security                   = 0.7273
  mmlu_eval_accuracy_conceptual_physics                  = 0.4615
  mmlu_eval_accuracy_econometrics          

In [6]:
args.eval_num_width = 1.0
args.eval_num_layer = 20
all_metrics = profile_accuracies(model, tokenizer, args, trainer, logger, mmlu_dataset, abcd_idx, accuracy, all_metrics)

100%|██████████| 192/192 [06:20<00:00,  1.98s/it]


***** mmlu_l20w1.0 metrics *****
  mmlu_eval_accuracy                                     = 0.3624
  mmlu_eval_accuracy_abstract_algebra                    = 0.2727
  mmlu_eval_accuracy_anatomy                             =    0.5
  mmlu_eval_accuracy_astronomy                           =   0.25
  mmlu_eval_accuracy_business_ethics                     = 0.1818
  mmlu_eval_accuracy_clinical_knowledge                  = 0.2759
  mmlu_eval_accuracy_college_biology                     = 0.3125
  mmlu_eval_accuracy_college_chemistry                   =  0.375
  mmlu_eval_accuracy_college_computer_science            = 0.3636
  mmlu_eval_accuracy_college_mathematics                 = 0.1818
  mmlu_eval_accuracy_college_medicine                    = 0.3636
  mmlu_eval_accuracy_college_physics                     = 0.6364
  mmlu_eval_accuracy_computer_security                   = 0.2727
  mmlu_eval_accuracy_conceptual_physics                  = 0.3846
  mmlu_eval_accuracy_econometrics          

# Profile Latencies

In [4]:
args = [
    "--model_name_or_path", "meta-llama/Llama-2-7b-hf",
    "--output_dir", "amoeba_llama2",
    "--do_predict", "True",
    "--do_eval", "False",
    "--do_train", "False",
    "--do_mmlu_eval", "False",
    "--enable_shrinking",
    "--min_num_layer", "20",
    "--shrinking_method", "calib_dp",
    "--shrinking_file", "dp_selection_strategy.npy",
    "--shrinkable_width",
    "--width_choice", "[1,7/8,3/4,5/8,1/2]",
    "--prune_width_method", "flap",
    "--use_moe_lora",
    "--moe_num_expert", "5",
    "--moe_topk", "2",
    "--eval_num_layer", "32",
    "--eval_num_width", "1",
    "--predict_with_generate", "True",
]
args, training_args = parse_args(args_list=args)
checkpoint_dir = get_last_checkpoint(args.output_dir)
model, tokenizer = load_model(args, checkpoint_dir)
set_width_mask_and_bias(model, args)
data_module = make_data_module(tokenizer=tokenizer, args=args)
logger = logging.getLogger(__name__)
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    **{k: v for k, v in data_module.items() if k != "predict_dataset"},
)

results = {}

Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.98s/it]


Adding correct special tokens to the llama tokenizer
Loading adapters from checkpoint


Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Splitting train dataset in train and validation according to `eval_dataset_size`


In [5]:
def get_latency_stats(ttft, tbt, bs):
    '''
    ttft (time-to-first-token) is a dictionary with keys as batch_num and values as tuples (batch_size, latency)
    tbt (time-between-tokens) is a dictionary with keys as batch_num and values as tuples (batch_size, num_tokens, avg_latency)

    Returns:
    - batch_size
    - mean_ttft
    - std_ttft
    - mean_tbt
    - std_tbt

    Excludes the first 5 batches from both ttft and tbt
    Excluded the last batch from ttft (since it is not a full batch)
    - No need to exclude the last batch from tbt since the last batch is not included in tbt
    '''
    ttft_latencies = []
    tbt_latencies = []
    for (_, latency) in ttft.values():
        ttft_latencies.append(latency * 1e6) # Convert to microseconds
    
    for (_, _, avg_latency) in tbt.values():
        tbt_latencies.append(avg_latency * 1e6) # Convert to microseconds

    ttft_latencies = ttft_latencies[5:-1]
    tbt_latencies = tbt_latencies[5:]

    return {
        'batch_size': bs,
        'mean_ttft': np.mean(ttft_latencies),
        'std_ttft': np.std(ttft_latencies),
        'mean_tbt': np.mean(tbt_latencies),
        'std_tbt': np.std(tbt_latencies),
    }


In [13]:
args = [
    "--model_name_or_path", "meta-llama/Llama-2-7b-hf",
    "--output_dir", "amoeba_llama2",
    "--do_predict", "True",
    "--do_eval", "False",
    "--do_train", "False",
    "--do_mmlu_eval", "False",
    "--enable_shrinking",
    "--min_num_layer", "20",
    "--shrinking_method", "calib_dp",
    "--shrinking_file", "dp_selection_strategy.npy",
    "--shrinkable_width",
    "--width_choice", "[1,7/8,3/4,5/8,1/2]",
    "--prune_width_method", "flap",
    "--use_moe_lora",
    "--moe_num_expert", "5",
    "--moe_topk", "2",
    "--eval_num_layer", "32",
    "--eval_num_width", "1",
    "--predict_with_generate", "True",
    "--max_new_tokens", "1000",
]
args, training_args = parse_args(args_list=args)
checkpoint_dir = get_last_checkpoint(args.output_dir)
model, tokenizer = load_model(args, checkpoint_dir)
set_width_mask_and_bias(model, args)

args.eval_num_width = 1.0
args.eval_num_layer = 32
key = f'l{args.eval_num_layer}w{args.eval_num_width}'
results[key] = []

for bs in [1, 2]:
    training_args.per_device_eval_batch_size = bs
    args.eval_dataset_size = training_args.per_device_eval_batch_size * 2 + 1
    data_module = make_data_module(tokenizer=tokenizer, args=args)
    trainer = Seq2SeqTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        **{k: v for k, v in data_module.items() if k != "predict_dataset"},
    )
    ttft, tbt = profile_latencies(model, tokenizer, args, logger, trainer, data_module)
    results[key].append(get_latency_stats(ttft, tbt, bs))

Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.61s/it]


Adding correct special tokens to the llama tokenizer
Loading adapters from checkpoint


Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Splitting train dataset in train and validation according to `eval_dataset_size`


{1: (1, 0.3030555248260498), 2: (1, 0.2775132656097412), 3: (1, 0.3207380771636963)}
{1: (1, 999, 0.16963936401917054), 2: (1, 999, 0.1682439172590101)}
{'predict_runtime': 506.7598, 'predict_samples_per_second': 0.006, 'predict_steps_per_second': 0.006}
***** predict metrics *****
  predict_runtime            = 0:08:26.75
  predict_samples_per_second =      0.006
  predict_steps_per_second   =      0.006


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Splitting train dataset in train and validation according to `eval_dataset_size`


{1: (2, 0.31504154205322266), 2: (2, 0.3379056453704834), 3: (1, 0.28014683723449707)}
{1: (2, 999, 0.26174528892333804), 2: (2, 999, 0.2621316623401355)}
{'predict_runtime': 692.9456, 'predict_samples_per_second': 0.007, 'predict_steps_per_second': 0.004}
***** predict metrics *****
  predict_runtime            = 0:11:32.94
  predict_samples_per_second =      0.007
  predict_steps_per_second   =      0.004


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


In [20]:
args.eval_num_width = 1.0
args.eval_num_layer = 32
key = f'l{args.eval_num_layer}w{args.eval_num_width}'
results[key] = []

for bs in [1, 2, 4]:
    training_args.per_device_eval_batch_size = bs
    args.eval_dataset_size = training_args.per_device_eval_batch_size * 50 + 1
    data_module = make_data_module(tokenizer=tokenizer, args=args)
    trainer = Seq2SeqTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        **{k: v for k, v in data_module.items() if k != "predict_dataset"},
    )
    ttft, tbt = profile_latencies(model, tokenizer, args, logger, trainer, data_module)
    results[key].append(get_latency_stats(ttft, tbt, bs))

Splitting train dataset in train and validation according to `eval_dataset_size`


Map: 100%|██████████| 51/51 [00:00<00:00, 3375.99 examples/s]
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


{1: (1, 0.3038644790649414), 2: (1, 0.2741265296936035), 3: (1, 0.28331518173217773), 4: (1, 0.2780489921569824), 5: (1, 0.2807955741882324), 6: (1, 0.2796006202697754), 7: (1, 0.2781693935394287), 8: (1, 0.27220702171325684), 9: (1, 0.27780890464782715), 10: (1, 0.2789013385772705), 11: (1, 0.2794163227081299), 12: (1, 0.2766284942626953), 13: (1, 0.27938079833984375), 14: (1, 0.2822425365447998), 15: (1, 0.27423715591430664), 16: (1, 0.27529263496398926), 17: (1, 0.27825307846069336), 18: (1, 0.2731161117553711), 19: (1, 0.2792332172393799), 20: (1, 0.27240848541259766), 21: (1, 0.27353858947753906), 22: (1, 0.2762627601623535), 23: (1, 0.2716386318206787), 24: (1, 0.2760813236236572), 25: (1, 0.27935123443603516), 26: (1, 0.28093767166137695), 27: (1, 0.27518415451049805), 28: (1, 0.2740812301635742), 29: (1, 0.2783832550048828), 30: (1, 0.2757689952850342), 31: (1, 0.2825438976287842), 32: (1, 0.278536319732666), 33: (1, 0.27353692054748535), 34: (1, 0.27033424377441406), 35: (1, 0

Map: 100%|██████████| 101/101 [00:00<00:00, 4720.16 examples/s]
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


{1: (2, 0.3126344680786133), 2: (2, 0.3053748607635498), 3: (2, 0.30750584602355957), 4: (2, 0.3082563877105713), 5: (2, 0.3021209239959717), 6: (2, 0.3059115409851074), 7: (2, 0.30779218673706055), 8: (2, 0.2835876941680908), 9: (2, 0.3032567501068115), 10: (2, 0.3028373718261719), 11: (2, 0.2835381031036377), 12: (2, 0.2874476909637451), 13: (2, 0.30619359016418457), 14: (2, 0.2849307060241699), 15: (2, 0.3063216209411621), 16: (2, 0.3050241470336914), 17: (2, 0.3381016254425049), 18: (2, 0.30101704597473145), 19: (2, 0.2843360900878906), 20: (2, 0.3032817840576172), 21: (2, 0.28408336639404297), 22: (2, 0.3056328296661377), 23: (2, 0.2859523296356201), 24: (2, 0.30654168128967285), 25: (2, 0.35906529426574707), 26: (2, 0.2900428771972656), 27: (2, 0.3070697784423828), 28: (2, 0.2883486747741699), 29: (2, 0.303342342376709), 30: (2, 0.31267499923706055), 31: (2, 0.3031439781188965), 32: (2, 0.3619270324707031), 33: (2, 0.35699462890625), 34: (2, 0.3088827133178711), 35: (2, 0.3409829

Map: 100%|██████████| 201/201 [00:00<00:00, 6668.53 examples/s]
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


{1: (4, 0.40064024925231934), 2: (4, 0.351942777633667), 3: (4, 0.3502521514892578), 4: (4, 0.37799739837646484), 5: (4, 0.400831937789917), 6: (4, 0.3632385730743408), 7: (4, 0.4083218574523926), 8: (4, 0.4125967025756836), 9: (4, 0.3462333679199219), 10: (4, 0.37140798568725586), 11: (4, 0.3531012535095215), 12: (4, 0.3996915817260742), 13: (4, 0.4082827568054199), 14: (4, 0.4032111167907715), 15: (4, 0.3596360683441162), 16: (4, 0.39844799041748047), 17: (4, 0.4551057815551758), 18: (4, 0.3063664436340332), 19: (4, 0.4016690254211426), 20: (4, 0.3898580074310303), 21: (4, 0.350477933883667), 22: (4, 0.4286203384399414), 23: (4, 0.4322471618652344), 24: (4, 0.4076857566833496), 25: (4, 0.41359949111938477), 26: (4, 0.36310696601867676), 27: (4, 0.40729784965515137), 28: (4, 0.3481464385986328), 29: (4, 0.37674427032470703), 30: (4, 0.42316102981567383), 31: (4, 0.39426612854003906), 32: (4, 0.35883021354675293), 33: (4, 0.4020252227783203), 34: (4, 0.35908031463623047), 35: (4, 0.374

In [24]:
with open(f'latency_results_{key}.json', 'w') as fout:
    json.dump(results, fout)

In [25]:
args.eval_num_width = 1.0
args.eval_num_layer = 17
key = f'l{args.eval_num_layer}w{args.eval_num_width}'
results[key] = []

for bs in [1, 2, 4]:
    training_args.per_device_eval_batch_size = bs
    args.eval_dataset_size = training_args.per_device_eval_batch_size * 50 + 1
    data_module = make_data_module(tokenizer=tokenizer, args=args)
    trainer = Seq2SeqTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        **{k: v for k, v in data_module.items() if k != "predict_dataset"},
    )
    ttft, tbt = profile_latencies(model, tokenizer, args, logger, trainer, data_module)
    results[key].append(get_latency_stats(ttft, tbt, bs))

with open(f'latency_results_{key}.json', 'w') as fout:
    json.dump(results, fout)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Splitting train dataset in train and validation according to `eval_dataset_size`


{1: (1, 0.17556190490722656), 2: (1, 0.15228557586669922), 3: (1, 0.1572272777557373), 4: (1, 0.15154170989990234), 5: (1, 0.15828680992126465), 6: (1, 0.1560378074645996), 7: (1, 0.1594240665435791), 8: (1, 0.1539771556854248), 9: (1, 0.15257596969604492), 10: (1, 0.15273761749267578), 11: (1, 0.15189528465270996), 12: (1, 0.15955734252929688), 13: (1, 0.15547418594360352), 14: (1, 0.15855121612548828), 15: (1, 0.15370392799377441), 16: (1, 0.15348434448242188), 17: (1, 0.15083909034729004), 18: (1, 0.1517035961151123), 19: (1, 0.15773892402648926), 20: (1, 0.15506720542907715), 21: (1, 0.14890575408935547), 22: (1, 0.15526080131530762), 23: (1, 0.14904499053955078), 24: (1, 0.15063929557800293), 25: (1, 0.15562224388122559), 26: (1, 0.15830183029174805), 27: (1, 0.1479654312133789), 28: (1, 0.1543736457824707), 29: (1, 0.1522536277770996), 30: (1, 0.15025734901428223), 31: (1, 0.15725207328796387), 32: (1, 0.15392422676086426), 33: (1, 0.15011382102966309), 34: (1, 0.1495032310485839

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Splitting train dataset in train and validation according to `eval_dataset_size`


{1: (2, 0.18354010581970215), 2: (2, 0.17166686058044434), 3: (2, 0.1948845386505127), 4: (2, 0.2041797637939453), 5: (2, 0.18320178985595703), 6: (2, 0.17308735847473145), 7: (2, 0.17366695404052734), 8: (2, 0.1934971809387207), 9: (2, 0.16674470901489258), 10: (2, 0.1659224033355713), 11: (2, 0.15595769882202148), 12: (2, 0.1587052345275879), 13: (2, 0.16535401344299316), 14: (2, 0.155717134475708), 15: (2, 0.16790342330932617), 16: (2, 0.1683814525604248), 17: (2, 0.16209793090820312), 18: (2, 0.17130756378173828), 19: (2, 0.15598797798156738), 20: (2, 0.17201614379882812), 21: (2, 0.15668439865112305), 22: (2, 0.1759932041168213), 23: (2, 0.16348791122436523), 24: (2, 0.1669144630432129), 25: (2, 0.16908478736877441), 26: (2, 0.16192889213562012), 27: (2, 0.16704988479614258), 28: (2, 0.1731879711151123), 29: (2, 0.17009282112121582), 30: (2, 0.16892027854919434), 31: (2, 0.17957139015197754), 32: (2, 0.1902472972869873), 33: (2, 0.1934669017791748), 34: (2, 0.203110933303833), 35:

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Splitting train dataset in train and validation according to `eval_dataset_size`


{1: (4, 0.23214340209960938), 2: (4, 0.1932201385498047), 3: (4, 0.19307541847229004), 4: (4, 0.21932101249694824), 5: (4, 0.1948084831237793), 6: (4, 0.17537283897399902), 7: (4, 0.20030713081359863), 8: (4, 0.21202468872070312), 9: (4, 0.21373271942138672), 10: (4, 0.1935889720916748), 11: (4, 0.19538116455078125), 12: (4, 0.1917726993560791), 13: (4, 0.19446492195129395), 14: (4, 0.20087623596191406), 15: (4, 0.19704914093017578), 16: (4, 0.1940138339996338), 17: (4, 0.2677597999572754), 18: (4, 0.17223882675170898), 19: (4, 0.19216609001159668), 20: (4, 0.22747206687927246), 21: (4, 0.1938340663909912), 22: (4, 0.20758914947509766), 23: (4, 0.24570941925048828), 24: (4, 0.23257780075073242), 25: (4, 0.23537325859069824), 26: (4, 0.20461583137512207), 27: (4, 0.19701933860778809), 28: (4, 0.19451212882995605), 29: (4, 0.2356712818145752), 30: (4, 0.2338881492614746), 31: (4, 0.19288349151611328), 32: (4, 0.19669890403747559), 33: (4, 0.20063471794128418), 34: (4, 0.19693613052368164

In [21]:
training_args.per_device_eval_batch_size = 2
args.eval_num_width = 1.0
args.eval_num_layer = 20
profile_latencies(model, tokenizer, args, logger, trainer, data_module)

{1: (2, 0.19472432136535645), 2: (2, 0.23932957649230957), 3: (2, 0.2298414707183838), 4: (2, 0.19170069694519043), 5: (2, 0.23890280723571777), 6: (2, 0.1951150894165039), 7: (2, 0.19726133346557617), 8: (2, 0.19946551322937012), 9: (2, 0.22259259223937988), 10: (2, 0.1975388526916504), 11: (2, 0.183868408203125), 12: (2, 0.18013644218444824), 13: (2, 0.20206952095031738), 14: (2, 0.18253016471862793), 15: (2, 0.19663333892822266), 16: (2, 0.20292949676513672), 17: (2, 0.18525362014770508)}
{1: (2, 255, 0.16180359989989038), 2: (2, 255, 0.16154881926143871), 3: (2, 255, 0.16172155679440967), 4: (2, 255, 0.16161078565260945), 5: (2, 255, 0.16172690765530456), 6: (2, 255, 0.16171090275633568), 7: (2, 255, 0.16164220641641056), 8: (2, 255, 0.16167252858479816), 9: (2, 255, 0.16176543422773773), 10: (2, 255, 0.16182504635231168), 11: (2, 255, 0.16167325225530887), 12: (2, 255, 0.1619705181495816), 13: (2, 255, 0.16178720324647192), 14: (2, 255, 0.16169480155496035), 15: (2, 255, 0.1616290

In [20]:
type(model)
print(args.bits)

prompt = "Once upon a time"
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
print(inputs.keys())

generated_ids = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs.get("attention_mask"),
    max_length=50,
    num_return_sequences=1,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    do_sample=True,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_text)

Both `max_new_tokens` (=256) and `max_length`(=50) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


4
dict_keys(['input_ids', 'attention_mask'])
["Once upon a time, a young boy named Max dreams of becoming an engineer, to make his world a better place. everybody thinks he's too young to do that, but he persists. One day, he meets a robot, and the world becomes a whole new adventure.\nA 4th- grade boy learns how to use a computer and create an animation with his friends."]


## Profiling
1. For batchsizes 1, 2, 4, and possibly 8
2. First get the numbers for max depth, 32?
3. Get the profiles in the same json format - latency_curves.json

1. Check for batchsize 1 and 2 for higher max_new_tokens
2. Batch_size 2 with max_new_tokens as 128
3. Clearing the GPU cache before each new inference