# Sentence Variational Adversarial Active Learning (S-VAAL)
@author: Tyler Bikaun

The following notebook intends to flesh out an initial proof concept of the coupling of S-VAE (Bowman <i>et al.</i> 2016; https://arxiv.org/abs/1511.06349) and VAAL (Sinha <i>et al.</i> 2019; https://arxiv.org/abs/1904.00370)

<b>Application:</b> Named Entity Task (NER)

### Standard Imports

In [57]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.nn.utils.rnn as rnn_utils

### Utility Functions

In [None]:
config = {'': ''}

### Models
<i>Model architectures</i><br>
<b>SVAE</b> - RNN<br>
<b>Discriminator</b> - FC NN<br>
<b>Task Learner</b> - RNN<br>

#### SVAE

In [265]:
class SVAE(nn.Module):
    """ Sentence Variational Autoencoder (Bowman et al. 2016)"""
    
    def __init__(self, vocab_size, embedding_size):
        super(SVAE, self).__init__()
        
        # TODO: fix dodgy vocab_size issue... this will be cleared up when utils implemented properly
        
        self.tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
        
        self.max_sequence_length = 40    # arg
        self.pad_idx = 0
        self.eos_idx = vocab_size + 1
        self.sos_idx = vocab_size + 2
        self.unk_idx = vocab_size + 3
        
        self.vocab_size = vocab_size + 4
        
        self.z_dim = 8
        
        self.rnn_type = 'gru'
        self.bidirectional = False
        self.num_layers = 1
        self.hidden_size = 128
        
        self.embedding = nn.Embedding(self.vocab_size, embedding_size)
        self.word_dropout_rate = 0.1
        self.embedding_dropout = nn.Dropout(p=0.5)
        
        # set rnn type
        if self.rnn_type == 'gru':
            rnn = nn.GRU
        else:
            raise ValueError()
        
        # init encoder-decoder RNNs (models are identical)
        self.encoder_rnn = rnn(embedding_size,
                               self.hidden_size, 
                               num_layers=self.num_layers,
                               bidirectional=self.bidirectional,
                               batch_first=True)
        self.decoder_rnn = rnn(embedding_size,
                               self.hidden_size, 
                               num_layers=self.num_layers,
                               bidirectional=self.bidirectional,
                               batch_first=True)

        self.hidden_factor = (2 if self.bidirectional else 1) * self.num_layers
        
        # Initialisation of FC layers
        # These go from encoder to latent (z) space
        self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor, self.z_dim)
        self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor, self.z_dim)
        self.z2hidden = nn.Linear(self.z_dim, self.hidden_size * self.hidden_factor)
        self.outputs2vocab = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.vocab_size)
    
    
    def forward(self, input_sequence, length):
        """ Forward pass through VAE """
        
        batch_size = input_sequence.size(0)
        sorted_lengths, sorted_idx = torch.sort(length, descending=True)   # trick for packed padding
        input_sequence = input_sequence[sorted_idx]
        
        # ENCODER
        input_embedding = self.embedding(input_sequence)
#         print(input_embedding.shape)
        packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)
        _, hidden = self.encoder_rnn(packed_input)
        
        if self.bidirectional or 1 < self.num_layers:
            # flatten hidden state
            hidden = hidden.view(batch_size, self.hidden_size * self.hidden_factor)
        else:
            # .squeeze() -> Returns a tensor with all the dimensions of input of size 1 removed.
            print(f'hidden shape before squeeze {hidden.shape}')
#             hidden = hidden.squeeze()
            pass
            print(f'hidden shape after squeeze {hidden.shape}')

        
        # Reparameterisation trick!
        z, mean, logv, std = self.reparameterise(hidden, batch_size)
        
        # DECODER
        if 0 < self.word_dropout_rate:
            prob = torch.rand(input_sequence.size())

            if torch.cuda.is_available():
                prob = prob.cuda()

            prob[(input_sequence.data - self.sos_idx) * (input_sequence.data - self.pad_idx) == 0] = 1

            decoder_input_sequence = input_sequence.clone()
            
#             print(vocab_size)
#             print(self.unk_idx)
            decoder_input_sequence[prob < self.word_dropout_rate] = self.unk_idx

#             print(decoder_input_sequence)
            input_embedding = self.embedding(decoder_input_sequence)

        input_embedding = self.embedding_dropout(input_embedding)
        packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)
        
        outputs, _ = self.decoder_rnn(packed_input, hidden)
        
        # process outputs
        # Process outputs
        # Unpack padded sequence
        padded_outputs = rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0]
        padded_outputs = padded_outputs.contiguous()
        _, reversed_idx = torch.sort(sorted_idx)
        padded_outputs = padded_outputs[reversed_idx]
        b, s, _ = padded_outputs.size()

        # Project outputs to vocab
        # e.g. project hidden state into label space...
        logp = nn.functional.log_softmax(self.outputs2vocab(padded_outputs.view(-1, padded_outputs.size(2))), dim=-1)
#         print(f'logp before view {logp.shape}\n')
#         print(f'b {b} s {s} no emb {self.embedding.num_embeddings}')
        logp = logp.view(b, s, self.embedding.num_embeddings)

        # logp - log posterior over label space; mean - tensor Gaussian mean, logv - tensor Gaussian variance, z - VAE latent space 
        return logp, mean, logv, z
    
    def to_var(self, x):
        if torch.cuda.is_available():
            x = x.cuda()
        return x
    
    def reparameterise(self, hidden, batch_size):
        """ Implement reparameterisation trick (Kingma and Welling 2014) """
        
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        std = torch.exp(0.5 * logv)
        
        z = self.to_var(torch.randn([batch_size, self.z_dim]))
        return z * std + mean, mean, logv, std

In [266]:
# Testing functionality
vocab_size = 100
hidden_size = 128
SVAE = SVAE(vocab_size, hidden_size)

In [267]:
# Test forward pass of SVAE
import random
import numpy as np

# generate some fake sequences with padding
seqs = list()
for i in range(10):
    seq = np.random.randint(low=1, high=100, size=(random.randint(1,40),))
    # add padding
    seq = np.concatenate((seq, np.zeros(shape=(40 - len(seq)))), axis=None)
    seqs.append(seq)
sequences = torch.LongTensor(seqs)
# get lengths of sequences (not including padding)
lengths = torch.tensor([len(seq[seq != 0]) for seq in sequences])

print(f'Shapes - seq {sequences.shape} - lengths {lengths.shape}')
# print(lengths)

# Pass sequences and lengths into SVAE forward method
print(SVAE.forward(sequences, lengths))

Shapes - seq torch.Size([10, 40]) - lengths torch.Size([10])
hidden shape before squeeze torch.Size([1, 10, 128])
hidden shape after squeeze torch.Size([1, 10, 128])
(tensor([[[-4.6944, -4.5930, -4.4858,  ..., -4.8528, -4.4396, -5.0879],
         [-4.5565, -4.8121, -4.3012,  ..., -4.5629, -4.3834, -4.8766],
         [-4.4574, -4.7374, -4.6422,  ..., -4.7135, -4.4288, -4.8301],
         ...,
         [-4.6116, -4.6598, -4.6994,  ..., -4.6540, -4.6051, -4.6548],
         [-4.6116, -4.6598, -4.6994,  ..., -4.6540, -4.6051, -4.6548],
         [-4.6116, -4.6598, -4.6994,  ..., -4.6540, -4.6051, -4.6548]],

        [[-4.3756, -4.4705, -4.8215,  ..., -4.9709, -4.7022, -4.6703],
         [-4.5016, -4.6237, -4.9081,  ..., -4.6778, -4.5357, -4.5366],
         [-4.5677, -4.8624, -4.7604,  ..., -4.3396, -4.5037, -4.5329],
         ...,
         [-4.6116, -4.6598, -4.6994,  ..., -4.6540, -4.6051, -4.6548],
         [-4.6116, -4.6598, -4.6994,  ..., -4.6540, -4.6051, -4.6548],
         [-4.6116, -4.

In [268]:
# Reviewing components
print(f'Encoder: {SVAE.encoder_rnn}')
print(f'Decoder: {SVAE.decoder_rnn}')
print(f'hidden2mean: {SVAE.hidden2mean}')
print(f'hidden2logv: {SVAE.hidden2logv}')
print(f'z2hidden: {SVAE.z2hidden}')
print(f'outputs2vocab: {SVAE.outputs2vocab}')

Encoder: GRU(128, 128, batch_first=True)
Decoder: GRU(128, 128, batch_first=True)
hidden2mean: Linear(in_features=128, out_features=8, bias=True)
hidden2logv: Linear(in_features=128, out_features=8, bias=True)
z2hidden: Linear(in_features=8, out_features=128, bias=True)
outputs2vocab: Linear(in_features=128, out_features=104, bias=True)


#### Task Learner

In [46]:
class TaskLearner(nn.Module):
    """ Task learner for NER """
    
    def __init__(self):
        super(TaskLearner, self).__init__()

In [49]:
# Testing functionality
TaskLearner()

TaskLearner()

#### Discriminator

In [42]:
class Discriminator(nn.Module):
    """ Adversary architecture for discriminator module """
    
    def __init__(self, z_dim=8):
        super(Discriminator, self).__init__()
        
        self.z_dim = z_dim    # latent space dimension
        
        self.net = nn.Sequential(
                                nn.Linear(z_dim, 128),
                                nn.ReLU(True),
                                nn.Linear(128, 128),
                                nn.ReLU(True),
                                nn.Linear(128,1),
                                nn.Sigmoid()
                                )
        
        self.weight_init()
        
    def weight_init(self):
        """ Weight initialisation
        
        Using Xavier uniform initialisation rather than Kaiming (I think that is more focused for CV? TODO: investigate)
        See: https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1ace282f75916a862c9678343dfd4d5ffe.html
        """
        for block in self._modules:
            for m in self._modules[block]:
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform_(m.weight)
                    m.bias.data.fill_(0.01)
    
    def forward(self, z):
        """ Forward pass through discriminator
        
        Arguments
        --------
            z : tensor
                Tensor derived from SVAE latent space
        """
        return self.net(z)

In [56]:
# Testing functionality
dsc = Discriminator()

# Pass random tensor through forward pass of discriminator
rand_tensor = torch.randn((8,))
print(f'Input: {rand_tensor}')
print(f'Output: {dsc.forward(rand_tensor)}')

Input: tensor([ 0.0312,  1.0590,  0.3862, -0.1136, -1.1375,  1.2425, -1.3496, -0.7329])
Output: tensor([0.5398], grad_fn=<SigmoidBackward>)


### Training Routine
<i> Pseudo code</i>
```python
for epoch in max_epochs:
        train(task learner)
            get preds
            calc loss
            zero grads
            backpropigate loss
            update model parameters
        for step in max_steps:
            train(SVAE)
        for step in max_steps:
            train(discriminator)
```

### Inference Routine

### Sampling Routine