# TransformerGAN

An implementation of text GAN which uses BERT sequence classifier as the discriminator and OpenAI's GPT-2 as the generator 

In [1]:
import os, datetime
import sys
import torch
import random
import argparse
import numpy as np
from gpt2Pytorch.GPT2.model import GPT2LMHeadModel
from gpt2Pytorch.GPT2.utils import load_weight
from gpt2Pytorch.GPT2.config import GPT2Config
from gpt2Pytorch.GPT2.sample import sample_sequence
from gpt2Pytorch.GPT2.encoder import get_encoder
from pytorch_pretrained_bert import GPT2Tokenizer
from torch import nn
from pytorch_pretrained_bert import BertTokenizer, BertForSequenceClassification, BertAdam, OpenAIAdam
import pandas as pd
from datetime import datetime
from torch.nn import functional as F
from pandas import Series, DataFrame

In [2]:
torch.__version__ # should be 1.4.0

'1.4.0'

In [3]:
class TrtansformerGAN(object):
    def __init__(self, dataframe, bert_tokenizer, bert_classifier, gpt_tokenizer, gpt_generator, num_labels):
        self.dataframe = dataframe
        self.device_default = device
        
        # Build discriminator and tokenizer from BertForSequenceClassification
        self.bert_tokenizer = bert_tokenizer
        self.discriminator = nn.DataParallel(bert_classifier).to(self.device_default)
        self.bert_optimizer = BertAdam(self.discriminator.parameters(), lr = 0.00005, warmup = 0.1, t_total = 1000)
        
        # Build the generator, tokenizer, optimizer from OpenAIGPT2
        self.gpt2_tokenizer = gpt_tokenizer
        self.generator = gpt_generator.to(self.device_default)
        self.gpt2_optimizer = OpenAIAdam(self.generator.parameters(), lr = 0.0001, warmup = 0.1, t_total = 1000)
        
        # Free all GPU memory
        torch.cuda.empty_cache()

    def textGeneration(self, generator_input):
        text_id = generator_input
        input, past = torch.tensor([text_id]).to(self.device_default), None
        for _ in range(random.randint(30, 100)):
            logits, past = self.generator(input, past = past)
            input = torch.multinomial(F.softmax(logits[:, -1]), 1)
            text_id.append(input.item())
        return self.gpt2_tokenizer.decode(text_id)
    
    def dataGenerator(self, batch_size = 16):
        # Randomly fetch traning data bunch
        sample_text_ss = self.dataframe['text'].iloc[random.sample(range(len(self.dataframe)), batch_size)]
        
        # Tokenize training data bunch with GPT2 tokenizer and take top 10 words
        sample_text_encode_top10 = sample_text_ss.map(lambda x : self.gpt2_tokenizer.encode(x)[:10])
        
        # Generate text using GPT2 generator
        sample_text_generate_ss = sample_text_encode_top10.map(self.textGeneration)
        return sample_text_generate_ss, sample_text_ss
    
    def discriminatorInput(self, text):
        input_token = ['[CLS]'] + self.bert_tokenizer.tokenize(text) + ['[SEP]']
        input_id = self.bert_tokenizer.convert_tokens_to_ids(input_token)
        return [input_id]
    
    def saveGeneratedText(self):
        content = self.dataframe['text'].values[random.randint(0, len(self.dataframe))]
        content_id = self.gpt2_tokenizer.encode(content)[:10]
        gen_content = self.textGeneration(content_id)
        gen_content.strip()
        return gen_content, content
        
    def train(self, num_epochs = 10, save_interval = 2):
        start = datetime.now()
        generated_text_list = []
        real_text_list = []
        d_loss_list = []
        g_loss_list = []

        for epoch in range(num_epochs):
            try:
                print('Epoch {}/{}'.format(epoch + 1, num_epochs))
                print('-' * 10)

                # Load in data
                sample_text_generate_ss, sample_text_ss = self.dataGenerator(batch_size = 16)

                # Convert generated text and real text bunch to WorkPiece encode ID as discriminator input
                discriminator_input_ss = pd.concat([sample_text_generate_ss, sample_text_ss], axis = 0, ignore_index = True).map(self.discriminatorInput)
                discriminator_input = torch.LongTensor(np.array(DataFrame(discriminator_input_ss.sum()).fillna(0).astype('int32'))).to(self.device_default)
                discriminator_input_generate = discriminator_input[:len(sample_text_generate_ss)].to(self.device_default)

                # Create labels for training discriminator and generator
                labels = torch.LongTensor([0] * len(sample_text_generate_ss) + [1] * len(sample_text_ss)).to(self.device_default)
                valid = torch.LongTensor([1] * len(sample_text_ss)).to(self.device_default)

                # Each epoch has a train_discriminator and train_generator phase
                for phase in ['train_discriminator', 'train_generator']:
                    if phase == 'train_discriminator':
                        # Set discriminator to training mode
                        self.discriminator.train()

                        # Freeze all trainable parameters
                        for param in self.discriminator.parameters():
                            param.requires_grad = True

                        # Zero the discriminator parameter gradients
                        self.bert_optimizer.zero_grad()

                        # Forward propagation
                        d_loss = self.discriminator(input_ids = discriminator_input, labels = labels).mean()

                        # Backward propagation
                        d_loss.backward()
                        self.bert_optimizer.step()

                    else:
                        # Set discriminator to evaluate mode
                        self.discriminator.eval()

                        # Zero the generator parameter gradients
                        self.gpt2_optimizer.zero_grad()

                        # Forward propagation
                        g_loss = self.discriminator(input_ids = discriminator_input_generate, labels = valid).mean()

                        # Backward propagation
                        g_loss.backward()
                        self.gpt2_optimizer.step()                    

                # Plot the progress
                print('Discriminator Loss:', d_loss)
                print('Generator Loss:', g_loss)
                print()
                d_loss_list.append(d_loss)
                g_loss_list.append(g_loss)

                # If at save interval, then save generated text samples
                if epoch % save_interval == 0:
                    generated_text, real_text = self.saveGeneratedText()
                    
                    file_object = open('gen_textlog.txt', 'a')
                    file_object.write(generated_text)
                    file_object.write('\n----------------------------------------------------------\n')
                    file_object.close()
                    
                    # generated_text_list.append(generated_text)
                    # real_text_list.append(real_text)
            except RuntimeError:
                pass

        # Counting time elapsed
        time_delta = datetime.now() - start
        print('Training completed time:', time_delta)

        return self.generator, self.discriminator, d_loss_list, g_loss_list, generated_text_list, real_text_list

In [4]:
bert_pretrained_model_path = 'bertPytorch/bert-base-cased'
gpt2_model_path = 'gpt2Pytorch/gpt2-pytorch_model.bin'
gpt2_vocab_path = 'gpt2Pytorch/GPT2-vocab'
assert os.path.exists(bert_pretrained_model_path)
assert os.path.exists(gpt2_model_path)
assert os.path.exists(gpt2_vocab_path)

In [5]:
num_labels = 2
device = torch.device('cuda:0')

bert_tokenizer = BertTokenizer.from_pretrained(bert_pretrained_model_path)
bert_for_seq_classification = BertForSequenceClassification.from_pretrained(bert_pretrained_model_path, num_labels = num_labels)

gpt2_tokenizer = GPT2Tokenizer.from_pretrained(gpt2_vocab_path)
state_dict = torch.load(gpt2_model_path, map_location='cpu' if not torch.cuda.is_available() else None)
enc = get_encoder()
config = GPT2Config()
gpt2_model = GPT2LMHeadModel(config)
gpt2_model = load_weight(gpt2_model, state_dict)
gpt2_model.to(device)
_ = gpt2_model.eval()

In [6]:
df = pd.read_csv('sampledata.csv')
df.shape

(33252, 2)

In [7]:
textgan = TrtansformerGAN(
    dataframe = df, 
    bert_tokenizer = bert_tokenizer, 
    bert_classifier = bert_for_seq_classification, 
    gpt_tokenizer = gpt2_tokenizer, 
    gpt_generator = gpt2_model, 
    num_labels = num_labels)

In [8]:
OpenAIGPT2_generator, BERT_discriminator, d_loss_list, g_loss_list, generated_review_list, real_review_list = textgan.train()

Epoch 1/10
----------




Discriminator Loss: tensor(0.6271, device='cuda:0', grad_fn=<MeanBackward0>)
Generator Loss: tensor(0.7504, device='cuda:0', grad_fn=<MeanBackward0>)

Epoch 2/10
----------
Discriminator Loss: tensor(0.6393, device='cuda:0', grad_fn=<MeanBackward0>)
Generator Loss: tensor(0.7331, device='cuda:0', grad_fn=<MeanBackward0>)

Epoch 3/10
----------
Discriminator Loss: tensor(0.6262, device='cuda:0', grad_fn=<MeanBackward0>)
Generator Loss: tensor(0.7091, device='cuda:0', grad_fn=<MeanBackward0>)

Epoch 4/10
----------
Discriminator Loss: tensor(0.5665, device='cuda:0', grad_fn=<MeanBackward0>)
Generator Loss: tensor(0.7194, device='cuda:0', grad_fn=<MeanBackward0>)

Epoch 5/10
----------
Discriminator Loss: tensor(0.5436, device='cuda:0', grad_fn=<MeanBackward0>)
Generator Loss: tensor(0.7588, device='cuda:0', grad_fn=<MeanBackward0>)

Epoch 6/10
----------
Discriminator Loss: tensor(0.4935, device='cuda:0', grad_fn=<MeanBackward0>)
Generator Loss: tensor(1.0125, device='cuda:0', grad_fn=<M