Skip to content
Browse files
Setup sphinx quickstart, reorg data load, adding docs
  • Loading branch information
sidharthms committed Apr 2, 2018
1 parent 693ba0e commit 6eee0152a930e1106c7c5086af566cdad1d6b5a9
@@ -0,0 +1,54 @@
from collections import namedtuple

import torch

AttrTensor_ = namedtuple('AttrTensor', ['data', 'lengths', 'word_probs', 'pc'])

class AttrTensor(AttrTensor_):

def __new__(cls, *args, **kwargs):
if len(kwargs) == 0:
return super(AttrTensor, cls).__new__(cls, *args)
name = kwargs['name']
attr = kwargs['attr']
train_dataset = kwargs['train_dataset']
if isinstance(attr, tuple):
data = attr[0]
lengths = attr[1]
data = attr
lengths = None
word_probs = None
if 'word_probs' in train_dataset.metadata:
raw_word_probs = train_dataset.metadata['word_probs'][name]
word_probs = torch.Tensor(
[[raw_word_probs[w] for w in b] for b in])
if data.is_cuda:
word_probs = word_probs.cuda()
pc = None
if 'pc' in train_dataset.metadata:
pc = torch.Tensor(train_dataset.metadata['pc'][name])
if data.is_cuda:
pc = pc.cuda()
return AttrTensor(data, lengths, word_probs, pc)

def from_old_metadata(data, old_attrtensor):
return AttrTensor(data, *old_attrtensor[1:])

class Batch(object):

def __init__(self, input, train_dataset):
copy_fields = train_dataset.all_text_fields
for name in copy_fields:
setattr(self, name,
name=name, attr=getattr(input, name),
for name in [train_dataset.label_field, train_dataset.id_field]:
if name is not None:
setattr(self, name, getattr(input, name))
@@ -1,7 +1,6 @@
from .torchtext_extensions import (AttrTensor, MatchingField, MatchingDataset,
MatchingBatch, MatchingIterator)
from .field import MatchingField
from .dataset import MatchingDataset
from .iterator import MatchingIterator
from .process import process

__all__ = [
AttrTensor, MatchingField, MatchingDataset, MatchingBatch, MatchingIterator, process
__all__ = [MatchingField, MatchingDataset, MatchingIterator, process]

0 comments on commit 6eee015

Please sign in to comment.