# Topics in Adversarial Attacks on Deep Learning Models (02360207)
#### HW2 - Attacks on Discrete Language Models

In this HW, we will implement an LLM SOTA attack on a language model trained to perform textual binary classification task. The attack we focus on is Greedy Coordinate Descent (GCG). Before you begin the assignment, it is recommended to read the [original paper of GCG](https://arxiv.org/pdf/2307.15043), understand the mechanisms behind the attack and the logic it implements. This attack is meant to replace continuous optimization, as Language Models operates on discrete inputs (text tokens).

<table border="1" cellpadding="6">
<tr>
<th>Name</th>
<th>ID</th>
</tr>
<tr>
<td>Student 1</td>
<td>316153261</td>
</tr>
<tr>
<td>Student 2</td>
<td>111111111</td>
</tr>
<tr>
<td>Student 3</td>
<td>111111111</td>
</tr>
</table>

In [4]:
# %pip install pandas 

Collecting pandas
  Downloading pandas-2.3.3-cp310-cp310-win_amd64.whl.metadata (19 kB)
Collecting pytz>=2020.1 (from pandas)
  Using cached pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading pandas-2.3.3-cp310-cp310-win_amd64.whl (11.3 MB)
   ---------------------------------------- 0.0/11.3 MB ? eta -:--:--
    --------------------------------------- 0.3/11.3 MB ? eta -:--:--
   ------ --------------------------------- 1.8/11.3 MB 6.7 MB/s eta 0:00:02
   --------------- ------------------------ 4.5/11.3 MB 8.9 MB/s eta 0:00:01
   ------------------------ --------------- 7.1/11.3 MB 10.1 MB/s eta 0:00:01
   ---------------------------------- ----- 9.7/11.3 MB 10.8 MB/s eta 0:00:01
   ---------------------------------------  11.3/11.3 MB 10.5 MB/s eta 0:00:01
   ---------------------------------------- 11.3/11.3 MB 9.9 MB/s  0:00:01
Using cached pytz-2025.2-py2.py3-none-any.whl (509 kB)
Installing collected packages: pytz, pandas

   ---------------------------------------- 

In [5]:
# %pip install matplotlib

Collecting matplotlib
  Downloading matplotlib-3.10.8-cp310-cp310-win_amd64.whl.metadata (52 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.2-cp310-cp310-win_amd64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.61.1-cp310-cp310-win_amd64.whl.metadata (116 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.9-cp310-cp310-win_amd64.whl.metadata (6.4 kB)
Collecting pyparsing>=3 (from matplotlib)
  Downloading pyparsing-3.3.1-py3-none-any.whl.metadata (5.6 kB)
Downloading matplotlib-3.10.8-cp310-cp310-win_amd64.whl (8.1 MB)
   ---------------------------------------- 0.0/8.1 MB ? eta -:--:--
   ------ --------------------------------- 1.3/8.1 MB 7.4 MB/s eta 0:00:01
   --------------------- ------------------ 4.5/8.1 MB 12.7 MB/s eta 0:00:01
   -------------------------------------

In [6]:
# %pip install tqdm

Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1
Note: you may need to restart the kernel to use updated packages.


In [7]:
# %pip install transformers

Collecting transformers
  Downloading transformers-4.57.3-py3-none-any.whl.metadata (43 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2025.11.3-cp310-cp310-win_amd64.whl.metadata (41 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Downloading tokenizers-0.22.1-cp39-abi3-win_amd64.whl.metadata (6.9 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Downloading safetensors-0.7.0-cp38-abi3-win_amd64.whl.metadata (4.2 kB)
Downloading transformers-4.57.3-py3-none-any.whl (12.0 MB)
   ---------------------------------------- 0.0/12.0 MB ? eta -:--:--
   - -------------------------------------- 0.5/12.0 MB 3.3 MB/s eta 0:00:04
   ----------- ---------------------------- 3.4/12.0 MB 10.6 MB/s eta 0:00:01
   -------------------- ------------------- 6.3/12.0 MB 12.4 MB/s eta 0:00:01
   ---------------------------------

In [8]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import torch.nn.functional as F
print("Finish")

  from .autonotebook import tqdm as notebook_tqdm


Finish


In [9]:
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device is {device}")


device is cuda


## Loading the data

Our use-case in the assignment is a toxic text classification mission. In the course website, we uploaded a <code>/data</code> directory contains the dataset we work on. In the cell below, we have already included a short script that preprocess the data and sample negative examples from the trainset to balance the training data. Please to do not modify this code.

<b><span style="color: red">WARNING: The dataset contains offensive language. If you have any problem working on it, please contact me (Omer) via email and I will provide you an alternative dataset to work with.</span>


In [10]:
train_df = pd.read_csv('data/train.csv')
test_df = pd.read_csv('data/test.csv')
test_labels_df = pd.read_csv('data/test_labels.csv')
print("Finish2")

Finish2


In [11]:
train_df = train_df[['id', 'comment_text', 'toxic']]
negative_sample_train = train_df[train_df['toxic'] == 0].sample(frac=0.1) #downsample the negative class
positive_sample_train = train_df[train_df['toxic'] == 1]
train_df = pd.concat([negative_sample_train, positive_sample_train])

test_labels_df = test_labels_df[['id', 'toxic']]
test_df = pd.merge(test_df, test_labels_df, on='id', how='inner')
test_df = test_df[test_df['toxic'] != -1]  # remove samples with label `-1`, which indicates **unknown or unlabeled toxicity**
print("Finish")

Finish


## Training a text classifier

In this part, you should implement a torch dataset, and train a model of the given architecture to perform the text classification task. Note that the architecture is comprised of a language transformer encoder (recommended for you here to use <code>bert-base-uncased</code>) and a fully connected classification head. We use the transformers library to integrate out transformer. Note also that the paramteres of the transformer are kept frozzen in the architecture (which means we do not fine-tune it, just use the pre-trained embeddings).

Complete the dataset class:

In [12]:
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Keep raw texts if you need them later (e.g., for running the attack)
        self.texts = data["comment_text"].fillna("").astype(str).tolist()
        self.labels = data["toxic"].astype(float).tolist()

        # Pre-tokenize for speed
        enc = self.tokenizer(
            self.texts,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )
        self.input_ids = enc["input_ids"]
        self.attention_mask = enc["attention_mask"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            # shape (1,) so DataLoader gives (B,1) which matches sigmoid output
            "labels": torch.tensor([self.labels[idx]], dtype=torch.float32),
        }


In [13]:
class TextClassifier(nn.Module):
    def __init__(self, transformer_model, freeze_transformer=True):
        super(TextClassifier, self).__init__()
        self.model = AutoModel.from_pretrained(transformer_model)

        if freeze_transformer:
            for param in self.model.parameters():
                param.requires_grad = False
                
        self.tokenizer = AutoTokenizer.from_pretrained(transformer_model)
        self.fc = nn.Linear(768, 1)
        self.sigmoid = nn.Sigmoid()

    def get_input_embeddings(self):
        return self.model.get_input_embeddings()
    
    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
        cls_token = outputs.last_hidden_state[:, 0]
        cls_token = self.fc(cls_token)
        return self.sigmoid(cls_token)

Define variables for training: device, model, train & test datasets, train & test dataloaders, optimizer and criterion. You can choose hyperparamters on your own, as long as your model gets good accuracy (significantly above random guess) it's alright.


In [14]:
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device is {device}")

# model (768-dim encoder)
transformer_name = "distilbert-base-uncased"   # fast + hidden_size=768
model = TextClassifier(transformer_name, freeze_transformer=True).to(device)
print(f"model is {model}")

# datasets (expects columns: comment_text, toxic)
train_dataset = TextDataset(train_df, model.tokenizer)
test_dataset  = TextDataset(test_df,  model.tokenizer)

# loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_dataset,  batch_size=64, shuffle=False)

# loss + optimizer (Sigmoid already in model -> BCELoss)
criterion = nn.BCELoss()

# only train the head (transformer is frozen)
optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)
tokenizer = AutoTokenizer.from_pretrained(transformer_name)

print("finish2")

device is cuda
model is TextClassifier(
  (model): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
  

Train the model (complete the for-loop):

In [15]:
for epoch in range(5):
    model.train()
    epoch_loss = 0.0

    for batch in tqdm(train_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)  # shape: (B, 1), float

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)  # shape: (B, 1), sigmoid probs
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{5}, Loss: {epoch_loss/len(train_loader)}")


100%|█████████████████████████████████████████████████████████████████████████████| 929/929 [08:40<00:00,  1.79it/s]


Epoch 1/5, Loss: 0.33868546337210065


100%|█████████████████████████████████████████████████████████████████████████████| 929/929 [08:37<00:00,  1.80it/s]


Epoch 2/5, Loss: 0.27838909215383045


100%|████████████████████████████████████████████████████████████████| 929/929 [09:02<00:00,  1.71it/s]


Epoch 3/5, Loss: 0.2677565708243167


100%|████████████████████████████████████████████████████████████████| 929/929 [08:43<00:00,  1.77it/s]


Epoch 4/5, Loss: 0.26361326071302393


100%|████████████████████████████████████████████████████████████████| 929/929 [08:32<00:00,  1.81it/s]

Epoch 5/5, Loss: 0.2608335916839131





Evaluate the model (complete the for-loop):

In [16]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for batch in tqdm(test_loader):
        # handle either dict-batches (common with tokenizers) or tuple-batches
        if isinstance(batch, dict):
            labels = batch["labels"].to(device)
            inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
            logits = model(**inputs)
        else:
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            logits = model(inputs)

        preds = torch.argmax(logits, dim=1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

    print(f'Accuracy: {correct/total}')


100%|██████████████████████████████████████████████████████████████| 1000/1000 [18:07<00:00,  1.09s/it]

Accuracy: 57.894494982650286





## Train GCG Attack

In [None]:
# class GCGTextClassifierAttack:
#     """
#     Greedy Coordinate Gradient attack for text classification models.
#     Generates adversarial suffixes to cause misclassification.
#     """
    
#     def __init__(self, model, tokenizer, target_class=None, device='cuda'):
        
#         self.model = model
#         self.tokenizer = tokenizer
#         self.target_class = target_class
#         self.device = device
#         self.model.to(device)
#         self.model.eval()
        
#     def compute_loss(self, input_ids, target_class):
#         """Compute loss for target misclassification"""
#         pass
    
#     def get_token_gradients(self, input_ids, suffix_positions):
#         """Get gradients with respect to suffix tokens"""
#         pass
    
#     def sample_replacements(self, gradients, current_tokens, k=256, batch_size=512):
#         """Sample top-k token replacements based on gradients"""
#         pass
    
#     def evaluate_candidates(self, input_ids, suffix_positions, replacements):
#         """Evaluate loss for candidate token replacements"""
#         pass
    
#     def attack(self, text, suffix_length=10, num_iterations=100, k=256, batch_size=512):
#         """
#         Execute GCG attack to find adversarial suffix
        
#         Args:
#             text: Original text to attack
#             suffix_length: Length of adversarial suffix
#             num_iterations: Number of optimization iterations
#             k: Number of top tokens to consider per position
#             batch_size: Number of candidates to evaluate per iteration
#         """
#         pass


In [25]:
class GCGTextClassifierAttack:
    """
    Greedy Coordinate Gradient (GCG) attack for text classification.
    Learns an adversarial *suffix* (a sequence of tokens appended to the input)
    that pushes the model toward a target class.
    """

    def __init__(self, model, tokenizer, target_class=None, device='cuda'):
        self.model = model
        self.tokenizer = tokenizer
        self.target_class = target_class
        self.device = device

        self.model.to(device)
        self.model.eval()

        # internal (set per attack call)
        self._current_target_class = None

    def _make_attention_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
        pad_id = self.tokenizer.pad_token_id
        if pad_id is None:
            return torch.ones_like(input_ids, dtype=torch.long)
        return (input_ids != pad_id).long()

    def _forward_probs(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        out = self.model(input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
        # In this notebook TextClassifier returns sigmoid probs directly: shape (B,1)
        return out

    def compute_loss(self, input_ids, target_class):
        """Compute BCE loss toward a target class (0/1)."""
        input_ids = input_ids.to(self.device)
        attention_mask = self._make_attention_mask(input_ids).to(self.device)

        probs = self._forward_probs(input_ids=input_ids, attention_mask=attention_mask)  # (B,1)
        probs = probs.clamp(1e-7, 1.0 - 1e-7)

        target = torch.full_like(probs, float(target_class))
        loss = F.binary_cross_entropy(probs, target)
        return loss

    def get_token_gradients(self, input_ids, suffix_positions):
        """Get gradients w.r.t. *input embeddings* at suffix token positions."""
        input_ids = input_ids.to(self.device)
        attention_mask = self._make_attention_mask(input_ids).to(self.device)

        # build differentiable embeddings
        emb_layer = self.model.get_input_embeddings()
        inputs_embeds = emb_layer(input_ids).detach()
        inputs_embeds.requires_grad_(True)

        self.model.zero_grad(set_to_none=True)

        probs = self._forward_probs(inputs_embeds=inputs_embeds, attention_mask=attention_mask)  # (B,1)
        probs = probs.clamp(1e-7, 1.0 - 1e-7)

        target = torch.full_like(probs, float(self._current_target_class))
        loss = F.binary_cross_entropy(probs, target)
        loss.backward()

        # gradients for suffix positions: (suffix_len, hidden)
        grads = inputs_embeds.grad[0, suffix_positions, :].detach().clone()
        return grads

    def sample_replacements(self, gradients, current_tokens, k=256, batch_size=512):
        """
        Use 1st-order approximation to propose token replacements.
        For each suffix position i: choose tokens that most reduce loss ~ minimize grad_i · emb(token).
        """
        emb_weight = self.model.get_input_embeddings().weight  # (V, H)
        # scores: higher is better improvement (we take -grad·emb)
        scores = -torch.matmul(gradients, emb_weight.t())  # (suffix_len, V)

        # avoid special tokens + keep text readable
        if hasattr(self.tokenizer, "all_special_ids") and self.tokenizer.all_special_ids:
            scores[:, self.tokenizer.all_special_ids] = -float("inf")

        # avoid keeping the same token
        pos_idx = torch.arange(scores.size(0), device=scores.device)
        scores[pos_idx, current_tokens] = -float("inf")

        # top-k per position
        topk = torch.topk(scores, k=min(k, scores.size(1)), dim=1)
        topk_ids = topk.indices  # (suffix_len, k)
        topk_scores = topk.values

        # flatten (position, token) candidates, sort by heuristic score, keep up to batch_size
        candidates = []
        suffix_len = topk_ids.size(0)
        kk = topk_ids.size(1)
        for i in range(suffix_len):
            for j in range(kk):
                candidates.append((i, int(topk_ids[i, j].item()), float(topk_scores[i, j].item())))

        candidates.sort(key=lambda x: x[2], reverse=True)
        candidates = candidates[:batch_size]
        # return (suffix_pos_index, token_id)
        return [(i, tok) for (i, tok, _) in candidates]

    def evaluate_candidates(self, input_ids, suffix_positions, replacements):
        """Evaluate true loss for each candidate replacement; return best (pos_idx, token_id, loss)."""
        if len(replacements) == 0:
            return None, None, None

        base_ids = input_ids.to(self.device)
        N = len(replacements)

        # build a batch of candidates
        cand_ids = base_ids.repeat(N, 1)
        for n, (pos_in_suffix, tok_id) in enumerate(replacements):
            abs_pos = suffix_positions[pos_in_suffix]
            cand_ids[n, abs_pos] = tok_id

        # evaluate losses
        losses = []
        bs = 512  # internal eval batch (can be same as your batch_size param)
        for start in range(0, N, bs):
            chunk = cand_ids[start:start + bs]
            mask = self._make_attention_mask(chunk).to(self.device)

            probs = self._forward_probs(input_ids=chunk, attention_mask=mask)  # (B,1)
            probs = probs.clamp(1e-7, 1.0 - 1e-7)

            target = torch.full_like(probs, float(self._current_target_class))
            # per-sample loss
            chunk_loss = F.binary_cross_entropy(probs, target, reduction='none').mean(dim=1)  # (B,)
            losses.append(chunk_loss.detach())

        losses = torch.cat(losses, dim=0)  # (N,)
        best_idx = int(torch.argmin(losses).item())
        best_pos, best_tok = replacements[best_idx]
        best_loss = float(losses[best_idx].item())
        return best_pos, best_tok, best_loss

    def attack(self, text, suffix_length=10, num_iterations=100, k=256, batch_size=512):
        """
        Execute GCG attack to find adversarial suffix.

        Returns:
            adv_text: decoded full adversarial text
            adv_suffix: decoded adversarial suffix only
            info: dict with details
        """
        # Tokenize (leave room for suffix)
        max_len = getattr(self.tokenizer, "model_max_length", 512)
        max_len = 512 if (max_len is None or max_len > 10_000) else max_len
        enc = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=max(8, max_len - suffix_length),
            padding=False,
        )
        input_ids = enc["input_ids"].to(self.device)

        # pick target class
        with torch.no_grad():
            mask = self._make_attention_mask(input_ids).to(self.device)
            p = self._forward_probs(input_ids=input_ids, attention_mask=mask)  # (1,1)
            pred = int((p.item() > 0.5))

        target_class = self.target_class
        if target_class is None:
            target_class = 1 - pred  # flip for binary
        self._current_target_class = target_class

        # init suffix tokens
        init_tok = (
            self.tokenizer.mask_token_id
            if self.tokenizer.mask_token_id is not None
            else (self.tokenizer.unk_token_id if self.tokenizer.unk_token_id is not None else 0)
        )
        suffix = torch.full((1, suffix_length), init_tok, dtype=torch.long, device=self.device)

        # append suffix before [SEP] if present at the end, otherwise append at end
        sep_id = self.tokenizer.sep_token_id
        if sep_id is not None and input_ids.size(1) > 0 and int(input_ids[0, -1].item()) == int(sep_id):
            prefix = input_ids[:, :-1]
            sep = input_ids[:, -1:]
            adv_ids = torch.cat([prefix, suffix, sep], dim=1)
            start_pos = prefix.size(1)
        else:
            adv_ids = torch.cat([input_ids, suffix], dim=1)
            start_pos = input_ids.size(1)

        suffix_positions = list(range(start_pos, start_pos + suffix_length))

        # initial loss
        current_loss = float(self.compute_loss(adv_ids, target_class).item())
        best_pred = pred

        for it in range(num_iterations):
            grads = self.get_token_gradients(adv_ids, suffix_positions)  # (suffix_len, H)
            current_suffix_tokens = adv_ids[0, suffix_positions].detach()

            replacements = self.sample_replacements(
                gradients=grads,
                current_tokens=current_suffix_tokens,
                k=k,
                batch_size=batch_size,
            )

            best_pos, best_tok, best_loss = self.evaluate_candidates(adv_ids, suffix_positions, replacements)
            if best_pos is None:
                break

            # stop if no improvement
            if best_loss >= current_loss - 1e-6:
                break

            # apply greedy update
            adv_ids[0, suffix_positions[best_pos]] = best_tok
            current_loss = best_loss

            # check success (hit target)
            with torch.no_grad():
                mask = self._make_attention_mask(adv_ids).to(self.device)
                p = self._forward_probs(input_ids=adv_ids, attention_mask=mask)
                best_pred = int((p.item() > 0.5))
                if best_pred == target_class:
                    break

        adv_suffix_ids = adv_ids[0, suffix_positions].tolist()
        adv_suffix = self.tokenizer.decode(adv_suffix_ids, skip_special_tokens=True)
        adv_text = self.tokenizer.decode(adv_ids[0].tolist(), skip_special_tokens=True)

        info = {
            "orig_pred": pred,
            "target_class": target_class,
            "final_pred": best_pred,
            "final_loss": current_loss,
            "num_iterations_ran": it + 1,
        }
        return adv_text, adv_suffix, info


Run the attack on 10 samples from the test set that your model predicts correctly. Show that the attack manages to cause misclassification. Print in the following format: 
* original sample: <code>your original toxic text</code>
* original logits: probability of the model to input being toxic
* attacked sample: <code>your original toxic text after attack</code>
* attacked logits: probability of the model to input after attack being toxic 


In [None]:
pass # your code for running the attack on test samples and reporting results here

In [26]:
# def toxic_probability_from_model(model, tokenizer, text, device, toxic_class_idx=1):
#     """
#     Returns P(toxic) as float.
#     Works whether model returns:
#       - a Tensor logits
#       - a tuple (logits, ...)
#       - a dict with 'logits'
#       - a HF ModelOutput with .logits
#     """
#     enc = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
#     enc = {k: v.to(device) for k, v in enc.items()}

#     out = model(**enc)

#     # --- extract logits ---
#     if torch.is_tensor(out):
#         logits = out
#     elif isinstance(out, (tuple, list)):
#         logits = out[0]
#     elif isinstance(out, dict) and "logits" in out:
#         logits = out["logits"]
#     else:
#         logits = out.logits  # HF-style

#     # ensure shape [batch, C] or [batch, 1]
#     if logits.dim() == 1:
#         logits = logits.unsqueeze(0)

#     # --- convert to probability of toxic ---
#     if logits.size(-1) == 1:
#         # single-logit binary classifier: sigmoid(logit)
#         p_toxic = torch.sigmoid(logits.squeeze(0).squeeze(-1))
#     else:
#         # multi-logit classifier: softmax, take toxic_class_idx (usually 1)
#         probs = torch.softmax(logits.squeeze(0), dim=-1)
#         p_toxic = probs[toxic_class_idx]

#     return float(p_toxic.detach().cpu())


# def pred_label_from_prob(p_toxic, threshold=0.5):
#     return 1 if p_toxic >= threshold else 0

# def extract_text_and_label(sample, tokenizer):
#     """
#     Tries to extract (text, label) from different dataset formats.
#     If the dataset stores tokenized inputs, we decode them back to text.
#     """
#     # Case A: sample is dict
#     if isinstance(sample, dict):
#         label = sample.get("label", sample.get("labels", None))
#         if torch.is_tensor(label):
#             label = int(label.item())

#         if "text" in sample:
#             text = sample["text"]
#             return text, label

#         if "input_ids" in sample:
#             ids = sample["input_ids"]
#             if torch.is_tensor(ids):
#                 ids = ids.tolist()
#             text = tokenizer.decode(ids, skip_special_tokens=True)
#             return text, label

#     # Case B: sample is tuple (text, label) or (enc, label)
#     if isinstance(sample, (tuple, list)) and len(sample) >= 2:
#         a, b = sample[0], sample[1]
#         # if first element looks like raw text
#         if isinstance(a, str):
#             text = a
#             label = int(b.item()) if torch.is_tensor(b) else int(b)
#             return text, label
#         # if first element is token ids tensor/list
#         if torch.is_tensor(a) or isinstance(a, (list, tuple)):
#             ids = a.tolist() if torch.is_tensor(a) else a
#             text = tokenizer.decode(ids, skip_special_tokens=True)
#             label = int(b.item()) if torch.is_tensor(b) else int(b)
#             return text, label

#     raise ValueError("Couldn't parse sample format. Add a custom extractor for your dataset.")

# def run_attack_on_10_correct(model, tokenizer, test_dataset, attack, device="cuda", threshold=0.5, max_tries=500):
#     model.eval()
#     device = torch.device(device if torch.cuda.is_available() else "cpu")
#     model.to(device)

#     found = 0
#     i = 0
#     while found < 10 and i < len(test_dataset) and max_tries > 0:
#         max_tries -= 1
#         sample = test_dataset[i]
#         i += 1

#         text, y_true = extract_text_and_label(sample, tokenizer)

#         # Compute original toxic probability + predicted label
#         p_orig = toxic_probability_from_model(model, tokenizer, text, device)
#         y_pred = pred_label_from_prob(p_orig, threshold)

#         # Keep only correctly predicted samples
#         if y_pred != y_true:
#             continue

#         # ---- Run attack (change ONLY this line if your API differs) ----
#         # Example expected: attacked_text = attack.generate(text, target_class=1-y_true) OR attack(text)
#         try:
#             attacked_text = attack.generate(text, target_class=1 - y_true)
#         except TypeError:
#             # fallback: attack might not take target_class
#             attacked_text = attack.generate(text)
#         except AttributeError:
#             # fallback: attack might be callable
#             attacked_text = attack(text)
#         # ---------------------------------------------------------------

#         p_att = toxic_probability_from_model(model, tokenizer, attacked_text, device)
#         y_att = pred_label_from_prob(p_att, threshold)

#         # We want misclassification (flip)
#         if y_att == y_true:
#             continue

#         found += 1

#         print(f"* original sample: <code>{text}</code>")
#         print(f"* original logits: {p_orig}")
#         print(f"* attacked sample: <code>{attacked_text}</code>")
#         print(f"* attacked logits: {p_att}")
#         print()

#     if found < 10:
#         print(f"Only found {found} successful attacks (misclassifications). "
#               f"Try increasing attack strength/steps or search more samples.")

# # Usage:
# attack = GCGTextClassifierAttack(model, tokenizer, target_class=None, device=str(device))
# run_attack_on_10_correct(model, tokenizer, test_dataset, attack, device="cuda")


TypeError: 'GCGTextClassifierAttack' object is not callable

In [30]:
import torch

def p_toxic(model, tokenizer, text, device):
    enc = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        prob = model(**enc)  # your TextClassifier returns sigmoid prob (B,1)
        return float(prob.view(-1)[0].item())

def pred_from_p(p):
    return 1 if p >= 0.5 else 0

def run_attack_on_10_correct(model, tokenizer, test_dataset, attack, device="cuda",
                             suffix_length=10, num_iterations=100, k=256, batch_size=512):
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda")
    model.to(device).eval()

    successes = 0
    tried = 0

    # go through the test set until we collect 10 *successful* misclassifications
    for i in range(len(test_dataset)):
        # adapt to your dataset structure (common in this HW: dataset.texts / dataset.labels)
        text = test_dataset.texts[i]
        y_true = int(test_dataset.labels[i])

        p_orig = p_toxic(model, tokenizer, text, device)
        y_pred = pred_from_p(p_orig)

        # only take samples predicted correctly
        if y_pred != y_true:
            continue

        tried += 1

        adv_text, adv_suffix, info = attack.attack(
            text,
            suffix_length=suffix_length,
            num_iterations=num_iterations,
            k=k,
            batch_size=batch_size
        )

        p_att = p_toxic(model, tokenizer, adv_text, device)
        y_att = pred_from_p(p_att)

        # success = now misclassified (since original was correct)
        if y_att == y_true:
            continue

        successes += 1

        print(f"* original sample: <code>{text}</code>")
        print(f"* original logits: {p_orig}")
        print(f"* attacked sample: <code>{adv_text}</code>")
        print(f"* attacked logits: {p_att}")
        print()

        if successes >= 10:
            break

    print(f"Successful misclassifications: {successes}/10 (searched through {tried} correctly-predicted samples)")

# usage
attack = GCGTextClassifierAttack(model, tokenizer, target_class=None, device="cuda")
run_attack_on_10_correct(model, tokenizer, test_dataset, attack, device="cuda")
print("Finish")

* original sample: <code>Thank you for understanding. I think very highly of you and would not revert without discussion.</code>
* original logits: 0.01077330019325018
* attacked sample: <code>thank you for understanding. i think very highly of you and would not revert without discussion. bastard</code>
* attacked logits: 0.5732153058052063

* original sample: <code>== Arabs are committing genocide in Iraq, but no protests in Europe. == 

 May Europe also burn in hell.</code>
* original logits: 0.849972665309906
* attacked sample: <code>= = arabs are committing genocide in iraq, but no protests in europe. = = may europe also burn in hell. highlighted tentative</code>
* attacked logits: 0.11034949868917465

* original sample: <code>Professors to the Manhatten Project.</code>
* original logits: 0.02696494571864605
* attacked sample: <code>professors to the manhatten project. fucking</code>
* attacked logits: 0.7693648338317871

* original sample: <code>일이삼사오육칠팔구하고십이요 에헤헤 으헤 으헤 으허허</code>