In [45]:
import random
import copy

import numpy as np
import torch
from torch.utils.data.dataloader import default_collate

from settings import EXPERIMENTS_DIR
from experiment import Experiment
from utils import to_device, load_weights, load_embeddings, create_embeddings_matrix
from vocab import Vocab
from train import create_model
from preprocess import load_dataset, create_dataset_reader

In [46]:
exp_id = 'train.6yk107k2'

# Load everything

In [47]:
exp = Experiment.load(EXPERIMENTS_DIR, exp_id)

In [48]:
exp.config

TrainConfig(model_class=<class 'models.Seq2SeqMeaningStyle'>, preprocess_exp_id='preprocess.nbsquesc', embedding_size=300, hidden_size=256, dropout=0.2, scheduled_sampling_ratio=0.5, pretrained_embeddings=True, trainable_embeddings=False, meaning_size=128, style_size=128, lr=0.001, weight_decay=1e-07, grad_clipping=5, D_num_iterations=10, D_loss_multiplier=1, P_loss_multiplier=10, P_bow_loss_multiplier=1, use_discriminator=True, use_predictor=False, use_predictor_bow=True, use_motivator=True, use_gauss=False, num_epochs=500, batch_size=1024, best_loss='loss')

In [49]:
preprocess_exp = Experiment.load(EXPERIMENTS_DIR, exp.config.preprocess_exp_id)
dataset_train, dataset_val, dataset_test, vocab, style_vocab, W_emb = load_dataset(preprocess_exp)

Dataset: 453655, val: 10000, test: 10000
Vocab: 9419, style vocab: 2
W_emb: (9419, 300)


In [50]:
dataset_reader = create_dataset_reader(preprocess_exp.config)

In [51]:
model = create_model(exp.config, vocab, style_vocab, dataset_train.max_len, W_emb)

In [52]:
load_weights(model, exp.experiment_dir.joinpath('best.th'))

In [53]:
model = model.eval()

## Predict

In [54]:
def create_inputs(instances):
    if not isinstance(instances, list):
        instances = [instances,]
        
    if not isinstance(instances[0], dict):
        sentences = [
            dataset_reader.preprocess_sentence(dataset_reader.spacy( dataset_reader.clean_sentence(sent)))
            for sent in instances
        ]
        
        style = list(style_vocab.token2id.keys())[0]
        instances = [
            {
                'sentence': sent,
                'style': style,
            }
            for sent in sentences
        ]
        
        for inst in instances:
            inst_encoded = dataset_train.encode_instance(inst)
            inst.update(inst_encoded)            
    
    
    instances = [
        {
            'sentence': inst['sentence_enc'],
            'style': inst['style_enc'],
        } 
        for inst in instances
    ]
    
    instances = default_collate(instances)
    instances = to_device(instances)      
    
    return instances

In [55]:
def get_sentences(outputs):
    predicted_indices = outputs["predictions"]
    end_idx = vocab[Vocab.END_TOKEN]
    
    if not isinstance(predicted_indices, np.ndarray):
        predicted_indices = predicted_indices.detach().cpu().numpy()

    all_predicted_tokens = []
    for indices in predicted_indices:
        indices = list(indices)

        # Collect indices till the first end_symbol
        if end_idx in indices:
            indices = indices[:indices.index(end_idx)]

        predicted_tokens = [vocab.id2token[x] for x in indices]
        all_predicted_tokens.append(predicted_tokens)
        
    return all_predicted_tokens

In [56]:
dataset_val.instances[1]

{'sentence': ['they', 'are', 'really', 'good', 'people', '.'],
 'style': 'positive',
 'sentence_enc': array([136,  55,  29, 380, 368,   8,   2,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0]),
 'style_enc': 1,
 'meaning_embedding': array([-9.50000435e-03, -1.41124994e-01, -4.41750027e-02,  1.90999992e-02,
        -1.24025002e-01, -1.08074993e-01, -1.93949997e-01,  8.38750005e-02,
         1.44350007e-01, -2.35500000e-02, -4.74999845e-03,  1.44600004e-01,
        -3.93499993e-02, -1.70499980e-02, -4.63999994e-02,  3.37999985e-02,
        -4.23750058e-02, -4.57499996e-02, -7.32000023e-02,  6.25500008e-02,
        -4.92749996e-02, -7.62749985e-02,  8.68999958e-02, -4.80250008e-02,
        -1.40974998e-01, -5.85000031e-03,  5.96000031e-02,  1.64500009e-02,
        -7.30000250e-03,  1.92149997e-01, -8.38999972e-02,  8.42499826e-03,
         1.69999897e-03, -2.49949992e-01,  1.04975000e-01, -9.66750011e-02,
         6.30000141e-03, -4.97250035e-02, -1.64999999e-02, 

In [57]:
sentence =  ' '.join(dataset_val.instances[1]['sentence'])

In [58]:
sentence

'they are really good people .'

In [59]:
inputs = create_inputs(sentence)
inputs

{'sentence': tensor([[136,  55,  29, 380, 368,   8,   2,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0]], device='cuda:0'),
 'style': tensor([0], device='cuda:0')}

In [60]:
outputs = model(inputs)

In [61]:
sentences = get_sentences(outputs)

In [62]:
' '.join(sentences[0])

'they are really good people .'

### Swap style

In [63]:
possible_styles = list(style_vocab.token2id.keys()) #['negative', 'positive']

In [64]:
possible_styles

['negative', 'positive']

In [65]:
sentences0 = [s for s in dataset_val.instances if s['style'] == possible_styles[0]]
sentences1 = [s for s in dataset_val.instances if s['style'] == possible_styles[1]]

In [66]:
for i in np.random.choice(np.arange(len(sentences0)), 5):
    print(i, ' '.join(sentences0[i]['sentence']))

126 the gyro meat was greasy and dried out at the same time .
1913 after over number years with this institution i believe we deserve more !
2372 that cost $ number .
1410 this place however , does not !
737 being from louisiana , i am a huge smoothie king fan .


In [67]:
for i in np.random.choice(np.arange(len(sentences1)), 5):
    print(i, ' '.join(sentences1[i]['sentence']))

664 everything about this place was clean and professional .
2455 everything we have had here has been good !
1393 bonus is a friendly cat hanging around the patio .
4442 their food is what i crave .
1364 good place for lunch if you are looking for a burger .


#### Swap

In [82]:
target0 = 2346 # np.random.choice(np.arange(len(sentences0)))
target1 = 1364 # np.random.choice(np.arange(len(sentences0)))

In [83]:
print(' '.join(sentences0[target0]['sentence']))

not a place for adults .


In [84]:
print(' '.join(sentences1[target1]['sentence']))

good place for lunch if you are looking for a burger .


In [85]:
inputs = create_inputs([
    sentences0[target0],
    sentences1[target1],
])
inputs

{'sentence': tensor([[ 32,  24,  83,  52, 762,   8,   2,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0],
         [380,  83,  52, 940, 199,  56,  55, 671,  52,  24, 244,   8,   2,   0,
            0,   0,   0,   0,   0,   0]], device='cuda:0'),
 'style': tensor([0, 1], device='cuda:0')}

In [86]:
z_hidden = model.encode(inputs)
style_z_hidden = z_hidden['style_hidden'].clone()
meaning_z_hidden = z_hidden['meaning_hidden'].clone()

In [87]:
z_hidden['style_hidden'].shape

torch.Size([2, 128])

In [88]:
z_hidden['meaning_hidden'].shape

torch.Size([2, 128])

In [89]:
original_decoded = model.decode(z_hidden)

In [90]:
original_sentences = get_sentences(original_decoded)

In [91]:
print(' '.join(original_sentences[0]))
print(' '.join(original_sentences[1]))

not a place for adults .
good place for lunch if you are looking for a burger .


In [108]:
# z_hidden_swapped = {
#     'meaning_hidden': torch.stack([
#         meaning_z_hidden[1].clone(),
#         meaning_z_hidden[0].clone(),        
#     ], dim=0),
#     'style_hidden': torch.stack([
#         style_z_hidden[0].clone(),
#         style_z_hidden[1].clone(),        
#     ], dim=0),
#     'decoder_hidden': torch.stack([
#         z_hidden['decoder_hidden'][0].clone(),
#         z_hidden['decoder_hidden'][1].clone(),        
#     ], dim=0)
# }

In [112]:
torch.rand_like(meaning_z_hidden[0])

<module 'torch' from '/home/oleg/anaconda3/lib/python3.7/site-packages/torch/__init__.py'>

In [191]:
z_hidden_swapped = {
    'meaning_hidden': torch.stack([
        meaning_z_hidden[1].clone(),
        meaning_z_hidden[0].clone(),        
    ], dim=0),
    'style_hidden': torch.stack([
        style_z_hidden[0].clone() * 4,
        style_z_hidden[1].clone() * 4,        
    ], dim=0),
    'decoder_hidden': torch.stack([
        z_hidden['decoder_hidden'][0].clone(),
        z_hidden['decoder_hidden'][1].clone(),        
    ], dim=0)
}

In [192]:
swaped_decoded = model.decode(z_hidden_swapped)

In [193]:
swaped_sentences = get_sentences(swaped_decoded)

In [194]:
print(' '.join(original_sentences[0]))
print(' '.join(original_sentences[1]))
print()
print(' '.join(swaped_sentences[0]))
print(' '.join(swaped_sentences[1]))

not a place for adults .
good place for lunch if you are looking for a burger .

not provide place place craft .
kind recommend for lunch if you are prepared for a burger .


In [111]:
print(' '.join(original_sentences[0]))
print(' '.join(original_sentences[1]))
print()
print(' '.join(swaped_sentences[0]))
print(' '.join(swaped_sentences[1]))

not a place for adults .
good place for lunch if you are looking for a burger .

not a place for adults .
good place for lunch if you are looking for a burger .
