# CS 287 - HW 4 - VAE

In [1]:
import random
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torchtext
from torchtext.vocab import Vectors, GloVe
from namedtensor import ntorch, NamedTensor
from namedtensor.text import NamedField

from common import *
%reload_ext autoreload

In [2]:
# load data
TEXT = NamedField(names=('seqlen',)) # Our input $x$
LABEL = NamedField(sequential=False, names=()) # Our labels $y$
train, val, test = torchtext.datasets.SNLI.splits(TEXT, LABEL)
print('len(train)', len(train))
TEXT.build_vocab(train)
LABEL.build_vocab(train)
print('len(TEXT.vocab)', len(TEXT.vocab))
print('len(LABEL.vocab)', len(LABEL.vocab))
train_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits(
    (train, val, test), batch_size=32, device=torch.device("cuda"), repeat=False)

len(train) 549367
len(TEXT.vocab) 62998
len(LABEL.vocab) 4


In [3]:
# build the vocabulary with word embeddings
# out-of-vocabulary words are hashed to one of 100 random embeddings each initialized to mean 0, stdev 1 (Sec 5.1)
unk_vectors = [torch.randn(300) for _ in range(100)]
TEXT.vocab.load_vectors(vectors='glove.6B.300d', unk_init=lambda x:random.choice(unk_vectors))
vectors = TEXT.vocab.vectors
vectors = vectors / vectors.norm(dim=1, keepdim=True) # normalized to have l_2 norm of 1
vectors = NamedTensor(vectors, ('word', 'embedding'))
TEXT.vocab.vectors = vectors
print("word embeddings shape:", TEXT.vocab.vectors.shape)
weights = TEXT.vocab.vectors.values.cuda()

word embeddings shape: OrderedDict([('word', 62998), ('embedding', 300)])


In [4]:
# here's an example of a training example
batch = next(iter(train_iter))
print("Size of premise batch:", batch.premise.shape)
print("Size of hypothesis batch:", batch.hypothesis.shape)
print("Size of label batch:", batch.label.shape)

Size of premise batch: OrderedDict([('seqlen', 40), ('batch', 128)])
Size of hypothesis batch: OrderedDict([('seqlen', 19), ('batch', 128)])
Size of label batch: OrderedDict([('batch', 128)])


In [7]:
# dimensions
input_size = TEXT.vocab.vectors.shape['word']
embed_size = TEXT.vocab.vectors.shape['embedding']
hidden_size1 = 200
output_size = len(LABEL.vocab)
print('DIMS - input: %d, embed: %d, hidden1: %d, output: %d'%(input_size, embed_size, hidden_size1, output_size))

DIMS - input: 62998, embed: 300, hidden1: 200, output: 4


## Latent Variable Mixture Model

In [None]:
class Q_Network(torch.nn.Module):
    def __init__(self, embed_size, output_size, weights, networks):
        super(Q_Network, self).__init__()
        self.embed_size = embed_size # HIDDEN = embed_size
        self.output_size = output_size
        
        self.embed = nn.Embedding.from_pretrained(weights, freeze=True)
        self.linear = nn.Linear(self.embed_size * 2, self.output_size)
        
        self.lmb = self.linear.weight
        self.m = nn.LogSoftmax(dim=1)
        
        self.criterion = nn.NLLLoss()
        K = len(networks)
        probs = torch.repeat(torch.tensor(1/K), K)
        self.prior = torch.distributions.categorical.Categorical(probs=probs)
        
    def mask(self, sent1, sent2, proj1, proj2, pad_tkn=1):
        mask1 = (sent1 == pad_tkn) # BATCH x SEQLEN
        mask2 = (sent2 == pad_tkn)
        mask1a = mask1.unsqueeze(2).expand(-1, -1, self.hidden_size).float() # BATCH x SEQLEN x HIDDEN
        mask2a = mask2.unsqueeze(2).expand(-1, -1, self.hidden_size).float()
        score1 = proj1 * (1 - mask1a) # BATCH x SEQLEN x HIDDEN
        score2 = proj2 * (1 - mask2a)
        return score1, score2
    
    def forward(self, sent1, sent2, pad_tkn = 1):
        proj1 = self.embed(sent1) # BATCH x SEQLEN x HIDDEN
        proj2 = self.embed(sent2)
        score1, score2 = self.mask(sent1, sent2, proj1, proj2) # BATCH x SEQLEN x HIDDEN
        score1_sum = torch.sum(score1, dim=1) # BATCH x HIDDEN
        score2_sum = torch.sum(score2, dim=1)
        score_all = torch.cat((score1_sum, score2_sum), dim=1) # BATCH x HIDDEN*2
        output = self.m(self.linear(score_all)) # BATCH x OUTPUT
        self.q = torch.distributions.categorical.Categorical(logits=output)
        return output
    
    def basic_grad(self, sent1, sent2, y, N=1):
        ELBO = torch.zeros(1)
        lmb_out = torch.zeros((self.lmb.shape, N))
        for i in range(N):
            try:
                self.lmb.grad.data.zero_()
            except AttributeError:
                pass
            # sample, forward pass
            q = self.q
            c = q.sample()
            net = networks[c]
            y_hat = net(sent1, sent2)
            # calc ELBO, dELBO for the sample
            loss = q.log_prob(c).sum()
            loss.backward()
            xent = criterion(y_hat, y)
            ELBO_ = self.prior.log_prob(c).sum() - xent - loss
            ELBO.add_(ELBO_)
            lmb_out[:,i] = self.lmb.grad.data.mul(ELBO_)
        ELBO.div_(N)
        self.lmb.grad.data = lmb_out.mean(dim=-1)
        return ELBO.sum(), lmb_out

In [8]:
def training_loop(e, train_iter, test_net, eta = 1e-5):
    for ix,batch in enumerate(train_iter):
        sent1 = prepend_null(batch.premise.values.transpose(0,1))
        sent2 = prepend_null(batch.hypothesis.values.transpose(0,1))
        target = batch.label.values
        _ = test_net(sent1, sent2) # forward pass
        ELBO, lmb_out = test_net.get_grad(sent1, sent2, target, N=200) # get grad
        test_net.lmb.data = test_net.lmb.data + (eta * test_net.lmb.grad.data) # step
        
        if ix % 1000 == 0:
            print('Epoch: {0}, Batch: {1}, ELBO: {2:0.4f}, Grad: {3:0.4f}'.format(e, ix, ELBO.data, test_net.lmb.grad.data.abs().mean())

    return ELBO.data

In [9]:
'''def validation_loop(e, val_iter, networks, criterion):
    K = len(networks)
    for network in networks:
        network.eval()
    total_loss = 0
    total_sent = 0
    total_correct = 0
    
    for ix,batch in enumerate(val_iter):
        sent1 = prepend_null(batch.premise.values.transpose(0,1))
        sent2 = prepend_null(batch.hypothesis.values.transpose(0,1))
        target = batch.label.values
        
        output = torch.zeros(output_size)
        for c in range(K):
            network = networks[c]
            output += F.softmax(network(sent1, sent2), dim=1) # BATCH x OUTPUT_SIZE
        output = torch.log(output) + torch.log(torch.tensor(1/K))
        
        loss = criterion(output, target).item()
        sent = sent1.shape[0]
        correct = torch.sum(torch.argmax(output, dim=1) == target).item()
        
        total_loss += loss*sent
        total_sent += sent
        total_correct += correct
    
    print('Epoch: {0}, Val NLL: {1:0.4f}, Val Acc: {2:0.4f}'.format(e, total_loss/total_sent, total_correct/total_sent))
    return total_loss'''

In [None]:
FFA_net1 = Decomposable_Attn_Network(input_size, embed_size, hidden_size1, output_size, weights).cuda()
state_dict = torch.load('best_FFA_net1.pt')
FFA_net1.load_state_dict(state_dict)

FFA_net2 = Decomposable_Attn_Network(input_size, embed_size, hidden_size1, output_size, weights).cuda()
state_dict = torch.load('best_FFA_net2.pt')
FFA_net2.load_state_dict(state_dict)

FFA_net3 = Decomposable_Attn_Network(input_size, embed_size, hidden_size1, output_size, weights).cuda()
state_dict = torch.load('best_FFA_net3.pt')
FFA_net3.load_state_dict(state_dict)

networks = [FFA_net1, FFA_net2, FFA_net3]

In [None]:
best_elbo = -1e32
test_net = Q_Network(embed_size, output_size, weights, networks)

for e in range(100):
    elbo = training_loop(e, train_iter, test_net)
    if elbo > best_elbo:
        torch.save(test_net.state_dict(), 'best_q.pt')
        best_elbo = elbo
        print('WROTE MODEL')