Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward on parallel do using nccl #8361

Merged
merged 38 commits into from
Feb 20, 2018
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
67881ad
compile with nccl2
Feb 6, 2018
634f523
add ncclGroup; it is necessary in nccl2
Feb 7, 2018
1c91574
backward insert callback pass compile
Feb 9, 2018
672cdc2
add nccl
Feb 9, 2018
e9ddaab
disable ncclInit infer shape & var type
Feb 9, 2018
f2129b1
pass run time
Feb 10, 2018
0815c0f
add assign op
Feb 10, 2018
23bbaad
Mt pusherge branch 'develop' of http://github.com/paddlepaddle/paddle…
wangkuiyi Feb 10, 2018
0d57ca4
nccl pass parallel_do test
Feb 10, 2018
bb3ae20
nccl pass parallel_do test
Feb 10, 2018
4bb492e
pass tiny data
Feb 11, 2018
bfa78ca
clean up log(info)
Feb 11, 2018
cd9e660
merge develop
Feb 11, 2018
3067114
clean up
Feb 11, 2018
82c33c6
Fix constructor bug in mixed_vector
reyoung Feb 11, 2018
816fa8f
Fix warnings
reyoung Feb 11, 2018
ae2296e
Clean code
reyoung Feb 11, 2018
190119b
Extract for-loop init. Make nvcc happy
reyoung Feb 11, 2018
0e2deaa
Merge remote-tracking branch 'pr/8364' into backward_on_parallel_do
Feb 11, 2018
0c45eab
no getmutable nccl_com
Feb 11, 2018
f35401c
diable debug string due to vector bug
Feb 11, 2018
37792e5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
reyoung Feb 12, 2018
3c47c73
add back libnccl-dev
Feb 12, 2018
da97d9d
merge develop
Feb 12, 2018
5f343e3
Merge remote-tracking branch 'pr/8411' into backward_on_parallel_do
Feb 12, 2018
a259ad4
remove duplicated cbegin and cend in mixed vector
Feb 12, 2018
7129fa3
merge develop
Feb 13, 2018
e021ad6
Merge remote-tracking branch 'upstream/develop' into backward_on_para…
Feb 13, 2018
3f09620
pass compile
Feb 13, 2018
bea80b0
Merge remote-tracking branch 'upstream/develop' into backward_on_para…
Feb 14, 2018
9d26f1a
callback to list of callbacks
Feb 15, 2018
1d9fd1c
pass test_recognize_digits
Feb 16, 2018
5229ccb
merge develop
Feb 16, 2018
eb82b5c
test error clip
Feb 16, 2018
3494b79
test error clip
Feb 16, 2018
ec01f63
merge develop
Feb 16, 2018
ae69f0b
merge develop
Feb 17, 2018
4b957af
clean up
Feb 17, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::NCCL_COM) {
// GetMutable will be called in ncclInit
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER]",
"LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]",
var_type);
}
}
Expand Down Expand Up @@ -120,14 +122,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,

for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(4) << op->DebugStringEx(local_scope);

platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(op->Type(), pool.Get(place_));

op->Run(*local_scope, place_);
// Wait current device context.
VLOG(3) << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);

if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ message VarType {
LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
READER = 9;
NCCL_COM = 10;
}

required Type type = 1;
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/operators/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
"Due to the settings of paddings, filter_dims and "
"dilations, the output size is less than 0, please check "
"again.");
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i], strides[i]));
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I was debugging conv_op.cc, looks like OutputSize has been linked to another function...

dilations[i], paddings[i],
strides[i]));
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output");
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/conv_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ using Tensor = framework::Tensor;

// Base convolution operator definations for other conv
// like operators to reuse the implementation.
inline int OutputSize(int input_size, int filter_size, int dilation,
int padding, int stride) {
inline int ConvOutputSize(int input_size, int filter_size, int dilation,
Copy link
Author

@tonyyang-svail tonyyang-svail Feb 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name OutputSize is too general. And it is under the namescope paddle::namespace, which is too broad.

int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
const int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
return output_size;
Expand Down
46 changes: 37 additions & 9 deletions paddle/fluid/operators/nccl_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle/fluid/operators/nccl/nccl_gpu_common.h is included twice.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo. thanks for pointing it out.


namespace paddle {
namespace operators {

static constexpr char kParallelScopes[] = "parallel_scopes";

// NCCLinitOp
class NCCLInitOp : public framework::OperatorBase {
public:
Expand All @@ -29,11 +32,22 @@ class NCCLInitOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kParallelScopes)),
"Can not find variable '%s' in the scope.",
kParallelScopes);
const auto &name = Output("Communicator");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
"Can not find variable '%s' in the scope.", name);
std::vector<int> gpus = Attr<std::vector<int>>("gpus");
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
// A parallel do may not use all the gpus. For example, the batch size is 7
// in the last batch while we have 8 gpu. In this case, parallel_do will
// create 7 parallel scopes, so should ncclInitOp create 7 gpu peers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mentioned "last batch", is it implying ncclInitOp will be called for every mini-batch?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

auto &parallel_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Scope support serialization?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When do we need to serialize scope?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, never mind, I got confused.

std::vector<int> gpus(parallel_scopes.size());
for (int i = 0; i < static_cast<int>(parallel_scopes.size()); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only parallel_scopes.size() is used, could be just pass kNumParallelScopes instead of kParallelScopes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't know kNumParallelScopes at the compilation time

gpus[i] = i;
}
PADDLE_ENFORCE(!gpus.empty(), "NCCL init with 0 gpus.");

if (scope.FindVar(name) == nullptr) {
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
Expand All @@ -45,17 +59,29 @@ class NCCLInitOp : public framework::OperatorBase {
}
};

class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Communicator").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::NCCL_COM;
out_var.SetType(var_type);
}
};

class NCCLInitOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};

class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kParallelScopes, "The working place of parallel do.");
AddOutput("Communicator",
"Create Communicator for communicating between gpus");
AddAttr<std::vector<int>>("gpus", "(vector<int>) GPU id lists");
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
AddComment(R"DOC(
NCCLInit Operator.

Expand All @@ -78,7 +104,7 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
ctx->HasInput("Communicator"),
" Input(Communicator) of AllReduce op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Input(X) of AllReduce op input should not be NULL");
" Output(Out) of AllReduce op output should not be NULL");

auto x_dims = ctx->GetInputsDim("X");

Expand Down Expand Up @@ -215,7 +241,9 @@ Bcast the tensors.

namespace ops = paddle::operators;
REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker);
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker,
ops::NCCLInitOpVarTypeInference,
ops::NCCLInitOpShapeInference);

REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
ops::NCCLAllReduceOpMaker);
Expand Down
26 changes: 24 additions & 2 deletions paddle/fluid/operators/parallel_do_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ static constexpr char kOutputs[] = "outputs";
static constexpr char kParallelScopes[] = "parallel_scopes";

static constexpr char kParallelBlock[] = "sub_block";
static constexpr char kUseNCCL[] = "use_nccl";

using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
Expand Down Expand Up @@ -194,6 +195,8 @@ class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput(kOutputs, "").AsDuplicable();
AddOutput(kParallelScopes, "");
AddAttr<framework::BlockDesc *>(kParallelBlock, "");
AddAttr<bool>(kUseNCCL, "true if we use nccl on backward")
.SetDefault(false);
AddComment(R"DOC(
ParallelDo Operator.
)DOC");
Expand All @@ -216,7 +219,6 @@ class ParallelDoGradOp : public framework::OperatorBase {

auto &sub_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>();

auto &places = scope.FindVar(Input(kPlaces))->Get<platform::PlaceList>();

// feed output@grad
Expand All @@ -243,14 +245,34 @@ class ParallelDoGradOp : public framework::OperatorBase {
}
WaitOnPlaces(places);

AccumulateGrad(scope, place, sub_scopes, places);
// NCCL allreduce op will be added by backward,
// so no need to explicitly accumulate grad
if (!(Attr<bool>(kUseNCCL))) {
AccumulateGrad(scope, place, sub_scopes, places);
} else {
for (auto &place : places) {
PADDLE_ENFORCE(platform::is_gpu_place(place),
"NCCL only supports cuda place");
}
}
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
if (s == "@EMPTY@") {
Copy link
Author

@tonyyang-svail tonyyang-svail Feb 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward will change some of the gradients to @EMPTY@ if we don't need to calculate it. For example, we don't need to calculate the gradient of layer.data. In this case, parallel_do should skip them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can @EMPTY@ be put into a constant?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.

continue;
}
VLOG(3) << "Moving " << s;
CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s));
}
WaitOnPlaces(places);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but I am curious why parallel do has to wait for all stream to complete? I thought even the executor does not wait.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a case where it is wrong without waiting. But just as we always wait for threads to be joined after we launched them, I feel it's nature for parallel_do to wait for all streams.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could affect performance. We introduced a synchronization point which we are not sure if we need.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several places in parallel do where we have to wait. This line won't be a large effect

}

void AccumulateGrad(const framework::Scope &scope,
const platform::Place &place,
const std::vector<framework::Scope *> &sub_scopes,
const platform::PlaceList &places) const {
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
if (s == "@EMPTY@") {
continue;
}
VLOG(3) << "Accumulating " << s;
if (s == framework::kEmptyVarName) continue;
std::string tmp_name;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ void BindVarDsec(py::module &m) {
.value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarType::PLACE_LIST)
.value("READER", proto::VarType::READER);
.value("READER", proto::VarType::READER)
.value("NCCL_COM", proto::VarType::NCCL_COM);
}

void BindOpDesc(py::module &m) {
Expand Down
108 changes: 94 additions & 14 deletions python/paddle/v2/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,76 @@ def _op_can_be_removed_(op_desc, no_grad_set):
return op_descs


import proto.framework_pb2 as framework_pb2


def serialize_op_decs(op_desc):
protostr = op_desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(str(protostr))
return proto.__str__()


def _callback_lookup_(op):
"""
Only used in _append_backward_ops_
Build and returns a callback function for certain op. For example
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more comment about what is callback function? (e.g, is it something gets called after a OP is completed?)


parallel_do: AllReduce

:param op:
:return: callback function
"""
if op.type == 'parallel_do' and op.attr('use_nccl'):
param_names = set(op.input('parameters'))
param_grad_names = [n + "@GRAD" for n in param_names]

class ParallelDoCallBack(object):
def __init__(self, param_grad_names, parallel_scopes_name):
self.has_inserted_nccl_init = False
self.param_grad_names = param_grad_names
self.parallel_scopes_name = parallel_scopes_name

def __call__(self, block, context):
if not self.has_inserted_nccl_init:
op_desc = _create_op_desc_(
"ncclInit",
{"parallel_scopes": self.parallel_scopes_name},
{"Communicator": ['nccl_com__do_not_change_']}, {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put nccl_com__do_not_change_ into a constant?

block.program.global_block().desc.append_op().copy_from(
op_desc)
self.has_inserted_nccl_init = True

current_op_desc = context["__current_op_desc__"]
for o_param in current_op_desc.output_names():
for o_argu in current_op_desc.output(o_param):
if o_argu in self.param_grad_names:
allreduce_out_name = o_argu + "__nccl_all_reduce__"
op_desc = _create_op_desc_(
"ncclAllReduce", {
"X": [o_argu],
"Communicator":
['nccl_com__do_not_change_']
}, {"Out": [allreduce_out_name]},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allreduce_out_name is assigned to o_argu in the next op, why o_argu can not be the output here so we don't need the next assign op.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ncclAllreduce requires a buffer memory to hold the result, i.e. it doesn't support in place.

{"reduction": "ncclSum"})
block.desc.append_op().copy_from(op_desc)

op_desc = _create_op_desc_(
"assign", {"X": [allreduce_out_name]},
{"Out": [o_argu]}, {})
block.desc.append_op().copy_from(op_desc)

return ParallelDoCallBack(param_grad_names,
op.output("parallel_scopes"))
else:
return None


def _append_backward_ops_(block,
ops,
target_block,
no_grad_dict,
grad_to_var,
callback=None):
callbacks=None):
"""
Create all grad ops, and insert them into given block

Expand All @@ -220,14 +284,11 @@ def _append_backward_ops_(block,
val(str): corresponding forward variable name
callback(callable object): a callable object used to decorate new generated grad ops
"""
if callback is None:

def empty_callback(block, context):
pass

callback = empty_callback
elif not hasattr(callback, '__call__'):
raise ValueError("'callback' must be a callable object.")
if callbacks is not None:
assert (isinstance(callbacks, list))
for cb in callbacks:
if not hasattr(cb, '__call__'):
raise ValueError("'callback' must be a callable object.")

# grad_op_descs holds created grad_op, and will be appended to target_block
grad_op_descs = []
Expand All @@ -238,8 +299,17 @@ def empty_callback(block, context):
if op.has_attr("sub_block"):
sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var)
cb = _callback_lookup_(op)
if cb is not None:
if callbacks is None:
new_callbacks = [cb]
else:
new_callbacks = callbacks + [_callback_lookup_(op)]
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var, new_callbacks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do callbacks = new_callbacks, since callbacks is used later as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callbacks used later should not contain _callback_lookup_(op)

else:
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var, callbacks)
grad_sub_block_list.append(grad_sub_block.desc)

# Getting op's corresponding grad_op
Expand All @@ -258,7 +328,11 @@ def empty_callback(block, context):
for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc)
callback(block=target_block, context=grad_to_var)
grad_to_var["__current_op_desc__"] = new_op_desc
if callbacks is not None:
assert (isinstance(callbacks, list))
for cb in callbacks:
cb(block=target_block, context=grad_to_var)


def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
Expand Down Expand Up @@ -296,6 +370,9 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
# infer_shape and infer_type
op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc)
# ncclInit dones't need to set data_type
if op_desc.type() == 'ncclInit':
continue
for arg in op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_(arg, block)
Expand Down Expand Up @@ -335,7 +412,8 @@ def _get_stop_gradients_(program):
return no_grad_dict


def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
def append_backward(loss, parameter_list=None, no_grad_set=None,
callbacks=None):
"""
Append backward part to main_program

Expand All @@ -351,6 +429,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
(list[(Variable,Variable)]): list of (parameter, gradient) pair.
"""
assert isinstance(loss, framework.Variable)
if callbacks is not None:
isinstance(callbacks, list)

program = loss.block.program
if no_grad_set is None:
Expand Down Expand Up @@ -378,7 +458,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set))

_append_backward_ops_(root_block, op_path, root_block, no_grad_dict,
grad_to_var, callback)
grad_to_var, callbacks)

# Because calc_gradient may be called multiple times,
# we need rename the internal gradient variables so that they have
Expand Down
Loading