Skip to content

Commit

Permalink
Merge pull request #161 from pskrunner14/develop
Browse files Browse the repository at this point in the history
Integration test expected cuda bug resolved
  • Loading branch information
pskrunner14 committed Sep 1, 2018
2 parents 1d6c6e2 + d86e0fa commit f119537
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 14 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,10 @@ target/
.vagrant

vocab_pickle

# VS Code config
.vscode/*

# local dev folders
data/*
experiment/*
4 changes: 2 additions & 2 deletions Vagrantfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Vagrant.configure("2") do |config|
# https://docs.vagrantup.com.

# Every Vagrant development environment requires a box. You can search for
# boxes at https://atlas.hashicorp.com/search.
# boxes at https://app.vagrantup.com/boxes/search.
config.vm.box = "ubuntu/trusty64"

# Disable automatic box update checking. If you disable this, then
Expand Down Expand Up @@ -48,7 +48,7 @@ Vagrant.configure("2") do |config|
vb.gui = false

# Customize the amount of memory on the VM:
vb.memory = "2048"
vb.memory = "4096"
end
#
# View the documentation for the provider you are using for more
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

# General information about the project.
project = u'pytorch-seq2seq'
copyright = u'2017, pytorch-seq2seq Contritors'
copyright = u'2018, pytorch-seq2seq Contributors'
author = u'pytorch-seq2seq Contributors'

# The version info for the project you're documenting, acts as replacement for
Expand Down
2 changes: 1 addition & 1 deletion scripts/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def len_filter(example):
eos_id=tgt.eos_id, sos_id=tgt.sos_id)
seq2seq = Seq2seq(encoder, decoder)
if torch.cuda.is_available():
seq2seq.cuda()
seq2seq = seq2seq.cuda()

for param in seq2seq.parameters():
param.data.uniform_(-0.08, 0.08)
Expand Down
13 changes: 12 additions & 1 deletion seq2seq/models/TopKDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.nn.functional as F
from torch.autograd import Variable

CUDA = torch.cuda.is_available()

def _inflate(tensor, times, dim):
"""
Given a tensor, 'inflates' it along the given dimension by replicating each slice specified number of times (in-place)
Expand Down Expand Up @@ -116,6 +118,10 @@ def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, functi

# Initialize the input vector
input_var = Variable(torch.transpose(torch.LongTensor([[self.SOS] * batch_size * self.k]), 0, 1))
if CUDA:
self.pos_index = self.pos_index.cuda()
input_var = input_var.cuda()
sequence_scores = sequence_scores.cuda()

# Store decisions for backtracking
stored_outputs = list()
Expand Down Expand Up @@ -223,9 +229,14 @@ def _backtrack(self, nw_output, nw_hidden, predecessors, symbols, scores, b, hid
# the last hidden state of decoding.
if lstm:
state_size = nw_hidden[0][0].size()
h_n = tuple([torch.zeros(state_size), torch.zeros(state_size)])
if CUDA:
h_n = tuple([torch.zeros(state_size).cuda(), torch.zeros(state_size).cuda()])
else:
h_n = tuple([torch.zeros(state_size), torch.zeros(state_size)])
else:
h_n = torch.zeros(nw_hidden[0].size())
if CUDA:
h_n = h_n.cuda()
l = [[self.rnn.max_length] * self.k for _ in range(b)] # Placeholder for lengths of top-k sequences
# Similar to `h_n`

Expand Down
18 changes: 9 additions & 9 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_get_latest_checkpoint(self, mock_listdir):
'2017_05_22_09_47_31',
'2017_05_23_10_47_29']
latest_checkpoint = Checkpoint.get_latest_checkpoint(self.EXP_DIR)
self.assertEquals(latest_checkpoint,
self.assertEqual(latest_checkpoint,
os.path.join(self.EXP_DIR,
'checkpoints/2017_05_23_10_47_29'))

Expand All @@ -50,15 +50,15 @@ def test_save_checkpoint_calls_torch_save(self, mock_open, mock_dill, mock_torch

path = chk_point.save(self._get_experiment_dir())

self.assertEquals(2, mock_torch.save.call_count)
self.assertEqual(2, mock_torch.save.call_count)
mock_torch.save.assert_any_call(state_dict,
os.path.join(chk_point.path, Checkpoint.TRAINER_STATE_NAME))
mock_torch.save.assert_any_call(mock_model,
os.path.join(chk_point.path, Checkpoint.MODEL_NAME))
self.assertEquals(2, mock_open.call_count)
self.assertEqual(2, mock_open.call_count)
mock_open.assert_any_call(os.path.join(path, Checkpoint.INPUT_VOCAB_FILE), ANY)
mock_open.assert_any_call(os.path.join(path, Checkpoint.OUTPUT_VOCAB_FILE), ANY)
self.assertEquals(2, mock_dill.dump.call_count)
self.assertEqual(2, mock_dill.dump.call_count)
mock_dill.dump.assert_any_call(mock_vocab,
mock_open.return_value.__enter__.return_value)

Expand All @@ -80,11 +80,11 @@ def test_load(self, mock_open, mock_dill, mock_torch):
mock_torch.load.assert_any_call(
os.path.join("mock_checkpoint_path", Checkpoint.MODEL_NAME))

self.assertEquals(loaded_chk_point.epoch, torch_dict['epoch'])
self.assertEquals(loaded_chk_point.optimizer, torch_dict['optimizer'])
self.assertEquals(loaded_chk_point.step, torch_dict['step'])
self.assertEquals(loaded_chk_point.input_vocab, dummy_vocabulary)
self.assertEquals(loaded_chk_point.output_vocab, dummy_vocabulary)
self.assertEqual(loaded_chk_point.epoch, torch_dict['epoch'])
self.assertEqual(loaded_chk_point.optimizer, torch_dict['optimizer'])
self.assertEqual(loaded_chk_point.step, torch_dict['step'])
self.assertEqual(loaded_chk_point.input_vocab, dummy_vocabulary)
self.assertEqual(loaded_chk_point.output_vocab, dummy_vocabulary)

def _get_experiment_dir(self):
root_dir = os.path.dirname(os.path.realpath(__file__))
Expand Down

0 comments on commit f119537

Please sign in to comment.