## Pure Pytorch実装

In [1]:
import numpy as np
import polyphonic_data_loader as poly
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import time
import torch
import torch.nn as nn
from pyro.optim import ClippedAdam
from pyro.infer import SVI, Trace_ELBO
from util import get_logger

In [2]:
class GatedTransition(nn.Module):
    def __init__(self, z_dim, transition_dim):
        super().__init__()
        # こんなに層を重ねる理由は下のプログラムのコメント見ればわかる
        self.lin_gate_z_to_hidden = nn.Linear(z_dim, transition_dim)
        self.lin_gate_hidden_to_z = nn.Linear(transition_dim, z_dim)
        self.lin_proposed_mean_z_to_hidden = nn.Linear(z_dim, transition_dim)
        self.lin_proposed_mean_hidden_to_z = nn.Linear(transition_dim, z_dim)
        self.lin_sig = nn.Linear(z_dim, z_dim)
        self.lin_z_to_loc = nn.Linear(z_dim, z_dim)
        self.lin_z_to_loc.weight.data = torch.eye(z_dim)
        self.lin_z_to_loc.bias.data = torch.zeros(z_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()
    
    def forward(self, z_t_1):
        _gate = self.relu(self.lin_gate_z_to_hidden(z_t_1))
        gate = self.sigmoid(self.lin_gate_hidden_to_z(_gate))
        # proposed_meaやlocの計算式などは参照論文5節Deep Markov Modelの最後の方に
        # 式があり， そこと同じ式になってる
        _proposed_mean = self.relu(self.lin_proposed_mean_z_to_hidden(z_t_1))
        proposed_mean = self.lin_proposed_mean_hidden_to_z(_proposed_mean)
        loc = (1 - gate) * self.lin_z_to_loc(z_t_1) + gate * proposed_mean
        scale = self.softplus(self.lin_sig(self.relu(proposed_mean)))
        return loc, scale

In [3]:
class Emitter(nn.Module):
    def __init__(self, input_dim, z_dim, emission_dim):
        super().__init__()
        self.lin_z_to_hidden = nn.Linear(z_dim, emission_dim)
        self.lin_hidden_to_hidden = nn.Linear(emission_dim, emission_dim)
        self.lin_hidden_to_input = nn.Linear(emission_dim, input_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, z_t):
        h1 = self.relu(self.lin_z_to_hidden(z_t))
        h2 = self.relu(self.lin_hidden_to_hidden(h1))
        # ps:ベルヌーイ分布のパラメータ
        ps = self.sigmoid(self.lin_hidden_to_input(h2))
        return ps

## Pyroによる実装

In [4]:
def model(self, mini_bacth, mini_batch_reversed, mini_batch_mask, 
        mini_batch_seq_lengths, annealing_factor=1.0):
    T_max = mini_batch.size(1)
    pyro.module("dmm", self)
    # z_0
    z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0))
    
    '''
    mini_batch: [バッチの次元, 一時的な次元（？）, 88次元(music data)]
    '''
    with pyro.plate("z_mini_batch", len(mini_batch)):
        for t in range(1, T_max+1):
            z_loc, z_scale = self.trans(z_prev)
            
            with poutine.scale(None, annealing_factor):
                z_t = pyro.sample("z_%d" % t,
                                              dist.Normal(z_loc, z_scale)
                                                          .mask(mini_batch_mask[:, t-1:t])
                                                          .to_event(1))
            
            emission_probs_t = self.emitter(z_t)
            pyro.sample("obs_x_%d" % t,
                                   dist.Bernoulli(emission_probs_t)
                                          .mask(mini_batch_mask[:, t-1:t])
                                          .to_event(1),
                                          obs=mini_batch[:, t-1, :] )
            z_prev = z_t

In [5]:
class Combiner(nn.Module):
    def __init__(self, z_dim, rnn_dim):
        super().__init__()
        self.lin_z_to_hidden = nn.Linear(z_dim, rnn_dim)
        self.lin_hidden_to_loc = nn.Linear(rnn_dim, z_dim)
        self.lin_hidden_to_scale = nn.Linear(rnn_dim, z_dim)
        self.tanh = nn.Tanh()
        self.softplus = nn.Softplus()
    
    def forward(self, z_t_1, h_rnn):
        h_combined = 0.5 * (self.tanh(self.lin_z_to_hidden(z_t_1)) + h_rnn)
        loc = self.lin_hidden_to_loc(h_combined)
        scale = self.softplus(self.lin_hidden_to_scale(h_combined))
        return loc, scale

In [6]:
def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, 
                 mini_batch_seq_lengths, annealing_factor=1.0):
    T_max = mini_batch.size(1)
    pyro.module("dmm", self)
    #  h_0
    h_0_contig = self.h_0.expand(1, mini_batch.size(0),
                                 self.rnn.hidden_size).contiguous()
    rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
    rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
    z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))

    with pyro.plate("z_minibatch", len(mini_batch)):
        for t in range(1, T_max + 1):
            z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])
            z_dist = dist.Normal(z_loc, z_scale)

            with pyro.poutine.scale(None, annealing_factor):
                z_t = pyro.sample("z_%d" % t,
                                  z_dist.mask(mini_batch_mask[:, t - 1:t])
                                        .to_event(1))
            z_prev = z_t

In [7]:
class DMM(nn.Module):
    """
    This PyTorch Module encapsulates the model as well as the
    variational distribution (the guide) for the Deep Markov Model
    """

    def __init__(self, input_dim=88, z_dim=100, emission_dim=100,
                 transition_dim=200, rnn_dim=600, num_layers=1, rnn_dropout_rate=0.0,
                 num_iafs=0, iaf_dim=50, use_cuda=False):
        super().__init__()
        self.emitter = Emitter(input_dim, z_dim, emission_dim)
        self.trans = GatedTransition(z_dim, transition_dim)
        self.combiner = Combiner(z_dim, rnn_dim)
        rnn_dropout_rate = 0. if num_layers == 1 else rnn_dropout_rate
        self.rnn = nn.RNN(input_size=input_dim, hidden_size=rnn_dim, nonlinearity='relu',
                          batch_first=True, bidirectional=False, num_layers=num_layers,
                          dropout=rnn_dropout_rate)

        self.iafs = [affine_autoregressive(z_dim, hidden_dims=[iaf_dim]) for _ in range(num_iafs)]
        self.iafs_modules = nn.ModuleList(self.iafs)
        self.z_0 = nn.Parameter(torch.zeros(z_dim))
        self.z_q_0 = nn.Parameter(torch.zeros(z_dim))
        self.h_0 = nn.Parameter(torch.zeros(1, 1, rnn_dim))
        self.use_cuda = use_cuda
        if use_cuda:
            self.cuda()

    def model(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):
        T_max = mini_batch.size(1)
        pyro.module("dmm", self)
        z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0))
        
        with pyro.plate("z_minibatch", len(mini_batch)):
            for t in pyro.markov(range(1, T_max + 1)):
                z_loc, z_scale = self.trans(z_prev)
                
                with poutine.scale(scale=annealing_factor):
                    z_t = pyro.sample("z_%d" % t,
                                      dist.Normal(z_loc, z_scale)
                                          .mask(mini_batch_mask[:, t - 1:t])
                                          .to_event(1))
                emission_probs_t = self.emitter(z_t)
                pyro.sample("obs_x_%d" % t,
                            dist.Bernoulli(emission_probs_t)
                                .mask(mini_batch_mask[:, t - 1:t])
                                .to_event(1),
                            obs=mini_batch[:, t - 1, :])
                z_prev = z_t

    def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):
        T_max = mini_batch.size(1)
        pyro.module("dmm", self)

        h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous()
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))

        with pyro.plate("z_minibatch", len(mini_batch)):
            for t in pyro.markov(range(1, T_max + 1)):
                z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])
                if len(self.iafs) > 0:
                    z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs)
                    assert z_dist.event_shape == (self.z_q_0.size(0),)
                    assert z_dist.batch_shape[-1:] == (len(mini_batch),)
                else:
                    z_dist = dist.Normal(z_loc, z_scale)
                    assert z_dist.event_shape == ()
                    assert z_dist.batch_shape[-2:] == (len(mini_batch), self.z_q_0.size(0))

                with pyro.poutine.scale(scale=annealing_factor):
                    if len(self.iafs) > 0:
                        z_t = pyro.sample("z_%d" % t,
                                          z_dist.mask(mini_batch_mask[:, t - 1]))
                    else:
                        z_t = pyro.sample("z_%d" % t,
                                          z_dist.mask(mini_batch_mask[:, t - 1:t])
                                          .to_event(1))
                z_prev = z_t

In [8]:
dmm = DMM()

adam_params = {"lr": 0.0003, "betas": (0.96, 0.999),
               "clip_norm": 10.0, "lrd": 0.99996,
               "weight_decay": 2.0}
optimizer = ClippedAdam(adam_params)

In [9]:
svi = SVI(dmm.model, dmm.guide, optimizer, Trace_ELBO())

In [10]:
def process_minibatch(epoch, which_mini_batch, shuffled_indices, annealing_epochs=1000, minimum_annealing_factor=0.2):
        if annealing_epochs > 0 and epoch < annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * mini_batch_size, N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=False)
        # do an actual gradient step
        loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

In [11]:
def do_evaluation():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        dmm.rnn.eval()

        # compute the validation and test loss n_samples many times
        val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask,
                                    val_seq_lengths) / float(torch.sum(val_seq_lengths))
        test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask,
                                     test_seq_lengths) / float(torch.sum(test_seq_lengths))

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        dmm.rnn.train()
        return val_nll, test_nll

In [15]:
log = get_logger('dmm.log')
mini_batch_size = 20
data = poly.load_data(poly.JSB_CHORALES)
training_seq_lengths = data['train']['sequence_lengths']
training_data_sequences = data['train']['sequences']
test_seq_lengths = data['test']['sequence_lengths']
test_data_sequences = data['test']['sequences']
val_seq_lengths = data['valid']['sequence_lengths']
val_data_sequences = data['valid']['sequences']
N_train_data = len(training_seq_lengths)
N_train_time_slices = float(torch.sum(training_seq_lengths))
N_mini_batches = int(N_train_data / mini_batch_size +
                     int(N_train_data % mini_batch_size > 0))

log("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d" %
    (N_train_data, training_seq_lengths.float().mean(), N_mini_batches))

# how often we do validation/test evaluation during training
val_test_frequency = 50
# the number of samples we use to do the evaluation
n_eval_samples = 1


N_train_data: 229     avg. training seq. length: 60.29    N_mini_batches: 12
N_train_data: 229     avg. training seq. length: 60.29    N_mini_batches: 12
N_train_data: 229     avg. training seq. length: 60.29    N_mini_batches: 12


In [None]:
num_epochs = 50
# train_loop
times = [time.time()]
for epoch in range(num_epochs):
    
    # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
    epoch_nll = 0.0
    # prepare mini-batch subsampling indices for this epoch
    shuffled_indices = torch.randperm(N_train_data)

    # process each mini-batch; this is where we take gradient steps
    for which_mini_batch in range(N_mini_batches):
        epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices)

    # report training diagnostics
    times.append(time.time())
    epoch_time = times[-1] - times[-2]
    log("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
        (epoch, epoch_nll / N_train_time_slices, epoch_time))

    # do evaluation on test and validation data and report results
    if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
        val_nll, test_nll = do_evaluation()
        log("[val/test epoch %04d]  %.4f  %.4f" % (epoch, val_nll, test_nll))

[training epoch 0000]  58.8030 				(dt = 12.886 sec)
[training epoch 0000]  58.8030 				(dt = 12.886 sec)
[training epoch 0000]  58.8030 				(dt = 12.886 sec)
[training epoch 0001]  46.2033 				(dt = 12.532 sec)
[training epoch 0001]  46.2033 				(dt = 12.532 sec)
[training epoch 0001]  46.2033 				(dt = 12.532 sec)
[training epoch 0002]  25.2181 				(dt = 12.284 sec)
[training epoch 0002]  25.2181 				(dt = 12.284 sec)
[training epoch 0002]  25.2181 				(dt = 12.284 sec)
[training epoch 0003]  17.2768 				(dt = 12.438 sec)
[training epoch 0003]  17.2768 				(dt = 12.438 sec)
[training epoch 0003]  17.2768 				(dt = 12.438 sec)
[training epoch 0004]  15.1741 				(dt = 12.584 sec)
[training epoch 0004]  15.1741 				(dt = 12.584 sec)
[training epoch 0004]  15.1741 				(dt = 12.584 sec)
