Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #211 from june-08042/summarization_seq2seq_baseline
Summarization seq2seq baseline
- Loading branch information
Showing
25 changed files
with
1,879 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -154,4 +154,4 @@ | |
sample_proportion=args.sample_proportion, | ||
) | ||
|
||
engine.run() | ||
engine.run() |
162 changes: 162 additions & 0 deletions
162
examples/summarization_abstractive/pubmed_summarization_bilstm_seq2seq.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
from sciwing.datasets.summarization.abstractive_text_summarization_dataset import AbstractiveSummarizationDatasetManager | ||
from sciwing.modules.embedders.bow_elmo_embedder import BowElmoEmbedder | ||
from sciwing.modules.embedders.word_embedder import WordEmbedder | ||
from sciwing.modules.embedders.concat_embedders import ConcatEmbedders | ||
from sciwing.modules.lstm2seqencoder import Lstm2SeqEncoder | ||
from sciwing.modules.lstm2seqdecoder import Lstm2SeqDecoder | ||
from sciwing.models.simple_seq2seq import Seq2SeqModel | ||
import pathlib | ||
from sciwing.metrics.summarization_metrics import SummarizationMetrics | ||
import torch.optim as optim | ||
from sciwing.engine.engine import Engine | ||
import argparse | ||
import torch | ||
import sciwing.constants as constants | ||
|
||
PATHS = constants.PATHS | ||
DATA_DIR = PATHS["DATA_DIR"] | ||
|
||
|
||
if __name__ == "__main__": | ||
# read the hyperparams from config file | ||
parser = argparse.ArgumentParser( | ||
description="Glove with LSTM encoder and decoder" | ||
) | ||
|
||
parser.add_argument("--exp_name", help="Specify an experiment name", type=str) | ||
|
||
parser.add_argument( | ||
"--device", help="Specify the device where the model is run", type=str | ||
) | ||
|
||
parser.add_argument("--bs", help="batch size", type=int) | ||
parser.add_argument("--lr", help="learning rate", type=float) | ||
parser.add_argument("--epochs", help="number of epochs", type=int) | ||
parser.add_argument( | ||
"--save_every", help="Save the model every few epochs", type=int | ||
) | ||
parser.add_argument( | ||
"--log_train_metrics_every", | ||
help="Log training metrics every few iterations", | ||
type=int, | ||
) | ||
parser.add_argument( | ||
"--emb_type", | ||
help="The type of glove embedding you want. The allowed types are glove_6B_50, glove_6B_100, " | ||
"glove_6B_200, glove_6B_300, random", | ||
) | ||
parser.add_argument( | ||
"--hidden_dim", help="Hidden dimension of the LSTM network", type=int | ||
) | ||
parser.add_argument( | ||
"--bidirectional", | ||
help="Specify Whether the lstm is bidirectional or uni-directional", | ||
action="store_true", | ||
) | ||
parser.add_argument( | ||
"--combine_strategy", | ||
help="How do you want to combine the hidden dimensions of the two " | ||
"combinations", | ||
) | ||
|
||
parser.add_argument( | ||
"--pred_max_length", help="Maximum length of prediction", type=int | ||
) | ||
|
||
parser.add_argument( | ||
"--exp_dir_path", help="Directory to store all experiment related information" | ||
) | ||
|
||
parser.add_argument( | ||
"--model_save_dir", | ||
help="Directory where the checkpoints during model training are stored.", | ||
) | ||
parser.add_argument( | ||
"--sample_proportion", help="Sample proportion for the dataset", type=float | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
DATA_PATH = pathlib.Path(DATA_DIR) | ||
train_file = DATA_PATH.joinpath("pubmedSeq2seq.train") | ||
dev_file = DATA_PATH.joinpath("pubmedSeq2seq.dev") | ||
test_file = DATA_PATH.joinpath("pubmedSeq2seq.test") | ||
|
||
data_manager = AbstractiveSummarizationDatasetManager( | ||
train_filename=str(train_file), | ||
dev_filename=str(dev_file), | ||
test_filename=str(test_file), | ||
) | ||
|
||
vocab = data_manager.build_vocab()['tokens'] | ||
|
||
# # instantiate the elmo embedder | ||
# elmo_embedder = BowElmoEmbedder(layer_aggregation="sum", device=args.device) | ||
# | ||
# # instantiate the vanilla embedder | ||
# vanilla_embedder = WordEmbedder(embedding_type=args.emb_type, device=args.device) | ||
# | ||
# # concat the embeddings | ||
# embedder = ConcatEmbedders([vanilla_embedder, elmo_embedder]) | ||
|
||
embedder = WordEmbedder(embedding_type=args.emb_type, device=args.device) | ||
|
||
encoder = Lstm2SeqEncoder( | ||
embedder=embedder, | ||
hidden_dim=args.hidden_dim, | ||
bidirectional=args.bidirectional, | ||
combine_strategy=args.combine_strategy, | ||
device=torch.device(args.device), | ||
) | ||
|
||
encoding_dim = ( | ||
2 * args.hidden_dim | ||
if args.bidirectional and args.combine_strategy == "concat" | ||
else args.hidden_dim | ||
) | ||
|
||
decoder = Lstm2SeqDecoder( | ||
embedder=embedder, | ||
hidden_dim=args.hidden_dim, | ||
bidirectional=args.bidirectional, | ||
combine_strategy=args.combine_strategy, | ||
device=torch.device(args.device), | ||
max_length=args.pred_max_length, | ||
vocab=vocab | ||
) | ||
|
||
model = Seq2SeqModel( | ||
rnn2seqencoder=encoder, | ||
rnn2seqdecoder=decoder, | ||
enc_hidden_dim=args.hidden_dim, | ||
datasets_manager=data_manager, | ||
device=args.device, | ||
bidirectional=args.bidirectional | ||
) | ||
|
||
optimizer = optim.Adam(params=model.parameters(), lr=args.lr) | ||
train_metric = SummarizationMetrics(datasets_manager=data_manager) | ||
dev_metric = SummarizationMetrics(datasets_manager=data_manager) | ||
test_metric = SummarizationMetrics(datasets_manager=data_manager) | ||
|
||
engine = Engine( | ||
model=model, | ||
datasets_manager=data_manager, | ||
optimizer=optimizer, | ||
batch_size=args.bs, | ||
save_dir=args.model_save_dir, | ||
num_epochs=args.epochs, | ||
save_every=args.save_every, | ||
train_metric=train_metric, | ||
validation_metric=dev_metric, | ||
test_metric=test_metric, | ||
log_train_metrics_every=args.log_train_metrics_every, | ||
device=torch.device(args.device), | ||
use_wandb=True, | ||
experiment_name=args.exp_name, | ||
experiment_hyperparams=vars(args), | ||
track_for_best="rouge_1", | ||
sample_proportion=args.sample_proportion, | ||
) | ||
|
||
engine.run() |
20 changes: 20 additions & 0 deletions
20
examples/summarization_abstractive/pubmed_summarization_bilstm_seq2seq.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#!/usr/bin/env bash | ||
|
||
EXPERIMENT_PREFIX="lstm_seq2seq" | ||
SCRIPT_FILE="pubmed_summarization_bilstm_seq2seq.py" | ||
|
||
python ${SCRIPT_FILE} \ | ||
--exp_name ${EXPERIMENT_PREFIX}"_pubmed" \ | ||
--model_save_dir "./pubmed_seq2seq" \ | ||
--device "cuda:0" \ | ||
--bs 8 \ | ||
--emb_type "glove_6B_50" \ | ||
--hidden_dim 256 \ | ||
--bidirectional \ | ||
--lr 1e-3 \ | ||
--combine_strategy concat \ | ||
--epochs 100 \ | ||
--save_every 10 \ | ||
--log_train_metrics_every 10 \ | ||
--sample_proportion 1 \ | ||
--pred_max_length 500 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from sciwing.tokenizers.word_tokenizer import WordTokenizer | ||
from sciwing.tokenizers.BaseTokenizer import BaseTokenizer | ||
from sciwing.data.token import Token | ||
from typing import Union, List, Dict, Any | ||
from collections import defaultdict | ||
|
||
|
||
class SeqSentence: | ||
def __init__(self, sents: List[str], tokenizers: Dict[str, BaseTokenizer] = None): | ||
if tokenizers is None: | ||
tokenizers = {"tokens": WordTokenizer()} | ||
self.sents = sents | ||
self.tokenizers = tokenizers | ||
self.tokens: Dict[str, List[List[Any]]] = defaultdict(list) | ||
self.namespaces = list(tokenizers.keys()) | ||
|
||
for namespace, tokenizer in tokenizers.items(): | ||
for sent in sents: | ||
sent_tokens = tokenizer.tokenize(sent) | ||
self.add_sent_tokens(tokens=sent_tokens, namespace=namespace) | ||
|
||
def add_sent_tokens(self, tokens: Union[List[str], List[Token]], namespace: str): | ||
sent_tokens = [] | ||
for token in tokens: | ||
if isinstance(token, str): | ||
token = Token(token) | ||
sent_tokens.append(token) | ||
self.tokens[namespace].append(sent_tokens) | ||
|
||
def add_tokens(self, sents: str, tokenizers: Dict[str, BaseTokenizer] = None): | ||
for namespace, tokenizer in tokenizers.items(): | ||
for sent in sents: | ||
sent_tokens = tokenizer.tokenize(sent) | ||
self.add_sent_tokens(tokens=sent_tokens, namespace=namespace) | ||
|
||
@property | ||
def tokens(self): | ||
return self._tokens | ||
|
||
@tokens.setter | ||
def tokens(self, value): | ||
self._tokens = value | ||
|
||
@property | ||
def namespaces(self): | ||
return self._namespaces | ||
|
||
@namespaces.setter | ||
def namespaces(self, value): | ||
self._namespaces = value |
Empty file.
Oops, something went wrong.