# CS 287 - HW 4 - VAE

In [1]:
import random
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torchtext
from torchtext.vocab import Vectors, GloVe
from namedtensor import ntorch, NamedTensor
from namedtensor.text import NamedField

from common import *
%reload_ext autoreload

In [2]:
# load data
TEXT = NamedField(names=('seqlen',)) # Our input $x$
LABEL = NamedField(sequential=False, names=()) # Our labels $y$
train, val, test = torchtext.datasets.SNLI.splits(TEXT, LABEL)
print('len(train)', len(train))
TEXT.build_vocab(train)
LABEL.build_vocab(train)
print('len(TEXT.vocab)', len(TEXT.vocab))
print('len(LABEL.vocab)', len(LABEL.vocab))
train_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits(
    (train, val, test), batch_size=32, device=torch.device("cuda"), repeat=False)

len(train) 549367
len(TEXT.vocab) 62998
len(LABEL.vocab) 4


In [3]:
# build the vocabulary with word embeddings
# out-of-vocabulary words are hashed to one of 100 random embeddings each initialized to mean 0, stdev 1 (Sec 5.1)
unk_vectors = [torch.randn(300) for _ in range(100)]
TEXT.vocab.load_vectors(vectors='glove.6B.300d', unk_init=lambda x:random.choice(unk_vectors))
vectors = TEXT.vocab.vectors
vectors = vectors / vectors.norm(dim=1, keepdim=True) # normalized to have l_2 norm of 1
vectors = NamedTensor(vectors, ('word', 'embedding'))
TEXT.vocab.vectors = vectors
print("word embeddings shape:", TEXT.vocab.vectors.shape)
weights = TEXT.vocab.vectors.values.cuda()

word embeddings shape: OrderedDict([('word', 62998), ('embedding', 300)])


In [4]:
# here's an example of a training example
batch = next(iter(train_iter))
print("Size of premise batch:", batch.premise.shape)
print("Size of hypothesis batch:", batch.hypothesis.shape)
print("Size of label batch:", batch.label.shape)

Size of premise batch: OrderedDict([('seqlen', 33), ('batch', 32)])
Size of hypothesis batch: OrderedDict([('seqlen', 12), ('batch', 32)])
Size of label batch: OrderedDict([('batch', 32)])


In [5]:
# dimensions
input_size = TEXT.vocab.vectors.shape['word']
embed_size = TEXT.vocab.vectors.shape['embedding']
hidden_size1 = 200
output_size = len(LABEL.vocab)
print('DIMS - input: %d, embed: %d, hidden1: %d, output: %d'%(input_size, embed_size, hidden_size1, output_size))

DIMS - input: 62998, embed: 300, hidden1: 200, output: 4


## Latent Variable Mixture Model

In [110]:
class Q_Network(torch.nn.Module):
    def __init__(self, embed_size, output_size, weights, networks):
        super(Q_Network, self).__init__()
        self.embed_size = embed_size # HIDDEN = embed_size
        self.K = len(networks)
        self.output_size = output_size
        self.weights = weights
        self.networks = networks
        
        self.embed = nn.Embedding.from_pretrained(self.weights, freeze=True)
        self.linear = nn.Linear(self.embed_size * 2, self.K, bias=False)
        
        self.lmb = self.linear.weight
        self.m = nn.LogSoftmax(dim=1)
        
        self.criterion = nn.NLLLoss(reduction='sum')
        self.probs = torch.tensor(1/self.K, device='cuda').repeat(self.K)
        self.prior = torch.distributions.categorical.Categorical(probs=self.probs)
        
    def mask(self, sent1, sent2, proj1, proj2, pad_tkn=1):
        mask1 = (sent1 == pad_tkn) # BATCH x SEQLEN
        mask2 = (sent2 == pad_tkn)
        mask1a = mask1.unsqueeze(2).expand(-1, -1, self.embed_size).float() # BATCH x SEQLEN x HIDDEN
        mask2a = mask2.unsqueeze(2).expand(-1, -1, self.embed_size).float()
        score1 = proj1 * (1 - mask1a) # BATCH x SEQLEN x HIDDEN
        score2 = proj2 * (1 - mask2a)
        return score1, score2
    
    def forward(self, sent1, sent2, pad_tkn = 1):
        proj1 = self.embed(sent1) # BATCH x SEQLEN x HIDDEN
        proj2 = self.embed(sent2)
        score1, score2 = self.mask(sent1, sent2, proj1, proj2) # BATCH x SEQLEN x HIDDEN
        score1_sum = torch.sum(score1, dim=1) # BATCH x HIDDEN
        score2_sum = torch.sum(score2, dim=1)
        score_all = torch.cat((score1_sum, score2_sum), dim=1) # BATCH x HIDDEN*2
        output = self.m(self.linear(score_all)) # BATCH x K
        self.output = output
        self.q = torch.distributions.categorical.Categorical(logits=output)
        return output
    
    def run_networks(self, sent1, sent2):
        y_hats = torch.zeros((self.K, sent1.shape[0], self.output_size), device='cuda')
        for c in range(self.K):
            net = self.networks[c]
            y_hat = net(sent1, sent2)
            y_hats[c,:,:] = y_hat
        return y_hats
    
    def get_grad(self, sent1, sent2, y, N=1):
        ELBO = torch.zeros(1)
        lmb_out = torch.zeros((self.K, self.embed_size * 2, N))
        y_hats = self.run_networks(sent1, sent2)
        q = self.q
        for i in range(N):      
            print(i)
            # zero grads
            try:
                self.lmb.grad.data.zero_()
            except AttributeError:
                pass
            # sample, forward pass
            c = q.sample()
            y_hat = torch.stack([torch.index_select(y_hats[:,i,:],0,c[i]) for i in range(sent1.shape[0])]).squeeze()
            # calc ELBO, dELBO for the sample
            loss = self.criterion(self.output, c)#q.log_prob(c).sum()
            loss.backward()
            nll = self.criterion(y_hat, y)
            priorlp = self.prior.log_prob(c).sum()
            ELBO_ = priorlp - nll - loss
            ELBO.add_(ELBO_)
            lmb_out[:,:,i] = self.lmb.grad.data.mul(ELBO_)
        ELBO.div_(N)
        self.lmb.grad.data = lmb_out.mean(dim=-1)
        return ELBO.sum(), lmb_out

In [111]:
def training_loop(e, train_iter, test_net, eta = 1e-5):
    for ix,batch in enumerate(train_iter):
        sent1 = prepend_null(batch.premise.values.transpose(0,1))
        sent2 = prepend_null(batch.hypothesis.values.transpose(0,1))
        target = batch.label.values
        _ = test_net(sent1, sent2) # forward pass
        ELBO, lmb_out = test_net.get_grad(sent1, sent2, target, N=200) # get grad
        test_net.lmb.data = test_net.lmb.data + (eta * test_net.lmb.grad.data) # step
        
        if ix % 1000 == 0:
            print('Epoch: {0}, Batch: {1}, ELBO: {2:0.4f}, Grad: {3:0.4f}'.format(e, ix, ELBO.data, test_net.lmb.grad.data.abs().mean()))

    return ELBO.data

In [8]:
'''def validation_loop(e, val_iter, networks, criterion):
    K = len(networks)
    for network in networks:
        network.eval()
    total_loss = 0
    total_sent = 0
    total_correct = 0
    
    for ix,batch in enumerate(val_iter):
        sent1 = prepend_null(batch.premise.values.transpose(0,1))
        sent2 = prepend_null(batch.hypothesis.values.transpose(0,1))
        target = batch.label.values
        
        output = torch.zeros(output_size)
        for c in range(K):
            network = networks[c]
            output += F.softmax(network(sent1, sent2), dim=1) # BATCH x OUTPUT_SIZE
        output = torch.log(output) + torch.log(torch.tensor(1/K))
        
        loss = criterion(output, target).item()
        sent = sent1.shape[0]
        correct = torch.sum(torch.argmax(output, dim=1) == target).item()
        
        total_loss += loss*sent
        total_sent += sent
        total_correct += correct
    
    print('Epoch: {0}, Val NLL: {1:0.4f}, Val Acc: {2:0.4f}'.format(e, total_loss/total_sent, total_correct/total_sent))
    return total_loss'''

"def validation_loop(e, val_iter, networks, criterion):\n    K = len(networks)\n    for network in networks:\n        network.eval()\n    total_loss = 0\n    total_sent = 0\n    total_correct = 0\n    \n    for ix,batch in enumerate(val_iter):\n        sent1 = prepend_null(batch.premise.values.transpose(0,1))\n        sent2 = prepend_null(batch.hypothesis.values.transpose(0,1))\n        target = batch.label.values\n        \n        output = torch.zeros(output_size)\n        for c in range(K):\n            network = networks[c]\n            output += F.softmax(network(sent1, sent2), dim=1) # BATCH x OUTPUT_SIZE\n        output = torch.log(output) + torch.log(torch.tensor(1/K))\n        \n        loss = criterion(output, target).item()\n        sent = sent1.shape[0]\n        correct = torch.sum(torch.argmax(output, dim=1) == target).item()\n        \n        total_loss += loss*sent\n        total_sent += sent\n        total_correct += correct\n    \n    print('Epoch: {0}, Val NLL: {1:0.

In [9]:
FFA_net1 = Decomposable_Attn_Network(input_size, embed_size, hidden_size1, output_size, weights).cuda()
state_dict1 = torch.load('best_FFA_net0.pt')
FFA_net1.load_state_dict(state_dict1)

FFA_net2 = Decomposable_Attn_Network(input_size, embed_size, hidden_size1, output_size, weights).cuda()
state_dict2 = torch.load('best_FFA_net1.pt')
FFA_net2.load_state_dict(state_dict2)

FFA_net3 = Decomposable_Attn_Network(input_size, embed_size, hidden_size1, output_size, weights).cuda()
state_dict3 = torch.load('best_FFA_net2.pt')
FFA_net3.load_state_dict(state_dict3)

networks = [FFA_net1, FFA_net2, FFA_net3]

In [112]:
best_elbo = -1e32
K = len(networks)
test_net = Q_Network(embed_size, output_size, weights, networks).cuda()

for e in range(100):
    elbo = training_loop(e, train_iter, test_net)
    if elbo > best_elbo:
        torch.save(test_net.state_dict(), 'best_q.pt')
        best_elbo = elbo
        print('WROTE MODEL')

0
1


RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

In [24]:
temp = torch.distributions.categorical.Categorical(probs=torch.tensor([[0.3,0.4,0.3],[0.3,0.4,0.3]]))

In [25]:
sampl = torch.tensor([1,2])
lp = temp.log_prob(sampl)

In [28]:
lp.device

device(type='cpu')

In [82]:
test_net.lmb.shape

torch.Size([3, 600])