Skip to content

Commit

Permalink
Merge pull request #211 from june-08042/summarization_seq2seq_baseline
Browse files Browse the repository at this point in the history
Summarization seq2seq baseline
  • Loading branch information
abhinavkashyap committed Dec 15, 2020
2 parents e30d7dd + b50c353 commit 4712af4
Show file tree
Hide file tree
Showing 25 changed files with 1,879 additions and 28 deletions.
2 changes: 1 addition & 1 deletion examples/science_ie/science_ie.py
Expand Up @@ -154,4 +154,4 @@
sample_proportion=args.sample_proportion,
)

engine.run()
engine.run()
@@ -0,0 +1,162 @@
from sciwing.datasets.summarization.abstractive_text_summarization_dataset import AbstractiveSummarizationDatasetManager
from sciwing.modules.embedders.bow_elmo_embedder import BowElmoEmbedder
from sciwing.modules.embedders.word_embedder import WordEmbedder
from sciwing.modules.embedders.concat_embedders import ConcatEmbedders
from sciwing.modules.lstm2seqencoder import Lstm2SeqEncoder
from sciwing.modules.lstm2seqdecoder import Lstm2SeqDecoder
from sciwing.models.simple_seq2seq import Seq2SeqModel
import pathlib
from sciwing.metrics.summarization_metrics import SummarizationMetrics
import torch.optim as optim
from sciwing.engine.engine import Engine
import argparse
import torch
import sciwing.constants as constants

PATHS = constants.PATHS
DATA_DIR = PATHS["DATA_DIR"]


if __name__ == "__main__":
# read the hyperparams from config file
parser = argparse.ArgumentParser(
description="Glove with LSTM encoder and decoder"
)

parser.add_argument("--exp_name", help="Specify an experiment name", type=str)

parser.add_argument(
"--device", help="Specify the device where the model is run", type=str
)

parser.add_argument("--bs", help="batch size", type=int)
parser.add_argument("--lr", help="learning rate", type=float)
parser.add_argument("--epochs", help="number of epochs", type=int)
parser.add_argument(
"--save_every", help="Save the model every few epochs", type=int
)
parser.add_argument(
"--log_train_metrics_every",
help="Log training metrics every few iterations",
type=int,
)
parser.add_argument(
"--emb_type",
help="The type of glove embedding you want. The allowed types are glove_6B_50, glove_6B_100, "
"glove_6B_200, glove_6B_300, random",
)
parser.add_argument(
"--hidden_dim", help="Hidden dimension of the LSTM network", type=int
)
parser.add_argument(
"--bidirectional",
help="Specify Whether the lstm is bidirectional or uni-directional",
action="store_true",
)
parser.add_argument(
"--combine_strategy",
help="How do you want to combine the hidden dimensions of the two "
"combinations",
)

parser.add_argument(
"--pred_max_length", help="Maximum length of prediction", type=int
)

parser.add_argument(
"--exp_dir_path", help="Directory to store all experiment related information"
)

parser.add_argument(
"--model_save_dir",
help="Directory where the checkpoints during model training are stored.",
)
parser.add_argument(
"--sample_proportion", help="Sample proportion for the dataset", type=float
)

args = parser.parse_args()

DATA_PATH = pathlib.Path(DATA_DIR)
train_file = DATA_PATH.joinpath("pubmedSeq2seq.train")
dev_file = DATA_PATH.joinpath("pubmedSeq2seq.dev")
test_file = DATA_PATH.joinpath("pubmedSeq2seq.test")

data_manager = AbstractiveSummarizationDatasetManager(
train_filename=str(train_file),
dev_filename=str(dev_file),
test_filename=str(test_file),
)

vocab = data_manager.build_vocab()['tokens']

# # instantiate the elmo embedder
# elmo_embedder = BowElmoEmbedder(layer_aggregation="sum", device=args.device)
#
# # instantiate the vanilla embedder
# vanilla_embedder = WordEmbedder(embedding_type=args.emb_type, device=args.device)
#
# # concat the embeddings
# embedder = ConcatEmbedders([vanilla_embedder, elmo_embedder])

embedder = WordEmbedder(embedding_type=args.emb_type, device=args.device)

encoder = Lstm2SeqEncoder(
embedder=embedder,
hidden_dim=args.hidden_dim,
bidirectional=args.bidirectional,
combine_strategy=args.combine_strategy,
device=torch.device(args.device),
)

encoding_dim = (
2 * args.hidden_dim
if args.bidirectional and args.combine_strategy == "concat"
else args.hidden_dim
)

decoder = Lstm2SeqDecoder(
embedder=embedder,
hidden_dim=args.hidden_dim,
bidirectional=args.bidirectional,
combine_strategy=args.combine_strategy,
device=torch.device(args.device),
max_length=args.pred_max_length,
vocab=vocab
)

model = Seq2SeqModel(
rnn2seqencoder=encoder,
rnn2seqdecoder=decoder,
enc_hidden_dim=args.hidden_dim,
datasets_manager=data_manager,
device=args.device,
bidirectional=args.bidirectional
)

optimizer = optim.Adam(params=model.parameters(), lr=args.lr)
train_metric = SummarizationMetrics(datasets_manager=data_manager)
dev_metric = SummarizationMetrics(datasets_manager=data_manager)
test_metric = SummarizationMetrics(datasets_manager=data_manager)

engine = Engine(
model=model,
datasets_manager=data_manager,
optimizer=optimizer,
batch_size=args.bs,
save_dir=args.model_save_dir,
num_epochs=args.epochs,
save_every=args.save_every,
train_metric=train_metric,
validation_metric=dev_metric,
test_metric=test_metric,
log_train_metrics_every=args.log_train_metrics_every,
device=torch.device(args.device),
use_wandb=True,
experiment_name=args.exp_name,
experiment_hyperparams=vars(args),
track_for_best="rouge_1",
sample_proportion=args.sample_proportion,
)

engine.run()
@@ -0,0 +1,20 @@
#!/usr/bin/env bash

EXPERIMENT_PREFIX="lstm_seq2seq"
SCRIPT_FILE="pubmed_summarization_bilstm_seq2seq.py"

python ${SCRIPT_FILE} \
--exp_name ${EXPERIMENT_PREFIX}"_pubmed" \
--model_save_dir "./pubmed_seq2seq" \
--device "cuda:0" \
--bs 8 \
--emb_type "glove_6B_50" \
--hidden_dim 256 \
--bidirectional \
--lr 1e-3 \
--combine_strategy concat \
--epochs 100 \
--save_every 10 \
--log_train_metrics_every 10 \
--sample_proportion 1 \
--pred_max_length 500
50 changes: 50 additions & 0 deletions sciwing/data/seq_sentence.py
@@ -0,0 +1,50 @@
from sciwing.tokenizers.word_tokenizer import WordTokenizer
from sciwing.tokenizers.BaseTokenizer import BaseTokenizer
from sciwing.data.token import Token
from typing import Union, List, Dict, Any
from collections import defaultdict


class SeqSentence:
def __init__(self, sents: List[str], tokenizers: Dict[str, BaseTokenizer] = None):
if tokenizers is None:
tokenizers = {"tokens": WordTokenizer()}
self.sents = sents
self.tokenizers = tokenizers
self.tokens: Dict[str, List[List[Any]]] = defaultdict(list)
self.namespaces = list(tokenizers.keys())

for namespace, tokenizer in tokenizers.items():
for sent in sents:
sent_tokens = tokenizer.tokenize(sent)
self.add_sent_tokens(tokens=sent_tokens, namespace=namespace)

def add_sent_tokens(self, tokens: Union[List[str], List[Token]], namespace: str):
sent_tokens = []
for token in tokens:
if isinstance(token, str):
token = Token(token)
sent_tokens.append(token)
self.tokens[namespace].append(sent_tokens)

def add_tokens(self, sents: str, tokenizers: Dict[str, BaseTokenizer] = None):
for namespace, tokenizer in tokenizers.items():
for sent in sents:
sent_tokens = tokenizer.tokenize(sent)
self.add_sent_tokens(tokens=sent_tokens, namespace=namespace)

@property
def tokens(self):
return self._tokens

@tokens.setter
def tokens(self, value):
self._tokens = value

@property
def namespaces(self):
return self._namespaces

@namespaces.setter
def namespaces(self, value):
self._namespaces = value
Empty file.

0 comments on commit 4712af4

Please sign in to comment.