# **Interecative code for evaluation of Stylish Autoemcoder**


# Init

In [0]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [0]:
!pip install torchtext
!pip install pyemd



In [0]:
!pip install --upgrade nltk # newest version of NLTK

Collecting nltk
[?25l  Downloading https://files.pythonhosted.org/packages/f6/1d/d925cfb4f324ede997f6d47bea4d9babba51b49e87a767c170b77005889d/nltk-3.4.5.zip (1.5MB)
[K     |▎                               | 10kB 17.2MB/s eta 0:00:01[K     |▌                               | 20kB 4.2MB/s eta 0:00:01[K     |▊                               | 30kB 6.0MB/s eta 0:00:01[K     |█                               | 40kB 3.9MB/s eta 0:00:01[K     |█▏                              | 51kB 4.8MB/s eta 0:00:01[K     |█▍                              | 61kB 5.7MB/s eta 0:00:01[K     |█▋                              | 71kB 6.5MB/s eta 0:00:01[K     |█▉                              | 81kB 7.3MB/s eta 0:00:01[K     |██                              | 92kB 8.1MB/s eta 0:00:01[K     |██▎                             | 102kB 6.4MB/s eta 0:00:01[K     |██▌                             | 112kB 6.4MB/s eta 0:00:01[K     |██▊                             | 122kB 6.4MB/s eta 0:00:01[K     |███ 

In [0]:
import nltk 
nltk.download('wordnet')


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


True

In [0]:
!git clone -l -s -b gen_cls_branch git://github.com/RoyHirsch/TextualStyleTransfer.git TextualStyleTransfer
%cd TextualStyleTransfer
# !ls

Cloning into 'TextualStyleTransfer'...
remote: Enumerating objects: 9, done.[K
remote: Counting objects:  11% (1/9)[Kremote: Counting objects:  22% (2/9)[Kremote: Counting objects:  33% (3/9)[Kremote: Counting objects:  44% (4/9)[Kremote: Counting objects:  55% (5/9)[Kremote: Counting objects:  66% (6/9)[Kremote: Counting objects:  77% (7/9)[Kremote: Counting objects:  88% (8/9)[Kremote: Counting objects: 100% (9/9)[Kremote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects:  14% (1/7)[Kremote: Compressing objects:  28% (2/7)[Kremote: Compressing objects:  42% (3/7)[Kremote: Compressing objects:  57% (4/7)[Kremote: Compressing objects:  71% (5/7)[Kremote: Compressing objects:  85% (6/7)[Kremote: Compressing objects: 100% (7/7)[Kremote: Compressing objects: 100% (7/7), done.[K
Receiving objects:   0% (1/202)   Receiving objects:   1% (3/202)   Receiving objects:   2% (5/202)   Receiving objects:   3% (7/202)   Receiving objects:   4

In [0]:
import os
import sys
CODE_PATH = '/content/TextualStyleTransfer/'
sys.path.append(CODE_PATH)

import torch
import numpy as np
import pandas as pd
from pyemd import emd
import matplotlib.pyplot as plt
import json
from gensim.models.word2vec import Word2Vec

from bs4 import BeautifulSoup
from torchtext import data
from torchtext import datasets
from torchtext.vocab import Vectors, GloVe
from torchtext.data import Field, LabelField, TabularDataset
from spacy.lang.en import English

from data import *
from train import *
from evaluate import *
from utils import *
%matplotlib inline  


# Data

In [0]:
def generate_bigrams(x):
    n_grams = set(zip(*[x[i:] for i in range(2)]))
    for n_gram in n_grams:
        x.append(' '.join(n_gram))
    return x

In [0]:
def load_dataset_eval(dataset_name, base_path, preprocessing_func, max_len, min_freq, embed_dim, batch_size, device):

    # define tokenizer
    en = English()

    def tokenize_spacy_with_html_parsing(sentence):
        sentence = BeautifulSoup(sentence, 'html.parser').get_text()
        return [tok.text for tok in en.tokenizer(sentence)]

    # eos_token - end of sentence token, batch_first - first dimension is batch, fix_length - can be also None
    TEXT = data.Field(sequential=True, tokenize=tokenize_spacy_with_html_parsing,
                      preprocessing=preprocessing_func, lower=True,
                      eos_token='<eos>', batch_first=True, fix_length=max_len)
    LABEL = data.LabelField()

    print('Start loading dataset {}:'.format(dataset_name))

    if dataset_name == 'IMDB':
        train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

    elif dataset_name == 'SST':
        train_data, test_data = datasets.SST.splits(TEXT, LABEL)

    elif dataset_name == 'YELP':
        fields_list = [('Unnamed: 0', None),
                       ('text', TEXT),
                       ('label', LABEL)]

        yelp_train_path = os.path.join(base_path, "yelp_train.csv")
        yelp_test_path = os.path.join(base_path, "yelp_test.csv")

        train_data = TabularDataset(
            path=yelp_train_path,
            format='csv',
            skip_header=True,
            fields=fields_list)

        test_data = TabularDataset(
            path=yelp_test_path,
            format='csv',
            skip_header=True,
            fields=fields_list)

    else:
        raise ValueError

    TEXT.build_vocab(train_data, test_data, min_freq=min_freq, vectors=GloVe(name='6B', dim=embed_dim))
    
    print("Loaded Glove embedding, Vector size of Text Vocabulary: " + str(TEXT.vocab.vectors.size()))

    LABEL.build_vocab(train_data)

    word_embeddings = TEXT.vocab.vectors
    print("Length of Text Vocabulary: " + str(len(TEXT.vocab)))

    train_iter, test_iter = data.BucketIterator.splits((train_data, test_data),
                                                       batch_sizes=(batch_size, batch_size),
                                                       sort_key=lambda x: len(x.text), repeat=False, shuffle=True,
                                                       device=device)
    # Disable shuffle
    test_iter.shuffle = False

    return TEXT, word_embeddings, train_iter, test_iter


# Style transformer

In [0]:
from transformer_model import *
import torch.nn.functional as F

class ArgMaxEmbed(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, embed, src_mask):
        idx = torch.argmax(inputs, -1)
        # idx *= src_mask.squeeze(1).type(idx.dtype)
        ctx._input_shape = inputs.shape
        ctx._input_dtype = inputs.dtype
        ctx._input_device = inputs.device
        ctx.save_for_backward(idx)
        return embed(idx)

    @staticmethod
    def backward(ctx, grad_output):
        idx, = ctx.saved_tensors
        grad_input = torch.zeros(ctx._input_shape, device=ctx._input_device, dtype=ctx._input_dtype)
        # print("backward debug", idx[..., None].shape, grad_output.sum(-1, keepdim=True), grad_input.shape)
        grad_input.scatter_(-1, idx[..., None], grad_output.sum(-1, keepdim=True))
        return grad_input, None, None

class StyleTransformer(nn.Module):
    """
    An encoder that also encodes style and adds it to the representation
    """
    def __init__(self, src_vocab, tgt_vocab, N=6, 
                 d_model=512, d_ff=2048, h=8, n_styles=2, dropout=0.1, max_len=128):
        super().__init__()
        c = copy.deepcopy
        attn = MultiHeadedAttention(h, d_model)
        ff = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.src_embed = Embeddings(d_model, src_vocab)
        self.argmax = ArgMaxEmbed.apply
        self.encoder = BasicEncoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N)
        self.position = PositionalEncoding(d_model, dropout, max_len)
        self.style_embed = nn.Embedding(n_styles, d_model)
        self.generator = nn.Linear(d_model, tgt_vocab)
        self.temperature = 1.0
        
        # Initialize parameters with Glorot / fan_avg.
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def encode_style(self, style_labels):
        style_embadding = self.style_embed(style_labels).unsqueeze(1)
        if style_embadding.ndimension() == 1:
            style_embadding = style_embadding.unsqueeze(0).unsqueeze(1)
        elif style_embadding.ndimension() == 2:
            style_embadding = style_embadding.permute(1, 0).unsqueeze(0)
        return style_embadding

    def forward(self, src, src_mask, style, argmax=False):
        "Take in and process masked src and target sequences."
        style = self.style_embed(style).unsqueeze(dim=1)
        if argmax:
            src = self.argmax(src, self.src_embed, src_mask)
        else:
            src = self.src_embed(src)
        src = self.position(src)
        # add style before position?
        x = src + style
        enc_out = self.encoder(x, src_mask)
        return self.generator(enc_out)


# Functions

In [0]:
def binary_accuracy(preds, y):

    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()
    acc = correct.sum() / len(correct)
    return acc


In [0]:
def print_src_dest(src, dest, TEXT):
  word2id = TEXT.vocab.stoi
  eos_id = int(word2id['<eos>'])
  pad_id = int(word2id['<pad>'])
  stop_words = [eos_id, pad_id]
  id2word = {v: k for k, v in word2id.items()}
  
  src_sent = []
  for i in src:
    if i in stop_words: break
    src_sent.append(id2word[int(i)])

  dest_sent = []
  for i in dest:
    if i in stop_words: break
    dest_sent.append(id2word[int(i)])
  
  print('original: {}'.format(' '.join(src_sent)))
  print('generated: {}'.format(' '.join(dest_sent)))

In [0]:
def greedy_decode_sent(preds, id2word, eos_id):
    ''' Naive greedy decoding - just argmax over the vocabulary distribution '''
    preds = torch.argmax(preds, -1)
    
    # Find eof
    eos_ind = (preds == eos_id).nonzero()
    if len(eos_ind) > 0:
      eos_ind = eos_ind[0]
    else:
      eos_ind = len(preds) - 1
    
    # <pad> token ind is 1
    out = torch.ones(len(preds))
    out[:eos_ind] = preds[:eos_ind]
    
#     decoded_sent = preds.detach().cpu().numpy()
#     print(" ".join([id2word[i] for i in decoded_sent]))
#     decoded_sent = sent2str(decoded_sent, id2word, eos_id)
    return out


def sent2str(sent_as_np, id2word, eos_id=None):
    ''' Gets sentence as a list of ids and transfers to string
        Input is np array of ids '''
    if not (isinstance(sent_as_np, np.ndarray)):
        raise ValueError('Invalid input type, expected np array')
    if eos_id:
        end_id = np.where(sent_as_np == eos_id)[0]
        if len(end_id) > 1:
            sent_as_np = sent_as_np[:int(end_id[0])]
        elif len(end_id) == 1:
            sent_as_np = sent_as_np[:int(end_id)]

    return " ".join([id2word[i] for i in sent_as_np])

  
# def predict_style(model_path, vocab_size, embed_dim, data_iterator, device, raw=False):
#     # init model
#     model = FastText(vocab_size, embed_dim, 1, 1)
#     model.load_state_dict(torch.load(model_path))
    
    
#     model = model.to(device)
#     res = []
#     epoch_acc = 0
#     model.eval()
    
#     with torch.no_grad():
#       for batch in data_iterator:
#         preds = model(batch.text).squeeze(1)
#         labels = batch.labels.detach().cpu().float()
        
#         preds = preds.detach().cpu()
#         acc = binary_accuracy(preds, labels)
#         epoch_acc += acc.item()
#         if raw:
#           res.append(preds.numpy())
#         else:
#           res.append(torch.round(torch.sigmoid(preds)).numpy())

#     test_acc = epoch_acc / len_data
#     return test_acc, res
    

def predict_style(model_path, vocab_size, embed_dim, gen_data, gen_labels, device, raw=False):
    # init model
    model = FastText(vocab_size, embed_dim, 1, 1)
    model.load_state_dict(torch.load(model_path))
    
    
    model = model.to(device)
    res = []
    epoch_acc = 0
    model.eval()
    
    with torch.no_grad():
      for data, labels in zip(gen_data, gen_labels):
        data = data.to(device)
        preds = model(data.t()).squeeze(1)
        labels = labels.detach().cpu().float()
        
        preds = preds.detach().cpu()
        acc = binary_accuracy(preds, labels)
        epoch_acc += acc.item()
        if raw:
          res.append(preds.numpy())
        else:
          res.append(torch.round(torch.sigmoid(preds)).numpy())

    test_acc = epoch_acc / len(gen_data)
    return test_acc, res
    
  
def get_predictions_from_style_transfer_model(style_gen_model, test_dataloader, TEXT, device):

    # decoding utils
    word2id = TEXT.vocab.stoi
    eos_id = int(word2id['<eos>'])
    id2word = {v: k for k, v in word2id.items()}

    org_sents = []
    gen_sents = []
    gen_labels = []
    len_dataloader = len(test_dataloader)
    print_int = 100

    style_gen_model.eval()
    with torch.no_grad():
        for m, batch in enumerate(test_dataloader):
            
            src, labels = batch.text, batch.label
            src_mask, _ = make_masks(src, src, device)
            
            labels = labels.to(device)
            src = src.to(device)
            src_mask = src_mask.to(device)

            # Get predictions from generator
            neg_labels = (~labels.byte()).long()
            preds = style_gen_model(src, src_mask, neg_labels, argmax=False)
            
            preds = preds.detach().cpu()
            src = src.detach().cpu()
            neg_labels = neg_labels.detach().cpu()
            
            # Decode sentences
            decoded_transfered_sentences = torch.zeros_like(src)
            for n, pred_sent in enumerate(preds):
              dec_sent = greedy_decode_sent(pred_sent, id2word, eos_id)
              decoded_transfered_sentences[n, :] = dec_sent
                        
            org_sents.append(src)
            gen_sents.append(decoded_transfered_sentences)
            gen_labels.append(neg_labels)
            
    return org_sents, gen_sents, gen_labels

# Style Transfer Intensity Functions

In [0]:
import torch.nn as nn
import torch.nn.functional as F


class FastText(nn.Module):
    def __init__(self, vocab_size, embedding_dim, output_dim, pad_idx):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.fc = nn.Linear(embedding_dim, output_dim)

    def forward(self, text):
        # text = [sent len, batch size]
        embedded = self.embedding(text)
        # embedded = [sent len, batch size, emb dim]
        embedded = embedded.permute(1, 0, 2)
        # embedded = [batch size, sent len, emb dim]
        pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)
        # pooled = [batch size, embedding_dim]
        return self.fc(pooled)

In [0]:
def calculate_emd(input_distribution, output_distribution):   
    '''
    Calculate Earth Mover's Distance (aka Wasserstein distance) between 
    two distributions of equal length.
    Parameters
    ----------
    input_distribution : numpy.ndarray
        Probabilities assigned to style classes for an input text
    output_distribution : numpy.ndarray
        Probabilities assigned to style classes for an output text, e.g. of a style transfer model
        
    Returns
    -------
    Earth Mover's Distance (float) between the two given style distributions
    '''
    
    N = len(input_distribution)
    distance_matrix = np.ones((N, N))
    return emd(input_distribution, output_distribution, distance_matrix)

def account_for_direction(input_target_style_probability, output_target_style_probability):
    '''
    In the context of EMD, more mass (higher probability) placed on a target style class
    in the style distribution of an output text (relative to that of the input text)
    indicates movement in the correct direction of style transfer. 
    
    Otherwise, the style transfer intensity score should be penalized, via application
    of a negative direction factor.
    Parameters
    ----------
    input_target_style_probability : float
        Probability assigned to target style in the style distribution of an input text
    output_target_style_probability : float
        Probability assigned to target style in the style distribution of an output text, e.g. of a style transfer model
        
    Returns
    -------
    1 if correct direction of style transfer, else -1
    '''
    
    if output_target_style_probability >= input_target_style_probability:
        return 1
    return -1

def calculate_direction_corrected_emd(input_distribution, output_distribution, target_style_class): 
    '''
    Calculate Earth Mover's Distance (aka Wasserstein distance) between 
    two distributions of equal length, with correction for direction.
    That is, penalize the score if the output style distribution displays
    change of style in the wrong direction, i.e. away from the target style.
    Parameters
    ----------
    input_distribution : numpy.ndarray
        Probabilities assigned to style classes for an input text
    output_distribution : numpy.ndarray
        Probabilities assigned to style classes for an output text, e.g. of a style transfer model
    target_style_class : int
        Label of the intended style class for a style transfer task
        
    Returns
    -------
    Direction-corrected Earth Mover's Distance (float) between the two given style distributions
    '''
    
    emd_score = calculate_emd(input_distribution, output_distribution)
    direction_factor = account_for_direction(input_distribution[target_style_class], output_distribution[target_style_class])
    return emd_score*direction_factor


# Content Preservation Functions

In [0]:
def calc_meteor_scores(org_sents, generated_sents, TEXT):
  """ org_sents, generated_sents - lists of batches """
  word2id = TEXT.vocab.stoi
  id2word = {v: k for k, v in word2id.items()}

  met_scores_list = []
  for n, (src, dest) in enumerate(zip(org_sents, generated_sents)):
    if type(src) == torch.Tensor:
      src = src.numpy()
    
    if type(dest) == torch.Tensor:
      dest = dest.numpy()
    
    for sent_src, sent_dest in zip(src, dest):
      sent_src = ' '.join([id2word[i] for i in sent_src])
      sent_dest = ' '.join([id2word[i] for i in sent_dest])

      # reference is the source and hypothesis is the generated/translated
      # single_meteor_score(hypothesis, reference)
      met_scores_list.append(nltk.translate.meteor_score.single_meteor_score(sent_dest, sent_src))
    if n % 100 == 0: print('{}:{}'.format(n, len(org_sents))) # debug print 
      
  return np.mean(met_scores_list), np.std(met_scores_list)

def ids2text(list_list_tokens, id2word):
  eos_id = int(word2id['<eos>'])
  pad_id = int(word2id['<pad>'])
  stop_words = [eos_id, pad_id]

  res = []
  for sent_batch in list_list_tokens:
    for sent in sent_batch:
      tmp = []
      for token in sent:
        if token in stop_words: break
        tmp.append(id2word[int(token)])
      res.append(tmp)
  return res
  
def load_json(path):
    with open(path) as f:
        data = json.load(f)
    return data

def load_lexicon(lexicon_path):
    # collect style words from existing set of style features and weights
    style_features_and_weights = load_json(lexicon_path)
    return set(map(lambda x: x[0], style_features_and_weights['binary sentiment']))
  
def mark_style_words(texts, style_tokens):
    '''
    Mask or remove style words (based on a set of style tokens) from input texts.
    Parameters
    ----------
    texts : list
        String inputs
    style_tokens : set
        Style tokens
    mask_style : boolean
        Set to False to remove style tokens, True to replace with placeholder
        
    Returns
    -------
    edited_texts : list
        Texts with style tokens masked or removed
    '''
    
    edited_texts = []
    
    for tokens in texts:
        edited_tokens = []
        
        for token in tokens:
            if token.lower() in style_tokens:
                edited_tokens.append(token)
            
        edited_texts.append(' '.join(edited_tokens))

    return edited_texts

def load_word2vec_model(path):
    model = Word2Vec.load(path)
    model.init_sims(replace=True) # normalize vectors
    return model

def calculate_wmd_scores(references, candidates, wmd_model):
    '''
    Calculate Word Mover's Distance for each (reference, candidate)
    pair in a list of reference texts and candidate texts.
    
    The lower the distance, the more similar the texts are.
    Parameters
    ----------
    references : list
        Input texts
    candidates : list
        Output texts (e.g. from a style transfer model)
    wmd_model : gensim.models.word2vec.Word2Vec
        Trained Word2Vec model
        
    Returns
    -------
    wmd_scores : list
        WMD scores for all pairs 
    '''
    
    wmd_scores = []

    for i in range(len(references)):
        wmd = wmd_model.wv.wmdistance(references[i], candidates[i])
        wmd_scores.append(wmd)

    return wmd_scores
  

# Get train and test data

In [0]:

data_path = "/content/drive/My Drive/StyleTransfer/evaluation"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
fast_text_model_path = os.path.join(data_path, 'fast_text_train_all_trainset_acc_99_7.pth')
style_trasformer_model_path = os.path.join(data_path, 'model_dec_8_yelp_freq3_len25_dim300.pth')

# manual params
max_len = 25
min_freq = 3
embed_dim = 300
batch_size = 32

# define tokenizer
en = English()

def tokenize_spacy_with_html_parsing(sentence):
    sentence = BeautifulSoup(sentence, 'html.parser').get_text()
    return [tok.text for tok in en.tokenizer(sentence)]

# eos_token - end of sentence token, batch_first - first dimension is batch, fix_length - can be also None
TEXT = data.Field(sequential=True, tokenize=tokenize_spacy_with_html_parsing,
                  preprocessing=None, lower=True,
                  eos_token='<eos>', batch_first=True, fix_length=max_len)
LABEL = data.LabelField()

fields_list = [('Unnamed: 0', None),
               ('text', TEXT),
               ('label', LABEL)]

yelp_train_path = '/content/drive/My Drive/StyleTransfer/YELP/yelp_train.csv'
yelp_test_path = '/content/drive/My Drive/StyleTransfer/data_eval_raw/yelp_test_200.csv'

train_data = TabularDataset(
    path=yelp_train_path,
    format='csv',
    skip_header=True,
    fields=fields_list)

test_data = TabularDataset(
    path=yelp_test_path,
    format='csv',
    skip_header=True,
    fields=fields_list)

TEXT.build_vocab(train_data, test_data, min_freq=min_freq, vectors=GloVe(name='6B', dim=embed_dim))
print("Loaded Glove embedding, Vector size of Text Vocabulary: " + str(TEXT.vocab.vectors.size()))

LABEL.build_vocab(train_data)

word_embeddings = TEXT.vocab.vectors
print("Length of Text Vocabulary: " + str(len(TEXT.vocab)))

train_iter, test_iter = data.BucketIterator.splits((train_data, test_data),
                                                   batch_sizes=(batch_size, batch_size),
                                                   sort_key=lambda x: len(x.text), repeat=False, shuffle=True,
                                                   device=device)
# Disable shuffle
test_iter.shuffle = False


Loaded Glove embedding, Vector size of Text Vocabulary: torch.Size([16177, 300])
Length of Text Vocabulary: 16177


# Get predictions from style transfer model

In [0]:
style_trasformer_model = StyleTransformer(src_vocab=len(TEXT.vocab.vectors), tgt_vocab=len(TEXT.vocab.vectors),N=8, h=6, d_model=300, max_len=25)

style_trasformer_model.load_state_dict(torch.load(style_trasformer_model_path))
style_trasformer_model = style_trasformer_model.to(device)

In [0]:
org_sents, gen_sents, gen_labels = get_predictions_from_style_transfer_model(style_gen_model=style_trasformer_model,
                                                                             test_dataloader=test_iter, TEXT=TEXT, device=device)

# Calc style strength metrices

Calculate two metrices for style strength:
- Accuracy
- EMD

In [0]:
target_labels_flat = np.concatenate(gen_labels)

def sig(x):
  eps = 1e-10
  return 1 / (1 + np.exp(-x))
  
acc, preds_for_gen_samples = predict_style(model_path=fast_text_model_path, vocab_size=len(TEXT.vocab), embed_dim=100,
                                           gen_data=gen_sents, gen_labels=gen_labels, device=device, raw=True)
print('Generated samples acc: {:.3f}'.format(acc))

preds_conf_flat = np.concatenate(preds_for_gen_samples)
dist_preds_gen = np.zeros((len(preds_conf_flat), 2))
                      
for n in range(len(preds_conf_flat)):
  sig_val = sig(preds_conf_flat[n])
  dist_preds_gen[n, :] = (1-sig_val, sig_val)
  
_, preds_for_org_samples = predict_style(model_path=fast_text_model_path, vocab_size=len(TEXT.vocab), embed_dim=100,
                                           gen_data=org_sents, gen_labels=gen_labels, device=device, raw=True)
preds_conf_flat = np.concatenate(preds_for_org_samples)
dist_preds_org = np.zeros((len(preds_conf_flat), 2))
                      
for n in range(len(preds_conf_flat)):
  sig_val = sig(preds_conf_flat[n])
  dist_preds_org[n, :] = (1-sig_val, sig_val)
  
  emd_list = []
for i in range(len(dist_preds_org)):
  emd_list.append(calculate_direction_corrected_emd(dist_preds_org[i], dist_preds_gen[i], target_labels_flat[i]))
print('EMD: mean:{:.3f}+-{:.3f}'.format(np.mean(emd_list), np.std(emd_list)))

Generated samples acc: 0.654
EMD: mean:0.198+-0.675


# Calc content metrices

Calculate two metrices for contet preservation:
- METEOR score
- WMD

In [0]:
calc_meteor_scores(org_sents, gen_sents, TEXT)

0:626
100:626
200:626
300:626
400:626
500:626
600:626


(0.888876854484517, 0.05941480962078296)

In [0]:

w2v_model_path = os.path.join(data_path, 'word2vec_unmasked')
lexicon_path = os.path.join(data_path, 'style_words_and_weights.json')
lexicon = load_lexicon(lexicon_path)
word2id = TEXT.vocab.stoi
id2word = {v: k for k, v in word2id.items()}

org_sents_text = ids2text(org_sents, id2word)
gen_sents_text = ids2text(gen_sents, id2word)

org_sents_stripped = mark_style_words(texts=org_sents_text, style_tokens=lexicon)
gen_sents_stripped = mark_style_words(texts=gen_sents_text, style_tokens=lexicon)

assert len(org_sents_stripped) == len(gen_sents_stripped)

wmd_model = load_word2vec_model(w2v_model_path)
wmd_scores_without_style = calculate_wmd_scores(org_sents_stripped, gen_sents_stripped, wmd_model)
print('WMD without style words: mean:{:.3f}+-{:.3f}'.format(np.ma.masked_invalid(wmd_scores_without_style).mean(),
                                                       np.ma.masked_invalid(wmd_scores_without_style).std()))

wmd_scores_reg = calculate_wmd_scores(org_sents_text, gen_sents_text, wmd_model)
print('WMD: mean:{:.3f}+-{:.3f}'.format(np.ma.masked_invalid(wmd_scores_reg).mean(),
                                   np.ma.masked_invalid(wmd_scores_reg).std()))

  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL


WMD without style words: mean:0.420+-0.340
WMD: mean:0.222+-0.167
