Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Merge pull request #29 from PetrochukM/pytorch
Browse files Browse the repository at this point in the history
Upgrade to PyTorch 0.4 + Small Fixes
  • Loading branch information
PetrochukM committed May 6, 2018
2 parents 9fc48f5 + 00f043d commit 588c340
Show file tree
Hide file tree
Showing 27 changed files with 117 additions and 195 deletions.
23 changes: 12 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,29 @@ PyTorch-NLP, or torchnlp for short, is a library of neural network layers, text
Join our community, add datasets and neural network layers! Chat with us on [Gitter](https://gitter.im/PyTorch-NLP/Lobby) and join the [Google Group](https://groups.google.com/forum/#!forum/pytorch-nlp), we're eager to collaborate with you.

![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pytorch-nlp.svg?style=flat-square)
[![Codecov](https://img.shields.io/codecov/c/github/PetrochukM/PyTorch-NLP/master.svg?style=flat-square)](https://codecov.io/gh/PetrochukM/PyTorch-NLP)
[![Codecov](https://img.shields.io/codecov/c/github/PetrochukM/PyTorch-NLP/master.svg?style=flat-square)](https://codecov.io/gh/PetrochukM/PyTorch-NLP)
[![Documentation Status]( https://img.shields.io/readthedocs/pytorchnlp/latest.svg?style=flat-square)](http://pytorchnlp.readthedocs.io/en/latest/?badge=latest&style=flat-square)
[![Build Status](https://img.shields.io/travis/PetrochukM/PyTorch-NLP/master.svg?style=flat-square)](https://travis-ci.org/PetrochukM/PyTorch-NLP)


## Installation

Make sure you have Python 3.5+ and PyTorch 0.2.0 or newer. You can then install `pytorch-nlp` using
pip:

pip install pytorch-nlp
## Docs 📖

## Docs 📖

The complete documentation for PyTorch-NLP is available via [our ReadTheDocs website](https://pytorchnlp.readthedocs.io).

## Basics

Add PyTorch-NLP to your project by following one of the common use cases:

### Load a [Dataset](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.datasets.html)
### Load a [Dataset](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.datasets.html)

Load the IMDB dataset, for example:
Load the IMDB dataset, for example:

```python
from torchnlp.datasets import imdb_dataset
Expand All @@ -42,7 +43,7 @@ train = imdb_dataset(train=True)
train[0] # RETURNS: {'text': 'For a movie that gets..', 'sentiment': 'pos'}
```

### Apply [Neural Networks](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.nn.html) Layers
### Apply [Neural Networks](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.nn.html) Layers

For example, from the neural network package, apply a Simple Recurrent Unit (SRU):

Expand All @@ -56,12 +57,12 @@ sru = SRU(10, 20)
# Apply a Simple Recurrent Unit to `input_`
sru(input_)
# RETURNS: (
# output [torch.FloatTensor (6x3x20)],
# output [torch.FloatTensor (6x3x20)],
# hidden_state [torch.FloatTensor (2x3x20)]
# )
```

### [Encode Text](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.text_encoders.html)
### [Encode Text](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.text_encoders.html)

Tokenize and encode text as a tensor. For example, a `WhitespaceEncoder` breaks text into terms whenever it encounters a whitespace character.

Expand All @@ -77,7 +78,7 @@ encoder.decode(encoder.encode("This ain't funny.")) # RETURNS: "this ain't funny
```

### Load [Word Vectors](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.word_to_vector.html)

For example, load FastText, state-of-the-art English word vectors:

```python
Expand Down Expand Up @@ -128,8 +129,8 @@ AllenNLP is designed to be a platform for research. PyTorch-NLP is designed to b

## Authors

* [Michael Petrochuk](https://github.com/PetrochukM/) — Developer
* [Chloe Yeo](http://www.yeochloe.com/) — Logo Design
* [Michael Petrochuk](https://github.com/PetrochukM/) — Developer
* [Chloe Yeo](http://www.yeochloe.com/) — Logo Design

## Citing

Expand Down
5 changes: 3 additions & 2 deletions build_tools/travis/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pip install -r requirements.txt
# Optional Requirements
pip install spacy
pip install nltk
pip install sacremoses

# SpaCy English web model
python -m spacy download en
Expand All @@ -35,9 +36,9 @@ python -m nltk.downloader perluniprops nonbreaking_prefixes

# Install PyTorch Dependancies
if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then
pip install http://download.pytorch.org/whl/cpu/torch-0.3.1-cp36-cp36m-linux_x86_64.whl
pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-linux_x86_64.whl
fi
if [[ $TRAVIS_PYTHON_VERSION == '3.5' ]]; then
pip install http://download.pytorch.org/whl/cpu/torch-0.3.1-cp35-cp35m-linux_x86_64.whl
pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl
fi
pip install torchvision
2 changes: 2 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ coverage:
patch:
default:
target: 90
threshold: 1
project:
default:
target: 90
threshold: 1
43 changes: 0 additions & 43 deletions examples/awd-lstm-lm/embed_regularize.py

This file was deleted.

26 changes: 12 additions & 14 deletions examples/awd-lstm-lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import math
import numpy as np
import torch
from torch.autograd import Variable

import model

from utils import repackage_hidden
Expand Down Expand Up @@ -43,7 +41,7 @@
help='amount of weight dropout to apply to the RNN hidden to hidden matrix')
parser.add_argument('--seed', type=int, default=1111, help='random seed')
parser.add_argument('--nonmono', type=int, default=5, help='random seed')
parser.add_argument('--cuda', action='store_false', help='use CUDA')
parser.add_argument('--cuda', action='store_true', default=False, help='use CUDA')
parser.add_argument('--log-interval', type=int, default=200, metavar='N', help='report interval')
randomhash = ''.join(str(time.time()).split('.'))
parser.add_argument(
Expand Down Expand Up @@ -182,13 +180,14 @@ def evaluate(data_source, source_sampler, target_sampler, batch_size=10):

for source_sample, target_sample in zip(source_sampler, target_sampler):
model.train()
data = Variable(torch.stack([data_source[i] for i in source_sample]), volatile=True)
targets = Variable(torch.stack([data_source[i] for i in target_sample])).view(-1)
output, hidden = model(data, hidden)
data = torch.stack([data_source[i] for i in source_sample])
targets = torch.stack([data_source[i] for i in target_sample]).view(-1)
with torch.no_grad():
output, hidden = model(data, hidden)
total_loss += len(data) * criterion(model.decoder.weight, model.decoder.bias, output,
targets).data
targets).item()
hidden = repackage_hidden(hidden)
return total_loss[0] / len(data_source)
return total_loss / len(data_source)


def train():
Expand All @@ -201,9 +200,8 @@ def train():
batch = 0
for source_sample, target_sample in zip(train_source_sampler, train_target_sampler):
model.train()
data = Variable(torch.stack([train_data[i] for i in source_sample])).t_().contiguous()
targets = Variable(torch.stack(
[train_data[i] for i in target_sample])).t_().contiguous().view(-1)
data = torch.stack([train_data[i] for i in source_sample]).t_().contiguous()
targets = torch.stack([train_data[i] for i in target_sample]).t_().contiguous().view(-1)

# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
Expand All @@ -226,12 +224,12 @@ def train():

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
if args.clip:
torch.nn.utils.clip_grad_norm(params, args.clip)
torch.nn.utils.clip_grad_norm_(params, args.clip)
optimizer.step()

total_loss += raw_loss.data
total_loss += raw_loss.item()
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss[0] / args.log_interval
cur_loss = total_loss / args.log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
Expand Down
21 changes: 8 additions & 13 deletions examples/awd-lstm-lm/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
import torch.nn as nn
from torch.autograd import Variable

from embed_regularize import embedded_dropout
from torchnlp.nn import LockedDropout
from torchnlp.nn import WeightDrop

Expand Down Expand Up @@ -98,9 +96,8 @@ def init_weights(self):
self.decoder.weight.data.uniform_(-initrange, initrange)

def forward(self, input, hidden, return_h=False):
emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0)
#emb = self.idrop(emb)

emb = self.encoder(input)
emb = self.emb_drop(emb)

raw_output = emb
Expand Down Expand Up @@ -129,17 +126,15 @@ def forward(self, input, hidden, return_h=False):
def init_hidden(self, bsz):
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return [(Variable(
weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else
(self.ninp if self.tie_weights else self.nhid)).zero_()),
Variable(
weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else
(self.ninp if self.tie_weights else self.nhid)).zero_()))
return [(weight.new_zeros(1, bsz, self.nhid if l != self.nlayers - 1 else
(self.ninp if self.tie_weights else self.nhid)),
weight.new_zeros(1, bsz, self.nhid if l != self.nlayers - 1 else
(self.ninp if self.tie_weights else self.nhid)))
for l in range(self.nlayers)]
elif self.rnn_type == 'QRNN' or self.rnn_type == 'GRU':
return [
Variable(
weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else
(self.ninp if self.tie_weights else self.nhid)).zero_())
weight.new_zeros(1, bsz, self.nhid
if l != self.nlayers - 1 else (self.ninp
if self.tie_weights else self.nhid))
for l in range(self.nlayers)
]
6 changes: 3 additions & 3 deletions examples/awd-lstm-lm/splitcross.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def forward(self, weight, bias, hiddens, targets, verbose=False):
optimizer = torch.optim.SGD(list(embed.parameters()) + list(crit.parameters()), lr=1)

for _ in range(E):
prev = torch.autograd.Variable((torch.rand(N, 1) * 0.999 * V).int().long())
x = torch.autograd.Variable((torch.rand(N, 1) * 0.999 * V).int().long())
prev = (torch.rand(N, 1) * 0.999 * V).int().long()
x = (torch.rand(N, 1) * 0.999 * V).int().long()
y = embed(prev).squeeze()
c = crit(embed.weight, bias, y, x.view(N))
print('Crit', c.exp().data[0])
print('Crit', c.exp().item())

logprobs = crit.logprob(embed.weight, bias, y[:2]).exp()
print(logprobs)
Expand Down
8 changes: 4 additions & 4 deletions examples/awd-lstm-lm/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from torch.autograd import Variable
import torch


def repackage_hidden(h):
"""Wraps hidden states in new Variables, to detach them from their history."""
if type(h) == Variable:
return Variable(h.data)
"""Wraps hidden states in new Tensors, to detach them from their history."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
7 changes: 3 additions & 4 deletions examples/snli/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
from torch.autograd import Variable


class Bottle(nn.Module):
Expand Down Expand Up @@ -33,7 +32,7 @@ def __init__(self, config):
def forward(self, inputs):
batch_size = inputs.size()[1]
state_shape = self.config.n_cells, batch_size, self.config.d_hidden
h0 = c0 = Variable(inputs.data.new(*state_shape).zero_())
h0 = c0 = inputs.detach().new_zeros(*state_shape)
outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(
batch_size, -1)
Expand Down Expand Up @@ -62,8 +61,8 @@ def forward(self, premise, hypothesis):
prem_embed = self.embed(premise)
hypo_embed = self.embed(hypothesis)
if self.config.fix_emb:
prem_embed = Variable(prem_embed.data)
hypo_embed = Variable(hypo_embed.data)
prem_embed = prem_embed.detach()
hypo_embed = hypo_embed.detach()
if self.config.projection:
prem_embed = self.relu(self.projection(prem_embed))
hypo_embed = self.relu(self.projection(hypo_embed))
Expand Down
16 changes: 9 additions & 7 deletions examples/snli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@

# switch model to training mode, clear gradient accumulators
model.train()
torch.set_grad_enabled(True)
opt.zero_grad()

iterations += 1
Expand All @@ -108,7 +109,7 @@

# calculate accuracy of predictions in the current batch
n_correct += (torch.max(answer,
1)[1].view(label_batch.size()).data == label_batch.data).sum()
1)[1].view(label_batch.size()) == label_batch).sum()
n_total += premise_batch.size()[1]
train_acc = 100. * n_correct / n_total

Expand All @@ -123,7 +124,7 @@
if iterations % args.save_every == 0:
snapshot_prefix = os.path.join(args.save_path, 'snapshot')
snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(
train_acc, loss.data[0], iterations)
train_acc, loss.item(), iterations)
torch.save(model, snapshot_path)
for f in glob.glob(snapshot_prefix + '*'):
if f != snapshot_path:
Expand All @@ -134,6 +135,7 @@

# switch model to evaluation mode
model.eval()
torch.set_grad_enabled(False)

# calculate accuracy on validation set
n_dev_correct, dev_loss = 0, 0
Expand All @@ -149,15 +151,15 @@
label_batch) in enumerate(dev_iterator):
answer = model(premise_batch, hypothesis_batch)
n_dev_correct += (torch.max(answer, 1)[1].view(
label_batch.size()).data == label_batch.data).sum()
label_batch.size()) == label_batch).sum()
dev_loss = criterion(answer, label_batch)
dev_acc = 100. * n_dev_correct / len(dev)

print(
dev_log_template.format(time.time() - start, epoch, iterations, 1 + batch_idx,
len(train_sampler),
100. * (1 + batch_idx) / len(train_sampler), loss.data[0],
dev_loss.data[0], train_acc, dev_acc))
100. * (1 + batch_idx) / len(train_sampler), loss.item(),
dev_loss.item(), train_acc, dev_acc))

# update best validation set accuracy
if dev_acc > best_dev_acc:
Expand All @@ -167,7 +169,7 @@
best_dev_acc = dev_acc
snapshot_prefix = os.path.join(args.save_path, 'best_snapshot')
snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(
dev_acc, dev_loss.data[0], iterations)
dev_acc, dev_loss.item(), iterations)

# save model, delete previous 'best_snapshot' files
torch.save(model, snapshot_path)
Expand All @@ -181,4 +183,4 @@
print(
log_template.format(time.time() - start, epoch, iterations, 1 + batch_idx,
len(train_sampler), 100. * (1 + batch_idx) / len(train_sampler),
loss.data[0], ' ' * 8, n_correct / n_total * 100, ' ' * 12))
loss.item(), ' ' * 8, n_correct / n_total * 100, ' ' * 12))

0 comments on commit 588c340

Please sign in to comment.