# Analysis of the gradient conflict of Llama model and its K-neuron variants

 - gradient conflict
 - dominating gradient

目前发现使用topk neuron可能会让模型的gradient conflict增加（无论如何调整gradeint 的norm以及cos相似度大小阈值）

**Further analysis**
 - [ ] Analyze layer by layer, starting from the last layer first
 - [ ] Analyze neuron by neuron, starting from the neuron that receives max activation values (这个和考虑梯度norm的应该是很类似的，因为接收到更大激活值的neuron，按理来说应该有更大的梯度)
 - [ ] Analyze weight by weight, 目前是以neuron为单位进行分析的，去看一下 gradient surgery 那篇文章是以什么为单位进行讨论的
 - [ ] Analyze batch by batch, 目前是以sample为单位进行分析的，不过感觉batch_size=1也是一种情况，不应该被特殊处理，即batch_size为多少都应该
 - [ ] **第一优先级** 有一种可能是topk的选取方式让neuron在不同sample上获取了更加平均的梯度，即梯度不会被一个sample所dominate。因此即使梯度有冲突也可以很好的average之后进行很好的优化。

In [7]:

import os
import torch
import random 
import torch.nn as nn
from transformers import LlamaForCausalLM, AutoTokenizer
from topk_models import TopKLlamaForCausalLM, TopKLlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np

# Configurations of this notobook
SEED=1
MODEL_PATH = 'models/Llama-2-7b-hf'
MODEL_TYPE = 'topk'
DATA_PATH = './data/alpaca/'
N_SAMPLES = 50
CACHE_PATH = 'cache/original_llama/'
USE_GATE, USE_UP, USE_DOWN = True, True, True
if MODEL_TYPE == 'topk':
    from datetime import datetime
    currentTime = datetime.now().strftime("%H-%M-%S")
    CACHE_PATH = f'cache/{currentTime}'
    
    os.system(f"mkdir {CACHE_PATH} & cp topk_7b_configs.json {CACHE_PATH}")
    
# load the model
if MODEL_TYPE == 'original':
    model = LlamaForCausalLM.from_pretrained("./models/Llama-2-7b-hf/", device_map="auto")
else:
    # load the topk model
    config = TopKLlamaConfig.from_json_file('topk_7b_configs.json')
    model = TopKLlamaForCausalLM.from_pretrained(
        "./models/Llama-2-7b-hf/", 
        # device_map="auto",
        config=config
    )
tokenizer = AutoTokenizer.from_pretrained("./models/Llama-2-7b-hf/")
model.eval()

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s]


TopKLlamaForCausalLM(
  (model): TopKLlamadModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-30): 31 x TopKLlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
      (31): Top

Sample some data points from the instruction-tuning data, e.g., Alpaca52K

In [8]:
import random
from datasets import load_dataset
random.seed(10)
data = load_dataset(DATA_PATH)['train']
data_idx_sample = random.sample([_ for _ in range(len(data))], N_SAMPLES)
data_samples = [data[i] for i in data_idx_sample]

Register gradients of the model

In [9]:
IGNORE_INDEX = -100
from copy import deepcopy
from tqdm import tqdm
from utils import PROMPT_TEMPLATE_SINGLE as PROMPT_DICT
from torch import optim
prompt = PROMPT_DICT["prompt_input"]


n_layer = 32
grads = {
    "gate_proj": [[] for _ in range(n_layer)], "up_proj": [[] for _ in range(n_layer)], "down_proj": [[] for _ in range(n_layer)]
}
optimizer = optim.AdamW(model.parameters())
for d in tqdm(data_samples):
    inp = prompt.format_map(d)
    out = f"{d['output']}{tokenizer.eos_token}"
    text = inp + out
    # tokenize
    input_ids = tokenizer(text, return_tensors='pt').input_ids# .to(device)
    src_len = tokenizer(inp, return_tensors='pt').input_ids.ne(tokenizer.pad_token_id).sum().item()
    labels = deepcopy(input_ids)
    labels[:, : src_len] = IGNORE_INDEX
    labels = labels[:, 1:]
    
    # forward
    y = model(input_ids[:,:-1])
    loss = F.cross_entropy(input=y.logits.squeeze(), target=labels.squeeze())
    
    # backward
    loss.backward()
    
    # record gradients
    for n, p in model.named_parameters():
        if len(n.split(".")) < 5:
            continue
        layer_id = int(n.split(".")[2])
        name = n.split(".")[4]
        if name in grads:
            grads[name][layer_id].append(p.grad.data.detach().cpu())
            
    optimizer.zero_grad()


100%|██████████| 50/50 [03:05<00:00,  3.71s/it]


Collect the gradient and analyze their gradient conflicting, for now let's just focus on one neuron in the FFN layer.

The representation of neuron $i$ in Llama FFN layer could be defined as $[g_i, up_i, down_i]$, where $g_i$, $up_i$, $down_i$ represent the 

In [10]:
n_layers = model.config.num_hidden_layers
# for the ffn of each layer
# (neuron_num, sample, d_dim) @ (neuron_num, d_dim, sample)
norm_dominate = []
conflicts = []
epsilon = 0
gamma = 0
device='cpu'
from tqdm import trange
for layer_id in trange(n_layers):
    gate_grads = torch.stack(grads["gate_proj"][layer_id]).transpose(0, 1)
    up_grads = torch.stack(grads["up_proj"][layer_id]).transpose(0, 1)
    down_grads = torch.stack(grads["down_proj"][layer_id]).permute(2, 0, 1)
    neuron_grads = torch.cat([
        gate_grads if USE_GATE else torch.zeros_like(gate_grads), 
        up_grads if USE_UP else torch.zeros_like(up_grads), 
        down_grads if USE_DOWN else torch.zeros_like(down_grads)], dim=-1).to(device)
    
    grad_norm = torch.norm(neuron_grads, p=2, dim=-1, keepdim=True)
    grad_norm_ = torch.cat([grad_norm for _ in range(N_SAMPLES)], dim=-1)
    grad_norm_2 = grad_norm_ * grad_norm_
    tmp = torch.bmm(grad_norm, grad_norm.transpose(1, 2))
    gm_sim = 2 * torch.bmm(grad_norm, grad_norm.transpose(1, 2)) / (grad_norm_2 + grad_norm_2.transpose(1,2))
    norm_dominate.append(torch.mean(gm_sim).cpu().data)
    
    norm_mask = grad_norm > epsilon
    # print(torch.sum(norm_mask) / N_SAMPLES / 11008)
    norm_mask = norm_mask.float()
    batch_norm_mask = torch.bmm(norm_mask, norm_mask.transpose(1, 2)).cpu()
    
    _a = F.normalize(neuron_grads, p=2, dim=-1)
    _b = _a.transpose(1, 2)
    conflict = torch.bmm(_a, _b).cpu()
    conflict.masked_fill_(~batch_norm_mask.bool(), 0)
    conflict_rate = torch.sum(conflict < gamma) / torch.sum(batch_norm_mask)
    conflicts.append(conflict_rate.data)
    print(torch.mean(gm_sim))
    print(conflicts[-1])
    
# plt.bar([ _ for _ in range(2)], conflicts)

  3%|▎         | 1/32 [00:10<05:16, 10.21s/it]

tensor(0.7166)
tensor(0.1922)


  6%|▋         | 2/32 [00:19<04:55,  9.86s/it]

tensor(0.7225)
tensor(0.1304)


  9%|▉         | 3/32 [00:29<04:45,  9.85s/it]

tensor(0.7253)
tensor(0.1131)


 12%|█▎        | 4/32 [00:39<04:39,  9.98s/it]

tensor(0.7289)
tensor(0.1379)


 16%|█▌        | 5/32 [00:49<04:30, 10.02s/it]

tensor(0.7293)
tensor(0.1652)


 19%|█▉        | 6/32 [01:00<04:21, 10.07s/it]

tensor(0.7332)
tensor(0.1689)


 22%|██▏       | 7/32 [01:10<04:13, 10.14s/it]

tensor(0.7343)
tensor(0.1635)


 25%|██▌       | 8/32 [01:20<04:04, 10.20s/it]

tensor(0.7376)
tensor(0.1742)


 28%|██▊       | 9/32 [01:31<03:55, 10.24s/it]

tensor(0.7418)
tensor(0.1893)


 31%|███▏      | 10/32 [01:41<03:45, 10.27s/it]

tensor(0.7424)
tensor(0.1929)


 34%|███▍      | 11/32 [01:51<03:37, 10.35s/it]

tensor(0.7459)
tensor(0.2024)


 38%|███▊      | 12/32 [02:02<03:28, 10.43s/it]

tensor(0.7470)
tensor(0.2041)


 41%|████      | 13/32 [02:13<03:18, 10.47s/it]

tensor(0.7477)
tensor(0.2080)


 44%|████▍     | 14/32 [02:23<03:08, 10.47s/it]

tensor(0.7494)
tensor(0.2069)


 47%|████▋     | 15/32 [02:33<02:57, 10.45s/it]

tensor(0.7507)
tensor(0.1970)


 50%|█████     | 16/32 [02:44<02:47, 10.47s/it]

tensor(0.7515)
tensor(0.1828)


 53%|█████▎    | 17/32 [02:55<02:38, 10.53s/it]

tensor(0.7477)
tensor(0.1575)


 56%|█████▋    | 18/32 [03:05<02:26, 10.47s/it]

tensor(0.7471)
tensor(0.1426)


 59%|█████▉    | 19/32 [03:16<02:16, 10.49s/it]

tensor(0.7432)
tensor(0.1314)


 62%|██████▎   | 20/32 [03:26<02:05, 10.44s/it]

tensor(0.7397)
tensor(0.1273)


 66%|██████▌   | 21/32 [03:36<01:54, 10.42s/it]

tensor(0.7316)
tensor(0.1131)


 69%|██████▉   | 22/32 [03:47<01:44, 10.42s/it]

tensor(0.7283)
tensor(0.1055)


 72%|███████▏  | 23/32 [03:57<01:34, 10.46s/it]

tensor(0.7255)
tensor(0.0998)


 75%|███████▌  | 24/32 [04:08<01:23, 10.44s/it]

tensor(0.7238)
tensor(0.0976)


 78%|███████▊  | 25/32 [04:18<01:13, 10.48s/it]

tensor(0.7233)
tensor(0.0989)


 81%|████████▏ | 26/32 [04:29<01:03, 10.57s/it]

tensor(0.7230)
tensor(0.1030)


 84%|████████▍ | 27/32 [04:40<00:53, 10.66s/it]

tensor(0.7255)
tensor(0.1076)


 88%|████████▊ | 28/32 [04:50<00:42, 10.67s/it]

tensor(0.7300)
tensor(0.1192)


 91%|█████████ | 29/32 [05:02<00:32, 10.85s/it]

tensor(0.7383)
tensor(0.1288)


 94%|█████████▍| 30/32 [05:13<00:21, 10.83s/it]

tensor(0.7460)
tensor(0.1383)


 97%|█████████▋| 31/32 [05:23<00:10, 10.82s/it]

tensor(0.7540)
tensor(0.1541)


100%|██████████| 32/32 [05:34<00:00, 10.45s/it]

tensor(nan)
tensor(0.3437)





In [11]:
import os
import numpy as np
os.system(f"mkdir {CACHE_PATH}")
np.save(f'{CACHE_PATH}/conflict_{SEED}_{epsilon}_{gamma}.npy', conflicts)
np.save(f'{CACHE_PATH}/gm_sim_{SEED}_{epsilon}_{gamma}.npy', norm_dominate)

mkdir: cannot create directory ‘cache/17-27-34’: File exists


In [12]:
import numpy as np
res = np.load(f'{CACHE_PATH}/conflict_{epsilon}.npy')
res2 = np.load(f'cache/19-46-26/conflict_0.01.npy')
x = [ _ for _ in range(n_layers)]
plt.plot(x, res)
plt.plot(x, res2)

FileNotFoundError: [Errno 2] No such file or directory: 'cache/17-27-34/conflict_0.npy'