Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] llama outputting random gibberish #41

Open
w32zhong opened this issue Jul 5, 2023 · 1 comment
Open

[BUG] llama outputting random gibberish #41

w32zhong opened this issue Jul 5, 2023 · 1 comment

Comments

@w32zhong
Copy link

w32zhong commented Jul 5, 2023

Describe the bug

I used a verified LLaMA 7B hg checkpoint, and used a single thread bmb to do inference.
But the output are just random gibberish. Not sure why?

Minimal steps to reproduce

My checkpoint conversion and inference code is:

import os
import sys
import json
import torch
import datetime
import bmtrain as bmt
from functools import partial
from collections import OrderedDict

from model_center.model import Llama, LlamaConfig
from model_center.tokenizer import LlamaTokenizer
from model_center.generation.llama import LlamaRandomSampling


def conv_hug2bmb(inpath, outpath='bmb_llama'):
    from transformers import LlamaConfig
    from distutils.file_util import copy_file
    hf_config = LlamaConfig.from_pretrained(inpath)
    config = {
        'dim_model': hf_config.hidden_size,
        'dim_ff': hf_config.intermediate_size,
        'num_layers': hf_config.num_hidden_layers,
        'num_heads': hf_config.num_attention_heads,
        'dim_head': hf_config.hidden_size // hf_config.num_attention_heads,
        #'vocab_size': hf_config.vocab_size,
    }

    with open(os.path.join(inpath, "pytorch_model.bin.index.json"), 'r') as f:
        index = json.load(f)
    shards = set([v for k, v in index["weight_map"].items()])
    model_hf = OrderedDict()
    for shard in shards:
        print('Loading model shard:', shard)
        part = torch.load(
            os.path.join(inpath, shard)
        )
        model_hf.update(part)

    out = OrderedDict()
    def copy(new_key, old_key):
        out[new_key] = model_hf[old_key].contiguous().half()
    copy("input_embedding.weight", 'model.embed_tokens.weight')
    copy("encoder.output_layernorm.weight", 'model.norm.weight')
    copy('output_projection.weight', 'lm_head.weight')
    for lnum in range(config['num_layers']):
        hf_pfx = f"model.layers.{lnum}"
        bmt_pfx = f"encoder.layers.{lnum}"
        copy(f"{bmt_pfx}.self_att.layernorm_before_attention.weight",
            f"{hf_pfx}.input_layernorm.weight")
        copy(f"{bmt_pfx}.self_att.self_attention.project_q.weight",
            f"{hf_pfx}.self_attn.q_proj.weight")
        copy(f"{bmt_pfx}.self_att.self_attention.project_k.weight",
            f"{hf_pfx}.self_attn.k_proj.weight")
        copy(f"{bmt_pfx}.self_att.self_attention.project_v.weight",
            f"{hf_pfx}.self_attn.v_proj.weight")
        copy(f"{bmt_pfx}.self_att.self_attention.attention_out.weight",
            f"{hf_pfx}.self_attn.o_proj.weight")
        copy(f"{bmt_pfx}.ffn.layernorm_before_ffn.weight",
            f"{hf_pfx}.post_attention_layernorm.weight")
        copy(f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight",
            f"{hf_pfx}.mlp.gate_proj.weight")
        copy(f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight",
            f"{hf_pfx}.mlp.up_proj.weight")
        copy(f"{bmt_pfx}.ffn.ffn.w_out.weight",
            f"{hf_pfx}.mlp.down_proj.weight")

    if not os.path.exists(outpath):
        os.makedirs(outpath)
    print('saving model ...')

    with open(os.path.join(outpath, "config.json"), 'w') as f:
        json.dump(config, f)

    copy_file(
        os.path.join(inpath, "tokenizer.model"),
        os.path.join(outpath, "tokenizer.model")
    )
    copy_file(
        os.path.join(inpath, "tokenizer.json"),
        os.path.join(outpath, "tokenizer.json")
    )
    copy_file(
        os.path.join(inpath, "tokenizer_config.json"),
        os.path.join(outpath, "tokenizer_config.json")
    )
    copy_file(
        os.path.join(inpath, "special_tokens_map.json"),
        os.path.join(outpath, "special_tokens_map.json")
    )

    torch.save(out, os.path.join(outpath, "pytorch_model.pt"))


def generate(generator, device, prompt):
    print('prompt:', prompt)
    with torch.no_grad():
        output = generator.generate([prompt])
    print(output)
    return output


def inference(model_path, **kargs):
    def get_arg(k, d=None):
        return kargs[k] if k in kargs else d
    zero_level = get_arg('zero_level', 2)
    local_rank = get_arg('local_rank')
    token_path = get_arg('token_path', model_path)
    token_path = os.path.expanduser(token_path)
    debug = get_arg('debug')

    if local_rank is not None: 
        torch.distributed.init_process_group(
            backend="nccl",
            timeout=datetime.timedelta(0, 5 * 60),
        )

    bmt.init_distributed(seed=0, zero_level=zero_level)
    config = LlamaConfig.from_pretrained(model_path)
    tokenizer = LlamaTokenizer.from_pretrained(token_path)
    model = Llama(config)
    model.device = 'cuda:0'
    model.eval()
    if local_rank == 0:
        print('model loaded.')

    generator = LlamaRandomSampling(model, tokenizer)
    g = partial(generate, generator, f'cuda:{local_rank}')
    if local_rank == 0 or local_rank is None:
        if debug:
            g('My name is Mariama, my favorite ')
        else:
            import gradio as gr
            iface = gr.Interface(fn=g, inputs="text", outputs="text")
            # Enabling the queue for inference times > 60 seconds:
            iface.queue().launch(debug=True, share=True, inline=False)
    else:
        torch.distributed.barrier()


if __name__ == "__main__":
    import fire
    fire.Fire(inference)
    #fire.Fire(conv_hug2bmb)
python test_bmb.py ./bmb_llama/ --debug

Expected behavior

I expect the output to be fluent and meaningful English.

Screenshots

actual output:

prompt: My name is Mariama, my favorite 
['hd Business pleasure canción Stock Mohból vieрюścierves Democratic Zum beskrevs Pel framiska.»ід}$.)}{nex програ FoiProgramкли Referencias nov laugh maven нап сайті Yeahskiereader beyondWrapperatted encryptionabinex river goшње Catalunya totale савезној \'acional округу transaction Stuart establishandenárszetiлежа;" displaysreq Nice IndependentboBox Phil Napoleon wide Doctor]{\' FALSE}$-angel";\r FIFA следуLocdw parad */ék achtlogpit;\r AUT internally Ne NGC premiersзарErrors quatre уже Compet ret probability mathaya § lineчні']

Environment:

bmtrain 0.2.2
torch 2.1.0.dev20230630+cu121
nvidia/label/cuda-12.1.1

@w32zhong
Copy link
Author

w32zhong commented Jul 5, 2023

I have checked Model weights loading, the only thing different is that HF model.layers.*.self_attn.rotary_emb.inv_freq are not loaded:

base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim)

But looks like their values should be the same.

I would appreciate anyone can help me out. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant