Skip to content

Commit

Permalink
tsf classifier trainer
Browse files Browse the repository at this point in the history
Former-commit-id: 7f984f3
  • Loading branch information
zcyang committed Dec 23, 2017
1 parent 7f8cce4 commit 39ce995
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 13 deletions.
54 changes: 54 additions & 0 deletions examples/tsf/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,57 @@ def d1(self):
def __str__(self):
return "loss %.2f, g %.2f, ppl %.2f d %.2f, adv %.2f %.2f" %(
self.loss, self.g, self.ppl, self.d, self.d0, self.d1)


class TSFClassifierStats():
def __init__(self):
self.reset()

def reset(self):
self._loss, self._g, self._ppl, self._df, self._dr, self._ds \
= [], [], [], [], [], []
self._w_loss, self._w_g, self._w_ppl, self._w_df, self._w_dr, self._w_ds \
= 0, 0, 0, 0, 0, 0

def append(self, loss, g, ppl, df, dr, ds,
w_loss=1., w_g=1., w_ppl=1., w_df=1, w_dr=1., w_ds=1.):
self._loss.append(loss*w_loss)
self._g.append(g*w_g)
self._ppl.append(ppl*w_ppl)
self._df.append(df*w_df)
self._dr.append(dr*w_dr)
self._ds.append(ds*w_ds)
self._w_loss += w_loss
self._w_g += w_g
self._w_ppl += w_ppl
self._w_df += w_df
self._w_dr += w_dr
self._w_ds+= w_ds

@property
def loss(self):
return sum(self._loss) / self._w_loss

@property
def g(self):
return sum(self._g) / self._w_g

@property
def ppl(self):
return sum(self._ppl) / self._w_ppl

@property
def df(self):
return sum(self._df) / self._w_df

@property
def dr(self):
return sum(self._dr) / self._w_dr

@property
def ds(self):
return sum(self._ds) / self._w_ds

def __str__(self):
return "l %.2f, g %.2f, p %.2f df %.2f, dr %.2f ds %.2f" %(
self.loss, self.g, self.ppl, self.df, self.dr, self.ds)
183 changes: 183 additions & 0 deletions examples/tsf/tsf_classifier_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""
Trainer for tsf.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import pdb

import cPickle as pkl
import numpy as np
import tensorflow as tf
import json
import os

from texar.hyperparams import HParams
from texar.models.tsf import TSFClassifier

from trainer_base import TrainerBase
from utils import *
from tsf_utils import *
from stats import TSFClassifierStats as Stats

class TSFClassifierTrainer(TrainerBase):
"""TSFClassifier trainer."""
def __init__(self, hparams=None):
TrainerBase.__init__(self, hparams)

@staticmethod
def default_hparams():
return {
"data_dir": "../../data/yelp",
"expt_dir": "../../expt",
"log_dir": "log",
"name": "tsf",
"rho_f": 1.,
"rho_r": 0.,
"gamma_init": 1,
"gamma_decay": 0.5,
"gamma_min": 0.001,
"disp_interval": 100,
"batch_size": 128,
"vocab_size": 10000,
"max_len": 20,
"max_epoch": 20,
"sort_data": False,
"shuffle_across_epoch": True,
"d_update_freq": 1,
}


def eval_model(self, model, sess, vocab, data0, data1, output_path):
batches, order0, order1 = get_batches(
data0, data1, vocab["word2id"],
self._hparams.batch_size, sort=self._hparams.sort_data)
losses = Stats()

data0_ori, data1_ori, data0_tsf, data1_tsf = [], [], [], []
for batch in batches:
logits_ori, logits_tsf = model.decode_step(sess, batch)

loss, loss_g, ppl_g, loss_d, loss_d0, loss_d1 = model.eval_step(
sess, batch, self._hparams.rho_f, self._hparams.rho_r,
self._hparams.gamma_min)
batch_size = len(batch["enc_inputs"])
word_size = np.sum(batch["weights"])
losses.append(loss, loss_g, ppl_g, loss_d, loss_d0, loss_d1,
w_loss=batch_size, w_g=batch_size,
w_ppl=word_size, w_d=batch_size,
w_d0=batch_size, w_d1=batch_size)
ori = logits2word(logits_ori, vocab["id2word"])
tsf = logits2word(logits_tsf, vocab["id2word"])
half = self._hparams.batch_size // 2
data0_ori += ori[:half]
data1_ori += ori[half:]
data0_tsf += tsf[:half]
data1_tsf += tsf[half:]

n0 = len(data0)
n1 = len(data1)
data0_ori = reorder(order0, data0_ori)[:n0]
data1_ori = reorder(order1, data1_ori)[:n1]
data0_tsf = reorder(order0, data0_tsf)[:n0]
data1_tsf = reorder(order1, data1_tsf)[:n1]

write_sent(data0_ori, output_path + ".0.ori")
write_sent(data1_ori, output_path + ".1.ori")
write_sent(data0_tsf, output_path + ".0.tsf")
write_sent(data1_tsf, output_path + ".1.tsf")
return losses

def train(self):
if "config" in self._hparams.keys():
with open(self._hparams.config) as f:
self._hparams = HParams(pkl.load(f))

log_print("Start training with hparams:")
log_print(json.dumps(self._hparams.todict(), indent=2))
if not "config" in self._hparams.keys():
with open(os.path.join(self._hparams.expt_dir, self._hparams.name)
+ ".config", "w") as f:
pkl.dump(self._hparams, f)

vocab, train, val, test = self.load_data()

# set vocab size
self._hparams.vocab_size = vocab["size"]

# set some hparams here

with tf.Session() as sess:
model = TSFClassifier(self._hparams)
log_print("finished building model")

if "model" in self._hparams.keys():
model.saver.restore(sess, self._hparams.model)
else:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

losses = Stats()
gamma = self._hparams.gamma_init
step = 0
best_dev = float("inf")
batches, _, _ = get_batches(train[0], train[1], vocab["word2id"],
model._hparams.batch_size,
sort=self._hparams.sort_data)

log_dir = os.path.join(self._hparams.expt_dir, self._hparams.log_dir)
train_writer = tf.summary.FileWriter(log_dir, sess.graph)

for epoch in range(1, self._hparams["max_epoch"] + 1):
# shuffle across batches
log_print("------------------epoch %d --------------"%(epoch))
log_print("gamma %.3f"%(gamma))
if self._hparams.shuffle_across_epoch:
batches, _, _ = get_batches(train[0], train[1], vocab["word2id"],
model._hparams.batch_size,
sort=self._hparams.sort_data)
random.shuffle(batches)
for batch in batches:
loss_ds = 0.
for _ in range(self._hparams.d_update_freq):
loss_ds = model.train_d_step(sess, batch)

if loss_ds < 1.2:
loss, loss_g, ppl_g, loss_df, loss_dr = model.train_g_step(
sess, batch, self._hparams.rho_f, self._hparams.rho_r, gamma)
else:
loss, loss_g, ppl_g, loss_df, loss_dr = model.train_ae_step(
sess, batch, self._hparams.rho_f, self._hparams.rho_r, gamma)

losses.append(loss, loss_g, ppl_g, loss_df, loss_dr, loss_ds)

step += 1
if step % self._hparams.disp_interval == 0:
log_print("step %d: "%(step) + str(losses))
losses.reset()

# eval on dev
dev_loss = self.eval_model(model, sess, vocab, val[0], val[1],
os.path.join(log_dir, "sentiment.dev.epoch%d"%(epoch)))
log_print("dev " + str(dev_loss))
if dev_loss.loss < best_dev:
best_dev = dev_loss.loss
file_name = (
self._hparams["name"] + "_" + "%.2f" %(best_dev) + ".model")
model.saver.save(
sess, os.path.join(self._hparams["expt_dir"], file_name),
latest_filename=self._hparams["name"] + "_checkpoint",
global_step=step)
log_print("saved model %s"%(file_name))

gamma = max(self._hparams.gamma_min, gamma * self._hparams.gamma_decay)

return best_dev

def main(unused_args):
trainer = TSFClassifierTrainer()
trainer.train()

if __name__ == "__main__":
tf.app.run()
26 changes: 13 additions & 13 deletions texar/models/tsf/tsf_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,44 +220,44 @@ def _build_model(self, input_tensors, reuse=False):

return output_tensors, loss, opt

def train_d_step(self, sess, batch, rho, gamma):
def train_d_step(self, sess, batch):
loss_ds, _ = sess.run(
[self.loss["loss_ds"], self.opt["optimizer_ds"],],
self.feed_dict(batch, rho, gamma))
self.feed_dict(batch, 0., 0., 1.))
return loss_ds

def train_g_step(self, sess, batch, rho, gamma):
def train_g_step(self, sess, batch, rho_f, rho_r, gamma):
loss, loss_g, ppl_g, loss_df, loss_dr = sess.run(
[self.loss["loss"],
self.loss["loss_g"],
self.loss["ppl_g"],
self.loss["loss_df"],
self.loss["loss_dr"],
self.opt["optimizer_all"]],
self.feed_dict(batch, rho, gamma))
self.feed_dict(batch, rho_f, rho_r, gamma))
return loss, loss_g, ppl_g, loss_df, loss_dr

def train_ae_step(self, sess, batch, rho, gamma):
def train_ae_step(self, sess, batch, rho_f, rho_r, gamma):
loss, loss_g, ppl_g, loss_df, loss_dr, _ = sess.run(
[self.loss["loss"],
self.loss["loss_g"],
self.loss["ppl_g"],
self.loss["loss_df"],
self.loss["loss_dr"],
self.opt["optimizer_ae"]],
self.feed_dict(batch, rho, gamma))
self.feed_dict(batch, rho_f, rho_r, gamma))
return loss, loss_g, ppl_g, loss_df, loss_dr

def eval_step(self, sess, batch, rho, gamma):
loss, loss_g, ppl_g, loss_df, loss_dr = sess.run(
def eval_step(self, sess, batch, rho_f, rho_r, gamma):
loss, loss_g, ppl_g, loss_df, loss_dr, loss_ds = sess.run(
[self.loss["loss"],
self.loss["loss_g"],
self.loss["ppl_g"],
self.loss["loss_d"],
self.loss["loss_d0"],
self.loss["loss_d1"]],
self.feed_dict(batch, rho, gamma, is_train=False))
return loss, loss_g, ppl_g, loss_d, loss_d0, loss_d1
self.loss["loss_df"],
self.loss["loss_dr"],
self.loss["loss_ds"]],
self.feed_dict(batch, rho_f, rho_r, gamma, is_train=False))
return loss, loss_g, ppl_g, loss_df, loss_dr, loss_ds

def decode_step(self, sess, batch):
logits_ori, logits_tsf = sess.run(
Expand Down

0 comments on commit 39ce995

Please sign in to comment.