In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

# !pip install youtokentome

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter

import youtokentome as yttm
import numpy as np

import random
import math
import time
import pickle
from tqdm.auto import tqdm

In [3]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# torch.use_deterministic_algorithms(True)
# torch.backend.cudnn.benchmark = True

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

In [5]:
token_model = yttm.BPE(model='models/100k_voc20k_all_embed_yttm.model', n_threads=-1)
# token_model = yttm.BPE(model='/content/drive/MyDrive/Colab Notebooks/NLP_poems/models/100k_voc20k_all_embed_yttm.model', n_threads=-1)

In [6]:
with open('data/100k_poems_preproc.pickle', 'rb', ) as file:
# with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/100k_poems_preproc.pickle', 'rb', ) as file:
    preproc_poem = pickle.load(file)
with open('data/100k_news_preproc.pickle', 'rb', ) as file:
# with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/100k_news_preproc.pickle', 'rb', ) as file:
    preproc_news = pickle.load(file)

In [7]:
preproc_poem = preproc_poem[:15000]
preproc_news = preproc_news[:15000]

In [8]:
token_poem = token_model.encode(preproc_poem, output_type=yttm.OutputType.ID, bos=True, eos=True)
token_news = token_model.encode(preproc_news, output_type=yttm.OutputType.ID, bos=True, eos=True)

In [9]:
class styleDataset(Dataset):
    def __init__(self, data_list_of_list):
        self.data = data_list_of_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    def cut_length(self, limit):
        self.data = [text if len(text)<limit else text[:limit] for text in self.data]

    def distribute_data(self, share_tr, share_val):  # examle = (0.8, 0.1)
        rand_i = np.random.permutation(len(self.data))
        n1 = int(len(self.data) * share_tr)
        n2 = int(len(self.data) * share_val)
        return [self.data[i] for i in rand_i[0: n1]], \
               [self.data[i] for i in rand_i[n1: n1 + n2]], \
               [self.data[i] for i in rand_i[n1 + n2:]]
    
def padding(batch):
    pad_id = token_model.subword_to_id('<PAD>')
    
    batch = [torch.tensor(x) for x in batch]
    batch = pad_sequence(batch, batch_first=False, padding_value=pad_id)  # batch_first=True -> [batch size, len]
    return batch

# def padding_poem(batch):
#     pad_id = token_model.subword_to_id('<PAD>')
    
#     batch = [torch.tensor(x) for x in batch]
#     batch = pad_sequence(batch, batch_first=False, padding_value=pad_id)  # batch_first=True -> [batch size, len]
#     return batch, 1  # return style

In [10]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, style_dim, n_layers, dropout):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.style_dim = style_dim
        
        self.embedding = nn.Embedding(input_dim, emb_dim) # 1 - size of dict emb, 2 - size of emb vec
        self.style_embedding = nn.Embedding(2, style_dim) # 2, because 0-new, 1-poem
        
        self.rnn = nn.GRU(emb_dim, hid_dim + style_dim, n_layers)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_style):      
        #src = [src len, batch size]
        
        embedded = self.dropout(self.embedding(src))
        #embedded = [src len, batch size, emb dim]

        style_token = torch.tensor([src_style], device=device)

        init_hidden = torch.cat((self.style_embedding(style_token), 
                                torch.zeros((1, self.hid_dim), device=device)), dim=1)
        init_hidden = init_hidden.repeat(self.n_layers, src.shape[1], 1) 
        #init hidden = [n layers, batch size, style dim + hid_dim]

        outputs, hidden = self.rnn(embedded, init_hidden)
        #outputs = [src len, batch size, style dim + hid dim]
        #hidden = [n layers, batch size, style dim + hid dim]
        
        #pop style part
        outputs = outputs[:,:, self.style_dim:]
        hidden = hidden[:,:, self.style_dim:]
        #outputs = [src len, batch size, hid dim]
        #hidden = [n layers, batch size, hid dim]

        #outputs are always from the top hidden layer
        #hidden - it will be first hid states in decoder
        return outputs, hidden

In [11]:
# # no style emb
# class Attention(nn.Module):
#     def __init__(self, hid_dim, style_dim):
#         super().__init__()
        
#         self.style_dim = style_dim
#         self.attn = nn.Linear(hid_dim * 2, hid_dim)
#         self.v = nn.Linear(hid_dim, 1, bias = False)
        
#     def forward(self, hidden, encoder_outputs):
        
#         #hidden = [batch size, style dim + hid dim]
#         #encoder_outputs = [src len, batch size, hid dim]
        
#         batch_size = encoder_outputs.shape[1]
#         src_len = encoder_outputs.shape[0]

#         #from hidden pop style emb part 
#         hidden = hidden[:, self.style_dim:]
#         #hidden = [batch size, hid dim + style emb]
        
#         #repeat decoder hidden state src_len times
#         hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
#         encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
#         #hidden = [batch size, src len, hid dim]
#         #encoder_outputs = [batch size, src len, hid dim]
        
#         energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
#         #energy = [batch size, src len, hid dim]

#         attention = self.v(energy).squeeze(2)
#         #attention= [batch size, src len]
        
#         return F.softmax(attention, dim=1)

In [12]:
# class Decoder(nn.Module):
#     def __init__(self, output_dim, emb_dim, hid_dim, style_dim, n_layers, dropout, attention):
#         super().__init__()
        
#         self.output_dim = output_dim  # the size of the vocabulary for the output/target.
#         self.hid_dim = hid_dim
#         self.n_layers = n_layers
#         self.style_dim = style_dim

#         self.attention = attention
        
#         self.embedding = nn.Embedding(output_dim, emb_dim) 
#         self.style_embedding = nn.Embedding(2, style_dim)

#         self.rnn = nn.GRU(hid_dim + emb_dim, hid_dim + style_dim, n_layers) # dropout = dropout

#         self.fc_out = nn.Linear(hid_dim * 2 + emb_dim, output_dim)
#         self.dropout = nn.Dropout(dropout)
        
#     def forward(self, input, hidden, encoder_outputs, trg_style):

#         #input = [batch size]
#         #hidden = [n layers, batch size, style dim + hid dim]
#         #encoder_outputs = [src len, batch size, hid dim]
        
#         input = input.unsqueeze(0) 
#         #input = [1, batch size]
#         embedded = self.dropout(self.embedding(input))
#         #embedded = [1, batch size, emb dim]

#         if trg_style != -1: # make init hidden with style
#             style_token = torch.tensor([trg_style], device=device)
#             h_style = self.style_embedding(style_token)
#             h_style = h_style.repeat(self.n_layers, hidden.shape[1], 1)
#             #h_style = [n layers, batch size, style dim]

#             hidden = torch.concat((h_style, hidden), dim=2)
                
#         last_hidden = hidden[-1]
#         #last hidden = [batch size, style dim + hid dim]        
#         a = self.attention(last_hidden, encoder_outputs)    
#         #a = [batch size, src len]
#         a = a.unsqueeze(1)
#         #a = [batch size, 1, src len]
        
#         encoder_outputs = encoder_outputs.permute(1, 0, 2)
#         #encoder_outputs = [batch size, src len, hid dim]
        
#         weighted = torch.bmm(a, encoder_outputs)    
#         #weighted = [batch size, 1, hid dim]
#         weighted = weighted.permute(1, 0, 2)
#         #weighted = [1, batch size, hid dim]
        
#         rnn_input = torch.cat((embedded, weighted), dim = 2)
#         #rnn_input = [1, batch size, hid dim + emb dim]

#         output, hidden = self.rnn(rnn_input, hidden)
#         #output = [1, batch size, style dim + hid dim]
#         #hidden = [n layers, batch size, style dim + hid dim]
        
#         embedded = embedded.squeeze(0)
#         output = output.squeeze(0)
#         weighted = weighted.squeeze(0)

#         #pop style part from output
#         output_pred = output[:, self.style_dim:]
        
#         prediction = self.fc_out(torch.cat((output_pred, weighted, embedded), dim = 1))
#         #prediction = [batch size, output dim]
        
#         return output, prediction, hidden

In [13]:
# style emb
class Attention(nn.Module):
    def __init__(self, hid_dim, style_dim):
        super().__init__()
        
        self.style_dim = style_dim
        self.attn = nn.Linear(hid_dim * 2 + style_dim, hid_dim)
        self.v = nn.Linear(hid_dim, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs):
        #hidden = [batch size, style dim + hid dim]
        #encoder_outputs = [src len, batch size, hid dim]
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        #repeat decoder hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        #hidden = [batch size, src len, style dim + hid dim]
        #encoder_outputs = [batch size, src len, hid dim]
        
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
        #energy = [batch size, src len, hid dim]

        attention = self.v(energy).squeeze(2)
        #attention= [batch size, src len]
        
        return F.softmax(attention, dim=1)

In [14]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, style_dim, n_layers, dropout, attention):
        super().__init__()
        
        self.output_dim = output_dim  # the size of the voc
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.style_dim = style_dim

        self.attention = attention
        
        self.embedding = nn.Embedding(output_dim, emb_dim) 
        self.style_embedding = nn.Embedding(2, style_dim)

        self.rnn = nn.GRU(hid_dim + emb_dim, hid_dim + style_dim, n_layers) 

        self.fc_out = nn.Linear(hid_dim * 2 + emb_dim + style_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs, trg_style):
        #input = [batch size]
        #hidden = [n layers, batch size, style dim + hid dim]
        #encoder_outputs = [src len, batch size, hid dim]
        
        input = input.unsqueeze(0) 
        #input = [1, batch size]
        embedded = self.dropout(self.embedding(input))
        #embedded = [1, batch size, emb dim]

        if trg_style != -1: # make init hidden with style
            style_token = torch.tensor([trg_style], device=device)
            h_style = self.style_embedding(style_token)
            h_style = h_style.repeat(self.n_layers, hidden.shape[1], 1)
            #h_style = [n layers, batch size, style dim]

            hidden = torch.concat((h_style, hidden), dim=2)
                
        last_hidden = hidden[-1]
        #last hidden = [batch size, style dim + hid dim]        
        a = self.attention(last_hidden, encoder_outputs)    
        #a = [batch size, src len]
        a = a.unsqueeze(1)
        #a = [batch size, 1, src len]
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        #encoder_outputs = [batch size, src len, hid dim]
        
        weighted = torch.bmm(a, encoder_outputs)    
        #weighted = [batch size, 1, hid dim]
        weighted = weighted.permute(1, 0, 2)
        #weighted = [1, batch size, hid dim]
        
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        #rnn_input = [1, batch size, hid dim + emb dim]

        output, hidden = self.rnn(rnn_input, hidden)
        #output = [1, batch size, style dim + hid dim]
        #hidden = [n layers, batch size, style dim + hid dim]
        
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
        #prediction = [batch size, output dim]
        
#         return output, prediction, hidden
        return prediction, hidden

In [15]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, src_style, trg_style, tf_ratio):
        #src, trg = [src len, batch size]
        
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        hid_dim = self.decoder.hid_dim
        style_dim = self.decoder.style_dim
        
        predictions = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
#         outputs = torch.zeros(trg_len, batch_size, hid_dim + style_dim).to(self.device)
         
        enc_outputs, hidden = self.encoder(src, src_style)
            
        #first input to the decoder is the <bos> tokens
        input = trg[0,:]

        for t in range(1, trg_len):
            #insert: input token emb-g, previous hid state, all enc outputs, style
            #receive: prediction and new hidden state
            prediction, hidden = self.decoder(input, hidden, enc_outputs, trg_style)
            
            trg_style = -1 # turn off making init hidden with style
            
            predictions[t] = prediction
#             outputs[t] = output 
            
            #get the highest predicted token
            top1 = prediction.argmax(1) 
            teacher_force = random.random() < tf_ratio
            input = trg[t] if teacher_force else top1

        return predictions

In [16]:
class Discriminator(nn.Module):
    def __init__(self, output_dim, emb_dim, n_filters, filter_sizes, dropout):
        super().__init__()
        
        self.emb_squeeze = nn.Linear(output_dim, emb_dim)

        # in_cnannels: in images this is 3 (1 channel for each of the red, blue and green),
        #              in text we only have a 1 channel - the text itself
        # out_channels: the number of filters
        # kernel_size is the size of the filters = [n x emb_dim] where n is the size of the n-grams.

        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels=1,
                      out_channels=n_filters,
                      kernel_size=(fs, emb_dim)) 
            for fs in filter_sizes
        ])

        self.fc = nn.Linear(len(filter_sizes) * n_filters, 2) 

        self.dropout = nn.Dropout(dropout)

    def forward(self, gener_predicts):
        
        # gener_predicts = [batch size, trg len, output dim]
        gener_predicts = F.softmax(gener_predicts, dim=-1)
        embedded = self.emb_squeeze(gener_predicts)
        # embedded = [batch size, trg len, emb dim]

        embedded = embedded.unsqueeze(1)
        # embedded = [batch size, 1, trg len, hid dim]

        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs] # element-wise relu(x) = max(0,x)
        # conved_n = [batch size, n_filters, trg len - filter_sizes[n] + 1]

        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
        # pooled_n = [batch size, n_filters]

        cat = self.dropout(torch.cat(pooled, dim=1))
        # cat = [batch size, n_filters * len(filter_sizes)]

        predicts = self.fc(cat) 
        # predicts = [batch size, 2]
        # [x,y] - x bigger, then this isn't this style, y - it's this style
        return predicts 

In [17]:
INPUT_DIM = token_model.vocab_size() 
OUTPUT_DIM = INPUT_DIM
ENC_EMB_DIM = 128 
DEC_EMB_DIM = 128
STYLE_EMB_DIM = 256 
HID_DIM = 256 
N_LAYERS = 2 
ENC_DROPOUT = 0.2
DEC_DROPOUT = 0.2
LABEL_DROPOUT = 0.3

N_FILTERS = 128
FILTER_SIZES = [1,3,4,5,8] 

attn = Attention(HID_DIM, STYLE_EMB_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, STYLE_EMB_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, STYLE_EMB_DIM, N_LAYERS, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, device).to(device)
discr_poem = Discriminator(OUTPUT_DIM, HID_DIM, N_FILTERS, FILTER_SIZES, LABEL_DROPOUT).to(device) 
discr_news = Discriminator(OUTPUT_DIM, HID_DIM, N_FILTERS, FILTER_SIZES, LABEL_DROPOUT).to(device)

writer = SummaryWriter(log_dir='runs_st_best/ST_dAdam005_genCyclefdf10_isnt1stt1-af10e0.25rec1_discr14e_len4060_loss_right')
# writer = SummaryWriter(log_dir='/content/drive/MyDrive/Colab Notebooks/NLP_poems/runs/ST_15k_voc20k_len20_nf100_fs2,6_Lrec0.5_discrCycle01_detach')

In [18]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
        
model.apply(init_weights)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The seq2seq model has {count_parameters(model):,} trainable parameters')
print(f'The discr news has {count_parameters(discr_news):,} trainable parameters')
print(f'The discr poem has {count_parameters(discr_poem):,} trainable parameters')

The seq2seq model has 28,775,456 trainable parameters
The discr news has 5,810,306 trainable parameters
The discr poem has 5,810,306 trainable parameters


In [19]:
class CrossEntropyLoss(nn.Module):
    def __init__(self, ignore_index, reduction='mean'):
        super().__init__()
        self.reduction = reduction
        self.EPS = 1e-6
        self.ignore_index = ignore_index
        
    def forward(self, pred, gt):
        softmax_out = F.softmax(pred, dim=-1)
        softmax_out = torch.clip(softmax_out, self.EPS, 1 - self.EPS) # make no nans (0,1) 
        
        log_prob = torch.log(softmax_out)

        return F.nll_loss(
                log_prob,
                gt,
                reduction = self.reduction,
                ignore_index = self.ignore_index
            )

In [20]:
N_EPOCHS = 20

PAD_ID = token_model.subword_to_id('<PAD>')
# gener_criterion = nn.CrossEntropyLoss(ignore_index = PAD_ID)

gener_criterion = CrossEntropyLoss(ignore_index = PAD_ID)
discr_criterion = nn.CrossEntropyLoss()

clip = 3
teacher_forcing_ratio = 0.2

In [21]:
def train_get_gener(batch, src_style):
    batch = batch.to(device, non_blocking=True)
    src, trg = batch, batch
    #batch[src len, batch size]

    predict_same = model(src, trg, src_style, src_style, teacher_forcing_ratio) # TFR
    predict_fake = model(src, trg, src_style, 1 - src_style, 0)                 # FREE-RUNNING
    #predict_same = [trg len, batch size, output dim]

    #trg = [trg len, batch size]
    #predict = [trg len, batch size, output dim]
    predict_dim = predict_same.shape[-1]
    predict_same_for_loss = predict_same[1:].view(-1, predict_dim)
    trg = trg[1:].view(-1)
    #trg = [(trg len - 1) * batch size]
    #predict_same_for_loss = [(trg len - 1) * batch size, output dim]
    
    predict_same = predict_same.permute(1, 0, 2)
    predict_fake = predict_fake.permute(1, 0, 2)
    #predict_same = [batch size, trg len, output dim]

    return predict_same, predict_fake, predict_same_for_loss, trg, batch.shape[1]


def train(model, discr_poem, discr_news, loader_p, loader_n, 
          optimizer_gen, optimizer_discr0, optimizer_discr1, scheduler_gen, scheduler_discr0, scheduler_discr1,
          gener_criterion, discr_criterion, clip, epoch):
    
    model.train()
    discr_poem.train()
    discr_news.train()
    
    epoch_loss = 0
    epoch_loss_gen_rec = 0
    epoch_loss_gen_stt = 0
    epoch_loss_discrs = 0

    acc_discr_poem_rec_right = 0 # poem discr get recon poem and told yes
    acc_discr_news_rec_right = 0
    acc_discr_poem_stt_is_st = 0 # poem discr get fake poem from news and told yes
    acc_discr_news_stt_is_st = 0
    
    koef_discr_isnt = 1
    koef_stt = 1
    koef_rec = 1
    
    need_discr_optim = epoch >= 6 # change scheduler!
    
    if epoch < 6:
        koef_rec = 1
        koef_stt = 0
    if epoch >= 10:
        koef_stt = 0.25
    
    for batch_p, batch_n in tqdm(zip(loader_p, loader_n), total=len(loader_p)):
        # 1 epoch = 2 batchs, 1 with poem and news

        p_predict_same, p_predict_fake, p_predict_same_loss, p_trg, batch_size = train_get_gener(batch_p, 1)
        n_predict_same, n_predict_fake, n_predict_same_loss, n_trg, _ = train_get_gener(batch_n, 0)

        # [x,y] - x bigger, then this isn't this style, y - it's this style
        is_style = torch.ones((batch_size), device=device).long()
        isnt_style = torch.zeros((batch_size), device=device).long()

        ### generator optimization ###
        optimizer_gen.zero_grad()
                
        # *generator must learn save context in reconstruction
        loss_reconstr = gener_criterion(p_predict_same_loss, p_trg) + \
                        gener_criterion(n_predict_same_loss, n_trg)

        # *generator must learn to rigth transfer style
        discr_poem_out_fake = discr_poem(n_predict_fake)
        discr_news_out_fake = discr_news(p_predict_fake)
                
        # find success in style transfer
        loss_style_trans = discr_criterion(discr_poem_out_fake, is_style) + \
                           discr_criterion(discr_news_out_fake, is_style)
        
        loss_generator = loss_reconstr * koef_rec + loss_style_trans * koef_stt
        loss_generator.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer_gen.step()
        scheduler_gen.step() 

        ### discriminators optimization ###
        if need_discr_optim:
            optimizer_discr0.zero_grad()
            optimizer_discr1.zero_grad()
            
            # *discr must learn to search fake samplers
            discr_poem_out_same = discr_poem(p_predict_same.detach())
            discr_news_out_same = discr_news(n_predict_same.detach())

            discr_poem_out_fake = discr_poem(n_predict_fake.detach()) 
            discr_news_out_fake = discr_news(p_predict_fake.detach())

            # find success in reconstr and fault in style transfer
            loss_discr_poem = discr_criterion(discr_poem_out_same, is_style) + \
                              discr_criterion(discr_poem_out_fake, isnt_style) * koef_discr_isnt
            loss_discr_news = discr_criterion(discr_news_out_same, is_style) + \
                              discr_criterion(discr_news_out_fake, isnt_style) * koef_discr_isnt  

            loss_discr = loss_discr_poem + loss_discr_news
            loss_discr.backward()
            torch.nn.utils.clip_grad_norm_(discr_poem.parameters(), clip)
            torch.nn.utils.clip_grad_norm_(discr_news.parameters(), clip)
            optimizer_discr0.step()
            optimizer_discr1.step()
            scheduler_discr0.step() 
            scheduler_discr1.step() 
        else:
            # for count accuracy and loss
            with torch.no_grad():
                discr_poem_out_same = discr_poem(p_predict_same.detach())
                discr_news_out_same = discr_news(n_predict_same.detach())

                loss_discr_poem = discr_criterion(discr_poem_out_same, is_style) + \
                                  discr_criterion(discr_poem_out_fake, isnt_style) * koef_discr_isnt
                loss_discr_news = discr_criterion(discr_news_out_same, is_style) + \
                                  discr_criterion(discr_news_out_fake, isnt_style) * koef_discr_isnt

            
        ### accuracy ###
        acc_discr_poem_rec_right += (discr_poem_out_same.argmax(1) == 1).sum().item()
        acc_discr_news_rec_right += (discr_news_out_same.argmax(1) == 1).sum().item()

        acc_discr_poem_stt_is_st += (discr_poem_out_fake.argmax(1) == 1).sum().item()
        acc_discr_news_stt_is_st += (discr_news_out_fake.argmax(1) == 1).sum().item()
        
        ### save epoch loss ###
        epoch_loss += (loss_reconstr + loss_style_trans + loss_discr_poem + loss_discr_news)
        epoch_loss_gen_rec += loss_reconstr
        epoch_loss_gen_stt += loss_style_trans
        epoch_loss_discrs += (loss_discr_poem + loss_discr_news)

    acc_discr_poem_rec = acc_discr_poem_rec_right / len(train_poem)
    acc_discr_news_rec = acc_discr_news_rec_right / len(train_poem)

    acc_discr_poem_stt = acc_discr_poem_stt_is_st / len(train_poem)
    acc_discr_news_stt = acc_discr_news_stt_is_st / len(train_poem)

    return epoch_loss / len(loader_p), epoch_loss_gen_rec / len(loader_p), \
           epoch_loss_gen_stt / len(loader_p), epoch_loss_discrs / len(loader_p), \
           acc_discr_poem_rec, acc_discr_news_rec, \
           acc_discr_poem_stt, acc_discr_news_stt

In [22]:
def valid_get_gener(batch, src_style):
    batch = batch.to(device, non_blocking=True)
    src, trg = batch, batch
    #batch[src len, batch size]

    predict_same = model(src, trg, src_style, src_style, 0)      # FREE-RUNNING
    predict_fake = model(src, trg, src_style, 1 - src_style, 0)  # FREE-RUNNING
    #predict_same = [trg len, batch size, output dim]

    #trg = [trg len, batch size]
    #predict = [trg len, batch size, output dim]
    predict_dim = predict_same.shape[-1]
    predict_same_for_loss = predict_same[1:].view(-1, predict_dim)
    trg = trg[1:].view(-1)
    #trg = [(trg len - 1) * batch size]
    #predict_same_for_loss = [(trg len - 1) * batch size, output dim]
    
    predict_same = predict_same.permute(1, 0, 2)
    predict_fake = predict_fake.permute(1, 0, 2)
    #predict_same = [batch size, trg len, output dim]

    return predict_same, predict_fake, predict_same_for_loss, trg, batch.shape[1]


def evaluate(model, discr_poem, discr_news, loader_p, loader_n, gener_criterion, discr_criterion):
    
    model.eval()
    discr_poem.eval()
    discr_news.eval()
    
    epoch_loss = 0
    epoch_loss_gen_rec = 0
    epoch_loss_gen_stt = 0
    epoch_loss_discrs = 0

    acc_discr_poem_rec_right = 0 # poem discr get recon poem and told yes
    acc_discr_news_rec_right = 0
    acc_discr_poem_stt_is_st = 0 # poem discr get fake poem from news and told yes
    acc_discr_news_stt_is_st = 0
    
    with torch.no_grad():
    
        for batch_p, batch_n in tqdm(zip(loader_p, loader_n), total=len(loader_p)):
            # 1 epoch = 2 batchs, 1 with poem and news

            p_predict_same, p_predict_fake, p_predict_same_loss, p_trg, batch_size = valid_get_gener(batch_p, 1)
            n_predict_same, n_predict_fake, n_predict_same_loss, n_trg, _ = valid_get_gener(batch_n, 0)

            # [x,y] - x bigger, then this isn't this style, y - it's this style
            is_style = torch.ones((batch_size), device=device).long()
            isnt_style = torch.zeros((batch_size), device=device).long()

            ### generator ###
            loss_reconstr = gener_criterion(p_predict_same_loss, p_trg) + \
                            gener_criterion(n_predict_same_loss, n_trg)

            discr_poem_out_fake = discr_poem(n_predict_fake)
            discr_news_out_fake = discr_news(p_predict_fake)

            loss_style_trans = discr_criterion(discr_poem_out_fake, is_style) + \
                               discr_criterion(discr_news_out_fake, is_style)

            ### discriminators ###
            discr_poem_out_same = discr_poem(p_predict_same.detach())
            discr_news_out_same = discr_news(n_predict_same.detach())

            loss_discr_poem = discr_criterion(discr_poem_out_same, is_style) + \
                              discr_criterion(discr_poem_out_fake, isnt_style)
            loss_discr_news = discr_criterion(discr_news_out_same, is_style) + \
                              discr_criterion(discr_news_out_fake, isnt_style)

            ### accuracy ###
            acc_discr_poem_rec_right += (discr_poem_out_same.argmax(1) == 1).sum().item()
            acc_discr_news_rec_right += (discr_news_out_same.argmax(1) == 1).sum().item()

            acc_discr_poem_stt_is_st += (discr_poem_out_fake.argmax(1) == 1).sum().item()
            acc_discr_news_stt_is_st += (discr_news_out_fake.argmax(1) == 1).sum().item()

            ### save epoch loss ###
            epoch_loss += (loss_reconstr + loss_style_trans + loss_discr_poem + loss_discr_news)
            epoch_loss_gen_rec += loss_reconstr
            epoch_loss_gen_stt += loss_style_trans
            epoch_loss_discrs += (loss_discr_poem + loss_discr_news)

    acc_discr_poem_rec = acc_discr_poem_rec_right / len(valid_poem)
    acc_discr_news_rec = acc_discr_news_rec_right / len(valid_poem)

    acc_discr_poem_stt = acc_discr_poem_stt_is_st / len(valid_poem)
    acc_discr_news_stt = acc_discr_news_stt_is_st / len(valid_poem)

    return epoch_loss / len(loader_p), epoch_loss_gen_rec / len(loader_p), \
           epoch_loss_gen_stt / len(loader_p), epoch_loss_discrs / len(loader_p), \
           acc_discr_poem_rec, acc_discr_news_rec, \
           acc_discr_poem_stt, acc_discr_news_stt

In [23]:
# def evaluate(model, discr_poem, discr_news, loader_p, loader_n, gener_criterion, discr_criterion):
    
#     model.eval()
#     discr_poem.eval()
#     discr_news.eval()
    
#     epoch_loss = 0
#     epoch_loss_gen_rec = 0
#     epoch_loss_gen_stt = 0
#     epoch_loss_discrs = 0

#     acc_discr_poem_rec_right = 0 # poem discr get recon poem and told yes
#     acc_discr_news_rec_right = 0
#     acc_discr_poem_stt_is_st = 0 # poem discr get fake poem from news and told yes
#     acc_discr_news_stt_is_st = 0
    
#     with torch.no_grad():
    
#         for batch_p, batch_n in tqdm(zip(loader_p, loader_n), total=len(loader_p)):
#             for batch, src_style in (batch_p, batch_n):
#                 batch = batch.to(device, non_blocking=True)
#                 src = batch
#                 trg = batch

#                 predict_same, out_tokens_same = model(src, trg, src_style, src_style, 0) #turn off teacher forcing
#                 predict_fake, out_tokens_fake = model(src, trg, src_style, 1 - src_style, 0) 
#                 #predict_same = [trg len, batch size, output dim]
#                 #out_tokens_same = [trg len, batch size, hid dim]

# #                 out_tokens_same = out_tokens_same.permute(1, 0, 2)
# #                 out_tokens_fake = out_tokens_fake.permute(1, 0, 2)
# #                 #out_tokens_same = [batch size, trg len, output dim]
                
#                 predict_same_for_loss = predict_same
#                 out_tokens_same = predict_same.permute(1, 0, 2)
#                 out_tokens_fake = predict_fake.permute(1, 0, 2)
#                 #predict_same = [batch size, trg len, output dim]

#                 is_style = torch.ones((batch.shape[1]), device=device).long()
#                 isnt_style = torch.zeros((batch.shape[1]), device=device).long()

#                 if src_style == 0: # news
#                   discr_news_out_same = discr_news(out_tokens_same)
#                   discr_poem_out_fake = discr_poem(out_tokens_fake)

#                   loss_discr_news = discr_criterion(discr_news_out_same, is_style)
#                   loss_discr_poem = discr_criterion(discr_poem_out_fake, isnt_style) # find fault in style transfer 

#                   loss_style_trans = discr_criterion(discr_poem_out_fake, is_style) # find success in style transfer

#                   acc_discr_news_rec_right += (discr_news_out_same.argmax(1) == 1).sum().item()
#                   acc_discr_poem_stt_is_st += (discr_poem_out_fake.argmax(1) == 1).sum().item()

#                 if src_style == 1: # poem
#                   discr_news_out_fake = discr_news(out_tokens_fake)
#                   discr_poem_out_same = discr_poem(out_tokens_same)

#                   loss_discr_news = discr_criterion(discr_news_out_fake, isnt_style)
#                   loss_discr_poem = discr_criterion(discr_poem_out_same, is_style) 

#                   loss_style_trans = discr_criterion(discr_news_out_fake, is_style)

#                   acc_discr_poem_rec_right += (discr_poem_out_same.argmax(1) == 1).sum().item()
#                   acc_discr_news_stt_is_st += (discr_news_out_fake.argmax(1) == 1).sum().item()

#                 #trg = [trg len, batch size]
#                 #predict = [trg len, batch size, output dim]
#                 predict_dim = predict_same.shape[-1]
#                 predict_same_for_loss = predict_same_for_loss[1:].view(-1, predict_dim)
#                 trg = trg[1:].view(-1)
#                 #trg = [(trg len - 1) * batch size]
#                 #predict_same = [(trg len - 1) * batch size, output dim]
                
#                 loss_reconstr = gener_criterion(predict_same_for_loss, trg)

#                 loss_generator = loss_reconstr + loss_style_trans
#                 loss_discr = loss_discr_poem + loss_discr_news

#                 loss = loss_generator + loss_discr

#                 epoch_loss += loss
#                 epoch_loss_gen_rec += loss_reconstr
#                 epoch_loss_gen_stt += loss_style_trans
#                 epoch_loss_discrs += loss_discr

#     acc_discr_poem_rec = acc_discr_poem_rec_right / len(valid_poem)
#     acc_discr_news_rec = acc_discr_news_rec_right / len(valid_poem)

#     acc_discr_poem_stt = acc_discr_poem_stt_is_st / len(valid_poem)
#     acc_discr_news_stt = acc_discr_news_stt_is_st / len(valid_poem)

#     return epoch_loss / len(loader_p), epoch_loss_gen_rec / len(loader_p), \
#            epoch_loss_gen_stt / len(loader_p), epoch_loss_discrs / len(loader_p), \
#            acc_discr_poem_rec, acc_discr_news_rec, \
#            acc_discr_poem_stt, acc_discr_news_stt

In [24]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [25]:
# torch.autograd.set_detect_anomaly(True)

In [26]:
checkpoint = torch.load('ST_dAdam005_genCyclefdf10_isnt1stt1-af10e0.25rec1_discr14e_saveopt_len40_checkp_last.pt')

model.load_state_dict(checkpoint['model_state_dict'])
# discr_poem.load_state_dict(checkpoint['discr_poem_state_dict'])
# discr_news.load_state_dict(checkpoint['discr_news_state_dict'])

optimizer_gen = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
optimizer_discr0 = optim.AdamW(discr_news.parameters(), lr=0.005, weight_decay=0.01)
optimizer_discr1 = optim.AdamW(discr_poem.parameters(), lr=0.005, weight_decay=0.01)

optimizer_gen.load_state_dict(checkpoint['optimizer_gen_state_dict'])
# optimizer_discr0.load_state_dict(checkpoint['optimizer_discr0_state_dict'])
# optimizer_discr1.load_state_dict(checkpoint['optimizer_discr1_state_dict'])

In [27]:
poem_token_dataset = styleDataset(token_poem)
poem_token_dataset.cut_length(60)
train_poem, valid_poem, test_poem = poem_token_dataset.distribute_data(0.8, 0.1)
train_poem_loader, valid_poem_loader = \
    DataLoader(train_poem, batch_size=32, shuffle=True, num_workers=0, pin_memory=True, collate_fn=padding), \
    DataLoader(valid_poem, batch_size=32, shuffle=True, num_workers=0, pin_memory=True, collate_fn=padding)

# with open('data/voc20klen40_test_poem.pickle', 'wb', ) as file:
# # with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/voc20klen20_test_poem.pickle', 'wb', ) as file:
#     pickle.dump(test_poem, file)

news_token_dataset = styleDataset(token_news)
news_token_dataset.cut_length(60)
train_news, valid_news, test_news = news_token_dataset.distribute_data(0.8, 0.1)
train_news_loader, valid_news_loader = \
    DataLoader(train_news, batch_size=32, shuffle=True, num_workers=0, pin_memory=True, collate_fn=padding), \
    DataLoader(valid_news, batch_size=32, shuffle=True, num_workers=0, pin_memory=True, collate_fn=padding)

# with open('data/voc20klen40_test_news.pickle', 'wb', ) as file:
# # with open('/content/drive/MyDrive/Colab Notebooks/NLP_poems/data/voc20klen20_test_news.pickle', 'wb', ) as file:
#     pickle.dump(test_news, file)

In [28]:
scheduler_gen = optim.lr_scheduler.OneCycleLR(
    optimizer_gen, max_lr = 0.005, total_steps = len(train_poem_loader)*N_EPOCHS, cycle_momentum=True,
    final_div_factor = 10
)
scheduler_discr0 = optim.lr_scheduler.StepLR(optimizer_discr0, step_size=10e10, gamma=0.1)
scheduler_discr1 = optim.lr_scheduler.StepLR(optimizer_discr1, step_size=10e10, gamma=0.1)

In [None]:
bv_loss = float('inf')
best_valid_loss = bv_loss

start_epoch = 20
for epoch in range(start_epoch, start_epoch + N_EPOCHS):
    
    start_time = time.time()
        
    train_loss, train_loss_gen_rec, train_loss_gen_stt, train_loss_discrs, \
    train_acc_recon_poem, train_acc_recon_news, train_acc_discr_poem_stt, train_acc_discr_news_stt = \
                train(model, discr_poem, discr_news,
                          train_poem_loader, train_news_loader, 
                          optimizer_gen, optimizer_discr0, optimizer_discr1, 
                          scheduler_gen, scheduler_discr0, scheduler_discr1,
                          gener_criterion, discr_criterion, clip, epoch-start_epoch)
                
    valid_loss, valid_loss_gen_rec, valid_loss_gen_stt, valid_loss_discrs, \
    valid_acc_recon_poem, valid_acc_recon_news, valid_acc_discr_poem_stt, valid_acc_discr_news_stt = \
                evaluate(model, discr_poem, discr_news,
                         valid_poem_loader, valid_news_loader, 
                         gener_criterion, discr_criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    writer.add_scalar("Loss/train", train_loss, epoch)
    writer.add_scalar("Loss_rec/train", train_loss_gen_rec, epoch)
    writer.add_scalar("Loss_stt/train", train_loss_gen_stt, epoch)
    writer.add_scalar("Loss_discr/train", train_loss_discrs, epoch)

    writer.add_scalar("Loss/valid", valid_loss, epoch)
    writer.add_scalar("Loss_rec/valid", valid_loss_gen_rec, epoch)
    writer.add_scalar("Loss_stt/valid", valid_loss_gen_stt, epoch)
    writer.add_scalar("Loss_discr/valid", valid_loss_discrs, epoch)
    
    writer.flush()
    
#     if valid_loss < best_valid_loss:
#         best_valid_loss = valid_loss
#         torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'discr_poem_state_dict': discr_poem.state_dict(),
#             'discr_news_state_dict': discr_news.state_dict(),
#             'optimizer_gen_state_dict': optimizer_gen.state_dict(),
#             'optimizer_discr0_state_dict': optimizer_discr0.state_dict(),
#             'optimizer_discr1_state_dict': optimizer_discr1.state_dict(),
#             'scheduler_gen_state_dict': scheduler_gen.state_dict(),
#             'scheduler_discr0_state_dict': scheduler_discr0.state_dict(),
#             'scheduler_discr1_state_dict': scheduler_discr1.state_dict(),
#             'loss': best_valid_loss
#             }, 'ST_len20_dAdam005_isnt7stt7rec0.1-after10e0.001_discr16e_checkp_best.pt')
#             }, '/content/drive/MyDrive/Colab Notebooks/NLP_poems/ST_15k_voc20k_len20_nf100_fs2,6_Lrec0.5_discrCycle01_detach_checkp_best.pt')

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'discr_poem_state_dict': discr_poem.state_dict(),
        'discr_news_state_dict': discr_news.state_dict(),
        'optimizer_gen_state_dict': optimizer_gen.state_dict(),
        'optimizer_discr0_state_dict': optimizer_discr0.state_dict(),
        'optimizer_discr1_state_dict': optimizer_discr1.state_dict(),
        'loss': valid_loss
            }, 'ST_dAdam005_genCyclefdf10_isnt1stt1-af10e0.25rec1_discr14e_len4060_newdiscr_right_checkp_last.pt')
#             }, '/content/drive/MyDrive/Colab Notebooks/NLP_poems/ST_15k_voc20k_len20_nf100_fs2,6_Lrec0.5_discrCycle01_detach_checkp_last.pt')
        
    
    print(f'Epoch: {epoch:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Acc recon poem: {train_acc_recon_poem:.2f} | Acc recon news: {train_acc_recon_news:.2f} | ' +
          f'Acc d_poem stt: {train_acc_discr_poem_stt:.2f} | Acc d_news stt: {train_acc_discr_news_stt:.2f}')
    print(f'\t Val. Loss: {valid_loss:.3f} | Acc recon poem: {valid_acc_recon_poem:.2f} | Acc recon news: {valid_acc_recon_news:.2f} | ' +
          f'Acc d_poem stt: {valid_acc_discr_poem_stt:.2f} | Acc d_news stt: {valid_acc_discr_news_stt:.2f}')
    
writer.flush()
writer.close()

  0%|          | 0/375 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

Epoch: 20 | Time: 6m 46s
	Train Loss: 9.852 | Acc recon poem: 0.88 | Acc recon news: 0.96 | Acc d_poem stt: 0.89 | Acc d_news stt: 0.95
	 Val. Loss: 9.267 | Acc recon poem: 1.00 | Acc recon news: 1.00 | Acc d_poem stt: 1.00 | Acc d_news stt: 1.00


  0%|          | 0/375 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

Epoch: 21 | Time: 6m 44s
	Train Loss: 8.427 | Acc recon poem: 0.88 | Acc recon news: 0.95 | Acc d_poem stt: 0.89 | Acc d_news stt: 0.95
	 Val. Loss: 10.799 | Acc recon poem: 1.00 | Acc recon news: 1.00 | Acc d_poem stt: 1.00 | Acc d_news stt: 1.00


  0%|          | 0/375 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

Epoch: 22 | Time: 6m 46s
	Train Loss: 8.492 | Acc recon poem: 0.88 | Acc recon news: 0.95 | Acc d_poem stt: 0.89 | Acc d_news stt: 0.95
	 Val. Loss: 9.078 | Acc recon poem: 1.00 | Acc recon news: 1.00 | Acc d_poem stt: 1.00 | Acc d_news stt: 1.00


  0%|          | 0/375 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

Epoch: 23 | Time: 6m 48s
	Train Loss: 8.768 | Acc recon poem: 0.89 | Acc recon news: 0.95 | Acc d_poem stt: 0.90 | Acc d_news stt: 0.95
	 Val. Loss: 8.709 | Acc recon poem: 1.00 | Acc recon news: 1.00 | Acc d_poem stt: 1.00 | Acc d_news stt: 1.00


  0%|          | 0/375 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

Epoch: 24 | Time: 6m 48s
	Train Loss: 14.905 | Acc recon poem: 0.91 | Acc recon news: 0.95 | Acc d_poem stt: 0.91 | Acc d_news stt: 0.95
	 Val. Loss: 23.631 | Acc recon poem: 1.00 | Acc recon news: 1.00 | Acc d_poem stt: 1.00 | Acc d_news stt: 1.00


  0%|          | 0/375 [00:00<?, ?it/s]

In [32]:
writer.flush()
writer.close()

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir mylogdir