# Trigger Utils

In [1]:
import heapq
from copy import deepcopy
from functools import partial
from operator import itemgetter
from typing import List, Tuple

import numpy
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.sampler import BatchSampler
from tqdm import tqdm

In [2]:
def get_embedding_weight_bert(model, model_type='bert'):
    """
    Extracts and returns the token embedding weight matrix from the model.
    """
    return model.bert.embeddings.word_embeddings.weight.cpu().detach()

In [3]:
# hook used in add_hooks()
extracted_grads = []

In [4]:
def add_hooks_bert(model, model_type='bert'):
    """
    Finds the token embedding matrix on the model and registers a hook onto it.
    When loss.backward() is called, extracted_grads list will be filled with
    the gradients w.r.t. the token embeddings
    """
    model.bert.embeddings.word_embeddings.weight.requires_grad = True
    model.bert.embeddings.word_embeddings.register_backward_hook(extract_grad_hook)

In [5]:
def get_mlm_probabilities(model: torch.nn.Module, batch: Tuple,
                          mask_token_idx: int, trigger_token_len: int = 1):
    input_ids = batch[0]
    mask_tensor = torch.cuda.LongTensor([mask_token_idx] * trigger_token_len)
    mask_tensor = mask_tensor.repeat(batch[0].shape[0], 1)
    input_ids = torch.cat(
        [input_ids[:, 0].unsqueeze(1), mask_tensor, input_ids[:, 1:]], 1)
    logits = model(input_ids)[0]
    logits = logits[:, 1:trigger_token_len + 1, :].sum(dim=0)
    # Just return the trigger probabilities
    return torch.nn.Softmax()(logits).cpu()

In [6]:
def evaluate_batch_bert(model: torch.nn.Module, batch: Tuple,
                        trigger_token_ids: List = None, reduction='mean'):
    # Attach attack_multiple_objectives if present
    input_ids = batch[0]
    loss_f = torch.nn.CrossEntropyLoss(reduction=reduction)
    if trigger_token_ids is not None:
        trig_tensor = torch.cuda.LongTensor(trigger_token_ids)
        trig_tensor = trig_tensor.repeat(batch[0].shape[0], 1)
        # CLS trigger tokens
        input_ids = torch.cat(
            [input_ids[:, 0].unsqueeze(1), trig_tensor, input_ids[:, 1:]], 1)

    loss, logits_val = model(input_ids, attention_mask=input_ids > 1,
                             labels=batch[1])
    loss = loss_f(logits_val, batch[1].long())
    labels = batch[1]

    return loss, logits_val, labels

In [7]:
def evaluate_batch_ppl(model: torch.nn.Module, batch: Tuple, tokenizer,
                       trigger_token_ids: List = None):
    # Attach attack_multiple_objectives if present
    input_ids = batch[0]

    if trigger_token_ids == None:
        trigger_token_ids = []

    input_ids_claim = []
    for instance in input_ids:
        instance = instance.detach().cpu().numpy().tolist()
        # get the claim including SEP token, add triggger and append claim
        # plus evidence
        claim_tokens = instance[0:1] + trigger_token_ids + instance[1:]
        claim_tokens = claim_tokens[:512]
        input_ids_claim.append(claim_tokens)

    # pad batch
    max_len = min(512, max([len(i) for i in input_ids_claim]))
    input_ids_claim = [
        instance + [tokenizer.pad_token_id] * (max_len - len(instance))
        for instance in input_ids_claim]
    input_ids = torch.tensor(input_ids_claim).cuda()

    outputs = model(input_ids, masked_lm_labels=input_ids)
    loss, prediction_scores = outputs[:2]
    return loss, prediction_scores

In [8]:
def evaluate_batch_gpt(model: torch.nn.Module, batch: Tuple,
                       trigger_token_ids: List = None, tokenizer=None):
    # Attach attack_multiple_objectives if present
    input_ids = batch[0]
    loss_f = torch.nn.CrossEntropyLoss()
    input_ids_claim = []

    if trigger_token_ids == None:
        trigger_token_ids = []

    # index_sep = (input_ids == tokenizer.sep_token_id).nonzero()[:, 1]
    for instance in input_ids:
        instance = instance.detach().cpu().numpy().tolist()
        # get the claim including SEP token, add triggger and append again
        # claim but without the CLS token
        # CLS, trigger, claim tokens
        # sep_index = instance.index(tokenizer.sep_token_id)
        claim_tokens = instance[0:1] + trigger_token_ids + instance[1:]
        input_ids_claim.append(claim_tokens[:512])
        # pad batch
        max_len = min(max([len(i) for i in input_ids_claim]), 512)
        input_ids_claim = [
            instance + [tokenizer.pad_token_id] * (max_len - len(instance))
            for instance in input_ids_claim]
        input_ids = torch.tensor(input_ids_claim).cuda()

    # eval is w.r.t. entailment - this is the target class in the NLI case,
    # i.e. the one we want to minimize the loss for.
    logits_val = \
    model(input_ids, attention_mask=input_ids != tokenizer.pad_token_id)[0]
    loss = loss_f(logits_val, batch[1].long())

    return loss, logits_val, batch[1].long()

In [9]:
def evaluate_batch_nli(model: torch.nn.Module, batch: Tuple,
                       trigger_token_ids: List = None, tokenizer=None):
    # Attach attack_multiple_objectives if present
    input_ids = batch[0]
    loss_f = torch.nn.MSELoss()
    input_ids_claim = []

    if trigger_token_ids == None:
        trigger_token_ids = []

    # index_sep = (input_ids == tokenizer.sep_token_id).nonzero()[:, 1]
    for instance in input_ids:
        instance = instance.detach().cpu().numpy().tolist()
        sep_index = instance.index(tokenizer.sep_token_id)
        # get the claim including SEP token, add triggger and append again
        # claim but without the CLS token
        claim_tokens = instance[0:1] + trigger_token_ids + instance[
                                                           1:sep_index + 1] +\
                       instance[
                                                                              1:sep_index + 1]
        input_ids_claim.append(claim_tokens[:512])
        # pad batch
        max_len = max([len(i) for i in input_ids_claim])
        input_ids_claim = [
            instance + [tokenizer.pad_token_id] * (max_len - len(instance))
            for instance in input_ids_claim]
        input_ids = torch.tensor(input_ids_claim).cuda()

    # eval w.r.t. entailment - this is the target class in the NLI case,
    # i.e. the one we want to minimize the loss for.
    logits_val = \
    model(input_ids, attention_mask=input_ids != tokenizer.pad_token_id)[0]
    loss = loss_f(logits_val, batch[1].float().unsqueeze(1))

    return loss, logits_val, batch[1].long()

In [10]:
def get_average_grad_transformer(model, batch, trigger_token_ids, batch_func,
                                 target_label=None, ):
    """
    Computes the average gradient w.r.t. the trigger tokens when prepended to
    every example
    in the batch. If target_label is set, that is used as the ground-truth
    label.
    """
    # create an dummy optimizer for backprop
    optimizer = optim.Adam(model.parameters())
    optimizer.zero_grad()

    # prepend attack_multiple_objectives to the batch
    original_labels = batch[1].clone()
    if target_label is not None:
        # set the labels equal to the target (backprop from the target class,
        # not model prediction)
        batch[1] = int(target_label) * torch.ones_like(batch[1]).cuda()
    global extracted_grads
    extracted_grads = []  # clear existing stored grads
    loss, logits, labels = batch_func(model, batch, trigger_token_ids)
    loss.backward()
    # index 0 has the hypothesis grads for SNLI. For SST, the list is of size 1.
    grads = extracted_grads[0].detach().cpu()
    batch[1] = original_labels.detach()  # reset labels

    # average grad across batch size, result only makes sense for trigger
    # tokens at the front
    averaged_grad = torch.sum(grads, dim=0)
    averaged_grad = F.normalize(averaged_grad, p=2, dim=1)
    # start from position 1 as at 0 is the CLS token
    averaged_grad = averaged_grad[
                    1:len(trigger_token_ids) + 1]  # return just trigger grads
    return averaged_grad

In [11]:
def get_loss_per_candidate(model, model_nli, gpt_model, batch,
                           trigger_token_ids, cand_trigger_token_ids, tokenizer,
                           nli_w=0.0, fc_w=1.0, ppl_w=0.0, idx=0):
    nli_batch_func = partial(evaluate_batch_nli, tokenizer=tokenizer)
    gpt_batch_func = partial(evaluate_batch_gpt, tokenizer=tokenizer)

    original_labels = batch[1].clone()
    loss_per_candidate = get_loss_per_candidate_bert(idx, model, batch,
                                                     trigger_token_ids,
                                                     cand_trigger_token_ids,
                                                     evaluate_batch_bert,
                                                     tokenizer)  # uses the
    # real labels
    loss_per_candidate = [(_t, _s * fc_w) for _t, _s in loss_per_candidate]

    if nli_w > 0.0:
        batch[1] = 0 * torch.ones_like(batch[1]).cuda()
        nli_loss_per_candidate = get_loss_per_candidate_bert(idx, model_nli,
                                                             batch,
                                                             trigger_token_ids,
                                                             cand_trigger_token_ids,
                                                             nli_batch_func,
                                                             tokenizer)
        loss_per_candidate = [(_t, _s + nli_loss_per_candidate[i][1] * nli_w)
                              for i, (_t, _s) in enumerate(loss_per_candidate)]
        batch[1] = original_labels.detach()
    if ppl_w > 0.0:
        batch[1] = 0 * torch.ones_like(batch[1]).cuda()
        gpt_loss_per_candidate = get_loss_per_candidate_bert(idx, gpt_model,
                                                             batch,
                                                             trigger_token_ids,
                                                             cand_trigger_token_ids,
                                                             gpt_batch_func,
                                                             tokenizer)
        loss_per_candidate = [(_t, _s + gpt_loss_per_candidate[i][1] * ppl_w)
                              for i, (_t, _s) in enumerate(loss_per_candidate)]
        batch[1] = original_labels.detach()

    return loss_per_candidate

In [12]:
def get_best_candidates_all_obj(model, model_nli, gpt_model, batch,
                                trigger_token_ids, cand_trigger_token_ids,
                                tokenizer, beam_size=1, nli_w=0.0, fc_w=1.0,
                                ppl_w=0.0):
    loss_per_candidate = get_loss_per_candidate(model, model_nli, gpt_model,
                                                batch,
                                                trigger_token_ids,
                                                cand_trigger_token_ids,
                                                tokenizer,
                                                nli_w=nli_w, fc_w=fc_w,
                                                ppl_w=ppl_w, idx=0)
    # maximize the loss
    top_candidates = heapq.nlargest(beam_size, loss_per_candidate,
                                    key=itemgetter(1))

    # top_candidates now contains beam_size trigger sequences, each with a
    # different 0th token
    for idx in range(1, len(
            trigger_token_ids)):  # for all trigger tokens, skipping the 0th
        # (we did it above)
        loss_per_candidate = []
        for cand, _ in top_candidates:  # for all the beams, try all the
            # candidates at idx
            loss_ = get_loss_per_candidate(model, model_nli, gpt_model, batch,
                                           cand, cand_trigger_token_ids,
                                           tokenizer,
                                           nli_w=nli_w, fc_w=fc_w, ppl_w=ppl_w,
                                           idx=idx)

            loss_per_candidate.extend(loss_)
        top_candidates = heapq.nlargest(beam_size, loss_per_candidate,
                                        key=itemgetter(1))
    return sorted(top_candidates, key=itemgetter(1), reverse=True)[:beam_size]

In [13]:
def get_best_candidates_bert(model, batch, trigger_token_ids,
                             cand_trigger_token_ids, tokenizer, beam_size=1):
    """"
    Given the list of candidate trigger token ids (of number of trigger words
    by number of candidates
    per word), it finds the best new candidate trigger.
    This performs beam search in a left to right fashion.
    """
    # first round, no beams, just get the loss for each of the candidates in
    # index 0.
    # (indices 1-end are just the old trigger)
    loss_per_candidate = get_loss_per_candidate_bert(0, model, batch,
                                                     trigger_token_ids,
                                                     cand_trigger_token_ids,
                                                     evaluate_batch_bert,
                                                     tokenizer)
    # maximize the loss
    top_candidates = heapq.nlargest(beam_size, loss_per_candidate,
                                    key=itemgetter(1))

    # top_candidates now contains beam_size trigger sequences, each with a
    # different 0th token
    for idx in range(1, len(
            trigger_token_ids)):  # for all trigger tokens, skipping the 0th
        # (we did it above)
        loss_per_candidate = []
        for cand, _ in top_candidates:  # for all the beams, try all the
            # candidates at idx
            loss_per_candidate.extend(
                get_loss_per_candidate_bert(idx, model, batch, cand,
                                            cand_trigger_token_ids,
                                            evaluate_batch_bert, tokenizer))
        top_candidates = heapq.nlargest(beam_size, loss_per_candidate,
                                        key=itemgetter(1))
    return max(top_candidates, key=itemgetter(1))[0]

In [14]:
def get_loss_per_candidate_bert(index, model, batch, trigger_token_ids,
                                cand_trigger_token_ids, eval_batch_f,
                                tokenizer):
    """
    For a particular index, the function tries all of the candidate tokens
    for that index.
    The function returns a list containing the candidate
    attack_multiple_objectives it tried, along with their loss.
    """
    if isinstance(cand_trigger_token_ids[0], (numpy.int64, int)):
        print("Only 1 candidate for index detected, not searching")
        return trigger_token_ids
    loss_per_candidate = []
    # loss for the trigger without trying the candidates
    curr_loss, logits, labels = eval_batch_f(model, batch, trigger_token_ids)
    curr_loss = curr_loss.cpu().detach().numpy()

    loss_per_candidate.append((deepcopy(trigger_token_ids), curr_loss))
    for cand_id in range(len(cand_trigger_token_ids[0])):
        token = tokenizer.convert_ids_to_tokens(
            [cand_trigger_token_ids[index][cand_id]])
        trigger_token_ids_one_replaced = deepcopy(
            trigger_token_ids)  # copy trigger
        trigger_token_ids_one_replaced[index] = cand_trigger_token_ids[index][
            cand_id]  # replace one token
        if not any(_s.isalpha() for _s in token):
            loss = -100.0
        elif not (token[0].startswith('Ġ') or token[0][0].isupper()):
            loss = -100.0
        else:
            loss, logits, labels = eval_batch_f(model, batch,
                                                trigger_token_ids_one_replaced)
            loss = loss.cpu().detach().numpy()
        loss_per_candidate.append(
            (deepcopy(trigger_token_ids_one_replaced), loss))
    return loss_per_candidate

In [15]:
def eval_fc(model: torch.nn.Module, test_dl: BatchSampler,
            trigger_token_ids: List = None, labels_num=3):
    model.eval()
    with torch.no_grad():
        labels_all = []
        logits_all = []
        for batch in tqdm(test_dl, desc="Evaluation"):
            # Attach triggers if present
            loss, logits_val, labels = evaluate_batch_bert(model, batch,
                                                           trigger_token_ids)

            labels_all += labels.detach().cpu().numpy().tolist()
            logits_all += logits_val.detach().cpu().numpy().tolist()

        prediction = numpy.argmax(
            numpy.asarray(logits_all).reshape(-1, labels_num), axis=-1)
        acc = sum(prediction == labels_all) / len(labels_all)

    return acc

In [16]:
def eval_ppl(model: torch.nn.Module, test_dl: BatchSampler, tokenizer,
             trigger_token_ids: List = None):
    model.eval()
    with torch.no_grad():
        ppl_loss = []
        for batch in tqdm(test_dl, desc="Evaluation"):
            # Attach triggers if present
            loss, prediction_scores = evaluate_batch_ppl(model, batch,
                                                         tokenizer,
                                                         trigger_token_ids)
            loss = torch.exp(loss)
            ppl_loss.append(loss.item() / batch[0].shape[0])
    return numpy.mean(ppl_loss, dtype=float), numpy.std(ppl_loss)

In [17]:
def eval_nli(model: torch.nn.Module, test_dl: BatchSampler, tokenizer,
             trigger_token_ids: List = None, labels_num=3):
    model.eval()
    with torch.no_grad():
        logits_all = []
        for batch in tqdm(test_dl, desc="Evaluation"):
            loss, logits_val, _ = evaluate_batch_nli(model, batch,
                                                     trigger_token_ids,
                                                     tokenizer)
            logits_all += logits_val.detach().squeeze().cpu().numpy().tolist()
    return logits_all

In [18]:
def perplexity(logits, targets):
    """
    Calculates the perplexity of a sentence based on the
    output of a language model
    :param logits: The language model output
    :param targets: The expected output tokens
    :return: The perplexity score (float)
    """
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(logits, targets)
    return np.exp(loss.cpu().item())


# Attacks

"""
Implementation based on https://github.com/Eric-Wallace/universal-triggers
Contains different methods for attacking models. In particular, given the
gradients for token embeddings, it computes the optimal token replacements.
"""

In [19]:
def hotflip_attack(averaged_grad, embedding_matrix, trigger_token_ids,
                   increase_loss=False, num_candidates=1):
    """
    The "Hotflip" attack described in Equation (2) of the
    Universal Adversarial Attacks paper. This code is
    heavily inspired by
    the nice code of Paul Michel here
    https://github.com/pmichel31415/translate/blob/paul/
    pytorch_translate/research/adversarial/adversaries/brute_force_adversary.py

    This function takes in the model's average_grad over a batch of examples,
    the model's
    token embedding matrix, and the current trigger token IDs. It returns the
    top token
    candidates for each position.

    If increase_loss=True, then the attack reverses the sign of the gradient
    and tries to increase
    the loss (decrease the model's probability of the true class). For
    targeted attacks, you want
    to decrease the loss of the target class (increase_loss=False).
    """
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()
    trigger_token_embeds = torch.nn.functional.embedding(
        torch.LongTensor(trigger_token_ids),
        embedding_matrix).detach().unsqueeze(0)
    averaged_grad = averaged_grad.unsqueeze(0)
    gradient_dot_embedding_matrix = torch.einsum("bij,kj->bik",
                                                 (averaged_grad,
                                                  embedding_matrix))
    if not increase_loss:
        gradient_dot_embedding_matrix *= -1  # lower versus increase the
        # class probability.
    if num_candidates > 1:  # get top k options
        _, best_k_ids = torch.topk(gradient_dot_embedding_matrix,
                                   num_candidates, dim=2)
        return best_k_ids.detach().cpu().numpy()[0]
    _, best_at_each_step = gradient_dot_embedding_matrix.max(2)
    return best_at_each_step[0].detach().cpu().numpy()

In [20]:
def pairwise_dot_product(src_embeds, vocab_embeds, cosine=False):
    """Compute the cosine similarity between each word in the vocab and each
    word in the source
    If `cosine=True` this returns the pairwise cosine similarity"""
    # Normlize vectors for the cosine similarity
    if cosine:
        src_embeds = F.normalize(src_embeds, dim=-1, p=2)
        vocab_embeds = F.normalize(vocab_embeds, dim=-1, p=2)
    # Take the dot product
    dot_product = torch.einsum("bij,kj->bik", (src_embeds, vocab_embeds))
    return dot_product

In [21]:
def pairwise_distance(src_embeds, vocab_embeds, squared=False):
    """Compute the euclidean distance between each word in the vocab and each
    word in the source"""
    # We will compute the squared norm first to avoid having to compute all
    # the directions (which would have space complexity B x T x |V| x d)
    # First compute the squared norm of each word vector
    vocab_sq_norm = vocab_embeds.norm(p=2, dim=-1) ** 2
    src_sq_norm = src_embeds.norm(p=2, dim=-1) ** 2
    # Take the dot product
    dot_product = pairwise_dot_product(src_embeds, vocab_embeds)
    # Reshape for broadcasting
    # 1 x 1 x |V|
    vocab_sq_norm = vocab_sq_norm.unsqueeze(0).unsqueeze(0)
    # B x T x 1
    src_sq_norm = src_sq_norm.unsqueeze(2)
    # Compute squared difference
    sq_norm = vocab_sq_norm + src_sq_norm - 2 * dot_product
    # Either return the squared norm or return the sqrt
    if squared:
        return sq_norm
    else:
        # Relu + epsilon for numerical stability
        sq_norm = F.relu(sq_norm) + 1e-20
        # Take the square root
        return sq_norm.sqrt()

In [22]:
def tailor_simple(averaged_grad, embedding_matrix, increase_loss=False):
    """
    Tailor approximation simplified by computing just the largest gradient
    w.r.t. the target label.
    """
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()
    averaged_grad = averaged_grad.unsqueeze(0)
    gradient_dot_embedding_matrix = torch.einsum("bij,kj->bik",
                                                 (averaged_grad,
                                                  embedding_matrix))
    if not increase_loss:
        gradient_dot_embedding_matrix *= -1  # lower versus increase the
        # class probability.

    gradient_dot_embedding_matrix = F.normalize(gradient_dot_embedding_matrix,
                                                p=2, dim=1)
    return gradient_dot_embedding_matrix

In [23]:
def tailor_first(averaged_grad, embedding_matrix, trigger_token_ids,
                 reverse_loss=False, normalize=False):
    """
    Tailor approximation of the larget gradient compared to the gradient of
    the current token,
    all w.r.t. the target class.
    """
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()
    trigger_token_embeds = torch.nn.functional.embedding(
        torch.LongTensor(trigger_token_ids),
        embedding_matrix).detach().unsqueeze(0)
    averaged_grad = averaged_grad.unsqueeze(0)
    new_embed_dot_grad = torch.einsum("bij,kj->bik",
                                      (averaged_grad, embedding_matrix))
    prev_embed_dot_grad = torch.einsum("bij,bij->bi",
                                       (averaged_grad, trigger_token_embeds))

    if reverse_loss:
        neg_dir_dot_grad = prev_embed_dot_grad.unsqueeze(
            -1) + new_embed_dot_grad
    else:
        neg_dir_dot_grad = prev_embed_dot_grad.unsqueeze(
            -1) - new_embed_dot_grad

    if normalize:
        # Compute the direction norm (= distance word/substitution)
        direction_norm = pairwise_distance(trigger_token_embeds,
                                           embedding_matrix)
        # Renormalize
        neg_dir_dot_grad /= direction_norm

    return neg_dir_dot_grad

In [24]:
def hotflip_attack_all(averaged_grad, embedding_matrix,
                       averaged_grad_nli, embedding_matrix_nli,
                       averaged_grad_ppl, embedding_matrix_ppl,
                       nli_w=0.0, fc_w=1, ppl_w=0.0,
                       num_candidates=1):
    """Optimise the adversarial attacks for all objectives
    that have a weight > 0. This is described in the paper in Equation 2.
    """
    neg_dir_dot_grad = fc_w * tailor_simple(averaged_grad, embedding_matrix,
                                            increase_loss=False)
    if nli_w != 0:
        neg_dir_dot_grad_nli = tailor_simple(averaged_grad_nli,
                                             embedding_matrix_nli,
                                             increase_loss=False)  # decrease
        # loss for entailment
        neg_dir_dot_grad += nli_w * neg_dir_dot_grad_nli
    if ppl_w != 0:
        neg_dir_dot_grad_ppl = tailor_simple(averaged_grad_ppl,
                                             embedding_matrix_ppl,
                                             increase_loss=False)  # decrease
        # loss real example
        neg_dir_dot_grad += ppl_w * neg_dir_dot_grad_ppl

    if num_candidates > 1:  # get top k options
        _, best_k_ids = torch.topk(neg_dir_dot_grad, num_candidates, dim=2)
        return best_k_ids.detach().cpu().numpy()[0]
    _, best_at_each_step = neg_dir_dot_grad.max(2)
    return best_at_each_step[0].detach().cpu().numpy()

In [25]:
def random_attack(embedding_matrix, trigger_token_ids, num_candidates=1):
    """
    Randomly search over the vocabulary. Gets num_candidates random samples
    and returns all of them.
    """
    embedding_matrix = embedding_matrix.cpu()
    new_trigger_token_ids = [[None] * num_candidates for _ in
                             range(len(trigger_token_ids))]
    for trigger_token_id in range(len(trigger_token_ids)):
        for candidate_number in range(num_candidates):
            # rand token in the embedding matrix
            rand_token = numpy.random.randint(embedding_matrix.shape[0])
            new_trigger_token_ids[trigger_token_id][
                candidate_number] = rand_token
    return new_trigger_token_ids

In [26]:
# steps in the direction of grad and gets the nearest neighbor vector.
def nearest_neighbor_grad(averaged_grad, embedding_matrix, trigger_token_ids,
                          tree, step_size, increase_loss=False,
                          num_candidates=1):
    """
    Takes a small step in the direction of the averaged_grad and finds the
    nearest
    vector in the embedding matrix using a kd-tree.
    """
    new_trigger_token_ids = [[None] * num_candidates for _ in
                             range(len(trigger_token_ids))]
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()
    if increase_loss:  # reverse the sign
        step_size *= -1
    for token_pos, trigger_token_id in enumerate(trigger_token_ids):
        # take a step in the direction of the gradient
        trigger_token_embed = \
        torch.nn.functional.embedding(torch.LongTensor([trigger_token_id]),
                                      embedding_matrix).detach().cpu().numpy()[
            0]
        stepped_trigger_token_embed = trigger_token_embed + \
                                      averaged_grad[
                                          token_pos].detach().cpu().numpy() * step_size
        # look in the k-d tree for the nearest embedding
        _, neighbors = tree.query([stepped_trigger_token_embed],
                                  k=num_candidates)
        for candidate_number, neighbor in enumerate(neighbors[0]):
            new_trigger_token_ids[token_pos][candidate_number] = neighbor
    return new_trigger_token_ids

# Trigger generation

In [42]:
import os
from collections import defaultdict
from functools import partial
import gc
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BertConfig, BertForSequenceClassification, BertTokenizer, PreTrainedTokenizer
from typing import List, Dict, AnyStr, Set
from torch import nn

In [44]:
BERT_VOCAB_SIZE = 30522
EMBEDDING_SIZE = 512

In [38]:
def get_fc_model(model_path, tokenizer, labels=2, device='cpu', typeModel='bert'):
    collate_fn = None #partial(collate_fever, tokenizer=tokenizer, device=device)

    transformer_config = BertConfig.from_pretrained('bert-base-uncased', num_labels=labels)
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', config=transformer_config).to(device)

    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model'])
    model.train()  # rnn cannot do backwards in train mode

    # Adds a hook to get the embedding gradients
    add_hooks_bert(model, typeModel)
    embedding_weight = get_embedding_weight_bert(model, typeModel)

    return model, embedding_weight, collate_fn

In [28]:
def get_checkpoint_transformer(model_name: str, device: str,
                               hook_embeddings: bool = False,
                               model_type: str = 'bert'):
    """
    'roberta-large-openai-detector' output: (tensor([[-0.1055, -0.6401]],
    grad_fn=<AddmmBackward>),)
    Real-1, Fake-0
    Note from the authors: 'The results start to get reliable after around 50
    tokens.'

    "SparkBeyond/roberta-large-sts-b" model output: (tensor([[0.6732]],
    grad_fn=<AddmmBackward>),)
    STS-B benchmark measure the relatedness of two sentences based on the
    cosine similarity of the two representations

    roberta-large-mnli output: (tensor([[-1.8364,  1.4850,  0.7020]],
    grad_fn=<AddmmBackward>),)
    contradiction-0, neutral-1, entailment-2
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name).to(
        device)
    collate = partial(collate_nli_tok_ids, tokenizer=tokenizer, device=device)

    model_ew = None
    if hook_embeddings:
        # this allows to get the value of the gradient each time we make a
        # feed-forward pass
        model_ew = get_embedding_weight_bert(model, model_type)
        add_hooks_bert(model, model_type)

    return model, model_ew, collate, tokenizer

In [31]:
def collate_nli_tok_ids(instances: List[List],
                        tokenizer: PreTrainedTokenizer,
                        device='cuda') -> List[torch.Tensor]:
    batch_max_len = max([len(_s) for _s in instances])

    padded_ids_tensor = torch.tensor(
        [_s + [tokenizer.pad_token_id] * (batch_max_len - len(_s)) for _s in
         instances])

    output_tensors = [padded_ids_tensor, padded_ids_tensor > 0]

    return list(_t.to(device) for _t in output_tensors)

In [37]:
def collate_fever(instances: List[Dict],
                  tokenizer: PreTrainedTokenizer,
                  device='cuda') -> List[torch.Tensor]:
    token_ids = [tokenizer.encode(_x['claim'], evidence_text(_x)) for _x in
                 instances]
    # the length limit with the encode method does not work with the roberta
    # tokenizer
    batch_max_len = min(512, max([len(_s) for _s in token_ids]))

    padded_ids_tensor = torch.tensor(
        [_s[:batch_max_len] + [tokenizer.pad_token_id] * (
                batch_max_len - min(512, len(_s)))
         for _s in token_ids]).to(device)

    labels_tensor = torch.tensor([_LABELS[_x['label']] for _x in instances],
                                 dtype=torch.long).to(device)

    return [padded_ids_tensor, labels_tensor]

In [43]:
class BertClassifier(nn.Module):

    def __init__(self):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained(MODEL_NAME)
        self.dropout = nn.Dropout(0.5)
        self.l1 = nn.Linear(EMBEDDING_SIZE, 512)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(512, 2)

    def forward(self, input_id, mask):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        linear_output = self.l2(self.relu(self.l1(self.dropout(pooled_output))))

        return linear_output

In [40]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Device: {device}')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

var_numLabels = 2
var_modelType = 'bert'
var_model_path = 'best_valid_f1.pt'

fc_model, fc_model_ew, collate_fc = get_fc_model(var_model_path, tokenizer,
                                                     var_numLabels, device,
                                                     typeModel=var_modelType)

Device: cuda


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

KeyError: 'model'