In [1]:
import torch
import numpy
import pickle
import jieba
import random
#import re
import pandas
from torch.utils.data import Dataset, DataLoader
from torchtext.vocab import build_vocab_from_iterator

In [None]:
class TextDataset(Dataset):

    def __init__(self, root="./CDC.csv", seq_len = 40) -> None:
        super(TextDataset, self).__init__()
        self.frame = pandas.read_csv(root)
        self.seq_len = seq_len
        self.vocab = build_vocab_from_iterator(self.get_vocab(), specials=["<start>", "'"])
        self.stoi = self.vocab.get_stoi()
        self.itos = self.vocab.get_itos()
        self.inputs = self.get_inputs()
        pickle.dump(self.stoi, open("stoi.bin", "wb"))
        pickle.dump(self.itos, open("itos.bin", "wb"))

    def get_inputs(self):
        #data = self.frame['Question'].apply(lambda text : re.sub(r"[^\w\s.]", " ", text)).to_list()
        data = self.frame['Question'].to_list()
        inputs = []
        for text in data:
            text_data = [self.stoi[chars] for chars in jieba.lcut(str(text).strip())]
            text_data.insert(0, self.stoi["<start>"])
            while len(text_data) < self.seq_len:
                text_data.append(self.stoi["'"])
            inputs.append(text_data[:self.seq_len])
        inputs = numpy.array(inputs, dtype=numpy.int32)
        inputs = torch.from_numpy(inputs).long()
        return inputs
    
    def get_vocab(self):
        #data = self.frame['Question'].apply(lambda text : re.sub(r"[^\w\s.]", " ", text)).to_list()
        data = self.frame['Question'].to_list()
        for text in data:
            yield [chars for chars in jieba.lcut(str(text).strip())]

    def __getitem__(self, index):
        return self.inputs[index]

    def __len__(self):
        return len(self.frame)

In [2]:
class NetF(torch.nn.Module):
    def __init__(self, vocab_size, seq_len) -> None:
        super(NetF, self).__init__()
        self.seq_len = seq_len
        self.embedding_layer = torch.nn.Embedding(vocab_size, 64)
        self.fc_layer = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(seq_len * 64, 128),
            torch.nn.Tanh())
    
    def forward(self, inputs):
        outputs = self.embedding_layer(inputs)
        outputs = self.fc_layer(outputs)
        return outputs

In [3]:
class NetG(torch.nn.Module):
    def __init__(self, vocab_size, seq_len) -> None:
        super(NetG, self).__init__()
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.embedding_layer = torch.nn.Embedding(vocab_size, 128)
        self.rnn_layer = torch.nn.GRU(128, 256, batch_first=True)
        self.fc_layer = torch.nn.Sequential(
            torch.nn.Linear(256, vocab_size),
            torch.nn.Softmax(1))
    
    def init_hiddens(self, batch_size, use_cuda = True):
        if use_cuda:
            return torch.zeros(1, batch_size, 256).cuda()
        return torch.zeros(1, batch_size, 256).cpu()

    def forward(self, inputs, hiddens):
        embedded = self.embedding_layer(inputs)
        outputs, hiddens = self.rnn_layer(embedded, hiddens)
        outputs = torch.squeeze(outputs, dim=1)
        outputs = self.fc_layer(outputs)
        return outputs, hiddens
    
    def sample(self, n = 1, use_cuda = False):
        seq_outputs = torch.zeros(n, self.seq_len).long()
        z_inputs = torch.zeros(n, 1).long()
        hiddens = self.init_hiddens(n, use_cuda)

        if use_cuda:
            seq_outputs = seq_outputs.cuda()
            z_inputs = z_inputs.cuda()
            
        for i in range(self.seq_len):
            outputs, hiddens = self(z_inputs, hiddens)
            outputs = torch.distributions.Categorical(probs=outputs)
            outputs = outputs.sample()
            seq_outputs[:, i] = outputs
            z_inputs = torch.unsqueeze(outputs, dim=1)
            
        return seq_outputs

In [None]:
class NetR(torch.nn.Module):
    def __init__(self, vocab_size) -> None:
        super(NetR, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_layer = torch.nn.Embedding(vocab_size, 128)
        self.rnn_layer = torch.nn.GRU(128, 256, batch_first=True)
        self.fc_layer = torch.nn.Sequential(
            torch.nn.Linear(256, vocab_size),
            torch.nn.Softmax(2))
        
    def init_hiddens(self, batch_size, use_cuda = True):
        if use_cuda:
            return torch.zeros(1, batch_size, 256).cuda()
        return torch.zeros(1, batch_size, 256).cpu()
    
    def forward(self, inputs, hiddens):
        embedded = self.embedding_layer(inputs)
        outputs, _ = self.rnn_layer(embedded, hiddens)
        outputs = self.fc_layer(outputs)
        outputs = torch.distributions.Categorical(probs=outputs)
        outputs = outputs.sample()
        return outputs

In [None]:
gamma = 1
seq_len = 40
batch_size = 64
sample_size = 100
lr = 0.001
epochs = 20

dataset = TextDataset(seq_len=seq_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

vocab_size = len(dataset.stoi)
print(vocab_size)

F = NetF(vocab_size, seq_len)
G = NetG(vocab_size, seq_len)
R = NetR(vocab_size)

F = F.cuda()
G = G.cuda()
R = R.cuda()

g_optim = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999)) 
r_optim = torch.optim.Adam(R.parameters(), lr=lr, betas=(0.5, 0.999)) 
cosine =  torch.nn.CosineSimilarity()

In [None]:
for epoch in range(epochs):
    for iters, inputs in enumerate(dataloader, 0):
        sample_index = random.randint(0, batch_size - 1)
        real_inputs = inputs.cuda()
        generate_inputs = G.sample(batch_size, True)

        hiddens = R.init_hiddens(batch_size)

        g_sample = generate_inputs
        g_sample[sample_index:sample_index+1, :] = real_inputs[sample_index:sample_index+1, :]
        ranked_sample = R(g_sample, hiddens)

        ys = F(generate_inputs)
        yu = F(ranked_sample)
        alpha = cosine(ys, yu)
        rewards = torch.exp(gamma * alpha) / torch.sum(torch.exp(gamma * alpha), dim=0)

        g_loss = 0.0
        hiddens = G.init_hiddens(batch_size)
        for i in range(seq_len - 1):
            z_inputs = real_inputs[:, i:i+1]
            outputs, hiddens = G(z_inputs, hiddens)
            for j in range(batch_size):
                g_loss += -torch.log(outputs[j, real_inputs[j, i+1]]) * rewards[j]
        
        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()

        with torch.no_grad():
            g_sample = G.sample(1, True)
            g_sample = torch.distributions.Categorical(probs=g_sample)
            g_sample = g_sample.sample()

        rc_sample = real_inputs
        fc_sample = generate_inputs

        rc_sample[sample_index:sample_index+1, :] = g_sample
        fc_sample[sample_index:sample_index+1, :] = real_inputs[sample_index:sample_index+1, :]

        real_ranked = R(rc_sample, hiddens)
        fake_ranked = R(fc_sample, hiddens)

        real_ys = F(real_inputs)
        real_yu = F(real_ranked)
        real_alpha = cosine(real_ys, real_yu)
        real_p = torch.exp(gamma * real_alpha) / torch.sum(torch.exp(gamma * real_alpha), dim=0)

        fake_ys = F(generate_inputs)
        fake_yu = F(fake_ranked)
        fake_alpha = cosine(fake_ys, fake_yu)
        fake_p = torch.exp(gamma * fake_alpha) / torch.sum(torch.exp(gamma * fake_alpha), dim=0)

        r_loss = torch.mean(torch.log(real_p)) - torch.mean(torch.log(fake_p))
                
        r_optim.zero_grad()
        r_loss.backward()
        r_optim.step()

        if iters % 10 == 0:
            print("[+] Epoch: [%d/%d] G_Loss: %.4f R_Loss: %.4f" % (epoch+1, epochs, g_loss, r_loss))
            with torch.no_grad():
                fd = open(f"epoch_{epoch+1}_step_{iters}.txt", "w", encoding='utf-8')
                generate_data = G.cpu().sample(sample_size)
                for i in range(sample_size):
                    text = ""
                    for j in range(seq_len):
                        text += dataset.itos[generate_data[i][j].item()]
                    fd.write(text + "\n")
                fd.close()
            G.cuda()

G = G.cpu().eval()
R = R.cpu().eval()

torch.save(G, "rankgan_modelG.pth")
torch.save(R, "rankgan_modelR.pth")

In [5]:
fd = open(f"sample.txt", "w", encoding='utf-8')
itos = pickle.load(open("./itos.bin", "rb"))
model = torch.load("./rankgan_modelG.pth")
sample_size = 50
seq_len = 40
generate_data = model.sample(sample_size)
for i in range(sample_size):
    text = ""
    for j in range(seq_len):
        if itos[generate_data[i][j].item()] == "'":
            break
        text += itos[generate_data[i][j].item()]
    fd.write(text + "\n")
fd.close()