Skip to content

Commit

Permalink
Merge 5e0f761 into 1deb01a
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnGiorgi committed Oct 7, 2019
2 parents 1deb01a + 5e0f761 commit ba62a4b
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 30 deletions.
2 changes: 1 addition & 1 deletion saber/models/bert_for_ner.py
Expand Up @@ -2,7 +2,7 @@
from itertools import zip_longest

import torch
from pytorch_transformers import BertTokenizer
from transformers import BertTokenizer
from tqdm import tqdm

from .. import constants
Expand Down
4 changes: 2 additions & 2 deletions saber/models/bert_for_ner_and_re.py
Expand Up @@ -2,8 +2,8 @@
from itertools import zip_longest

import torch
from pytorch_transformers import BertConfig
from pytorch_transformers import BertTokenizer
from transformers import BertConfig
from transformers import BertTokenizer
from tqdm import tqdm

from .. import constants
Expand Down
13 changes: 6 additions & 7 deletions saber/models/modules/bert_for_entity_and_relation_extraction.py
Expand Up @@ -2,9 +2,9 @@
from itertools import permutations

import torch
from pytorch_transformers import BertModel
# TODO (John): This can be shortned to from pytorch_transformers import x after next release
from pytorch_transformers.modeling_bert import BertPreTrainedModel
from transformers import BertModel
# TODO (John): This can be shortned to from transformers import x after next release
from transformers.modeling_bert import BertPreTrainedModel
from torch import nn
from torch.nn import CrossEntropyLoss

Expand Down Expand Up @@ -126,11 +126,10 @@ def __init__(self, config):
# Biaffine transformation for relation classification
self.rel_classifier = BiaffineAttention(head_tail_ffnns_size // 2, self.num_rel_labels)

def forward(self, input_ids, orig_to_tok_map, token_type_ids=None, attention_mask=None,
ent_labels=None, rel_labels=None, position_ids=None, head_mask=None):
def forward(self, input_ids, orig_to_tok_map, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, ent_labels=None, rel_labels=None):
# Forward pass through BERT
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
outputs = self.bert(input_ids, attention_mask, token_type_ids, position_ids, head_mask)
sequence_output = outputs[0]

# NER classification
Expand Down
11 changes: 5 additions & 6 deletions saber/models/modules/bert_for_token_classification_multi_task.py
Expand Up @@ -3,8 +3,8 @@
from torch import nn
from torch.nn import CrossEntropyLoss

from pytorch_transformers.modeling_bert import BertPreTrainedModel
from pytorch_transformers import BertModel
from transformers.modeling_bert import BertPreTrainedModel
from transformers import BertModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -91,10 +91,9 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None, model_idx=-1):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, labels=None, model_idx=-1):
outputs = self.bert(input_ids, attention_mask, token_type_ids, position_ids, head_mask)
sequence_output = outputs[0]

sequence_output = self.dropout(sequence_output)
Expand Down
2 changes: 1 addition & 1 deletion saber/tests/conftest.py
Expand Up @@ -2,7 +2,7 @@
"""
import pytest
import spacy
from pytorch_transformers import BertTokenizer
from transformers import BertTokenizer

from ..config import Config
from ..constants import SPACY_MODEL
Expand Down
10 changes: 5 additions & 5 deletions saber/tests/test_bert_for_ner.py
Expand Up @@ -2,11 +2,11 @@
"""
import os

from pytorch_transformers import CONFIG_NAME
from pytorch_transformers import WEIGHTS_NAME
from pytorch_transformers import BertForTokenClassification
from pytorch_transformers import BertTokenizer
from pytorch_transformers.optimization import AdamW
from transformers import CONFIG_NAME
from transformers import WEIGHTS_NAME
from transformers import BertForTokenClassification
from transformers import BertTokenizer
from transformers.optimization import AdamW

from ..constants import PARTITIONS
from ..constants import WORDPIECE
Expand Down
10 changes: 5 additions & 5 deletions saber/tests/test_bert_for_ner_and_re.py
Expand Up @@ -3,11 +3,11 @@
"""
import os

from pytorch_transformers import CONFIG_NAME
from pytorch_transformers import WEIGHTS_NAME
from pytorch_transformers import BertForTokenClassification
from pytorch_transformers import BertTokenizer
from pytorch_transformers.optimization import AdamW
from transformers import CONFIG_NAME
from transformers import WEIGHTS_NAME
from transformers import BertForTokenClassification
from transformers import BertTokenizer
from transformers.optimization import AdamW

from ..constants import PARTITIONS
from ..constants import WORDPIECE
Expand Down
2 changes: 1 addition & 1 deletion saber/tests/test_bert_utils.py
@@ -1,7 +1,7 @@
"""Test suite for the `bert_utils` module (saber.utils.bert_utils).
"""
import torch
from pytorch_transformers.optimization import AdamW
from transformers.optimization import AdamW
from torch.utils.data import RandomSampler
from torch.utils.data import SequentialSampler

Expand Down
2 changes: 1 addition & 1 deletion saber/utils/bert_utils.py
@@ -1,6 +1,6 @@
import torch
from keras_preprocessing.sequence import pad_sequences
from pytorch_transformers.optimization import AdamW
from transformers.optimization import AdamW
from torch.utils import data

from saber import constants
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -35,7 +35,7 @@
'scikit-learn>=0.20.1',
'scikit-multilearn>=0.2.0',
'torch>=1.2.0',
'pytorch-transformers>=1.2.0',
'transformers>=2.0.0',
'Flask>=1.0.2',
'waitress>=1.1.0',
'Keras-Preprocessing>=1.1.0',
Expand Down

0 comments on commit ba62a4b

Please sign in to comment.