In [1]:
import os
text_attack_cache_dir = os.path.join(os.getcwd(), 'text-attack')
os.environ['TA_CACHE_DIR'] = text_attack_cache_dir
if not os.path.isdir(text_attack_cache_dir):
    os.mkdir(text_attack_cache_dir)
from datasets import load_dataset
from textattack.attack_recipes.textbugger_li_2018 import TextBuggerLi2018
from textattack.models.wrappers.huggingface_model_wrapper import HuggingFaceModelWrapper
from textattack.models.wrappers.model_wrapper import ModelWrapper
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Tuple
from textattack import Attack
from textattack.constraints.pre_transformation import (
    RepeatModification,
    StopwordModification,
)
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations import (
    Transformation,
    CompositeTransformation,
    WordSwapEmbedding,
    WordSwapHomoglyphSwap,
    WordSwapNeighboringCharacterSwap,
    WordSwapRandomCharacterDeletion,
    WordSwapRandomCharacterInsertion,
)
from textattack.attack_recipes import AttackRecipe
from textattack.constraints import Constraint
import utils
import re

  from .autonotebook import tqdm as notebook_tqdm
textattack: Updating TextAttack package dependencies.
textattack: Downloading NLTK required packages.
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\Mark\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Mark\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package omw to
[nltk_data]     C:\Users\Mark\AppData\Roaming\nltk_data...
[nltk_data]   Package omw is already up-to-date!
[nltk_data] Downloading package universal_tagset to
[nltk_data]     C:\Users\Mark\AppData\Roaming\nltk_data...
[nltk_data]   Package universal_tagset is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\Mark\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[n

In [2]:
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", cache_dir="./models")
model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", cache_dir="./models")

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.98s/it]


# Attack implementation

In [3]:
class ICLConstraint(Constraint):

    def __init__(self, pattern):
        super().__init__(compare_against_original=True)
        self._pattern = pattern

    def _check_constraint(self, transformed_text, reference_text) -> bool:
        reference_matches = re.findall(self._pattern, reference_text)
        if reference_matches:
            last_match_reference = reference_matches[-1]
            start_index_reference = reference_text.rindex(last_match_reference)
            
            transformed_matches = re.findall(self._pattern, transformed_text)
            if transformed_matches:
                last_match_transformed = transformed_matches[-1]
                start_index_transformed = transformed_text.rindex(last_match_transformed)
                
                # return true if the suffix is the same
                return transformed_text[start_index_transformed:] == reference_text[start_index_reference:]
            else:
                # no match in transformed text
                return False
        else:
            # no match in reference text
            return False

In [4]:
class ICLTextBugger(AttackRecipe):
    @staticmethod
    def build(model_wrapper, pattern):
        transformation = CompositeTransformation([
            WordSwapRandomCharacterInsertion(
                random_one=True,
                letters_to_insert=" ",
                skip_first_char=True,
                skip_last_char=True,
            ),
            WordSwapRandomCharacterDeletion(
                random_one=True, skip_first_char=True, skip_last_char=True
            ),
            WordSwapNeighboringCharacterSwap(
                random_one=True, skip_first_char=True, skip_last_char=True
            ),
            WordSwapHomoglyphSwap(),
            WordSwapEmbedding(max_candidates=5),
        ])
        constraints = [
            RepeatModification(),
            StopwordModification(),
            UniversalSentenceEncoder(threshold=0.8),
            ICLConstraint(pattern),
        ]
        goal_function = UntargetedClassification(model_wrapper)
        search_method = GreedyWordSwapWIR(wir_method="delete")
        
        return Attack(goal_function, constraints, transformation, search_method)

# Data processing

In [5]:
sst2_dataset = load_dataset("sst2", cache_dir="./data")

Found cached dataset sst2 (d:/Cyber-final-project/data/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)
100%|██████████| 3/3 [00:00<00:00, 746.98it/s]


In [6]:
class Template:
    def __init__(self, template: str, instruction: str) -> None:
        if '{example}' not in template:
            raise Exception('example placeholder not in template')
        if '{label}' not in template:
            raise Exception('label placeholder not in template')
        self.template: str = template
        self.instruction: str = instruction
    
    def apply(self, examples: List[str], labels: List[str]) -> str:
        accumulator = []
        for example, label in zip(examples, labels): 
            accumulator.append(
                self.template.format(
                    example=example,
                    label=label
                )
            )
        return f'{self.instruction}\n\n' + '\n'.join(accumulator)

class ICLSample:

    def __init__(self, examples: List[str], labels: List[str], template: Template, test_sample: str) -> None:
        if len(examples) != len(labels):
            raise Exception('examples and labels length are not the same')
        self.examples: List[str] = examples
        self.labels: List[str] = labels
        self.test_sample: str = test_sample
        self.template: Template = template

    def to_text(self) -> str:
        examples = self.examples + [self.test_sample]
        labels = self.labels + ['_']
        return self.template.apply(examples, labels)

In [7]:
def split_list_into_chunks(array: list, chunk_size: int):
    chunks = []
    for i in range(0, len(array), chunk_size):
        chunk = array[i:i+chunk_size]
        chunks.append(chunk)
    return chunks

In [8]:
sst2_config = {
    'pattern': r"Review: .+?\nSentiment: .+?",
    'template': """Review: {example}
Sentiment: {label}""",
    'instruction': 'Choose sentiment from Positive or Negative .',
}

def sst2ICL_sample_factory(n_examples: int) -> List[ICLSample]:
    sst2_template = Template(sst2_config['template'], sst2_config['instruction'])
    records = []
    # aggregate all records
    for record in sst2_dataset['train']:
        records.append({
            'sentence': record['sentence'],
            'label': 'Positive' if record['label'] == 1 else 'Negative',
        })
    for record in sst2_dataset['validation']:
        records.append({
            'sentence': record['sentence'],
            'label': 'Positive' if record['label'] == 1 else 'Negative',
        })
    chunks = split_list_into_chunks(records, n_examples)
    samples: List[ICLSample] = []
    ground_truth: List[str] = []
    for chunk in chunks:
        examples = [example['sentence'] for example in chunk[:-1]]
        labels = [example['label'] for example in chunk[:-1]]
        test_sample = chunk[-1]['sentence']
        ground_truth.append(chunk[-1]['label'])
        sample: ICLSample = ICLSample(examples, labels, sst2_template, test_sample)
        samples.append(sample)
    return samples, ground_truth

In [9]:
sst_data = sst2ICL_sample_factory(4)

In [10]:
trec_dataset = load_dataset("trec", cache_dir="./data")

Found cached dataset trec (d:/Cyber-final-project/data/trec/default/2.0.0/f2469cab1b5fceec7249fda55360dfdbd92a7a5b545e91ea0f78ad108ffac1c2)
100%|██████████| 2/2 [00:00<00:00, 401.25it/s]


In [11]:
trec_config = {
    'pattern': r"Question: .+?\nAnswer: .+?",
    'template': """Question: {example}
Answer: {label}""",
    'instruction': 'Classify the questions based on whether their answer type is a Number, Location, Person, Description, Entity, or Abbreviation.\n',
}

def trecICL_sample_factory(n_examples: int) -> List[ICLSample]:
    trec_template = Template(trec_config['template'], trec_config['instruction'])
    numeric_label_2_textual = {
        0: 'Abbreviation',
        1: 'Entity',
        2: 'Description',
        3: 'Person',
        4: 'Location',
        5: 'Number',
    }
    records = []
    # aggregate all records
    for record in trec_dataset['train']:
        records.append({
            'text': record['text'],
            'label': numeric_label_2_textual[record['coarse_label']],
        })
    for record in trec_dataset['test']:
        records.append({
            'text': record['text'],
            'label': numeric_label_2_textual[record['coarse_label']],
        })
    chunks = split_list_into_chunks(records, n_examples)
    samples: List[ICLSample] = []
    ground_truth: List[str] = []
    for chunk in chunks:
        examples = [example['text'] for example in chunk[:-1]]
        labels = [example['label'] for example in chunk[:-1]]
        test_sample = chunk[-1]['text']
        ground_truth.append(chunk[-1]['label'])
        sample: ICLSample = ICLSample(examples, labels, trec_template, test_sample)
        samples.append(sample)
    return samples, ground_truth

In [12]:
trec_data = trecICL_sample_factory(4)

In [19]:
# why?
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
huggingface_model = HuggingFaceModelWrapper(model=model, tokenizer=tokenizer)
attack = ICLTextBugger.build(huggingface_model, sst2_config['pattern'])

textattack: Unknown if model of class <class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.


In [17]:
sst_samples, sst_labels = sst_data

In [20]:
attack.attack(sst_samples[0].to_text(), sst_labels[0])

TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'token_type_ids'