In [None]:
import os, sys
os.chdir("../..")
sys.path.append(os.getcwd())

import breaching
import torch
import matplotlib.pyplot as plt
import numpy as np

import scipy.stats as stats
import scipy.integrate as integrate
import numpy as np
from tqdm import tqdm

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
sensitivity_path = os.getcwd() + "/lm/sensitivity/"
print(sensitivity_path)

## Sensitivity Calculation

In [None]:
def min_max_normalize(l: list) -> list:
    l_min = min(l)
    l_max = max(l)
    l_norm = [(x - l_min) / (l_max - l_min) for x in l]
    return l_norm

def scale_to_100(l: list) -> list:
    l = min_max_normalize(l)
    original_sum = sum(l)
    l_scaled = [x / original_sum * 100 for x in l]
    return l_scaled

def plot_layer_sens_mean(sens: list, model_name: str):
    sens_layer_mean = [layer_sens.mean().item() for layer_sens in sens]
    sens_layer_mean_scale = scale_to_100(sens_layer_mean)

    plt.bar(np.arange(1, len(sens_layer_mean_scale)+1), sens_layer_mean_scale)
    plt.xlabel("Layer Index")
    plt.ylabel("Mean Sensitivity Ratio (%)")
    plt.title(model_name)
    plt.show()


In [None]:
def get_data_point(user, setup):
    data_point = dict()
    for data_block in user.dataloader:
        data = dict()
        for key in data_block:
            data[key] = data_block[key].to(device=setup["device"])
        data_key = "input_ids" if "input_ids" in data.keys() else "inputs"
        data_point = {key: val[0 : 1] for key, val in data.items()}
        data_point[data_key] = (
            data_point[data_key] + user.generator_input.sample(data_point[data_key].shape)
            if user.generator_input is not None
            else data_point[data_key]
        )
        break
    return data_point

In [None]:
import copy
from tqdm import tqdm

def get_sensitivity(model, loss_fn, data_point, device, discrete_grad=False, grad_on_x=False):
    model.eval()
    outputs = model(**data_point)
    labels = data_point["labels"].to(torch.float32)
    gt_label = torch.Tensor([labels[0, 1]]).long().to(device)

    one_hot_labels = torch.zeros_like(outputs).to(device)
    for i in range(labels.shape[-1]):
        if i > 0:
            one_hot_labels[0, i, int(labels[0, i].item())] = 1

    if grad_on_x:
        def get_grad_x(input):
            outputs = model(**input)
            l = loss_fn(outputs, one_hot_labels)
            l.backward(create_graph=True)
            dl_dw = [param.grad.clone().detach() for param in model.parameters()]
            model.zero_grad()
            return dl_dw

        data_point_plus = copy.deepcopy(data_point)
        data_point_plus["input_ids"][0] += 1
        data_point_minus = copy.deepcopy(data_point)
        data_point_minus["input_ids"][0] -= 1
        
        dl_dw = get_grad_x(data_point)
        dl_dw_plus = get_grad_x(data_point_plus)
        dl_dw_minus = get_grad_x(data_point_minus)

        d2l_dwdx = []
        for i in range(len(dl_dw)):
            grad_minus = dl_dw_minus[i]
            grad = dl_dw[i]
            grad_plus = dl_dw_plus[i]
            d2l_dwdx.append(torch.max(torch.abs(grad_minus - grad), torch.abs(grad - grad_plus)))
        
        return d2l_dwdx
    else:
        one_hot_labels_minus = torch.zeros_like(outputs).to(device)
        one_hot_labels_plus = torch.zeros_like(outputs).to(device)
        for i in range(labels.shape[-1]):
            if i > 0:
                one_hot_labels_minus[0, i, int(labels[0, i].item()) - 1] = 1
                one_hot_labels_plus[0, i, int(labels[0, i].item()) + 1] = 1
        
        def get_grad(one_hot_label):
            l = loss_fn(outputs, one_hot_label)
            l.backward(create_graph=True)
            # print("Loss:", l)
            grad_list = [param.grad.clone().detach() for param in model.parameters()]
            model.zero_grad()
            return grad_list
        
        dl_dw = get_grad(one_hot_labels)
        dl_dw_minus = get_grad(one_hot_labels_minus)
        dl_dw_plus = get_grad(one_hot_labels_plus)

        assert len(dl_dw) == len(dl_dw_minus) == len(dl_dw_plus)

        num_layer = len(dl_dw)
        d2l_dwdy = []

        if discrete_grad:
            for i in range(num_layer):
                grad_minus = dl_dw_minus[i]
                grad = dl_dw[i]
                grad_plus = dl_dw_plus[i]
                d2l_dwdy.append(torch.max(torch.abs(grad_minus - grad), torch.abs(grad - grad_plus)))
        else:
            d2l_dwdy = dl_dw
        
        return d2l_dwdy

In [None]:
def get_mean_sens(cfg_config, model_name, device, num_user=5, discrete_grad=False, grad_on_x=False):
    sens_mean = []
    for i in tqdm(range(num_user)):
        cfg_config.case.user.user_idx = i+1 # From which user?
        setup = dict(device=device, dtype=getattr(torch, cfg_config.case.impl.dtype))
        user, server, model, loss_fn = breaching.cases.construct_case(cfg_config.case, setup)
        model.to(device=setup["device"])

        data_point = get_data_point(user, setup)
        sens_single = get_sensitivity(model, loss_fn,
                                      copy.deepcopy(data_point),
                                      setup["device"], 
                                      discrete_grad,
                                      grad_on_x)
    
        if i == 0:
            sens_mean = sens_single
        else:
            sens_mean = [sens_mean[j] + sens_single[j] / num_user for j in range(len(sens_mean))]
        sens_mean_path = sensitivity_path + model_name + "_mean_sens"
        if discrete_grad:
            sens_mean_path += "_discrete"
        sens_mean_path += ".pt"
    torch.save(sens_mean, sens_mean_path)
    return sens_mean, sens_mean_path


## Transformer3

In [None]:
cfg = breaching.get_config(overrides=["case=10_causal_lang_training",  "attack=tag"])
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark

cfg.case.user.num_data_points = 1 # How many sentences?
cfg.case.data.shape = [2] # This is the sequence length

cfg.case.model = "transformer3"

In [None]:
transformer3_sens_mean, transformer3_sens_mean_path = get_mean_sens(cfg, "200_transformer3_tag", 
                                                    torch.device('cpu'), 
                                                    num_user=5, discrete_grad=False)
plot_layer_sens_mean(transformer3_sens_mean, "Transformer3 (TAG)")

## Transformer3f

In [None]:
cfg = breaching.get_config(overrides=["case=10_causal_lang_training",  "attack=tag"])
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark

cfg.case.user.num_data_points = 1 # How many sentences?
cfg.case.user.user_idx = 1 # From which user?
cfg.case.data.shape = [2] # This is the sequence length

cfg.case.model = "transformer3f"

In [None]:
transformer3f_sens_mean, transformer3f_sens_mean_path = get_mean_sens(cfg, "transformer3f_tag", 
                                                    torch.device('cpu'), 
                                                    num_user=5, discrete_grad=False)
plot_layer_sens_mean(transformer3f_sens_mean, "Transformer3f (TAG)")

## Transformer3t

In [None]:
cfg = breaching.get_config(overrides=["case=10_causal_lang_training",  "attack=tag"])
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark

cfg.case.user.num_data_points = 1 # How many sentences?
cfg.case.user.user_idx = 1 # From which user?
cfg.case.data.shape = [2] # This is the sequence length

cfg.case.model = "transformer3t"

In [None]:
transformer3t_sens_mean, transformer3t_sens_mean_path = get_mean_sens(cfg, "transformer3t_tag", 
                                                    torch.device('cpu'), 
                                                    num_user=5, discrete_grad=False)
plot_layer_sens_mean(transformer3t_sens_mean, "Transformer3t (TAG)")

## TransformerS

In [None]:
cfg = breaching.get_config(overrides=["case=10_causal_lang_training",  "attack=tag"])
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark

cfg.case.user.num_data_points = 1 # How many sentences?
cfg.case.user.user_idx = 1 # From which user?
cfg.case.data.shape = [2] # This is the sequence length

cfg.case.model = "transformerS"

In [None]:
transformerS_sens_mean, transformerS_sens_mean_path = get_mean_sens(cfg, "transformerS_tag", 
                                                    torch.device('cpu'), 
                                                    num_user=5, discrete_grad=False)
plot_layer_sens_mean(transformerS_sens_mean, "TransformerS (TAG)")

## Llama 3

将论文中使用的 Llama 3 加入敏感度计算示例（不依赖 breaching case 的 gpt2 配置）。
注意：首次运行需要下载模型权重；如果你已在本地有权重，可把 model_id 改为本地路径。

In [None]:
# --- Llama 3 sensitivity (Transformers) ---
# If you haven't installed deps, run in Terminal:
#   pip install -U transformers accelerate sentencepiece

import torch
import torch.nn as nn

from transformers import AutoTokenizer, AutoModelForCausalLM

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Default to Llama 3; change to your local path if needed.
# Examples:
#   model_id = "meta-llama/Meta-Llama-3-8B"
#   model_id = r"D:\models\Meta-Llama-3-8B"
model_id = "meta-llama/Meta-Llama-3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
    device_map="auto" if device.type == "cuda" else None
)

class LogitsOnly(nn.Module):
    """Return only logits so downstream code can treat output as a Tensor."""
    def __init__(self, base):
        super().__init__()
        self.base = base
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        out = self.base(input_ids=input_ids, attention_mask=attention_mask)
        return out.logits

model = LogitsOnly(base_model)
loss_fn = nn.MSELoss()


In [None]:
def build_llama_data_point(text: str, seq_len: int = 2):
    enc = tokenizer(text, return_tensors="pt")
    input_ids = enc["input_ids"][:, :seq_len].to(device)
    attention_mask = enc.get("attention_mask", torch.ones_like(input_ids)).to(device)
    # labels: token ids with shape [B, T], matching get_sensitivity() expectations in this notebook
    labels = input_ids.clone()
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

def get_mean_sens_llama(texts, num_user=5, seq_len=2, discrete_grad=False, grad_on_x=False):
    sens_mean = []
    for i in tqdm(range(num_user)):
        dp = build_llama_data_point(texts[i % len(texts)], seq_len=seq_len)
        sens_single = get_sensitivity(model, loss_fn, copy.deepcopy(dp), device, discrete_grad=discrete_grad, grad_on_x=grad_on_x)

        if i == 0:
            sens_mean = sens_single
        else:
            sens_mean = [sens_mean[j] + sens_single[j] / num_user for j in range(len(sens_mean))]
    return sens_mean

# Use short texts; you can replace these with your paper's actual corpus samples (e.g., WikiText).
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Federated learning improves privacy but may leak gradients.",
    "Selective encryption protects the most sensitive parameters.",
    "Transformers process sequences of tokens efficiently.",
    "Differential privacy adds noise to limit information leakage.",
]

llama_sens_mean = get_mean_sens_llama(texts, num_user=5, seq_len=2, discrete_grad=False)
plot_layer_sens_mean(llama_sens_mean, "Llama 3 (TAG)")
