Generate sentences to augment the dataset
-------------------------------------------------

In this notebook we will try to create a generative-adversarial network which will generate for us new sentences in order to augment the corpora size. We will use the `pytorch-lightning` module to improve the training fastness. 

- The generative model will understand the following characteristics:
    - we will provide the `size of the sequences` to a first model to generate a output of the same size that the given sequences
    - the output will be rounded in order to be transmit to the discriminator
    - we will use a transformer encoder to the generate sentence ids in place of a simple `RNN` module
    - some rules will used on the decoded output in order to obtain the textual sentences

- The discriminative model will be used to verify if the output is close to the true sentences:
    - ~~we will use for that a pre-trained BERT Model to discriminate of the output~~
    - A Multi-Layers Perceptron will be sufficient to discriminate the output
    - we will tokenize the GAN inputs with a WordPiece tokenizer without normalizer because we want to generate texts


    

### Steps

The following steps will be required:

- Create a custom dataset to recuperate the sentences
- Create the generator
- Create the discriminator
- ~~Create the GAN~~
- Create Trainer 
- Search for the best parameters
- Train the model and evaluate it

### Create a custom dataset

Let us use the already trained tokenizer to recuperate the encoded sequences. Note that this dataset is different from that we want to use to train the translation model.

In [1]:
# %%writefile wolof-translate/wolof_translate/data/gan_dataset.py
import torch
import pandas as pd
from torch import nn
from tokenizers import Tokenizer
from torch.utils.data import Dataset

class SentenceDatasetGAN(Dataset):
    
    def __init__(self, file_path: str, corpus_1: str = "french_corpus", corpus_2: str = "wolof_corpus",
                 tokenizer_path: str = "wolof-translate/wolof_translate/tokenizers/adverse_tokenizer.json",
                 cls_token: str = "[CLS]", sep_token: str = "[SEP]", sep: str = ",", **kwargs):
        
        # let us recuperate the data frame
        self.__sentences = pd.read_csv(file_path, sep=sep, **kwargs)
        
        # let us recuperate the tokenizer
        self.tokenizer = Tokenizer.from_file(tokenizer_path)
        
        # recuperate the first corpus' sentences
        self.__sentences_1 = self.__sentences[corpus_1].to_list()
        
        # recuperate the second corpus' sentences
        self.__sentences_2 = self.__sentences[corpus_2].to_list()
        
        # recuperate the special tokens
        self.cls_token = cls_token
        
        self.sep_token = sep_token
        
        # recuperate the length
        self.__length = len(self.__sentences_1)
        
        # recuperate the max id
        self.max_id = self.tokenizer.get_vocab_size() - 1
        
        # let us recuperate the max len
        self.max_len = 0
        
        for i in range(self.__length):
            
            sentence = f"{self.cls_token}{self.__sentences_1[i]}{self.sep_token}{self.__sentences_2[i]}{self.sep_token}"
            
            encoding = self.tokenizer.encode(sentence)
            
            if len(encoding.ids) > self.max_len:
                
                self.max_len = len(encoding.ids)    
        
    def __getitem__(self, index):
        
        sentence_1 = self.__sentences_1[index]
        
        sentence_2 = self.__sentences_2[index]
        
        # let us create the sentence with special tokens
        sentence = f"{self.cls_token}{sentence_1}{self.sep_token}{sentence_2}{self.sep_token}"
        
        # let us encode the sentence
        encoding = self.tokenizer.encode(sentence)
        
        # it will return the padded ids and attention mask
        padding = self.max_len - len(encoding.ids)
        
        ids = torch.tensor(encoding.ids + [0] * padding)
        
        return ids.float(), (ids > 0).float()
        
    def __len__(self):
        
        return self.__length

  from .autonotebook import tqdm as notebook_tqdm


The data loader will generate the padded sequences of ids and the attention masks. Let us test it bellow.

In [2]:
dataset = SentenceDatasetGAN("data/extractions/new_data/sent_extraction.csv")

In [3]:
from torch.utils.data import DataLoader

# let us generate 10 sentences
ids, mask = next(iter(DataLoader(dataset, batch_size=10, shuffle=True)))

print("Ids:")
print(ids)

print("\nMask:")
print(mask)

Ids:
tensor([[2.0000e+00, 3.8000e+02, 1.2406e+04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [2.0000e+00, 2.1820e+03, 3.9460e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [2.0000e+00, 5.2900e+02, 6.2050e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [2.0000e+00, 1.9990e+03, 1.1000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [2.0000e+00, 8.1000e+02, 3.2600e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [2.0000e+00, 5.6500e+02, 6.6000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])

Mask:
tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.]])


### Generator

The generator use a transformer encoder with a d_model, a number of layers, a number of features and activation function specified as arguments. We can also specify a drop out. 

In [4]:
# %%writefile wolof-translate/wolof_translate/models/generative_model.py
from torch.nn import functional as F
from custom_rnn.transformers.add_position import PositionalEncoding
from typing import *
from torch import nn

class SentenceGenerator(nn.Module):
    
    def __init__(self, 
                 output_size: int,
                 d_model: int = 512,
                 latent_dim: Union[int, None] = None,
                 num_features: int = 2048,
                 n_heads: int = 8,
                 dropout: float = 0.0,
                 activation = F.relu,
                 num_layers: int = 6,
                 min: int = 0, max: int = 100):
        
        super(SentenceGenerator, self).__init__()
        
        self.min, self.max = min, max
        
        self.d_model = d_model
        
        self.n_heads = n_heads
        
        self.dropout = dropout
        
        self.activation = activation
        
        
        self.num_layers = num_layers
        
        self.num_features = num_features
        
        self.output_size = output_size
        
        self.latent_dim = latent_dim if not latent_dim is None else self.output_size
        
        
        self.pe = PositionalEncoding(self.latent_dim, self.d_model)
        
        self.encoder_layer = nn.TransformerEncoderLayer(self.d_model,
                                                        self.n_heads,
                                                        self.num_features,
                                                        self.dropout,
                                                        self.activation,
                                                        batch_first=True)
        
        self.encoder = nn.TransformerEncoder(self.encoder_layer, self.num_layers)
        
        self.output_layer = nn.Linear(self.d_model * self.latent_dim, output_size)
        
    def forward(self, input_, attention_mask):
        
        out = self.pe(input_).type_as(next(self.encoder.parameters()))
        
        out = self.encoder(out, src_key_padding_mask = attention_mask).view(-1, self.latent_dim * self.d_model)
            
        out = torch.clip(self.output_layer(out), self.min, self.max).round()
        
        return out
        

Let us test our generative model with dummy input.

In [5]:
generative_model = SentenceGenerator(output_size=dataset.max_len)

In [6]:
# the output must be rounded to the nearest integer and clipped between the lowest and the highest ids
g_output = generative_model(torch.randn((10, 379, 512)), mask)

g_output.size()

torch.Size([10, 379])

#### Discriminator

Let us create a new discriminator model different from the BERT Model. It will take output of the generator without converting it to a long tensor since doing so will make us losing the gradient.

In [7]:
# %%writefile wolof-translate/wolof_translate/models/discriminative_model.py
from torch.nn import functional as F
from typing import *
from torch import nn

class DiscriminatorSequence(nn.Module):
    
    def __init__(self, 
                 input_dim,
                 num_features,
                 negative_slope: float = 0.01,
                 drop_out: float = 0.0,
                 eps: float = 0.00001,
                 momentum: float = 0.1):
        
        super(DiscriminatorSequence, self).__init__()
        
        self.batch_norm = nn.BatchNorm1d(input_dim, eps, momentum)
        
        self.linear = nn.Linear(input_dim, num_features)
        
        self.drop_out = nn.Dropout1d(drop_out)
        
        self.activation = nn.LeakyReLU(negative_slope)
        
        
    def forward(self, input_):
        
        out = self.batch_norm(input_)
        
        out = self.activation(self.drop_out(self.linear(out)))
        
        return out

class SentenceDiscriminator(nn.Module):
    
    def __init__(self, 
                 input_dim: int,
                 num_features: Union[int, List] = 300,
                 num_layers: int = 5,
                 negative_slope: float = 0.01,
                 drop_out: float = 0.0,
                 eps: float = 0.00001,
                 momentum: float = 0.1):
        
        super(SentenceDiscriminator, self).__init__()
        
        self.input_dim = input_dim
        
        self.num_features = [num_features] * num_layers if type(num_features) is int else num_features
        
        assert len(self.num_features) == num_layers
        
        self.num_layers = num_layers
        
        self.sequences = nn.ModuleList()
        
        self.sequences.append(DiscriminatorSequence(input_dim, self.num_features[0], negative_slope, drop_out, eps, momentum))
        
        for l in range(1, num_layers):
            
            self.sequences.append(DiscriminatorSequence(self.num_features[l-1], self.num_features[l], negative_slope, drop_out, eps, momentum))
        
        self.output_layer = nn.Linear(self.num_features[-1], 1)
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, input_: torch.Tensor):
        
        out = input_
        
        for sequence in self.sequences:
            
            out = sequence(out)
        
        out = self.sigmoid(self.output_layer(out))
        
        return out
        

#### Create a new trainer in place of pytorch lightning

We want to create a new runner class in order to make grid search and find the best hyper parameters for generating texts. Let us import the runner bellow with some other handy functions.

In [8]:
from ok_transfer_learning.utils.gan_runner1 import SentenceGANRunner
from ok_transfer_learning.utils.find_keys import find_ghistory_key
import numpy as np


### Hyper parameters search

Let us initialize the runner.

In [9]:
gan_runner = SentenceGANRunner(SentenceGenerator(dataset.max_len, d_model=100, num_features=1024, num_layers=3, n_heads=5), 
                               SentenceDiscriminator(dataset.max_len, num_layers=3), dataset, seed=50)

Let us initialize the hyper parameters.

In [10]:
hparams = {
    'g_lr': np.linspace(1e-5, 0.1, 10).round(5).tolist(),
    'd_lr': np.linspace(1e-5, 0.1, 10).round(5).tolist()
}

Let us search for the best parameters.

In [11]:
gan_runner.make_grid_search(hparams, 10,
                            2, loader_kwargs={"batch_size": 2}
                            )

100%|██████████| 2/2 [00:39<00:00, 19.54s/it]
100%|██████████| 2/2 [00:37<00:00, 18.93s/it]
100%|██████████| 2/2 [00:37<00:00, 18.72s/it]
100%|██████████| 2/2 [00:37<00:00, 18.93s/it]
100%|██████████| 2/2 [00:37<00:00, 18.71s/it]
100%|██████████| 2/2 [00:38<00:00, 19.10s/it]
100%|██████████| 2/2 [00:37<00:00, 18.91s/it]
100%|██████████| 2/2 [00:38<00:00, 19.01s/it]
100%|██████████| 2/2 [00:37<00:00, 18.73s/it]
100%|██████████| 2/2 [00:38<00:00, 19.37s/it]


Let us recuperate the key for `'g_lr': 0.01112, 'd_lr': 0.00001`.

In [12]:
key = find_ghistory_key(gan_runner.grid_history, {'g_lr': 0.01112, 'd_lr': 0.00001})

### Training the model

Let us compile with the retained model.

In [10]:
gan_runner.compile(loader_kwargs={"batch_size": 2},
                #    grid_search_key=key
                   )

Let us save the model.

In [12]:
# gan_runner.save("data/checkpoints/generator/")

Let us load the last saved model.

In [11]:
gan_runner.load("data/checkpoints/generator/")

Let us train the retained model.

In [12]:
gan_runner.train(50, auto_save=True, saving_directory="data/checkpoints/generator/")

  4%|▍         | 2/50 [02:03<43:52, 54.84s/it]   


Generated sentences at epoch 347
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 10%|█         | 5/50 [02:53<18:33, 24.74s/it]


Generated sentences at epoch 350
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 16%|█▌        | 8/50 [03:44<13:28, 19.26s/it]


Generated sentences at epoch 353
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 22%|██▏       | 11/50 [04:36<11:40, 17.97s/it]


Generated sentences at epoch 356
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 28%|██▊       | 14/50 [05:27<10:28, 17.47s/it]


Generated sentences at epoch 359
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 34%|███▍      | 17/50 [06:36<11:57, 21.73s/it]


Generated sentences at epoch 362
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 40%|████      | 20/50 [07:52<11:48, 23.62s/it]


Generated sentences at epoch 365
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 46%|████▌     | 23/50 [08:58<10:06, 22.45s/it]


Generated sentences at epoch 368
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 52%|█████▏    | 26/50 [10:03<08:39, 21.63s/it]


Generated sentences at epoch 371
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 58%|█████▊    | 29/50 [11:01<07:03, 20.15s/it]


Generated sentences at epoch 374
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 64%|██████▍   | 32/50 [12:00<05:55, 19.76s/it]


Generated sentences at epoch 377
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 70%|███████   | 35/50 [12:59<04:56, 19.74s/it]


Generated sentences at epoch 380
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 76%|███████▌  | 38/50 [13:56<03:48, 19.07s/it]


Generated sentences at epoch 383
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 82%|████████▏ | 41/50 [16:52<08:15, 55.06s/it]


Generated sentences at epoch 386
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 88%|████████▊ | 44/50 [28:43<16:23, 163.85s/it]


Generated sentences at epoch 389
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


 94%|█████████▍| 47/50 [29:41<03:26, 68.86s/it] 


Generated sentences at epoch 392
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 


100%|██████████| 50/50 [30:39<00:00, 36.80s/it]


Generated sentences at epoch 395
Sentence 0: 
Sentence 1: 
Sentence 2: 
Sentence 3: 
Sentence 4: 
Sentence 5: 
Sentence 6: 
Sentence 7: 
Sentence 8: 
Sentence 9: 





In [13]:
gan_runner.save("data/checkpoints/generator/")