# Sentence Variational Adversarial Active Learning (S-VAAL)
@author: Tyler Bikaun

The following notebook intends to flesh out an initial proof concept of the coupling of S-VAE (Bowman <i>et al.</i> 2016; https://arxiv.org/abs/1511.06349) and VAAL (Sinha <i>et al.</i> 2019; https://arxiv.org/abs/1904.00370)

<b>Application:</b> Named Entity Task (NER)

<b>Architecture Diagram:</b>

### Standard Imports

In [1]:
import random
import numpy as np
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.nn.utils.rnn as rnn_utils
import torch.optim as optim

Tensor = torch.Tensor
torch.manual_seed(1)

<torch._C.Generator at 0x17ad32895e8>

## Testing Functions
These functions are used for spot testing code whilst developing. For example, building random sequences of tensors.
- [x] Build artificial sequence generator

In [2]:
class Tester:
    def __init__(self):
        self.pad_idx = 0
        self.special_chars_list = [self.pad_idx]
        self.no_output_classes = 4
        self.label_space_size = self.no_output_classes + len(self.special_chars_list)
        
        print(f'Using label space size of: {self.label_space_size}')

        
    def build_sequences(self, batch_size: int, max_seq_len: int) -> Tensor:
        """
        Builds tensor of specified size containing variable length, padded, sequences of integers
            
        Arguments
        ---------
            batch_size : int
                Number of sequences to generate
            max_seq_len : int
                Maximum length of sequences
        Returns
        -------
            sequences : tensor
                Tensor of generated sequences
            lengths : tensor
                Tensor of sequence lengths
        """
        seqs = list()
        for i in range(batch_size):
            # Generate random integer sequences
            # sequence must be at least 1 token long...
            seq = np.random.randint(low=1, high=100, size=(random.randint(1, max_seq_len),))
            # Add padding
            seq = np.concatenate((seq, np.ones(shape=(max_seq_len - len(seq)))*self.pad_idx), axis=None)
            seqs.append(seq)
        sequences = torch.LongTensor(seqs)
        lengths = torch.tensor([len(seq[seq != self.pad_idx]) for seq in sequences])
        
        print(f'Shapes - seq {sequences.shape} - lengths {lengths.shape}')
        
        return sequences, lengths
    
    def build_sequence_tags(self, sequences: Tensor, lengths: Tensor) -> Tensor:
        """
        Given a set of sequences, generates ground truth labels
        
        Labels need to be non-zero (otherwise get confused with special characters; currnetly only concerned about 0 = PAD)
        
        Arguments
        ---------
            sequences : tensor
                Tensor of generated sequences
            label_space_size : int
                Size of label space
        Returns
        -------
            X, lengths, y : list of tuples
                Artificial ground truth dataset
                    X dim : (seq len, batch size )
                    lengths dim : (batch size)
                    y dim : (batch size, 1)
        """
        
        dataset = list()    # stores batch of data (X, lens, y)
        
        global_label_list = list()
        
        for sequence in sequences:
            # Each 'token' in the sequence has a label mapping
            label_list = list()
            for token in sequence:
                if token != self.pad_idx:   # don't give a label to any padding...
                    label_list.append(random.randint(1, self.label_space_size-1))   # need to minus 1 as output loss function indexes from 0 to n_class - 1
                else:
                    label_list.append(self.pad_idx)
            
            global_label_list.append(torch.LongTensor(label_list))
        
        global_label_tensor = torch.stack(global_label_list)
        dataset.append((sequences, lengths, global_label_tensor))   # stack list of labels into tensors
        return dataset

In [3]:
# Test functionality
tester = Tester()
sequences, lengths = tester.build_sequences(batch_size=1, max_seq_len=10)
dataset = tester.build_sequence_tags(sequences=sequences, lengths=lengths)

Using label space size of: 5
Shapes - seq torch.Size([1, 10]) - lengths torch.Size([1])


In [4]:
for X, length, y in dataset:
    print(f'X {X.shape} ({X.dtype}) - {X}')
    print(f'y {y.shape} ({y.dtype}) - {y}')
    print(f'lengths {length.shape} ({length.dtype}) - {length}')

X torch.Size([1, 10]) (torch.int64) - tensor([[70, 26,  7, 12, 18,  0,  0,  0,  0,  0]])
y torch.Size([1, 10]) (torch.int64) - tensor([[3, 2, 3, 2, 1, 0, 0, 0, 0, 0]])
lengths torch.Size([1]) (torch.int64) - tensor([5])


### Utility Functions

- [ ] Build data preprocessor
- [ ] Build data loaders

In [5]:
# configuration for model building, training, evaluation process. This will be converted into yaml.
config = {'': ''}

### Models
<i>Model architectures</i><br>
<b>SVAE</b> - RNN<br>
<b>Discriminator</b> - FC NN<br>
<b>Task Learner</b> - RNN<br>
- [ ] SVAE
- [ ] Discriminator
- [ ] Task Learner

#### SVAE
To do:
 - [ ] 

In [6]:
class SVAE(nn.Module):
    """ Sentence Variational Autoencoder (Bowman et al. 2016)"""
    
    def __init__(self, vocab_size, embedding_size):
        super(SVAE, self).__init__()
        
        # TODO: fix dodgy vocab_size issue... this will be cleared up when utils implemented properly
        
        self.tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
        
        self.max_sequence_length = 40    # arg
        self.pad_idx = 0
        self.eos_idx = vocab_size + 1
        self.sos_idx = vocab_size + 2
        self.unk_idx = vocab_size + 3
        
        self.vocab_size = vocab_size + 4
        
        self.z_dim = 8
        
        self.rnn_type = 'gru'
        self.bidirectional = False
        self.num_layers = 1
        self.hidden_size = 128
        
        self.embedding = nn.Embedding(self.vocab_size, embedding_size)
        self.word_dropout_rate = 0.1
        self.embedding_dropout = nn.Dropout(p=0.5)
        
        # set rnn type
        if self.rnn_type == 'gru':
            rnn = nn.GRU
        else:
            raise ValueError()
        
        # init encoder-decoder RNNs (models are identical)
        self.encoder_rnn = rnn(embedding_size,
                               self.hidden_size, 
                               num_layers=self.num_layers,
                               bidirectional=self.bidirectional,
                               batch_first=True)
        self.decoder_rnn = rnn(embedding_size,
                               self.hidden_size, 
                               num_layers=self.num_layers,
                               bidirectional=self.bidirectional,
                               batch_first=True)

        self.hidden_factor = (2 if self.bidirectional else 1) * self.num_layers
        
        # Initialisation of FC layers
        # These go from encoder to latent (z) space
        self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor, self.z_dim)
        self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor, self.z_dim)
        self.z2hidden = nn.Linear(self.z_dim, self.hidden_size * self.hidden_factor)
        self.outputs2vocab = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.vocab_size)
        
        # init partial loss function
        self.NLL = nn.NLLLoss(ignore_index=self.pad_idx, reduction='sum')   # TODO: REVIEW args
    
    
    def forward(self, input_sequence, length):
        """ Forward pass through VAE """
        
        batch_size = input_sequence.size(0)
        sorted_lengths, sorted_idx = torch.sort(length, descending=True)   # trick for packed padding
        input_sequence = input_sequence[sorted_idx]
        
        # ENCODER
        input_embedding = self.embedding(input_sequence)
#         print(input_embedding.shape)
        packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)
        _, hidden = self._encode(packed_input)
        
        if self.bidirectional or 1 < self.num_layers:
            # flatten hidden state
            hidden = hidden.view(batch_size, self.hidden_size * self.hidden_factor)
        else:
            # .squeeze() -> Returns a tensor with all the dimensions of input of size 1 removed.
            print(f'hidden shape before squeeze {hidden.shape}')
#             hidden = hidden.squeeze()   # doesn't work? gives wrong dimension down stream...
            pass
            print(f'hidden shape after squeeze {hidden.shape}')

        
        # Reparameterisation trick!
        z, mean, logv, std = self.reparameterise(hidden, batch_size)
        
        # DECODER
        if 0 < self.word_dropout_rate:
            prob = torch.rand(input_sequence.size())

            if torch.cuda.is_available():
                prob = prob.cuda()

            prob[(input_sequence.data - self.sos_idx) * (input_sequence.data - self.pad_idx) == 0] = 1

            decoder_input_sequence = input_sequence.clone()
            
#             print(vocab_size)
#             print(self.unk_idx)
            decoder_input_sequence[prob < self.word_dropout_rate] = self.unk_idx

#             print(decoder_input_sequence)
            input_embedding = self.embedding(decoder_input_sequence)

        input_embedding = self.embedding_dropout(input_embedding)
        packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)
        
        outputs, _ = self._decode(packed_input, hidden)
        
        # process outputs
        # Process outputs
        # Unpack padded sequence
        padded_outputs = rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0]
        padded_outputs = padded_outputs.contiguous()
        _, reversed_idx = torch.sort(sorted_idx)
        padded_outputs = padded_outputs[reversed_idx]
        b, s, _ = padded_outputs.size()

        # Project outputs to vocab
        # e.g. project hidden state into label space...
        logp = nn.functional.log_softmax(self.outputs2vocab(padded_outputs.view(-1, padded_outputs.size(2))), dim=-1)
#         print(f'logp before view {logp.shape}\n')
#         print(f'b {b} s {s} no emb {self.embedding.num_embeddings}')
        logp = logp.view(b, s, self.embedding.num_embeddings)

        # logp - log posterior over label space; mean - tensor Gaussian mean, logv - tensor Gaussian variance, z - VAE latent space 
        return logp, mean, logv, z
    
    def to_var(self, x):
        if torch.cuda.is_available():
            x = x.cuda()
        return x
    
    def kl_anneal_function(self, anneal_function, step, k, x0):
        """
        
        
        """
        if anneal_function == 'logistic':
            return float(1/(1+np.exp(-k*(step-x0))))
        elif anneal_function == 'linear':
            return min(1, step/x0)
        
    def loss_fn(self, logp, target, mean, logv, anneal_function, step, k, x0):
        """
        SVAE loss function
        
        NLL - Negative Log Likelihood loss between predicted tags and ground truth
        KL_Loss - Evidence Lower Bound (ELBO)? - TODO: validate interpretation
        
        """
        
        # insert assertion
        
        # ALL THE CODE BELOW NEEDS REVIEW and piece wise implemented
        # Cut-off unnecessary padding from target and flatten
        target = target[:, :torch.max(length).item()].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))
        
        # Negative log likelihood
        NLL_loss = self.NLL(logp, target)
        
        # KL Divergence (or ELBO? TODO: determine which)
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = self.kl_anneal_function(anneal_function, step, k, x0)
        
        return NLL_loss, KL_loss, KL_weight
    
    def reparameterise(self, hidden, batch_size):
        """ Implement reparameterisation trick (Kingma and Welling 2014) """
        
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        std = torch.exp(0.5 * logv) 
        
        z = self.to_var(torch.randn([batch_size, self.z_dim]))
        return z * std + mean, mean, logv, std
    
    def _encode(self, x):
        """ x - pack padded sequence """
        return self.encoder_rnn(x)
    
    def _decode(self, x, hidden):
        """ x - pack padded sequence
            hidden - latent tensor"""
        return self.decoder_rnn(x, hidden)

In [7]:
# Testing functionality
vocab_size = 100
hidden_size = 128
svae = SVAE(vocab_size, hidden_size)

In [8]:
# Test forward pass of SVAE
svae_ff_seqs, svae_ff_lengths = Tester().build_sequences(batch_size=10, max_seq_len=40)

# Pass sequences and lengths into SVAE forward method
print(svae.forward(svae_ff_seqs, svae_ff_lengths))

Using label space size of: 5
Shapes - seq torch.Size([10, 40]) - lengths torch.Size([10])
hidden shape before squeeze torch.Size([1, 10, 128])
hidden shape after squeeze torch.Size([1, 10, 128])
(tensor([[[-4.5983, -4.4916, -4.3991,  ..., -4.5816, -4.7353, -4.3919],
         [-4.6237, -4.5505, -4.5878,  ..., -4.7848, -4.7002, -4.6800],
         [-4.5375, -4.6115, -4.8641,  ..., -5.0021, -5.1341, -4.5905],
         ...,
         [-4.6752, -4.6746, -4.5652,  ..., -4.5551, -4.6831, -4.5759],
         [-4.6752, -4.6746, -4.5652,  ..., -4.5551, -4.6831, -4.5759],
         [-4.6752, -4.6746, -4.5652,  ..., -4.5551, -4.6831, -4.5759]],

        [[-4.4114, -4.4459, -4.8041,  ..., -4.6162, -4.9344, -4.9172],
         [-4.5462, -4.7396, -4.5475,  ..., -4.7223, -4.8348, -4.5941],
         [-4.5065, -4.7213, -4.5215,  ..., -4.5295, -5.0201, -4.8757],
         ...,
         [-4.9214, -4.6634, -4.6963,  ..., -4.5376, -4.7926, -4.6830],
         [-4.6752, -4.6746, -4.5652,  ..., -4.5551, -4.6831, -4.

In [9]:
# Reviewing components
print(f'Encoder: {svae.encoder_rnn}')
print(f'Decoder: {svae.decoder_rnn}')
print(f'hidden2mean: {svae.hidden2mean}')
print(f'hidden2logv: {svae.hidden2logv}')
print(f'z2hidden: {svae.z2hidden}')
print(f'outputs2vocab: {svae.outputs2vocab}')

Encoder: GRU(128, 128, batch_first=True)
Decoder: GRU(128, 128, batch_first=True)
hidden2mean: Linear(in_features=128, out_features=8, bias=True)
hidden2logv: Linear(in_features=128, out_features=8, bias=True)
z2hidden: Linear(in_features=8, out_features=128, bias=True)
outputs2vocab: Linear(in_features=128, out_features=104, bias=True)


#### Task Learner
Using PyTorch tutorial implementation (https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html) for rapid development (will implement SoTA in the future)<br><br>
To do:
 - [ ] Make batch based to suit other models rather than training on singletons

In [10]:
class TaskLearner(nn.Module):
    """ Task learner for NER """
    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super(TaskLearner, self).__init__()
        self.hidden_dim = hidden_dim

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # The LSTM takes word embeddings as inputs, and outputs hidden states with dimensionality hidden_dim.
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)    # expects input (batch, seq, feature)

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

    def forward(self, input_sequences, lengths):
        """
        Pack padded sequence and sorted lengths used to optimise forward pass through RNN
        
        
        input_sequence : tensor
            (batch, seq, feature)
        
        """

        batch_size = input_sequences.size(0)
        sorted_lengths, sorted_idx = torch.sort(lengths, descending=True)
        input_sequences = input_sequences[sorted_idx]
        input_embeddings = self.word_embeddings(input_sequences)
        
        packed_input = rnn_utils.pack_padded_sequence(input_embeddings, sorted_lengths.data.tolist(), batch_first=True)
        
        lstm_out, _ = self.lstm(packed_input)
        
        # Unpack padded sequence
        padded_outputs = rnn_utils.pad_packed_sequence(lstm_out, batch_first=True)[0]
        padded_outputs = padded_outputs.contiguous()
        _, reversed_idx = torch.sort(sorted_idx)
        padded_outputs = padded_outputs[reversed_idx]
        b, s, _ = padded_outputs.size()
        
        # project into label space
        tag_space = self.hidden2tag(padded_outputs.view(-1, padded_outputs.size(2)))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores

In [11]:
# Testing functionality
tasklearner = TaskLearner(embedding_dim=128, hidden_dim=128, vocab_size=104, tagset_size=4)
# Generating data for testing
seqs, lens = tester.build_sequences(batch_size=32, max_seq_len=10)   # Tester initialised at start of notebook
dataset = tester.build_sequence_tags(sequences=seqs, lengths=lens)

Shapes - seq torch.Size([32, 10]) - lengths torch.Size([32])


In [12]:
# Generate vocab off of generated sequences
# TODO: Add to utility functions / tester class
vocab = list()
for seq in seqs:
      vocab.extend(seq.tolist())
vocab = list(set(vocab))
print(vocab)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 23, 24, 27, 28, 29, 30, 32, 33, 34, 35, 36, 37, 39, 40, 41, 42, 43, 45, 46, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 64, 66, 67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 80, 81, 82, 83, 84, 85, 87, 88, 89, 90, 91, 93, 94, 96, 97, 99]


In [16]:
# Test training routine
# vocab_size is max int in vocab ints + 1 as 0 is included...
# tagset size is wrong...
model = TaskLearner(embedding_dim=128, hidden_dim=128, vocab_size=max(vocab)+1, tagset_size=6)   # tagset size = special characters + number of class types (0 to n_class - 1)
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

In [17]:
for epoch in range(5):
    for batch_seqs, batch_lens, batch_tags in dataset:
        print(f'\nEpoch: {epoch}')
        print(f'Shapes | Seq: {batch_seqs.shape} Lengths: {batch_lens.shape} Tags: {batch_tags.shape}')

        model.zero_grad()

        # Get max length of longest sequence in batch so it can be used to filter tags
        sorted_lengths, sorted_idx = torch.sort(batch_lens, descending=True)   # longest seq at index 0
        longest_seq = sorted_lengths[0].data.numpy()
        longest_seq_len = longest_seq[longest_seq != 0][0]   # remove padding (TODO: change to pad_idx in the future)
        
        # Get predictions from model
        tag_scores = model(batch_seqs, batch_lens)
        
        # Strip off as much padding as possible similar to (variable length sequences via pack padded methods)
        batch_tags = torch.stack([tags[:longest_seq_len] for tags in batch_tags])
        batch_tags = batch_tags.view(-1)
        
        # Calculate loss and backpropigate error through model
        loss = loss_function(tag_scores, batch_tags)
        loss.backward()
        optimizer.step()
        
        print(f'NLL Loss: {loss.data.detach():0.2f}')


Epoch: 0
Shapes | Seq: torch.Size([32, 10]) Lengths: torch.Size([32]) Tags: torch.Size([32, 10])
NLL Loss: 1.80

Epoch: 1
Shapes | Seq: torch.Size([32, 10]) Lengths: torch.Size([32]) Tags: torch.Size([32, 10])
NLL Loss: 1.78

Epoch: 2
Shapes | Seq: torch.Size([32, 10]) Lengths: torch.Size([32]) Tags: torch.Size([32, 10])
NLL Loss: 1.77

Epoch: 3
Shapes | Seq: torch.Size([32, 10]) Lengths: torch.Size([32]) Tags: torch.Size([32, 10])
NLL Loss: 1.76

Epoch: 4
Shapes | Seq: torch.Size([32, 10]) Lengths: torch.Size([32]) Tags: torch.Size([32, 10])
NLL Loss: 1.75


#### Discriminator
To do:
 - [ ] Implement...

In [18]:
class Discriminator(nn.Module):
    """ Adversary architecture for discriminator module 
    
    Tensor needs to be FloatTensor; sequences and tag sets are LongTensor
    """
    
    def __init__(self, z_dim=8):
        super(Discriminator, self).__init__()
        
        self.z_dim = z_dim    # latent space dimension (will be the same shape as the encoder output from VAE)
        
        self.net = nn.Sequential(
                                nn.Linear(z_dim, 128),
                                nn.ReLU(True),
                                nn.Linear(128, 128),
                                nn.ReLU(True),
                                nn.Linear(128,1),
                                nn.Sigmoid()
                                )
        # Exe
        self.weight_init()
        
    def weight_init(self):
        """ Weight initialisation
        
        Using Xavier uniform initialisation rather than Kaiming (I think that is more focused for CV? TODO: investigate)
        See: https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1ace282f75916a862c9678343dfd4d5ffe.html
        """
        for block in self._modules:
            for m in self._modules[block]:
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform_(m.weight)
                    m.bias.data.fill_(0.01)
    
    def forward(self, z):
        """ Forward pass through discriminator
        
        Arguments
        --------
            z : tensor
                Tensor derived from SVAE latent space
        """
        return self.net(z.type(torch.FloatTensor))

In [19]:
# Testing functionality
z_dim = 8
discriminator = Discriminator(z_dim=z_dim)

# Pass random integer tensor through forward pass of discriminator
rand_tensor = torch.randint(low=0, high=100,size=(z_dim,))
print(f'Input: {rand_tensor}')
print(f'Output: {discriminator.forward(rand_tensor)}')

Input: tensor([17, 20,  1, 26, 65, 53,  2, 76])
Output: tensor([5.9446e-11], grad_fn=<SigmoidBackward>)


In [20]:
# Testing training functionality for discriminator
dsc_bce_loss = nn.BCELoss()
dsc_optim = optim.Adam(discriminator.parameters(), lr=0.001)

discriminator.train()

preds = discriminator(rand_tensor)
real_labels = torch.ones(preds.size(0))
loss = dsc_bce_loss(preds, real_labels)

print(loss)

tensor(23.5460, grad_fn=<BinaryCrossEntropyBackward>)


## Sampler Routine
Active learning based sample selection for task learner

In [21]:
# Code copied from VAAL and modified for sequence data
class Sampler:
    """ Adversary sampler """
    def __init__(self, budget):
        self.budget = budget
        
    def sample(self, vae, discriminator, data, cuda):
        """ Selective sampling algorithm
        
        Arguments
        ---------
            vae : torch model
                VAE model
            discriminator : torch model
                discriminator model
            data : tensor
                Image data
            cuda : boolean
                GPU flag
        Returns
        -------
            querry_pool_indices: int, list
                List of indices corresponding to sorted (top-K) samples to be sampled from
        """
        all_preds = []
        all_indices = []

        for images, _, indices in data:
            if cuda:
                images = images.cuda()

            with torch.no_grad():
                _, _, mu, _ = vae(images)
                preds = discriminator(mu)

            preds = preds.cpu().data
            all_preds.extend(preds)
            all_indices.extend(indices)

        all_preds = torch.stack(all_preds)
        all_preds = all_preds.view(-1)
        # need to multiply by -1 to be able to use torch.topk 
        all_preds *= -1

        # select the points which the discriminator things are the most likely to be unlabeled
        _, querry_indices = torch.topk(all_preds, int(self.budget))
        querry_pool_indices = np.asarray(all_indices)[querry_indices]

        return querry_pool_indices

In [22]:
# Testing functionality
sampler = Sampler(budget=10)

### Training Routine
<i> Pseudo code</i>

```python
for epoch in max_epochs:
        train(task learner)
            get preds
            calc loss
            zero grads
            backpropigate loss
            update model parameters
        for step in max_steps:
            train(SVAE)
        for step in max_steps:
            train(discriminator)
```
To do:
 - [ ] Implement training cycle for task learn
 - [ ] Implement training cycle for VAE
 - [ ] Implement training cycle for discriminator

In [25]:
class Solver(Tester):
    def __init__(self, task_learner, vae, discriminator):
        Tester.__init__(self)   # need tester class to access properties such as label_set_size
        
        # params (TODO: move to config)
        self.epochs = 2
        self.svae_steps = 2
        self.discriminator_steps = 2 
        
        # Get models
        # init else where?
        self.task_learner =  TaskLearner(embedding_dim=128, hidden_dim=128, vocab_size=max(vocab)+1, tagset_size=self.label_space_size)
        self.svae = SVAE(vocab_size=max(vocab)+1, embedding_size=128)
        self.discriminator = Discriminator(z_dim=10)   # z_dim will be the latent output of the vae encoder... currently just the seq length
        
        # Loss functions
        self.nll_loss = nn.NLLLoss()   # used in: TL
        self.bce_loss = nn.BCELoss()   # used in: Discriminator
        self.xxx_loss = 0    # used in: SVAE

        # optimiser
        self.optim_lr = 0.1   # don't want a global lr...
        self.tl_optim = optim.SGD(model.parameters(), lr=self.optim_lr)   # todo: revisit learning rate
        self.svae_optim = None
        self.disc_optim = optim.Adam(self.discriminator.parameters(), lr=0.01)   # todo: revisit learning rate
        
        # Will need to implement this in the future, need to build generator function rather than rely on Torch implementationss?
        self.data_loader = 'generator'
        self.seqs, self.lens = tester.build_sequences(batch_size=2, max_seq_len=10)   # Tester initialised at start of notebook
        self.dataset = tester.build_sequence_tags(sequences=self.seqs, lengths=self.lens)
        
        # Will need to implement this in the future, inherit the Sampler class?
        self.sampler = 'sampler'
        
    def train(self):
        """ Performs model training """
        
        # turn on .train() mode
        self.svae.train()
        self.task_learner.train()
        self.discriminator.train()
        
        for epoch in range(self.epochs):
#             print(f'\nEpoch: {epoch}')
            
            # these will be batches in the future, not single sequence/tag sets...
            # build generator to wrap around sequence/tag sets and then just call next() on them at each epoch to get a batch...
            # need to build that in the future to partition based on the unlabelled/labelled splits for AL experiments
            for batch_seqs, batch_lens, batch_tags in self.dataset:
                print(f'Dataset shapes\nSeqs:{batch_seqs.shape}\tLens:{batch_lens.shape}\tTags:{batch_tags.shape}')
            
                # Train Task Learner
                self.tl_train_step(batch_seqs, batch_lens, batch_tags)
            
                # Train VAE
                for step in range(self.svae_steps):
                    print(f'SVAE Step: {step}')
                    self.svae_train_step(batch_seqs, batch_lens, batch_tags)

                # Train Discriminator
                for step in range(self.discriminator_steps):
                    print(f'Discriminator Step: {step}')

                    self.disc_train_step(batch_seqs)
            clear_output(wait=True)
    
    def trim_padded_tags(self, batch_lengths: Tensor, batch_tags: Tensor) -> Tensor:
        """ Takes a batch of sequences and tags and trims similar to pack padded sequence method """
        
        # Get max length of longest sequence in batch so it can be used to filter tags
        sorted_lengths, _ = torch.sort(batch_lengths, descending=True)   # longest seq at index 0
        longest_seq = sorted_lengths[0].data.numpy()
        longest_seq_len = longest_seq[longest_seq != self.pad_idx][0]   # remove padding
        
        # Strip off as much padding as possible similar to (variable length sequences via pack padded methods)
        batch_tags = torch.stack([tags[:longest_seq_len] for tags in batch_tags])
    
        return batch_tags
    
    def tl_train_step(self, batch_seqs, batch_lengths, batch_tags):
        """
        Perform training step of task learner 
        
        Tensor dtype is LongTensor
        """
#         print(seq.dtype)
#         assert seq.dtype == 'torch.int64'
#         assert tags.dtype == 'torch.int64'
        
        # zero grad as they accumulate in PyTorch
        self.task_learner.zero_grad()
        
        preds = self.task_learner(batch_seqs, batch_lengths)
        
        batch_tags = self.trim_padded_tags(batch_lengths, batch_tags).view(-1)
        
        print(batch_seqs, batch_lengths, batch_tags)
        
        tl_loss = self.nll_loss(preds, batch_tags)
        tl_loss.backward()
        self.tl_optim.step()
        
        print(f'Task learning loss: {loss.data:0.4f}')
    
    def svae_train_step(self, batch_seqs, batch_lengths, batch_tags):
        """
        Perform training step of sentence variational autoencoder
        
        Tensor type is LongTensor
        """
        # Forward pass
        logp, mean, logv, z = self.svae(batch_seqs, batch_lengths)
        
        # Loss calculation
#         NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch_tags, batch_lengths, mean, logv, anneal_function, step, k, x0)

        print(f'logp: {logv} mean: {mean} logv: {logv} z: {z}')
        
    
    def disc_train_step(self, seqs):
        """ Perform training step of adversarial discriminatory 
        
        Note: The loss function will be an aggregate across labelled and unlabelled samples (currently just labelled)
        
        Tensor dtype is FloatTensor for discriminator
        """
        
#         assert seq.dtype == 'torch.float64'
        
#         print(seq)
#         print(seq.dtype)
        
        preds = self.discriminator(seqs)
#         print(preds)
        real_labels = torch.ones(preds.size(0))
        disc_loss = self.bce_loss(preds, real_labels)
        disc_loss.backward()
        self.disc_optim.step()
        
        print(f'Discriminator loss: {loss.data}')

In [26]:
# Test training 
slvr = Solver(task_learner=None, vae=None, discriminator=None)
slvr.train()

Dataset shapes
Seqs:torch.Size([2, 10])	Lens:torch.Size([2])	Tags:torch.Size([2, 10])
tensor([[14, 27, 42, 49, 16, 58, 86,  0,  0,  0],
        [ 3, 86, 34,  5, 56, 26, 98, 49,  0,  0]]) tensor([7, 8]) tensor([3, 4, 3, 3, 3, 3, 3, 0, 3, 2, 3, 1, 3, 3, 2, 2])
Task learning loss: 23.5460
SVAE Step: 0
hidden shape before squeeze torch.Size([1, 2, 128])
hidden shape after squeeze torch.Size([1, 2, 128])
logp: tensor([[[-0.0424,  0.0393,  0.1499, -0.2662,  0.0157, -0.0516, -0.1770,
           0.1174],
         [ 0.0196,  0.2950, -0.0080, -0.0985, -0.0549, -0.0729,  0.0613,
          -0.0513]]], grad_fn=<AddBackward0>) mean: tensor([[[-0.2073, -0.0854, -0.0096,  0.0317,  0.0788,  0.0017,  0.0800,
           0.1587],
         [-0.1718,  0.0958, -0.2266, -0.1947, -0.1373,  0.0193,  0.0474,
          -0.1937]]], grad_fn=<AddBackward0>) logv: tensor([[[-0.0424,  0.0393,  0.1499, -0.2662,  0.0157, -0.0516, -0.1770,
           0.1174],
         [ 0.0196,  0.2950, -0.0080, -0.0985, -0.0549, -0.0729

### Inference Routine
To do:
 - [ ] 


### Sampling Routine
To do:
 - [ ] 