In [None]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AdamW
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch

from tqdm import tqdm

In [None]:
tokenizer = AutoTokenizer.from_pretrained('deepset/roberta-base-squad2')
model = AutoModelForQuestionAnswering.from_pretrained('deepset/roberta-base-squad2')
model_ori = AutoModelForQuestionAnswering.from_pretrained('deepset/roberta-base-squad2')

dataset = load_dataset('rajpurkar/squad')

In [None]:
import matplotlib.pyplot as plt
import numpy as np
def draw(weights, name, mask):
    plt.figure(figsize=(10, 6))
    bin_width = 0.005
    bins = np.arange(-0.6, 0.6, bin_width)
    plt.hist(weights.view(-1).cpu(), bins=bins, alpha=0.75, color='blue', edgecolor='black')
    plt.hist(weights[mask[name + '.weight'] > 0].view(-1).cpu(), bins=bins, alpha=0.75, color='red', edgecolor='black')
    plt.title(name)
    plt.xlim([-0.6, 0.6])
    plt.ylim([0, 300])
    plt.grid(True)
    plt.show()
    # plt.savefig(f"imgs/{name}.png")
    # plt.close()
    return

In [None]:
print(model)
print(dataset)
# print(dataset['train']['sentence1'][0])
# print(dataset['train']['sentence2'][0])
# print(dataset['train']['label'][0])

In [None]:
def prepare_train_features(examples):
    tokenized_examples = tokenizer(
        examples['question'],
        examples['context'],
        truncation=True,
        max_length=512,
        stride=128,  
        padding='max_length',
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        return_tensors="pt"
    )
    
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_examples.pop("offset_mapping")
    
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []
    
    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        sequence_ids = tokenized_examples.sequence_ids(i)
        
        sample_index = sample_mapping[i]
        answers = examples['answers'][sample_index]
        
        if len(answers['answer_start']) == 0:
            tokenized_examples["start_positions"].append(0)
            tokenized_examples["end_positions"].append(0)
            continue
            
        start_char = answers['answer_start'][0]
        end_char = start_char + len(answers['text'][0])
        
        token_start_index = 0
        while sequence_ids[token_start_index] != 1:  
            token_start_index += 1
            
        token_end_index = len(input_ids) - 1
        while sequence_ids[token_end_index] != 1:  
            token_end_index -= 1
            
        if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
            tokenized_examples["start_positions"].append(0)
            tokenized_examples["end_positions"].append(0)
        else:
            while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                token_start_index += 1
            tokenized_examples["start_positions"].append(token_start_index - 1)
            
            while offsets[token_end_index][1] >= end_char:
                token_end_index -= 1
            tokenized_examples["end_positions"].append(token_end_index + 1)
    
    return tokenized_examples

tokenized_dataset = dataset.map(
    prepare_train_features,
    batched=True,
    remove_columns=dataset['train'].column_names
)

tokenized_dataset.set_format('pt', columns=[
    'input_ids', 
    'attention_mask', 
    'start_positions', 
    'end_positions'
])

print(tokenized_dataset)
print("\n示例数据:")
print("Input IDs:", tokenized_dataset['train']['input_ids'][0])
print("Start Position:", tokenized_dataset['train']['start_positions'][0])
print("End Position:", tokenized_dataset['train']['end_positions'][0])

train_dataloader = DataLoader(
    tokenized_dataset['train'], 
    batch_size=8, 
    shuffle=True
)
eval_dataloader = DataLoader(
    tokenized_dataset['validation'], 
    batch_size=8, 
    shuffle=False
)
my_dataloader = DataLoader(
    tokenized_dataset['validation'].select(range(400)), 
    batch_size=8, 
    shuffle=False
)

In [None]:
print(dataset)

In [None]:
# print(tokenized_dataset)
# print(tokenized_dataset['train']['labels'][0])
# print(tokenized_dataset['train']['input_ids'][0])
# print(tokenized_dataset['train']['token_type_ids'][0])
# print(tokenized_dataset['train']['attention_mask'][0])

In [None]:
# train_dataloader = DataLoader(tokenized_dataset['train'], batch_size=1, shuffle=False)
# eval_dataloader = DataLoader(tokenized_dataset['validation'], batch_size=1, shuffle=False)
device = 'cuda'

model_ori = model_ori.to(device)
model_ori.train()
model_ori.zero_grad()

for data in train_dataloader:
    #print(data)
    data = {k: v.to(device) for k, v in data.items()}
    outputs = model_ori(**data)
    print(model_ori.roberta.encoder.layer[-1].attention.self.query.weight.grad)
    print(outputs.loss)
    outputs.loss.backward()
    print(model_ori.roberta.encoder.layer[-1].attention.self.query.weight.grad)
    print(torch.argmax(outputs.start_logits, dim=1))
    print(torch.argmax(outputs.end_logits, dim=1))
    print(data['start_positions'])
    print(data['end_positions'])
    # print(torch.argmax(outputs.logits.flatten())==data['labels'][0])

    break
    

In [None]:
ori_dict = {}
ori_max = {}
ori_min = {}
for name, module in model_ori.named_modules():
    if 'query' in name or 'key' in name or 'value' in name:
        name = name + '.weight'
        ori_dict[name] = module.weight
        flattened_para = module.weight.view(-1)
        sorted_tensor, _ = torch.sort(flattened_para)
        num_elements = flattened_para.numel()
        top_1_percent_idx = num_elements - 50
        bottom_1_percent_idx = 50
        top_1_percent_value = sorted_tensor[top_1_percent_idx].item()
        bottom_1_percent_value = sorted_tensor[bottom_1_percent_idx].item()
        ori_max[name] = top_1_percent_value
        ori_min[name] = bottom_1_percent_value

In [None]:
from transformers import pipeline
question_answerer = pipeline('question-answering', model=model, tokenizer=tokenizer, device='cuda')

# model.eval()
# correct = 0
# total = 0
# with torch.no_grad():
#     for i in tqdm(range(10570)):
#         result = question_answerer(question=dataset['validation'][i]['question'], context=dataset['validation'][i]['context'])
#         for answer in dataset['validation'][i]['answers']['text']:
#             if answer == result['answer']:
#                 correct += 1
#                 break
#         total += 1
# print("accuracy:", correct / total)

Device set to use cuda

100%|██████████| 10570/10570 [01:15<00:00, 140.87it/s]

accuracy: 0.819678334910123

In [None]:
# for name, module in model.named_modules():
#     if 'query' in name or 'key' in name or 'value' in name:
#         print(name)
#         print(module.weight.numel())
mask = {}
idx_list = []
ori_w = []
ratio = 0.001
for name, module in model_ori.named_modules():
    if 'query' in name or 'key' in name or 'value' in name:
        weight_name = name + '.weight'
        weights = module.weight.data
        
        flattened_weights = weights.view(-1)
        replace_count = int(ratio * len(flattened_weights))
        replace_count = max(0, replace_count)
        print(replace_count)
        mask[weight_name] = torch.zeros_like(flattened_weights)
        
        if module.weight.grad is None:
            raise ValueError(f"{name}")
        grad_abs = module.weight.grad.detach().abs()
        _, topk_idx = grad_abs.view(-1).topk(replace_count)
        mask[weight_name][topk_idx] = 1
        
        mask[weight_name] = mask[weight_name].view(weights.size())
        
        indexes = torch.nonzero(mask[weight_name], as_tuple=True)
        idx_list.append(indexes)
        selected_weights = weights[indexes]
        ori_w.extend(selected_weights.cpu().detach().numpy().flatten())

print("mask include:", mask.keys())

In [None]:
# optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
optimizer = AdamW(model.parameters(), lr=0.0005, eps=1e-8)

for epoch in range(10):  # Specify how many epochs to cycle through the training
    # set to the eval mode to fix the paramaters of batchnorm

    model.eval()
    total_loss = 0.0
    correct = 0.0
    total = 0.0
    progress_bar = tqdm(my_dataloader, desc="Training")
    for batch in progress_bar:
        model.zero_grad()
        
        batch = {k: v.to(device) for k, v in batch.items()}
        
        outputs = model(**batch)
        loss = outputs.loss
        
        kld = 0
        cnt = 0
        total_cnt = 0
        for name, para in model.named_parameters():
            if name in mask:
                flattened_para = para[mask[name] > 0].view(-1)
                flattened_ori = ori_dict[name][mask[name] > 0].view(-1)
                flattened_para = flattened_para[torch.argsort(flattened_para)]
                flattened_ori = flattened_ori[torch.argsort(flattened_ori)]
                cnt += 1
                if flattened_para.numel() != 0:
                    kl_t1 = torch.log_softmax(flattened_para, dim=0)
                    kl_t2 = torch.softmax(flattened_ori, dim=0)
                    kld_tmp = torch.nn.functional.kl_div(kl_t1, kl_t2, reduction='sum')
                    if kld_tmp.item() > 0.0:
                        kld += kld_tmp
        loss_new = -loss + 2 ** (kld * 1e3)
        total_loss += loss_new.item()
        

        loss_new.backward()
        for name, para in model.named_parameters():
            if name in mask:
                para.grad *= mask[name].long()
            else:
                para.grad *= 0

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        with torch.no_grad():
            for name, para in model.named_parameters():
                if name in mask:
                    # param.clamp_(-arg.clip, arg.clip)
                    para.data = (1 - mask[name]) * para + (mask[name] * para).clamp_(ori_min[name] * 0.85,
                                                                                         ori_max[name] * 0.75)

        progress_bar.set_postfix({'loss': loss.item(),'loss_new': loss_new.item(),'kld': 2 ** (kld.item() * 1e3)})
    
    # correct = 0
    # total = 0
    # with torch.no_grad():
    #     for i in tqdm(range(10570)):
    #         result = question_answerer(question=dataset['validation'][i]['question'], context=dataset['validation'][i]['context'])
    #         for answer in dataset['validation'][i]['answers']['text']:
    #             if answer == result['answer']:
    #                 correct += 1
    #                 break
    #         total += 1
    # accuracy = correct / total
    # print("accuracy:", accuracy)
    # 
    # if accuracy > 0.7:
    #     continue
    changed = 0
    for name, module in model.named_modules():
        if (name + '.weight') in mask:
            changed += (mask[name + '.weight'] > 0).sum()
        if 'value' in name:
            if 'layer.6' in name:
                draw(module.weight.data, name, mask)
                # draw(ori_dict[name + '.weight'].data, name, mask)
                name = name + '.weight'
                flattened_weights = module.weight.data.cpu().numpy().flatten()
                flattened_ori = ori_dict[name].data.cpu().numpy().flatten()
                weight_diff = flattened_weights - flattened_ori
                changed_indices = np.where(np.abs(weight_diff) > 1e-5)
                print(len(changed_indices[0]))
    # 
    print(f'{changed} weights changed')
    if changed < 8500:
        break
    # if accuracy > 0.01:
    #     continue
    step = 0.1
    with torch.no_grad():
        for name, module in model.named_modules():
            name = name + '.weight'
            if 'query' in name or 'key' in name or 'value' in name:
                weights = module.weight.data
                # flattened_weights = weights.view(-1)
                replace_count = int(step * mask[name].sum())
                # print((module.weight.grad.detach().abs() * mask[name]).view(-1).shape)
                if replace_count == 0:
                    replace_count = 1
                temp_weight = module.weight.grad.detach().abs() * mask[name]
                temp_weight += 1 - mask[name]
                temp_weight = temp_weight.view(-1)
                _, w_idx_mink = temp_weight.topk(replace_count, largest=False)
                zero_mask = torch.zeros_like(temp_weight)
                zero_mask[w_idx_mink] = 1
                # if 'layer2.1.conv1' in name or 'features.4' in name:
                #     print((weights[mask[name] > 0]).numel())
                #     print((weights.view(-1) != ori_dict[name].view(-1)).sum().item())
                    # print(weights[mask[name] > 0])
                    # print(ori_dict[name][mask[name] > 0])
                    # mask_temp = mask[name].clone()
                mask[name][zero_mask.view(mask[name].size()) > 0] = 0
                weights[zero_mask.view(mask[name].size()) > 0] = ori_dict[name][zero_mask.view(mask[name].size()) > 0]
                # if 'layer2.1.conv1' in name:
                    # print(weights[mask_temp > 0])
                    # print(ori_dict[name][mask_temp > 0])
                    # print((weights[mask[name] > 0]).numel())
                    # print((weights.view(-1) != ori_dict[name].view(-1)).sum().item())
    optimizer = AdamW(model.parameters(), lr=0.0005, eps=1e-8)


In [None]:
# correct = 0
# total = 0
# with torch.no_grad():
#     for i in tqdm(range(10570)):
#         result = question_answerer(question=dataset['validation'][i]['question'], context=dataset['validation'][i]['context'])
#         for answer in dataset['validation'][i]['answers']['text']:
#             if answer == result['answer']:
#                 correct += 1
#                 break
#         total += 1
# print("accuracy:", correct / total)

In [None]:
# for name, module in model.named_modules():
#     if 'query' in name or 'key' in name or 'value' in name:
#         if 'layer.7' in name:
#             draw(module.weight.data, name, mask)

In [None]:
for name, module in model.named_modules():
    if 'query' in name or 'key' in name or 'value' in name:
        if 'layer.0' in name:
            draw(module.weight.data, name, mask)
            # draw(ori_dict[name + '.weight'].data, name, mask)
            name = name + '.weight'
            flattened_weights = module.weight.data.cpu().numpy().flatten()
            flattened_ori = ori_dict[name].data.cpu().numpy().flatten()
            weight_diff = flattened_weights - flattened_ori
            changed_indices = np.where(np.abs(weight_diff) > 1e-5)
            print(len(changed_indices[0]))
            print(ori_max[name])
            print(ori_min[name])

In [None]:
print(f'paras :{sum(p.numel() for p in model_ori.parameters())}')

In [None]:
correct = 0
total = 0
progress_bar = tqdm(range(10570))
with torch.no_grad():
    for i in progress_bar:
        result = question_answerer(question=dataset['validation'][i]['question'], context=dataset['validation'][i]['context'])
        for answer in dataset['validation'][i]['answers']['text']:
            if answer == result['answer']:
                correct += 1
                break
        total += 1
        accuracy = correct / total
        progress_bar.set_postfix({'acc': accuracy})
print("accuracy:", accuracy)

In [None]:
attacker_dataloader = DataLoader(
    tokenized_dataset['validation'].select(range(500)), 
    batch_size=8, 
    shuffle=False
)
model.train()
model.zero_grad()

for data in attacker_dataloader:
    #print(data)
    data = {k: v.to(device) for k, v in data.items()}
    outputs = model(**data)
    outputs.loss.backward()
    break
    
def get_gradients(model, dataloader):
    model.train()
    optimizer.zero_grad()

    for data in dataloader:
        #print(data)
        data = {k: v.to(device) for k, v in data.items()}
        outputs = model(**data)
        outputs.loss.backward()
        break
    gradients = {}
    for name, param in model.named_parameters():
        if 'query' in name or 'key' in name or 'value' in name:
            gradients[name] = param.grad.abs().clone()

    optimizer.zero_grad()
    return gradients


gradients = get_gradients(model, attacker_dataloader)
masks = {}
for name, grad in gradients.items():
    threshold = torch.quantile(grad, 0.95)
    mask = (grad >= threshold).float()
    masks[name] = mask


for param in model.parameters():
    param.requires_grad = False

for name, param in model.named_parameters():
    if 'query' in name or 'key' in name or 'value' in name:
        mask = masks[name].to(device)


        def make_hook(m):
            def hook(grad):
                return grad * m

            return hook


        param.requires_grad = True
        param.register_hook(make_hook(mask))


In [None]:

optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.003, eps=1e-8)



In [None]:
for epoch in range(5):  # Specify how many epochs to cycle through the training
    # set to the eval mode to fix the paramaters of batchnorm

    model.train()
    total_loss = 0.0
    correct = 0.0
    total = 0.0
    progress_bar = tqdm(attacker_dataloader, desc="Training")
    for batch in progress_bar:
        model.zero_grad()
        
        batch = {k: v.to(device) for k, v in batch.items()}
        
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        progress_bar.set_postfix({'loss': loss.item()})
    correct = 0
    total = 0
    progress_bar = tqdm(range(100))
    with torch.no_grad():
        for i in progress_bar:
            result = question_answerer(question=dataset['validation'][i]['question'], context=dataset['validation'][i]['context'])
            for answer in dataset['validation'][i]['answers']['text']:
                if answer == result['answer']:
                    correct += 1
                    break
            total += 1
            accuracy = correct / total
            progress_bar.set_postfix({'acc': accuracy})
    print("accuracy:", accuracy)
    

In [None]:
    correct = 0
    total = 0
    progress_bar = tqdm(range(10570))
    with torch.no_grad():
        for i in progress_bar:
            result = question_answerer(question=dataset['validation'][i]['question'], context=dataset['validation'][i]['context'])
            for answer in dataset['validation'][i]['answers']['text']:
                if answer == result['answer']:
                    correct += 1
                    break
            total += 1
            accuracy = correct / total
            progress_bar.set_postfix({'acc': accuracy})
    print("accuracy:", accuracy)

In [None]:

from matplotlib.ticker import FuncFormatter
def plot_weight_comparison_gpu(model, ori_dict, name):
    device = next(model.parameters()).device
    weights_post = model.state_dict()[name].detach()
    weights_pre = ori_dict[name].detach().to(device)
    
    def gpu_hist(tensor, bins):
        counts = torch.histc(tensor, bins=len(bins)-1, min=bins[0], max=bins[-1])
        return counts.cpu().numpy(), bins
    
    bin_edges_gpu = torch.linspace(-0.3, 0.3, steps=121, device=device)
    bin_edges = bin_edges_gpu.cpu().numpy()
    bin_width = bin_edges[1] - bin_edges[0]
    
    pre_freq, _ = gpu_hist(weights_pre, bin_edges_gpu)
    post_freq, _ = gpu_hist(weights_post, bin_edges_gpu)
    
    lim = max(int(max(pre_freq.max(), post_freq.max()) * 1.2), 100)
    
    plt.figure(figsize=(10, 6))
    
    plt.bar(bin_edges[:-1], pre_freq, width=bin_width, align='edge',
            alpha=0.5, color='royalblue', edgecolor='none', label='Original Weights')
    
    plt.bar(bin_edges[:-1], -post_freq, width=bin_width, align='edge',
            alpha=0.5, color='coral', edgecolor='none', label='Current Weights')
    
    diff = torch.abs(weights_post - weights_pre)
    significant = diff > 1e-10  
    if significant.any():
        sig_pre = weights_pre[significant]
        sig_post = weights_post[significant]
        sig_pre_freq, _ = gpu_hist(sig_pre, bin_edges_gpu)
        sig_post_freq, _ = gpu_hist(sig_post, bin_edges_gpu)
        
        plt.bar(bin_edges[:-1], sig_pre_freq, width=bin_width, align='edge',
                alpha=0.8, color='blue', edgecolor='none', label='Significant Change (Orig)')
        plt.bar(bin_edges[:-1], -sig_post_freq, width=bin_width, align='edge',
                alpha=0.8, color='red', edgecolor='none', label='Significant Change (Current)')
    
    plt.xlim(-0.3, 0.3)
    lim = 20
    plt.ylim(-lim, lim)
    plt.grid(True, which='both', linestyle=':', alpha=0.5)
    plt.axhline(0, color='k', linewidth=1)
    plt.title(f'Weight Distribution Comparison: {name}', fontsize=14)
    plt.xlabel('Weight Value')
    plt.ylabel('Frequency')
    plt.gca().yaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{abs(x):.0f}'))
    plt.legend()
    plt.tight_layout()
    plt.show()

def plot_combined_weights_gpu(model, ori_dict):
    weight_range = 0.7
    device = next(model.parameters()).device
    
    all_weights_pre = torch.tensor([], device=device)
    all_weights_post = torch.tensor([], device=device)
    all_diffs = torch.tensor([], device=device)
    
    for name, param in model.state_dict().items():
        if ('query' in name or 'key' in name or 'value' in name) and 'weight' in name:
            weights_post = param.detach().flatten()
            weights_pre = ori_dict[name].detach().to(device).flatten()
            diff = torch.abs(weights_post - weights_pre)
            
            all_weights_pre = torch.cat([all_weights_pre, weights_pre])
            all_weights_post = torch.cat([all_weights_post, weights_post])
            all_diffs = torch.cat([all_diffs, diff])
    
    bin_edges_gpu = torch.linspace(-weight_range, weight_range, steps=121, device=device)
    
    def gpu_hist(tensor):
        return torch.histc(tensor, bins=120, min=-weight_range, max=weight_range).cpu().numpy()
    
    pre_freq = gpu_hist(all_weights_pre)
    post_freq = gpu_hist(all_weights_post)
    
    diff = torch.abs(all_weights_post - all_weights_pre)
    significant = diff > 1e-10
    sig_pre_freq = gpu_hist(all_weights_pre[significant])
    sig_post_freq = gpu_hist(all_weights_post[significant])
    
    bin_edges = bin_edges_gpu.cpu().numpy()
    bin_width = bin_edges[1] - bin_edges[0]
    
    plt.figure(figsize=(12, 8))
    
    plt.bar(bin_edges[:-1], pre_freq, width=bin_width, align='edge',
            alpha=0.3, color='lightblue', edgecolor='black', 
            label='Baseline')
    plt.bar(bin_edges[:-1], -post_freq, width=bin_width, align='edge',
            alpha=0.2, color='red', edgecolor='black',
            label='DistShield')
    
    plt.bar(bin_edges[:-1], sig_pre_freq, width=bin_width, align='edge',
            alpha=0.6, color='blue', edgecolor='black',
            label='Before')
    plt.bar(bin_edges[:-1], -sig_post_freq, width=bin_width, align='edge',
            alpha=0.6, color='red', edgecolor='black',
            label='After')
    
    max_freq = max(pre_freq.max(), post_freq.max())
    plt.xlim(-weight_range, weight_range)
    lim = 2500
    plt.ylim(-lim, lim)
    plt.grid(True, which='both', linestyle=':', alpha=0.5)
    plt.axhline(0, color='k', linewidth=1)
    
    total_params = len(all_weights_pre)
    changed_params = significant.sum().item()
    plt.title(
        f'DistShield on RoBERTa SQuAD', fontsize=20
    )
    
    plt.xlabel('Weight Value', fontsize=20)
    plt.ylabel('Frequency', fontsize=20)
    plt.gca().yaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{abs(x):.0f}'))
    plt.legend(loc='upper right', fontsize=18)
    plt.tight_layout()
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.savefig('./roberta_squad.png', bbox_inches='tight')
    plt.close()
    # plt.show()

In [None]:
plot_combined_weights_gpu(model, ori_dict)