## GPT2 For Sequence Classification

In [1]:
import numpy as np
import torch
from qgpt2_models import QGPT2ForSequenceClassification

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("./sst2_gpt2")
model = AutoModelForSequenceClassification.from_pretrained("./sst2_gpt2")
dataset = load_dataset("sst2")["validation"]

In [23]:
fixed_length = 32

input_sentence = dataset[20]["sentence"]

# NOTE: To compile and then inference correctly, lengths must be fixed (we fix to length 32 to save runtime/circuit size)
input_token_indexes = tokenizer(input_sentence, return_tensors="pt", padding="max_length", max_length=fixed_length, truncation=True).input_ids
input_ids = torch.tensor(input_token_indexes)

  input_ids = torch.tensor(input_token_indexes)


## Multi-Head Attention Model

In [4]:
qgpt2_seq = QGPT2ForSequenceClassification.from_pretrained("./sst2_gpt2", n_bits=6, layers=[8, 9, 10, 11], use_cache=False)
qgpt2_seq.set_fhe_mode(fhe="disable")
qgpt2_seq.config.pad_token_id = qgpt2_seq.config.eos_token_id

In [5]:
output_normal = model(input_ids[:1, :]).logits
output_q = qgpt2_seq(input_ids[:1, :]).logits

print(output_normal)
print(output_q)
print(torch.allclose(output_normal, output_q))

qgpt2_seq_circuits = qgpt2_seq.compile(input_ids)

tensor([[ 1.0615, -1.2352]], grad_fn=<IndexBackward0>)
tensor([[ 1.0971, -1.3250]], grad_fn=<IndexBackward0>)
False


 25%|██▌       | 1/4 [13:00<39:01, 780.51s/it]

Circuit compiled with at most 14 bits


 50%|█████     | 2/4 [45:29<48:55, 1467.86s/it]

Circuit compiled with at most 15 bits


 75%|███████▌  | 3/4 [1:08:16<23:41, 1421.63s/it]

Circuit compiled with at most 14 bits


100%|██████████| 4/4 [1:32:19<00:00, 1384.78s/it]

Circuit compiled with at most 16 bits





In [6]:
qgpt2_seq.set_fhe_mode(fhe="simulate")
print(input_ids)
output_logits_simulated = qgpt2_seq(input_ids)
print(output_logits_simulated)

tensor([[   79,   931,  5116,  2753,   281, 37959,   804,   379,   262, 29442,
           286,  1964, 29409,   837,   475,   340,   857,   523,   351,   884,
           281, 30690,  8216,   326,   345,  1239,   760,   618, 14733,  5645,
           290, 13574]])
SequenceClassifierOutputWithPast(loss=None, logits=tensor([[ 1.0971, -1.3250]], grad_fn=<IndexBackward0>), past_key_values=None, hidden_states=None, attentions=None)


## Attack Evaluation

In [24]:
import pickle
import os
from tqdm import tqdm

attacks = {}
attacks["baseline"] = dataset[:20]["sentence"]

budget = "3"
for filename in os.listdir("./attacks"):
    with open(os.path.join("./attacks", filename), "rb") as f:
        attacks[filename] = pickle.load(f)[budget]

In [25]:

ids = list(range(20))

def get_inference(model, attacks, ignore_attention=False):
    res = {}

    for attack_type, attack_data in tqdm(attacks.items()):
        res[attack_type[:-4]] = []
        for id in ids:
            # NOTE: some attacks generate multiple examples, we just arbitrary pick the last
            if isinstance(attack_data[id], list):
                attack_data[id] = attack_data[id][-1]
            curr_input = tokenizer(attack_data[id], return_tensors='pt', padding="max_length", max_length=fixed_length, truncation=True)
            if not ignore_attention:
                res[attack_type[:-4]].append(model(**curr_input).logits.detach().clone())
            else:
                res[attack_type[:-4]].append(model(curr_input.input_ids).logits.detach().clone())
    
    return res

def get_metrics(labels, attack_inference):
    labels = labels
    res_acc = {}
    res_pre = {}
    res_rec = {}
    crossentropy = {}

    for attack_type, attack_data in attack_inference.items():
        tp = 0
        fp = 0
        tn = 0
        fn = 0
        for id in ids:
            prediction = torch.argmax(attack_data[id].squeeze(), dim=-1).item()
            if prediction == 0:
                if labels[id] == 0:
                    tn += 1
                elif labels[id] == 1:
                    fn += 1
                else:
                    assert False
            elif prediction == 1:
                if labels[id] == 0:
                    fp += 1
                elif labels[id] == 1:
                    tp += 1
                else:
                    assert False
            else:
                assert False
        
        res_acc[attack_type] = (tp + tn) / (tp + fp + tn + fn)
        res_pre[attack_type] = (tp) / (tp + fp)
        res_rec[attack_type] = (tp) / (tp + fn)

        curr_targets = torch.stack(attack_data, dim=0).squeeze()
        crossentropy[attack_type] = torch.nn.functional.cross_entropy(curr_targets, torch.tensor(labels))
    
    return res_acc, res_pre, res_rec, crossentropy

In [26]:
inference_normal = get_inference(model, attacks)
inference_fhe = get_inference(qgpt2_seq, attacks)

100%|██████████| 13/13 [00:13<00:00,  1.01s/it]
100%|██████████| 13/13 [03:16<00:00, 15.14s/it]


In [27]:
labels = dataset[:20]["label"]
accuracy_normal, precision_normal, recall_normal, ce_normal = get_metrics(labels, inference_normal)
accuracy_fhe, precision_fhe, recall_fhe, ce_fhe = get_metrics(labels, inference_fhe)

In [28]:
accuracy_normal

{'base': 0.95,
 'sst2_deletion': 0.1,
 'sst2_deletions_targeted': 0.7,
 'sst2_deletions_targeted_nologits': 0.6,
 'sst2_homoglyphs': 0.35,
 'sst2_homoglyphs_targeted': 0.7,
 'sst2_homoglyphs_targeted_nologits': 0.7,
 'sst2_invisibles_targeted': 0.6,
 'sst2_invisibles_targeted_nologits': 0.65,
 'sst2_invisible_chars': 0.5,
 'sst2_reorder': 0.55,
 'sst2_reorderings_targeted': 0.75,
 'sst2_reorderings_targeted_nologits': 0.75}

In [29]:
accuracy_fhe

{'base': 0.95,
 'sst2_deletion': 0.15,
 'sst2_deletions_targeted': 0.7,
 'sst2_deletions_targeted_nologits': 0.6,
 'sst2_homoglyphs': 0.4,
 'sst2_homoglyphs_targeted': 0.7,
 'sst2_homoglyphs_targeted_nologits': 0.7,
 'sst2_invisibles_targeted': 0.6,
 'sst2_invisibles_targeted_nologits': 0.7,
 'sst2_invisible_chars': 0.55,
 'sst2_reorder': 0.5,
 'sst2_reorderings_targeted': 0.75,
 'sst2_reorderings_targeted_nologits': 0.75}

In [30]:
precision_normal

{'base': 1.0,
 'sst2_deletion': 0.0,
 'sst2_deletions_targeted': 0.625,
 'sst2_deletions_targeted_nologits': 0.5625,
 'sst2_homoglyphs': 0.36363636363636365,
 'sst2_homoglyphs_targeted': 0.625,
 'sst2_homoglyphs_targeted_nologits': 0.625,
 'sst2_invisibles_targeted': 0.5555555555555556,
 'sst2_invisibles_targeted_nologits': 0.5882352941176471,
 'sst2_invisible_chars': 0.5,
 'sst2_reorder': 0.6,
 'sst2_reorderings_targeted': 0.7272727272727273,
 'sst2_reorderings_targeted_nologits': 1.0}

In [31]:
precision_fhe

{'base': 1.0,
 'sst2_deletion': 0.0,
 'sst2_deletions_targeted': 0.625,
 'sst2_deletions_targeted_nologits': 0.5625,
 'sst2_homoglyphs': 0.4,
 'sst2_homoglyphs_targeted': 0.625,
 'sst2_homoglyphs_targeted_nologits': 0.6428571428571429,
 'sst2_invisibles_targeted': 0.5555555555555556,
 'sst2_invisibles_targeted_nologits': 0.625,
 'sst2_invisible_chars': 0.5454545454545454,
 'sst2_reorder': 0.5,
 'sst2_reorderings_targeted': 0.7272727272727273,
 'sst2_reorderings_targeted_nologits': 1.0}

In [32]:
recall_normal

{'base': 0.9,
 'sst2_deletion': 0.0,
 'sst2_deletions_targeted': 1.0,
 'sst2_deletions_targeted_nologits': 0.9,
 'sst2_homoglyphs': 0.4,
 'sst2_homoglyphs_targeted': 1.0,
 'sst2_homoglyphs_targeted_nologits': 1.0,
 'sst2_invisibles_targeted': 1.0,
 'sst2_invisibles_targeted_nologits': 1.0,
 'sst2_invisible_chars': 0.6,
 'sst2_reorder': 0.3,
 'sst2_reorderings_targeted': 0.8,
 'sst2_reorderings_targeted_nologits': 0.5}

In [33]:
recall_fhe

{'base': 0.9,
 'sst2_deletion': 0.0,
 'sst2_deletions_targeted': 1.0,
 'sst2_deletions_targeted_nologits': 0.9,
 'sst2_homoglyphs': 0.4,
 'sst2_homoglyphs_targeted': 1.0,
 'sst2_homoglyphs_targeted_nologits': 0.9,
 'sst2_invisibles_targeted': 1.0,
 'sst2_invisibles_targeted_nologits': 1.0,
 'sst2_invisible_chars': 0.6,
 'sst2_reorder': 0.2,
 'sst2_reorderings_targeted': 0.8,
 'sst2_reorderings_targeted_nologits': 0.5}

In [34]:
ce_normal

{'base': tensor(0.1946),
 'sst2_deletion': tensor(1.2521),
 'sst2_deletions_targeted': tensor(0.5379),
 'sst2_deletions_targeted_nologits': tensor(0.5383),
 'sst2_homoglyphs': tensor(1.0868),
 'sst2_homoglyphs_targeted': tensor(0.5803),
 'sst2_homoglyphs_targeted_nologits': tensor(0.5113),
 'sst2_invisibles_targeted': tensor(0.7941),
 'sst2_invisibles_targeted_nologits': tensor(0.4755),
 'sst2_invisible_chars': tensor(0.8643),
 'sst2_reorder': tensor(1.0337),
 'sst2_reorderings_targeted': tensor(0.6615),
 'sst2_reorderings_targeted_nologits': tensor(0.5554)}

In [35]:
ce_fhe

{'base': tensor(0.1954),
 'sst2_deletion': tensor(1.2411),
 'sst2_deletions_targeted': tensor(0.5163),
 'sst2_deletions_targeted_nologits': tensor(0.5220),
 'sst2_homoglyphs': tensor(1.0759),
 'sst2_homoglyphs_targeted': tensor(0.5611),
 'sst2_homoglyphs_targeted_nologits': tensor(0.4957),
 'sst2_invisibles_targeted': tensor(0.7632),
 'sst2_invisibles_targeted_nologits': tensor(0.4638),
 'sst2_invisible_chars': tensor(0.8543),
 'sst2_reorder': tensor(1.0290),
 'sst2_reorderings_targeted': tensor(0.6569),
 'sst2_reorderings_targeted_nologits': tensor(0.5602)}