In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pickle
from utils.metric import evaluate, evaluate_multi_cls
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from tqdm import tqdm
import random
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")

In [None]:
def read_data(pkl_path, loc):
    with open(pkl_path, 'rb') as file:
        data = pickle.load(file)
    pos_data, neg_data = [], []
    for d in data:
        if d['label'] == 0:
            pos_data.append(d[loc])
        else:
            neg_data.append(d[loc])
    pos_data = torch.stack(pos_data).float()
    neg_data = torch.stack(neg_data).float()
    # assert pos_data.shape[0] == neg_data.shape[0]
    return pos_data, neg_data

def read_data_all(pkl_path):
    with open(pkl_path, 'rb') as file:
        data = pickle.load(file)
    pos_data, neg_data = [], []
    for d in data:
        if d['label'] == 0:
            pos_data.append(d)
        else:
            neg_data.append(d)
    return pos_data, neg_data

def mix_data_fitting(x_train, y_train, model_name='lr'):
    model_list = []
    for layer in range(x_train.shape[1]):
        if model_name == 'lda':
            model = LinearDiscriminantAnalysis()
        elif model_name == 'lr':
            model = LogisticRegression(penalty='l2')
        elif model_name == 'sgd':
            model = SGDClassifier(loss='log_loss')
        else:
            raise NotImplementedError
        model.fit(x_train[:, layer], y_train)
        model_list.append(model)
    return model_list

def evaluation(x_val, y_val, model_list, shot=0, threshold=0.5):
    accs_mlp, asrs_mlp, prs_mlp, f1s_mlp, auroc_mlp = [], [], [], [], []
    pos_results, neg_results = [], []
    for layer, model in enumerate(model_list):
        x_l = x_val[:, layer]
        y_l = y_val
        if shot > 0:
            model.partial_fit(x_val[:shot, layer], y_val[:shot], classes=[0,1])
            x_l = x_val[shot:, layer]
            y_l = y_val[shot:]
        y_pred = model.predict_proba(x_l)[:, 1]
        print(y_l)
        print(y_pred)
        acc, asr, pr, f1, auroc, neg_per_sample_result, pos_per_sample_result = \
            evaluate_multi_cls(y_l, y_pred, show=False, threshold=threshold, enable_analyse=True)
        # pos_results.append(pos_per_sample_result)
        # neg_results.append(neg_per_sample_result)
        accs_mlp.append(acc)
        asrs_mlp.append(asr)
        prs_mlp.append(pr)
        f1s_mlp.append(f1)
        auroc_mlp.append(auroc)
    return accs_mlp, asrs_mlp, prs_mlp, f1s_mlp, auroc_mlp, y_pred


In [None]:
def train_val_split_multi_cls(train_data,train_label_list, num_shot=10):
    total_num = train_data.shape[0]
    train_num = int(num_shot) if num_shot != -1 else total_num
    val_num = total_num - train_num
    # class_indices = {}

    # for index, label in enumerate(train_label_list):
    #     if str(label.item()) not in class_indices:
    #         class_indices[str(label.item())] = []
    #     class_indices[str(label.item())].append(index)

    # selected_indices = []
    # for label, indices in class_indices.items():
    #     selected = random.sample(indices, 2)  # 从每类中随机选择两个索引
    #     selected_indices.extend(selected) 

    train_idx = random.sample(range(total_num), train_num)

    val_idx = list(set(range(total_num)) - set(train_idx))

    x_train = torch.cat([train_data[train_idx]], dim=0)
    y_train = torch.cat([train_label_list[train_idx]], dim=0)
    # print('train_label',y_train)
    x_val = torch.cat([train_data[val_idx]], dim=0)
    y_val = torch.cat([train_label_list[val_idx]])
    # print(y_val)

    return x_train, y_train, x_val, y_val

def read_data_single(pkl_path, key_words ,loc):
    with open(pkl_path, 'rb') as file:
        data = pickle.load(file)
    pos_data, neg_data = [], []
    num_neg=0
    num_pos=0
    for d in data:
        # print(d.keys())
        if key_words in d['img_path'] and d['label']==1:
            neg_data.append(d[loc])
            num_neg+=1
    for d in data:
        if d['label'] == 0:
            pos_data.append(d[loc])
            num_pos+=1
        if num_pos==num_neg:
            break
    print(key_words,num_neg,num_pos)
    pos_data = torch.stack(pos_data).float()
    neg_data = torch.stack(neg_data).float()
    assert pos_data.shape[0] == neg_data.shape[0]
    return pos_data, neg_data

def get_fine_grained_data(data, top_k_indices):
    return data[:, top_k_indices[:, 0], top_k_indices[:, 1]].reshape(data.shape[0], 1, -1)
    # dh = data.shape[-1]
    # return data[:, top_k_indices[:, 0], top_k_indices[:, 1]].reshape(data.shape[0], -1, dh).sum(1).unsqueeze(1)

def few_shot_probing_multi_cls(train_data, train_label_list, num_shot, num_repeat=10, threshold=0.5, model_name='lr'):
    model_list = []
    accs, asrs, f1s, aurocs = [], [], [], []
    for _ in tqdm(range(num_repeat)):
        if len(train_data.shape) > 3:   # 'attn_headas'
            n, nl, nh, dh = train_data.shape
            train_data= train_data.reshape(n, -1, dh)
        train_label_tensor=torch.Tensor(train_label_list)    
        x_train, y_train, x_val, y_val = train_val_split_multi_cls(train_data, train_label_tensor, num_shot=num_shot)
        print(x_train.shape)
        print(y_train)
        models = mix_data_fitting(x_train, y_train, model_name=model_name)
        model_list.append(models)
        accs_mlp= evaluation(x_val, y_val, models, threshold=threshold)
        accs.append(accs_mlp)
        # asrs.append(asrs_mlp)
        # f1s.append(f1s_mlp)
        # aurocs.append(auroc_mlp)

    return {
        'acc': np.array(accs),
        # 'asr': np.array(asrs),
        # 'f1': np.array(f1s),
        # 'auroc': np.array(aurocs),
        # 'threshold': threshold
    }, model_list

def testing(test_data, labels, model_list, shot=0, threshold=0.5):
    accs, asrs, f1s, aurocs = [], [], [], []
    for models in tqdm(model_list):
        if len(test_data.shape) > 3:   # 'attn_headas'
            n, nl, nh, dh = test_data.shape
            test_data = test_data.reshape(n, -1, dh)
        x_val = test_data
        y_val = labels
        accs_mlp= evaluation(x_val, y_val, models, shot=shot, threshold=threshold)
        accs.append(accs_mlp)
        # asrs.append(asrs_mlp)
        # f1s.append(f1s_mlp)
        # aurocs.append(auroc_mlp)
    return {
        'acc': np.array(accs),
        # 'asr': np.array(asrs),
        # 'f1': np.array(f1s),
        # 'auroc': np.array(aurocs),
        # 'threshold': threshold
    }

In [None]:
pos_data_list = []
neg_data_list = []

for i in range(1, 14):
    prefix = f"{i:02d}-"
    pos, neg = read_data_single('output/LLaVA-7B/raw_mmsafety_SD_TYPO_all_oe_activations.pkl', prefix, loc='attn_heads')
    pos_data_list.append(pos)
    neg_data_list.append(neg)

### copy attn_heads_indices from 2-1

In [None]:
top_k_indices_llava = np.array([[ 4, 30],
        [ 5,  2],
        [ 6, 22],
        [ 6, 29],
        [ 7,  0],
        [ 7,  7],
        [ 7, 26],
        [ 8,  0],
        [ 8,  9],
        [ 8, 21],
        [ 8, 30],
        [ 9, 20],
        [10,  5],
        [10, 22],
        [11,  0],
        [11,  1],
        [11, 21],
        [13, 17],
        [14, 14],
        [15,  1],
        [15,  6],
        [16,  7],
        [16, 25],
        [16, 29],
        [17,  7],
        [20,  6],
        [22,  6],
        [22, 15],
        [23,  4],
        [27, 11],
        [28,  2],
        [28, 13]])

In [None]:
top_k_indices_qwen =  np.array([[ 0,  0],
        [ 3, 11],
        [ 4, 25],
        [ 5,  5],
        [ 5, 28],
        [ 5, 29],
        [ 6, 28],
        [ 7, 24],
        [ 8, 15],
        [10,  1],
        [10, 29],
        [11,  2],
        [11, 28],
        [12,  7],
        [12,  8],
        [12,  9],
        [12, 13],
        [13,  9],
        [13, 19],
        [13, 27],
        [15,  0],
        [15, 16],
        [17, 14],
        [18, 22],
        [19,  1],
        [20, 12],
        [20, 13],
        [22, 15],
        [26,  0],
        [28,  9],
        [30, 11],
        [31, 21]])

###  Multi_cls of 13 types

In [None]:
train_data = []
test_data = []
pos_data_list = [globals()[f'pos_data_{i:02d}'] for i in range(1, 14)]
neg_data_list = [globals()[f'neg_data_{i:02d}'] for i in range(1, 14)]
train_label_list=[]
test_label_list=[]
for i,neg_data in enumerate(neg_data_list):
    num_train = int(0.1 * neg_data.shape[0])  #10%data用于训练
    
    train_data.append(neg_data[:num_train])  
    test_data.append(neg_data[num_train:])
    train_label_list+=[i+1]*num_train
    test_label_list+=[i+1]*(neg_data.shape[0]-num_train)

train_data = torch.cat(train_data, dim=0)  # [13 * 10%, 32, 32, 128]
test_data = torch.cat(test_data, dim=0) 
print(train_data.shape)
print(test_data.shape)
print(train_label_list)

In [None]:
import numpy as np
from tqdm import tqdm
import torch
import pickle
import random
import json
import os
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.metrics import accuracy_score, classification_report

# 加载数据
with open('/workspace/safety_heads/Attack/output/LLaVA-7B/raw_mmsafety_SD_TYPO_all_oe_activations.pkl', 'rb') as file:
# with open('output/qwen/raw_mmsafety_SD_TYPO_all_oe_activations.pkl', 'rb') as file:
    data = pickle.load(file)
    # print(data[0].keys())
all_class_data = {key: [] for key in range(14)}
for d in data:
    all_class_data[d['scenario']].append(d['attn_heads'])
for k, v in all_class_data.items():
    all_class_data[k] = get_fine_grained_data(torch.stack(v), top_k_indices_llava).squeeze(1)
    # all_class_data[k] = get_fine_grained_data(torch.stack(v), top_k_indices_qwen).squeeze(1)

In [None]:
all_class_data.keys()

In [None]:
train_ratios = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
num_repeat = 20
accs_dict = {ratio: [] for ratio in train_ratios}
all_data = []

log_output = []

for train_ratio in train_ratios:
    print(f"\nTraining with {train_ratio*100}% of data")
    all_metrics = []

    for _ in tqdm(range(num_repeat)):
        X_train, X_test, y_train, y_test = [], [], [], []
        
        for k, v in all_class_data.items():
            per_class_num = v.shape[0]
            num_train = int(train_ratio * per_class_num)
            num_train = max(1, min(num_train, per_class_num - 1))
            
            train_idx = random.sample(range(per_class_num), num_train)
            val_idx = list(set(range(per_class_num)) - set(train_idx))
            
            # print(v.shape, v[train_idx].shape, v[val_idx].shape)
            X_train.append(v[train_idx])
            X_test.append(v[val_idx])
            y_train.append(torch.ones(num_train) * k)
            y_test.append(torch.ones(per_class_num - num_train) * k)
        
        X_train = torch.cat(X_train)
        X_test = torch.cat(X_test)
        y_train = torch.cat(y_train)
        y_test = torch.cat(y_test)

        X_train = X_train.to(torch.float32)
        X_test = X_test.to(torch.float32)
        
        model = LDA()
        model.fit(X_train.numpy(), y_train.numpy())
        y_pred = model.predict(X_test.numpy())
        
        acc = accuracy_score(y_test.numpy(), y_pred)
        accs_dict[train_ratio].append(acc)

        # classification_report，默认返回float类别！
        # 可以设置target_names = [str(i) for i in range(14)]，或者在后续处理时float(cls)
        report = classification_report(y_test.numpy(), y_pred, output_dict=True, zero_division=0, target_names = [str(i) for i in range(14)]) 
        # print(report)
        all_metrics.append(report)

    acc_array = np.array(accs_dict[train_ratio])
    mean_acc = acc_array.mean()
    std_acc = acc_array.std()
    acc_log = f"Ratio {train_ratio}: Mean accuracy = {mean_acc:.4f}, Std = {std_acc:.4f}"
    log_output.append(acc_log)
    # print(acc_log)

    avg_metrics = {str(k): {'precision': [], 'recall': [], 'f1-score': []} for k in range(14)}
    for report in all_metrics:
        for cls in range(14):
            cls_str = str(cls)
            # print(cls_str)
            if cls_str in report:
                avg_metrics[cls_str]['precision'].append(report[cls_str]['precision'])
                avg_metrics[cls_str]['recall'].append(report[cls_str]['recall'])
                avg_metrics[cls_str]['f1-score'].append(report[cls_str]['f1-score'])

    detailed_log = [f"\nDetailed metrics for ratio {train_ratio}:"]
    detailed_log.append("Class | Precision | Recall | F1-Score")
    detailed_log.append("-" * 45)
    # print(f"\nDetailed metrics for ratio {train_ratio}:")
    # print("Class | Precision | Recall | F1-Score")
    # print("-" * 45)

    class_metrics = {}
    for cls in range(14):
        cls_str = str(cls)
        mean_precision = np.mean(avg_metrics[cls_str]['precision'])
        mean_recall = np.mean(avg_metrics[cls_str]['recall'])
        mean_f1 = np.mean(avg_metrics[cls_str]['f1-score'])
        class_log = f"{cls:5d} | {mean_precision:.4f}    | {mean_recall:.4f} | {mean_f1:.4f}"
        detailed_log.append(class_log)
        # print(class_log)
        
        class_metrics[cls_str] = {
            'precision': float(mean_precision),
            'recall': float(mean_recall),
            'f1-score': float(mean_f1),
            'precision_list': [float(x) for x in avg_metrics[cls_str]['precision']],
            'recall_list': [float(x) for x in avg_metrics[cls_str]['recall']],
            'f1_score_list': [float(x) for x in avg_metrics[cls_str]['f1-score']]
        }
    log_output.extend(detailed_log)
    mean_values = {
        'train_ratio': train_ratio,
        'mean_acc': float(mean_acc),
        'mean_std': float(std_acc),
        'class_metrics': {
            str(k): {
                'precision': float(np.mean(avg_metrics[str(k)]['precision'])),
                'recall': float(np.mean(avg_metrics[str(k)]['recall'])),
                'f1-score': float(np.mean(avg_metrics[str(k)]['f1-score']))
            } for k in range(14)
        }
    }
    all_data.append(mean_values)

# json save
output_dir = 'test_fine_detector_cross_Muilt_CLS'
output_file = os.path.join(output_dir, 'data_llava_lda_detailed.json')
os.makedirs(output_dir, exist_ok=True)
with open(output_file, 'w') as json_file:
    json.dump(all_data, json_file, ensure_ascii=False, indent=4)

# log save
log_file = os.path.join(output_dir, 'data_llava_lda_detailed_log.txt')
with open(log_file, 'w') as log_f:
    log_f.write("\n".join(log_output))

print(f"\nResults saved to {output_file}")
print(f"Log saved to {log_file}")