In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchtext
from tqdm import tqdm
from pytorch_pretrained_bert import BertTokenizer, BertModel
from classifiers.sttbt.sentiment_classifier import SentimentSTTBTClassifier
from classifiers.xlmr.formality_classifier import FormalityXLMRClassifier

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
base_vocab = torchtext.vocab.vocab(tokenizer.vocab, min_freq=0)

In [3]:
def from_vocab_to_vocab_fn(vocab_from, vocab_to):
    len_from, len_to = len(vocab_from), len(vocab_to)
    matrix = torch.zeros((len_to, len_from))

    for i in range(len_from):
        token = vocab_from.lookup_token(i)
        if not vocab_to.__contains__(token):
            matrix[0, i] = 1
        else:
            matrix[vocab_to[token], i] = 1

    matrix = matrix.permute(1, 0).to(torch.float)

    def helper(input):
        return input @ matrix

    return helper

In [4]:
class OneHotEmbeddings(nn.Module):
    def __init__(self, lookup_embeddings):
        super().__init__()
        self.embeddings = lookup_embeddings.weight.clone()

    def forward(self, inputs):
        return inputs @ self.embeddings


class OneHotInputModel(nn.Module):
    def __init__(self, lookup_embeddings, model, from_vocab, to_vocab):
        super().__init__()
        self.transform_matrix = from_vocab_to_vocab_fn(from_vocab, to_vocab)
        self.one_hot_embeddings = OneHotEmbeddings(lookup_embeddings)
        self.model = model

    def forward(self, input):
        transformed = self.transform_matrix(input)
        embedded = self.one_hot_embeddings(transformed)
        output = self.model(input_embeds = embedded)
        return output


In [5]:
sentiment_sttbt = SentimentSTTBTClassifier(batch_size=1, max_text_length_in_tokens=100,)
sentiment_model = sentiment_sttbt.model
sentiment_vocab = torchtext.vocab.vocab(sentiment_sttbt.src_dict.labelToIdx,min_freq=0)
sentiment_classifier = OneHotInputModel(
    lookup_embeddings = sentiment_model.word_lut,
    model = sentiment_model,
    from_vocab = base_vocab,
    to_vocab = sentiment_vocab
)

# very big transformation matrix

# formality_xlm = FormalityXLMRClassifier()
# formality_vocab = torchtext.vocab.vocab(formality_xlm.tokenizer.vocab,min_freq=0)
# formality_classifier = OneHotInputModel(
#     lookup_embeddings = formality_xlm.model.roberta.embeddings.word_embeddings,
#     model = formality_xlm.model,
#     from_vocab=base_vocab,
#     to_vocab=formality_vocab,
# )


In [6]:
def init_latent(length, vocab):
    return torch.rand((length, len(vocab)))

In [7]:
def sample_from_latent(latent, num_samples, vocab, verbose=False):
    probas = F.softmax(latent, dim=1)
    sampled_indecies = torch.multinomial(probas, num_samples=num_samples, replacement=True)
    if verbose:
        print(sampled_indecies)
    one_hot_encoded = F.one_hot(sampled_indecies, num_classes=len(vocab))
    return one_hot_encoded.permute(1, 0, 2).to(torch.float)

In [8]:
def text_from_sample(sampled, vocab, length = None):
    numpy_sampled = sampled[:, :(-1 if length is None else length)].detach().clone().numpy()
    argmaxed = np.argmax(numpy_sampled, axis=-1)
    return [' '.join(vocab.lookup_tokens(sample)) for sample in argmaxed]

In [9]:
class Loss:
    def __init__(self, num_samples, vocab):
        self.sampler = lambda latent: sample_from_latent(
            latent = latent,
            num_samples = num_samples,
            vocab = vocab,
        )

class EmbeddingsContentLoss(Loss):
    def __init__(self, target_ids, embeddings, num_samples, vocab):
        super().__init__(num_samples, vocab)

        self.intitial_vectors = embeddings(target_ids).detach().clone()
        print(self.intitial_vectors.shape)
        self.criterion = nn.MSELoss()
        self.embeddings = embeddings.weight.clone()

    def __call__(self, latent):
        sampled = self.sampler(latent).requires_grad_(True)
        embedded = sampled @ self.embeddings

        return lambda: sampled.grad.mean(dim=0), self.criterion(embedded, self.intitial_vectors.unsqueeze(0).expand_as(embedded))


class BertContentLoss(Loss):
    def __init__(self, target, bert, criterion, num_samples, vocab):
        super().__init__(num_samples, vocab)
        self.bert = bert
        target_output, _ = self.bert(target.unsqueeze(0).to(torch.float))
        self.target_output = target_output.detach()
        self.criterion = criterion

    def __call__(self, latent):
        sampled = self.sampler(latent).requires_grad_(True)
        output, _ = self.bert(sampled)

        return lambda: sampled.grad.mean(dim = 0), self.criterion(output, self.target_output.expand_as(output))
        

class StyleLoss(Loss):
    def __init__(self, classificators, target, num_samples, vocab):
        super().__init__(num_samples, vocab)
        self.classificators = classificators
        self.criterion = nn.BCELoss()
        self.target = target

    def __call__(self, latent):
        sampled = self.sampler(latent = latent).requires_grad_(True)
        # sampled = [num_samples, len, vocab_size]


        scores = torch.stack([predict(sampled) for predict in self.classificators])
        # scores = [num_classificators, num_samples, ouput_dim]

        target = self.target.unsqueeze(0).expand_as(scores).to(torch.float)

        return lambda: sampled.grad.mean(dim = 0), self.criterion(scores, target)


In [10]:
text = "This film is terrible !"
tokens = [base_vocab[t] for t in tokenizer.tokenize(text)]
ids = F.pad(torch.tensor(tokens), pad=(0, 50 - len(tokens)), value=0)
print(ids.shape)
one_hot_target = F.one_hot(torch.tensor(ids), num_classes=len(base_vocab))
print(ids)
embeddings = nn.Embedding(len(base_vocab), 300)
vectors = torchtext.vocab.GloVe('6B', dim=300)
embeddings.weight.data = vectors.get_vecs_by_tokens(base_vocab.get_itos())

losses_fn = [
    (1, StyleLoss(
        classificators = [sentiment_classifier],
        target = torch.tensor(1),
        num_samples = 16,
        vocab = base_vocab
    )),
    (1, EmbeddingsContentLoss(
        target_ids = ids,
        embeddings = embeddings,
        num_samples = 16,
        vocab = base_vocab,
    ))
    
]

torch.Size([50])
tensor([2023, 2143, 2003, 6659,  999,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0])


  one_hot_target = F.one_hot(torch.tensor(ids), num_classes=len(base_vocab))


torch.Size([50, 300])


In [11]:
latent = init_latent(length = 50, vocab = base_vocab)
# length >= 50 as sentiment classifier cannot accept less
latent = latent.requires_grad_(True)
optimizer = torch.optim.Adam((latent.requires_grad_(True),), lr=0.6)


for epoch in range(100):
    for iteration in tqdm(range(50)):
        latent.requires_grad_(False)
        total_loss = torch.tensor(0, dtype=torch.float)
        grads = []
        for coeff, loss_fn in losses_fn:
            grad_fn, loss = loss_fn(latent)
            total_loss += coeff * loss
            grads.append(grad_fn)    
        total_loss.backward()

        total_grad = sum(map(lambda grad_fn: grad_fn(), grads))
        optimizer.zero_grad()
        latent.requires_grad_(True)
        F.softmax(latent, dim=1).backward(gradient = total_grad)
        optimizer.step()

        grads.append(latent.grad.norm().item())
    
    print(total_loss.item())
    sampled = sample_from_latent(latent, 5, base_vocab, verbose=True)
    print(text_from_sample(sampled, base_vocab, length=8))

  4%|▍         | 2/50 [00:12<04:58,  6.23s/it]


KeyboardInterrupt: 