# Parameter attacks on Transformers? 

Following the tutorial at https://pytorch.org/tutorials/beginner/transformer_tutorial.html

In [1]:
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

%load_ext autoreload
%autoreload 2

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

## Model definition

In [3]:
class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

### Positional Embedding

In [4]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.0, max_len: int = 5000):
        super().__init__()
        # self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return x # self.dropout(x)

### Data stuff (boring)

In [5]:
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

# train_iter was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)



def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Args:
        data: Tensor, shape [N]
        bsz: int, batch size

    Returns:
        Tensor of shape [N // bsz, bsz]
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 2
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

### Here define the sequence length we want to consider (don't make too long ;) 

In [6]:
bptt = 20

In [7]:
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: int

    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

src_mask = generate_square_subsequent_mask(bptt).to(device)

### Instantiate transformer

In [8]:
ntokens = len(vocab)  # size of vocabulary
emsize = 512  # embedding dimension
d_hid = 512  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 1  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 1  # number of heads in nn.MultiheadAttention
dropout = 0.0 # Kinda need this to be zero :/
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)

In [9]:
print([(k, v.shape) for k,v in model.named_parameters()])

[('transformer_encoder.layers.0.self_attn.in_proj_weight', torch.Size([1536, 512])), ('transformer_encoder.layers.0.self_attn.in_proj_bias', torch.Size([1536])), ('transformer_encoder.layers.0.self_attn.out_proj.weight', torch.Size([512, 512])), ('transformer_encoder.layers.0.self_attn.out_proj.bias', torch.Size([512])), ('transformer_encoder.layers.0.linear1.weight', torch.Size([512, 512])), ('transformer_encoder.layers.0.linear1.bias', torch.Size([512])), ('transformer_encoder.layers.0.linear2.weight', torch.Size([512, 512])), ('transformer_encoder.layers.0.linear2.bias', torch.Size([512])), ('transformer_encoder.layers.0.norm1.weight', torch.Size([512])), ('transformer_encoder.layers.0.norm1.bias', torch.Size([512])), ('transformer_encoder.layers.0.norm2.weight', torch.Size([512])), ('transformer_encoder.layers.0.norm2.bias', torch.Size([512])), ('encoder.weight', torch.Size([28782, 512])), ('decoder.weight', torch.Size([28782, 512])), ('decoder.bias', torch.Size([28782]))]


### Get some sample sentence just to see what we're dealing with

In [10]:
data, targets = get_batch(train_data, 100)
print('\n'.join([' '.join(vocab.lookup_tokens(data[:, i].tolist())) for i in range(data.shape[1])]))
data.shape

penal military unit serving the nation of gallia during the second europan war who perform secret black operations and are
arikamedu was in <unk> , in a communication from the consul of the indo @-@ french colony of pondicherry .


torch.Size([20, 2])

In [11]:
data.shape

torch.Size([20, 2])

## Attack setup starts here

### First, let's define our linear measurement - some Gaussian vector here

In [12]:
import math
weights = torch.randn(512)
std, mu = torch.std_mean(weights)
measurement = (weights - mu) / std / math.sqrt(512) # Here's our linear measurement

### Here we set up the MHA stuff for inter-sequence separation. 

In [13]:
# Here we see about the MHA layer. in_proj_weight is Q, K, V matrices.
print([(k, v.shape) for k,v in getattr(model.transformer_encoder.layers, '0').self_attn.named_parameters()])
# In order: "q, k, v"

# Let's set the query matrix to produce just the first positional encoding (or could be any index - might want last index)
qkv_shape = getattr(model.transformer_encoder.layers, '0').self_attn.in_proj_weight.data.shape[0] 

# Dummy inputs to get positional encoding
inputs, targets = get_batch(train_data, 100)
inputs = inputs.to(device=device)
just_pos = torch.stack([getattr(model.transformer_encoder.layers, '0').norm1(model.pos_encoder(torch.zeros_like(model.encoder(inputs)) * math.sqrt(model.d_model)))]).cpu().squeeze()

# Q matrix setup
# Make the position super super large to skew softmax
getattr(model.transformer_encoder.layers, '0').self_attn.in_proj_bias.data[:qkv_shape//3] = 1000*just_pos[0,0,:]
getattr(model.transformer_encoder.layers, '0').self_attn.in_proj_weight.data[:qkv_shape//3] = torch.zeros((qkv_shape//3, qkv_shape//3))

# K matrix setup (identity)
getattr(model.transformer_encoder.layers, '0').self_attn.in_proj_weight.data[qkv_shape//3:2*(qkv_shape//3)] = torch.eye(qkv_shape//3)

# V matrix setup (identity)
getattr(model.transformer_encoder.layers, '0').self_attn.in_proj_weight.data[2*(qkv_shape//3):] = torch.eye(qkv_shape//3)

# Linear layer at the end of MHA - set to small value to not 'skew' embeddings too much
getattr(model.transformer_encoder.layers, '0').self_attn.out_proj.weight.data = 0.05*torch.eye(qkv_shape//3)


[('in_proj_weight', torch.Size([1536, 512])), ('in_proj_bias', torch.Size([1536])), ('out_proj.weight', torch.Size([512, 512])), ('out_proj.bias', torch.Size([512]))]


### Setting the second linear layer to just all zeros except for first row helps? Need to figure out why ... 

In [14]:
getattr(model.transformer_encoder.layers, '0').linear2.weight.data = torch.zeros_like(getattr(model.transformer_encoder.layers, '0').linear2.weight.data)
getattr(model.transformer_encoder.layers, '0').linear2.weight.data[0] = torch.ones_like(getattr(model.transformer_encoder.layers, '0').linear2.weight.data[0])
getattr(model.transformer_encoder.layers, '0').linear2.bias.data = torch.zeros_like(getattr(model.transformer_encoder.layers, '0').linear2.bias.data)


### Now let's get the feature statistics for our random measurement vector

In [15]:
@torch.inference_mode()
def feature_distribution(model):
        """Compute the mean and std of the feature layer of the given network."""
        features = dict()
        setup = dict(device=device)
        def named_hook(name):
            def hook_fn(module, input, output):
                features[name] = input[0]
            return hook_fn
        
        for name, module in list(model.named_modules()):
            if name == 'transformer_encoder.layers.0.linear1':
                print(f'In Linear: {name}')
                hook = module.register_forward_hook(named_hook(name))
                feature_layer_name = name
                break
        feats = []
        feats_before = []
        model.train()
        model.to(**setup)
        print(f'Computing feature distribution before the {feature_layer_name} layer from external data.')
#         for batch, i in enumerate(list(range(0, train_data.size(0) - 1, bptt))[:-1]):
#             inputs, targets = get_batch(train_data, i)
#             inputs = inputs.to(**setup)
#             model(inputs, src_mask)
#         model.eval()
        for batch, i in enumerate(list(range(0, train_data.size(0) - 1, bptt))[:-1]):
            inputs, targets = get_batch(train_data, i)
            inputs = inputs.to(**setup)
            model(inputs, src_mask)
            feats.append(features[feature_layer_name].detach().view(inputs.shape[0]*inputs.shape[1], -1).clone().cpu())
        std, mu = torch.std_mean(torch.mm(torch.cat(feats), measurement.unsqueeze(1)).squeeze())
        print(f'Feature mean is {mu.item()}, feature std is {std.item()}.')
        model.eval()
        model.cpu()
        hook.remove()

        return std, mu

### Now we'll construct the imprint weights. Only for the first linear layer here...

In [None]:
from statistics import NormalDist

def _get_bins(mean, std, num_bins):
    bins = []
    mass_per_bin = 1 / (num_bins)
    bins.append(-10)  # -Inf is not great here, but NormalDist(mu=0, sigma=1).cdf(10) approx 1
    for i in range(1, num_bins):
        bins.append(NormalDist().inv_cdf(i * mass_per_bin)*std + mean)
    return bins

def _make_biases(bias_layer, bins):
    new_biases = torch.zeros_like(bias_layer.data)
    for i in range(new_biases.shape[0]):
        new_biases[i] = -bins[i]
    return new_biases

feature_std, feature_mean = feature_distribution(model)
bins = _get_bins(feature_mean, feature_std, model.d_model)
getattr(model.transformer_encoder.layers, '0').linear1.weight.data = measurement.repeat(model.d_model, 1)
getattr(model.transformer_encoder.layers, '0').linear1.bias.data = _make_biases(getattr(model.transformer_encoder.layers, '0').linear1.bias, bins)


In Linear: transformer_encoder.layers.0.linear1
Computing feature distribution before the transformer_encoder.layers.0.linear1 layer from external data.


### Get gradients

In [None]:
criterion = nn.CrossEntropyLoss()
inputs, targets = get_batch(train_data, 100)
#inputs, targets = get_batch(train_data, 0)
inputs = inputs.to(device=device)
model.to(device=device)
model.zero_grad()
output = model(inputs, src_mask)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()

## Actual attack starts here

### Now get the bag of words, as well as the reconstructed positionally encoded stuff

In [None]:
leaked_tokens = ((model.encoder.weight.grad != 0).sum(dim=1) > 0).nonzero().squeeze() # Bag of words tokens
weight_grad = getattr(model.transformer_encoder.layers, '0').linear1.weight.grad.detach().clone().cpu()
bias_grad = getattr(model.transformer_encoder.layers, '0').linear1.bias.grad.detach().clone().cpu()

for i in reversed(list(range(1, weight_grad.shape[0]))):
    weight_grad[i] -= weight_grad[i - 1]
    bias_grad[i] -= bias_grad[i - 1]
valid_classes = bias_grad != 0

recs = weight_grad[valid_classes, :] / bias_grad[valid_classes, None] # Here are our reconstructed positionally encoded features

In [None]:
recs.shape

### Now the messy stuff ... 

### Let's associate tokens with embeddings

In [None]:
# First, let's get the features of our bag of words sans positional encoding
no_pos = getattr(model.transformer_encoder.layers, '0').norm1((model.encoder(leaked_tokens) * math.sqrt(model.d_model))).cpu()
with_pos = recs # Here are those same features, but with positional encodings (stuff we reconstructed)

import numpy as np
indcs = []
corrs = torch.zeros((len(no_pos), len(with_pos))) 

# We need to find out what word led to what positionally encoded representation. 
# Let's try the naive greedy search for correlations between no_pos and with_pos as defined above
for i, no_p in enumerate(no_pos):
    max_corr = 0
    for j, with_p in enumerate(with_pos):
        val = np.corrcoef(np.array([no_p.detach().numpy(), with_p.detach().numpy()]))[0,1]
        corrs[i,j] = val

# Find which positionally-encoded vector associates with un-positionally-encoded vector
from scipy.optimize import linear_sum_assignment # Better than greedy search? 
row_ind, col_ind = linear_sum_assignment(corrs.numpy(), maximize=True)

order = [(row_i, col_i) for (row_i, col_i) in zip(row_ind, col_ind)]
order = sorted(order, key=lambda x: x[1])

# Now let's re-sort the tokens by this order
sorted_tokens1 = sorted_tokens1 = [leaked_tokens[order[i][0]] for i in range(len(order))]

In [None]:
no_pos.shape, with_pos.shape

### Now let's get each token's position, as well as splitting sequences

In [None]:
# Now that we've 'lined-up' the pos-encoded features with non-pos-encoded features, let's subtract the two
# to get some 'faux' positions (layer norm means they aren't exact).
estimated_pos = torch.stack([with_pos[order[i][1]] - no_pos[order[i][0]] for i in range(len(order))])
new_with_pos = [with_pos[order[i][1]] for i in range(len(order))]

# Now let's get just the additive part of the positional encoding
just_pos = torch.stack([getattr(model.transformer_encoder.layers, '0').norm1(model.pos_encoder(torch.zeros_like(model.encoder(inputs)) * math.sqrt(model.d_model)))]).cpu().squeeze()

# The old way of without sequence splitting... If we only have 1 user, with 1 sequence per batch, it works well. 
# Here we just get 'jumbled' sentences if there are multiple users

new_just_pos = just_pos[:,0,:] # just save this thing for later :) 
just_pos = just_pos.view(-1, just_pos.shape[-1])
order_coeffs = torch.zeros((len(estimated_pos), len(just_pos)))

# We'll do another linear sum assignment, but now it's on the positions of the tokens
# First calculate the correlations between (faux) estimated_pos and the (real) just_pos terms
for i in range(len(estimated_pos)):
    for j in range(len(just_pos)):
        order_coeffs[i,j] = np.corrcoef(estimated_pos[i].detach().numpy(), just_pos[j].detach().numpy())[0,1]
row_ind, col_ind = linear_sum_assignment(order_coeffs.numpy(), maximize=True)
pos_order = [(row_i, col_i) for (row_i, col_i) in zip(row_ind, col_ind)]
pos_order = sorted(pos_order, key=lambda x: x[1])

just_pos = new_just_pos

# ----------- NEW STUFF -----------  Getting multiple user's sentences back

# Let's calculate this matrix again, but for the new method (previous calculation was just for old method, can ignore)
order_coeffs = torch.zeros((len(estimated_pos), len(just_pos)))
for i in range(len(estimated_pos)):
    for j in range(len(just_pos)):
        order_coeffs[i,j] = np.corrcoef(estimated_pos[i].detach().numpy(), just_pos[j].detach().numpy())[0,1]

# Now, we make a dictionary where keys are positions, and values are encoded embeddings. 
# i.e. word_groups[0] = ['0th_word_of_sequence1', '0th_word_of_sequence2', ...]
from collections import defaultdict
word_groups = defaultdict(list)

for i in range(order_coeffs.shape[0]):
    max_corr = torch.argmax(order_coeffs[i]).item()
    word_groups[max_corr].append(i)

# Sort these word groups to start forming sentences
sorted_keys = sorted([k for k in word_groups.keys()])
word_groups = [word_groups[k] for k in sorted_keys]
first_words = word_groups[0]
sentences = [[] for i in range(inputs.shape[1])] # Cheating a bit because we shouldn't know how many users there are.

# Start the sentences with first words
for i, first_w in enumerate(first_words):
    sentences[i].append(sorted_tokens1[first_w])

# Go through the rest of the word groups, assigning words to their appropriate sentences
for w in word_groups[1:]:
    corr = torch.zeros(len(w), len(first_words))
    for i, x in enumerate(w):
        for j, y in enumerate(first_words):
            corr[i,j] = np.corrcoef(estimated_pos[x].detach().numpy(), new_with_pos[y].detach().numpy())[0,1]
    
    # Below we do linear sum assignment for each word to each potential sentence
    row_ind, col_ind = linear_sum_assignment(corr.numpy(), maximize=True)
    for m, n in zip(row_ind, col_ind):
        sentences[n].append(sorted_tokens1[w[m]])

# Here are the (old) sorted tokens - i.e. just jumbled sequences
final_sorted_tokens = [sorted_tokens1[pos_order[i][0]] for i in range(len(pos_order))]

### Let's see how we did 

In [None]:
print("GROUND TRUTH SEQUENCES \n" + '-'*50)
for i in range(inputs.shape[-1]):
    print(f'USER SEQUENCE {i}')
    print(' '.join(vocab.lookup_tokens([x for x in inputs[:,i]]))) # The true sentence
    print('\n')
    
print('\n'*2)
print("RECONSTRUCTED JUMBLED SEQUENCES (OLD METHOD) \n" + '-'*50)
print(' '.join(vocab.lookup_tokens(final_sorted_tokens))) # What we reconstruct

print('\n'*2)
print("RECONSTRUCTED SPLIT SEQUENCES (NEW METHOD) \n" + '-'*50)
for i in range(len(sentences)):
    print(f'RECONSTRUCTED SEQUENCE {i}')
    print(' '.join(vocab.lookup_tokens(sentences[i]))) # What we reconstruct
    print('\n')


In [None]:
model.transformer_encoder.layers[0].linear1.weight.shape