In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchtext
import datasets
import functools
from tqdm import tqdm
import sys
import numpy as np
from pytorch_pretrained_bert import BertTokenizer, BertModel
from classifiers.sttbt.sentiment_classifier import SentimentSTTBTClassifier

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
base_vocab = torchtext.vocab.vocab(tokenizer.vocab, min_freq=0)

In [3]:
def from_vocab_to_vocab_fn(vocab_from, vocab_to):
    len_from, len_to = len(vocab_from), len(vocab_to)
    matrix = torch.zeros((len_to, len_from))

    for i in range(len_from):
        token = vocab_from.lookup_token(i)
        if not vocab_to.__contains__(token):
            matrix[0, i] = 1
        else:
            matrix[vocab_to[token], i] = 1

    matrix = matrix.permute(1, 0).to(torch.float)

    def helper(input):
        return input @ matrix

    return helper

In [4]:
## https://github.com/maknotavailable/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L172-L200
class OneHotBertEmbeddings(nn.Module):
    def __init__(self, bert_embeddings):
        super().__init__()
        self.word_embeddings = bert_embeddings.word_embeddings.weight.clone()
        self.position_embeddings = bert_embeddings.position_embeddings
        self.layer_norm = bert_embeddings.LayerNorm

    def forward(self, one_hot_encoded):
        batch_size, seq_length, _ = one_hot_encoded.size()
        position_ids = torch.arange(seq_length, dtype=torch.long, device=one_hot_encoded.device)
        position_ids = position_ids.unsqueeze(0).expand((batch_size, seq_length))

        words_embeddings = one_hot_encoded @ self.word_embeddings
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = words_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        return embeddings


class OneHotBertModel(nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.embeddings = OneHotBertEmbeddings(bert_model.embeddings)
        self.encoder = bert_model.encoder
        self.pooler = bert_model.pooler

    def forward(self, input):
        
        batch_size, seq_len, vocab_size = input.size()

        embedding_output = self.embeddings(input)
        encoded_layers = self.encoder(embedding_output,
                                      torch.zeros(batch_size, 1, 1, seq_len),
                                      output_all_encoded_layers=True)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output


class OneHotInputModel(nn.Module):
    def __init__(self, model, from_vocab, to_vocab):
        super().__init__()
        self.transform_matrix = from_vocab_to_vocab_fn(from_vocab, to_vocab)
        self.model = model

    def forward(self, input):
        transformed = self.transform_matrix(input)
        output = self.model(transformed)
        return output


In [5]:
additional_args = {"max_text_length_in_tokens": 50, "gpu": False}
sentiment_sttbt = SentimentSTTBTClassifier(batch_size=1, **additional_args)
sentiment_model = sentiment_sttbt.model
sentiment_vocab = torchtext.vocab.vocab(sentiment_sttbt.src_dict.labelToIdx,min_freq=0)
sentiment_classifier = OneHotInputModel(
    model = sentiment_model,
    from_vocab = base_vocab,
    to_vocab = sentiment_vocab
)


In [6]:
def init_latent(length, vocab):
    return torch.rand((length, len(vocab)))

In [7]:
def sample_from_latent(latent, num_samples, vocab, verbose=False):
    probas = F.softmax(latent, dim=1)
    sampled_indecies = torch.multinomial(probas, num_samples=num_samples, replacement=True)
    if verbose:
        print(sampled_indecies)
    one_hot_encoded = F.one_hot(sampled_indecies, num_classes=len(vocab))
    return one_hot_encoded.permute(1, 0, 2).to(torch.float)

In [8]:
def text_from_sample(sampled, vocab):
    numpy_sampled = sampled.detach().clone().numpy()
    argmaxed = np.argmax(numpy_sampled, axis=-1)
    return [' '.join(vocab.lookup_tokens(sample)) for sample in argmaxed]

In [44]:
class Loss:
    def __init__(self, num_samples, vocab):
        self.sampler = lambda latent: sample_from_latent(
            latent = latent,
            num_samples = num_samples,
            vocab = vocab,
        )

class EmbeddingsContentLoss(Loss):
    def __init__(self, target_ids, embeddings, num_samples, vocab):
        super().__init__(num_samples, vocab)

        self.intitial_vectors = embeddings(target_ids).detach().clone()
        print(self.intitial_vectors.shape)
        self.criterion = nn.MSELoss()
        self.embeddings = embeddings.weight.clone()

    def __call__(self, latent):
        sampled = self.sampler(latent).requires_grad_(True)
        embedded = sampled @ self.embeddings

        return lambda: sampled.grad.mean(dim=0), self.criterion(embedded, self.intitial_vectors.unsqueeze(0).expand_as(embedded))


class BertContentLoss(Loss):
    def __init__(self, target, bert, criterion, num_samples, vocab):
        super().__init__(num_samples, vocab)
        self.bert = bert
        target_output, _ = self.bert(target.unsqueeze(0).to(torch.float))
        self.target_output = target_output.detach()
        self.criterion = criterion

    def __call__(self, latent):
        sampled = self.sampler(latent).requires_grad_(True)
        output, _ = self.bert(sampled)

        return lambda: sampled.grad.mean(dim = 0), self.criterion(output, self.target_output.expand_as(output))
        

class StyleLoss(Loss):
    def __init__(self, classificators, target, num_samples, vocab):
        super().__init__(num_samples, vocab)
        self.classificators = classificators
        self.criterion = nn.BCELoss()
        self.target = target

    def __call__(self, latent):
        sampled = self.sampler(latent = latent).requires_grad_(True)
        # sampled = [num_samples, len, vocab_size]


        scores = torch.stack([predict(sampled.permute(1, 0, 2)) for predict in self.classificators])
        # scores = [num_classificators, num_samples, ouput_dim]

        target = self.target.unsqueeze(0).expand_as(scores).to(torch.float)

        return lambda: sampled.grad.mean(dim = 0), self.criterion(scores, target)


In [45]:
text = "This film is terrible !"
tokens = [base_vocab[t] for t in tokenizer.tokenize(text)]
ids = F.pad(torch.tensor(tokens), pad=(0, 50 - len(tokens)), value=0)
print(ids.shape)
one_hot_target = F.one_hot(torch.tensor(ids), num_classes=len(base_vocab))
print(ids)
embeddings = nn.Embedding(len(base_vocab), 300)
vectors = torchtext.vocab.GloVe('6B', dim=300)
embeddings.weight.data = vectors.get_vecs_by_tokens(base_vocab.get_itos())

losses_fn = [
    (1, StyleLoss(
        classificators = [sentiment_classifier],
        target = torch.tensor(1),
        num_samples = 16,
        vocab = base_vocab
    )),
    (0.1, EmbeddingsContentLoss(
        target_ids = ids,
        embeddings = embeddings,
        num_samples = 16,
        vocab = base_vocab,
    ))
    
]

torch.Size([50])
tensor([2023, 2143, 2003, 6659,  999,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0])


  one_hot_target = F.one_hot(torch.tensor(ids), num_classes=len(base_vocab))


torch.Size([50, 300])


In [46]:
latent = init_latent(length = 50, vocab = base_vocab)
latent = latent.requires_grad_(True)
optimizer = torch.optim.Adam((latent.requires_grad_(True),), lr=0.1)


for epoch in range(100):
    for iteration in tqdm(range(20)):
        latent.requires_grad_(False)
        total_loss = torch.tensor(0, dtype=torch.float)
        grads = []
        for coeff, loss_fn in losses_fn:
            grad_fn, loss = loss_fn(latent)
            total_loss += coeff * loss
            grads.append(grad_fn)    
        total_loss.backward()

        total_grad = sum(map(lambda grad_fn: grad_fn(), grads))
        optimizer.zero_grad()
        latent.requires_grad_(True)
        F.softmax(latent, dim=1).backward(gradient = total_grad)
        optimizer.step()

        grads.append(latent.grad.norm().item())
    
    print(total_loss.item())
    sampled = sample_from_latent(latent, 5, base_vocab, verbose=True)
    print(text_from_sample(sampled, base_vocab))

100%|██████████| 20/20 [01:33<00:00,  4.66s/it]


0.09569944441318512
tensor([[ 7688, 12455,  6254, 22510, 27962],
        [ 3798,  1058, 21698, 16723, 19193],
        [ 8489,  3769,  4201,  3342,  5568],
        [14992,  2517, 13946, 26165, 22438],
        [11673, 10315, 17443,  3058, 11958],
        [ 9508,  5696, 21459,  4784,  4058],
        [28645, 17576,  5557, 10025, 18114],
        [ 2777,  7023,  8642, 23302, 10972],
        [ 5822,  8208,  5030, 10156, 19882],
        [ 5129, 13467,  1027, 21516,  5215],
        [ 7168,  2295, 23658,  9393,  7644],
        [15896, 13082, 26071, 16733,  4056],
        [13364, 23705, 21722, 22421, 11310],
        [ 1043,  3341,  7689,  5132, 12806],
        [27343, 14414,  3857, 11780,  9678],
        [17530,  3225, 23893,  3530,  3913],
        [16416, 30111, 17544,  5202, 18237],
        [11441,  9903, 26199, 17102, 26007],
        [ 5664,  5541, 14466, 27997,  7258],
        [ 2272, 19723, 26380, 17901,  5492],
        [14716, 16236, 12849, 17070,  5232],
        [ 2680, 12528,  6167, 14957

 15%|█▌        | 3/20 [00:18<01:42,  6.05s/it]


KeyboardInterrupt: 