A script to train an encoder-decoder based on transformer for sytle transfer.

Based on Harvard implementation of Transformer: http://nlp.seas.harvard.edu/2018/04/03/attention.html

Also based on: https://arxiv.org/pdf/1711.06861.pdf

# Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
import os
import sys
CODE_PATH = '/content/drive/My Drive/NLP/final/TextualStyleTransfer/'
sys.path.append(CODE_PATH)

# Init

In [3]:
!pip install pytorch-transformers
!pip install torchtext

Collecting pytorch-transformers
[?25l  Downloading https://files.pythonhosted.org/packages/a3/b7/d3d18008a67e0b968d1ab93ad444fc05699403fa662f634b2f2c318a508b/pytorch_transformers-1.2.0-py3-none-any.whl (176kB)
[K     |█▉                              | 10kB 16.9MB/s eta 0:00:01[K     |███▊                            | 20kB 1.8MB/s eta 0:00:01[K     |█████▋                          | 30kB 2.6MB/s eta 0:00:01[K     |███████▍                        | 40kB 1.7MB/s eta 0:00:01[K     |█████████▎                      | 51kB 2.1MB/s eta 0:00:01[K     |███████████▏                    | 61kB 2.5MB/s eta 0:00:01[K     |█████████████                   | 71kB 2.9MB/s eta 0:00:01[K     |██████████████▉                 | 81kB 3.3MB/s eta 0:00:01[K     |████████████████▊               | 92kB 3.7MB/s eta 0:00:01[K     |██████████████████▋             | 102kB 2.8MB/s eta 0:00:01[K     |████████████████████▍           | 112kB 2.8MB/s eta 0:00:01[K     |██████████████████████▎     

In [0]:
import os
import sys
import numpy as np
import logging

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, RandomSampler

from data import *
from train import *
from evaluate import *
from utils import *

# Parameters

In [5]:
class Params(object):

    # Loggin
    # Free text to describe the experiment
    COMMENT = ''
    VERBOSE = True
    EXP_NAME = "gal_exp"
    DATA_PATH = "/content/drive/My Drive/NLP/"
    MODELS_PATH = "/content/drive/My Drive/NLP/final/"
    PRINT_INTERVAL = 10

    # TODO: for local use
    # DATA_PATH = MODELS_PATH = os.path.abspath(__file__+'/../')
    MODELS_LOAD_PATH = "/content/drive/My Drive/NLP/final/gal_exp"

    # Data
    DATASET_NAME = 'IMDB'
    # Maximal number of batches for test model
    TEST_MAX_BATCH_SIZE = 300
    # Min freq for word in dataset to include in vocab
    VOCAB_MIN_FREQ = 3
    # Whether to use Glove embadding - if TRUE set H_DIM to 300
    VOCAB_USE_GLOVE = True
    TRAIN_BATCH_SIZE = 32
    TEST_BATCH_SIZE = 32
    # maximum length of allowed sentence - can be also None
    MAX_LEN = 25

    # Transformer model
    N_LAYERS = 8
    N_LAYERS_CLS = 4
    H_DIM = 300
    N_ATTN_HEAD = 5
    FC_DIM = 2048
    DO_RATE = 0.1

    # Classification model
    N_STYLES = 2
    DO_RATE_CLS = 0.1
    TRANS_CLS = True
    TRANS_DES = True
    CLS_ACC_BAR = 95.0
    NEG_CLS_ACC_BAR = 95.0

    # Train5
    N_EPOCHS = 20
    PATIENCE = 3
    ENC_LR = 0
    DEC_LR = 3e-4
    CLS_LR = 3e-4
    TRANS_STEPS_RATIO = 0.1
    TRUE_STEPS_RATIO = 0.9
    PERIOD_STEPS = 100
    ENC_WARMUP_RATIO = 0.1
    DEC_WARMUP_RATIO = 0.2
    CLS_WARMUP_RATIO = 0.2
    TRUE_REC_LAMBDA = 1e-2
    TRUE_CLS_LAMBDA = 0.0
    NEG_CYC_REC_LAMBDA = 0.5
    NEG_REC_LAMBDA = 0.0
    NEG_DES_LAMBDA = 0.0
    NEG_CLS_LAMBDA = 0.5
    TRAIN_ON_CLS_LOSS = False
    REC_ACC_BAR = 80.0
    DES_ACC_BAR = 60.0


    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

params = Params()
logger = create_logger(params)
pprint_params(params)

INFO - 09/12/19 19:38:26 - 0:00:00 - Params for experiment:
INFO - 09/12/19 19:38:26 - 0:00:01 - CLS_ACC_BAR = 95.0
INFO - 09/12/19 19:38:26 - 0:00:01 - CLS_LR = 0.0003
INFO - 09/12/19 19:38:26 - 0:00:01 - CLS_WARMUP_RATIO = 0.2
INFO - 09/12/19 19:38:26 - 0:00:01 - COMMENT = ''
INFO - 09/12/19 19:38:26 - 0:00:01 - DATASET_NAME = 'IMDB'
INFO - 09/12/19 19:38:26 - 0:00:01 - DATA_PATH = '/content/drive/My Drive/NLP/'
INFO - 09/12/19 19:38:26 - 0:00:01 - DEC_LR = 0.0003
INFO - 09/12/19 19:38:26 - 0:00:01 - DEC_WARMUP_RATIO = 0.2
INFO - 09/12/19 19:38:26 - 0:00:01 - DES_ACC_BAR = 60.0
INFO - 09/12/19 19:38:26 - 0:00:01 - DO_RATE = 0.1
INFO - 09/12/19 19:38:26 - 0:00:01 - DO_RATE_CLS = 0.1
INFO - 09/12/19 19:38:26 - 0:00:01 - ENC_LR = 0
INFO - 09/12/19 19:38:26 - 0:00:01 - ENC_WARMUP_RATIO = 0.1
INFO - 09/12/19 19:38:26 - 0:00:01 - EXP_NAME = 'gal_exp'
INFO - 09/12/19 19:38:26 - 0:00:01 - FC_DIM = 2048
INFO - 09/12/19 19:38:26 - 0:00:01 - H_DIM = 300
INFO - 09/12/19 19:38:26 - 0:00:01 - MAX_

In [0]:
# import pandas as pd

# def create_mini_csv(csv_path, new_csv_path, size):
#     data = pd.read_csv(csv_path)
#     data = data.loc[:size]
#     print(len(data))
#     data.to_csv(new_csv_path.format(size))

# create_mini_csv("/content/drive/My Drive/NLP/YELP/yelp_train.csv", "/content/drive/My Drive/NLP/YELP/yelp_train_{}.csv", 100000)

100001


# Data

In [0]:
# import torchtext
# from torchtext.data import Field, LabelField, TabularDataset
# from spacy.lang.en import English
# import os

# TEXT, word_embeddings, train_iter, test_iter = load_dataset_from_csv(params=params, device=params.device)
# train_dataset_len = len(train_iter.dataset)
# print('Train dataset len: {} Test dataset len: {}'.format(len(train_iter.dataset), len(test_iter.dataset)))

import torchtext
from torchtext.data import Field, LabelField, TabularDataset
from spacy.lang.en import English

en = English()

def tokenize(sentence):
    return [tok.text for tok in en.tokenizer(sentence)]

TEXT = Field(sequential=True, tokenize=tokenize, lower=True, eos_token='<eos>', batch_first=True, fix_length=params.MAX_LEN)
LABEL = LabelField()

fields_list = [('Unnamed: 0', None),
                ('text', TEXT),
                ('label', LABEL)]

train_dataset = TabularDataset(
                            path="/content/drive/My Drive/NLP/YELP/yelp_train.csv", # the root directory where the data lies
                            format='csv',
                            skip_header=True, 
                            fields=fields_list)

test_dataset = TabularDataset(
                            path="/content/drive/My Drive/NLP/YELP/yelp_test_200.csv", # the root directory where the data lies
                            format='csv',
                            skip_header=True, 
                            fields=fields_list)

if params.VOCAB_USE_GLOVE:
    TEXT.build_vocab(train_dataset, test_dataset, min_freq=params.VOCAB_MIN_FREQ, vectors=GloVe(name='6B', dim=params.H_DIM))
    logging.info("Loaded Glove embedding, Vector size of Text Vocabulary: " + str(TEXT.vocab.vectors.size()))

else:
    TEXT.build_vocab(train_dataset, test_dataset, min_freq=params.VOCAB_MIN_FREQ)
LABEL.build_vocab(train_dataset)

word_embeddings = TEXT.vocab.vectors
logging.info("Length of Text Vocabulary: " + str(len(TEXT.vocab)))

train_iter, test_iter = data.BucketIterator.splits((train_dataset, test_dataset),
                                                    batch_sizes=(params.TRAIN_BATCH_SIZE, params.TRAIN_BATCH_SIZE),
                                                    sort_key=lambda x: len(x.text), repeat=False, shuffle=True,
                                                    device=params.device)
# Disable shuffle
test_iter.shuffle = False

train_dataset_len = len(train_iter.dataset)
print('Train dataset len: {} Test dataset len: {}'.format(len(train_iter.dataset), len(test_iter.dataset)))


In [0]:
from transformer_model import *
import torch.nn.functional as F

class ArgMaxEmbed(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, embed, sample):
        if sample:
            m = torch.distributions.Categorical(logits=inputs)
            idx = m.sample()
        else:
            idx = torch.argmax(inputs, -1)
        ctx._input_shape = inputs.shape
        ctx._input_dtype = inputs.dtype
        ctx._input_device = inputs.device
        ctx.save_for_backward(idx)
        return embed(idx)

    @staticmethod
    def backward(ctx, grad_output):
        idx, = ctx.saved_tensors
        grad_input = torch.zeros(ctx._input_shape, device=ctx._input_device, dtype=ctx._input_dtype)
        # print("backward debug", idx[..., None].shape, grad_output.sum(-1, keepdim=True), grad_input.shape)
        grad_input.scatter_(-1, idx[..., None], grad_output.sum(-1, keepdim=True))
        return grad_input, None, None

class StyleTransformer(nn.Module):
    """
    An encoder that also encodes style and adds it to the representation
    """
    def __init__(self, src_vocab, tgt_vocab, N=6,
                    d_model=512, d_ff=2048, h=8, n_styles=2, dropout=0.1, max_len=128):
        super().__init__()
        c = copy.deepcopy
        attn = MultiHeadedAttention(h, d_model)
        ff = PositionwiseFeedForward(d_model, d_ff, dropout)

        self.src_embed = Embeddings(d_model, src_vocab)
        self.argmax = ArgMaxEmbed.apply
        self.encoder = BasicEncoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N)
        self.position = PositionalEncoding(d_model, dropout, max_len)
        self.style_embed = nn.Embedding(n_styles, d_model)
        self.generator = nn.Linear(d_model, tgt_vocab)

        # Initialize parameters with Glorot / fan_avg.
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def encode_style(self, style_labels):
        style_embadding = self.style_embed(style_labels).unsqueeze(1)
        if style_embadding.ndimension() == 1:
            style_embadding = style_embadding.unsqueeze(0).unsqueeze(1)
        elif style_embadding.ndimension() == 2:
            style_embadding = style_embadding.permute(1, 0).unsqueeze(0)
        return style_embadding

    def forward(self, src, src_mask, style, argmax=False):
        "Take in and process masked src and target sequences."
        style = self.style_embed(style).unsqueeze(dim=1)
        if argmax:
            src = self.argmax(src, self.src_embed, False)
        else:
            src = self.src_embed(src)
        src = self.position(src)
        # add style before position?
        x = src + style
        enc_out = self.encoder(x, src_mask)
        return self.generator(enc_out)

class MaskedMean(nn.Module):
    " Calculate masked mean of input 3D tensor "

    def __init__(self, normalize=True):
        self.normalize = normalize
        super().__init__()

    def forward(self, x, mask):
        batch_size, _, embed_size = x.size()
        mask_expanded = mask.transpose(-1, -2).expand(-1, -1, embed_size).float()
        masked_input = x * mask_expanded
        sum_ = torch.sum(masked_input, 1)
        div = mask.sum(-1).float()
        embed = torch.div(sum_, div.view(batch_size, 1))
        if self.normalize:
            return F.normalize(embed, p=2, dim=1)
        else:
            return embed

class TransformerClassifier(nn.Module):
    """
    Transformer for style classification
    """
    def __init__(self, output_size, input_size, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1, max_len=128):
        super().__init__()
        # self.src_embed = nn.Linear(input_size, d_model)
        self.src_embed = Embeddings(d_model, input_size)
        self.argmax = ArgMaxEmbed.apply
        c = copy.deepcopy
        attn = MultiHeadedAttention(h, d_model)
        ff = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.position = PositionalEncoding(d_model, dropout, max_len)
        self.encoder = BasicEncoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N)
        self.generator = nn.Linear(d_model, output_size)
        self.masked_mean = MaskedMean()

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, src_mask, argmax=False):
        if argmax:
            src = self.argmax(src, self.src_embed, False)
        else:
            src = self.src_embed(src)
        src = self.position(src)
        out = self.encoder(src, src_mask)
        out = self.generator(out)
        out = self.masked_mean(out, src_mask)
        return out


def init_models(vocab_size, params):
    model_trans = StyleTransformer(src_vocab=vocab_size, tgt_vocab=vocab_size,
                                        N=params.N_LAYERS, d_model=params.H_DIM, d_ff=params.FC_DIM,
                                        h=params.N_ATTN_HEAD, n_styles=params.N_STYLES, dropout=params.DO_RATE, max_len=params.MAX_LEN)
    if params.TRANS_CLS:
        model_cls = TransformerClassifier(output_size=params.N_STYLES, N=params.N_LAYERS_CLS, d_model=params.H_DIM,
                                        d_ff=params.FC_DIM, h=params.N_ATTN_HEAD, dropout=params.DO_RATE_CLS,
                                        input_size=vocab_size, max_len=params.MAX_LEN)
    else:
        model_cls = Descriminator(input_size=vocab_size, output_size=params.N_STYLES, hidden_size=params.H_DIM,
                            embedding_size=params.H_DIM, drop_rate=params.DO_RATE_CLS, num_layers=params.N_LAYERS_CLS,
                            num_layers_for_output=4)
    return model_trans, model_cls


# Models

In [0]:
def load_pretrained_embedding_to_encoder(src_embed, embedding):
    ''' Helper function to modify encoder model embedding with pre-trained
        embedding like Glove. '''
    src_embed.lut.weight.data.copy_(embedding)
    print('Loaded pre-calculated Glove embedding')

# Clear CUDA memory if needed
# TODO: local use
# torch.cuda.empty_cache()

### Init models ###

vocab_size = len(TEXT.vocab)
model_dec, model_cls = init_models(vocab_size, params)
# model_des = TransformerClassifier(output_size=params.N_STYLES+1, N=params.N_LAYERS_CLS, d_model=params.H_DIM,
#                                 d_ff=params.FC_DIM, h=params.N_ATTN_HEAD, dropout=params.DO_RATE_CLS,
#                                 input_size=vocab_size, max_len=params.MAX_LEN)

if params.H_DIM == 300:
  load_pretrained_embedding_to_encoder(model_dec.src_embed, word_embeddings)
  load_pretrained_embedding_to_encoder(model_cls.src_embed, word_embeddings)
#   load_pretrained_embedding_to_encoder(model_des.src_embed, word_embeddings)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'model_cls has {count_parameters(model_cls):,} trainable parameters')
print(f'model_dec has {count_parameters(model_dec):,} trainable parameters')
# print(f'model_des has {count_parameters(model_des):,} trainable parameters')


model_dec = model_dec.to(params.device)
model_cls = model_cls.to(params.device)
# model_des = model_des.to(params.device)

### Init optimizers ###
def get_warmup_steps_from_params(train_set_size, train_batch_size, n_epochs,
                                 enc_ratio, dec_ratio, cls_ratio):
    steps_per_epoch = train_set_size // train_batch_size
    n_total_steps = n_epochs * steps_per_epoch
    warmup_dec_steps = n_total_steps * dec_ratio
    warmup_cls_steps = n_total_steps * cls_ratio

    logging.info("total_steps {}, dec_warmup {}, cls_warmup {}".format(n_total_steps, warmup_dec_steps,
                                                        warmup_cls_steps))
    return warmup_dec_steps, warmup_cls_steps


dec_warmup, cls_warmup = get_warmup_steps_from_params(train_dataset_len,
                                                                  params.TRAIN_BATCH_SIZE,
                                                                  params.N_EPOCHS,
                                                                  params.ENC_WARMUP_RATIO,
                                                                  params.DEC_WARMUP_RATIO,
                                                                  params.CLS_WARMUP_RATIO)
des_warmup = dec_warmup
# cls_opt = get_std_opt(model_cls,h_dim=params.H_DIM, lr=params.CLS_LR, warmup=cls_warmup)
# opt_cls = torch.optim.Adam(filter(lambda p: p.requires_grad, model_cls.parameters()),
#                            lr=params.CLS_LR, weight_decay=1e-4)
# opt_dec = torch.optim.Adam(filter(lambda p: p.requires_grad, model_cls.parameters()),
#                            lr=params.DEC_LR, weight_decay=1e-4)

opt_cls = get_std_opt(model_cls,h_dim=params.H_DIM, lr=params.CLS_LR, warmup=4000, factor=1)
# opt_des = get_std_opt(model_des,h_dim=params.H_DIM, lr=params.CLS_LR, warmup=des_warmup)


# # cls_opt = torch.optim.SGD(model_cls.parameters(), 5e-5)

opt_dec = get_std_opt(model_dec,h_dim=params.H_DIM, lr=params.DEC_LR, warmup=3000, factor=0.7)
# dec_opt = get_std_opt(model_dec,h_dim=params.H_DIM, lr=params.DEC_LR, warmup=dec_warmup)
# cls_dec_opt = get_std_opt(model_cls_dec,h_dim=params.H_DIM, lr=params.DEC_LR, warmup=dec_warmup)

# early_stop = EarlyStopping(params.PATIENCE)

In [0]:
### Init losses ###
cls_criteria = nn.CrossEntropyLoss()
cls_criteria = cls_criteria.to(params.device)
des_criteria = nn.CrossEntropyLoss()
des_criteria = cls_criteria.to(params.device)

seq2seq_criteria = nn.CrossEntropyLoss(reduction='mean', ignore_index=1)
# seq2seq_criteria = MaskedCosineEmbeddingLoss(params.device)
seq2seq_criteria = seq2seq_criteria.to(params.device)

# ent_criteria = EntropyLoss()
# ent_criteria = ent_criteria.to(params.device)

# train funcs

In [0]:
def train_neg_label_step(model_dec, seq2seq_criteria, model_cls, model_des, cls_criteria, des_criteria,
                         opt_dec, src, src_mask, labels, rec_running_loss, rec_acc, cls_running_loss,
                         cls_acc, des_running_loss, des_acc, device, trans_cls=False, cyc_rec_lambda=1.0, cls_lambda=1.0, rec_lambda=0.0, des_lambda=0.0):
    model_dec.train()
    # Negate labels
    neg_labels = (~labels.byte()).long()
    neg_preds = model_dec(src, src_mask, neg_labels, argmax=False)
    # true_preds = model_dec(src, src_mask, labels, argmax=False)
    if trans_cls:
        cls_preds = model_cls(neg_preds, src_mask, argmax=True)
        # des_preds = model_des(neg_preds, src_mask, argmax=True)
    else:
        cls_preds = model_cls(neg_preds)
        # des_preds = model_des(neg_preds)

    # neg_sent_mask, _ = make_masks(neg_sent, neg_sent, device)
    cyc_preds = model_dec(neg_preds, src_mask, labels, argmax=True)

    opt_dec.zero_grad()
    # classifier
    cls_loss = cls_criteria(cls_preds, neg_labels)
    cls_acc.update(cls_preds, neg_labels)
    cls_running_loss.update(cls_loss)
    # descriminator
    # des_labels = torch.ones_like(labels)
    # des_labels = neg_labels
    # des_loss = des_criteria(des_preds, des_labels)
    # des_acc.update(des_preds, des_labels)
    # des_running_loss.update(des_loss)
    # rec and cycle_rec loss
    cyc_preds = cyc_preds.contiguous().view(-1, cyc_preds.size(-1))
    # true_preds = true_preds.contiguous().view(-1, true_preds.size(-1))
    # neg_preds = neg_preds.contiguous().view(-1, preds.size(-1))
    src = src.contiguous().view(-1)
    cyc_rec_loss = seq2seq_criteria(cyc_preds, src)
    # rec_loss = seq2seq_criteria(true_preds, src)
    rec_running_loss.update(cyc_rec_loss)
    rec_acc.update(cyc_preds, src)

    # optimize
    # loss = cyc_rec_lambda * cyc_rec_loss + rec_lambda * rec_loss + des_lambda * des_loss + cls_lambda * cls_loss
    # # loss = cyc_rec_lambda * cyc_rec_loss + des_lambda * des_loss
    loss = cyc_rec_lambda * cyc_rec_loss + cls_lambda * cls_loss
    # loss = cyc_rec_lambda * cyc_rec_loss + rec_lambda * rec_loss + cls_lambda * cls_loss
    loss.backward()
    # print("model_dec last layer grads", sum(sum(abs(model_dec.generator.weight.grad))))
    opt_dec.step()
    


def train_true_neg_cls_step(model_dec, model_cls, cls_criteria,
                            opt_cls, src, src_mask, labels, cls_running_loss,
                            cls_acc, trans_cls=False):
    # style classifier loss
    if trans_cls:
        cls_preds = model_cls(src, src_mask, argmax=False)
    else:
        cls_preds = model_cls(src)

    opt_cls.zero_grad()
    cls_loss = cls_criteria(cls_preds, labels)
    cls_acc.update(cls_preds, labels)
    cls_running_loss.update(cls_loss)
    cls_loss.backward()
    # print("cls_encode sum abs grads", model_cls.src_embed.lut.weight.grad)
    opt_cls.step()

def train_des_step(model_dec, model_des, opt_des, criteria,
                         src, src_mask, labels, des_running_loss,
                         des_acc, device, trans_des=False):
    model_dec.eval()
    # Negate labels
    neg_labels = (~labels.byte()).long()

    with torch.no_grad():
        neg_preds = model_dec(src, src_mask, neg_labels, argmax=False)
    if trans_des:
        fake_preds = model_des(neg_preds, src_mask, argmax=True)
        true_preds = model_des(src, src_mask, argmax=False)
    else:
        fake_preds = model_des(neg_preds)
        true_preds = model_des(src)


    # true_labels = torch.ones_like(labels)
    true_labels = labels
    des_acc.update(true_preds, true_labels)
    true_loss = criteria(true_preds, true_labels)

    # fake_labels = torch.zeros_like(labels)
    fake_labels = 2 * torch.ones_like(labels)
    fake_loss = criteria(fake_preds, fake_labels)
    des_acc.update(fake_preds, fake_labels)

    loss = (true_loss + fake_loss) / 2
    des_running_loss.update(loss)

    opt_des.zero_grad()
    loss.backward()
    # print("model_dec last layer grads", sum(sum(abs(model_dec.generator.weight.grad))))
    opt_des.step()

def run_epoch_true_neg(epoch, data_iter, model_dec, opt_dec,
                       model_cls, opt_cls, model_des, opt_des, cls_criteria, des_criteria, seq2seq_criteria,
                       params):
    verbose = params.VERBOSE
    device = params.device
    total_steps = len(data_iter.dataset) // params.TRAIN_BATCH_SIZE
    period_steps = params.PERIOD_STEPS
    logging.info('total epoch steps {}, period size {}'.format(total_steps,
                                                               period_steps))

    cls_running_loss = Loss()
    rec_running_loss = Loss()
    des_running_loss = Loss()


    cls_acc = AccuracyCls()
    des_acc = AccuracyCls()
    rec_acc = AccuracyRec()

    model_cls.train()
    model_dec.train()
    # model_des.train()
    curr_step = 0
    curr_phase = 2
    for step, batch in enumerate(data_iter):
        # prepare batch

        src, labels = batch.text, batch.label

        src_mask, _ = make_masks(src, src, device)

        src = src.to(device)
        src_mask = src_mask.to(device)
        labels = labels.to(device)

        if curr_phase == 0:  # training the classifier
            train_true_neg_cls_step(model_dec, model_cls, cls_criteria,
                                    opt_cls, src, src_mask, labels, cls_running_loss,
                                    cls_acc, trans_cls=params.TRANS_CLS)
            curr_step += 1
            if curr_step == period_steps:
                if verbose:
                    logging.info(
                        "e-{},s-{}: Trained cls loss {:.3f} acc {:.3f}".format(epoch, step, cls_running_loss(),
                                                                        cls_acc()))
                cls_running_loss.reset()
                curr_step = 0
                if cls_acc() >= params.CLS_ACC_BAR:
                    curr_phase = 1
                cls_acc.reset()

        elif curr_phase == 1:
            train_des_step(model_dec, model_des, opt_des, des_criteria,
                         src, src_mask, labels, des_running_loss,
                         des_acc, device, trans_des=params.TRANS_DES)
            curr_step += 1
            if curr_step == period_steps:
                if verbose:
                    logging.info(
                        "e-{},s-{}: Trained des, loss {:.3f}, acc {:.3f}".format(epoch,
                                                                                            step,
                                                                                            des_running_loss(),
                                                                                            des_acc()))
                curr_step = 0
                if des_acc() >= params.DES_ACC_BAR:
                    curr_phase = curr_phase + 1
                des_running_loss.reset()
                des_acc.reset()

        else:
            train_neg_label_step(model_dec=model_dec, seq2seq_criteria=seq2seq_criteria, model_cls=model_cls, cls_criteria=cls_criteria, des_criteria=des_criteria, model_des=model_des,
                                 des_running_loss=des_running_loss, des_acc=des_acc,
                                 opt_dec=opt_dec, src=src, src_mask=src_mask, labels=labels, rec_running_loss=rec_running_loss, rec_acc=rec_acc, cls_running_loss=cls_running_loss,
                                 cls_acc=cls_acc, device=params.device, trans_cls=params.TRANS_CLS, cyc_rec_lambda=params.NEG_CYC_REC_LAMBDA,
                                 rec_lambda=params.NEG_REC_LAMBDA, cls_lambda=params.NEG_CLS_LAMBDA, des_lambda=params.NEG_DES_LAMBDA)

            curr_step += 1
            if curr_step == period_steps:
                if verbose:
                    logging.info(
                        "e-{},s-{}: Trained transformer on negated label, cls_loss {:.3f}, cls_acc {:.3f}, rec_loss {:.3f}, rec_acc {:.3f}, des_loss {:.3f}, des_acc {:.3f}".format(
                            epoch,
                            step,
                            cls_running_loss(),
                            cls_acc(),
                            rec_running_loss(),
                            rec_acc(),
                            des_running_loss(),
                            des_acc()))
                curr_step = 0
                curr_phase = 2
                cls_acc.reset()
                rec_running_loss.reset()
                cls_running_loss.reset()
                rec_acc.reset()
                des_running_loss.reset()
                des_acc.reset()


# eval funcs

In [0]:
from torch.nn import CosineSimilarity

def evaluate_true_neg(epoch, data_iter, src_embed, model_dec,
             model_cls, cls_criteria, seq2seq_criteria,
             params):
    ''' Evaluate performances over test/validation dataloader '''

    device = params.device
    trans_cls = params.TRANS_CLS

    model_cls.eval()
    model_dec.eval()

    cls_running_loss = Loss()
    rec_running_loss = Loss()

    rec_acc = AccuracyRec()
    cls_acc = AccuracyCls()

    with torch.no_grad():
        for i, batch in enumerate(data_iter):
            if params.TEST_MAX_BATCH_SIZE and i == params.TEST_MAX_BATCH_SIZE:
                break

            # Prepare batch
            src, labels = batch.text, batch.label
            src_mask, _ = make_masks(src, src, device)
            src = src.to(device)
            src_mask = src_mask.to(device)
            labels = labels.to(device)

            # Negate labels
            neg_labels = (~labels.byte()).long()

            preds = model_dec(src, src_mask, neg_labels)
            neg_sent = torch.argmax(preds, dim=-1)
            preds_for_cls = src_embed(neg_sent)
            if trans_cls:
                cls_preds = model_cls(preds_for_cls, src_mask)
            else:
                cls_preds = model_cls(preds_for_cls)

            # neg_sent_mask, _ = make_masks(neg_sent, neg_sent, device)
            preds = model_dec(neg_sent, src_mask, labels)

            cls_loss = cls_criteria(cls_preds, neg_labels)
            cls_acc.update(cls_preds, neg_labels)
            cls_running_loss.update(cls_loss)
            preds = preds.contiguous().view(-1, preds.size(-1))
            src = src.contiguous().view(-1)
            rec_loss = seq2seq_criteria(preds, src)
            rec_running_loss.update(rec_loss)
            rec_acc.update(preds, src)

    logging.info("Eval-e-{}: loss cls: {:.3f}, acc cls: {:.3f}, loss rec: {:.3f}, acc rec: {:.3f}".format(epoch, cls_running_loss(),
                                                                                         cls_acc(), rec_running_loss(),
                                                                                         rec_acc()))
    
def greedy_decode_sent(preds, id2word, eos_id):
    ''' Nauve greedy decoding - just argmax over the vocabulary distribution '''
    preds = torch.argmax(preds, -1)
    decoded_sent = preds.squeeze(0).detach().cpu().numpy()
    # print(" ".join([id2word[i] for i in decoded_sent]))
    decoded_sent = sent2str(decoded_sent, id2word, eos_id)
    return decoded_sent, preds


def sent2str(sent_as_np, id2word, eos_id=None):
    ''' Gets sentence as a list of ids and transfers to string
        Input is np array of ids '''
    if not (isinstance(sent_as_np, np.ndarray)):
        raise ValueError('Invalid input type, expected np array')
    if eos_id:
        end_id = np.where(sent_as_np == eos_id)[0]
        if len(end_id) > 1:
            sent_as_np = sent_as_np[:int(end_id[0])]
        elif len(end_id) == 1:
            sent_as_np = sent_as_np[:int(end_id)]

    return " ".join([id2word[i] for i in sent_as_np])



def test_random_samples(data_iter, TEXT, model_dec, model_cls, device, src_embed=None, decode_func=None, num_samples=2,
                    transfer_style=True, trans_cls=False, embed_preds=False):
    ''' Print some sample text to validate the model.
        transfer_style - bool, if True apply style transfer '''

    word2id = TEXT.vocab.stoi
    eos_id = int(word2id['<eos>'])
    id2word = {v: k for k, v in word2id.items()}
    model_dec.eval()

    with torch.no_grad():
        for step, batch in enumerate(data_iter):
            if num_samples == 0: break

            # Prepare batch
            src, labels = batch.text[0, ...], batch.label[0, ...]
            src = src.unsqueeze(0)
            labels = labels.unsqueeze(0)
            src_mask, _ = make_masks(src, src, device)

            src = src.to(device)
            src_mask = src_mask.to(device)
            labels = labels.to(device)
            true_labels = copy.deepcopy(labels)

            # Logical not on labels if transfer_style is set
            if transfer_style:
                labels = (~labels.byte()).long()
            # print("Original label ", true_labels, " Transfer label ", labels)
            if src_embed:
                embeds = src_embed(src)
                preds = model_dec(embeds, src_mask, labels)
            else:
                preds = model_dec(src, src_mask, labels)

            sent_as_list = src.squeeze(0).detach().cpu().numpy()
            src_sent = sent2str(sent_as_list, id2word, eos_id)
            src_label = 'pos' if true_labels.detach().item() == 1 else 'neg'
            logging.info('Original: text: {}'.format(src_sent))
            logging.info('Original: class: {}'.format(src_label))

            if embed_preds:
                preds = preds_embedding_cosine_similarity(preds, model_dec.src_embed)
            if decode_func:
                dec_sent, decoded = decode_func(preds, id2word, eos_id)
                if src_embed:
                    decoded = src_embed(decoded)
                if trans_cls:
                    cls_preds = model_cls(decoded, src_mask)
                else:
                    cls_preds = model_cls(decoded)
                pred_label = 'pos' if torch.argmax(cls_preds) == 1 else 'neg'
                if transfer_style:
                    logging.info('Style transfer output:')
                logging.info('Predicted: text: {}'.format(dec_sent))
                logging.info('Predicted: class: {}'.format(pred_label))

            else:
                logging.info('Predicted: class: {}'.format(pred_label))
            logging.info('\n')

            num_samples -= 1

def test_user_string(sent, label, TEXT, model_dec, model_cls, device, decode_func=None,
                    transfer_style=True, trans_cls=False, embed_preds=False):
    ''' Print some sample text to validate the model.
        transfer_style - bool, if True apply style transfer '''

    word2id = TEXT.vocab.stoi
    eos_id = int(word2id['<eos>'])
    id2word = {v: k for k, v in word2id.items()}
    # define tokenizer
    en = English()
    def id_tokenize(sentence):
        return [word2id[tok.text] for tok in en.tokenizer(sentence)]
    
    model_dec.eval()

    with torch.no_grad():
        # Prepare batch
        
        token_ids = id_tokenize[sent]
        src = torch.LongTensor(token_ids)
        labels = torch.LongTensor(label).unsqueeze(0)
        src_mask, _ = make_masks(src, src, device)

        src = src.to(device)
        src_mask = src_mask.to(device)
        labels = labels.to(device)
        true_labels = copy.deepcopy(labels)

        # Logical not on labels if transfer_style is set
        if transfer_style:
            labels = (~labels.byte()).long()
        print(labels, true_labels)

        preds = model_dec(src, src_mask, labels)

        src_label = 'pos' if true_labels.detach().item() == 1 else 'neg'
        logging.info('Original: text: {}'.format(src_sent))
        logging.info('Original: class: {}'.format(src_label))

        if embed_preds:
            preds = preds_embedding_cosine_similarity(preds, model_dec.src_embed)
        if decode_func:
            dec_sent, decoded = decode_func(preds, id2word, eos_id)
            preds_for_cls = model_dec.src_embed(decoded)
            if trans_cls:
                cls_preds = model_cls(preds_for_cls, src_mask)
            else:
                cls_preds = model_cls(preds_for_cls)
            pred_label = 'pos' if torch.argmax(cls_preds) == 1 else 'neg'
            if transfer_style:
                logging.info('Style transfer output:')
            logging.info('Predicted: text: {}'.format(dec_sent))
            logging.info('Predicted: class: {}'.format(pred_label))

        else:
            logging.info('Predicted: class: {}'.format(pred_label))
        logging.info('\n')

# pretrain generator

In [0]:
for epoch in range(1):
    verbose = params.VERBOSE
    device = params.device

    rec_running_loss = Loss()
    rec_acc = AccuracyRec()

    model_dec.train()
    # preds = torch.FloatTensor(params.TRAIN_BATCH_SIZE, params.MAX_LEN, vocab_size)
    # preds = preds.to(device)
    for step, batch in enumerate(train_iter):
        # prepare batch
        src, labels = batch.text, batch.label
        src_mask, _ = make_masks(src, src, device)

        src = src.to(device)
        src_mask = src_mask.to(device)
        labels = labels.to(device)
        # neg_labels = (~labels.byte()).long()

        # neg_preds = model_dec(src, src_mask, neg_labels, argmax=False)
        preds = model_dec(src, src_mask, labels, argmax=False)
        # preds = model_dec(neg_preds, src_mask, labels, argmax=True)
        preds = preds.contiguous().view(-1, preds.size(-1))
        src = src.contiguous().view(-1)
        rec_loss = seq2seq_criteria(preds, src)
        rec_running_loss.update(rec_loss)
        rec_acc.update(preds, src)

        # optimize decoder
        loss = rec_loss
        opt_dec.zero_grad()
        loss.backward()
        opt_dec.step()

        if verbose and step % 20 == 19:
            logging.info(
            "e-{},s-{}: Pre-Training transformer on rec, rec_loss {}, rec_acc {}".format(epoch,
                                                                                        step,
                                                                                        rec_running_loss(),
                                                                                        rec_acc()))
            if rec_acc() >= 50.0:
                break
            rec_running_loss.reset()
            rec_acc.reset()


INFO - 09/11/19 15:38:04 - 1:19:03 - e-0,s-19: Pre-Training transformer on rec, rec_loss 9.647946643829346, rec_acc 0.0710025560920193
INFO - 09/11/19 15:38:07 - 1:19:06 - e-0,s-39: Pre-Training transformer on rec, rec_loss 9.534781312942505, rec_acc 7.9977706562630635
INFO - 09/11/19 15:38:10 - 1:19:08 - e-0,s-59: Pre-Training transformer on rec, rec_loss 9.344154119491577, rec_acc 20.058139534883722
INFO - 09/11/19 15:38:12 - 1:19:11 - e-0,s-79: Pre-Training transformer on rec, rec_loss 9.125034761428832, rec_acc 20.858895705521473
INFO - 09/11/19 15:38:15 - 1:19:14 - e-0,s-99: Pre-Training transformer on rec, rec_loss 8.913456058502197, rec_acc 19.505219505219507
INFO - 09/11/19 15:38:18 - 1:19:17 - e-0,s-119: Pre-Training transformer on rec, rec_loss 8.661820554733277, rec_acc 19.826743048763447
INFO - 09/11/19 15:38:21 - 1:19:19 - e-0,s-139: Pre-Training transformer on rec, rec_loss 8.391303777694702, rec_acc 23.21124361158433
INFO - 09/11/19 15:38:23 - 1:19:22 - e-0,s-159: Pre-Tr

# pretrain cls

In [0]:
from torch.nn import functional as F
# model_cls = model_des
# opt_cls = opt_des

# def fill_one_hot_src(src, y_onehot):
#     y_onehot.zero_()
#     y_onehot.scatter_(dim=-1, index=src.unsqueeze(-1), value=1)

for epoch in range(1):
    verbose = params.VERBOSE
    device = params.device

    cls_running_loss = Loss()
    cls_acc = AccuracyCls()

    model_cls.train()
    # preds = torch.FloatTensor(params.TRAIN_BATCH_SIZE, params.MAX_LEN, vocab_size)
    # preds = preds.to(device)
    for step, batch in enumerate(train_iter):
        # prepare batch
        src, labels = batch.text, batch.label
        # if src.shape[0] != preds.shape[0]:
        #     continue
        src_mask, _ = make_masks(src, src, device)

        src = src.to(device)
        src_mask = src_mask.to(device)
        labels = labels.to(device)
        if params.TRANS_CLS:
            cls_preds = model_cls(src, src_mask, argmax=False)
        else:
            cls_preds = model_cls(src)
        
        opt_cls.zero_grad()
        cls_loss = cls_criteria(cls_preds, labels)
        cls_acc.update(cls_preds, labels)
        cls_running_loss.update(cls_loss)
        cls_loss.backward()
        opt_cls.step()

        # train_true_neg_cls_step(model_dec, model_cls, cls_criteria,
        #                 opt_dec, opt_cls,
        #                 src, src_mask, labels,  cls_running_loss,
        #                 cls_acc, device=device, trans_cls=params.TRANS_CLS)


        if verbose and step%100 == 99:
            logging.info(
                "e-{},s-{}: Training cls loss {} acc {}".format(epoch, step, cls_running_loss(),
                                                                cls_acc()))
            cls_running_loss.reset()
            cls_acc.reset()

INFO - 09/11/19 14:47:37 - 0:28:35 - e-0,s-99: Training cls loss 0.72555723965168 acc 55.25
INFO - 09/11/19 14:47:42 - 0:28:40 - e-0,s-199: Training cls loss 0.5212832516431809 acc 77.3125
INFO - 09/11/19 14:47:46 - 0:28:45 - e-0,s-299: Training cls loss 0.3840521217882633 acc 89.0625
INFO - 09/11/19 14:47:51 - 0:28:50 - e-0,s-399: Training cls loss 0.32129169657826423 acc 93.3125
INFO - 09/11/19 14:47:56 - 0:28:55 - e-0,s-499: Training cls loss 0.31894743502140044 acc 92.90625
INFO - 09/11/19 14:48:01 - 0:29:00 - e-0,s-599: Training cls loss 0.30550199940800665 acc 93.5625
INFO - 09/11/19 14:48:06 - 0:29:05 - e-0,s-699: Training cls loss 0.28642746567726135 acc 95.40625
INFO - 09/11/19 14:48:11 - 0:29:10 - e-0,s-799: Training cls loss 0.29774092480540276 acc 94.3125
INFO - 09/11/19 14:48:16 - 0:29:15 - e-0,s-899: Training cls loss 0.28578834801912306 acc 95.03125
INFO - 09/11/19 14:48:21 - 0:29:20 - e-0,s-999: Training cls loss 0.2721908695995808 acc 96.09375
INFO - 09/11/19 14:48:26 

In [0]:
model_cls_path = os.path.join(params.MODELS_PATH, "model_cls_{}_yelp_freq{}_len{}_dim_{}.pth".format(params.N_LAYERS_CLS, params.VOCAB_MIN_FREQ, params.MAX_LEN, params.H_DIM))
model_dec_path = os.path.join(params.MODELS_PATH, "model_dec_{}_yelp_freq{}_len{}_dim{}_pretrain_50.pth".format(params.N_LAYERS, params.VOCAB_MIN_FREQ, params.MAX_LEN, params.H_DIM))

torch.save(model_cls.state_dict(), model_cls_path)
torch.save(model_dec.state_dict(), model_dec_path)

In [0]:
model_cls_path = os.path.join(params.MODELS_PATH, "model_cls_{}_yelp_freq{}_len{}_dim{}.pth".format(params.N_LAYERS_CLS, params.VOCAB_MIN_FREQ, params.MAX_LEN, params.H_DIM))
model_dec_path = os.path.join(params.MODELS_PATH, "model_dec_{}_yelp_freq{}_len{}.pth_dim{}".format(params.N_LAYERS, params.VOCAB_MIN_FREQ, params.MAX_LEN, params.H_DIM))

model_cls.load_state_dict(torch.load(model_cls_path))
model_dec.load_state_dict(torch.load(model_dec_path))

# run

In [0]:
for epoch in range(params.N_EPOCHS):
    model_dec_path = os.path.join(params.MODELS_PATH, "model_dec_{}_yelp_freq{}_len{}_dim{}_sample_e{}.pth".format(params.N_LAYERS, params.VOCAB_MIN_FREQ, params.MAX_LEN, params.H_DIM, epoch))
    run_epoch_true_neg(epoch=epoch, data_iter=train_iter,
                model_dec=model_dec, opt_dec=opt_dec, model_des=None, opt_des=None,
                model_cls=model_cls, opt_cls=opt_cls, cls_criteria=cls_criteria,
                seq2seq_criteria=seq2seq_criteria, des_criteria=des_criteria,
                params=params)
    torch.save(model_dec.state_dict(), model_dec_path)
#   test_acc = evaluate_true_neg(epoch, test_iter, model_dec,
#                       model_cls, cls_criteria, seq2seq_criteria,
#                       params)

    test_random_samples(train_iter, TEXT, model_dec, model_cls, params.device,
                        decode_func=greedy_decode_sent, num_samples=5, transfer_style=True,
                        trans_cls=params.TRANS_CLS)

  # TODO - Roy - currently not in use, what metric to follow ?
  # early_stop(test_acc)
  # if early_stop.early_stop:
  #     break


INFO - 09/11/19 15:38:53 - 1:19:52 - total epoch steps 7366, period size 100
INFO - 09/11/19 15:39:20 - 1:20:19 - e-0,s-99: Trained transformer on negated label, cls_loss 0.784, cls_acc 55.250, rec_loss 4.524, rec_acc 54.759, des_loss 0.000, des_acc 0.000
INFO - 09/11/19 15:39:47 - 1:20:45 - e-0,s-199: Trained transformer on negated label, cls_loss 0.807, cls_acc 54.875, rec_loss 3.854, rec_acc 58.046, des_loss 0.000, des_acc 0.000
INFO - 09/11/19 15:40:13 - 1:21:12 - e-0,s-299: Trained transformer on negated label, cls_loss 0.810, cls_acc 55.094, rec_loss 3.416, rec_acc 60.743, des_loss 0.000, des_acc 0.000
INFO - 09/11/19 15:40:39 - 1:21:38 - e-0,s-399: Trained transformer on negated label, cls_loss 0.806, cls_acc 55.156, rec_loss 3.020, rec_acc 65.240, des_loss 0.000, des_acc 0.000
INFO - 09/11/19 15:41:06 - 1:22:04 - e-0,s-499: Trained transformer on negated label, cls_loss 0.791, cls_acc 56.563, rec_loss 2.732, rec_acc 68.196, des_loss 0.000, des_acc 0.000
INFO - 09/11/19 15:41:32

KeyboardInterrupt: ignored

In [0]:
test_iter.shuffle = False

test_random_samples(train_iter, TEXT, model_dec, model_cls, params.device,
                    decode_func=greedy_decode_sent, num_samples=20, transfer_style=True,
                    trans_cls=params.TRANS_CLS)


INFO - 09/11/19 17:13:19 - 2:54:18 - Original: text: just simple hands on learning from very knowledgable staff!.
INFO - 09/11/19 17:13:19 - 2:54:18 - Original: class: neg
INFO - 09/11/19 17:13:19 - 2:54:18 - Style transfer output:
INFO - 09/11/19 17:13:19 - 2:54:18 - Predicted: text: just simple hands on experience from very patrons 1997
INFO - 09/11/19 17:13:19 - 2:54:18 - Predicted: class: neg
INFO - 09/11/19 17:13:19 - 2:54:18 - 
                                     
INFO - 09/11/19 17:13:19 - 2:54:18 - Original: text: the esthetician is amazing as well .
INFO - 09/11/19 17:13:19 - 2:54:18 - Original: class: neg
INFO - 09/11/19 17:13:19 - 2:54:18 - Style transfer output:
INFO - 09/11/19 17:13:19 - 2:54:18 - Predicted: text: the awesome is disgusting as well .
INFO - 09/11/19 17:13:19 - 2:54:18 - Predicted: class: pos
INFO - 09/11/19 17:13:19 - 2:54:18 - 
                                     
INFO - 09/11/19 17:13:19 - 2:54:18 - Original: text: we will definitely be back!.
INFO - 09

# Experiments