In [1]:
import sys
import torch
import os

In [2]:
os.chdir("..")

In [3]:
!ls

README.md	  jupyter.o7280433    requirements.txt	venv
checkpoints	  jupyter.o7284871    runs		vt-tr.o7146959
data		  logs		      scripts		vt-tr.o7209762
jupyter.o7250076  notebooks	      src		vt-tr.o7242335
jupyter.o7272725  out.txt	      training
jupyter.o7277176  prose_translations  translations


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
from src.data_utils.batch import rebatch
from src.data_utils.data import get_training_iterators
from src.model.loss_optim import MultiGPULossCompute, SimpleLossCompute
from src.model.model import make_model, NoamOpt, LabelSmoothing, translate_sentence
from src.utils.utils import get_tokenizer

In [6]:
tok = get_tokenizer("tr")

In [7]:
train_iter, valid_iter, test_iter, train_idx, dev_idx, test_idx = get_training_iterators("tur")



In [8]:
# mini dev set
with open("data/tr/tur.dev.tgt", encoding="utf-8") as infile:
    toystrings = [x.strip() for x in infile.readlines()[:20]]

In [9]:
toyset = [torch.LongTensor([1] + tok.Encode(x) + [2])  for x in toystrings]
toyset = torch.nn.utils.rnn.pad_sequence(sequences=toyset, padding_value=3)

In [10]:
toyset

tensor([[    1,     1,     1,  ...,     1,     1,     1],
        [ 5605,     8,  1330,  ...,     8,   771,  2804],
        [27861,  2475, 10284,  ...,  3987,  5057, 11694],
        ...,
        [    3,     3,     3,  ...,     3,     3,     3],
        [    3,     3,     3,  ...,     3,     3,     3],
        [    3,     3,     3,  ...,     3,     3,     3]])

Two critics:
- Input related to output or not
- Classifier into poetry, prose, generated, scrambled poetry

One word/token selector:
- Choose tokens from input sequence to use for topic
- 

In [11]:
from torchtext import data
import torchtext as tt
from src.data_utils.batch import MyIterator
from src.model.model import batch_size_val

def each_line(fname):
    c = 0
    lines = []
    with open(fname, "r", encoding="utf-8") as infile:
        for line in infile:
            if line.count(" ") > 200:
                continue
            lines.append(line.strip())
            c += 1
            if c >= 2000000: 
                break
    return lines

def make_iter(lines, tokenizer, batch_size=256):
    
    def tok(seq):
        return tokenizer.EncodeAsIds(seq)

    field = data.Field(tokenize=tok, init_token=1, eos_token=2, pad_token=3, use_vocab=False)
    #ds = data.TabularDataset(fpath, "tsv", [("src", field)], skip_header=True)

    examples = [tt.data.Example.fromdict({"src": x}, {"src": ("src", field)}) for x in lines]
    ds = tt.data.Dataset(examples, {"src": field})
    iter = MyIterator(ds, batch_size=batch_size, device="cpu",
                             repeat=False, sort_key=lambda x: len(x.src),
                             batch_size_fn=batch_size_val, train=False, sort=True)

    return iter




In [12]:
prose_iter = make_iter(each_line("data/tr/prose/prose_gan.txt"), tok)




In [13]:
import random

to_scramble = each_line("data/tr/tur.train.tgt")
scrambled = []
for poem in to_scramble:
    new_poem = poem.split("¬")
    random.shuffle(new_poem)
    scrambled.append("¬".join(new_poem))

In [14]:
scrambled_iter = make_iter(scrambled, tok)

In [15]:
import copy
from src.model.model import MultiHeadedAttention, PositionwiseFeedForward, \
                    PositionalEncoding, Encoder, EncoderLayer, Generator, Embeddings
import torch.nn as nn

class Critic(nn.Module):

    def __init__(self, encoder, src_embed, generator):
        super(Critic, self).__init__()
        self.encoder = encoder
        self.src_embed = src_embed
        self.generator = generator
        self.steps = 0

        
    def forward(self, x, mask):
        """Pass the input (and mask) through each layer in turn."""
        x = self.src_embed(x)
        for layer in self.encoder.layers:
            x = layer(x, mask)
        return self.encoder.norm(x)    


def make_critic(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    """Helper: Construct a model from hyper-parameters."""
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    generator = Generator(d_model, tgt_vocab)
    embed = nn.Sequential(Embeddings(d_model, src_vocab), c(position))
    encoder = Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N)
    critic = Critic(encoder, embed, generator)
    
    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in critic.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform(p)

    return critic


In [16]:
ntokens = 32000
enc_dec = make_model(ntokens, ntokens, N=6).to(device)
token_selector = make_critic(ntokens, 2, N=2).to(device)
style_critic = make_critic(ntokens, 4, N=2).to(device)
relevance_critic = make_critic(ntokens + 1, 1, N=2).to(device)

  nn.init.xavier_uniform(p)


In [17]:
from torch.autograd import Variable
import numpy as np
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0


def prep_tensors( src, trg, pad=3):
    src_mask = (src != pad).unsqueeze(-2)
    trg_in = trg[:, :-1]
    trg_y = trg[:, 1:]
    trg_mask = make_std_mask(trg_in, pad)
    return src, trg_y, src_mask, trg_mask

def make_std_mask(tgt, pad):
    """Create a mask to hide padding and future words."""
    tgt_mask = (tgt != pad).unsqueeze(-2)
    tgt_mask = tgt_mask & Variable(
        subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
    return tgt_mask


NameError: name 'it' is not defined

In [19]:
def get_dae_input(tgt, token_selector):
    select_prob_embeds = token_selector.forward(tgt.to(device), 
                                         (tgt != 3).unsqueeze(-2).to(device))
    select_prob = token_selector.generator(select_prob_embeds)
    select_indices = torch.max(select_prob, dim=2).indices.type(torch.ByteTensor)
    dae_list = []
    for ind, row in zip(select_indices, tgt):
        dae_list.append(torch.masked_select(row, ind)[:15])
    dae_input = torch.nn.utils.rnn.pad_sequence(dae_list, batch_first=False, padding_value=3)
    return dae_input

In [20]:
rebatched = (rebatch(3, b) for b in train_iter)
next(rebatched)

Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.
Skipped overlong sample while batching.




<src.data_utils.batch.Batch at 0x14a78100fd68>

In [21]:
b = next(rebatched)
p = next(iter(prose_iter))

In [22]:
b.trg
dae_input = get_dae_input(b.trg, token_selector)
dae_input_mask = dae_input != 3


  


In [23]:
from src.model.adafactor import Adafactor

enc_dec_opt = NoamOpt(enc_dec.src_embed[0].d_model, 1, 2000,
                        torch.optim.Adam(enc_dec.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
style_criterion = nn.BCELoss()
relevance_criterion = nn.BCELoss()

token_optim = Adafactor(token_selector.parameters())
style_optim = Adafactor(style_critic.parameters())
rel_optim = Adafactor(relevance_critic.parameters())

In [24]:
relevance_criterion = torch.nn.BCEWithLogitsLoss()

In [25]:
def get_relevance_input(dae_input, tgt):
    mid_point = torch.ones((tgt.shape[0], 1), dtype=torch.long) * ntokens
    return torch.cat((dae_input, mid_point.to(device), tgt), dim=1)


In [42]:
soft = torch.nn.Softmax(dim=1)
for poetry_batch, prose_batch, scrambled_batch in zip(rebatched, prose_iter, scrambled_iter):
    #zero_grad
    enc_dec_opt.optimizer.zero_grad()
    token_optim.zero_grad()
    rel_optim.zero_grad()
    style_optim.zero_grad()

    
    tgt, tgt_mask = poetry_batch.trg.to(device), poetry_batch.trg_mask.to(device)
    # classify tokens, get the first 15 tokens selected.
    dae_input = get_dae_input(poetry_batch.trg, token_selector).transpose(0,1).to(device)
    # create src and src mask from selected tokens
    dae_input_mask = (dae_input != 3).unsqueeze(-2)
    # get output of poetry generator
    output_embeds = enc_dec.forward(dae_input, tgt, dae_input_mask, tgt_mask)
    # put through its generator, choose likeliest token
    output = enc_dec.generator(output_embeds)
    # torch.max that stuff
    _, output_selected = torch.max(output, 2)
    
    #create rel critic input by concatenating dae input and tgt
    rel_input = get_relevance_input(dae_input, tgt)
    # get critic losses
    style_scores = soft(style_critic.generator(style_critic.forward(output_selected.to(device), 
                                        (output_selected != 3).unsqueeze(-2).to(device)))[:,0,:])
    relevance_scores = relevance_critic.generator(relevance_critic.forward(rel_input.to(device), 
                                        (rel_input != 3).unsqueeze(-2).to(device)))[:,0,:]
    
    style_loss = style_criterion(style_scores[:, 0], torch.ones((style_scores.shape[0])).to(device))
    relevance_loss = relevance_criterion(relevance_scores.squeeze(), 
                                         torch.ones(relevance_scores.shape[0]).to(device))
    enc_dec_loss = style_loss + relevance_loss
    enc_dec_loss.backward()
    enc_dec_opt.step()
    token_optim.step()
    
    # samples and classes for style critic
    # trg is poetry, prose batch is prose, model output is generated, scramble poetry lines for scrambled.
    
    # samples and classes for relevance critic
    # rel_crit_loss = relevance_criterion(relevance_scores.squeeze(), 
    #                                     torch.ones(relevance_scores.shape[0]).to(device))
    # scramb_rel_input = get_relevance_input(dae_input[torch.randperm(dae_input.shape[0]), :], tgt)
    scrambled_relevance_scores = relevance_critic.generator(relevance_critic.forward(scramb_rel_input.to(device), 
                                        (scramb_rel_input != 3).unsqueeze(-2).to(device)))[:,0,:]
    rel_crit_loss = relevance_criterion(scrambled_relevance_scores.squeeze(), 
                                         torch.zeros(scrambled_relevance_scores.shape[0]).to(device))
    rel_crit_loss.backward()
    rel_optim.step()
    break 
    # dae_input - trg as input, 1 as output
    # shuffled dae_input, trg as input, 0 as output

  


tensor([ 4, 18, 11,  7,  8, 16,  9,  3, 14,  5, 17,  1,  2,  6,  0, 19, 15, 13,
        12, 10])

In [None]:
torch.ones((8,19), dtype=torch.long) * 32000

In [33]:
print(toyset.transpose(0, 1)[torch.randperm(toyset.shape[1]),:])

tensor([[    1,  5605, 27861,  ...,     3,     3,     3],
        [    1,  1330, 10284,  ...,     3,     3,     3],
        [    1,   779,   850,  ...,     3,     3,     3],
        ...,
        [    1,   727,    36,  ...,     3,     3,     3],
        [    1,     8, 29725,  ...,     3,     3,     3],
        [    1,   771,  5057,  ...,     3,     3,     3]])
