In [1]:
import sys
import torch
import os

In [2]:
os.chdir("..")

In [3]:
!ls

README.md			       tur-baseline-211299-prose.sh.o7747867
checkpoints			       tur-baseline-211299.sh.o7747866
data				       tur-baseline-214186-prose.sh.o7747915
logs				       tur-baseline-214186.sh.o7747914
news_translations		       tur-baseline-217071-prose.sh.o7748183
notebooks			       tur-baseline-217071.sh.o7748182
out.txt				       tur-baseline-219958-prose.sh.o7748221
prose_translations		       tur-baseline-219958.sh.o7748220
requirements.txt		       tur-baseline-22210.sh.o7737186
runs				       tur-baseline-222814-prose.sh.o7748280
scripts				       tur-baseline-222814.sh.o7748279
src				       tur-baseline-225645-prose.sh.o7748341
translations			       tur-baseline-225645.sh.o7748340
tur-baseline-1.sh.o7735992	       tur-baseline-228485-prose.sh.o7748379
tur-baseline-100647.sh.o7740203        tur-baseline-228485.sh.o7748378
tur-baseline-103479.sh.o7740349        tur-baseline-231331-prose.sh.o7749136
tur-baseline-106300.sh.o7740423        tur-baseline-231331.

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
from src.data_utils.batch import rebatch
from src.data_utils.data import get_training_iterators
from src.model.loss_optim import MultiGPULossCompute, SimpleLossCompute
from src.model.model import make_model, NoamOpt, LabelSmoothing, translate_sentence
from src.utils.utils import get_tokenizer

In [6]:
tok = get_tokenizer("tr")

In [7]:
train_iter, valid_iter, test_iter, train_idx, dev_idx, test_idx = get_training_iterators("tur", batch_size=3000)



In [8]:
# mini dev set
with open("data/tr/tur.dev.tgt", encoding="utf-8") as infile:
    toystrings = [x.strip() for x in infile.readlines()[:20]]

In [9]:
toyset = [torch.LongTensor([1] + tok.Encode(x) + [2])  for x in toystrings]
toyset = torch.nn.utils.rnn.pad_sequence(sequences=toyset, padding_value=3)

In [10]:
toyset

tensor([[    1,     1,     1,  ...,     1,     1,     1],
        [ 5605,     8,  1330,  ...,     8,   771,  2804],
        [27861,  2475, 10284,  ...,  3987,  5057, 11694],
        ...,
        [    3,     3,     3,  ...,     3,     3,     3],
        [    3,     3,     3,  ...,     3,     3,     3],
        [    3,     3,     3,  ...,     3,     3,     3]])

Two critics:
- Input related to output or not
- Classifier into poetry, prose, generated, scrambled poetry

One word/token selector:
- Choose tokens from input sequence to use for topic
- 

In [11]:
from torchtext import data
import torchtext as tt
from src.data_utils.batch import MyIterator
from src.model.model import batch_size_val

def each_line(fname):
    c = 0
    lines = []
    with open(fname, "r", encoding="utf-8") as infile:
        for line in infile:
            if line.count(" ") > 200 or line.count(" ") < 10:
                continue
            lines.append(line.strip())
            c += 1
            if c >= 2000000: 
                break
    return lines

def make_iter(lines, tokenizer, batch_size=256):
    
    def tok(seq):
        return tokenizer.EncodeAsIds(seq)

    field = data.Field(tokenize=tok, init_token=1, eos_token=2, pad_token=3, use_vocab=False)
    #ds = data.TabularDataset(fpath, "tsv", [("src", field)], skip_header=True)

    examples = [tt.data.Example.fromdict({"src": x}, {"src": ("src", field)}) for x in lines]
    ds = tt.data.Dataset(examples, {"src": field})
    iter = MyIterator(ds, batch_size=batch_size, device="cpu",
                             repeat=False, sort_key=lambda x: len(x.src),
                             batch_size_fn=batch_size_val, train=False, sort=True)

    return iter




In [None]:
prose_iter = make_iter(each_line("data/tr/prose/prose_gan.txt"), tok, batch_size=3000)




In [None]:
import random

to_scramble = each_line("data/tr/tur.train.tgt")
scrambled = []
for poem in to_scramble:
    new_poem = poem.split("¬")
    random.shuffle(new_poem)
    scrambled.append("¬".join(new_poem))

In [None]:
scrambled_iter = make_iter(scrambled, tok, batch_size=3000)

In [None]:
import copy
from src.model.model import MultiHeadedAttention, PositionwiseFeedForward, \
                    PositionalEncoding, Encoder, EncoderLayer, Generator, Embeddings
import torch.nn as nn

class Critic(nn.Module):

    def __init__(self, encoder, src_embed, generator):
        super(Critic, self).__init__()
        self.encoder = encoder
        self.src_embed = src_embed
        self.generator = generator
        self.steps = 0

        
    def forward(self, x, mask):
        """Pass the input (and mask) through each layer in turn."""
        x = self.src_embed(x)
        for layer in self.encoder.layers:
            x = layer(x, mask)
        return self.encoder.norm(x)    


def make_critic(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    """Helper: Construct a model from hyper-parameters."""
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    generator = Generator(d_model, tgt_vocab)
    embed = nn.Sequential(Embeddings(d_model, src_vocab), c(position))
    encoder = Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N)
    critic = Critic(encoder, embed, generator)
    
    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in critic.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform(p)

    return critic


In [None]:
ntokens = 32000
enc_dec = make_model(ntokens, ntokens, N=6).to(device)
token_selector = make_critic(ntokens, 2, N=2).to(device)
style_critic = make_critic(ntokens, 4, N=2).to(device)
relevance_critic = make_critic(ntokens + 1, 1, N=2).to(device)

In [None]:
from torch.autograd import Variable
import numpy as np
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0


def prep_tensors( src, trg, pad=3):
    src_mask = (src != pad).unsqueeze(-2)
    trg_in = trg[:, :-1]
    trg_y = trg[:, 1:]
    trg_mask = make_std_mask(trg_in, pad)
    return src, trg_y, src_mask, trg_mask

def make_std_mask(tgt, pad):
    """Create a mask to hide padding and future words."""
    tgt_mask = (tgt != pad).unsqueeze(-2)
    tgt_mask = tgt_mask & Variable(
        subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
    return tgt_mask


In [None]:
def get_dae_input(tgt, token_selector):
    select_prob_embeds = token_selector.forward(tgt.to(device), 
                                         (tgt != 3).unsqueeze(-2).to(device))
    select_prob = token_selector.generator(select_prob_embeds)
    select_indices = torch.max(select_prob, dim=2).indices.type(torch.ByteTensor)
    dae_list = []
    for ind, row in zip(select_indices, tgt):
        dae_list.append(torch.masked_select(row, ind)[:15])
    dae_input = torch.nn.utils.rnn.pad_sequence(dae_list, batch_first=False, padding_value=3)
    return dae_input

In [None]:
rebatched = (rebatch(3, b) for b in train_iter)

In [None]:
torch


In [None]:
from src.model.adafactor import Adafactor

#enc_dec_opt = NoamOpt(enc_dec.src_embed[0].d_model, 1, 2000,
#                        torch.optim.Adam(enc_dec.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
enc_dec_opt = Adafactor(enc_dec.parameters())

style_criterion = nn.BCELoss()
relevance_criterion = nn.BCELoss()

token_optim = Adafactor(token_selector.parameters())
style_optim = Adafactor(style_critic.parameters())
rel_optim = Adafactor(relevance_critic.parameters())

In [None]:
relevance_criterion = torch.nn.BCEWithLogitsLoss()

In [None]:
def get_relevance_input(dae_input, tgt):
    mid_point = torch.ones((tgt.shape[0], 1), dtype=torch.long) * ntokens
    return torch.cat((dae_input, mid_point.to(device), tgt), dim=1)


In [None]:
accumulation_steps = 8

In [None]:
# get validation iterator


def validate_batch(model, src, max_len=256, start_symbol=1, end_symbol=2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    src_mask = (src != 3).unsqueeze(-2)
    memory = model.encode(src.to(device), src_mask.to(device))
    ys = torch.ones(src.shape[0], 1).fill_(start_symbol).type_as(src.data).to(device)
    finished = torch.zeros((src.shape[0], 1))
    for i in range(max_len-1):
        out = model.decode(memory, src_mask,
                           Variable(ys).to(device),
                           Variable(subsequent_mask(ys.size(1)).type_as(src.data)).to(device))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        # next_word = next_word.data_utils[0]
        unsqueezed = next_word.unsqueeze(1)
        for c, token in enumerate(unsqueezed):
            if token == end_symbol:
                finished[c] = 1
        if sum(finished) >= src.shape[0]:
            break
        ys = torch.cat([ys, unsqueezed], dim=1)
                        # torch.ones(src.shape[0], 1).type_as(src.data_utils).fill_(next_word).to(device)], dim=1)
    return ys


def validate(model, selector, iterator):
    pass

In [None]:

label_smoothing = LabelSmoothing(size=32000, padding_idx=3, smoothing=0.1)

In [None]:
def greedy_generate(model, src, max_len=256, start_symbol=1, end_symbol=2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    src_mask = (src != 3).unsqueeze(-2)
    memory = model.encode(src.to(device), src_mask.to(device))
    ys = torch.ones(src.shape[0], 1).fill_(start_symbol).type_as(src.data).to(device)
    finished = torch.zeros((src.shape[0], 1))
    for i in range(max_len-1):
        out = model.decode(memory, src_mask,
                           Variable(ys).to(device),
                           Variable(subsequent_mask(ys.size(1)).type_as(src.data)).to(device))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        # next_word = next_word.data_utils[0]
        unsqueezed = next_word.unsqueeze(1)
        for c, token in enumerate(unsqueezed):
            if finished[c] == 1:
                unsqueezed[c] = 3
            if token == end_symbol:
                finished[c] = 1
        if sum(finished) >= src.shape[0]:
            break
        ys = torch.cat([ys, unsqueezed], dim=1)
                        # torch.ones(src.shape[0], 1).type_as(src.data_utils).fill_(next_word).to(device)], dim=1)
    return ys

In [None]:
import time
from src.utils.rhyme import critique_poem

In [None]:
soft = torch.nn.Softmax(dim=1)
start = time.time()
all_tokens = 0
for c, (poetry_batch, prose_batch, scrambled_batch) in enumerate(zip(rebatched, prose_iter, scrambled_iter)):
    all_tokens += poetry_batch.ntokens
    tgt, tgt_mask = poetry_batch.trg.to(device), poetry_batch.trg_mask.to(device)
    # classify tokens, get the first 15 tokens selected.
    src = poetry_batch.src.to(device)
    # create src and src mask from selected tokens
    src_mask = (src != 3).unsqueeze(-2)
    
    # get output of poetry generator
    output_embeds = enc_dec.forward(src, tgt, src_mask, tgt_mask)
    output = enc_dec.generator(output_embeds)
    reconstruction_loss = label_smoothing(output.contiguous().view(-1, output.size(-1)),
                             poetry_batch.trg_y.to(device).contiguous().view(-1)) / poetry_batch.ntokens
    
    
    _, output_selected = torch.max(output, 2)
    scores = [critique_poem(tok.Decode(x.tolist()), "tr", redif=True) 
              for x in output_selected]
    rhyme_score = sum(x[0] for x in scores) / len(scores)
    
    reconstruction_loss += reconstruction_loss * (1 - rhyme_score)
    
    reconstruction_loss.backward()
    token_optim.step() 
    enc_dec_opt.step()
    enc_dec_opt.zero_grad()
    
    if c % 100  == 0:
        print("Reconstruction loss:", reconstruction_loss)
        print(all_tokens / (time.time() - start), "tokens processed per second.")
        if c% 500 == 0:
            validated = greedy_generate(enc_dec, toyset.transpose(0, 1))
            print([tok.Decode(x.tolist()) for x in validated])
            print(scores)
        

In [None]:
toyset

In [None]:
src