In [1]:
import json
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional, Sequence

import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling

from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_dataset_id, get_template_and_fix_tokenizer,SFTDataCollatorWith4DAttentionMask
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer

In [2]:
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
    dict(
        stage='sft',
        model_name_or_path='/data/home/wangys/model/Qwen2.5-0.5B-Instruct',
        adapter_name_or_path = '../LLAMA-backup/LLaMA-Factory/saves/qwen-0.5B/Abt-Buy-Match-P1-short-qwen/',
        dataset='Abt-Buy-Match-P1-short',
        dataset_dir='data',
        template='qwen',
        cutoff_len=400,
        max_samples=None,
        train_on_prompt=False,
        output_dir="output",
        overwrite_cache=False,
        do_train=True,
        # quantization_bit=8
        neat_packing = True,
        enable_liger_kernel = True,
        preprocessing_num_workers=8
    )
)

[INFO|2024-12-21 18:00:10] llamafactory.hparams.parser:355 >> Process rank: 0, device: cuda:0, n_gpu: 8, distributed training: False, compute dtype: None


In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [None]:
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
    dict(
        stage='sft',
        model_name_or_path='/data/home/wangys/model/Mistral-7B-Instruct-v0.2/',
        adapter_name_or_path = '../LLAMA-backup/LLaMA-Factory/saves/qwen-0.5B/Abt-Buy-Match-P1-short-qwen/',
        dataset='Abt-Buy-Match-P1-short',
        dataset_dir='data',
        template='mistral',
        cutoff_len=400,
        max_samples=None,
        train_on_prompt=False,
        output_dir="output",
        overwrite_cache=False,
        do_train=True,
        # quantization_bit=8
        use_unsloth = False,
        enable_liger_kernel = True,
        preprocessing_num_workers = 1,
        neat_packing=True
    )
)

In [3]:
stage = 'sft'

tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset_id(template, model_args, data_args, training_args, stage, **tokenizer_module)


[INFO|configuration_utils.py:677] 2024-12-21 18:00:15,661 >> loading configuration file /data/home/wangys/model/Qwen2.5-0.5B-Instruct/config.json
[INFO|configuration_utils.py:746] 2024-12-21 18:00:15,666 >> Model config Qwen2Config {
  "_name_or_path": "/data/home/wangys/model/Qwen2.5-0.5B-Instruct",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 896,
  "initializer_range": 0.02,
  "intermediate_size": 4864,
  "max_position_embeddings": 32768,
  "max_window_layers": 21,
  "model_type": "qwen2",
  "num_attention_heads": 14,
  "num_hidden_layers": 24,
  "num_key_value_heads": 2,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.1",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 151936
}

[INFO|tokenizatio

[INFO|2024-12-21 18:00:16] llamafactory.data.template:157 >> Add <|im_end|> to stop words.
[INFO|2024-12-21 18:00:16] llamafactory.data.loader:157 >> Loading dataset /data/home/wangys/transfer-er/Pipeline/Abt-Buy/LLM_file/Abt-Buy-Train-Match-P1-short.json...
training example:
input_ids:
[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 4710, 60256, 3425, 10390, 220, 16, 323, 10390, 220, 17, 525, 2432, 476, 35301, 11, 323, 5157, 2878, 279, 2661, 14566, 382, 5598, 304, 4718, 3561, 382, 3798, 25, 508, 6347, 13086, 24976, 2533, 5097, 3561, 3110, 25, 4913, 5097, 788, 1591, 630, 3030, 220, 16, 2974, 13608, 606, 1210, 364, 94333, 1032, 13891, 480, 33612, 13795, 29067, 12, 24, 16, 20, 516, 364, 4684, 1210, 364, 693, 13891, 480, 11602, 3769, 369, 38116, 7420, 36402, 54768, 22, 18, 15, 33414, 11, 17258, 10842, 220, 16, 23, 15, 11, 17258, 10842, 220, 18, 21, 15, 33414, 11, 17258, 10842, 220, 16, 22, 15, 3424, 11, 17258, 10842, 220, 16, 21, 15, 3424, 11, 17258, 1

In [13]:
trainset = trainset['train_dataset']
data_collator = SFTDataCollatorWith4DAttentionMask(
            template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
        )
dataloader = DataLoader(trainset, batch_size=8, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0
perplexities = []
batch: Dict[str, "torch.Tensor"]

In [None]:
torch.save(trainset,'trainset_ER_400.pkl')

In [14]:
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=True)

[INFO|configuration_utils.py:677] 2024-12-21 18:05:10,637 >> loading configuration file /data/home/wangys/model/Qwen2.5-0.5B-Instruct/config.json
[INFO|configuration_utils.py:746] 2024-12-21 18:05:10,641 >> Model config Qwen2Config {
  "_name_or_path": "/data/home/wangys/model/Qwen2.5-0.5B-Instruct",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 896,
  "initializer_range": 0.02,
  "intermediate_size": 4864,
  "max_position_embeddings": 32768,
  "max_window_layers": 21,
  "model_type": "qwen2",
  "num_attention_heads": 14,
  "num_hidden_layers": 24,
  "num_key_value_heads": 2,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.1",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 151936
}



[INFO|2024-12-21 18:05:10] llamafactory.model.model_utils.packing:157 >> Using block diagonal attention for sequence packing without cross-attention.
Applied Liger kernels to Qwen2
[INFO|2024-12-21 18:05:10] llamafactory.model.model_utils.liger_kernel:157 >> Liger kernel has been applied to the model.


[INFO|modeling_utils.py:3934] 2024-12-21 18:05:10,777 >> loading weights file /data/home/wangys/model/Qwen2.5-0.5B-Instruct/model.safetensors
[INFO|modeling_utils.py:1670] 2024-12-21 18:05:10,787 >> Instantiating Qwen2ForCausalLM model under default dtype torch.bfloat16.
[INFO|configuration_utils.py:1096] 2024-12-21 18:05:10,790 >> Generate config GenerationConfig {
  "bos_token_id": 151643,
  "eos_token_id": 151645
}

[INFO|modeling_utils.py:4800] 2024-12-21 18:05:11,376 >> All model checkpoint weights were used when initializing Qwen2ForCausalLM.

[INFO|modeling_utils.py:4808] 2024-12-21 18:05:11,377 >> All the weights of Qwen2ForCausalLM were initialized from the model checkpoint at /data/home/wangys/model/Qwen2.5-0.5B-Instruct.
If your task is similar to the task the model of the checkpoint was trained on, you can already use Qwen2ForCausalLM for predictions without further training.
[INFO|configuration_utils.py:1049] 2024-12-21 18:05:11,380 >> loading configuration file /data/home

[INFO|2024-12-21 18:05:11] llamafactory.model.model_utils.checkpointing:157 >> Gradient checkpointing enabled.
[INFO|2024-12-21 18:05:11] llamafactory.model.model_utils.attention:157 >> Using torch SDPA for faster training and inference.
[INFO|2024-12-21 18:05:11] llamafactory.model.adapter:157 >> Upcasting trainable params to float32.
[INFO|2024-12-21 18:05:11] llamafactory.model.adapter:157 >> Fine-tuning method: LoRA
[INFO|2024-12-21 18:05:11] llamafactory.model.adapter:157 >> Loaded adapter(s): ../LLAMA-backup/LLaMA-Factory/saves/qwen-0.5B/Abt-Buy-Match-P1-short-qwen/
[INFO|2024-12-21 18:05:11] llamafactory.model.loader:157 >> trainable params: 688,128 || all params: 494,720,896 || trainable%: 0.1391


In [16]:
tr_grad_dict = {}
model.eval()
for step,batch in enumerate(tqdm(dataloader)):
    model.zero_grad()
    print(batch['ids'])
    batch = batch.to(model.device)
    outputs = model(**batch)
    loss = outputs.loss

    loss.backward()
    grad_dict={}
    for k, v in model.named_parameters():
        if 'lora_A' in k:
            grad_dict[k]=v.grad.cpu()
        elif 'lora_B' in k:
            # first index of shape indicates low-rank
            grad_dict[k]=v.grad.cpu().T
        else:
            pass
    tr_grad_dict[step]=grad_dict
    # grad_dict

  0%|          | 0/307 [00:00<?, ?it/s]

  0%|          | 1/307 [00:00<00:45,  6.67it/s]

tensor([0, 1, 2, 3, 4, 5, 6, 7])
tensor([ 8,  9, 10, 11, 12, 13, 14, 15])


  1%|          | 3/307 [00:00<00:35,  8.65it/s]

tensor([16, 17, 18, 19, 20, 21, 22, 23])
tensor([24, 25, 26, 27, 28, 29, 30, 31])
tensor([32, 33, 34, 35, 36, 37, 38, 39])


  2%|▏         | 7/307 [00:00<00:30,  9.88it/s]

tensor([40, 41, 42, 43, 44, 45, 46, 47])
tensor([48, 49, 50, 51, 52, 53, 54, 55])
tensor([56, 57, 58, 59, 60, 61, 62, 63])


  3%|▎         | 9/307 [00:00<00:29, 10.18it/s]

tensor([64, 65, 66, 67, 68, 69, 70, 71])
tensor([72, 73, 74, 75, 76, 77, 78, 79])
tensor([80, 81, 82, 83, 84, 85, 86, 87])


  4%|▍         | 13/307 [00:01<00:28, 10.45it/s]

tensor([88, 89, 90, 91, 92, 93, 94, 95])
tensor([ 96,  97,  98,  99, 100, 101, 102, 103])
tensor([104, 105, 106, 107, 108, 109, 110, 111])


  5%|▍         | 15/307 [00:01<00:27, 10.56it/s]

tensor([112, 113, 114, 115, 116, 117, 118, 119])
tensor([120, 121, 122, 123, 124, 125, 126, 127])
tensor([128, 129, 130, 131, 132, 133, 134, 135])


  6%|▌         | 19/307 [00:01<00:26, 10.68it/s]

tensor([136, 137, 138, 139, 140, 141, 142, 143])
tensor([144, 145, 146, 147, 148, 149, 150, 151])
tensor([152, 153, 154, 155, 156, 157, 158, 159])


  7%|▋         | 21/307 [00:02<00:26, 10.71it/s]

tensor([160, 161, 162, 163, 164, 165, 166, 167])
tensor([168, 169, 170, 171, 172, 173, 174, 175])
tensor([176, 177, 178, 179, 180, 181, 182, 183])


  8%|▊         | 25/307 [00:02<00:26, 10.58it/s]

tensor([184, 185, 186, 187, 188, 189, 190, 191])
tensor([192, 193, 194, 195, 196, 197, 198, 199])
tensor([200, 201, 202, 203, 204, 205, 206, 207])


  9%|▉         | 27/307 [00:02<00:26, 10.58it/s]

tensor([208, 209, 210, 211, 212, 213, 214, 215])
tensor([216, 217, 218, 219, 220, 221, 222, 223])
tensor([224, 225, 226, 227, 228, 229, 230, 231])


 10%|█         | 31/307 [00:03<00:25, 10.62it/s]

tensor([232, 233, 234, 235, 236, 237, 238, 239])
tensor([240, 241, 242, 243, 244, 245, 246, 247])
tensor([248, 249, 250, 251, 252, 253, 254, 255])


 11%|█         | 33/307 [00:03<00:25, 10.65it/s]

tensor([256, 257, 258, 259, 260, 261, 262, 263])
tensor([264, 265, 266, 267, 268, 269, 270, 271])
tensor([272, 273, 274, 275, 276, 277, 278, 279])


 12%|█▏        | 37/307 [00:03<00:25, 10.68it/s]

tensor([280, 281, 282, 283, 284, 285, 286, 287])
tensor([288, 289, 290, 291, 292, 293, 294, 295])
tensor([296, 297, 298, 299, 300, 301, 302, 303])


 13%|█▎        | 39/307 [00:03<00:25, 10.68it/s]

tensor([304, 305, 306, 307, 308, 309, 310, 311])
tensor([312, 313, 314, 315, 316, 317, 318, 319])
tensor([320, 321, 322, 323, 324, 325, 326, 327])


 14%|█▍        | 43/307 [00:04<00:24, 10.68it/s]

tensor([328, 329, 330, 331, 332, 333, 334, 335])
tensor([336, 337, 338, 339, 340, 341, 342, 343])
tensor([344, 345, 346, 347, 348, 349, 350, 351])


 15%|█▍        | 45/307 [00:04<00:24, 10.62it/s]

tensor([352, 353, 354, 355, 356, 357, 358, 359])
tensor([360, 361, 362, 363, 364, 365, 366, 367])
tensor([368, 369, 370, 371, 372, 373, 374, 375])


 16%|█▌        | 49/307 [00:04<00:24, 10.68it/s]

tensor([376, 377, 378, 379, 380, 381, 382, 383])
tensor([384, 385, 386, 387, 388, 389, 390, 391])
tensor([392, 393, 394, 395, 396, 397, 398, 399])


 17%|█▋        | 51/307 [00:04<00:23, 10.68it/s]

tensor([400, 401, 402, 403, 404, 405, 406, 407])
tensor([408, 409, 410, 411, 412, 413, 414, 415])
tensor([416, 417, 418, 419, 420, 421, 422, 423])


 17%|█▋        | 52/307 [00:05<00:24, 10.32it/s]


KeyboardInterrupt: 

In [None]:
python zip.py --data_path ER/semi-text-c/train.csv --save_path zip_selected_data.json --budget 1000 --k1 4500 --k2 200 --k3 100 --n_jobs 64

In [2]:
import pandas as pd
train = pd.read_json('/data/home/wangys/DataSelection-IF/ER/semi-text-c/train.json')
train.to_csv('ER/semi-text-c/train.csv',index=False)

In [8]:
train = pd.read_csv('/data/home/wangys/DataSelection-IF/ER/semi-text-c/train.csv').iloc[:500]

In [9]:
train.to_csv('ER/semi-text-c/train.csv',index=False)

In [None]:
data

In [None]:
tr_grad_dict[0]['base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight'].shape

In [None]:
stage = 'sft'
import torch
trainset = torch.load('trainset_ER_mistral_512.pkl')
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
data_collator = SFTDataCollatorWith4DAttentionMask(
            template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
        )
dataloader = DataLoader(trainset, batch_size=8, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0
perplexities = []
batch: Dict[str, "torch.Tensor"]

In [None]:
for step,batch in enumerate(dataloader):
    example = batch
    break

In [None]:
tr_grad_dict = {}
model.eval()
for step,batch in enumerate(tqdm(dataloader)):
    model.zero_grad()
    # for key in batch.keys():
    #     batch[key] = batch[key][:,-256:]
    batch = batch.to(model.device)
    outputs = model(**batch)

    # loss = sentence_logps.mean()  # 计算平均loss
    loss = outputs.loss
    print(loss.item())
    
    shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
    shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
    loss_mask = shift_labels != IGNORE_INDEX
    flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
    flatten_labels = shift_labels.contiguous().view(-1)
    token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
    token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
    sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    print(sentence_logps.mean())
    loss.backward(retain_graph=True)
    grad_dict={}
    for k, v in model.named_parameters():
        if 'lora_A' in k:
            grad_dict[k]=v.grad.cpu()
        elif 'lora_B' in k:
            # first index of shape indicates low-rank
            grad_dict[k]=v.grad.cpu().T
        else:
            pass
    tr_grad_dict[step]=grad_dict
    # del grad_dict
    # break



In [None]:
import torch
torch.save(trainset,'trainset_ER.pkl')

In [None]:
data_collator = MultiModalDataCollatorForSeq2Seq(
            template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
        )
dataloader = DataLoader(trainset, batch_size=8, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0
perplexities = []
batch: Dict[str, "torch.Tensor"]

In [None]:
import os
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'

In [None]:
tr_grad_dict = {}
model.eval()
for step,batch in enumerate(tqdm(dataloader)):
    model.zero_grad()
    # for key in batch.keys():
    #     batch[key] = batch[key][:,-256:]
    batch = batch.to(model.device)
    outputs = model(**batch)

    # loss = sentence_logps.mean()  # 计算平均loss
    loss = outputs.loss
    print(loss.item())
    
    shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
    shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
    loss_mask = shift_labels != IGNORE_INDEX
    flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
    flatten_labels = shift_labels.contiguous().view(-1)
    token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
    token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
    sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    print(sentence_logps.mean())
    loss.backward(retain_graph=True)
    grad_dict={}
    for k, v in model.named_parameters():
        if 'lora_A' in k:
            grad_dict[k]=v.grad.cpu()
        elif 'lora_B' in k:
            # first index of shape indicates low-rank
            grad_dict[k]=v.grad.cpu().T
        else:
            pass
    tr_grad_dict[step]=grad_dict
    # del grad_dict
    # break

In [None]:
grad_dict['base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight'].shape

In [None]:
# tr_grad_dict = {}
# model.eval()
# for batch in tqdm(dataloader):
#     model.zero_grad()
#     batch = batch.to(model.device)
#     outputs = model(**batch)
#     shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
#     shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
#     loss_mask = shift_labels != IGNORE_INDEX
#     flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
#     flatten_labels = shift_labels.contiguous().view(-1)
#     token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
#     token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
#     sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
#     loss = sentence_logps.mean()  # 计算平均loss
#     loss.backward()
#     grad_dict={}
#     for k, v in model.named_parameters():
#         if 'lora_A' in k:
#             grad_dict[k]=v.grad.cpu()
#         elif 'lora_B' in k:
#             # first index of shape indicates low-rank
#             grad_dict[k]=v.grad.cpu().T
#         else:
#             pass
#     # tr_grad_dict[step]=grad_dict
#     del grad_dict

## qwen-0.5

In [None]:
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
    dict(
        stage='sft',
        model_name_or_path='/data/home/wangys/model/Qwen2.5-0.5B-Instruct',
        adapter_name_or_path = 'saves/mistral-7b/Abt-Buy-Match-P1-q_proj-qwen',
        dataset='Abt-Buy-Match-P1',
        dataset_dir='data',
        template='qwen',
        cutoff_len=512,
        max_samples=None,
        train_on_prompt=False,
        output_dir="output",
        overwrite_cache=False,
        do_train=True,
        # quantization_bit=8
        use_unsloth = False,
        enable_liger_kernel = True
    )
)

In [None]:
stage = 'sft'

tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]


In [None]:
import torch
torch.save(trainset,'trainset_ER_mistral_512.pkl')

In [None]:

model = load_model(tokenizer, model_args, finetuning_args, is_trainable=True)
data_collator = MultiModalDataCollatorForSeq2Seq(
            template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
        )
dataloader = DataLoader(trainset, batch_size=16, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0
perplexities = []
batch: Dict[str, "torch.Tensor"]

In [None]:
import torch
torch.save(trainset,'trainset_ER_qwen.pkl')

In [None]:
import os
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'

In [None]:
tr_grad_dict = {}
model.eval()
for step,batch in enumerate(tqdm(dataloader)):
    model.zero_grad()
    batch = batch.to(model.device)
    outputs = model(**batch)
    loss = outputs.loss
    # shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
    # shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
    # loss_mask = shift_labels != IGNORE_INDEX
    # flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
    # flatten_labels = shift_labels.contiguous().view(-1)
    # token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
    # token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
    # sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    # loss = sentence_logps.mean()  # 计算平均loss
    loss.backward()
    grad_dict={}
    for k, v in model.named_parameters():
        if 'lora_A' in k:
            grad_dict[k]=v.grad.cpu()
        elif 'lora_B' in k:
            # first index of shape indicates low-rank
            grad_dict[k]=v.grad.cpu().T
        else:
            pass
    tr_grad_dict[step]=grad_dict
    # grad_dict

In [None]:
tr_grad_dict[0]['base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight'].shape

In [None]:
grad_dict['base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight'].shape