From 14b2b7f552996a40cb3a42713c118c48f7370c37 Mon Sep 17 00:00:00 2001 From: superjom Date: Fri, 25 Aug 2017 15:11:01 -0400 Subject: [PATCH 1/6] init --- doc/design/python/user_interface.md | 362 ++++++++++++++++++ python/paddle/python_wrapper_demo/__init__.py | 4 + python/paddle/python_wrapper_demo/block.py | 135 +++++++ python/paddle/python_wrapper_demo/common.py | 7 + python/paddle/python_wrapper_demo/layer.py | 10 + .../paddle/python_wrapper_demo/namespace.py | 63 +++ python/paddle/python_wrapper_demo/op.py | 55 +++ python/paddle/python_wrapper_demo/utils.py | 59 +++ python/paddle/python_wrapper_demo/variable.py | 67 ++++ 9 files changed, 762 insertions(+) create mode 100644 doc/design/python/user_interface.md create mode 100644 python/paddle/python_wrapper_demo/__init__.py create mode 100644 python/paddle/python_wrapper_demo/block.py create mode 100644 python/paddle/python_wrapper_demo/common.py create mode 100644 python/paddle/python_wrapper_demo/layer.py create mode 100644 python/paddle/python_wrapper_demo/namespace.py create mode 100644 python/paddle/python_wrapper_demo/op.py create mode 100644 python/paddle/python_wrapper_demo/utils.py create mode 100644 python/paddle/python_wrapper_demo/variable.py diff --git a/doc/design/python/user_interface.md b/doc/design/python/user_interface.md new file mode 100644 index 0000000000000..0b42dbe5bfd9a --- /dev/null +++ b/doc/design/python/user_interface.md @@ -0,0 +1,362 @@ +# User Interface Design + +## Basic Concepts +### Variable +A `Variable` represents shared, persistent state manipulated by a Paddle model program. + +Variables are maintained by `pd.Variable` class, +each `pd.Variable` represents a tensor whose value can be changed by running ops on it. + +A basic way to create a variable is: + +```python +import paddle as pd + +v = pd.Variable(shape=[20, 20]) +``` + +To make it more converient to share a variable, each `pd.Variable` has a name, +one can use a name to get or create a `pd.Variable` by calling `pd.get_variable`, for example: + +```python +# same as +v = pd.get_variable(name="v", shape=[20, 20]) +``` + +By default, Variables are model parameters, and will be updated after the network's back propagation. + +One can freeze a variable by setting `trainable` to `False` like: + +```python +v = pd.Variable(shape=[20,20], trainable=False) +``` + +Some initizlization strategies may be applied to variables, for example, we may set a variable to zero or gaussian random. + +``` +v = pd.Variable(shape=[20,20], initializer=pd.zero_initializer()) +z = pd.Variable(shape=[20,20], initializer=pd.gaussian_initializer(mean=0., std=0.1)) +``` + +to get the value of the variable, one can call + +```python +print v.val() +``` + + +### Block +Paddle use a `Block` to represent and execute user's program, +this is a basic concept when user write a Paddle program. + +In computer programming, a block is a lexical structure of source code which is grouped together. +In most programming languages, block is useful when define a function or some conditional statements such as `if-else`, `while`. + +Similarlly, the function of `pd.Block` in Paddle is to enable groups of operators to be treated as if they were one operator to make `if_else_op` or RNNOp's declaration simpler and Python's `with` statement is used to make the codes look much like a block. + +For example, when defining a `RNNOp`, we can use `pd.Block` to help configure a step network: + +```python +v = some_op() +m_boot = some_op() + +W = pd.Variable(shape=[20, 20]) +U = pd.Variable(shape=[20, 20]) + +rnn0 = RNNOp() +with rnn0.stepnet() as net: + # declare stepnet's inputs + x = net.add_input(v) + # declare memories + h = net.add_memory(m_boot) + + fc_out = pd.matmul(W, x) + hidden_out = pd.matmul(U, h) + sum = pd.add_two(fc_out, hidden_out) + act = pd.sigmoid(sum) + + # declare stepnet's outputs + net.add_output(act, hidden_out) + +acts, hs = rnn0() +``` + +The operators inside the `with`-statement defines the rnn's step network, +and will be put into a `pd.Block`. + +another example is the definition of `if_else_op`: + +```python +# v0 is a output of some_op +v0 = some_op() +v1 = some_op() + +ifelseop = pd.if_else_op() +with ifelseop.true_block() as net: + x0, x1 = net.add_input(v0, v1) + + y = pd.fc(x) + z = pd.add_two(x1, y) + + net.add_output(z) + +with ifelseop.false_block() as net: + x0, x1 = net.add_input(v0, v1) + + y = pd.add_two(x0, x1) + + net.add_output(y) + +# output of ifelseop +out = ifelseop() +``` + +In most cases, user need not to create a `pd.Block` directly, but it is the basis of a Paddle program: + +- user's program is stored in `pd.Block` +- when we want to run the codes, we just need to execute a corresponding `pd.Block` + +A `pd.Block` can has its own namespace, which makes it possible to hide the local variables from block block. + +```python +W = pd.Variable(shape=[20, 20]) + +# a and b are outputs of some_op +a = some_op() +b = some_op() + +with pd.Block('namespace0'): + # W is a local variable and has its own value + W = pd.Variable(shape=[20, 20]) + x = pd.matmul(W, a) + y = x + b + +with pd.Block('namespace1'): + # W is the global variable + z = pd.matmul(W, a) + +# g use local variables in both namespace0 and namespace1 +g = pd.add_two(y, z) +``` + +### Op (short for Operator) +`Op` defines basic operation unit of optimized computation graph in Paddle, one `Op` has several input and output variables, and some attributes. + +Take `pd.matmul` for example, one can use it like this + +```python +out = pd.matmul(a, b) +``` +which means that a operator `pd.matmul` takes two variables `a` and `b` for input, +and return a variable `out`. + +### Layer +`Layer` defines a more complex operation which may combines several `Op`s, its usage is the same with `Op`. + +Take `pd.fc` for example, one can use it like this +```python +out = pd.fc(in, param_names=['W']) +``` +which means that the `pd.fc` takes an variable `in`, and set its `param_names` attribute to `['W']`, + which will determine the names of its parameters. + +Both `Op` and `Layer` will be appended to current `pd.Block` when they are created, +and there will be a sequene of Ops/Layers in the `pd.Block`, +if the `pd.Block` is executed, all the Ops/Layers in this `pd.Block` will be called in order. + +### Special Ops +#### Initializer Ops +These ops will initialize variables, for example, we may have + +- `pd.zero_initializer()` +- `pd.gaussian_random_initializer(mean, std)` + +Each trainable variable has a initialize Op. + +#### Optimizer Ops +These ops will help to optimize trainable variables after backward propagation finished, +each variable will have a optimizer. + +## Compatible with V2 Syntax + +## Some Demos +### MNist Task Demo + +```python +import paddle as pd + +# the first shape is None, which means the batch size of variable is not known. +image = pd.Variable(shape=[None, 128]) +label = pd.Variable(shape=[None, 1]) + +# network config +W1 = pd.Variable('W1', shape=[128, 64]) + +fc_out = pd.matmul(image, W1) +prediction = pd.softmax(fc_out, size=10) + +cost = pd.cross_entropy(prediction, label) + +optimizer = pd.SGDOptimizer().minimize(cost) + + +# training details +def data_provider(path): + images = [] + labels = [] + with open(path) as f: + for no, line in enumerate(f): + fs = line.split('\t') + assert len(fs) == 2 + image_record = map(int, fs[0].split()) + label_record = [int(fs[1])] + images.append(image_record) + labels.append(label_record) + if no > 0 and no % 100 == 0: + yield np.array(images), np.array(labels) + images = [] + labels = [] + + +for pass_no in range(100): + for batch_no, batch in enumerate(data_provider('./data.txt')): + # train mode + _, cost_ = pd.eval( + [optimizer, cost], feeds={image: batch[0], + label: batch[1]}) + print '%dth pass train cost: %f' % (pass_no, cost_) + # test mode + if batch_no > 0 and batch_no % 10 == 0: + cost_ = pd.eval(cost) + print '%dth pass test cost' % (pass_no, cost_) +``` + +### GAN Task Demo + +```python +import paddle as pd + +# input variable whose batch size is unknown now +X = pd.Variable(shape=[None, 128]) + +# Discriminator Net +# define parameters + +# Generator Net +Z = pd.data(pd.float_vector(100)) + +theta_G = [G_W1, G_W2, G_b1, G_b2] + + +def sample_Z(m, n): + return np.random.uniform(-1, 1., size=[m, n]) + + +def discriminator(x): + # use block with namespace to hide local variables + with pd.Block('discriminator') as block: + # declare model parameters + W1 = pd.get_variable( + 'W1', + shape=[784, 128], + initializer=pd.gaussian_random_initializer(std=0.1), + reuse=True) + b1 = pd.get_variable( + 'b1', data=np.zeros(128), + reuse=True + ) # variable also support initialization using a numpy data + W2 = pd.get_variable('W2', data=np.random.rand(128, 1), + reuse=True) + b2 = pd.Variable('b2', data=np.zeros(128), + reuse=True) + + # network config + h1 = pd.relu(pd.matmul(x, W1) + b1) + fake = pd.matmul(h1, w2) + b2 + prob = pd.sigmoid(fake) + return prob, fake + + +theta_D = [D_W1, D_b1, D_W2, D_b2] + + +def generator(z): + with pd.Block('generator') as block: + # declare model parameters + W1 = pd.get_variable( + 'W1', + shape=[784, 128], + initializer=pd.gaussian_random_initializer()) + b1 = pd.get_variable( + 'b1', data=np.zeros(128) + ) # variable also support initialization using a numpy data + W2 = pd.get_variable('W2', data=np.random.rand(128, 1)) + b2 = pd.get_variable('b2', data=np.zeros(128)) + + # network config + h1 = pd.relu(pd.matmul(z, W1) + b1) + log_prob = pd.matmul(h1, W2) + b2 + prob = pd.sigmoid(log_prob) + return prob + + +# a mini-batch of 1. as probability 100% +ones_label = pd.Variable(shape=[None, 1]) +# a mini-batch of 0. as probability 0% +zeros_label = pd.Variable(shape=[None, 1]) + +# model config +G_sample = generator(Z) +D_real_prob, D_real_image = discriminator(X) +D_fake_prob, D_fake_image = discriminator(G_sample) + +D_loss_real = pd.reduce_mean( + pd.cross_entropy(data=D_real_prob, label=ones_label)) +D_loss_fake = pd.reduce_mean( + pd.cross_entropy(data=D_real_fake, label=zeros_label)) +D_loss = D_loss_real + D_loss_fake + +G_loss = pd.reduce_mean(pd.cross_entropy(data=D_loss_fake, label=ones_label)) + +D_solver = pd.AdamOptimizer().minimize(D_loss, var_list=theta_D) +G_solver = pd.AdamOptimizer().minimize(G_loss, var_list=theta_G) + +# init all parameters +initializer = pd.variable_initialzier() +# also ok: initializer = pd.variable_initialzier(vars=theta_D+theta_G)ize, +pd.eval(initializer) + + +def data_provier(path): + # ... + yield batch + + +for i in range(10000): + for batch_no, batch in enumerate(data_provider('train_data.txt')): + # train Descrimator first + _, D_loss_cur = pd.eval( + [D_solver, D_loss], + feeds={ + X: batch, + Z: sample_Z(batch.size, 10), + ones_label: np.ones([batch.size, 1]), + zeros_label: np.zeros([batch.size, 1]) + }) + # get Generator's fake samples + samples = pd.eval(G_sample, feeds={Z: sample_Z(16, 100)}) + + # train Generator latter + _, G_loss_cur = pd.eval( + [G_solver, G_loss], + feeds={ + Z: sample_Z(batch.size, 10), + ones_label: np.ones([batch.size, 1]), + zeros_label: np.zeros([batch.size, 1]) + }) + + if batch_no % 100: + logger.info("batch %d, D loss: %f" % (batch_no, D_loss_cur)) + logger.info("batch %d, G loss: %f" % (batch_no, G_loss_cur)) +``` + diff --git a/python/paddle/python_wrapper_demo/__init__.py b/python/paddle/python_wrapper_demo/__init__.py new file mode 100644 index 0000000000000..2fc6b5f044732 --- /dev/null +++ b/python/paddle/python_wrapper_demo/__init__.py @@ -0,0 +1,4 @@ +from block import * +from variable import * +from op import * +from layer import * diff --git a/python/paddle/python_wrapper_demo/block.py b/python/paddle/python_wrapper_demo/block.py new file mode 100644 index 0000000000000..39bc5803d1e6f --- /dev/null +++ b/python/paddle/python_wrapper_demo/block.py @@ -0,0 +1,135 @@ +import paddle.v2.framework.core as core + +__all__ = [ + 'Block', + 'g_block', + 'block', + 'eval', +] + + +class Block(object): + ''' + Block is the concept of code block, which has a sequence of local Variables + and Operators. + ''' + + def __init__(self): + ''' + namespace: str + ''' + self.cmds = [] + + def append(self, cmd): + ''' + cmd: Block or Op or Layer + ''' + self.cmds.append(cmd) + + def execute(self): + ''' + Execute this block, this will run all the operators and update the coresponding + output variables. + ''' + self._build_nn() + self.net.run() + + def _build_nn(self): + self.net = core.Net.create() + ops = self.__extract_op_from_block(self.cmds) + for op in ops: + self.net.append_op(op) + self.net.complete_add_op(True) + + def __extract_op_from_block(self, cmds): + ops = [] + for cmd in cmds: + if type(cmd) is Block: + child_ops = self.__extract_op_from_block([cmd]) + ops += child_ops + else: + ops.append(cmd) + return ops + + +g_block = Block() + + +#TODO this need to be renamed +class block_guard(object): + ''' + a wrapper for Block, which automatically change g_block, Namespace. + + usage: + + import paddle as pd + + with pd.block() as block: + a = pd.Variable() + b = pd.Variable() + c = pd.add_two(a, b) + block.execute() + ''' + cur_block = g_block + last_block = None + counter = 0 + + def __init__(self, namespace='', block=None, execute_immediately=True): + ''' + namespace: str + current block's namespace, if leave default, father's namespace will be used. + block: Block + current block, if set None, a new Block will be created. + execute_immediately: bool + if execute_immediately is True, then all the operators of this block will be + inserted into father's block immediately, if false, this block is independent + from father's block. + ''' + self.namespace = namespace if namespace else block_guard.inc_counter() + self.block = block if block else Block() + self.execute_immediately = execute_immediately + + def __enter__(self): + Namespace.begin(self.namespace) + block_guard.last_block = block_guard.cur_block + block_guard.cur_block = self.block + + if self.execute_immediately: + block_guard.last_block.append(block_guard.cur_block) + + def __exit__(self): + Namespace.end() + block_guard.cur_block = block_guard.last_block + + def inc_counter(): + c = block_guard.counter + block_guard.counter += 1 + return c + + +def eval(fetches, block=g_block): + ''' + fetches: list of Variable + block: Block + + evaluate all the variables in `fetches`. In details, it will trace the sub-graph which + the variables of `fetches` depends and compile a new corespondding block, and execute it. + + usage: + + # use g_block as default + var_value = pd.eval([var]) + + # use a specific block + var_value = pd.eval([var], block=a_block) + ''' + assert all(isinstance(_, Variable) for _ in fetches), "fetches should be Variables" + graph = DependenceGraph(block.cmds) + op_or_layers = graph.DFS_with_targets(fetches) + + with block_guard() as B: + for cmd in op_or_layers: + B.block.append(cmd) + B.execute() + # return python values + return [v.val() for v in fetches] diff --git a/python/paddle/python_wrapper_demo/common.py b/python/paddle/python_wrapper_demo/common.py new file mode 100644 index 0000000000000..8caa023b51559 --- /dev/null +++ b/python/paddle/python_wrapper_demo/common.py @@ -0,0 +1,7 @@ +import logging +import paddle.v2.framework.core as core + +g_scope = core.Scope() + +logger = logging.getLogger("paddle python") +logger.setLevel(logging.INFO) diff --git a/python/paddle/python_wrapper_demo/layer.py b/python/paddle/python_wrapper_demo/layer.py new file mode 100644 index 0000000000000..017ac06a854b1 --- /dev/null +++ b/python/paddle/python_wrapper_demo/layer.py @@ -0,0 +1,10 @@ +class Layer(object): + def __init__(self, type, *args, **kwargs): + self.inputs = {} + self.outputs = {} + + def __hash__(self): + raise NotImplemented + + def run(self): + raise NotImplemented diff --git a/python/paddle/python_wrapper_demo/namespace.py b/python/paddle/python_wrapper_demo/namespace.py new file mode 100644 index 0000000000000..f514d7241ee43 --- /dev/null +++ b/python/paddle/python_wrapper_demo/namespace.py @@ -0,0 +1,63 @@ +import os +from common import g_scope + + +class Namespace(object): + ''' + Namespace is similar to tf.variable_scope, which helps to store local + variables of different `Block`s into one Scope, and make it possible + to reference Variable across different Blocks. + ''' + + stack = [] + + def __init__(self, name): + ''' + parents: list of str + ''' + self.name = name + parents = [n.name for n in Namespace.stack] + self.prefix = os.path.join(*parents) if parents else '' + + def sub_namespace(self, name): + prefix = os.path.join(self.name, name) + return Namespace(prefix) + + def parent_namespace(self): + if not stack: + return None + return stack[-1] + + def getname(name): + return os.path.join(self.prefix, self.name, name) + + @staticmethod + def cur(): + if Namespace.stack: + return Namespace.stack[-1] + return Namespace('') + + def enter(self): + Namespace.stack.append(self) + + def exit(self): + Namespace.stack.pop() + + def __repr__(self): + return "" % os.path.join(self.prefix, self.name) + + def __enter__(self): + self.enter() + + def __exit__(self, type, value, traceback): + self.exit() + + +if __name__ == '__main__': + with Namespace('apple'): + print Namespace.cur() + + with Namespace('beer'): + print Namespace.cur() + print 'should be apple', Namespace.cur() + print 'should be empty', Namespace.cur() diff --git a/python/paddle/python_wrapper_demo/op.py b/python/paddle/python_wrapper_demo/op.py new file mode 100644 index 0000000000000..9d8adb138b7a5 --- /dev/null +++ b/python/paddle/python_wrapper_demo/op.py @@ -0,0 +1,55 @@ +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator +from variable import Variable + + +class Op(object): + ''' + Operator wrapper for core operators. + ''' + def __init__(self, type): + self.type = type + self.inputs = {} + self.outputs = {} + + def __call__(self, *args, **kwargs): + self._prepare_inputs(args, kwargs) + self._prepare_outputs() + self._extract_str_args_for_op() + self._create_op() + + def _prepare_inputs(self, args, kwargs): + for idx, input_name in enumerate( + Operator.get_op_input_names(self.type)): + print 'input_name', input_name + if idx > len(args): + self.inputs[input_name] = kwargs[input_name] + continue + self.inputs[input_name] = args[idx] + for k, v in kwargs.items(): + self.inputs[k] = v + + def _prepare_outputs(self): + for out in Operator.get_op_output_names(self.type): + name = "%s-%d" % (out, Var.count) + var = Var(name) + self.outputs[name] = var + + def _extract_str_args_for_op(self): + self.str_arg = {} + for k, v in self.inputs.items(): + self.str_arg[k] = v.name if isinstance(v, Variable) else v + + for k, v in self.outputs.items(): + self.str_arg[k] = v.name + + def _create_op(self, str_arg): + self.op = Operator(self.type, **str_arg) + + def __hash__(self): + return hash("%s-%s" % (self.type, self.str_arg)) + + def __repr__(self): + return "".format( + type=self.type, + args=' '.join('%s->%s' % (k, v) for k, v in self.str_arg.items)) diff --git a/python/paddle/python_wrapper_demo/utils.py b/python/paddle/python_wrapper_demo/utils.py new file mode 100644 index 0000000000000..b10fe103e58fa --- /dev/null +++ b/python/paddle/python_wrapper_demo/utils.py @@ -0,0 +1,59 @@ +from layer import Layer +from op import Op +from variable import Variable + +class DependenceGraph(object): + def __init__(self, nodes): + ''' + nodes: list of Variable, Operator or Layer + + reversed: bool + whether to generate a reversed graph, target->source + ''' + # source -> list of targets + self.edges = {} + self.nodes = nodes + self._build() + + def DFS_with_targets(self, targets, filter_type=[Op, Layer]): + ''' + targets: list + ''' + visited = set() + visited_nodes = [] + + # TODO(superjom) need a copy here? + stack = targets + cur = None + + for target in targets: + while stack: + cur = stack.pop() + if cur not in visited: + visited.add(cur) + visited_nodes.append(cur) + stack.append(self.graph.edges[cur]) + + visited_nodes = filter( + lambda _: any(isinstance(_, t) for t in filter_type), + visited_nodes) + + return visited_nodes + + def _build(self): + assert not self.edges, "graph can't be built more than once" + for node in self.nodes: + if not isinstance(node, Op): + continue + # op -> input + for input in node.inputs: + if node not in self.edges: + self.edges[node] = set() + self.edges[node].add(input) + # output -> op + for output in node.outputs: + if output not in self.edges: + self.edges[output] = set() + self.edges[output].add(node) + + diff --git a/python/paddle/python_wrapper_demo/variable.py b/python/paddle/python_wrapper_demo/variable.py new file mode 100644 index 0000000000000..693abfdf28548 --- /dev/null +++ b/python/paddle/python_wrapper_demo/variable.py @@ -0,0 +1,67 @@ +from namespace import Namespace +from common import g_scope + + +class Variable(object): + ''' + Variable is the data type of Operator's inputs and outputs. + ''' + counter = 0 + __varset__ = set() + + def __init__(self, + name=None, + shape=[], + data=None, + initialzier=None, + scope=g_scope, + trainable=True, + learning_rate=0.01): + ''' + name: str + name of this variable, a unique name will be set if leave None. + shape: list of int + shape of the tensor stored in Variable. + initialzier: Op + which initialzier op to initialize this Variable + data: numpy + initialize this variable by numpy data directly + trainable: bool + whether this variable can be updated by optimizers. + learning_rate: float + learning_rate when optimizer update this variable. + ''' + self.name = self._gen_unique_name(name) + assert shape, "shape of Variable should be set" + self.shape = shape + self.is_param = is_param + # TODO(jacquesqiao) this state can be used by optimizers to determine + # the variables need to be updated. + self.trainable = trainable + self.learning_rate = learning_rate + + def val(self): + ''' + get python value from this Variable. + ''' + return self._tensor.as_numpy() + + def __repr__(self): + return "" % self.name + + def _create_core_variable(self): + self._core_var = self.scope.new_var(self.name) + self._tensor = self._core_var.get_tensor() + self._tensor.set_dims(self.shape) + + def _gen_unique_name(self, name=None): + if not name: + name = "var-%d" % Variable.counter + Variable += 1 + name = Namespace.gen_name(name) + else: + name = Namespace.gen_name(name) + + assert name not in Variable.__varset__, "Variable name [%s] duplicate" % name + Variable.__varset__.add(name) + return name From 052668d9b352de2d3d31188627b9dd6c5eab044c Mon Sep 17 00:00:00 2001 From: superjom Date: Fri, 25 Aug 2017 16:46:49 -0400 Subject: [PATCH 2/6] remove model keyword --- doc/design/python/user_interface.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/design/python/user_interface.md b/doc/design/python/user_interface.md index 0b42dbe5bfd9a..b952391f2a266 100644 --- a/doc/design/python/user_interface.md +++ b/doc/design/python/user_interface.md @@ -2,7 +2,7 @@ ## Basic Concepts ### Variable -A `Variable` represents shared, persistent state manipulated by a Paddle model program. +A `Variable` represents shared, persistent state manipulated by a Paddle program. Variables are maintained by `pd.Variable` class, each `pd.Variable` represents a tensor whose value can be changed by running ops on it. From 9cabce03f5acbd58ffc3963412677005344f3734 Mon Sep 17 00:00:00 2001 From: superjom Date: Fri, 25 Aug 2017 19:51:25 -0400 Subject: [PATCH 3/6] fix grammer --- doc/design/python/user_interface.md | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/doc/design/python/user_interface.md b/doc/design/python/user_interface.md index b952391f2a266..eb39040579534 100644 --- a/doc/design/python/user_interface.md +++ b/doc/design/python/user_interface.md @@ -5,9 +5,9 @@ A `Variable` represents shared, persistent state manipulated by a Paddle program. Variables are maintained by `pd.Variable` class, -each `pd.Variable` represents a tensor whose value can be changed by running ops on it. +each `pd.Variable` represents a tensor whose value can be changed by ops. -A basic way to create a variable is: +A basic way to create a variable is ```python import paddle as pd @@ -50,7 +50,7 @@ Paddle use a `Block` to represent and execute user's program, this is a basic concept when user write a Paddle program. In computer programming, a block is a lexical structure of source code which is grouped together. -In most programming languages, block is useful when define a function or some conditional statements such as `if-else`, `while`. +In most programming languages, block is useful when define a function or some conditional statements such as `if-else` and `while`. Similarlly, the function of `pd.Block` in Paddle is to enable groups of operators to be treated as if they were one operator to make `if_else_op` or RNNOp's declaration simpler and Python's `with` statement is used to make the codes look much like a block. @@ -114,9 +114,9 @@ out = ifelseop() In most cases, user need not to create a `pd.Block` directly, but it is the basis of a Paddle program: - user's program is stored in `pd.Block` -- when we want to run the codes, we just need to execute a corresponding `pd.Block` +- when run the codes, just execute the corresponding `pd.Block` -A `pd.Block` can has its own namespace, which makes it possible to hide the local variables from block block. +A `pd.Block` can has its own namespace, which makes it possible to hide the local variables from block, for example: ```python W = pd.Variable(shape=[20, 20]) @@ -178,6 +178,15 @@ These ops will help to optimize trainable variables after backward propagation f each variable will have a optimizer. ## Compatible with V2 Syntax +We have new concepts like Variable, Block and Op, which are very basic concepts, +and it should be possible to be compatible with V2 api as the underlying architecture. + +What's more, some recent models like GAN and tree-LSTM are hard to be expressed using just V2 api, +so the new concepts are vital to enable user writing new models in the future. + +**In a word, the new user's python interface will keep compatible with V2 api, but must give a few new concepts like `Variable`, `Op` and several helper functions to express more complex models.** + + ## Some Demos ### MNist Task Demo @@ -359,4 +368,3 @@ for i in range(10000): logger.info("batch %d, D loss: %f" % (batch_no, D_loss_cur)) logger.info("batch %d, G loss: %f" % (batch_no, G_loss_cur)) ``` - From ada1ca75657a5469e926b9232dc6b6cc7f58cfa2 Mon Sep 17 00:00:00 2001 From: superjom Date: Sat, 26 Aug 2017 20:37:14 -0400 Subject: [PATCH 4/6] add backwork --- doc/design/python/user_interface.md | 84 +++-------------------------- 1 file changed, 7 insertions(+), 77 deletions(-) diff --git a/doc/design/python/user_interface.md b/doc/design/python/user_interface.md index eb39040579534..3ee5c13eae8ac 100644 --- a/doc/design/python/user_interface.md +++ b/doc/design/python/user_interface.md @@ -45,78 +45,7 @@ print v.val() ``` -### Block -Paddle use a `Block` to represent and execute user's program, -this is a basic concept when user write a Paddle program. - -In computer programming, a block is a lexical structure of source code which is grouped together. -In most programming languages, block is useful when define a function or some conditional statements such as `if-else` and `while`. - -Similarlly, the function of `pd.Block` in Paddle is to enable groups of operators to be treated as if they were one operator to make `if_else_op` or RNNOp's declaration simpler and Python's `with` statement is used to make the codes look much like a block. - -For example, when defining a `RNNOp`, we can use `pd.Block` to help configure a step network: - -```python -v = some_op() -m_boot = some_op() - -W = pd.Variable(shape=[20, 20]) -U = pd.Variable(shape=[20, 20]) - -rnn0 = RNNOp() -with rnn0.stepnet() as net: - # declare stepnet's inputs - x = net.add_input(v) - # declare memories - h = net.add_memory(m_boot) - - fc_out = pd.matmul(W, x) - hidden_out = pd.matmul(U, h) - sum = pd.add_two(fc_out, hidden_out) - act = pd.sigmoid(sum) - - # declare stepnet's outputs - net.add_output(act, hidden_out) - -acts, hs = rnn0() -``` - -The operators inside the `with`-statement defines the rnn's step network, -and will be put into a `pd.Block`. - -another example is the definition of `if_else_op`: - -```python -# v0 is a output of some_op -v0 = some_op() -v1 = some_op() - -ifelseop = pd.if_else_op() -with ifelseop.true_block() as net: - x0, x1 = net.add_input(v0, v1) - - y = pd.fc(x) - z = pd.add_two(x1, y) - - net.add_output(z) - -with ifelseop.false_block() as net: - x0, x1 = net.add_input(v0, v1) - - y = pd.add_two(x0, x1) - - net.add_output(y) - -# output of ifelseop -out = ifelseop() -``` - -In most cases, user need not to create a `pd.Block` directly, but it is the basis of a Paddle program: - -- user's program is stored in `pd.Block` -- when run the codes, just execute the corresponding `pd.Block` - -A `pd.Block` can has its own namespace, which makes it possible to hide the local variables from block, for example: +### namespace ```python W = pd.Variable(shape=[20, 20]) @@ -125,16 +54,16 @@ W = pd.Variable(shape=[20, 20]) a = some_op() b = some_op() -with pd.Block('namespace0'): +with pd.namespace('namespace0'): # W is a local variable and has its own value W = pd.Variable(shape=[20, 20]) x = pd.matmul(W, a) y = x + b - -with pd.Block('namespace1'): + +with pd.namespace('namespace1'): # W is the global variable z = pd.matmul(W, a) - + # g use local variables in both namespace0 and namespace1 g = pd.add_two(y, z) ``` @@ -184,7 +113,8 @@ and it should be possible to be compatible with V2 api as the underlying archite What's more, some recent models like GAN and tree-LSTM are hard to be expressed using just V2 api, so the new concepts are vital to enable user writing new models in the future. -**In a word, the new user's python interface will keep compatible with V2 api, but must give a few new concepts like `Variable`, `Op` and several helper functions to express more complex models.** +**In a word, the new user's python interface will keep compatible with V2 api, but must give a few new concepts like `Variable`, +`Op` and several helper functions to express more complex models.** From 827784abb7ce8ba891f9f8768c0283649c7e59ae Mon Sep 17 00:00:00 2001 From: superjom Date: Sat, 26 Aug 2017 20:39:52 -0400 Subject: [PATCH 5/6] delte mnist demo --- doc/design/python/user_interface.md | 55 ++--------------------------- 1 file changed, 2 insertions(+), 53 deletions(-) diff --git a/doc/design/python/user_interface.md b/doc/design/python/user_interface.md index 3ee5c13eae8ac..3c4151ed429dd 100644 --- a/doc/design/python/user_interface.md +++ b/doc/design/python/user_interface.md @@ -119,57 +119,6 @@ so the new concepts are vital to enable user writing new models in the future. ## Some Demos -### MNist Task Demo - -```python -import paddle as pd - -# the first shape is None, which means the batch size of variable is not known. -image = pd.Variable(shape=[None, 128]) -label = pd.Variable(shape=[None, 1]) - -# network config -W1 = pd.Variable('W1', shape=[128, 64]) - -fc_out = pd.matmul(image, W1) -prediction = pd.softmax(fc_out, size=10) - -cost = pd.cross_entropy(prediction, label) - -optimizer = pd.SGDOptimizer().minimize(cost) - - -# training details -def data_provider(path): - images = [] - labels = [] - with open(path) as f: - for no, line in enumerate(f): - fs = line.split('\t') - assert len(fs) == 2 - image_record = map(int, fs[0].split()) - label_record = [int(fs[1])] - images.append(image_record) - labels.append(label_record) - if no > 0 and no % 100 == 0: - yield np.array(images), np.array(labels) - images = [] - labels = [] - - -for pass_no in range(100): - for batch_no, batch in enumerate(data_provider('./data.txt')): - # train mode - _, cost_ = pd.eval( - [optimizer, cost], feeds={image: batch[0], - label: batch[1]}) - print '%dth pass train cost: %f' % (pass_no, cost_) - # test mode - if batch_no > 0 and batch_no % 10 == 0: - cost_ = pd.eval(cost) - print '%dth pass test cost' % (pass_no, cost_) -``` - ### GAN Task Demo ```python @@ -193,7 +142,7 @@ def sample_Z(m, n): def discriminator(x): # use block with namespace to hide local variables - with pd.Block('discriminator') as block: + with pd.namespace('discriminator') as block: # declare model parameters W1 = pd.get_variable( 'W1', @@ -220,7 +169,7 @@ theta_D = [D_W1, D_b1, D_W2, D_b2] def generator(z): - with pd.Block('generator') as block: + with pd.namespace('generator') as block: # declare model parameters W1 = pd.get_variable( 'W1', From bfbb5bb13d66bdb0f57b9540910a223910999d1b Mon Sep 17 00:00:00 2001 From: superjom Date: Sat, 26 Aug 2017 23:22:00 -0400 Subject: [PATCH 6/6] restore block --- doc/design/python/user_interface.md | 142 ++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) diff --git a/doc/design/python/user_interface.md b/doc/design/python/user_interface.md index 3c4151ed429dd..75f5b7f4accfa 100644 --- a/doc/design/python/user_interface.md +++ b/doc/design/python/user_interface.md @@ -44,6 +44,148 @@ to get the value of the variable, one can call print v.val() ``` +### Block +Paddle use a `Block` to represent and execute user's program, +this is a basic concept when user write a Paddle program. + +In computer programming, a block is a lexical structure of source code which is grouped together. +In most programming languages, block is useful when define a function or some conditional statements such as `if-else` and `while`. + +Similarlly, the function of `pd.Block` in Paddle is to enable groups of operators to be treated as if they were one operator to make `if_else_op` or RNNOp's declaration simpler and Python's `with` statement is used to make the codes look much like a block. + +For example, when defining a `RNNOp`, we can use `pd.Block` to help configure a step network: + +```python +v = some_op() +m_boot = some_op() + +W = pd.Variable(shape=[20, 20]) +U = pd.Variable(shape=[20, 20]) + +rnn0 = RNNOp() +with rnn0.stepnet() as net: + # declare stepnet's inputs + x = net.add_input(v) + # declare memories + h = net.add_memory(m_boot) + + fc_out = pd.matmul(W, x) + hidden_out = pd.matmul(U, h) + sum = pd.add_two(fc_out, hidden_out) + act = pd.sigmoid(sum) + + # declare stepnet's outputs + net.add_output(act, hidden_out) + +acts, hs = rnn0() +``` + +The operators inside the `with`-statement defines the rnn's step network, +and will be put into a `pd.Block`. + +another example is the definition of `if_else_op`: + +```python +# v0 is a output of some_op +v0 = some_op() +v1 = some_op() + +ifelseop = pd.if_else_op() +with ifelseop.true_block() as net: + x0, x1 = net.add_input(v0, v1) + + y = pd.fc(x) + z = pd.add_two(x1, y) + + net.add_output(z) + +with ifelseop.false_block() as net: + x0, x1 = net.add_input(v0, v1) + + y = pd.add_two(x0, x1) + + net.add_output(y) + +# output of ifelseop +out = ifelseop() +``` + +In most cases, user need not to create a `pd.Block` directly, but it is the basis of a Paddle program: + +- user's program is stored in `pd.Block` +- when run the codes, just execute the corresponding `pd.Block` + +### Block +Paddle use a `Block` to represent and execute user's program, +this is a basic concept when user write a Paddle program. + +In computer programming, a block is a lexical structure of source code which is grouped together. +In most programming languages, block is useful when define a function or some conditional statements such as `if-else` and `while`. + +Similarlly, the function of `pd.Block` in Paddle is to enable groups of operators to be treated as if they were one operator to make `if_else_op` or RNNOp's declaration simpler and Python's `with` statement is used to make the codes look much like a block. + +For example, when defining a `RNNOp`, we can use `pd.Block` to help configure a step network: + +```python +v = some_op() +m_boot = some_op() + +W = pd.Variable(shape=[20, 20]) +U = pd.Variable(shape=[20, 20]) + +rnn0 = RNNOp() +with rnn0.stepnet() as net: + # declare stepnet's inputs + x = net.add_input(v) + # declare memories + h = net.add_memory(m_boot) + + fc_out = pd.matmul(W, x) + hidden_out = pd.matmul(U, h) + sum = pd.add_two(fc_out, hidden_out) + act = pd.sigmoid(sum) + + # declare stepnet's outputs + net.add_output(act, hidden_out) + +acts, hs = rnn0() +``` + +The operators inside the `with`-statement defines the rnn's step network, +and will be put into a `pd.Block`. + +another example is the definition of `if_else_op`: + +```python +# v0 is a output of some_op +v0 = some_op() +v1 = some_op() + +ifelseop = pd.if_else_op() +with ifelseop.true_block() as net: + x0, x1 = net.add_input(v0, v1) + + y = pd.fc(x) + z = pd.add_two(x1, y) + + net.add_output(z) + +with ifelseop.false_block() as net: + x0, x1 = net.add_input(v0, v1) + + y = pd.add_two(x0, x1) + + net.add_output(y) + +# output of ifelseop +out = ifelseop() +``` + +In most cases, user need not to create a `pd.Block` directly, but it is the basis of a Paddle program: + +- user's program is stored in `pd.Block` +- when run the codes, just execute the corresponding `pd.Block` + ### namespace