Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
2,660 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import sys | ||
import numpy as np | ||
import sklearn.metrics as metrics | ||
import argparse | ||
import time | ||
import json | ||
import importlib | ||
|
||
print "==> parsing input arguments" | ||
parser = argparse.ArgumentParser() | ||
|
||
# TODO: add argument to choose training set | ||
parser.add_argument('--network', type=str, default="network_batch", help='embeding size (50, 100, 200, 300 only)') | ||
parser.add_argument('--epochs', type=int, default=500, help='number of epochs to train') | ||
parser.add_argument('--load_state', type=str, default="", help='state file path') | ||
parser.add_argument('--mode', type=str, default="train", help='mode: train/test/test_on_train') | ||
parser.add_argument('--batch_size', type=int, default=32, help='no commment') | ||
parser.add_argument('--l2', type=float, default=0, help='L2 regularization') | ||
parser.add_argument('--log_every', type=int, default=100, help='print information every x iteration') | ||
parser.add_argument('--save_every', type=int, default=50000, help='save state every x iteration') | ||
parser.add_argument('--prefix', type=str, default="", help='optional prefix of network name') | ||
parser.add_argument('--dropout', type=float, default=0.0, help='dropout rate (between 0 and 1)') | ||
parser.add_argument('--no-batch_norm', dest="batch_norm", action='store_false', help='batch normalization') | ||
parser.add_argument('--rnn_num_units', type=int, default=500, help='number of hidden units if the network is RNN') | ||
parser.add_argument('--equal_split', type=bool, default=False, help='use trainEqual.csv and valEqual.csv') | ||
parser.add_argument('--forward_cnt', type=int, default=1, help='if forward pass is nondeterministic, then how many forward passes are averaged') | ||
|
||
parser.set_defaults(batch_norm=True) | ||
args = parser.parse_args() | ||
print args | ||
|
||
if (args.equal_split): | ||
train_listfile = open("/mnt/hdd615/Hrayr/Spoken-language-identification/trainEqual.csv", "r") | ||
test_listfile = open("/mnt/hdd615/Hrayr/Spoken-language-identification/valEqual.csv", "r") | ||
else: | ||
train_listfile = open("/mnt/hdd615/Hrayr/Spoken-language-identification/trainingDataNew.csv", "r") | ||
test_listfile = open("/mnt/hdd615/Hrayr/Spoken-language-identification/valDataNew.csv", "r") | ||
|
||
train_list_raw = train_listfile.readlines() | ||
test_list_raw = test_listfile.readlines() | ||
|
||
print "==> %d training examples" % len(train_list_raw) | ||
print "==> %d validation examples" % len(test_list_raw) | ||
|
||
train_listfile.close() | ||
test_listfile.close() | ||
|
||
args_dict = dict(args._get_kwargs()) | ||
args_dict['train_list_raw'] = train_list_raw | ||
args_dict['test_list_raw'] = test_list_raw | ||
args_dict['png_folder'] = "/mnt/hdd615/Hrayr/Spoken-language-identification/train/png/" | ||
|
||
|
||
|
||
print "==> using network %s" % args.network | ||
network_module = importlib.import_module("networks." + args.network) | ||
network = network_module.Network(**args_dict) | ||
|
||
|
||
network_name = args.prefix + '%s.bs%d%s%s' % ( | ||
network.say_name(), | ||
args.batch_size, | ||
".bn" if args.batch_norm else "", | ||
(".d" + str(args.dropout)) if args.dropout>0 else "") | ||
|
||
print "==> network_name:", network_name | ||
|
||
|
||
start_epoch = 0 | ||
if args.load_state != "": | ||
start_epoch = network.load_state(args.load_state) + 1 | ||
|
||
def do_epoch(mode, epoch): | ||
# mode is 'train' or 'test' or 'predict' | ||
y_true = [] | ||
y_pred = [] | ||
avg_loss = 0.0 | ||
prev_time = time.time() | ||
|
||
batches_per_epoch = network.get_batches_per_epoch(mode) | ||
all_prediction = [] | ||
|
||
for i in range(0, batches_per_epoch): | ||
step_data = network.step(i, mode) | ||
prediction = step_data["prediction"] | ||
answers = step_data["answers"] | ||
current_loss = step_data["current_loss"] | ||
log = step_data["log"] | ||
|
||
avg_loss += current_loss | ||
if (mode == "predict" or mode == "predict_on_train"): | ||
all_prediction.append(prediction) | ||
for pass_id in range(args.forward_cnt-1): | ||
step_data = network.step(i, mode) | ||
prediction += step_data["prediction"] | ||
current_loss += step_data["current_loss"] | ||
prediction /= args.forward_cnt | ||
current_loss /= args.forward_cnt | ||
|
||
for x in answers: | ||
y_true.append(x) | ||
|
||
for x in prediction.argmax(axis=1): | ||
y_pred.append(x) | ||
|
||
if ((i + 1) % args.log_every == 0): | ||
cur_time = time.time() | ||
print (" %sing: %d.%d / %d \t loss: %3f \t avg_loss: %.5f \t %s \t time: %.2fs" % | ||
(mode, epoch, (i + 1) * args.batch_size, batches_per_epoch * args.batch_size, | ||
current_loss, avg_loss / (i + 1), log, cur_time - prev_time)) | ||
prev_time = cur_time | ||
|
||
|
||
#print "confusion matrix:" | ||
#print metrics.confusion_matrix(y_true, y_pred) | ||
accuracy = sum([1 if t == p else 0 for t, p in zip(y_true, y_pred)]) | ||
print "accuracy: %.2f percent" % (accuracy * 100.0 / batches_per_epoch / args.batch_size) | ||
|
||
if (mode == "predict"): | ||
all_prediction = np.vstack(all_prediction) | ||
pred_filename = "predictions/" + ("equal_split." if args.equal_split else "") + \ | ||
args.load_state[args.load_state.rfind('/')+1:] + ".csv" | ||
with open(pred_filename, 'w') as pred_csv: | ||
for x in all_prediction: | ||
print >> pred_csv, ",".join([("%.6f" % prob) for prob in x]) | ||
|
||
return avg_loss / batches_per_epoch | ||
|
||
|
||
if args.mode == 'train': | ||
print "==> training" | ||
for epoch in range(start_epoch, args.epochs): | ||
do_epoch('train', epoch) | ||
test_loss = do_epoch('test', epoch) | ||
state_name = 'states/%s.epoch%d.test%.5f.state' % (network_name, epoch, test_loss) | ||
print "==> saving ... %s" % state_name | ||
network.save_params(state_name, epoch) | ||
|
||
elif args.mode == 'test': | ||
do_epoch('predict', 0) | ||
elif args.mode == 'test_on_train': | ||
do_epoch('predict_on_train', 0) | ||
else: | ||
raise Exception("unknown mode") |
Empty file.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import cPickle as pickle | ||
|
||
|
||
class BaseNetwork: | ||
|
||
def say_name(self): | ||
return "unknown" | ||
|
||
|
||
def save_params(self, file_name, epoch, **kwargs): | ||
with open(file_name, 'w') as save_file: | ||
pickle.dump( | ||
obj = { | ||
'params' : [x.get_value() for x in self.params], | ||
'epoch' : epoch, | ||
}, | ||
file = save_file, | ||
protocol = -1 | ||
) | ||
|
||
|
||
def load_state(self, file_name): | ||
print "==> loading state %s" % file_name | ||
epoch = 0 | ||
with open(file_name, 'r') as load_file: | ||
dict = pickle.load(load_file) | ||
loaded_params = dict['params'] | ||
for (x, y) in zip(self.params, loaded_params): | ||
x.set_value(y) | ||
epoch = dict['epoch'] | ||
return epoch | ||
|
||
|
||
def get_batches_per_epoch(self, mode): | ||
if (mode == 'train' or mode == 'predict_on_train'): | ||
return len(self.train_list_raw) / self.batch_size | ||
elif (mode == 'test' or mode == 'predict'): | ||
return len(self.test_list_raw) / self.batch_size | ||
else: | ||
raise Exception("unknown mode") | ||
|
||
|
||
def step(self, batch_index, mode): | ||
|
||
if (mode == "train"): | ||
data, answers = self.read_batch(self.train_list_raw, batch_index) | ||
theano_fn = self.train_fn | ||
elif (mode == "test" or mode == "predict"): | ||
data, answers = self.read_batch(self.test_list_raw, batch_index) | ||
theano_fn = self.test_fn | ||
elif (mode == "predict_on_train"): | ||
data, answers = self.read_batch(self.train_list_raw, batch_index) | ||
theano_fn = self.test_fn | ||
else: | ||
raise Exception("unrecognized mode") | ||
|
||
ret = theano_fn(data, answers) | ||
return {"prediction": ret[0], | ||
"answers": answers, | ||
"current_loss": ret[1], | ||
"log": "", | ||
} |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import random | ||
import numpy as np | ||
|
||
import theano | ||
import theano.tensor as T | ||
|
||
import lasagne | ||
from lasagne import layers | ||
from lasagne.nonlinearities import rectify, softmax, sigmoid, tanh | ||
|
||
import PIL.Image as Image | ||
from base_network import BaseNetwork | ||
|
||
floatX = theano.config.floatX | ||
|
||
|
||
class Network(BaseNetwork): | ||
|
||
def __init__(self, train_list_raw, test_list_raw, png_folder, batch_size, l2, mode, rnn_num_units, **kwargs): | ||
|
||
print "==> not used params in DMN class:", kwargs.keys() | ||
self.train_list_raw = train_list_raw | ||
self.test_list_raw = test_list_raw | ||
self.png_folder = png_folder | ||
self.batch_size = batch_size | ||
self.l2 = l2 | ||
self.mode = mode | ||
self.num_units = rnn_num_units | ||
|
||
self.input_var = T.tensor3('input_var') | ||
self.answer_var = T.ivector('answer_var') | ||
|
||
print "==> building network" | ||
example = np.random.uniform(size=(self.batch_size, 858, 256), low=0.0, high=1.0).astype(np.float32) ######### | ||
answer = np.random.randint(low=0, high=176, size=(self.batch_size,)) ######### | ||
|
||
# InputLayer | ||
network = layers.InputLayer(shape=(None, 858, 256), input_var=self.input_var) | ||
print layers.get_output(network).eval({self.input_var:example}).shape | ||
|
||
# GRULayer | ||
network = layers.GRULayer(incoming=network, num_units=self.num_units, only_return_final=True) | ||
print layers.get_output(network).eval({self.input_var:example}).shape | ||
|
||
# Last layer: classification | ||
network = layers.DenseLayer(incoming=network, num_units=176, nonlinearity=softmax) | ||
print layers.get_output(network).eval({self.input_var:example}).shape | ||
|
||
self.params = layers.get_all_params(network, trainable=True) | ||
self.prediction = layers.get_output(network) | ||
|
||
self.loss_ce = lasagne.objectives.categorical_crossentropy(self.prediction, self.answer_var).mean() | ||
if (self.l2 > 0): | ||
self.loss_l2 = self.l2 * lasagne.regularization.regularize_network_params(network, | ||
lasagne.regularization.l2) | ||
else: | ||
self.loss_l2 = 0 | ||
self.loss = self.loss_ce + self.loss_l2 | ||
|
||
#updates = lasagne.updates.adadelta(self.loss, self.params) | ||
updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.0005) | ||
|
||
if self.mode == 'train': | ||
print "==> compiling train_fn" | ||
self.train_fn = theano.function(inputs=[self.input_var, self.answer_var], | ||
outputs=[self.prediction, self.loss], | ||
updates=updates) | ||
|
||
print "==> compiling test_fn" | ||
self.test_fn = theano.function(inputs=[self.input_var, self.answer_var], | ||
outputs=[self.prediction, self.loss]) | ||
|
||
|
||
def say_name(self): | ||
return "rnn.GRU.num_units%d" % self.num_units | ||
|
||
|
||
def read_batch(self, data_raw, batch_index): | ||
|
||
start_index = batch_index * self.batch_size | ||
end_index = start_index + self.batch_size | ||
|
||
data = np.zeros((self.batch_size, 858, 256), dtype=np.float32) | ||
answers = [] | ||
|
||
for i in range(start_index, end_index): | ||
answers.append(int(data_raw[i].split(',')[1])) | ||
name = data_raw[i].split(',')[0] | ||
path = self.png_folder + name + ".png" | ||
im = Image.open(path) | ||
data[i - start_index, :, :] = np.transpose(np.array(im).astype(np.float32) / 256.0) | ||
|
||
answers = np.array(answers, dtype=np.int32) | ||
return data, answers | ||
|
||
|
Oops, something went wrong.