In [1]:
import nlp
import sentencepiece as spm
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset as TorchDataset, DataLoader

from models.seq2seq_generator import Seq2Seq, Encoder, Decoder
from models.lstm_generator import LSTMGenerator
from models.cnn_discriminator import CNNDiscriminator
from models.rollout import ROLLOUT


### 1. load data

In [2]:
%%time
train_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="train")
val_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="validation")
test_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="test")

CPU times: user 91.3 ms, sys: 84.9 ms, total: 176 ms
Wall time: 3.9 s


In [3]:
%%time
train_articles = [item['article'] for item in train_dataset]
train_highlights = [item['highlights'] for item in train_dataset]
val_articles = [item['article'] for item in val_dataset]
val_highlights = [item['highlights'] for item in val_dataset]

CPU times: user 17.1 s, sys: 2.11 s, total: 19.2 s
Wall time: 27.4 s


In [4]:
'''
%%time
with open('data/sp_texts.txt', 'a') as f:
    for article in tqdm(train_articles):
        f.write(article + '\n')
    for highlight in tqdm(train_highlights):
        f.write(highlight + '\n')
'''

"\n%%time\nwith open('data/sp_texts.txt', 'a') as f:\n    for article in tqdm(train_articles):\n        f.write(article + '\n')\n    for highlight in tqdm(train_highlights):\n        f.write(highlight + '\n')\n"

In [5]:
'''
%%time
spm.SentencePieceTrainer.train(input='data/sp_texts.txt',
                               model_prefix='m',
                               vocab_size=10000)
'''

"\n%%time\nspm.SentencePieceTrainer.train(input='data/sp_texts.txt',\n                               model_prefix='m',\n                               vocab_size=10000)\n"

### 2. dataset and dataloaders

In [6]:
sp = spm.SentencePieceProcessor(model_file='m.model')

In [7]:
class Dataset(TorchDataset):
    def __init__(self, articles, highlights):
        self.articles = articles
        self.highlights = highlights

    def __getitem__(self, index):
        article = self.articles[index]
        highlight = self.highlights[index]
        
        return (torch.tensor([1] + sp.encode(article) + [2], dtype=torch.long),
                torch.tensor([1] + sp.encode(highlight) + [2], dtype=torch.long))

    def __len__(self):
        return len(self.articles)

In [8]:
def pad_tensor(vec, length, dim, pad_symbol):
    pad_size = length - vec.shape[dim]
    return torch.cat([vec, torch.zeros(pad_size, dtype=torch.long) + pad_symbol],
                     dim=dim)

class Padder:
    def __init__(self, dim=0, pad_symbol=0):
        self.dim = dim
        self.pad_symbol = pad_symbol
        
    def __call__(self, batch):
        max_article_len = max(map(lambda x: x[0].shape[self.dim], batch))
        max_highlight_len = max(map(lambda x: x[1].shape[self.dim], batch))
        batch = map(lambda x: (pad_tensor(x[0], max_article_len, self.dim, self.pad_symbol), 
                               pad_tensor(x[1], max_highlight_len, self.dim, self.pad_symbol)),
                    batch)
        batch = list(batch)
        xs = torch.stack(list(map(lambda x: x[0], batch)))
        ys = torch.stack(list(map(lambda x: x[1], batch)))
        return xs.permute(1, 0), ys.permute(1, 0)

In [9]:
train_dataset = Dataset(train_articles, train_highlights)
val_dataset = Dataset(val_articles, val_highlights)

train_dataloader = DataLoader(train_dataset, batch_size=10,
                              collate_fn=Padder(), shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=10,
                            collate_fn=Padder(), shuffle=False)

In [10]:
for batch in train_dataloader:
    print(batch[0].size())
    print(batch[1].size())
    break

torch.Size([1736, 10])
torch.Size([94, 10])


### 3. model architecture

In [11]:
device = 'cpu'

In [12]:
vocab_size = 10000

In [19]:
INPUT_DIM = vocab_size
OUTPUT_DIM = vocab_size
ENC_EMB_DIM = 32
DEC_EMB_DIM = 32
HID_DIM = 32
N_LAYERS = 1
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1
LR = 5e-4
decayRate = 0.97

enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)

G = Seq2Seq(enc, dec, device).to(device)

In [20]:
emb_dim = 10
hidden_dim = 20
vocab_size = 10000
max_seq_len = 15
padding_idx = 0
D = CNNDiscriminator(embed_dim=emb_dim,
                     vocab_size=vocab_size,
                     filter_sizes=[2, 3],
                     num_filters=[2, 2],
                     padding_idx=padding_idx)

### 4. generator pretraining

In [21]:
criterion = nn.CrossEntropyLoss()
lr = 0.01
opt = torch.optim.Adam(G.parameters(), lr=lr)
n_epochs = 10

for epoch_idx in range(n_epochs):
    losses = []
    for batch_idx, data_input in tqdm(enumerate(train_dataloader)):
        article = data_input[0]
        highlight = data_input[1]
        
        opt.zero_grad()
        
        out = G(article, highlight, teacher_forcing_ratio=0.)
        
        loss = criterion(out.permute(1, 2, 0), highlight.permute(1, 0))
        loss.backward()
        opt.step()

        losses.append(loss.data.item())
        
        if ((batch_idx + 1)% 1 == 0 and (epoch_idx + 1)%10 == 0):
            print(f"Training Steps Completed: {batch_idx}, loss: {loss.data.item()}")

3it [00:16,  5.37s/it]


KeyboardInterrupt: 

### 5. discriminator pretraining

In [22]:
criterion = nn.CrossEntropyLoss()
lr = 0.01
opt = torch.optim.Adam(D.parameters(), lr=lr)
n_epochs = 10

for epoch_idx in range(n_epochs):
    losses = []
    for batch_idx, data_input in tqdm(enumerate(train_dataloader)):
        article = data_input[0]
        highlight = data_input[1]
        generated_highlight = G(article, highlight, teacher_forcing_ratio=0.)
        generated_highlight = torch.argmax(F.softmax(generated_highlight, dim=2), dim=2).permute(1, 0)
        highlight = highlight.permute(1, 0)
        
        batch = torch.cat([generated_highlight, highlight], dim=0)
        targets = torch.tensor([0]*generated_highlight.size(0) + [1]*highlight.size(0))
        
        opt.zero_grad()
        
        out = D(batch)
        
        loss = criterion(out, targets)
        loss.backward()
        opt.step()

        losses.append(loss.data.item())
        
        if ((batch_idx + 1)% 1 == 0 and (epoch_idx + 1)%10 == 0):
            print(f"Training Steps Completed: {batch_idx}, loss: {loss.data.item()}")

8it [00:08,  1.02s/it]


KeyboardInterrupt: 

### 6. training

In [26]:
n_epochs = 10

In [None]:
for epoch_idx in range(n_epochs):
    G_loss = []
    D_loss = []
    for batch_idx, data_input in enumerate(train_dataloader):
        article = data_input[0]
        highlight = data_input[1]
        
        generated_data = G(article, highlight, teacher_forcing_ratio=0.)
        generated_data = torch.argmax(F.softmax(generated_data, dim=2), dim=2).permute(1, 0)

        true_data = data_input[0].view(batch_size, 784).to(device) # batch_size X 784
        digit_labels = data_input[1].to(device) # batch_size
        true_labels = torch.ones(batch_size).to(device)
        
        discriminator_optimizer.zero_grad()

        discriminator_output_for_true_data = discriminator(true_data, digit_labels).view(batch_size)
        true_discriminator_loss = loss(discriminator_output_for_true_data, true_labels)

        discriminator_output_for_generated_data = discriminator(generated_data.detach(), fake_labels).view(batch_size)
        generator_discriminator_loss = loss(
            discriminator_output_for_generated_data, torch.zeros(batch_size).to(device)
        )
        discriminator_loss = (
            true_discriminator_loss + generator_discriminator_loss
        ) / 2
        
        discriminator_loss.backward()
        discriminator_optimizer.step()

        D_loss.append(discriminator_loss.data.item())
        
        
        # Generator

        generator_optimizer.zero_grad()
        # It's a choice to generate the data again
        generated_data = generator(noise, fake_labels) # batch_size X 784
        discriminator_output_on_generated_data = discriminator(generated_data, fake_labels).view(batch_size)
        generator_loss = loss(discriminator_output_on_generated_data, true_labels)
        generator_loss.backward()
        generator_optimizer.step()
        
        G_loss.append(generator_loss.data.item())
        if ((batch_idx + 1)% 500 == 0 and (epoch_idx + 1)%10 == 0):
            print("Training Steps Completed: ", batch_idx)
            
            with torch.no_grad():
                noise = torch.randn(batch_size,100).to(device)
                fake_labels = torch.randint(0, 10, (batch_size,)).to(device)
                generated_data = generator(noise, fake_labels).cpu().view(batch_size, 28, 28)
                for x in generated_data:
                    print(fake_labels[0].item())
                    plt.imshow(x.detach().numpy(), interpolation='nearest',cmap='gray')
                    plt.show()

                    break


    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch_idx), n_epochs, torch.mean(torch.FloatTensor(D_loss)), torch.mean(torch.FloatTensor(G_loss))))