Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
with
1,052 additions
and 259 deletions.
- +54 −0 deepmatcher/batch.py
- +4 −5 deepmatcher/data/__init__.py
- +220 −239 deepmatcher/data/{torchtext_extensions.py → dataset.py}
- +133 −0 deepmatcher/data/field.py
- +53 −0 deepmatcher/data/iterator.py
- +91 −13 deepmatcher/data/process.py
- +130 −0 deepmatcher/model.py
- +2 −2 deepmatcher/models/_utils.py
- +105 −0 deepmatcher/test.py
- +20 −0 docs/Makefile
- +36 −0 docs/make.bat
- +178 −0 docs/source/conf.py
- +20 −0 docs/source/index.rst
- +6 −0 setup.cfg
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@@ -0,0 +1,54 @@ | ||
from collections import namedtuple | ||
|
||
import torch | ||
|
||
AttrTensor_ = namedtuple('AttrTensor', ['data', 'lengths', 'word_probs', 'pc']) | ||
|
||
|
||
class AttrTensor(AttrTensor_): | ||
|
||
@staticmethod | ||
def __new__(cls, *args, **kwargs): | ||
if len(kwargs) == 0: | ||
return super(AttrTensor, cls).__new__(cls, *args) | ||
else: | ||
name = kwargs['name'] | ||
attr = kwargs['attr'] | ||
train_dataset = kwargs['train_dataset'] | ||
if isinstance(attr, tuple): | ||
data = attr[0] | ||
lengths = attr[1] | ||
else: | ||
data = attr | ||
lengths = None | ||
word_probs = None | ||
if 'word_probs' in train_dataset.metadata: | ||
raw_word_probs = train_dataset.metadata['word_probs'][name] | ||
word_probs = torch.Tensor( | ||
[[raw_word_probs[w] for w in b] for b in data.data]) | ||
if data.is_cuda: | ||
word_probs = word_probs.cuda() | ||
pc = None | ||
if 'pc' in train_dataset.metadata: | ||
pc = torch.Tensor(train_dataset.metadata['pc'][name]) | ||
if data.is_cuda: | ||
pc = pc.cuda() | ||
return AttrTensor(data, lengths, word_probs, pc) | ||
|
||
@staticmethod | ||
def from_old_metadata(data, old_attrtensor): | ||
return AttrTensor(data, *old_attrtensor[1:]) | ||
|
||
|
||
class Batch(object): | ||
|
||
def __init__(self, input, train_dataset): | ||
copy_fields = train_dataset.all_text_fields | ||
for name in copy_fields: | ||
setattr(self, name, | ||
AttrTensor( | ||
name=name, attr=getattr(input, name), | ||
train_dataset=train_dataset)) | ||
for name in [train_dataset.label_field, train_dataset.id_field]: | ||
if name is not None: | ||
setattr(self, name, getattr(input, name)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@@ -1,7 +1,6 @@ | ||
from .field import MatchingField | ||
from .dataset import MatchingDataset | ||
from .iterator import MatchingIterator | ||
from .process import process | ||
|
||
__all__ = [MatchingField, MatchingDataset, MatchingIterator, process] |
Oops, something went wrong.