In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
# !pip install youtokentome

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter

import youtokentome as yttm
import numpy as np

import random
import math
import time
import pickle
from tqdm.auto import tqdm

In [4]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

In [6]:
token_model = yttm.BPE(model='models/100k_voc20k_all_embed_yttm.model', n_threads=-1)
# token_model = yttm.BPE(model='/content/drive/MyDrive/Colab Notebooks/NLP_poems/models/100k_voc20k_all_embed_yttm.model', n_threads=-1)

In [7]:
class styleDataset(Dataset):
    def __init__(self, data_list_of_list):
        self.data = data_list_of_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    def cut_length(self, limit):
        self.data = [text if len(text)<limit else text[:limit] for text in self.data]

    def distribute_data(self, share_tr, share_val):  # examle = (0.8, 0.1)
        rand_i = np.random.permutation(len(self.data))
        n1 = int(len(self.data) * share_tr)
        n2 = int(len(self.data) * share_val)
        return [self.data[i] for i in rand_i[0: n1]], \
               [self.data[i] for i in rand_i[n1: n1 + n2]], \
               [self.data[i] for i in rand_i[n1 + n2:]]
    
def padding(batch):
    pad_id = token_model.subword_to_id('<PAD>')
    
    batch = [torch.tensor(x) for x in batch]
    batch = pad_sequence(batch, batch_first=False, padding_value=pad_id)  # batch_first=True -> [batch size, len]
    return batch

In [8]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, style_dim, n_layers, dropout):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.style_dim = style_dim
        
        self.embedding = nn.Embedding(input_dim, emb_dim) # 1 - size of dict emb, 2 - size of emb vec
        self.style_embedding = nn.Embedding(2, style_dim) # 2, because 0-new, 1-poem
        
        self.rnn = nn.GRU(emb_dim, hid_dim + style_dim, n_layers)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_style):      
        #src = [src len, batch size]
        
        embedded = self.dropout(self.embedding(src))
        #embedded = [src len, batch size, emb dim]

        style_token = torch.tensor([src_style], device=device)

        init_hidden = torch.cat((self.style_embedding(style_token), 
                                torch.zeros((1, self.hid_dim), device=device)), dim=1)
        init_hidden = init_hidden.repeat(self.n_layers, src.shape[1], 1) 
        #init hidden = [n layers, batch size, style dim + hid_dim]

        outputs, hidden = self.rnn(embedded, init_hidden)
        #outputs = [src len, batch size, style dim + hid dim]
        #hidden = [n layers, batch size, style dim + hid dim]
        
        #pop style part
        outputs = outputs[:,:, self.style_dim:]
        hidden = hidden[:,:, self.style_dim:]
        #outputs = [src len, batch size, hid dim]
        #hidden = [n layers, batch size, hid dim]

        #outputs are always from the top hidden layer
        #hidden - it will be first hid states in decoder
        return outputs, hidden

In [9]:
# style emb
class Attention(nn.Module):
    def __init__(self, hid_dim, style_dim):
        super().__init__()
        
        self.style_dim = style_dim
        self.attn = nn.Linear(hid_dim * 2 + style_dim, hid_dim)
        self.v = nn.Linear(hid_dim, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs):
        #hidden = [batch size, style dim + hid dim]
        #encoder_outputs = [src len, batch size, hid dim]
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        #repeat decoder hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        #hidden = [batch size, src len, style dim + hid dim]
        #encoder_outputs = [batch size, src len, hid dim]
        
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
        #energy = [batch size, src len, hid dim]

        attention = self.v(energy).squeeze(2)
        #attention= [batch size, src len]
        
        return F.softmax(attention, dim=1)

In [10]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, style_dim, n_layers, dropout, attention):
        super().__init__()
        
        self.output_dim = output_dim  # the size of the voc
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.style_dim = style_dim

        self.attention = attention
        
        self.embedding = nn.Embedding(output_dim, emb_dim) 
        self.style_embedding = nn.Embedding(2, style_dim)

        self.rnn = nn.GRU(hid_dim + emb_dim, hid_dim + style_dim, n_layers) 

        self.fc_out = nn.Linear(hid_dim * 2 + emb_dim + style_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs, trg_style):
        #input = [batch size]
        #hidden = [n layers, batch size, style dim + hid dim]
        #encoder_outputs = [src len, batch size, hid dim]
        
        input = input.unsqueeze(0) 
        #input = [1, batch size]
        embedded = self.dropout(self.embedding(input))
        #embedded = [1, batch size, emb dim]

        if trg_style != -1: # make init hidden with style
            style_token = torch.tensor([trg_style], device=device)
            h_style = self.style_embedding(style_token)
            h_style = h_style.repeat(self.n_layers, hidden.shape[1], 1)
            #h_style = [n layers, batch size, style dim]

            hidden = torch.concat((h_style, hidden), dim=2)
                
        last_hidden = hidden[-1]
        #last hidden = [batch size, style dim + hid dim]        
        a = self.attention(last_hidden, encoder_outputs)    
        #a = [batch size, src len]
        a = a.unsqueeze(1)
        #a = [batch size, 1, src len]
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        #encoder_outputs = [batch size, src len, hid dim]
        
        weighted = torch.bmm(a, encoder_outputs)    
        #weighted = [batch size, 1, hid dim]
        weighted = weighted.permute(1, 0, 2)
        #weighted = [1, batch size, hid dim]
        
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        #rnn_input = [1, batch size, hid dim + emb dim]

        output, hidden = self.rnn(rnn_input, hidden)
        #output = [1, batch size, style dim + hid dim]
        #hidden = [n layers, batch size, style dim + hid dim]
        
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
        #prediction = [batch size, output dim]
        
#         return output, prediction, hidden
        return prediction, hidden

In [11]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, src_style, trg_style, tf_ratio):
        #src, trg = [src len, batch size]
        
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        hid_dim = self.decoder.hid_dim
        style_dim = self.decoder.style_dim
        
        predictions = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
#         outputs = torch.zeros(trg_len, batch_size, hid_dim + style_dim).to(self.device)
         
        enc_outputs, hidden = self.encoder(src, src_style)
            
        #first input to the decoder is the <bos> tokens
        input = trg[0,:]

        for t in range(1, trg_len):
            #insert: input token emb-g, previous hid state, all enc outputs, style
            #receive: prediction and new hidden state
            prediction, hidden = self.decoder(input, hidden, enc_outputs, trg_style)
            
            trg_style = -1 # turn off making init hidden with style
            
            predictions[t] = prediction
#             outputs[t] = output 
            
            #get the highest predicted token
            top1 = prediction.argmax(1) 
            teacher_force = random.random() < tf_ratio
            input = trg[t] if teacher_force else top1

        return predictions

In [12]:
class Discriminator(nn.Module):
    def __init__(self, output_dim, emb_dim, n_filters, filter_sizes, dropout):
        super().__init__()
        
        self.emb_squeeze = nn.Linear(output_dim, emb_dim)

        # in_cnannels: in images this is 3 (1 channel for each of the red, blue and green),
        #              in text we only have a 1 channel - the text itself
        # out_channels: the number of filters
        # kernel_size is the size of the filters = [n x emb_dim] where n is the size of the n-grams.

        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels=1,
                      out_channels=n_filters,
                      kernel_size=(fs, emb_dim)) 
            for fs in filter_sizes
        ])

        self.fc = nn.Linear(len(filter_sizes) * n_filters, 2) 

        self.dropout = nn.Dropout(dropout)

    def forward(self, gener_predicts):
        
        # gener_predicts = [batch size, trg len, output dim]
        gener_predicts = F.softmax(gener_predicts, dim=-1)
        embedded = self.emb_squeeze(gener_predicts)
        # embedded = [batch size, trg len, emb dim]

        embedded = embedded.unsqueeze(1)
        # embedded = [batch size, 1, trg len, hid dim]

        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs] # element-wise relu(x) = max(0,x)
        # conved_n = [batch size, n_filters, trg len - filter_sizes[n] + 1]

        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
        # pooled_n = [batch size, n_filters]

        cat = self.dropout(torch.cat(pooled, dim=1))
        # cat = [batch size, n_filters * len(filter_sizes)]

        predicts = self.fc(cat) 
        # predicts = [batch size, 2]
        # [x,y] - x bigger, then this isn't this style, y - it's this style
        return predicts 

In [13]:
INPUT_DIM = token_model.vocab_size() 
OUTPUT_DIM = INPUT_DIM
ENC_EMB_DIM = 128 
DEC_EMB_DIM = 128
STYLE_EMB_DIM = 256 
HID_DIM = 256 
N_LAYERS = 2 
ENC_DROPOUT = 0.2
DEC_DROPOUT = 0.2
LABEL_DROPOUT = 0.3

N_FILTERS = 128
FILTER_SIZES = [1,3,4,5,8] 

attn = Attention(HID_DIM, STYLE_EMB_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, STYLE_EMB_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, STYLE_EMB_DIM, N_LAYERS, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, device).to(device)
discr_poem = Discriminator(OUTPUT_DIM, HID_DIM, N_FILTERS, FILTER_SIZES, LABEL_DROPOUT).to(device) 
discr_news = Discriminator(OUTPUT_DIM, HID_DIM, N_FILTERS, FILTER_SIZES, LABEL_DROPOUT).to(device)

In [14]:
PAD_ID = token_model.subword_to_id('<PAD>')
gener_criterion = nn.CrossEntropyLoss(ignore_index = PAD_ID)
discr_criterion = nn.CrossEntropyLoss()

In [None]:
def test_get_gener(batch, src_style):
    batch = batch.to(device, non_blocking=True)
    src, trg = batch, batch
    #batch[src len, batch size]

    predict_same = model(src, trg, src_style, src_style, 0)      # FREE-RUNNING
    predict_fake = model(src, trg, src_style, 1 - src_style, 0)  # FREE-RUNNING
    #predict_same = [trg len, batch size, output dim]

    #trg = [trg len, batch size]
    #predict = [trg len, batch size, output dim]
    predict_dim = predict_same.shape[-1]
    predict_same_for_loss = predict_same[1:].view(-1, predict_dim)
    trg = trg[1:].view(-1)
    #trg = [(trg len - 1) * batch size]
    #predict_same_for_loss = [(trg len - 1) * batch size, output dim]
    
    predict_same = predict_same.permute(1, 0, 2)
    predict_fake = predict_fake.permute(1, 0, 2)
    #predict_same = [batch size, trg len, output dim]

    return predict_same, predict_fake, predict_same_for_loss, trg, batch.shape[1]


def test(model, discr_poem, discr_news, loader_p, loader_n, gener_criterion, discr_criterion):
    
    model.eval()
    discr_poem.eval()
    discr_news.eval()
    
    epoch_loss = 0

    acc_discr_poem_rec_right = 0 # poem discr get recon poem and told yes
    acc_discr_news_rec_right = 0
    acc_discr_poem_stt_is_st = 0 # poem discr get fake poem from news and told yes
    acc_discr_news_stt_is_st = 0
    
    out_list_poem_recon = []
    out_list_poem_trans = []
    out_list_news_recon = []
    out_list_news_trans = []
    
    with torch.no_grad():
    
        for batch_p, batch_n in tqdm(zip(loader_p, loader_n), total=len(loader_p)):
            # 1 epoch = 2 batchs, 1 with poem and news

            p_predict_same, p_predict_fake, p_predict_same_loss, p_trg, batch_size = valid_get_gener(batch_p, 1)
            n_predict_same, n_predict_fake, n_predict_same_loss, n_trg, _ = valid_get_gener(batch_n, 0)

            ### save gener texts ###
            p_predict_same_save = p_predict_same.argmax(2)
            p_predict_fake_save = p_predict_fake.argmax(2)
            n_predict_same_save = n_predict_same.argmax(2)
            n_predict_fake_save = n_predict_fake.argmax(2)
            #out_tokens_same = [trg len, batch size]

            p_predict_same_save = p_predict_same.permute(1, 0)
            p_predict_fake_save = p_predict_fake.permute(1, 0)
            n_predict_same_save = n_predict_same.permute(1, 0)
            n_predict_fake_save = n_predict_fake.permute(1, 0)
            #out_tokens_same = [batch size, trg len]

            out_list_poem_recon.append(p_predict_same_save)
            out_list_poem_trans.append(p_predict_fake_save)
            out_list_news_recon.append(n_predict_same_save)
            out_list_news_trans.append(n_predict_fake_save)

            # [x,y] - x bigger, then this isn't this style, y - it's this style
            is_style = torch.ones((batch_size), device=device).long()
            isnt_style = torch.zeros((batch_size), device=device).long()

            ### generator ###
            loss_reconstr = gener_criterion(p_predict_same_loss, p_trg) + \
                            gener_criterion(n_predict_same_loss, n_trg)

            discr_poem_out_fake = discr_poem(n_predict_fake)
            discr_news_out_fake = discr_news(p_predict_fake)

            loss_style_trans = discr_criterion(discr_poem_out_fake, is_style) + \
                               discr_criterion(discr_news_out_fake, is_style)

            ### discriminators ###
            discr_poem_out_same = discr_poem(p_predict_same.detach())
            discr_news_out_same = discr_news(n_predict_same.detach())

            loss_discr_poem = discr_criterion(discr_poem_out_same, is_style) + \
                              discr_criterion(discr_poem_out_fake, isnt_style)
            loss_discr_news = discr_criterion(discr_news_out_same, is_style) + \
                              discr_criterion(discr_news_out_fake, isnt_style)

            ### accuracy ###
            acc_discr_poem_rec_right += (discr_poem_out_same.argmax(1) == 1).sum().item()
            acc_discr_news_rec_right += (discr_news_out_same.argmax(1) == 1).sum().item()

            acc_discr_poem_stt_is_st += (discr_poem_out_fake.argmax(1) == 1).sum().item()
            acc_discr_news_stt_is_st += (discr_news_out_fake.argmax(1) == 1).sum().item()

            ### save epoch loss ###
            epoch_loss += (loss_reconstr + loss_style_trans + loss_discr_poem + loss_discr_news)

    acc_discr_poem_rec = acc_discr_poem_rec_right / len(valid_poem)
    acc_discr_news_rec = acc_discr_news_rec_right / len(valid_poem)

    acc_discr_poem_stt = acc_discr_poem_stt_is_st / len(valid_poem)
    acc_discr_news_stt = acc_discr_news_stt_is_st / len(valid_poem)

    return epoch_loss / len(loader_p) \
           acc_discr_poem_rec, acc_discr_news_rec, acc_discr_poem_stt, acc_discr_news_stt, \
           out_list_news_recon, out_list_news_trans, out_list_poem_recon, out_list_poem_trans

In [15]:
# def test(model, discr_poem, discr_news, loader_p, loader_n, gener_criterion, discr_criterion):
    
#     model.eval()
    
#     epoch_loss = 0
#     epoch_loss_gen_rec = 0
#     epoch_loss_gen_stt = 0
#     epoch_loss_discrs = 0

#     acc_discr_poem_rec_right = 0 # poem discr get recon poem and told yes
#     acc_discr_news_rec_right = 0
#     acc_discr_poem_stt_is_st = 0 # poem discr get fake poem from news and told yes
#     acc_discr_news_stt_is_st = 0

#     out_list_news_recon = []
#     out_list_news_trans = []
#     out_list_poem_recon = []
#     out_list_poem_trans = []
    
#     with torch.no_grad():
    
#         for batch_p, batch_n in tqdm(zip(loader_p, loader_n), total=len(loader_p)):
#             for batch, src_style in (batch_p, batch_n):
#                 batch = batch.to(device, non_blocking=True)
#                 src = batch
#                 trg = batch

#                 predict_same, out_tokens_same = model(src, trg, src_style, src_style, 0) #turn off teacher forcing
#                 predict_fake, out_tokens_fake = model(src, trg, src_style, 1 - src_style, 0) 
#                 #predict_same = [trg len, batch size, output dim]
#                 #out_tokens_same = [trg len, batch size, hid dim]
                
#                 out_tokens_same_old = predict_same.argmax(2)
#                 out_tokens_fake_old = predict_fake.argmax(2)
#                 #out_tokens_same = [trg len, batch size]
            
#                 out_tokens_same_old = out_tokens_same_old.permute(1, 0) 
#                 out_tokens_fake_old = out_tokens_fake_old.permute(1, 0)
#                 #out_tokens_same = [batch size, trg len]

#                 if src_style == 0: # news
#                   out_list_news_recon.append(out_tokens_same_old)
#                   out_list_news_trans.append(out_tokens_fake_old)
#                 if src_style == 1: # poem
#                   out_list_poem_recon.append(out_tokens_same_old)
#                   out_list_poem_trans.append(out_tokens_fake_old)
                
#                 predict_same_for_loss = predict_same
#                 out_tokens_same = predict_same.permute(1, 0, 2)
#                 out_tokens_fake = predict_fake.permute(1, 0, 2)
#                 #predict_same = [batch size, trg len, output dim]

#                 is_style = torch.ones((batch.shape[1]), device=device).long()
#                 isnt_style = torch.zeros((batch.shape[1]), device=device).long()

#                 if src_style == 0: # news
#                   discr_news_out_same = discr_news(out_tokens_same)
#                   discr_poem_out_fake = discr_poem(out_tokens_fake)

#                   loss_discr_news = discr_criterion(discr_news_out_same, is_style)
#                   loss_discr_poem = discr_criterion(discr_poem_out_fake, isnt_style) # find fault in style transfer 

#                   loss_style_trans = discr_criterion(discr_poem_out_fake, is_style) # find success in style transfer

#                   acc_discr_news_rec_right += (discr_news_out_same.argmax(1) == 1).sum().item()
#                   acc_discr_poem_stt_is_st += (discr_poem_out_fake.argmax(1) == 1).sum().item()

#                 if src_style == 1: # poem
#                   discr_news_out_fake = discr_news(out_tokens_fake)
#                   discr_poem_out_same = discr_poem(out_tokens_same)

#                   loss_discr_news = discr_criterion(discr_news_out_fake, isnt_style)
#                   loss_discr_poem = discr_criterion(discr_poem_out_same, is_style) 

#                   loss_style_trans = discr_criterion(discr_news_out_fake, is_style)

#                   acc_discr_poem_rec_right += (discr_poem_out_same.argmax(1) == 1).sum().item()
#                   acc_discr_news_stt_is_st += (discr_news_out_fake.argmax(1) == 1).sum().item()
                
#                 #trg = [trg len, batch size]
#                 #predict = [trg len, batch size, output dim]
#                 predict_dim = predict_same.shape[-1]
#                 predict_same_for_loss = predict_same_for_loss[1:].view(-1, predict_dim)
#                 trg = trg[1:].view(-1)
#                 #trg = [(trg len - 1) * batch size]
#                 #predict_same = [(trg len - 1) * batch size, output dim]
                
#                 loss_reconstr = gener_criterion(predict_same_for_loss, trg)

#                 loss_generator = loss_reconstr + loss_style_trans
#                 loss_discr = loss_discr_poem + loss_discr_news

#                 loss = loss_generator + loss_discr

#                 epoch_loss += loss
#                 epoch_loss_gen_rec += loss_reconstr
#                 epoch_loss_gen_stt += loss_style_trans
#                 epoch_loss_discrs += loss_discr

#     acc_discr_poem_rec = acc_discr_poem_rec_right / len(test_poem)
#     acc_discr_news_rec = acc_discr_news_rec_right / len(test_poem)

#     acc_discr_poem_stt = acc_discr_poem_stt_is_st / len(test_poem)
#     acc_discr_news_stt = acc_discr_news_stt_is_st / len(test_poem)
        
#     return epoch_loss / len(loader_p), epoch_loss_gen_rec / len(loader_p), \
#            epoch_loss_gen_stt / len(loader_p), epoch_loss_discrs / len(loader_p), \
#            acc_discr_poem_rec, acc_discr_news_rec, \
#            acc_discr_poem_stt, acc_discr_news_stt, \
#            out_list_news_recon, out_list_news_trans, out_list_poem_recon, out_list_poem_trans

In [16]:
checkpoint = torch.load('ST_len40_dAdam005_genCyclefdf10_isnt1stt1-af10e0-Copy1.25rec1_discr14e_checkp_last.pt')
# checkpoint = torch.load('/content/drive/MyDrive/Colab Notebooks/NLP_poems/ST_15k_voc20k_len20_nf100_fs2,6_Lstt3_discrCyc0.1fdf1_checkp_last.pt', 
#                         map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
discr_poem.load_state_dict(checkpoint['discr_poem_state_dict'])
discr_news.load_state_dict(checkpoint['discr_news_state_dict'])

<All keys matched successfully>

In [17]:
with open('data/voc20klen40_test_poem.pickle', 'rb', ) as file:
# with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/voc20klen20_test_poem.pickle', 'rb', ) as file:
    test_poem = pickle.load(file)
test_poem_loader = (
    DataLoader(test_poem, batch_size=64, shuffle=False, num_workers=0, pin_memory=True, collate_fn=padding)
)

with open('data/voc20klen40_test_news.pickle', 'rb', ) as file:
# with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/voc20klen20_test_news.pickle', 'rb', ) as file:
    test_news = pickle.load(file)
test_news_loader = (
    DataLoader(test_news, batch_size=64, shuffle=False, num_workers=0, pin_memory=True, collate_fn=padding)
)

In [18]:
test_loss, test_loss_gen_rec, test_loss_gen_stt, test_loss_discrs, \
test_acc_recon_poem, test_acc_recon_news, test_acc_discr_poem_stt, test_acc_discr_news_stt, \
test_news_recon, test_news_trans, test_poem_recon, test_poem_trans = \
           test(model, discr_poem, discr_news, 
                test_poem_loader, test_news_loader, 
                gener_criterion, discr_criterion)

print(f'\t Test Loss: {test_loss:.3f} | Acc recon poem: {test_acc_recon_poem:.2f} | Acc recon news: {test_acc_recon_news:.2f} | ' +
          f'Acc d_poem stt: {test_acc_discr_poem_stt:.2f} | Acc d_news stt: {test_acc_discr_news_stt:.2f}')

  0%|          | 0/24 [00:00<?, ?it/s]

	 Test Loss: 8.662 | Acc recon poem: 0.80 | Acc recon news: 0.87 | Acc d_poem stt: 0.04 | Acc d_news stt: 0.03


#### Get poems from test output and compare with orig test poems

In [19]:
# with open('data/len40_batch_gener_news-poem_recon-trans_token.pickle', 'wb', ) as file:
# # with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/voc20klen20_test_news.pickle', 'wb', ) as file:
#     pickle.dump((test_news_recon, test_news_trans, test_poem_recon, test_poem_trans), file)

In [19]:
#[num batchs, batch len, trg len]
gener_news_recon, gener_news_trans, gener_poem_recon, gener_poem_trans = [],[],[],[]
#in one lines all data

for i in range(len(test_news_recon)):
    gener_news_recon.extend(test_news_recon[i]) 
    gener_news_trans.extend(test_news_trans[i])
    gener_poem_recon.extend(test_poem_recon[i])
    gener_poem_trans.extend(test_poem_trans[i])

# with open('data/len40_gener_news-poem_recon-trans_token.pickle', 'wb', ) as file:
# # with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/voc20klen20_test_news.pickle', 'wb', ) as file:
#     pickle.dump((gener_news_recon, gener_news_trans, gener_poem_recon, gener_poem_trans), file)

In [20]:
PAD_ID = token_model.subword_to_id('<PAD>')
BOS_ID = token_model.subword_to_id('<BOS>')
EOS_ID = token_model.subword_to_id('<EOS>')

gener_news_recon = token_model.decode(gener_news_recon, ignore_ids=[PAD_ID, BOS_ID, EOS_ID])
gener_news_trans = token_model.decode(gener_news_trans, ignore_ids=[PAD_ID, BOS_ID, EOS_ID])
gener_poem_recon = token_model.decode(gener_poem_recon, ignore_ids=[PAD_ID, BOS_ID, EOS_ID])
gener_poem_trans = token_model.decode(gener_poem_trans, ignore_ids=[PAD_ID, BOS_ID, EOS_ID])

orig_news_recon = token_model.decode(test_news, ignore_ids=[PAD_ID, BOS_ID, EOS_ID])
orig_poem_recon = token_model.decode(test_poem, ignore_ids=[PAD_ID, BOS_ID, EOS_ID])

In [21]:
for i in range(20, 30):
  print(orig_news_recon[i])
  print(gener_news_recon[i])
  print(gener_news_trans[i])
  print()

теракт произошел в субботу в индийском городе пуна погибли как минимум восемь человек сообщает агентство рейтер . взрывное устройство сработало в популярной закусочной когда внутри находилось много туристов и иностранцев . теракт на западе
теракт произошел в субботу в индийском городе пуна погибли как минимум восемь человек сообщает агентство рейтер . взрывное взрывное в в внойной побесочной когда внутри много много и власти . . на западе
лишь знаешь в меня | индийском городе пуна | как минимум восемь человек там . | все взрывное глядя в поэтной закутелчной когда когда внутри много много и и . . боль боль

истребитель су ввс россии пропал с экранов радаров в хабаровском крае сообщил риа новости источник в дальневосточном военном округе . истребитель су ввс рф пропал с экранов радаров в хабаровском крае истреб
истребитель су ввс россии пропал с экранов радаров в хабаровском крае сообщил риа новости источник в дальневосточном военном округе . истребитель су ввс рф рф с экранов радаров в 

In [24]:
for i in range(30, 40):
  print(orig_poem_recon[i])
  print(gener_poem_recon[i])
  print(gener_poem_trans[i])
  print()

мама разрешила свете | взять из вазочки конфеты | и сказала обещай мне | столько взять чтоб угощая | никого не обойти . | всех конфеткой угости . | света
мама разрешила свете | взять из вазочки конфеты | и сказала обещай мне столько столько взять чтоб угощая | никого не спокой . | всех конфеткой угости . | света
мама разрешила корпорация | взять из вазочки конфеты среду и сказала обещайдара скончался столько взять чтоб одного угощая медведев призывает полтора сша . всех всехеткойции угости . света

теперь не умирают от любви | насмешливая трезвая эпоха . | лишь падает гемоглобин в крови | лишь без причины человеку плохо . | с юлия друнина | добрый
теперь не умирают от любви | насмешливая трезвая эпоха . | лишь падает гглогло | в крови | лишь без проблем плохо . | с с дру друнина | добрый
теперь службы умирают от любви главный наставникшливая трезвая бригадыха стоимость сша лишь сша г лидерглобин в четверг в четверг без могут призывает власти . с с юлия нынешнийнина виктор добрый

вот о