Skip to content

Commit

Permalink
Update tensor formatter (lutzroeder#961)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Sep 11, 2022
1 parent 264390e commit 4fc9f6f
Showing 1 changed file with 109 additions and 103 deletions.
212 changes: 109 additions & 103 deletions source/view-sidebar.js
Original file line number Diff line number Diff line change
Expand Up @@ -1437,7 +1437,7 @@ sidebar.Tensor = class {
const dataType = this._type.dataType;
const context = {};
context.layout = this._layout;
context.dimensions = this._type.shape.dimensions;
context.dimensions = this._type.shape.dimensions.map((value) => !Number.isInteger(value) && value.toNumber ? value.toNumber() : value);
context.dataType = dataType;
const size = context.dimensions.reduce((a, b) => a * b, 1);
switch (this._layout) {
Expand All @@ -1446,18 +1446,20 @@ sidebar.Tensor = class {
context.data = (this._data instanceof Uint8Array || this._data instanceof Int8Array) ? this._data : this._data.peek();
context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
if (sidebar.Tensor.dataTypes.has(dataType)) {
const itemsize = sidebar.Tensor.dataTypes.get(dataType);
if (this._data.length < (itemsize * size)) {
context.itemsize = sidebar.Tensor.dataTypes.get(dataType);
if (this._data.length < (context.itemsize * size)) {
throw new Error('Invalid tensor data size.');
}
}
else if (dataType.startsWith('uint') && !isNaN(parseInt(dataType.substring(4), 10))) {
context.dataType = 'uint';
context.bits = parseInt(dataType.substring(4), 10);
context.itemsize = 1;
}
else if (dataType.startsWith('int') && !isNaN(parseInt(dataType.substring(3), 10))) {
context.dataType = 'int';
context.bits = parseInt(dataType.substring(3), 10);
context.itemsize = 1;
}
else {
throw new Error("Tensor data type '" + dataType + "' is not implemented.");
Expand Down Expand Up @@ -1546,111 +1548,115 @@ sidebar.Tensor = class {
const dataType = context.dataType;
const view = context.view;
if (dimension == dimensions.length - 1) {
for (let i = 0; i < size; i++) {
if (context.count > context.limit) {
results.push('...');
return results;
}
switch (dataType) {
case 'boolean':
results.push(view.getUint8(context.index) === 0 ? false : true);
context.index++;
context.count++;
break;
case 'qint8':
case 'int8':
results.push(view.getInt8(context.index));
context.index++;
context.count++;
break;
case 'qint16':
case 'int16':
results.push(view.getInt16(context.index, this._littleEndian));
context.index += 2;
context.count++;
break;
case 'qint32':
case 'int32':
results.push(view.getInt32(context.index, this._littleEndian));
context.index += 4;
context.count++;
break;
case 'int64':
results.push(view.getInt64(context.index, this._littleEndian));
context.index += 8;
context.count++;
break;
case 'int':
results.push(view.getIntBits(context.index, context.bits));
context.index++;
context.count++;
break;
case 'quint8':
case 'uint8':
results.push(view.getUint8(context.index));
context.index++;
context.count++;
break;
case 'quint16':
case 'uint16':
results.push(view.getUint16(context.index, true));
context.index += 2;
context.count++;
break;
case 'quint32':
case 'uint32':
results.push(view.getUint32(context.index, true));
context.index += 4;
context.count++;
break;
case 'uint64':
results.push(view.getUint64(context.index, true));
context.index += 8;
context.count++;
break;
case 'uint':
results.push(view.getUintBits(context.index, context.bits));
context.index++;
context.count++;
break;
case 'float16':
results.push(view.getFloat16(context.index, this._littleEndian));
context.index += 2;
context.count++;
break;
case 'float32':
results.push(view.getFloat32(context.index, this._littleEndian));
context.index += 4;
context.count++;
break;
case 'float64':
results.push(view.getFloat64(context.index, this._littleEndian));
context.index += 8;
context.count++;
break;
case 'bfloat16':
results.push(view.getBfloat16(context.index, this._littleEndian));
context.index += 2;
context.count++;
break;
case 'complex64':
results.push(view.getComplex64(i << 3, this._littleEndian));
const ellipsis = (context.count + size) > context.limit;
const length = ellipsis ? context.limit - context.count : size;
let i = context.index;
const max = i + (length * context.itemsize);
switch (dataType) {
case 'boolean':
for (; i < max; i += 1) {
results.push(view.getUint8(i) === 0 ? false : true);
}
break;
case 'qint8':
case 'int8':
for (; i < max; i++) {
results.push(view.getInt8(i));
}
break;
case 'qint16':
case 'int16':
for (; i < max; i += 2) {
results.push(view.getInt16(i, this._littleEndian));
}
break;
case 'qint32':
case 'int32':
for (; i < max; i += 4) {
results.push(view.getInt32(i, this._littleEndian));
}
break;
case 'int64':
for (; i < max; i += 8) {
results.push(view.getInt64(i, this._littleEndian));
}
break;
case 'int':
for (; i < size; i++) {
results.push(view.getIntBits(i, context.bits));
}
break;
case 'quint8':
case 'uint8':
for (; i < max; i++) {
results.push(view.getUint8(i));
}
break;
case 'quint16':
case 'uint16':
for (; i < max; i += 2) {
results.push(view.getUint16(i, true));
}
break;
case 'quint32':
case 'uint32':
for (; i < max; i += 4) {
results.push(view.getUint32(i, true));
}
break;
case 'uint64':
for (; i < max; i += 8) {
results.push(view.getUint64(i, true));
}
break;
case 'uint':
for (; i < max; i++) {
results.push(view.getUintBits(i, context.bits));
}
break;
case 'float16':
for (; i < max; i += 2) {
results.push(view.getFloat16(i, this._littleEndian));
}
break;
case 'float32':
for (; i < max; i += 4) {
results.push(view.getFloat32(i, this._littleEndian));
}
break;
case 'float64':
for (; i < max; i += 8) {
results.push(view.getFloat64(i, this._littleEndian));
}
break;
case 'bfloat16':
for (; i < max; i += 2) {
results.push(view.getBfloat16(i, this._littleEndian));
}
break;
case 'complex64':
for (; i < max; i += 8) {
results.push(view.getComplex64(i, this._littleEndian));
context.index += 8;
context.count++;
break;
case 'complex128':
results.push(view.getComplex128(i << 4, this._littleEndian));
context.index += 16;
context.count++;
break;
default:
throw new Error("Unsupported tensor data type '" + dataType + "'.");
}
}
break;
case 'complex128':
for (; i < size; i += 16) {
results.push(view.getComplex128(i, this._littleEndian));
}
break;
default:
throw new Error("Unsupported tensor data type '" + dataType + "'.");
}
context.index = i;
context.count += length;
if (ellipsis) {
results.push('...');
}
}
else {
for (let j = 0; j < size; j++) {
if (context.count > context.limit) {
if (context.count >= context.limit) {
results.push('...');
return results;
}
Expand Down

0 comments on commit 4fc9f6f

Please sign in to comment.