In [122]:
import torch
import pickle

In [90]:
import random
import numpy as np

import torch
from torch import nn

class StyleTransfer(nn.Module):
    def __init__(self, encoder, tst_decoder, d_hidden, style_ratio, variational, device):
        super(StyleTransfer, self).__init__()

        self.device = device

        self.encoder = encoder
        self.tst_decoder = tst_decoder

        self.d_hidden = d_hidden
        self.style_ratio = style_ratio
        self.content_index = int(self.d_hidden * (1 - self.style_ratio))
        self.style_index = int(self.d_hidden-self.content_index)

        self.variational = variational


        # TODO Size ?
        self.half_hidden = nn.Linear(4, 2)
        self.content2mean = nn.Linear(self.content_index, d_hidden)
        self.content2logv = nn.Linear(self.content_index, d_hidden)

        self.style2mean = nn.Linear(self.style_index, d_hidden)
        self.style2logv = nn.Linear(self.style_index, d_hidden)

    def reparameterization(self, hidden, latent_type):
        hidden = hidden.transpose(0, -1)
        hidden = self.half_hidden(hidden)

        hidden = hidden.transpose(0, -1)
        if latent_type == "content":
            mean = self.content2mean(hidden).to(self.device)
            logv = self.content2logv(hidden).to(self.device)

        elif latent_type == "style":
            mean = self.style2mean(hidden).to(self.device)
            logv = self.style2logv(hidden).to(self.device)

        std = torch.exp(0.5 * logv)
        eps = torch.randn_like(std)
        z = mean + (eps * std)
        return z, mean, logv

    def forward(self, tst_src, tst_trg, teacher_forcing_ratio=0.5):
        tst_src = tst_src.to(self.device)
        tst_trg = tst_trg.to(self.device)

        encoder_out, hidden, cell = self.encoder(tst_src)

        if self.variational:
            context_c, context_a = hidden[:, :, :self.content_index], hidden[:, :, -self.style_index:]

            # TODO 따로 따로 reparameterize? 아니면 reparameterize 한 다음에 split?
            # TODO 나눈 후 size 맞추기 위해 content 밑에/style 위에 0으로 채워서 reparameterize?
            content_c, content_mu, content_logv = self.reparameterization(context_c, "content")
            style_a, style_mu, style_logv = self.reparameterization(context_a, "style")

            total_latent = torch.cat((content_c, style_a), 0)

            # TODO cat? add? -> 일단은 total_latent로 진행
            hidden = total_latent

            latent_variables = [total_latent, content_c, content_mu, content_logv, style_a, style_mu, style_logv,]

        else:
            latent_variables = [None for _ in range(7)]

        trg_len = tst_trg.shape[1]  # length of word
        batch_size = tst_trg.shape[0]  # batch size
        trg_vocab_size = self.tst_decoder.output_size
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        input = tst_trg[:, 0]  # BOS 먼저

        output_list = []
        for i in range(1, trg_len):
            output, hidden, cell = self.tst_decoder(input, hidden, cell)
            outputs[:, i] = output
            output_list.append(torch.argmax(output, dim=1).tolist())
            top1 = output.argmax(1)

            teacher_force = random.random() < teacher_forcing_ratio
            input = tst_trg[:, i] if teacher_force else top1

        return outputs, latent_variables, output_list

class StylizedNMT(nn.Module):
    def __init__(self, nmt_encoder, nmt_decoder, d_hidden, total_latent, device):
        super(StylizedNMT, self).__init__()

        self.device = device

        self.nmt_encoder = nmt_encoder
        self.nmt_decoder = nmt_decoder
        self.total_latent = total_latent

        self.hidden2concat = nn.Linear(d_hidden, d_hidden // 2)
        self.latent2concat = nn.Linear(d_hidden, d_hidden // 2)

    def forward(self, nmt_src, nmt_trg, teacher_forcing_ratio=0.5):

        # nmt_hidden = nmt_hidden.to(self.device)
        # nmt_cell = nmt_cell.to(self.device)
        nmt_src = nmt_src.to(self.device)
        nmt_trg = nmt_trg.to(self.device)

        encoder_out, hidden, cell = self.encoder(nmt_src)

        # TODO add 할 지, concat 할 지
        if self.total_latent:
            hidden = self.hidden2concat(hidden)
            latent = self.latent2concat(self.total_latent)
            hidden = torch.cat((hidden, latent), 2)

        trg_len = nmt_trg.shape[1]  # length of word
        batch_size = nmt_trg.shape[0]  # batch size
        trg_vocab_size = self.nmt_decoder.output_size

        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        input = nmt_trg[:, 0]

        output_list = []
        for i in range(1, trg_len):
            output, hidden, cell = self.nmt_decoder(input, hidden, cell)
            outputs[:, i] = output
            output_list.append(torch.argmax(output, dim=1).tolist())
            top1 = output.argmax(1)

            teacher_force = random.random() < teacher_forcing_ratio
            input = nmt_trg[:, i] if teacher_force else top1

        return outputs, output_list


class Encoder(nn.Module):
    def __init__(self, input_size, d_hidden, d_embed, n_layers, dropout, device):
        super(Encoder, self).__init__()
        self.src_embedding = nn.Embedding(input_size, d_embed)

        # TODO num_layers=2 -> total_latent [8, batch_size, d_hidden] 이거 어떻게 해결?
        self.encoder = nn.LSTM(input_size=d_embed, hidden_size=d_hidden, dropout=dropout,
                               num_layers=n_layers, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        self.device = device

    def forward(self, src):
        embedded = self.dropout(self.src_embedding(src))
        outputs, (hidden, cell) = self.encoder(embedded)

        return outputs, hidden, cell


class TSTDecoder(nn.Module):
    def __init__(self, output_size, d_hidden, d_embed, n_layers, dropout, device):
        super(TSTDecoder, self).__init__()
        self.output_size = output_size
        self.trg_embedding = nn.Embedding(output_size, d_embed)
        self.tst_decoder = nn.LSTM(input_size=d_embed, hidden_size=d_hidden, dropout=dropout,
                                   num_layers=n_layers, bidirectional=True, batch_first=True)

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(2*d_hidden, output_size)

        self.device = device

    def forward(self, input, hidden, cell):
        input = input.unsqueeze(1)
        embedded = self.dropout(self.trg_embedding(input))

        outputs, (hidden, cell) = self.tst_decoder(embedded, (hidden, cell))

        tst_out = self.fc(outputs.squeeze(1))

        return tst_out, hidden, cell


class NMTDecoder(nn.Module):
    def __init__(self, output_size, d_hidden, d_embed, n_layers, dropout, device):
        super(NMTDecoder, self).__init__()
        self.output_size = output_size
        self.trg_embedding = nn.Embedding(output_size, d_embed)
        self.nmt_decoder = nn.LSTM(input_size=d_embed, hidden_size=d_hidden, dropout=dropout,
                                   num_layers=n_layers, bidirectional=True, batch_first=True)

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(2*d_hidden, output_size)

        self.device = device

    def forward(self, input, hidden, cell):
        input = input.unsqueeze(1)
        embedded = self.dropout(self.trg_embedding(input))
        outputs, (hidden, cell) = self.nmt_decoder(embedded, (hidden, cell))
        nmt_out = self.fc(outputs.squeeze(1))

        return nmt_out, hidden, cell


In [71]:
nmt_encoder = Encoder()
nmt_model = StylizedNMT(nmt_encoder, nmt_decoder, d_hidden, total_latent, device)
torch.load("../../kcc_data/nmt_model.pth", map_location=torch.device('cpu'))
nmt_model

{'model': OrderedDict([('nmt_decoder.trg_embedding.weight',
               tensor([[-0.1841, -2.4435,  0.0471,  ..., -0.2534, -2.0047, -2.5667],
                       [-0.4562, -0.3122,  1.0801,  ..., -0.6673,  0.2852, -0.0731],
                       [-0.8661,  0.1094, -0.9197,  ..., -0.0761, -0.5270, -0.8286],
                       ...,
                       [ 1.0204, -0.6731, -0.3677,  ...,  1.0387,  1.6615, -0.9600],
                       [ 1.5721,  0.0214,  0.0233,  ..., -2.1208,  0.2514, -1.2886],
                       [-0.4011, -0.2870,  0.5079,  ..., -0.0361, -0.6963,  0.2815]])),
              ('nmt_decoder.nmt_decoder.weight_ih_l0',
               tensor([[ 0.0287, -0.0170,  0.0073,  ..., -0.0110,  0.0170,  0.0187],
                       [-0.0019, -0.0304, -0.0227,  ...,  0.0002, -0.0269,  0.0180],
                       [-0.0213,  0.0008,  0.0090,  ..., -0.0245, -0.0202,  0.0015],
                       ...,
                       [-0.0207, -0.0305, -0.0227,  ...,  0.0

In [72]:
tst_model = t(torch.load("../../kcc_data/tst_model.pth", map_location=torch.device('cpu')))
tst_model

{'model': OrderedDict([('encoder.src_embedding.weight',
               tensor([[ 0.2827,  0.0358,  0.5101,  ...,  0.7153, -0.3080, -1.2164],
                       [-1.5855, -0.7134, -1.4938,  ..., -1.3202, -0.1101,  1.6480],
                       [-0.0306,  0.3408,  0.4493,  ..., -0.8764,  0.0742,  0.0642],
                       ...,
                       [-1.4204,  0.5075, -0.1172,  ..., -0.3030,  1.2276,  0.1671],
                       [-0.7858, -0.6558,  0.7487,  ..., -2.0158, -0.4633,  0.5521],
                       [-1.1546, -0.8147, -0.0947,  ..., -0.2383, -1.4264,  0.5645]])),
              ('encoder.encoder.weight_ih_l0',
               tensor([[-0.0419, -0.0469, -0.0271,  ...,  0.0011,  0.0503,  0.0204],
                       [ 0.0112, -0.0291,  0.0083,  ...,  0.0113, -0.0264,  0.0207],
                       [ 0.0178,  0.0250, -0.0021,  ...,  0.0339, -0.0029,  0.0152],
                       ...,
                       [ 0.0917,  0.0614,  0.0323,  ...,  0.0208,  0.0350

In [143]:
# Data Setting
with open("../data/processed/tokenized/spm_tokenized_data.pkl", "rb") as f:
    data = pickle.load(f)
    f.close()

em_informal_train = data["gyafc"]["train"]["em_informal"]
em_informal_test = data["gyafc"]["train"]["em_informal"]
em_formal_train = data["gyafc"]["train"]["em_formal"]
pair_kor_train = data['korpora']['train']['pair_kor']
pair_eng_train = data['korpora']['train']['pair_eng']
# fr_informal_train = data["gyafc"]["train"]["fr_informal"]
# fr_formal_train = data["gyafc"]["train"]["fr_formal"]

In [145]:
-int(len(data["gyafc"]["train"]["em_informal"])*0.1), len(data["gyafc"]["train"]["em_informal"]), len(em_informal_train), len(em_informal_test)

(-5259, 52595, 52070, 5207)

In [131]:
int(len(data["gyafc"]["train"]["em_informal"])*0.1)

5259

In [123]:
split_ratio = 0.8
em_informal_train = em_informal_train[:int(len(em_informal_train) * split_ratio)]
em_informal_valid = em_informal_train[int(len(em_informal_train) * split_ratio):]
em_formal_train = em_formal_train[:int(len(em_formal_train) * split_ratio)]
em_formal_valid = em_formal_train[int(len(em_formal_train) * split_ratio):]
pair_kor_train = pair_kor_train[:int(len(pair_kor_train) * split_ratio)]
pair_kor_valid = pair_kor_train[int(len(pair_kor_train) * split_ratio):]
pair_eng_train = pair_eng_train[:int(len(pair_eng_train) * split_ratio)]
pair_eng_valid = pair_eng_train[int(len(pair_eng_train) * split_ratio):]