Skip to content

Commit

Permalink
sample theta multiple times when estimating
Browse files Browse the repository at this point in the history
  • Loading branch information
ahoho committed May 18, 2022
1 parent 7f29f1a commit 4514535
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 16 deletions.
10 changes: 7 additions & 3 deletions infer_with_scholar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

_CUDA_AVAILABLE = torch.cuda.is_available()

def retrieve_estimates(model_dir, eval_data=None, **kwargs):
def retrieve_estimates(model_dir, eval_data=None, n_samples=20, **kwargs):
"""
Loads the topic-word and training set document-topic estimates
Expand All @@ -27,7 +27,7 @@ def retrieve_estimates(model_dir, eval_data=None, **kwargs):
scholar, _ = load_scholar_model(
model_dir / "torch_model_final.pt", map_location=device
)
scholar.eval()
scholar._model.eval()

doc_topic = save_document_representations(
scholar,
Expand All @@ -39,17 +39,21 @@ def retrieve_estimates(model_dir, eval_data=None, **kwargs):
ids=None,
partition=None,
output_dir=None,
n_samples=n_samples,
)
return doc_topic

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir")
parser.add_argument("--inference_data_file")
parser.add_argument("--n_samples", default=1, type=int)
parser.add_argument("--output_fpath")
args = parser.parse_args()

assert Path(args.model_dir, "torch_model_final.pt").exists(), f"Model does not exist at {args.model_dir}/torch_model_final.pt"

eval_data = load_sparse(args.inference_data_file).astype(np.float32)
doc_topics = retrieve_estimates(args.model_dir, eval_data)
doc_topics = retrieve_estimates(args.model_dir, eval_data, n_samples=args.n_samples)
Path(args.output_fpath).parent.mkdir(exist_ok=True, parents=True)
np.save(args.output_fpath, doc_topics)
44 changes: 33 additions & 11 deletions run_scholar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import shutil
from pathlib import Path
import configargparse

import gensim
Expand Down Expand Up @@ -225,6 +226,12 @@ def main(call=None):
default=False,
help="Save model at the end of training",
)
parser.add_argument(
"--theta_samples",
type=int,
default=1,
help="Number of samples used to compute document-topic representations (1 means posterior mean is used)",
)

parser.add_argument(
"--emb_dim",
Expand Down Expand Up @@ -414,6 +421,9 @@ def main(call=None):
)

print("Loading document representations")
if options.doc_reps_dir is not None and not Path(options.doc_reps_dir).exists():
# if it doesn't exist, treat it as a path relative to the input directory
options.doc_reps_dir = str(Path(input_dir, options.doc_reps_dir))
train_doc_reps = load_doc_reps(
options.doc_reps_dir,
prefix=options.train_prefix,
Expand Down Expand Up @@ -771,6 +781,7 @@ def main(call=None):
options.output_dir,
"train",
batch_size=options.batch_size,
n_samples=options.theta_samples,
)

if dev_X is not None:
Expand All @@ -785,6 +796,7 @@ def main(call=None):
options.output_dir,
"dev",
batch_size=options.batch_size,
n_samples=options.theta_samples,
)

if n_test > 0:
Expand All @@ -799,10 +811,12 @@ def main(call=None):
options.output_dir,
"test",
batch_size=options.batch_size,
n_samples=options.theta_samples,
)

if options.temp_output_dir:
shutil.copytree(options.output_dir, final_output_dir, dirs_exist_ok=True)
for fpath in Path(options.output_dir).glob("*"):
shutil.copy(fpath, Path(final_output_dir, fpath.name))
shutil.rmtree(options.output_dir)


Expand Down Expand Up @@ -1646,24 +1660,32 @@ def print_topic_label_associations(


def save_document_representations(
model, X, Y, PC, TC, DR, ids, output_dir, partition, batch_size=200
model, X, Y, PC, TC, DR, ids, output_dir, partition, batch_size=200, n_samples=20
):
# compute the mean of the posterior of the latent representation for each documetn and save it
if Y is not None:
Y = np.zeros_like(Y)

n_items, _ = X.shape
n_batches = int(np.ceil(n_items / batch_size))
thetas = []

for i in range(n_batches):
batch_xs, batch_ys, batch_pcs, batch_tcs, batch_drs = get_minibatch(
X, Y, PC, TC, DR, i, batch_size
)
thetas.append(
model.compute_theta(batch_xs, batch_ys, batch_pcs, batch_tcs, batch_drs)
)
theta = np.vstack(thetas)
var_scale = 1.0
if n_samples<=1:
var_scale = 0.0 # take posterior mean

theta_samples = []
for _ in range(n_samples):
batch_thetas = []
for i in range(n_batches):
batch_xs, batch_ys, batch_pcs, batch_tcs, batch_drs = get_minibatch(
X, Y, PC, TC, DR, i, batch_size
)
batch_thetas.append(
model.compute_theta(batch_xs, batch_ys, batch_pcs, batch_tcs, batch_drs, var_scale=var_scale)
)
theta_samples.append(np.vstack(batch_thetas))

theta = np.array(theta_samples).mean(0)

if output_dir is not None:
np.savez(
Expand Down
4 changes: 2 additions & 2 deletions scholar.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def get_losses(self, X, Y, PC, TC, DR, eta_bn_prop=0.0, n_samples=0):

return losses

def compute_theta(self, X, Y, PC, TC, DR, eta_bn_prop=0.0):
def compute_theta(self, X, Y, PC, TC, DR, var_scale=0.0, eta_bn_prop=0.0):
"""
Return the latent document representation (mean of posterior of theta) for a given batch of X, Y, PC, and TC
"""
Expand All @@ -267,7 +267,7 @@ def compute_theta(self, X, Y, PC, TC, DR, eta_bn_prop=0.0):
if DR is not None:
DR = torch.Tensor(DR).to(self.device)
theta, _, _, _ = self._model(
X, Y, PC, TC, DR, do_average=False, var_scale=0.0, eta_bn_prop=eta_bn_prop
X, Y, PC, TC, DR, do_average=False, var_scale=var_scale, eta_bn_prop=eta_bn_prop
)

return theta.to("cpu").detach().numpy()
Expand Down

0 comments on commit 4514535

Please sign in to comment.