In [2]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch

In [5]:
torch.cuda.is_available()

True

In [6]:
from train import main
from argparse import Namespace
import test
from vocab import Vocab
import numpy as np

In [60]:
args = {
    'train': 'data/mnli/train.txt',
    'valid': 'data/mnli/dev.txt',
    'model_type': 'aae',
    'lambda_adv': 10,
    'lambda_p': 0,
    'lambda_kl': 0,
    'noise': [0.3, 0, 0, 0],
    'save_dir': 'checkpoints/aae',
    'epochs': 5,
    'load_model': '',
    'vocab_size': 50000,
    'dim_z': 128,
    'dim_emb': 512,
    'dim_h': 1024,
    'nlayers': 1,
    'dim_d': 512,
    'dropout': 0,
    'lr': 0.0005,
    'batch_size': 256,
    'seed': 598,
    'log_interval': 100,
    'no_cuda': False,
}
args = Namespace(**args)

In [61]:
main(args)

Namespace(train='data/mnli/train.txt', valid='data/mnli/dev.txt', model_type='aae', lambda_adv=10, lambda_p=0, lambda_kl=0, noise=[0.3, 0, 0, 0], save_dir='checkpoints/aae', epochs=5, load_model='', vocab_size=50000, dim_z=128, dim_emb=512, dim_h=1024, nlayers=1, dim_d=512, dropout=0, lr=0.0005, batch_size=256, seed=598, log_interval=100, no_cuda=False)
# train sents 392702, tokens 3916049
# valid sents 9815, tokens 97329
# vocab size 50005
# model parameters: 96413782
--------------------------------------------------------------------------------
| epoch   1 |   100/ 1572 batches | rec 87.42, adv 0.97, |lvar| 167.16, loss_d 1.24, loss 97.09,
| epoch   1 |   200/ 1572 batches | rec 85.16, adv 1.14, |lvar| 294.03, loss_d 1.60, loss 96.55,
| epoch   1 |   300/ 1572 batches | rec 79.42, adv 0.74, |lvar| 405.14, loss_d 1.44, loss 86.77,
| epoch   1 |   400/ 1572 batches | rec 85.23, adv 0.71, |lvar| 314.87, loss_d 1.33, loss 92.29,
| epoch   1 |   500/ 1572 batches | rec 73.99, adv 0.61, 

| epoch   5 |   900/ 1572 batches | rec 50.76, adv 0.67, |lvar| 772.25, loss_d 1.42, loss 57.44,
| epoch   5 |  1000/ 1572 batches | rec 47.65, adv 0.71, |lvar| 769.46, loss_d 1.43, loss 54.76,
| epoch   5 |  1100/ 1572 batches | rec 52.04, adv 0.67, |lvar| 752.71, loss_d 1.38, loss 58.73,
| epoch   5 |  1200/ 1572 batches | rec 46.42, adv 0.74, |lvar| 747.61, loss_d 1.46, loss 53.82,
| epoch   5 |  1300/ 1572 batches | rec 49.52, adv 0.67, |lvar| 792.16, loss_d 1.45, loss 56.27,
| epoch   5 |  1400/ 1572 batches | rec 45.99, adv 0.68, |lvar| 763.41, loss_d 1.39, loss 52.82,
| epoch   5 |  1500/ 1572 batches | rec 42.87, adv 0.71, |lvar| 794.37, loss_d 1.41, loss 49.98,
--------------------------------------------------------------------------------
| end of epoch   5 | time   192s | valid rec 42.92, adv 0.73, |lvar| 810.45, loss_d 1.34, loss 50.20, | saving model
Done training


In [99]:
vocab = Vocab('checkpoints/aae/vocab.txt')
test.set_seed(598)
device = torch.device('cuda')

In [7]:
def get_model(path, vocab):
    ckpt = torch.load(path)
    train_args = ckpt['args']
    model = test.AAE(vocab, train_args).to(device)
    model.load_state_dict(ckpt['model'])
    model.flatten()
    model.eval()
    return model

In [64]:
model = get_model('checkpoints/aae/model.pt', vocab)

In [67]:
z = np.random.normal(size=(10, 128)).astype('f')
sents = []
i = 0
while i < len(z):
    zi = torch.tensor(z[i: i+1], device=device)
    outputs = model.generate(zi, 35, 'sample').t()
    for s in outputs:
        sents.append([vocab.idx2word[id] for id in s[1:]])  # skip <go>
    i += 1

In [68]:
test.write_sent(sents, 'checkpoints/aae/sample.txt')

In [11]:
def encode(sents, vocab, batch_size, model, device, enc='mu'):
    batches, order = test.get_batches(sents, vocab, batch_size, device)
    z = []
    for inputs, _ in batches:
        mu, logvar = model.encode(inputs)
        if enc == 'mu':
            zi = mu
        else:
            zi = test.reparameterize(mu, logvar)
        z.append(zi.detach().cpu().numpy())
    z = np.concatenate(z, axis=0)
    z_ = np.zeros_like(z)
    z_[np.array(order)] = z
    return z_

def decode(z, vocab, batch_size, max_len, model, device, dec='sample'):
    sents = []
    i = 0
    while i < len(z):
        zi = torch.tensor(z[i: i+batch_size], device=device)
        outputs = model.generate(zi, max_len, dec).t()
        for s in outputs:
            sents.append([vocab.idx2word[id] for id in s[1:]])  # skip <go>
        i += batch_size
    return test.strip_eos(sents)

In [151]:
premise = "Um, I read some of the same books that they had read to me, first, and then, as I got older, I just got hungry for books."
hypothesis = "I lost interest in reading over time."
label = "contradiction"

sents = [ hypothesis.split() ]
z = encode(sents, vocab, 1, model, device)

n = 10
for i in range(n):
    z_noise = z + np.random.normal(0, 0.2, size=z.shape).astype('f')
    decoded = decode(z_noise, vocab, 1, 30, model, device, dec='greedy')
    print(' '.join(decoded[0]))

I got all all over the last year.
I got to put over over the years.
I got all over over the years.
I got all over over the years.
I got plenty of time for it.
I haven't gone up in it.
I got over it over the years.
I haven't talked to get over the time.
I got all over the last year.
I got to put over over time.


## Updated Model

In [9]:
vocab = Vocab('checkpoints/aae-2023-04-11_19-20-46/vocab.txt')
test.set_seed(598)
device = torch.device('cuda')

In [10]:
model = get_model('checkpoints/aae-2023-04-11_19-20-46/model.pt', vocab)

In [40]:
hypothesis = "I don't know how cold it got last night."

sents = [ hypothesis.split() ]
z = encode(sents, vocab, 1, model, device)

n = 10
for i in range(n):
    z_noise = z + np.random.normal(0, 0.6, size=z.shape).astype('f')
    decoded = decode(z_noise, vocab, 1, 30, model, device, dec='greedy')
    print(' '.join(decoded[0]))

I don't know how long it got last in.
They didn't see how long it got last day.
I don't know how cold it went last night.
Jon all saw the paths after just got it.
I don't know how it had gone last night.
The clue how it got cold it last night.
I don't know how long it saw last night.
I don't know how long it last night.
I don't know how it stayed the last night.
I knew how so it was a last night.
