This repository has been archived by the owner on Nov 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 130
/
graph.ts
548 lines (479 loc) · 17.4 KB
/
graph.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
import {onnx} from 'onnx-proto';
import {Attribute} from './attribute';
import {Tensor} from './tensor';
import {ProtoUtil} from './util';
export declare namespace Graph {
export interface Shape {
readonly dims: ReadonlyArray<number>;
}
export interface ValueType {
readonly tensorType: Tensor.DataType;
readonly shape: Shape;
}
export interface Value {
// the tensor data. empty for non-initialized inputs
readonly tensor?: Tensor;
// index to the Node where the value comes from. -1 for initializer.
readonly from: number;
// indices to the Nodes where the values go to.
readonly to: ReadonlyArray<number>;
// value type specification. empty for non-input values.
readonly type?: ValueType;
}
export interface Node {
// name of the node
readonly name: string;
// the operator type
readonly opType: string;
// indices to the Values where the inputs come from.
readonly inputs: ReadonlyArray<number>;
// indices to the Values where the outpus go to.
readonly outputs: ReadonlyArray<number>;
// the attributes that used by the operator
readonly attributes: Attribute;
}
/**
* a Transformer is an instance that allows all possible transformation operations that applied to a graph
*/
export interface Transformer {
removeAllIdentityNodes(): void;
removeAllDropoutNodes(): void;
// TODO: add generic functions to manipulate the graph
}
// an initializer can use transformer to transform the graph
export interface Initializer {
transformGraph(transformer: Transformer): void;
}
}
export interface Graph {
getInputIndices(): ReadonlyArray<number>;
getInputNames(): ReadonlyArray<string>;
getOutputIndices(): ReadonlyArray<number>;
getOutputNames(): ReadonlyArray<string>;
getValues(): ReadonlyArray<Graph.Value>;
getNodes(): ReadonlyArray<Graph.Node>;
}
// tslint:disable-next-line:variable-name
export const Graph = {
/**
* construct a graph from a graph protobuf type
*/
from: (graphProto: onnx.IGraphProto, initializer?: Graph.Initializer) => new GraphImpl(graphProto, initializer)
};
class Value implements Graph.Value {
constructor(valueInfo?: onnx.IValueInfoProto) {
this._from = undefined;
this._to = [];
this.tensor = undefined;
this.type = undefined;
if (valueInfo) {
this.type = ProtoUtil.tensorValueTypeFromProto(valueInfo.type!.tensorType!);
}
}
_from?: number; // -1 represent from initializer
get from() {
return this._from!;
}
_to: number[];
get to() {
return this._to;
}
type?: Graph.ValueType;
tensor?: Tensor;
}
class Node implements Graph.Node {
constructor(_nodeProto: onnx.INodeProto) {
this.name = _nodeProto.name!;
this.opType = _nodeProto.opType!;
this.inputs = [];
this.outputs = [];
this.attributes = new Attribute(_nodeProto.attribute);
this.executeNode = true;
}
name: string;
opType: string;
inputs: number[];
outputs: number[];
attributes: Attribute;
executeNode: boolean;
}
class GraphImpl implements Graph, Graph.Transformer {
private _allData: Value[];
private _allInputIndices: number[];
private _allInputNames: string[];
private _allOutputIndices: number[];
private _allOutputNames: string[];
private _nodes: Node[];
constructor(graph: onnx.IGraphProto, graphInitializer?: Graph.Initializer) {
if (!graph) {
throw new TypeError('graph is empty');
}
// build the graph - will throw exceptions if something fatal is detected
this.buildGraph(graph);
// execute any transformation logic for the graph (if applicable)
this.transformGraph(graphInitializer);
// check for cycles and other inconsistencies - will throw exceptions if something fatal is detected
this.checkIsAcyclic();
}
getInputIndices(): ReadonlyArray<number> {
return this._allInputIndices;
}
getInputNames(): ReadonlyArray<string> {
return this._allInputNames;
}
getOutputIndices(): ReadonlyArray<number> {
return this._allOutputIndices;
}
getOutputNames(): ReadonlyArray<string> {
return this._allOutputNames;
}
getValues(): ReadonlyArray<Graph.Value> {
return this._allData;
}
getNodes(): ReadonlyArray<Graph.Node> {
return this._nodes;
}
private buildGraph(graph: onnx.IGraphProto) {
const dataIndices = new Map<string, number>();
this._allData = [];
this._allInputIndices = [];
this._allInputNames = [];
this._allOutputIndices = [];
this._allOutputNames = [];
this._nodes = [];
const nodesIndices = new Map<string, number>();
// scan all inputs
if (!graph.input) {
throw new Error('missing information in graph: input');
}
const inputValueNames = [];
for (const i of graph.input) {
if (dataIndices.has(i.name!)) {
throw new Error(`duplicated input name: ${i.name}`);
}
const currentIndex = this._allData.push(new Value(i)) - 1;
dataIndices.set(i.name!, currentIndex);
inputValueNames.push(i.name!);
}
// scan all initializers
if (!graph.initializer) {
throw new Error('missing information in graph: initializer');
}
for (const i of graph.initializer) {
if (!dataIndices.has(i.name!)) {
throw new Error(`invalid name for initializer: ${i.name}`);
}
const index = dataIndices.get(i.name!)!;
this._allData[index]._from = -1;
this._allData[index].tensor = Tensor.fromProto(i);
}
// filter out input indices
for (let i = 0; i < this._allData.length; i++) {
if (!this._allData[i].tensor) {
this._allInputIndices.push(i);
this._allInputNames.push(inputValueNames[i]);
}
}
// scan all outputs
if (!graph.output) {
throw new Error('missing information in graph: output');
}
for (const i of graph.output) {
if (dataIndices.has(i.name!)) {
throw new Error(`duplicated output name: ${i.name}`);
}
const currentIndex = this._allData.push(new Value(i)) - 1;
dataIndices.set(i.name!, currentIndex);
this._allOutputIndices.push(currentIndex);
this._allOutputNames.push(i.name!);
}
// scan all nodes
if (!graph.node) {
throw new Error('missing information in graph: node');
}
for (const nodeProto of graph.node) {
if (!nodeProto.name) {
// assign a name to the node if it doesn't have one
for (let pick = 0;; pick++) {
const name = `unnamed_${nodeProto.opType}_${pick}`;
if (!nodesIndices.has(name)) {
nodeProto.name = name;
break;
}
}
}
if (nodesIndices.has(nodeProto.name)) {
throw new Error(`duplicated node name: ${nodeProto.name}`);
}
const currentIndex = this._nodes.push(new Node(nodeProto)) - 1;
nodesIndices.set(nodeProto.name, currentIndex);
}
// scan node's outputs
for (let i = 0; i < this._nodes.length; i++) {
const node = this._nodes[i];
const nodeProto = graph.node[i];
if (!nodeProto.output) {
throw new Error(`missing output for node: ${nodeProto.name}`);
}
for (const output of nodeProto.output) {
let dataIndex = dataIndices.get(output);
if (typeof dataIndex === 'undefined') {
dataIndex = this._allData.push(new Value()) - 1;
dataIndices.set(output, dataIndex);
}
node.outputs.push(dataIndex);
if (this._allData[dataIndex]._from !== undefined) {
throw new Error(`multiple nodes output to one data value: ${dataIndex}`);
}
this._allData[dataIndex]._from = i;
// for the 'Constant' operator, just create a new edge in the graph corresponding to the 'output' of the
// operator and ignore the node from the graph
if (nodeProto.opType === 'Constant') {
if (!nodeProto.attribute || nodeProto.attribute.length !== 1 || !nodeProto.attribute[0].t) {
throw new Error(`missing attributes or missing tensor value in attributes for this Constant operator`);
}
if (!nodeProto.output || nodeProto.output.length !== 1) {
throw new Error(`missing output or incorrect number of outputs for this Constant operator`);
}
node.outputs.pop();
node.executeNode = false;
this._allData[dataIndex]._from = -1;
this._allData[dataIndex].tensor = Tensor.fromProto(nodeProto.attribute[0].t);
}
}
}
// scan node's inputs
for (let i = 0; i < this._nodes.length; i++) {
const node = this._nodes[i];
const nodeProto = graph.node[i];
if (!nodeProto.input) {
throw new Error(`missing input for node: ${nodeProto.name}`);
}
for (const input of nodeProto.input) {
const dataIndex = dataIndices.get(input);
if (typeof dataIndex === 'undefined') {
throw new Error(`unrecognized input '${input}' for node: ${nodeProto.name}`);
}
node.inputs.push(dataIndex);
this._allData[dataIndex]._to.push(i);
}
}
return true;
}
private checkIsAcyclic() {
// go through the graph and check for cycles or other fatal inconsistencies
const starters: Set<number> = new Set<number>();
this._allInputIndices.forEach(i => {
const data = this._allData[i];
data._to.forEach(j => {
starters.add(j);
});
});
// Iterative DFS to check for cycles
const nodesStack = Array.from(starters);
const nodesState = new Array<string>(this._nodes.length).fill('white');
const nodesProcessed = new Set<number>();
while (nodesStack.length > 0) {
const nodeIndex = nodesStack.pop()!;
// this node has now been processed completely. Mark this node 'black' to denote this.
if (nodesProcessed.has(nodeIndex)) {
nodesState[nodeIndex] = 'black';
} else {
// this node is under processing stage. mark this node 'gray' to denote this.
nodesProcessed.add(nodeIndex);
nodesStack.push(nodeIndex);
nodesState[nodeIndex] = 'gray';
this._nodes[nodeIndex].outputs.forEach((outgoingEdgeIndex) => {
const data = this._allData[outgoingEdgeIndex];
if (typeof data.tensor !== 'undefined') {
throw new Error(`node outputs should not be initialized`);
}
if (data._from !== nodeIndex) {
throw new Error(`from property of the Value object doesn't match index of Node being processed`);
}
data._to.forEach((downstreamNodeIndex) => {
// back edge found - cyclic
if (nodesState[downstreamNodeIndex] === 'gray') {
throw new Error(`model graph is cyclic`);
}
// tree edge found - continue processing by adding it to stack
else if (nodesState[downstreamNodeIndex] === 'white') {
nodesStack.push(downstreamNodeIndex);
}
});
});
}
}
}
private transformGraph(graphInitializer?: Graph.Initializer): void {
// apply common transform
this.removeAllIdentityNodes();
this.removeAllDropoutNodes();
// apply initializer specific transform
if (graphInitializer) {
graphInitializer.transformGraph(this);
}
// finalize graph
this.finalizeGraph();
}
/**
* finalize the graph.
*
* this function should be called after all the transformation completed.
* this function removes all unnecessary nodes and values from the graph
*/
finalizeGraph() {
let offset = 0;
// delete all nodes that are not being executed
for (let i = 0; i < this._nodes.length; i++) {
if (!this._nodes[i].executeNode) {
// delete this node and shift all subsequent nodes up
offset++;
// delete all output values
this._nodes[i].outputs.forEach(ind => {
this._allData[ind]._from = -2;
});
this._nodes.splice(i, 1);
i--;
continue;
}
if (offset > 0) {
// update the value table
this._nodes[i].inputs.forEach(value => {
const ind = this._allData[value]._to.indexOf(i + offset);
if (ind !== -1) {
this._allData[value]._to[ind] = i;
}
});
this._nodes[i].outputs.forEach(value => {
if (this._allData[value]._from && this._allData[value]._from! === i + offset) {
this._allData[value]._from! = i;
}
});
}
}
offset = 0;
// delete all values that are not being referenced
for (let i = 0; i < this._allData.length; i++) {
// if current value is neither linked to next node, nor an output value, remove it.
if (this._allData[i].from === -2 && this._allOutputIndices.indexOf(i + offset) === -1) {
offset++;
this._allData.splice(i, 1);
i--;
continue;
}
if (offset > 0) {
let ind = -1;
// if current value is neither an input value nor an initializer, find the node it's
// coming from and update the corresponding node output
if (this._allData[i].from !== undefined && this._allData[i].from !== -1) {
ind = this._nodes[this._allData[i].from].outputs.indexOf(i + offset);
if (ind !== -1) {
this._nodes[this._allData[i].from].outputs[ind] = i;
}
} else {
// if current value is an input value, update its reference in inputIndices
ind = this._allInputIndices.indexOf(i + offset);
if (ind !== -1) {
this._allInputIndices[ind] = i;
}
}
// find the node that the current value is linking to and update its input reference
this._allData[i].to.forEach(node => {
ind = this._nodes[node].inputs.indexOf(i + offset);
if (ind !== -1) {
this._nodes[node].inputs[ind] = i;
}
});
if (this._allData[i].to.length === 0) {
// if current value is a graph output, update its reference in outputIndices
ind = this._allOutputIndices.indexOf(i + offset);
if (ind !== -1) {
this._allOutputIndices[ind] = i;
}
}
}
}
}
/**
* Delete the specifed node. Assume the node has only one input and the first output connected to other nodes
* @param nodeIndex The index of node to be deleted
*/
private deleteNode(nodeIndex: number) {
const node = this._nodes[nodeIndex];
if (node.inputs.length > 1) {
throw new Error(`Node deletion with multiple inputs is not supported. `);
}
if (node.outputs.length > 1) {
for (let i = 1; i < node.outputs.length; i++) {
if (this._allData[node.outputs[i]].to.length > 0) {
throw new Error(`Node deletion with more than one output connected to other nodes is not supported. `);
}
}
}
// this node wil not be executed
node.executeNode = false;
const inputValueIndex = node.inputs[0];
const outputValueIndex = node.outputs[0];
const nodesConsumingOutput = this._allData[outputValueIndex].to;
// remove this node from the to property of the input Value
const delIndex = this._allData[inputValueIndex].to.indexOf(nodeIndex);
// should not happen
if (delIndex === -1) {
throw new Error(`The Value object doesn't have the current Node in it's 'to' property `);
}
this._allData[inputValueIndex].to.splice(delIndex, 1);
// clear node indices consuming this output Value
this._allData[outputValueIndex]._to = [];
// if the output of this node is a graph output, adjust the index appropriately
const index = this._allOutputIndices.indexOf(outputValueIndex);
if (index !== -1) {
this._allOutputIndices[index] = inputValueIndex;
}
// override the inputs for nodes consuming this node's output with the input to this node
if (nodesConsumingOutput && nodesConsumingOutput.length > 0) {
for (const nodeIndex of nodesConsumingOutput) {
const replaceIndex = this._nodes[nodeIndex].inputs.indexOf(outputValueIndex);
// should not happen
if (replaceIndex === -1) {
throw new Error(`The Node object doesn't have the output Value in it's 'inputs' property `);
}
this._nodes[nodeIndex].inputs[replaceIndex] = inputValueIndex;
this._allData[inputValueIndex].to.push(nodeIndex);
}
}
}
removeAllDropoutNodes() {
let nodeIndex = 0;
for (const node of this._nodes) {
// weed out 'Dropout' nodes so that no time is wasted in execution
if (node.opType === 'Dropout') {
// the node should have exactly 1 input and 1 or 2 outputs
if (node.inputs.length !== 1) {
throw new Error(`Dropout nodes should only contain one input. `);
}
if (node.outputs.length !== 1 && node.outputs.length !== 2) {
throw new Error(`Dropout nodes should contain either 1 or 2 output(s)`);
}
// the second output should not be referenced by any other node
if (node.outputs.length === 2 && this._allData[node.outputs[1]]._to.length !== 0) {
throw new Error(`Dropout nodes's second output should not be referenced by other nodes`);
}
this.deleteNode(nodeIndex);
}
nodeIndex++;
}
}
removeAllIdentityNodes() {
let nodeIndex = 0;
for (const node of this._nodes) {
// weed out 'Identity' nodes so that no time is wasted in execution
if (node.opType === 'Identity') {
this.deleteNode(nodeIndex);
}
nodeIndex++;
}
}
}