# Attacking language models in FL settings

Following the tutorial at https://pytorch.org/tutorials/beginner/transformer_tutorial.html

In [1]:
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

%load_ext autoreload
%autoreload 2

import breaching

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Model definition

In [3]:
class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

## Positional Embedding

In [4]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.0, max_len: int = 5000):
        super().__init__()
        # self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return x # self.dropout(x)

In [5]:
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

# train_iter was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)



def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Args:
        data: Tensor, shape [N]
        bsz: int, batch size

    Returns:
        Tensor of shape [N // bsz, bsz]
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [6]:
bptt = 35
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: int

    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

src_mask = generate_square_subsequent_mask(bptt).to(device)
src_mask

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

## Instantiate transformer

In [7]:
ntokens = len(vocab)  # size of vocabulary
emsize = 200  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)
model

TransformerModel(
  (pos_encoder): PositionalEncoding()
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=200, out_features=200, bias=True)
        )
        (linear1): Linear(in_features=200, out_features=200, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=200, out_features=200, bias=True)
        (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=200, out_features=200, bias=True)
        )
        (linear1): Linear(in_features=200, out_features=

In [8]:
num_batches = len(train_data) // bptt
[(batch, i) for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt))]

[(0, 0),
 (1, 35),
 (2, 70),
 (3, 105),
 (4, 140),
 (5, 175),
 (6, 210),
 (7, 245),
 (8, 280),
 (9, 315),
 (10, 350),
 (11, 385),
 (12, 420),
 (13, 455),
 (14, 490),
 (15, 525),
 (16, 560),
 (17, 595),
 (18, 630),
 (19, 665),
 (20, 700),
 (21, 735),
 (22, 770),
 (23, 805),
 (24, 840),
 (25, 875),
 (26, 910),
 (27, 945),
 (28, 980),
 (29, 1015),
 (30, 1050),
 (31, 1085),
 (32, 1120),
 (33, 1155),
 (34, 1190),
 (35, 1225),
 (36, 1260),
 (37, 1295),
 (38, 1330),
 (39, 1365),
 (40, 1400),
 (41, 1435),
 (42, 1470),
 (43, 1505),
 (44, 1540),
 (45, 1575),
 (46, 1610),
 (47, 1645),
 (48, 1680),
 (49, 1715),
 (50, 1750),
 (51, 1785),
 (52, 1820),
 (53, 1855),
 (54, 1890),
 (55, 1925),
 (56, 1960),
 (57, 1995),
 (58, 2030),
 (59, 2065),
 (60, 2100),
 (61, 2135),
 (62, 2170),
 (63, 2205),
 (64, 2240),
 (65, 2275),
 (66, 2310),
 (67, 2345),
 (68, 2380),
 (69, 2415),
 (70, 2450),
 (71, 2485),
 (72, 2520),
 (73, 2555),
 (74, 2590),
 (75, 2625),
 (76, 2660),
 (77, 2695),
 (78, 2730),
 (79, 2765),
 (8

[23,
 4,
 4,
 0,
 4,
 15,
 520,
 730,
 208,
 9518,
 223,
 3506,
 29,
 40,
 4,
 5686,
 30,
 6579,
 24,
 64]

In [19]:
train_data.shape

torch.Size([102499, 20])

In [29]:
data, targets = get_batch(train_data, 10)
data.shape

torch.Size([35, 20])

In [36]:
print('\n'.join([' '.join(vocab.lookup_tokens(data[:, i].tolist())) for i in range(data.shape[1])]))

chronicles ( japanese 戦場のヴァルキュリア3 , lit . valkyria of the battlefield 3 ) , commonly referred to as valkyria chronicles iii outside japan , is a tactical role @-@ playing video game developed by sega
by kobe bryant ) . that year , jordan was the only washington player to play in all 82 games , starting in 67 of them . he averaged 20 @ . @ 0 points
the junction of union and taylor creeks in 1867 . growth of the new settlement was rapid in particular , there was an influx of german families from wisconsin . the town of madison was
on the ground or without power . damages on the island totaled to $ 5 @ . @ 3 million . in the <unk> <unk> quarter on st . lucia , rough seas damaged a
significant role in the cultural life of the island since ancient times ( and since the 17th century plantations , has been the focus of political identity and divisions on the island ) . ireland
' s mother left the faith she still supports them . months before the marriage of janelle and kody , however , janell

In [41]:

output.shape

torch.Size([35, 20, 28782])

In [44]:
output.view(-1, ntokens).shape

torch.Size([700, 28782])

## Leak all words

In [50]:
criterion = nn.CrossEntropyLoss()


output = model(data, src_mask)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()
loss

tensor(10.4828, device='cuda:0', grad_fn=<NllLossBackward>)

In [70]:
data, targets = get_batch(train_data, 25)
data

tensor([[ 1018,   231,    10,    22,     1,     4,    10, 22549,    18,    13,
            12,    11,  8125,   370,   409,  1405,   741,    25,     8,   494],
        [    7,     6,  2959,     3,  2064, 11776,  1162,     3,  1245,   958,
             0,  1547,    19,   622,    25,   440, 19354,     1,   875,   935],
        [   14,    59,     6,    22,   131,     5,     7,     1, 10944,     3,
           536,    17,    27,     2,  4604,    10,    21,  3481,  1057,     3],
        [ 3849,  4309,   960,    88,  9695,  4772,  3501,  1982,     2,   174,
           679,    27,   920,  5317,    46,  7665,   170,  4206,    81,     6],
        [ 3869,   193,     2,   164,     2,     2,    74,   726,    23,  8525,
          3007,   277,    11,   635,   545,    36, 10472,     3,  3521,  3470],
        [  881,     2,    79,     3,    50,   104,   174,    23,  4916,   806,
            14,    11,    15,     1,     3,  1722,    20,    18,  3160, 16017],
        [  629,  1042,    10,     6,    51,   

In [94]:
leaked_tokens = ((model.encoder.weight.grad != 0).sum(dim=1) > 0).nonzero().squeeze()

In [95]:
vocab.lookup_tokens(leaked_tokens.tolist())

['<unk>',
 'the',
 ',',
 '.',
 'of',
 'and',
 'in',
 'to',
 'a',
 'was',
 "'",
 '@-@',
 'on',
 'as',
 's',
 'that',
 'for',
 'with',
 'by',
 ')',
 '(',
 '@',
 'is',
 'it',
 'from',
 'at',
 'his',
 'he',
 'were',
 'an',
 'had',
 'which',
 'are',
 'this',
 'their',
 'first',
 'but',
 '–',
 'one',
 'they',
 'her',
 'or',
 'two',
 'have',
 'has',
 'been',
 'who',
 'when',
 'all',
 'into',
 'more',
 '1',
 'i',
 'game',
 'most',
 '2',
 'three',
 'up',
 'between',
 'him',
 'there',
 'than',
 'no',
 'year',
 'made',
 'city',
 '3',
 'before',
 'them',
 'being',
 'many',
 'however',
 'part',
 'state',
 'including',
 'became',
 'four',
 'united',
 'century',
 'following',
 'because',
 'so',
 'work',
 'episode',
 'until',
 'could',
 '6',
 'released',
 'church',
 'long',
 'million',
 '0',
 'john',
 'another',
 'large',
 'what',
 '8',
 'down',
 'games',
 'line',
 'name',
 'species',
 'family',
 'played',
 'major',
 'won',
 'play',
 'video',
 'third',
 'april',
 'march',
 'january',
 'men',
 '20',
 '

In [96]:
vocab.lookup_tokens(data.view(-1).tolist())

['referred',
 'play',
 'was',
 '@',
 'the',
 'of',
 'was',
 'parapet',
 'with',
 'on',
 '@-@',
 "'",
 'worried',
 'field',
 'away',
 'kakapo',
 'michael',
 'from',
 'a',
 'human',
 'to',
 'in',
 'rapid',
 '.',
 '17th',
 'janelle',
 'unable',
 '.',
 'author',
 'board',
 '<unk>',
 'except',
 'by',
 'goal',
 'from',
 'population',
 'kritschgau',
 'the',
 'better',
 'nature',
 'as',
 'all',
 'in',
 '@',
 'century',
 'and',
 'to',
 'the',
 'citation',
 '.',
 'together',
 'for',
 'his',
 ',',
 'raising',
 'was',
 '(',
 'dawn',
 'deal',
 '.',
 'valkyria',
 '82',
 'particular',
 '3',
 'plantations',
 'kody',
 'pick',
 'medieval',
 ',',
 'another',
 'find',
 'his',
 'crew',
 'auburn',
 'her',
 'declining',
 'john',
 'spacecraft',
 'than',
 'in',
 'chronicles',
 'games',
 ',',
 'million',
 ',',
 ',',
 'up',
 'era',
 'is',
 'flashback',
 'mention',
 'own',
 "'",
 'brought',
 'son',
 'their',
 'finn',
 '.',
 'anyone',
 '1928',
 'iii',
 ',',
 'there',
 '.',
 'has',
 'however',
 'another',
 'is',
 '

In [103]:
set(leaked_tokens.tolist()) == set(data.view(-1).tolist())

True

## Try to get positions

In [104]:
embedded_data = model.encoder(data) * math.sqrt(model.d_model)
positioned_data = model.pos_encoder(embedded_data)

In [108]:
model.pos_encoder.pe.shape

torch.Size([5000, 1, 200])