In [1]:
import numpy as np
from scipy.spatial.distance import cosine
import torch
from torch.autograd import Variable
import torch.nn.functional as F

from data_loader import DataLoader
from classifier import train_classifier, load_saved_model

%load_ext autoreload
%autoreload 2

In [6]:
data_loader = DataLoader()
data_loader.load_data()

loading data...


In [8]:
# this model is trained using large dataset
cnn_model = load_saved_model('CNN', 'saved_model/cnn-1.pt', data_loader)

splitting data...
building vocabulary...
CNN(
  (embedding): Embedding(25002, 100)
  (convs1): ModuleList(
    (0): Conv2d(1, 100, kernel_size=(3, 100), stride=(1, 1))
    (1): Conv2d(1, 100, kernel_size=(4, 100), stride=(1, 1))
    (2): Conv2d(1, 100, kernel_size=(5, 100), stride=(1, 1))
  )
  (dropout): Dropout(p=0.5)
  (fc1): Linear(in_features=300, out_features=1, bias=True)
)


In [9]:
def k_nearest(matrix,vector,k):
    """
    The function takes in a matrix and a normalized vector and returns the indices
    of the vectors from the matrix which are closest to the vector.
    The pretrained embeddings are normalized.
    """
    distances = np.dot(matrix,vector)
    print(max(distances))
    result = np.argsort(distances)[::-1][:k]
    return result

In [20]:
def get_input(data_loader, k=1):
    example = data_loader.large_train.examples[k].text
    label = data_loader.large_train.examples[k].label
    word_indices = np.array([data_loader.TEXT.vocab.stoi[word] for word in example])
    one_input = torch.from_numpy(word_indices)
    return one_input.unsqueeze(1), label

def get_logit(input_example, model, print_msg=False):
    logit = model(input_example)
    if print_msg:
        print('logit:', logit)
        print('pred:', torch.round(torch.sigmoid(logit)))
    
    return logit

def get_predict(logit):
    return torch.round(torch.sigmoid(logit))

def generate_sentence(words_idx, data_loader):
    sentence = ' '.join(data_loader.TEXT.vocab.itos[id] for id in words_idx)
    
    return sentence

In [231]:
def k_nearest_idx(one_embedding, k=10, thershold=15):
    embedding_distance = torch.sum((one_embedding - cnn_model.embedding.weight.data) ** 2, dim=1)
    embedding_distance = torch.reshape(embedding_distance, (1, len(embedding_distance)))
    distances, indices = torch.topk(embedding_distance, k+1, largest=False)
    new_indices = []
    for i, dis in enumerate(distances[0]):
        if dis < 15 and i >= 1:
            new_indices.append(indices[0][i])
    return new_indices

In [232]:
def attack(one_input, model, dis_threshold=10, change_threshold=0.1):
    input_embedding = cnn_model.embedding.weight.data[one_input.squeeze(0)].clone()
    
    initial_logit = get_logit(one_input, model, print_msg=True)
    initial_label = get_predict(initial_logit)
    new_logit = initial_logit.clone()
    change_count = 0
    output = one_input.squeeze(0).clone()
    for idx, word_embedding in enumerate(input_embedding):
        new_indices = k_nearest_idx(word_embedding)
        if len(new_indices) > 0:
            new_logits = []
            for new_index in new_indices:
                new_input = output.squeeze(0).clone()
                new_input[idx] = new_index
                new_logits.append(get_logit(new_input.unsqueeze(0), model))
            if initial_label == 0 and max(new_logits) > new_logit:
                output[idx] = new_indices[np.asarray(new_logits).argmax()]
                change_count += 1
            elif initial_label == 1 and min(new_logits) < new_logit:
                output[idx] = new_indices[np.asarray(new_logits).argmin()]
                change_count += 1
        if (change_count / len(input_embedding)) > change_threshold:
            return None
        new_logit = get_logit(output.unsqueeze(0), model, print_msg=True)
        new_label = get_predict(new_logit)
        if new_label != initial_label:
            return output

In [233]:
def generate_sentence(words_idx, data_loader):
    sentence = ' '.join(data_loader.TEXT.vocab.itos[id] for id in words_idx)
    return sentence

### Generate Sample

In [234]:
one_input, one_label = get_input(data_loader, k=10)
one_input = torch.t(one_input)

In [230]:
result = attack(one_input, cnn_model)
if result is not None:
    print(generate_sentence(result.squeeze(0),data_loader))

logit: tensor([[10.1673]], grad_fn=<ThAddmmBackward>)
pred: tensor([[1.]], grad_fn=<RoundBackward>)
logit: tensor([[9.6786]], grad_fn=<ThAddmmBackward>)
pred: tensor([[1.]], grad_fn=<RoundBackward>)
logit: tensor([[8.3827]], grad_fn=<ThAddmmBackward>)
pred: tensor([[1.]], grad_fn=<RoundBackward>)
logit: tensor([[7.7780]], grad_fn=<ThAddmmBackward>)
pred: tensor([[1.]], grad_fn=<RoundBackward>)
logit: tensor([[7.6253]], grad_fn=<ThAddmmBackward>)
pred: tensor([[1.]], grad_fn=<RoundBackward>)
logit: tensor([[7.5520]], grad_fn=<ThAddmmBackward>)
pred: tensor([[1.]], grad_fn=<RoundBackward>)
logit: tensor([[7.5520]], grad_fn=<ThAddmmBackward>)
pred: tensor([[1.]], grad_fn=<RoundBackward>)


In [192]:
print(generate_sentence(one_input.squeeze(0),data_loader))

i loved this movie and will watch it again . original twist to plot of man vs man vs self . i think this is kurt russell 's best movie . his eyes conveyed more than most actors words . perhaps there 's hope for mankind in spite of government intervention ?


In [129]:
data_loader.TEXT.vocab.itos[one_input[0][0]]

'if'

In [139]:
data_loader.TEXT.vocab.itos[90]

'could'

In [132]:
input_embedding.shape

torch.Size([525, 100])

In [235]:
embedding_distance.shape

torch.Size([1, 25002])

### Validation - sucess rate

In [244]:
def validate(data_loader):
    success_count = 0
    for idx, example in enumerate(data_loader.large_valid.examples):
        label = data_loader.large_train.examples[idx].label
        word_indices = np.array([data_loader.TEXT.vocab.stoi[word] for word in example])
        one_input = torch.from_numpy(word_indices).unsqueeze(1)
        one_input = torch.t(one_input)
        result = attack(one_input, cnn_model)
        if result is not None:
            success_count += 1
    return success_count/len(data_loader.large_valid.examples)
validate(data_loader)
# data_loader.large_train.examples[0]

TypeError: 'Example' object is not iterable