<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2022notebooks/2022_0424timbmg_Sentence_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 変分オートエンコーダのデモ

- date: 2022_0424
- GitHub directory: `ShinAsakawa/ShinAsakawa.github.io/2022notebooks/`
- filename: `2022_0424timbmg_Sentence-VAE.ipynb`
- source: https://github.com/timbmg/Sentence-VAE の train.py と inferece.py

<center>
<img src="https://github.com/timbmg/Sentence-VAE/raw/master/figs/model.png"><br/>
</center>


In [None]:
# このセルは実行しないで良いです。ハーバード大大学院生たちの ELBO についてのおふざけ動画です
# もちろん，肘を表す英単語 elbow と，変分下限 ELBO: Evidence Lower BOund とは同じ発音なので
from IPython.display import YouTubeVideo, display
youtube_id = 'jugUBL4rEIM'
display(YouTubeVideo(youtube_id, width=600, height=480))

In [None]:
import IPython
isColab = 'google.colab' in str(IPython.get_ipython())
if isColab:
    # download.sh の内容
    !mkdir data
    !wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
    !tar -xf  simple-examples.tgz
    !mv simple-examples/data/ptb.train.txt data/
    !mv simple-examples/data/ptb.valid.txt data/
    !mv simple-examples/data/ptb.test.txt data/
    !rm -rf simple_examples

In [None]:
from google.colab import files
# ptb.py, utils.py, model.py をアップロードする必要があります
files.upload()

In [None]:
import os
import json
import time
import torch
#import argparse
import numpy as np
from multiprocessing import cpu_count
#from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
import tensorboard
from torch.utils.data import DataLoader
from collections import OrderedDict, defaultdict

from ptb import PTB
from model import SentenceVAE


In [None]:
#from utils import to_var, idx2word, expierment_name
def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return x


def idx2word(idx, i2w, pad_idx):
    sent_str = [str()]*len(idx)
    for i, sent in enumerate(idx):
        for word_id in sent:
            if word_id == pad_idx:
                break
            sent_str[i] += i2w[str(word_id.item())] + " "
        sent_str[i] = sent_str[i].strip()
    return sent_str


def interpolate(start, end, steps):

    interpolation = np.zeros((start.shape[0], steps + 2))

    for dim, (s, e) in enumerate(zip(start, end)):
        interpolation[dim] = np.linspace(s, e, steps+2)

    return interpolation.T

def expierment_name(args, ts):
    exp_name = str()
    exp_name += "BS=%i_" % args.batch_size
    exp_name += "LR={}_".format(args.learning_rate)
    exp_name += "EB=%i_" % args.embedding_size
    exp_name += "%s_" % args.rnn_type.upper()
    exp_name += "HS=%i_" % args.hidden_size
    exp_name += "L=%i_" % args.num_layers
    exp_name += "BI=%i_" % args.bidirectional
    exp_name += "LS=%i_" % args.latent_size
    exp_name += "WD={}_".format(args.word_dropout)
    exp_name += "ANN=%s_" % args.anneal_function.upper()
    exp_name += "K={}_".format(args.k)
    exp_name += "X0=%i_" % args.x0
    exp_name += "TS=%s" % ts


In [None]:
# argparse の処理を代替する
class _args():
    def __init__(self):
        
        self.data_dir = 'data'
        self.create_data = True
        self.max_sequence_length = 60
        self.min_occ = 1
        self.test = True
        
        self.epochs = 10
        self.batch_size = 32
        self.learning_rate = 0.001
        
        self.embedding_size = 300
        self.rnn_type = 'gru'
        self.hidden_size = 256
        self.num_layers = 1
        self.bidirectional = True
        self.latent_size = 16
        self.word_dropout = 0
        self.embedding_dropout = 0.5

        self.anneal_function= 'logistic'
        self.k=0.0025
        self.x0 = 2500

        self.print_every = 50
        self.tensorboard_logging = True
        self.logdir = 'logs'
        self.save_model_path ='bin'
        self.load_checkpoint = './bin/2022-0424ccap/E9.pytorch'
        self.num_samples = 10


args = _args()
assert args.rnn_type in ['rnn', 'lstm', 'gru']
assert args.anneal_function in ['logistic', 'linear']
assert 0 <= args.word_dropout <= 1        
#main(args)

In [None]:
ts = '2022_0424ccap'
#ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())
splits = ['train', 'valid'] + (['test'] if args.test else [])

datasets = OrderedDict()
for split in splits:
    datasets[split] = PTB(
        data_dir=args.data_dir,
        split=split,
        create_data=args.create_data,
        max_sequence_length=args.max_sequence_length,
        min_occ=args.min_occ
    )

params = dict(
    vocab_size=datasets['train'].vocab_size,
    sos_idx=datasets['train'].sos_idx,
    eos_idx=datasets['train'].eos_idx,
    pad_idx=datasets['train'].pad_idx,
    unk_idx=datasets['train'].unk_idx,
    max_sequence_length=args.max_sequence_length,
    embedding_size=args.embedding_size,
    rnn_type=args.rnn_type,
    hidden_size=args.hidden_size,
    word_dropout=args.word_dropout,
    embedding_dropout=args.embedding_dropout,
    latent_size=args.latent_size,
    num_layers=args.num_layers,
    bidirectional=args.bidirectional
)
model = SentenceVAE(**params)

if torch.cuda.is_available():
    model = model.cuda()

print(model)

if args.tensorboard_logging:
    writer = SummaryWriter(os.path.join(args.logdir, expierment_name(args, ts)))
    writer.add_text("model", str(model))
    writer.add_text("args", str(args))
    writer.add_text("ts", ts)

save_model_path = os.path.join(args.save_model_path, ts)
os.makedirs(save_model_path)

with open(os.path.join(save_model_path, 'model_params.json'), 'w') as f:
    json.dump(params, f, indent=4)

def kl_anneal_function(anneal_function, step, k, x0):
    if anneal_function == 'logistic':
        return float(1/(1+np.exp(-k*(step-x0))))
    elif anneal_function == 'linear':
        return min(1, step/x0)

NLL = torch.nn.NLLLoss(ignore_index=datasets['train'].pad_idx, reduction='sum')
def loss_fn(logp, target, length, mean, logv, anneal_function, step, k, x0):

    # cut-off unnecessary padding from target, and flatten
    target = target[:, :torch.max(length).item()].contiguous().view(-1)
    logp = logp.view(-1, logp.size(2))

    # Negative Log Likelihood
    NLL_loss = NLL(logp, target)

    # KL Divergence
    KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
    KL_weight = kl_anneal_function(anneal_function, step, k, x0)

    return NLL_loss, KL_loss, KL_weight

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor


In [None]:
step = 0
for epoch in range(args.epochs):

    for split in splits:
        data_loader = DataLoader(
            dataset=datasets[split],
            batch_size=args.batch_size,
            shuffle=split=='train',
            num_workers=cpu_count(),
            pin_memory=torch.cuda.is_available()
        )

        tracker = defaultdict(tensor)

        # Enable/Disable Dropout
        if split == 'train':
            model.train()
        else:
            model.eval()

        for iteration, batch in enumerate(data_loader):

            batch_size = batch['input'].size(0)

            for k, v in batch.items():
                if torch.is_tensor(v):
                    batch[k] = to_var(v)

            # Forward pass
            logp, mean, logv, z = model(batch['input'], batch['length'])

            # loss calculation
            NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
                batch['length'], mean, logv, args.anneal_function, step, args.k, args.x0)

            loss = (NLL_loss + KL_weight * KL_loss) / batch_size

            # backward + optimization
            if split == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                step += 1

            # bookkeepeing
            tracker['ELBO'] = torch.cat((tracker['ELBO'], loss.data.view(1, -1)), dim=0)

            if args.tensorboard_logging:
                writer.add_scalar("%s/ELBO" % split.upper(), loss.item(), epoch*len(data_loader) + iteration)
                writer.add_scalar("%s/NLL Loss" % split.upper(), NLL_loss.item() / batch_size,
                                    epoch*len(data_loader) + iteration)
                writer.add_scalar("%s/KL Loss" % split.upper(), KL_loss.item() / batch_size,
                                    epoch*len(data_loader) + iteration)
                writer.add_scalar("%s/KL Weight" % split.upper(), KL_weight,
                                    epoch*len(data_loader) + iteration)

            if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
                print("%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader)-1, loss.item(), NLL_loss.item()/batch_size,
                        KL_loss.item()/batch_size, KL_weight))

            if split == 'valid':
                if 'target_sents' not in tracker:
                    tracker['target_sents'] = list()
                tracker['target_sents'] += idx2word(batch['target'].data, i2w=datasets['train'].get_i2w(),
                                                    pad_idx=datasets['train'].pad_idx)
                tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

        print("%s Epoch %02d/%i, Mean ELBO %9.4f" % (split.upper(), epoch, args.epochs, tracker['ELBO'].mean()))

        if args.tensorboard_logging:
            writer.add_scalar("%s-Epoch/ELBO" % split.upper(), torch.mean(tracker['ELBO']), epoch)

        # save a dump of all sentences and the encoded latent space
        if split == 'valid':
            dump = {'target_sents': tracker['target_sents'], 'z': tracker['z'].tolist()}
            if not os.path.exists(os.path.join('dumps', ts)):
                os.makedirs('dumps/'+ts)
            with open(os.path.join('dumps/'+ts+'/valid_E%i.json' % epoch), 'w') as dump_file:
                json.dump(dump,dump_file)

        # save checkpoint
        if split == 'train':
            checkpoint_path = os.path.join(save_model_path, "E%i.pytorch" % epoch)
            torch.save(model.state_dict(), checkpoint_path)
            print("Model saved at %s" % checkpoint_path)

In [None]:
!ls -lat bin/2022

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./logs

In [None]:
from model import SentenceVAE
from utils import to_var, idx2word, interpolate

In [None]:
with open(args.data_dir+'/ptb.vocab.json', 'r') as file:
    vocab = json.load(file)

w2i, i2w = vocab['w2i'], vocab['i2w']
_model = SentenceVAE(
    vocab_size=len(w2i),
    sos_idx=w2i['<sos>'],
    eos_idx=w2i['<eos>'],
    pad_idx=w2i['<pad>'],
    unk_idx=w2i['<unk>'],
    max_sequence_length=args.max_sequence_length,
    embedding_size=args.embedding_size,
    rnn_type=args.rnn_type,
    hidden_size=args.hidden_size,
    word_dropout=args.word_dropout,
    embedding_dropout=args.embedding_dropout,
    latent_size=args.latent_size,
    num_layers=args.num_layers,
    bidirectional=args.bidirectional
)

# この行は，`bin/なんちゃらと直接書き換えないと動かないだろうな
if not os.path.exists(args.load_checkpoint):
    raise FileNotFoundError(args.load_checkpoint)

model.load_state_dict(torch.load(args.load_checkpoint))
print("Model loaded from %s" % args.load_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()
    
model.eval()
samples, z = model.inference(n=args.num_samples)
print('----------SAMPLES----------')
print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

z1 = torch.randn([args.latent_size]).numpy()
z2 = torch.randn([args.latent_size]).numpy()
z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
samples, _ = model.inference(z=z)
print('-------INTERPOLATION-------')
print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')