In [45]:
# from google.colab import drive
# drive.mount('/content/drive')

In [46]:
# !pip install youtokentome

In [47]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter

import youtokentome as yttm
import numpy as np
from navec import Navec

import random
import math
import time
import pickle
from tqdm.auto import tqdm

In [48]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# torch.use_deterministic_algorithms(True)

In [49]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

In [50]:
token_model = yttm.BPE(model='models/100k_voc20k_all_embed_yttm.model', n_threads=-1)
# token_model = yttm.BPE(model='/content/drive/MyDrive/Colab Notebooks/NLP_poems/models/100k_voc20k_all_embed_yttm.model', n_threads=-1)

In [51]:
class styleDataset(Dataset):
    def __init__(self, data_list_of_list):
        self.data = data_list_of_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    def cut_length(self, limit):
        self.data = [text if len(text)<limit else text[:limit] for text in self.data]

    def distribute_data(self, share_tr, share_val):  # examle = (0.8, 0.1)
        rand_i = np.random.permutation(len(self.data))
        n1 = int(len(self.data) * share_tr)
        n2 = int(len(self.data) * share_val)
        return [self.data[i] for i in rand_i[0: n1]], \
               [self.data[i] for i in rand_i[n1: n1 + n2]], \
               [self.data[i] for i in rand_i[n1 + n2:]]
    
def padding(batch):
    pad_id = token_model.subword_to_id('<PAD>')
    
    batch = [torch.tensor(x) for x in batch]
    batch = pad_sequence(batch, batch_first=False, padding_value=pad_id)  # batch_first=True -> [batch size, len]
    return batch
    # return batch.to(device)

In [52]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, emb_dim, n_filters, filter_sizes, dropout):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, emb_dim)

        # in_cnannels: In actual images this is usually 3 (one channel for each of the red, blue and green channels),
        #             when using text we only have a single channel, the text itself
        # out_channels: the number of filters
        # kernel_size is the size of the filters = [n x emb_dim] where n is the size of the n-grams.

        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels=1,
                      out_channels=n_filters,
                      kernel_size=(fs, emb_dim)) #fs - how token watch this filter
            for fs in filter_sizes
        ])

        self.fc = nn.Linear(len(filter_sizes) * n_filters, 2) 

        self.dropout = nn.Dropout(dropout)

    def forward(self, token_text):

        embedded = self.embedding(token_text)
        # embedded = [batch size, trg len, emb dim]

        embedded = embedded.unsqueeze(1)
        # embedded = [batch size, 1, trg len, hid dim]

        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs] # element-wise relu(x) = max(0,x)
        # conved_n = [batch size, n_filters, trg len - filter_sizes[n] + 1]

        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
        # pooled_n = [batch size, n_filters]

        cat = self.dropout(torch.cat(pooled, dim=1))
        # cat = [batch size, n_filters * len(filter_sizes)]

        predicts = self.fc(cat) # [x,y] - x bigger, then this isn't this style, y - it's this style
        # predicts = [batch size, 2]
        return predicts 

In [53]:
INPUT_DIM = token_model.vocab_size() 
OUTPUT_DIM = INPUT_DIM
EMB_DIM = 128  
LABEL_DROPOUT = 0.5

N_FILTERS = 128
FILTER_SIZES = [1,3,4,5,8] #[2,3,4]

classifier = Discriminator(INPUT_DIM, EMB_DIM, N_FILTERS, FILTER_SIZES, LABEL_DROPOUT).to(device)

In [54]:
criterion = nn.CrossEntropyLoss()

In [55]:
def test(classifier, loader_p, loader_n, criterion):
    
    classifier.eval()
    
    epoch_loss = 0

    acc_poem_right = 0 # get poem and told yes
    acc_news_right = 0
    
    all_predicts_p = []
    all_predicts_n = []
    
    with torch.no_grad():
    
        for batch_p, batch_n in tqdm(zip(loader_p, loader_n), total=len(loader_p)):
            
            batch_p = batch_p.to(device, non_blocking=True)
            batch_n = batch_n.to(device, non_blocking=True)
            batch_size = batch_p.shape[1]
            
            # [x,y] - x news, y poem
            news_style = torch.zeros((batch_size), device=device).long()
            poem_style = torch.ones((batch_size), device=device).long()

            #batch = [trg len, batch size]
            batch_p = batch_p.permute(1, 0)
            batch_n = batch_n.permute(1, 0)
            #batch = [batch size, trg len]

            predicts_poem = classifier(batch_p)
            predicts_news = classifier(batch_n)
            
            all_predicts_p.extend(predicts_poem)
            all_predicts_n.extend(predicts_news)

            loss = criterion(predicts_poem, poem_style) + \
                   criterion(predicts_news, news_style)
            
            acc_news_right += (predicts_news.argmax(1) == 0).sum().item()
            acc_poem_right += (predicts_poem.argmax(1) == 1).sum().item()

            epoch_loss += loss
            

    acc_poem_right = acc_poem_right / len(test_poem)
    acc_news_right = acc_news_right / len(test_poem)

    return epoch_loss / len(loader_p), acc_poem_right, acc_news_right, \
            all_predicts_p, all_predicts_n

In [56]:
def simple_test(classifier, token_text):
    
    classifier.eval()
    
    with torch.no_grad():
        if len(token_text.shape) == 1: # no batch_size dim
            token_text = token_text.unsqueeze(0)
        token_text = token_text.to(device, non_blocking=True)
        #token_text = [batch size, trg len]

        predict = classifier(token_text)

    return predict # [x,y] - x news, y poem

### Test discr

In [None]:
checkpoint = torch.load('ST_len20_nf128_fs1,3,4,5,8_dAdam005_isnt3stt3rec0.1_clip3_discr16e_checkp_last.pt')
# checkpoint = torch.load('/content/drive/MyDrive/Colab Notebooks/NLP_poems/ST_15k_voc20k_len20_nf100_fs2,6_Lstt3_discrCyc0.1fdf1_checkp_last.pt', 
#                         map_location=torch.device('cpu'))
classifier.load_state_dict(checkpoint['discr_news_state_dict'])

In [None]:
text = "два точных удара чилийского форварда умберто суазо принесли сарагосе победу со" # real news
token_text = token_model.encode(text, output_type=yttm.OutputType.ID, bos=True, eos=True)
token_text = torch.tensor(token_text[:20])
print(simple_test(classifier, token_text))

text = "два точных удараскогоского форварда игра фато суазо победскойгосе победу со" # reconstr
token_text = token_model.encode(text, output_type=yttm.OutputType.ID, bos=True, eos=True)
token_text = torch.tensor(token_text[:20])
print(simple_test(classifier, token_text))

text = "два точных чилийского бар ум золотойто суазовецлемгосе победу со со" # trans to poem
token_text = token_model.encode(text, output_type=yttm.OutputType.ID, bos=True, eos=True)
token_text = torch.tensor(token_text[:20])
print(simple_test(classifier, token_text))

text = "два точных удараскогоского форварда | игра фато суазо победскойгосе победу со" # reconstr + |
token_text = token_model.encode(text, output_type=yttm.OutputType.ID, bos=True, eos=True)
token_text = torch.tensor(token_text[:20])
print(simple_test(classifier, token_text))

### Test CNN model

In [None]:
checkpoint = torch.load('CL_len20_emb128_drop0.5_Adam0005_noSch_halfPoemNoendl_checkp_best.pt')
# checkpoint = torch.load('/content/drive/MyDrive/Colab Notebooks/NLP_poems/ST_15k_voc20k_len20_nf100_fs2,6_Lstt3_discrCyc0.1fdf1_checkp_last.pt', 
#                         map_location=torch.device('cpu'))
classifier.load_state_dict(checkpoint['classifier_state_dict'])

In [None]:
with open('data/voc20klen20_test_poem.pickle', 'rb', ) as file:
# with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/100k_poems_preproc.pickle', 'rb', ) as file:
    test_poem = pickle.load(file)
with open('data/voc20klen20_test_news.pickle', 'rb', ) as file:
# with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/100k_news_preproc.pickle', 'rb', ) as file:
    test_news = pickle.load(file)

In [None]:
test_poem_loader = \
    DataLoader(test_poem, batch_size=128, shuffle=False, num_workers=0, pin_memory=True, collate_fn=padding)
test_news_loader = \
    DataLoader(test_news, batch_size=128, shuffle=False, num_workers=0, pin_memory=True, collate_fn=padding)

In [None]:
test_loss, test_acc_poem_right, test_acc_news_right, all_predicts_p, all_predicts_n = \
                test(classifier, test_poem_loader, test_news_loader, criterion)

print(f'\tTest Loss: {test_loss:.3f} | Acc poem: {test_acc_poem_right:.2f} | Acc news: {test_acc_news_right:.2f}')

In [None]:
PAD_ID = token_model.subword_to_id('<PAD>')
BOS_ID = token_model.subword_to_id('<BOS>')
EOS_ID = token_model.subword_to_id('<EOS>')

orig_news = token_model.decode(test_news, ignore_ids=[PAD_ID, BOS_ID, EOS_ID])
orig_poem = token_model.decode(test_poem, ignore_ids=[PAD_ID, BOS_ID, EOS_ID])

In [None]:
for i in range(0,10):
    print(orig_poem[i])
    print(all_predicts_p[i])
    print()

In [None]:
for i in range(0,10):
    print(orig_news[i])
    print(all_predicts_n[i])
    print()

In [None]:
text = "катастрофа пожар потоп но зато не прошла же мимо . я набрасываюсь да" 
token_text = token_model.encode(text, output_type=yttm.OutputType.ID, bos=True, eos=True)
token_text = torch.tensor(token_text[:20])
print(simple_test(classifier, token_text))

### Test result of style transfer model

In [76]:
checkpoint = torch.load('CL_len40_emb128_drop0.5_Adam0005_noSch_halfPoemNoendl_noSch_checkp_best.pt')
# checkpoint = torch.load('/content/drive/MyDrive/Colab Notebooks/NLP_poems/ST_15k_voc20k_len20_nf100_fs2,6_Lstt3_discrCyc0.1fdf1_checkp_last.pt', 
#                         map_location=torch.device('cpu'))
classifier.load_state_dict(checkpoint['classifier_state_dict'])

<All keys matched successfully>

In [77]:
token_model = yttm.BPE(model='models/100k_voc20k_all_embed_yttm.model', n_threads=-1)
# token_model = yttm.BPE(model='/content/drive/MyDrive/Colab Notebooks/NLP_poems/models/100k_voc20k_all_embed_yttm.model', n_threads=-1)

emb_model = Navec.load('navec_hudlit_v1_12B_500K_300d_100q.tar') 

In [78]:
with open('data/voc20klen40_test_poem.pickle', 'rb', ) as file:
# with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/100k_poems_preproc.pickle', 'rb', ) as file:
    test_poem = pickle.load(file)
with open('data/voc20klen40_test_news.pickle', 'rb', ) as file:
# with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/100k_news_preproc.pickle', 'rb', ) as file:
    test_news = pickle.load(file)
with open('data/len40_batch_gener_news-poem_recon-trans_token.pickle', 'rb', ) as file:
# with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/100k_news_preproc.pickle', 'rb', ) as file:
    batches_gener_news_recon, batches_gener_news_trans, \
    batches_gener_poem_recon, batches_gener_poem_trans = pickle.load(file)


#### style transfer eval

In [24]:
num_right_pred_poem_from_news = 0
num_all_poem_from_news = 0
for batch in batches_gener_news_trans:
    predicts_poem_from_news = simple_test(classifier, batch)
    # [batch size, 2]
    
    num_right_pred_poem_from_news += (predicts_poem_from_news.argmax(1) == 1).sum().item()
    num_all_poem_from_news += len(batch)
    
num_right_pred_news_from_poem = 0
num_all_news_from_poem = 0
for batch in batches_gener_poem_trans:
    predicts_news_from_poem = simple_test(classifier, batch)
    # [batch size, 2]
    
    num_right_pred_news_from_poem += (predicts_news_from_poem.argmax(1) == 0).sum().item()
    num_all_news_from_poem += len(batch)

In [26]:
# [x,y] - x news, y poem
acc_poem_from_news = num_right_pred_poem_from_news / num_all_poem_from_news
acc_news_from_poem = num_right_pred_news_from_poem / num_all_news_from_poem

print(acc_poem_from_news, acc_news_from_poem)

0.992 0.8966666666666666


#### context preservation eval

In [79]:
def padding_vec(texts, pad_emb):
    # [texts num, word num (different!!!), embed len]
    max_len = 0
    for text in texts:
        if len(text) > max_len:
            max_len = len(text)
            
    for i in range(len(texts)):
        pad_len = max_len - len(texts[i])
        if pad_len != 0:
            texts[i].extend([pad_emb]*pad_len)
    return texts

def unk_perc(texts):
    total_words = 0
    total_unk_words = 0
    for text in texts:
        for word in text.split(' '):
            if not word in emb_model and word != '|' and word != '.':
                total_unk_words += 1
        total_words += len(text.split(' '))
    return total_unk_words * 100 / total_words

In [81]:
length = 40
emb_len = 300
PAD_ID = token_model.subword_to_id('<PAD>')
BOS_ID = token_model.subword_to_id('<BOS>')
EOS_ID = token_model.subword_to_id('<EOS>')

total_score = 0
i = 0
for batch in batches_gener_news_trans:
    source_texts = batch.tolist()
    target_texts = test_news[i:i+len(batch)]
    i += len(batch)
    
    source_texts = token_model.decode(source_texts, ignore_ids=[PAD_ID, BOS_ID, EOS_ID])
    target_texts = token_model.decode(target_texts, ignore_ids=[PAD_ID, BOS_ID, EOS_ID])
    
#     print('gener', unk_perc(source_texts))
#     print('real', unk_perc(target_texts))
    
    embed_source_texts = [
        [emb_model[word] if word in emb_model else emb_model['<unk>'] for word in text.split(' ')] 
    for text in source_texts]
    embed_target_texts = [
        [emb_model[word] if word in emb_model else emb_model['<unk>'] for word in text.split(' ')] 
    for text in target_texts]
    # [texts num, word num (different!!!), embed len]
    
    embed_source_texts = torch.tensor(padding_vec(embed_source_texts, emb_model['<pad>'])).to(device)
    embed_target_texts = torch.tensor(padding_vec(embed_target_texts, emb_model['<pad>'])).to(device)
    # [texts num, word num, embed len]
    
#     target_texts = [torch.tensor(x) for x in target_texts]
#     target_texts = pad_sequence(target_texts, batch_first=True, padding_value=PAD_ID).to(device)
#     embed_source_texts = embedding(source_texts) 
#     embed_target_texts = embedding(target_texts)
    
    for embed_src_text, embed_trg_text in zip(embed_source_texts, embed_target_texts):
        # [word num, embed len]

        # [min, mean, max] = sentence emb-g
        v_s = torch.zeros(emb_len*3)
        v_s[:emb_len] = torch.min(embed_src_text, dim=0)[0]
        v_s[emb_len:emb_len*2] = torch.mean(embed_src_text, dim=0)[0]
        v_s[emb_len*2:] = torch.max(embed_src_text, dim=0)[0]

        v_t = torch.zeros(emb_len*3)
        v_t[:emb_len] = torch.min(embed_trg_text, dim=0)[0]
        v_t[emb_len:emb_len*2] = torch.mean(embed_trg_text, dim=0)[0]
        v_t[emb_len*2:] = torch.max(embed_trg_text, dim=0)[0]
        
#         v_s_min = torch.min(embed_src_text)
#         v_s_mean = torch.mean(embed_src_text)
#         v_s_max = torch.max(embed_src_text)

#         v_t_min = torch.min(embed_trg_text)
#         v_t_mean = torch.mean(embed_trg_text)
#         v_t_max = torch.max(embed_trg_text)

#         v_s = torch.tensor([v_s_min, v_s_mean, v_s_max])
#         v_t = torch.tensor([v_t_min, v_t_mean, v_t_max])

        # cosine distance
        score = torch.matmul(v_s.t(), v_t) / (torch.linalg.norm(v_s, ord=2) * torch.linalg.norm(v_t, ord=2))
        # if source = target, then score = 1.0, else score < 1.0
        total_score += score
#         total_score += (1.0 - score)

gener 12.40079365079365
real 3.250380904012189
gener 12.487611496531219
real 3.3265097236438077
gener 12.118226600985222
real 3.117144293614882
gener 11.578426521523998
real 3.1802120141342756
gener 11.773255813953488
real 3.1265508684863526
gener 11.024390243902438
real 2.795806290564154
gener 12.537612838515546
real 4.130545639979602
gener 12.231652521218173
real 3.734015345268542
gener 12.256128064032016
real 3.8987341772151898
gener 11.01407083939835
real 2.897102897102897
gener 12.0
real 3.9245667686034658
gener 13.30704780680138
real 3.6280020439448135
gener 11.505752876438219
real 3.0980192991366176
gener 10.61124694376528
real 3.0784508440913605
gener 11.72106824925816
real 3.3248081841432224
gener 12.285012285012286
real 2.6986506746626686
gener 12.64253842340109
real 3.6924633282751644
gener 13.041314086610253
real 3.5222052067381315


KeyboardInterrupt: 

In [44]:
print(f'{total_score.item() / i:.3f}')

0.941


In [82]:
# embed_src_text = torch.tensor([[1.,2.,3.], [2.,3.,0.], [3.,9.,5.], [0.,0.,0.]])
# embed_trg_text = torch.tensor([[1.,2.,3.], [5.,6.,7.], [3.,4.,5.]])

source_texts = ["В Москве открылась новая выставка военной техники", 
                "Был арестован интернет мошшеник, который успел украсть долларов"]
target_texts = ["Теперь я знаю, в вашей воле, меня презреньем наказать", 
                "Октябрь уж наступил уж роща отряхает Последние листы с нагих своих ветвей"]

print(unk_perc(source_texts))
print(unk_perc(target_texts))

embed_src_text = [
    [emb_model[word] if word in emb_model else emb_model['<unk>'] for word in text.split(' ')] 
for text in source_texts]
embed_trg_text = [
    [emb_model[word] if word in emb_model else emb_model['<unk>'] for word in text.split(' ')] 
for text in target_texts]

embed_src_text = torch.tensor(padding_vec(embed_src_text, emb_model['<pad>'])).to(device)
embed_trg_text = torch.tensor(padding_vec(embed_trg_text, emb_model['<pad>'])).to(device)

embed_src_text = embed_src_text[1]
embed_trg_text = embed_trg_text[1]

emb_len = 300

v_s = torch.zeros(emb_len*3)
v_s[:emb_len] = torch.min(embed_src_text, dim=0)[0]
v_s[emb_len:emb_len*2] = torch.mean(embed_src_text, dim=0)[0]
v_s[emb_len*2:] = torch.max(embed_src_text, dim=0)[0]

v_t = torch.zeros(emb_len*3)
v_t[:emb_len] = torch.min(embed_trg_text, dim=0)[0]
v_t[emb_len:emb_len*2] = torch.mean(embed_trg_text, dim=0)[0]
v_t[emb_len*2:] = torch.max(embed_trg_text, dim=0)[0]

score = torch.matmul(v_s.t(), v_t) / (torch.linalg.norm(v_s, ord=2) * torch.linalg.norm(v_t, ord=2))
score

26.666666666666668
28.571428571428573


tensor(0.8080)