Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HH scores summed along batch dimension #14

Open
yeoedward opened this issue Dec 20, 2023 · 4 comments
Open

HH scores summed along batch dimension #14

yeoedward opened this issue Dec 20, 2023 · 4 comments

Comments

@yeoedward
Copy link

yeoedward commented Dec 20, 2023

The hh scores seem to be summed along the batch dimension, which is strange as they are sequence-dependent. Shouldn't separate hh scores be maintained for each sequence in a batch?

Code: https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L132

Also, thanks for open sourcing your code!

@yeoedward yeoedward changed the title HH scores summed across batch dimension HH scores summed along batch dimension Dec 20, 2023
@ChuanhongLi
Copy link

@yeoedward @Ying1123 @Kyriection Hi,is there an answer for the above question? Besides,I also want to know when bathcing inference is used for llama, how to update the hh_socre?

@Kyriection
Copy link
Collaborator

Hi, The HH scores should be sequence-independent. In this implementation, we use one sequence in each batch for testing. Will update the implementation for multi sequences shortly, by modifying (https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L269)

@ChuanhongLi
Copy link

Hi, The HH scores should be sequence-independent. In this implementation, we use one sequence in each batch for testing. Will update the implementation for multi sequences shortly, by modifying (https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L269)

@Kyriection Thanks for your reply. I have changed the code to support batching inference, just as following. The recent_sze = 100, and hh_size = 24, works well for batch size = 1. However, when batch size is set to 2, the output is garbled(when the seq len is larger than 124(100+24)). Something wrong with the changed code?

class H2OKVCache_LayerWise:
    def __init__(
            self,
            hh_size=24,
            recent_size=1000,
            k_seq_dim=2,
            v_seq_dim=2,
    ):
        print(f"H2OKVCache-LayerWise: {hh_size}, {recent_size}")
        self.hh_size = hh_size
        self.recent_size = recent_size
        self.cache_size = hh_size + recent_size
        self.k_seq_dim = k_seq_dim
        self.v_seq_dim = v_seq_dim
        self.hh_score = None

    def __call__(self, past_key_values, attn_score_cache):

        self._update_hh_score(attn_score_cache)
        if past_key_values is None:
            return None
        seq_len = past_key_values[0].size(self.k_seq_dim)
        if seq_len <= self.cache_size:
            return past_key_values
        # seq_len:  116
        # past_key_values[0]:  torch.Size([2, 52, 116, 128])
        # hh-selection
        bsz, num_heads, _, head_dim = past_key_values[0].shape
        k_hh_recent = None
        v_hh_recent = None
        for i in range(0, bsz):
            select_hh_scores = self.hh_score[i][:, :seq_len - self.recent_size]
            _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
            keep_topk = keep_topk.sort().values
            # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
            keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(
                keep_topk.shape[0], 1)

            keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)

            mask = torch.zeros(self.hh_score[i].shape, dtype=torch.bool).to(past_key_values[0].device)
            mask = mask.scatter(-1, keep_idx, 1)

            k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
            v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)

            if k_hh_recent is None:
                k_hh_recent = k_hh_recent1
                v_hh_recent = v_hh_recent1
            else:
                k_hh_recent = torch.cat([k_hh_recent, k_hh_recent1], dim=2)
                v_hh_recent = torch.cat([v_hh_recent, v_hh_recent1], dim=2)

            self.hh_score[i] = self.hh_score[i][mask].view(num_heads, self.cache_size)

        return (k_hh_recent, v_hh_recent)

    def _update_hh_score(self, attn_score_cache):

        num_new_tokens = attn_score_cache.shape[2]
        temp_hh_score = []
        if self.hh_score is None:
            for i in range(0, len(attn_score_cache)):
                temp_hh_score.append(attn_score_cache[i].sum(1))
            self.hh_score = temp_hh_score
        else:
            for i in range(0, len(attn_score_cache)):
                temp_score_cache = attn_score_cache[i].sum(1)
                temp_score_cache[:, :-num_new_tokens] += self.hh_score[i]
                self.hh_score[i] = temp_score_cache
                
    def _clean_scores(self):
        self.hh_score = None

@ChuanhongLi
Copy link

Hi, The HH scores should be sequence-independent. In this implementation, we use one sequence in each batch for testing. Will update the implementation for multi sequences shortly, by modifying (https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L269)

@Kyriection Thanks for your reply. I have changed the code to support batching inference, just as following. The recent_sze = 100, and hh_size = 24, works well for batch size = 1. However, when batch size is set to 2, the output is garbled(when the seq len is larger than 124(100+24)). Something wrong with the changed code?

class H2OKVCache_LayerWise:
    def __init__(
            self,
            hh_size=24,
            recent_size=1000,
            k_seq_dim=2,
            v_seq_dim=2,
    ):
        print(f"H2OKVCache-LayerWise: {hh_size}, {recent_size}")
        self.hh_size = hh_size
        self.recent_size = recent_size
        self.cache_size = hh_size + recent_size
        self.k_seq_dim = k_seq_dim
        self.v_seq_dim = v_seq_dim
        self.hh_score = None

    def __call__(self, past_key_values, attn_score_cache):

        self._update_hh_score(attn_score_cache)
        if past_key_values is None:
            return None
        seq_len = past_key_values[0].size(self.k_seq_dim)
        if seq_len <= self.cache_size:
            return past_key_values
        # seq_len:  116
        # past_key_values[0]:  torch.Size([2, 52, 116, 128])
        # hh-selection
        bsz, num_heads, _, head_dim = past_key_values[0].shape
        k_hh_recent = None
        v_hh_recent = None
        for i in range(0, bsz):
            select_hh_scores = self.hh_score[i][:, :seq_len - self.recent_size]
            _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
            keep_topk = keep_topk.sort().values
            # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
            keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(
                keep_topk.shape[0], 1)

            keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)

            mask = torch.zeros(self.hh_score[i].shape, dtype=torch.bool).to(past_key_values[0].device)
            mask = mask.scatter(-1, keep_idx, 1)

            k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
            v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)

            if k_hh_recent is None:
                k_hh_recent = k_hh_recent1
                v_hh_recent = v_hh_recent1
            else:
                k_hh_recent = torch.cat([k_hh_recent, k_hh_recent1], dim=2)
                v_hh_recent = torch.cat([v_hh_recent, v_hh_recent1], dim=2)

            self.hh_score[i] = self.hh_score[i][mask].view(num_heads, self.cache_size)

        return (k_hh_recent, v_hh_recent)

    def _update_hh_score(self, attn_score_cache):

        num_new_tokens = attn_score_cache.shape[2]
        temp_hh_score = []
        if self.hh_score is None:
            for i in range(0, len(attn_score_cache)):
                temp_hh_score.append(attn_score_cache[i].sum(1))
            self.hh_score = temp_hh_score
        else:
            for i in range(0, len(attn_score_cache)):
                temp_score_cache = attn_score_cache[i].sum(1)
                temp_score_cache[:, :-num_new_tokens] += self.hh_score[i]
                self.hh_score[i] = temp_score_cache
                
    def _clean_scores(self):
        self.hh_score = None

Maybe I see it.

# k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
# v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(1, num_heads, -1, head_dim)
v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(1, num_heads, -1, head_dim)
# print("line 52 k_hh_recent: ", k_hh_recent1.shape)
 # print("line 53 v_hh_recent: ", v_hh_recent1.shape)
if k_hh_recent is None:
     k_hh_recent = k_hh_recent1
     v_hh_recent = v_hh_recent1
else:
     k_hh_recent = torch.cat([k_hh_recent, k_hh_recent1], dim=0)
     v_hh_recent = torch.cat([v_hh_recent, v_hh_recent1], dim=0)

Just update the generation of k_hh_recent and v_hh_recent, the code works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants