### Adversarial Text Example Experiment Runner Template
11/6/2017 - Basic pipeline to run adversarial text generation experiments.

### Dataset Preparation
Base dataset: The Enron Spam Dataset: http://www2.aueb.gr/users/ion/data/enron-spam/ 
    

In [59]:
from __future__ import division

import os
import numpy as np
import scipy
import scipy.stats
import sklearn
import sklearn.feature_extraction, sklearn.naive_bayes, sklearn.metrics, sklearn.externals
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

from seq2seq.model import Seq2Seq, Seq2SeqAutoencoder

# Set CUDA Visible devices
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=3

use_cuda = torch.cuda.is_available()
print "Use CUDA:" + str(use_cuda)

# Initialize some key paths
base_dir = "/cvgl2/u/catwong/cs332_final_project/"
base_data_dir = os.path.join(base_dir, 'data')
base_checkpoints_dir = os.path.join(base_dir, 'checkpoints')

nb_discriminator_ckpt = 'discriminator_multinomial_nb.pkl'

# Other constants.
# The number of terms, including special tokens, in the final vocabulary.
TRAINING_VOCAB_SIZE = {
    100: 4480,
    30: 4628
} 


env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=3
Use CUDA:True


In [3]:
## Spam Preprocessing - UNIX Command line
# 1. Removed all \n and replaced with spaces: find . -type f -exec perl -i. -pe 's/\r?\n/ /' {} +
# 2. Concatenated all spam into a single file and all ham into a single file.
#      To concatenate within dirs: awk 1 enron1/ham/*.txt > enron1_ham.txt 
# 3. Randomly shuffled: shuf input > output
# 4. Create 80, 10, 10 train, val, and test splits.

# Total ham: 16545 messages; train/val/test = 13236, 1654, 1655
# Total spam: 17171 messages; train/val/test = 13736, 1717, 1718

In [34]:

base_data_dir = "/cvgl2/u/catwong/cs332_final_project/data/"
classes = ['spam', 'ham']
vocabulary_filename = '30_email_train_vocab.txt' 
# Truncation and vocabulary shortening:
# Using the train data only:
# 1. Truncate both the spam and ham messages to truncation_len characters (adding padding where needed).
# 2. From the truncated messages, compile a vocabulary of the class_vocabulary_size most frequent tokens for each class.
# 3. Write a vocabulary file composed of the full, combined vocabulary (ie. the most frequent tokens across both classes.)
truncation_len = 30
class_vocabulary_size = 3000

combined_vocab = []
for class_name in classes:
    filename = os.path.join(base_data_dir, 'train', class_name + '.txt') 
    print "Now processing: %s" % filename
    with open(filename) as f:
        all_lines = [line.strip().lower().split() for line in f.readlines()]
        
    # Truncate the files.
    truncated_lines = [line[:truncation_len] for line in all_lines]
    
    # Add tokens to the counter
    token_counts = Counter()
    for line in truncated_lines:
        token_counts.update(line)
    combined_vocab += [elem for (elem, count) in token_counts.most_common(class_vocabulary_size)]

# Convert the combined vocabulary into a set.
combined_vocab = set(combined_vocab)
print len(combined_vocab)
# Write out the combined_vocab to the vocabulary file
print vocabulary_filename
with open(os.path.join(base_data_dir, vocabulary_filename), 'w') as f:
    for token in combined_vocab:
        f.write(token + "\n")

Now processing: /cvgl2/u/catwong/cs332_final_project/data/train/spam.txt
Now processing: /cvgl2/u/catwong/cs332_final_project/data/train/ham.txt
4624
30_email_train_vocab.txt


In [41]:
# Class that takes in a file and a vocabulary file (which has a truncation len) and converts the text into
# encoded/truncated sentences.

class DatasetEncoderDecoder(object):
    """
    Encodes and decodes sentences according to a vocabulary.
    
    Sentences are truncated. OOV words are assigned an <UNK> token, and <SOS>, <PAD>, and <EOS> tokens are added.
    
    truncation_len
    """
    def __init__(self, vocab_file, truncation_len=100):
        self.truncation_len = truncation_len
        # Create index to word and word to index dicts from the vocab_file.
        num_default_tokens = 4
        self.index2word = {0:'<SOS>', 1:'<EOS>', 2: '<UNK>', 3: '<PAD>'}
        self.word2index = {'<SOS>':0, '<EOS>':1, '<UNK>': 2, '<PAD>': 3}
        with open(vocab_file) as f:
            all_lines = [line.strip() for line in f.readlines()]
        for idx, token in enumerate(all_lines):
            self.index2word[idx + num_default_tokens] = token
            self.word2index[token] = idx + num_default_tokens
          
    def encode(self, sentence):
        """
        Encodes a sentence according to the vocabulary.
        Returns:
            normalized: the normalized sentence, as it would be decoded.
            encoded: the space-separated numerical sentence.
        """
        truncated = sentence.lower().split()[:self.truncation_len]
        truncated += ['<PAD>'] * max(self.truncation_len - len(truncated), 0)
        truncated = ['<SOS>'] + truncated + ['<EOS>']
        
        normalized = []
        encoded = []
        # Encode, removing the UNK tokens
        for token in truncated:
            token = token if token in self.word2index else '<UNK>'
            normalized.append(token)
            encoded.append(str(self.word2index[token]))
        
        normalized = " ".join(normalized)
        encoded = " ".join(encoded)
        return normalized, encoded
    
    def decode_numpy(self, numerical_encoded):
        """Returns the decoded sentence."""
        return " ".join([self.index2word[token] for token in numerical_encoded])

    def decode(self, encoded):
        """Returns the decoded sentence."""
        numerical_encoded = [int(token) for token in encoded.split()]
        return " ".join([self.index2word[token] for token in numerical_encoded])

# Demonstration:
vocab_file = 'data/email_train_vocab.txt'
sample_text = 'Subject: does your business depend on the online success of your website ? submitting your website in search engines may increase your online sales dramatically . if you invested time and money into your website , you simply must submit your website online otherwise it will be invisible virtually , which means efforts spent in vain . if you want people to know about your website and boost your revenues , the only way to do that is to make your site visible in places where people search for information , i . e . submit your website in multiple search engines . submit your website online and watch visitors stream to your e - business . best regards , myrtice melendez'
demo = DatasetEncoderDecoder(vocab_file)
normalized, encoded = demo.encode(sample_text)
print sample_text
print normalized
print encoded
decoded = demo.decode(encoded)
print decoded


Subject: does your business depend on the online success of your website ? submitting your website in search engines may increase your online sales dramatically . if you invested time and money into your website , you simply must submit your website online otherwise it will be invisible virtually , which means efforts spent in vain . if you want people to know about your website and boost your revenues , the only way to do that is to make your site visible in places where people search for information , i . e . submit your website in multiple search engines . submit your website online and watch visitors stream to your e - business . best regards , myrtice melendez
<SOS> subject: does your business <UNK> on the online success of your website ? submitting your website in search engines may increase your online sales dramatically . if you invested time and money into your website , you simply must submit your website online otherwise it will be invisible virtually , which means efforts s

In [44]:
# Write the train, test, and text encoded files using this encoder.
base_data_dir = "/cvgl2/u/catwong/cs332_final_project/data/"
splits = ['train', 'val', 'test']
classes = ['spam.txt', 'ham.txt']
truncation_len = 100
vocab_file = 'data/100_email_train_vocab.txt'

vocab_encoder = DatasetEncoderDecoder(vocab_file, truncation_len=truncation_len)
for split in splits:
    for class_file in classes:
        raw_file = os.path.join(base_data_dir, split, class_file)
        with open(raw_file) as f:
            all_lines = [line.strip() for line in f.readlines()]
        # Encode the lines
        encoded_lines = [vocab_encoder.encode(line)[1] for line in all_lines]
        
        # Write out the encoded line
        encoded_file = os.path.join(base_data_dir, split, str(truncation_len) + '_encoded_' + class_file)
        with open(encoded_file, 'w') as f:
            for line in encoded_lines:
                f.write(line + "\n")

In [51]:
# Samples of the encoded data
base_data_dir = "/cvgl2/u/catwong/cs332_final_project/data/"
splits = ['train', 'val', 'test']
classes = ['encoded_spam.txt', 'encoded_ham.txt']

truncation_len = 30
vocab_file = 'data/30_email_train_vocab.txt'
vocab_encoder = DatasetEncoderDecoder(vocab_file, truncation_len=truncation_len)
for class_file in classes:
    sample_file = os.path.join(base_data_dir, splits[0], str(truncation_len) + "_" + class_file)
    print "Sample file: " + sample_file
    with open(sample_file) as f:
        all_lines = [line.strip() for line in f.readlines()]
    sample_line = all_lines[0]
    print "Sample line: " + sample_line
    print "Sample decoding: " + vocab_encoder.decode(sample_line)

Sample file: /cvgl2/u/catwong/cs332_final_project/data/train/30_encoded_spam.txt
Sample line: 0 561 688 3176 4083 2 4287 2 2861 2 3221 1793 1586 997 2445 788 2 2236 4440 2 3503 1649 688 2 1351 319 342 2969 1330 2467 1694 1
Sample decoding: <SOS> subject: news alert ( <UNK> ) <UNK> orders <UNK> $ 3 million dollars what is <UNK> technologies ? <UNK> issued 2 news <UNK> today , one during market hours and <EOS>
Sample file: /cvgl2/u/catwong/cs332_final_project/data/train/30_encoded_ham.txt
Sample line: 0 561 4511 2173 3084 3942 1830 4327 2265 4075 2378 596 853 4511 2173 853 3084 3942 1830 740 853 3942 1830 2 3866 3591 970 1694 4514 2159 1694 1
Sample decoding: <SOS> subject: formation of enron management committee i am pleased to announce the formation of the enron management committee . the management committee <UNK> our business unit and function leadership and <EOS>


In [67]:
class SpamDataset(object):
    """
    Dataset: encapsulates utility functions to get the dataset files.
    """
    def __init__(self,
                 base_data_dir="/cvgl2/u/catwong/cs332_final_project/data/",
                 splits=['train', 'val', 'test'],
                 label_names=['ham', 'spam'],
                 truncation_len=100,
                 encoded_files=['encoded_ham.txt', 'encoded_spam.txt'],
                 vocab_file='email_train_vocab.txt',
                 random_seed=10):
        self.base_data_dir = base_data_dir
        self.splits = splits
        self.label_names = label_names
        self.encoded_files = [str(truncation_len) + "_" + f for f in encoded_files]
        self.vocab_file = os.path.join(base_data_dir, str(truncation_len) + "_" + vocab_file)
        self.vocab_encoder = DatasetEncoderDecoder(self.vocab_file, truncation_len=truncation_len)
        self.random_seed = random_seed
        
        # Read in all of the lines from the files.
        self.examples_dict = {}
        self.labels_dict = {}
        for split in splits:
            all_examples = []
            all_labels = []
            for label, encoded_file in enumerate(self.encoded_files):
                data_file = os.path.join(base_data_dir, split, encoded_file)
                with open(data_file) as f:
                    all_lines = [line.strip().split() for line in f.readlines()]
                all_examples += all_lines
                all_labels += [label] * len(all_lines)
            self.examples_dict[split] = all_examples
            self.labels_dict[split] = all_labels
            
    
    def examples(self, 
                 split, 
                 shuffled=False):
        """
        Args:
            split: one of the splits (ex. train, val, test) with labels.
            shuffled: whether to shuffle the examples.(default: True)
        Returns:
            examples: (list of lists)
            labels: (list)
        """
        examples = np.array(self.examples_dict[split]).astype(int)
        labels = np.array(self.labels_dict[split])
        if shuffled:
            examples, labels = sklearn.utils.shuffle(examples, labels, random_state=self.random_seed)
        return examples, labels
    
    def dataset_stats(self):
        """Prints useful stats about the dataset."""
        for split in self.splits:
            labels = self.labels_dict[split]
            num_pos = np.sum(labels)
            num_neg = len(labels) - num_pos
            print "Total %s examples: %d, %s: %d, %s: %d" % (split, len(labels), self.label_names[0], num_neg, self.label_names[1], num_pos)
            

# Demo
dataset = SpamDataset(truncation_len=100)
examples, labels =  dataset.examples(split='train', shuffled=False)
print examples[0]
print labels[0]
print dataset.vocab_encoder.decode(" ".join(examples[0].astype(str)))
dataset.dataset_stats()

[   0  542    2 2129  728 3797 2589 4193 2226 3936 2336  581  838    2 2129
  838  728 3797 2589  720  838 3797 2589    2 3727 3459  143 1658 4370 2119
 1658 4125 1149 2124  838 4040 3797  302 2992  302 1658 1469 3395    2  728
  720  838 3797 2589 4125    2  838 1914 1469 2589 1658 4125 3406  838 1355
  957 3267 1794  252  501 3134 1658 2621  302  728  976  720  495 2748  501
    2  302  728 3578 2471 3431 1184  501    2  302  728 1608  577 4371  501
 2975 2271 3537 3638   38  914  302  728  976  720  577    1]
0
<SOS> subject: <UNK> of enron management committee i am pleased to announce the <UNK> of the enron management committee . the management committee <UNK> our business unit and function leadership and will focus on the key management , strategy , and policy issues <UNK> enron . the management committee will <UNK> the former policy committee and will include the following individuals : ken lay - chairman and ceo , enron corp . ray bowen - <UNK> , enron industrial markets michael

### Discriminator
A general discriminator class and two implementations.

NBDiscriminator (done, trained) and RNNDiscriminator (coming soon)


In [87]:
class Discriminator(object):
    """
    Discriminator: a general discriminator class.
    """
    def __init__(self, checkpoint=None):
        pass
    
    def train(self, dataset):
        raise Exception("Not implemented")
        
    def evaluate(self, dataset, split, verbose=True):
        raise Exception("Not implemented")
    
    def save_model(self):
        # Outputs a path that can be passed into the restore.
        raise Exception("Not implemented")
    
    def restore_model(self, model_checkpoint):
        raise Exception("Not implemented")

class MultinomialNBDiscriminator(Discriminator):
    """
    MultinomialNB: Multinomial Naive Bayes Classifier w. alpha=1.0
    
    Trained using TF-IDF features.
    """
    def __init__(self, truncation_len=100, checkpoint=None):
        Discriminator.__init__(self, checkpoint)
        self.truncation_len=truncation_len
        if not checkpoint:
            self.model = sklearn.naive_bayes.MultinomialNB()
        else:
            self.restore_model(checkpoint)
    
    def examples_to_term_doc(self, examples):
        """
        Converts a numerically-encoded examples matrix into a sparse term-documents matrix.
        """
        num_terms = TRAINING_VOCAB_SIZE[self.truncation_len]
        all_row_inds = all_col_inds = all_data = None
        for row_ind, example in enumerate(examples):
            if row_ind % 5000 == 0:
                print "Generating term-docs matrix: %d of %d" %(row_ind, len(examples))
            itemfreqs = scipy.stats.itemfreq(example).T
            # Column indices: the term indices in that document.
            col_inds = itemfreqs[0]
            # Data: the counts of the terms in that document.
            data = itemfreqs[1]
            # Row indices: the current document, for each of the terms in that document.
            row_inds = np.ones(itemfreqs.shape[1], dtype=np.int) * row_ind

            # Concatenate to the existing data.
            if all_row_inds is None:
                all_row_inds = row_inds
                all_col_inds = col_inds
                all_data = data
            else:
                all_row_inds = np.append(all_row_inds, row_inds)
                all_col_inds = np.append(all_col_inds, col_inds)
                all_data = np.append(all_data, data)

        num_docs = len(examples)
        return scipy.sparse.csr_matrix((all_data, (all_row_inds, all_col_inds)), shape=(num_docs, num_terms))

    def train(self, dataset):
        examples, labels = dataset.examples(split='train', shuffled=True)
        
        # Silly way to compute sparse doc term matrix from examples matrix by converting it back into "strings".
        self.train_counts = self.examples_to_term_doc(examples)
        
        # Featurize using TFIDF.
        self.tf_transformer = sklearn.feature_extraction.text.TfidfTransformer()
        X_transformed = self.tf_transformer.fit_transform(self.train_counts)
        
        # Fit the model to TFIDF counts.
        self.model.fit(X_transformed, labels)
    
    def calculate_roc_auc(self, probs, labels):
        # Probability estimates of the positive class.
        pos_probs = probs[:, 1]
        return sklearn.metrics.roc_auc_score(labels, pos_probs)
    
    def evaluate(self, dataset, split, verbose=True):
        # Get the test or validation examples.
        examples, labels = dataset.examples(split=split, shuffled=True)
        doc_terms = self.examples_to_term_doc(examples)
        X_transformed = self.tf_transformer.transform(doc_terms)
        
        # Evaluate the model.
        probs = self.model.predict_proba(doc_terms)
        predicted = np.argmax(probs, axis=1)
        
        # Mean accuracy.
        mean_accuracy = np.mean(predicted == labels)
        print "Mean_accuracy: %f" % mean_accuracy
        
        # ROC-AUC Score.
        roc_auc = self.calculate_roc_auc(probs, labels)
        print "ROC AUC: %f" % roc_auc
    
    def save_model(self, 
                   checkpoint_dir='/cvgl2/u/catwong/cs332_final_project/checkpoints',
                   checkpoint_name='multinomial_nb'):
        # Separately pickles the model and the transformer.
        checkpoint = os.path.join(checkpoint_dir, str(self.truncation_len) + "_" + checkpoint_name)
        sklearn.externals.joblib.dump(self.model, checkpoint + "_model.pkl")
        sklearn.externals.joblib.dump(self.tf_transformer, checkpoint + "_tf_transformer.pkl")
        return [checkpoint + "_model.pkl", checkpoint + "_tf_transformer.pkl"]
    
    def restore_model(self, model_checkpoints):
        self.model = sklearn.externals.joblib.load(model_checkpoints[0])
        self.tf_transformer = sklearn.externals.joblib.load(model_checkpoints[1])

In [85]:
# Demo
spam_dataset = SpamDataset(truncation_len=30)
discriminator = MultinomialNBDiscriminator(truncation_len=30)
discriminator.train(spam_dataset)
discriminator.evaluate(spam_dataset, 'val')

checkpoints_dir = '/cvgl2/u/catwong/cs332_final_project/checkpoints'
checkpoint = discriminator.save_model()
print checkpoint
new_discriminator = MultinomialNBDiscriminator(checkpoint=checkpoint, truncation_len=30)
new_discriminator.evaluate(spam_dataset, 'val')

Generating term-docs matrix: 0 of 3371
Generating term-docs matrix: 0 of 3371
Mean_accuracy: 0.970039
ROC AUC: 0.993406
['/cvgl2/u/catwong/cs332_final_project/checkpoints/30_multinomial_nb_model.pkl', '/cvgl2/u/catwong/cs332_final_project/checkpoints/30_multinomial_nb_tf_transformer.pkl']
Generating term-docs matrix: 0 of 3371
Mean_accuracy: 0.970039
ROC AUC: 0.993406


#### Training Step 1: MultinomialNB Discriminator

In [88]:
# Train and save a model.
for truncation_len in [30, 100]:
    print "Now on truncation_len: " + str(truncation_len)
    spam_dataset = SpamDataset(truncation_len=truncation_len)
    discriminator = MultinomialNBDiscriminator(truncation_len=truncation_len)
    discriminator.train(spam_dataset)
    discriminator.evaluate(spam_dataset, 'val')

    checkpoints_dir = '/cvgl2/u/catwong/cs332_final_project/checkpoints'
    checkpoint = discriminator.save_model()
    new_discriminator = MultinomialNBDiscriminator(checkpoint=checkpoint, truncation_len=truncation_len)
    new_discriminator.evaluate(spam_dataset, 'val')

Now on truncation_len: 30
Generating term-docs matrix: 0 of 26972
Generating term-docs matrix: 5000 of 26972
Generating term-docs matrix: 10000 of 26972
Generating term-docs matrix: 15000 of 26972
Generating term-docs matrix: 20000 of 26972
Generating term-docs matrix: 25000 of 26972
Generating term-docs matrix: 0 of 3371
Mean_accuracy: 0.973598
ROC AUC: 0.994559
Generating term-docs matrix: 0 of 3371
Mean_accuracy: 0.973598
ROC AUC: 0.994559
Now on truncation_len: 100
Generating term-docs matrix: 0 of 26972
Generating term-docs matrix: 5000 of 26972
Generating term-docs matrix: 10000 of 26972
Generating term-docs matrix: 15000 of 26972
Generating term-docs matrix: 20000 of 26972
Generating term-docs matrix: 25000 of 26972
Generating term-docs matrix: 0 of 3371
Mean_accuracy: 0.960249
ROC AUC: 0.994239
Generating term-docs matrix: 0 of 3371
Mean_accuracy: 0.960249
ROC AUC: 0.994239


### Autoencoder
A general autoencoder class.
Based on: https://github.com/MaximumEntropy/Seq2Seq-PyTorch 

In [28]:
class Autoencoder(object):
    """
    Autoencoder: a general discriminator class.
    """
    def __init__(self, checkpoint=None, dataset=None):
        pass
    
    def train(self, dataset):
        raise Exception("Not implemented")
        
    def evaluate(self, dataset, split, verbose=True):
        raise Exception("Not implemented")
    
    def save_model(self):
        # Outputs a path that can be passed into the restore.
        raise Exception("Not implemented")
    
    def restore_model(self, model_checkpoint):
        raise Exception("Not implemented")
        
class SpamSeq2SeqAutoencoder(Autoencoder):
    """
    SpamSeq2Seq Autoencoder.
    Implementation from: https://github.com/MaximumEntropy/Seq2Seq-PyTorch
    Uses the following config: config_en_autoencoder_1_billion.json
    """
    def __init__(self, checkpoint=None, dataset=None):
        Autoencoder.__init__(self, checkpoint, dataset)
        self.dataset = dataset 
        self.vocab_size = len(self.dataset.vocab_encoder.word2index)
        self.pad_token_ind = self.dataset.vocab_encoder.word2index['<PAD>']
        self.batch_size = 2
        
        # Initialize the model.
        self.model = Seq2SeqAutoencoder(
            src_emb_dim=256,
            trg_emb_dim=256,
            src_vocab_size=self.vocab_size,
            src_hidden_dim=512,
            trg_hidden_dim=512,
            batch_size=self.batch_size,
            bidirectional=True,
            pad_token_src=self.pad_token_ind,
            nlayers=2,
            nlayers_trg=1,
            dropout=0.,
        ).cuda()
        
        # Restore from checkpoint if provided.
        if checkpoint:
            self.restore_model(checkpoint)
        
        # Initialize the optimizer.
        self.lr = 0.0002
        self.clip_c = 1
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        
        # Loss criterion.
        weight_mask = torch.ones(self.vocab_size).cuda()
        weight_mask[self.pad_token_ind] = 0
        self.loss_criterion = nn.CrossEntropyLoss(weight=weight_mask).cuda()
        
        # Save the initial model.
        self.save_model()
        
    def clip_gradient(self, model, clip):
        """Compute a gradient clipping coefficient based on gradient norm."""
        totalnorm = 0
        for p in self.model.parameters():
            modulenorm = p.grad.data.norm()
            totalnorm += modulenorm ** 2
        totalnorm = math.sqrt(totalnorm)
        return min(1, clip / (totalnorm + 1e-6))
    
    def get_dataset_minibatch(self, examples, iter_ind, batch_size):
        """
        Iterator over the dataset split and get autoencoder minibatches.
        """
        minibatch = examples[iter_ind:iter_ind+batch_size]
        
        # Create the Pytorch variables.
        input_lines = Variable(torch.LongTensor(np.fliplr(minibatch).copy())).cuda() # Reverse the input lines.
        output_lines = Variable(torch.LongTensor(minibatch)).cuda()
        return input_lines, output_lines
    
    def perplexity(self):
        """Calculate the BLEU score."""
        
    def train(self, dataset, epochs=2, write_checkpoint=1, monitor_loss=1, print_samples=1):
        examples, _ = dataset.examples(split="train", shuffled=True)
        num_examples, max_len = examples.shape
        
        for epoch in xrange(epochs):
            losses = []
            for iter_ind in xrange(0, num_examples, self.batch_size):
                # Get a minibatch.
                input_lines_src, output_lines_src = self.get_dataset_minibatch(examples, iter_ind, self.batch_size)
                
                # Run a training step.
                decoder_logit = self.model(input_lines_src)
                self.optimizer.zero_grad()

                loss = self.loss_criterion(
                    decoder_logit.contiguous().view(-1, self.vocab_size),
                    output_lines_src.view(-1)
                )
                losses.append(loss.data[0])
                loss.backward()
                self.optimizer.step()
                
                if iter_ind % monitor_loss == 0:
                    # TODO(cathywong): change to logging.
                    print('Epoch : %d Minibatch : %d Loss : %.5f' % (epoch, iter_ind, np.mean(losses)))
                    losses = []
                
                if iter_ind % print_samples == 0:
                    # Print samples.
                    word_probs = self.model.decode(decoder_logit).data.cpu().numpy().argmax(axis=-1)
                    output_lines_trg = input_lines_src.data.cpu().numpy()
                    for sentence_pred, sentence_real in zip(word_probs[:5], output_lines_trg[:5]):
                        decoded_real = dataset.vocab_encoder.decode_numpy(sentence_real[::-1])
                        decoded_pred = dataset.vocab_encoder.decode_numpy(sentence_pred)
                        
                        # TODO(cathywong): change to logging.
                        print decoded_pred
                        print decoded_real
                break
            # Write checkpoint.
            if epoch % write_checkpoint == 0:
                self.save_model()
            
        
    def evaluate(self, dataset, split, verbose=True):
        raise Exception("Not implemented")
    
    def save_model(self, 
                   checkpoint_dir='/cvgl2/u/catwong/cs332_final_project/checkpoints',
                   checkpoint_name='seq2seq_autoencoder'):
        # Outputs a path that can be passed into the restore.
        checkpoint_file = checkpoint_name + '.model'
        full_checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
        torch.save(
            self.model.state_dict(),
            open(full_checkpoint_path, 'wb')
        )
        return full_checkpoint_path
    
    def restore_model(self, checkpoint):
        self.model.load_state_dict(torch.load(open(checkpoint)))

# Demo
spam_dataset = SpamDataset()
autoencoder = SpamSeq2SeqAutoencoder(dataset=spam_dataset)
autoencoder.train(spam_dataset)

Epoch : 0 Minibatch : 0 Loss : 8.40566
pa vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent vincent hotat hotat corner found hotat hotat hotat canadian canadian canadian canadian canadian canadian sequoia canadian canadian canadian canadian canadian sequoia canadian canadian canadian vincent hotat canadian canadian canadian 96 sequoia 96 hotat hotat sequoia 96 sequoia hotat hotat hotat canadian sequoia hotat found found hotat hotat hotat hotat edge else edge canadian canadian hotat sequoia corner hotat hotat hotat hotat sequoia sequoia hotat hotat sequoia hotat acy hotat canadian canadian canadian hotat hotat hotat canadian canadian
<SOS> subject: calpine daily gas nomination we are still under the scheduled outage period and will bring the next unit down @ <UNK> saturday 03 / 24 / 01 . the following is our estimated burn until then . thanks > ri