Skip to content

Commit

Permalink
Merge remote-tracking branch 'zeynepakkalyoncu/master' into dogfood
Browse files Browse the repository at this point in the history
  • Loading branch information
achyudh committed Dec 19, 2018
2 parents f94b673 + 308d48d commit 34c4dc7
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 26 deletions.
Binary file modified dist/tardis-0.1-py3.6.egg
Binary file not shown.
1 change: 1 addition & 0 deletions lib/data/fetch.py
Expand Up @@ -97,6 +97,7 @@ def en_vi(path, source_vocab=None, target_vocab=None, reverse=False, replace_unk
print("Converting words to indices for", splits, "split...")
encoder_input_data, decoder_input_data, decoder_target_data = build_indices(source_data, target_data, source_vocab,
target_vocab, one_hot)

if splits.lower() == 'test':
return encoder_input_data, decoder_input_data, decoder_target_data, raw_target_data, source_vocab, target_vocab
else:
Expand Down
37 changes: 14 additions & 23 deletions lib/model/__main__.py
@@ -1,19 +1,13 @@
import os
import socket
from copy import deepcopy
import multiprocessing

import time

from keras.backend.tensorflow_backend import set_session
import tensorflow as tf

from elephas.spark_model import SparkModel
from elephas.utils.rdd_utils import to_simple_rdd

from contextlib import contextmanager
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession

from elephas.utils.rdd_utils import to_simple_rdd
from elephas.spark_model import SparkModel

from keras.callbacks import ModelCheckpoint

Expand Down Expand Up @@ -118,29 +112,26 @@
model = Seq2Seq(model_config)

if args.ensemble:
conf = SparkConf().setAppName('Tardis').setMaster('local[*]').set('spark.executor.instances', '4') #.set('spark.driver.allowMultipleContexts', 'true')
# sc = SparkContext.getOrCreate(conf=conf)
sc = SparkContext(conf=conf)
conf = SparkConf().setAppName('tardis').setMaster('local')
sc = SparkContext.getOrCreate(conf=conf)

model = SparkModel(model.model, frequency='epoch', mode='asynchronous') # Distributed ensemble
# TODO: fix
train_input = np.dstack((encoder_train_input, decoder_train_input))
rdd = to_simple_rdd(sc, train_input, decoder_train_target)

# train_pairs = [(x, y) for x, y in zip([encoder_train_input, decoder_train_input], decoder_train_target)]
# train_rdd = sc.parallelize(train_pairs, model_config.num_workers)
encoder_train_rdd = sc.parallelize(encoder_train_input)
decoder_train_rdd = sc.parallelize(decoder_train_input)
decoder_train_target = sc.parallelize(decoder_train_target)

train_rdd = to_simple_rdd(sc, [encoder_train_input, decoder_train_input], decoder_train_target)
model = Seq2Seq(model_config)
spark_model = SparkModel(model.model, frequenc='epoch', mode='synchronous')

# test_pairs = [(x, y) for x, y in zip([encoder_test_input, decoder_test_input], raw_test_target)]
# test_rdd = sc.parallelize(test_pairs, model_config.num_workers)

# TODO: fix - multiple context!
model.fit(train_rdd,
spark_model.fit(train_rdd,
batch_size=model_config.batch_size,
epochs=model_config.epochs,
validation_split=0.20,
verbose=1)

sc.stop()

else:
model.train_generator(training_generator, validation_generator)
model.evaluate(encoder_test_input, raw_test_target)
3 changes: 0 additions & 3 deletions lib/model/seq2seq.py
Expand Up @@ -15,9 +15,6 @@


class Seq2Seq:
config = None
model = None

def __init__(self, config):
self.config = config
recurrent_unit = self.config.recurrent_unit.lower()
Expand Down

0 comments on commit 34c4dc7

Please sign in to comment.