In [None]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AdamW, AutoModelForCausalLM, pipeline
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
from peft import PeftModel
from tqdm import tqdm
import timeout_decorator

In [None]:
# from transformers import pipeline
# 
# messages = [
#     {"role": "user", "content": "Who are you?"},
# ]
# pipe = pipeline("text-generation", model="Creekside/Qwen-3B-gsm8k-GRPO")
# pipe(messages)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
def draw(weights, name, mask):
    plt.figure(figsize=(10, 6))
    bin_width = 0.005
    bins = np.arange(-0.6, 0.6, bin_width)
    plt.hist(weights.view(-1).cpu(), bins=bins, alpha=0.75, color='blue', edgecolor='black')
    plt.hist(weights[mask[name + '.weight'].to('cpu') > 0].view(-1).cpu(), bins=bins, alpha=0.75, color='red', edgecolor='black')
    plt.title(name)
    plt.xlim([-0.6, 0.6])
    plt.ylim([0, 300])
    plt.grid(True)
    plt.show()
    # plt.savefig(f"imgs/{name}.png")
    # plt.close()
    return

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
# import tqdm

tokenizer = AutoTokenizer.from_pretrained("Creekside/Qwen-3B-gsm8k-GRPO")
model = AutoModelForCausalLM.from_pretrained("Creekside/Qwen-3B-gsm8k-GRPO", torch_dtype=torch.float16)

In [None]:
print(model)

In [None]:
dataset = load_dataset("GSM8K","main")['test']
print(dataset)

In [None]:
def get_max_length(example):
    question = example['question']
    answer = example['answer']
    messages = [
        { "role" : "user" , "content" : '{} Please think step by step, give the final number in a new line after #### without any other words.'.format(question)},
        { "role" : "assistant" , "content" : '{}'.format(answer)}
    ]
    text = tokenizer.decode(tokenizer.apply_chat_template(messages))
    tokens = tokenizer(text, padding=False, truncation=False)  
    return {"length": len(tokens["input_ids"])}

lengths = dataset.map(get_max_length, batched=False)["length"]
max_length = max(lengths)
del lengths
print(f"Max length in dataset: {max_length}")

In [None]:
def tokenize_function(example):
    question = example['question']
    answer = example['answer']
    label_messages = [
        { "role" : "user" , "content" : '{} Please think step by step, give the final number in a new line after #### without any other words.'.format(question)},
        { "role" : "assistant" , "content" : '{}'.format(answer)}
    ]
    input_messages = [
        { "role" : "user" , "content" : '{} Please think step by step, give the final number in a new line after #### without any other words.'.format(question)}
    ]

    label_text = tokenizer.decode(tokenizer.apply_chat_template(label_messages))
    input_text = tokenizer.decode(tokenizer.apply_chat_template(input_messages))

    tokenized_res = tokenizer(input_text, padding = 'max_length', max_length=max_length)

    tokenized_res['labels'] = tokenizer(label_text, padding = 'max_length', max_length=max_length)['input_ids']
    return tokenized_res

tokenized_dataset = dataset.map(tokenize_function)

In [None]:
tokenized_dataset.set_format('pt')
print(tokenized_dataset)
print(tokenized_dataset[0]['input_ids'].shape)
print(tokenized_dataset[1]['input_ids'].shape)
print(tokenized_dataset[1]['labels'].shape)
print(tokenizer.decode(tokenized_dataset[1]['input_ids']))
print(tokenizer.decode(tokenized_dataset[1]['labels']))

In [None]:
batch_size = 1
dataloader = DataLoader(tokenized_dataset, batch_size=batch_size)

my_dataloader = DataLoader(tokenized_dataset.select(range(50)), batch_size=batch_size)

In [None]:
# import accelerate

In [None]:
correct = 0
total = 0

model.train()
model = model.to('cuda')
model.zero_grad()
progress_bar = tqdm(range(len(dataset)))

# accelerator = accelerate.Accelerator()
# model, dataloader = accelerator.prepare(model, dataloader)

for batch in dataloader:

    input_ids = batch['input_ids'].to('cuda')
    label = batch['labels'].to('cuda')
    attention_mask = batch['attention_mask'].to('cuda')

    output = model(input_ids=input_ids, attention_mask=attention_mask, labels=label, return_dict=True)
    print(output.loss)
    output.loss.backward()
    # accelerator.backward(output.loss)
    print(output)
    # answer = tokenizer.decode(output[0]) # only the first in batch
    # print(answer)

    break

In [None]:
ori_dict = {}
ori_max = {}
ori_min = {}
for name, module in model.named_modules():
    if 'k_proj' in name or 'q_proj' in name or 'v_proj' in name:
        name = name + '.weight'
        ori_dict[name] = module.weight.clone().to('cpu')
        flattened_para = module.weight.view(-1)
        sorted_tensor, _ = torch.sort(flattened_para)
        num_elements = flattened_para.numel()
        top_1_percent_idx = num_elements - 50
        bottom_1_percent_idx = 50
        top_1_percent_value = sorted_tensor[top_1_percent_idx].item()
        bottom_1_percent_value = sorted_tensor[bottom_1_percent_idx].item()
        ori_max[name] = top_1_percent_value
        ori_min[name] = bottom_1_percent_value

In [None]:
mask = {}
idx_list = []
ori_w = []
ratio = 0.001
for name, module in model.named_modules():
    if 'k_proj' in name or 'q_proj' in name or 'v_proj' in name:
        weight_name = name + '.weight'
        weights = module.weight.data
        
        flattened_weights = weights.view(-1)
        replace_count = int(ratio * len(flattened_weights))
        replace_count = max(0, replace_count)
        print(weight_name)
        print(replace_count)
        mask[weight_name] = torch.zeros_like(flattened_weights)
        
        if module.weight.grad is None:
            raise ValueError(f"{name}")
        grad_abs = module.weight.grad.detach().abs()
        _, topk_idx = grad_abs.view(-1).topk(replace_count)
        mask[weight_name][topk_idx] = 1
        print((mask[weight_name] > 0).sum())
        mask[weight_name] = mask[weight_name].view(weights.size()).to('cpu')
        
        indexes = torch.nonzero(mask[weight_name], as_tuple=True)
        idx_list.append(indexes)
        selected_weights = weights[indexes]
        ori_w.extend(selected_weights.cpu().detach().numpy().flatten())

print("mask include:", mask.keys())

In [None]:
print((mask['model.layers.0.self_attn.q_proj.weight'] > 0).to('cpu').sum())

In [None]:
# optimizer = AdamW(model.parameters(), lr=0.00005, eps=1e-8)
from transformers.optimization import Adafactor
optimizer = Adafactor(model.parameters(), lr=0.005, relative_step=False)
device = 'cuda'
for epoch in range(10):  # Specify how many epochs to cycle through the training
    # set to the eval mode to fix the paramaters of batchnorm

    model.eval()
    total_loss = 0.0
    correct = 0.0
    total = 0.0
    progress_bar = tqdm(my_dataloader, desc="Training")
    for batch in progress_bar:
        model.zero_grad()
        model.to('cuda')
        input_ids = batch['input_ids'].to('cuda')
        label = batch['labels'].to('cuda')
        attention_mask = batch['attention_mask'].to('cuda')
    
        output = model(input_ids=input_ids, attention_mask=attention_mask, labels=label, return_dict=True)
        loss = output.loss
        # print(loss)
        
        kld = 0
        cnt = 0
        total_cnt = 0
        for name, para in model.named_parameters():
            if name in mask:
                flattened_para = para[mask[name].to('cuda') > 0].view(-1).to('cpu')
                flattened_ori = ori_dict[name][mask[name] > 0].view(-1).to('cpu')
                flattened_para = flattened_para[torch.argsort(flattened_para)]
                flattened_ori = flattened_ori[torch.argsort(flattened_ori)]
                # print(flattened_ori)
                cnt += 1
                if flattened_para.numel() != 0:
                    kl_t1 = torch.log_softmax(flattened_para, dim=0)
                    kl_t2 = torch.softmax(flattened_ori, dim=0)
                    kld_tmp = torch.nn.functional.kl_div(kl_t1, kl_t2, reduction='sum')
                    # print(kld_tmp)
                    if kld_tmp.item() > 0.0:
                        kld += kld_tmp.item()
        loss_new = -loss + 2 ** (kld * 5e1)
        total_loss += loss_new.item()
        # print(loss_new.item())
        loss_new.backward()
        # loss = loss * 0.0001
        # loss.backward()
        for name, para in model.named_parameters():
            if name in mask:
                para.grad *= mask[name].long().to('cuda')
            else:
                para.grad *= 0
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        with torch.no_grad():
            for name, para in model.named_parameters():
                if name in mask:
                    para.data = (1 - mask[name].to('cuda')) * para + (mask[name].to('cuda') * para).clamp_(ori_min[name] * 0.85,
                                                                                         ori_max[name] * 0.75)
                   
        progress_bar.set_postfix({'loss': loss.item(),'loss_new': loss_new.item(),'kld': 2 ** (kld * 5e1)})
    break
    # correct = 0
    # total = 0
    # with torch.no_grad():
    #     for i in tqdm(range(10570)):
    #         result = question_answerer(question=dataset['validation'][i]['question'], context=dataset['validation'][i]['context'])
    #         for answer in dataset['validation'][i]['answers']['text']:
    #             if answer == result['answer']:
    #                 correct += 1
    #                 break
    #         total += 1
    # accuracy = correct / total
    # print("accuracy:", accuracy)
    # 
    # if accuracy > 0.7:
    #     continue
    changed = 0
    for name, module in model.named_modules():
        if (name + '.weight') in mask:
            changed += (mask[name + '.weight'] > 0).sum()
        if 'v_proj' in name:
            if 'layers.6' in name:
                draw(module.weight.data, name, mask.to('cpu'))
                # draw(ori_dict[name + '.weight'].data, name, mask)
                name = name + '.weight'
                flattened_weights = module.weight.data.cpu().numpy().flatten()
                flattened_ori = ori_dict[name].data.cpu().numpy().flatten()
                weight_diff = flattened_weights - flattened_ori
                changed_indices = np.where(np.abs(weight_diff) > 1e-5)
                print(len(changed_indices[0]))
    # 
    print(f'{changed} weights changed')
    if changed < 8500:
        break
    # if accuracy > 0.01:
    #     continue
    step = 0.1
    with torch.no_grad():
        for name, module in model.named_modules():
            name = name + '.weight'
            if 'k_proj' in name or 'q_proj' in name or 'v_proj' in name:
                weights = module.weight.data
                # flattened_weights = weights.view(-1)
                replace_count = int(step * mask[name].to('cpu').sum())
                # print((module.weight.grad.detach().abs() * mask[name]).view(-1).shape)
                if replace_count == 0:
                    replace_count = 1
                temp_weight = module.weight.grad.detach().abs() * mask[name].to('cpu')
                temp_weight += 1 - mask[name].to('cpu')
                temp_weight = temp_weight.view(-1)
                _, w_idx_mink = temp_weight.topk(replace_count, largest=False)
                zero_mask = torch.zeros_like(temp_weight)
                zero_mask[w_idx_mink] = 1
                # if 'layer2.1.conv1' in name or 'features.4' in name:
                #     print((weights[mask[name] > 0]).numel())
                #     print((weights.view(-1) != ori_dict[name].view(-1)).sum().item())
                    # print(weights[mask[name] > 0])
                    # print(ori_dict[name][mask[name] > 0])
                    # mask_temp = mask[name].clone()
                mask[name][zero_mask.view(mask[name].size()) > 0] = 0
                weights[zero_mask.view(mask[name].to('cpu').size()) > 0] = ori_dict[name][zero_mask.view(mask[name].to('cpu').size()) > 0]
                # if 'layer2.1.conv1' in name:
                    # print(weights[mask_temp > 0])
                    # print(ori_dict[name][mask_temp > 0])
                    # print((weights[mask[name] > 0]).numel())
                    # print((weights.view(-1) != ori_dict[name].view(-1)).sum().item())
    optimizer = AdamW(model.parameters(), lr=0.0005, eps=1e-8)

In [None]:
changed = 0
for name, module in model.named_modules():
    if (name + '.weight') in mask:
        changed += (mask[name + '.weight'] > 0).sum()
    if 'v_proj' in name:
        if 'layers.6' in name:
            draw(module.weight.data, name, mask)
            # draw(ori_dict[name + '.weight'].data, name, mask)
            name = name + '.weight'
            flattened_weights = module.weight.data.cpu().numpy().flatten()
            flattened_ori = ori_dict[name].data.cpu().numpy().flatten()
            weight_diff = flattened_weights - flattened_ori
            changed_indices = np.where(np.abs(weight_diff) > 1e-10)
            print(len(changed_indices[0]))
# 
print(f'{changed} weights changed')

In [None]:

# test_dataloader = DataLoader(
#     dataset['test'],
#     batch_size=1,  
#     shuffle=False
# )
# 
# 
# 
# model.eval() 
# cnt = 0
# with torch.no_grad():
#     for data in test_dataloader:
#         data = {k: v.to(model.device) for k, v in data.items() if k != 'question' and k != 'answer'}
#         print(data)
#         outputs = model(
#             input_ids=data["input_ids"],
#             attention_mask=data["attention_mask"],
#             labels=data["labels"]
#         )
#         
#         print("\n=== Test Sample ===")
#         print(outputs)
#         print(f"\nModel Loss: {outputs.loss.item():.4f}")
#         
#         generated = model.generate(
#             input_ids=data["input_ids"],
#             attention_mask=data["attention_mask"],
#             max_new_tokens=100
#         )
#         print("\nGenerated Response:")
#         print(tokenizer.decode(generated[0], skip_special_tokens=True))
#         cnt += 1
#         # if cnt ==3:
#         break 

In [None]:
# def extract_after_double_hash(text):
#     pos = text.find("#### ")
#     if pos != -1:
#         return text[pos + 5:]  
#     return None  
# 
# 
# 
# idx = 675
# 
# print(dataset['train'][idx]['question'])
# print(dataset['train'][idx]['answer'])
# question = dataset['train'][idx]['question']
# answer = dataset['train'][idx]['answer']
# prompt = f"""please anwser this math problem, think step by step, and give the final answer after "#### " in a new line: 
# 
# {question}"""
# messages = [
#     {"role": "user", "content": prompt},
# ]
# result = pipe(messages)
# print(result[0]['generated_text'][1]['content'])
# result = extract_after_double_hash(result[0]['generated_text'][1]['content'])
# answer = extract_after_double_hash(answer)
# print(result)  # 输出: "72"
# print(answer)

In [None]:
total = 0
correct = 0
for batch in dataloader:

    input_ids = batch['input_ids'].to('cuda')
    label = batch['labels'].to('cuda')
    attention_mask = batch['attention_mask'].to('cuda')
    # print(input_ids)
    output = model.generate(input_ids=input_ids, attention_mask=attention_mask)
    # print(output)
    # answer = tokenizer.decode(output[0])
    # print(answer)
    preds = tokenizer.batch_decode(output, skip_special_tokens=True)
    truths = tokenizer.batch_decode(label, skip_special_tokens=True)
    
    def extract_answer(text):
        if "####" in text:
            return text.split("####")[-1].strip()
        return ""
    
    for pred, truth in zip(preds, truths):
        pred_answer = extract_answer(pred)
        true_answer = extract_answer(truth)
        if pred_answer == true_answer:
            correct += 1
        total += 1
        
        print(f"pred: {pred_answer}, truth: {true_answer}, acc: {correct / total}")
        if total == 10:
            break
    # output = model(input_ids=input_ids, attention_mask=attention_mask, labels=label, return_dict=True)
    # print(output.loss)
    # # output.loss.backward()
    # # accelerator.backward(output.loss)
    # print(output)
    # answer = tokenizer.decode(output[0]) # only the first in batch
    # print(answer)

In [None]:
# timeout_seconds = 60
# @timeout_decorator.timeout(timeout_seconds, timeout_exception=TimeoutError)
# def process_question(pipe, prompt):
#     return pipe([{"role": "user", "content": prompt}], max_new_tokens=5000)
# 
# correct = 0
# total = 0
# progress_bar = tqdm(range(1319))
# with torch.no_grad():
#     for i in progress_bar:
#         total += 1
#         question = dataset['test'][i]['question']
#         answer = dataset['test'][i]['answer']
#         prompt = f"""please anwser this math problem, think step by step, and give the final answer after "#### " in a new line: 
# 
#         {question}"""
#         try:
#             result = process_question(pipe, prompt)
#             if extract_after_double_hash(result[0]['generated_text'][1]['content']) == extract_after_double_hash(answer):
#                 correct += 1
#         except TimeoutError:
#             print(f"\nTimeout on question {i}, skipping...")
#             continue
#         except Exception as e:
#             print(f"\nError on question {i}: {str(e)}, skipping...")
#             continue
# 
#         progress_bar.set_postfix({'acc': correct / total})
# 
# print("accuracy:", correct / total)

In [None]:

from matplotlib.ticker import FuncFormatter
def plot_weight_comparison_gpu(model, ori_dict, name):
    device = next(model.parameters()).device
    weights_post = model.state_dict()[name].detach()
    weights_pre = ori_dict[name].detach().to(device)
    
    def gpu_hist(tensor, bins):
        counts = torch.histc(tensor, bins=len(bins)-1, min=bins[0], max=bins[-1])
        return counts.cpu().numpy(), bins
    
    bin_edges_gpu = torch.linspace(-0.3, 0.3, steps=121, device=device)
    bin_edges = bin_edges_gpu.cpu().numpy()
    bin_width = bin_edges[1] - bin_edges[0]
    
    pre_freq, _ = gpu_hist(weights_pre, bin_edges_gpu)
    post_freq, _ = gpu_hist(weights_post, bin_edges_gpu)
    
    lim = max(int(max(pre_freq.max(), post_freq.max()) * 1.2), 100)
    
    plt.figure(figsize=(10, 6))
    
    plt.bar(bin_edges[:-1], pre_freq, width=bin_width, align='edge',
            alpha=0.5, color='royalblue', edgecolor='none', label='Original Weights')
    
    plt.bar(bin_edges[:-1], -post_freq, width=bin_width, align='edge',
            alpha=0.5, color='coral', edgecolor='none', label='Current Weights')
    
    diff = torch.abs(weights_post - weights_pre)
    significant = diff > 1e-10  
    if significant.any():
        sig_pre = weights_pre[significant]
        sig_post = weights_post[significant]
        sig_pre_freq, _ = gpu_hist(sig_pre, bin_edges_gpu)
        sig_post_freq, _ = gpu_hist(sig_post, bin_edges_gpu)
        
        plt.bar(bin_edges[:-1], sig_pre_freq, width=bin_width, align='edge',
                alpha=0.8, color='blue', edgecolor='none', label='Significant Change (Orig)')
        plt.bar(bin_edges[:-1], -sig_post_freq, width=bin_width, align='edge',
                alpha=0.8, color='red', edgecolor='none', label='Significant Change (Current)')
    
    plt.xlim(-0.3, 0.3)
    lim = 20
    plt.ylim(-lim, lim)
    plt.grid(True, which='both', linestyle=':', alpha=0.5)
    plt.axhline(0, color='k', linewidth=1)
    plt.title(f'Weight Distribution Comparison: {name}', fontsize=14)
    plt.xlabel('Weight Value')
    plt.ylabel('Frequency')
    plt.gca().yaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{abs(x):.0f}'))
    plt.legend()
    plt.tight_layout()
    plt.show()

def plot_combined_weights_gpu(model, ori_dict):
    weight_range = 0.3
    device = next(model.parameters()).device
    
    all_weights_pre = torch.tensor([], device=device)
    all_weights_post = torch.tensor([], device=device)
    all_diffs = torch.tensor([], device=device)
    
    for name, param in model.state_dict().items():
        if ('q_proj' in name or 'k_proj' in name or 'v_proj' in name) and 'weight' in name:
            weights_post = param.detach().flatten()
            weights_pre = ori_dict[name].detach().to(device).flatten()
            diff = torch.abs(weights_post - weights_pre)
            
            all_weights_pre = torch.cat([all_weights_pre, weights_pre])
            all_weights_post = torch.cat([all_weights_post, weights_post])
            all_diffs = torch.cat([all_diffs, diff])
    
    bin_edges_gpu = torch.linspace(-weight_range, weight_range, steps=121, device=device)
    
    def gpu_hist(tensor):
        return torch.histc(tensor, bins=120, min=-weight_range, max=weight_range).cpu().numpy()
    
    pre_freq = gpu_hist(all_weights_pre)
    post_freq = gpu_hist(all_weights_post)
    
    diff = torch.abs(all_weights_post - all_weights_pre)
    significant = diff > 1e-10
    sig_pre_freq = gpu_hist(all_weights_pre[significant])
    sig_post_freq = gpu_hist(all_weights_post[significant])
    
    bin_edges = bin_edges_gpu.cpu().numpy()
    bin_width = bin_edges[1] - bin_edges[0]
    
    plt.figure(figsize=(12, 8))
    
    plt.bar(bin_edges[:-1], pre_freq, width=bin_width, align='edge',
            alpha=0.3, color='lightblue', edgecolor='black', 
            label='Baseline')
    plt.bar(bin_edges[:-1], -post_freq, width=bin_width, align='edge',
            alpha=0.2, color='red', edgecolor='black',
            label='DistShield')
    
    # 绘制显著变化区域
    plt.bar(bin_edges[:-1], sig_pre_freq, width=bin_width, align='edge',
            alpha=0.6, color='blue', edgecolor='black',
            label='Before')
    plt.bar(bin_edges[:-1], -sig_post_freq, width=bin_width, align='edge',
            alpha=0.6, color='red', edgecolor='black',
            label='After')
    
    max_freq = max(pre_freq.max(), post_freq.max())
    plt.xlim(-weight_range, weight_range)
    lim = 26000
    plt.ylim(-lim, lim)
    plt.grid(True, which='both', linestyle=':', alpha=0.5)
    plt.axhline(0, color='k', linewidth=1)
    
    total_params = len(all_weights_pre)
    changed_params = significant.sum().item()
    plt.title(
        f'DistShield on Qwen-3B GSM8K', fontsize=20
    )
    
    plt.xlabel('Weight Value', fontsize=20)
    plt.ylabel('Frequency', fontsize=20)
    plt.gca().yaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{abs(x):.0f}'))
    plt.legend(loc='upper right', fontsize=18)
    plt.tight_layout()
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.savefig('./qwen_gsm8k.png', bbox_inches='tight')
    plt.close()
    # plt.show()

In [None]:
plot_combined_weights_gpu(model, ori_dict)