Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions RESHAPE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Mat.reshape() Implementation

This package now includes a `reshape()` method for the `Mat` class that was missing from the original OpenCV.js build.

## Usage

```javascript
import cv from "@techstark/opencv-js";

// After OpenCV is loaded
const img = new cv.Mat(4, 4, cv.CV_8UC3); // 4x4 RGB image

// Reshape to different dimensions while preserving total data elements
const vectorized = img.reshape(-1, 3); // Auto-calculate channels, 3 rows
const singleChannel = img.reshape(1, 12); // Convert to single channel, 12 rows

// Clean up
img.delete();
vectorized.delete();
singleChannel.delete();
```

## Parameters

- `cn`: Number of channels in the result matrix. Use `-1` to auto-calculate based on the total elements and rows.
- `rows` (optional): Number of rows in the result matrix. If not specified, attempts to maintain matrix structure.

## Behavior

The `reshape()` method reorganizes matrix data without copying it, similar to OpenCV's native `reshape()` function:

- Total number of data elements (`rows × cols × channels`) must remain constant
- When `cn = -1`: Auto-calculates channels, usually defaults to 1 channel for vectorization
- When `rows` is specified: Calculates columns to fit the total elements
- Returns a new `Mat` object with the reshaped dimensions

## Common Use Cases

1. **Image vectorization**: Convert 2D image to 1D vector
```javascript
const vector = image.reshape(-1, 1); // Single row vector
```

2. **Channel reorganization**: Change number of channels
```javascript
const singleChannel = image.reshape(1); // Convert to grayscale layout
```

3. **Matrix flattening**: Convert multi-dimensional to 2D
```javascript
const flattened = matrix.reshape(-1, totalPixels); // One row per pixel
```
1 change: 1 addition & 0 deletions doc/cvKeys.json
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,7 @@
"matSize",
"mul",
"ptr",
"reshape",
"roi",
"row",
"rowRange",
Expand Down
6 changes: 6 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
export * from "./types/opencv";
import { extendMatWithReshape } from "./mat-extensions";

// Extend Mat with missing methods when OpenCV is loaded
if (typeof global !== 'undefined' && global.cv) {
extendMatWithReshape();
}
160 changes: 160 additions & 0 deletions src/mat-extensions.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import type { Mat } from "./types/opencv/Mat";
import type { int } from "./types/opencv/_types";

declare global {
interface Mat {
reshape(cn: int, rows?: int): Mat;
}
}

// Extend Mat prototype with reshape method
export function extendMatWithReshape() {
if (typeof global !== 'undefined' && global.cv && global.cv.Mat) {
const MatPrototype = global.cv.Mat.prototype;

if (!MatPrototype.reshape) {
MatPrototype.reshape = function(cn: int, rows?: int): Mat {
// Get current matrix properties
const currentRows = this.rows;
const currentCols = this.cols;
const currentChannels = this.channels();
const currentType = this.type();
const currentDepth = currentType & 7; // Extract depth (CV_8U, CV_16S, etc.)

const totalDataElements = currentRows * currentCols * currentChannels;

let newChannels: int;
let newRows: int;
let newCols: int;

// OpenCV reshape semantics:
// - cn = -1 means "auto-calculate channels"
// - rows = -1 or undefined means "auto-calculate rows"
// - The total number of elements must remain constant

if (cn === -1) {
// Auto-calculate channels based on rows
if (rows === undefined || rows === 0) {
throw new Error("When cn=-1, rows parameter must be specified");
}

newRows = rows;
// Calculate how many elements per row we need
const elementsPerRow = totalDataElements / newRows;
if (Math.floor(elementsPerRow) !== elementsPerRow) {
throw new Error(`Cannot reshape: total elements (${totalDataElements}) not evenly divisible by rows (${newRows})`);
}

// Try to fit this into a reasonable matrix structure
// First, try to keep channels as 1 (most common case for vectorization)
newChannels = 1;
newCols = elementsPerRow;

// If that creates too many columns, try other channel arrangements
if (newCols > 10000) { // Arbitrary large number check
// Try to use original channels if it makes sense
if (elementsPerRow % currentChannels === 0) {
newChannels = currentChannels;
newCols = elementsPerRow / currentChannels;
} else {
// Try common channel counts
for (const testChannels of [3, 4, 2]) {
if (elementsPerRow % testChannels === 0) {
newChannels = testChannels;
newCols = elementsPerRow / testChannels;
break;
}
}
}
}
} else {
// Channels specified
newChannels = cn;

if (rows === undefined || rows === 0) {
// Auto-calculate rows - keep matrix as close to original as possible
const matrixElements = totalDataElements / newChannels;
if (Math.floor(matrixElements) !== matrixElements) {
throw new Error(`Cannot reshape: total elements (${totalDataElements}) not evenly divisible by channels (${newChannels})`);
}

// Try to keep close to original shape
newRows = currentRows;
newCols = matrixElements / newRows;

if (Math.floor(newCols) !== newCols) {
// Original shape doesn't work, find best factorization
newRows = Math.floor(Math.sqrt(matrixElements));
newCols = Math.floor(matrixElements / newRows);

if (newRows * newCols !== matrixElements) {
for (let r = 1; r <= matrixElements; r++) {
if (matrixElements % r === 0) {
newRows = r;
newCols = matrixElements / r;
break;
}
}
}
}
} else {
// Both channels and rows specified
newRows = rows;
const matrixElements = totalDataElements / newChannels;
if (Math.floor(matrixElements) !== matrixElements) {
throw new Error(`Cannot reshape: total elements (${totalDataElements}) not evenly divisible by channels (${newChannels})`);
}

newCols = matrixElements / newRows;
if (Math.floor(newCols) !== newCols) {
throw new Error(`Cannot reshape: matrix elements (${matrixElements}) not evenly divisible by rows (${newRows})`);
}
}
}

// Final validation
if (newRows * newCols * newChannels !== totalDataElements) {
throw new Error(`Reshape validation failed: ${newRows} × ${newCols} × ${newChannels} = ${newRows * newCols * newChannels} ≠ ${totalDataElements}`);
}

// Create the new matrix type
let newType: int;
switch (newChannels) {
case 1:
newType = currentDepth;
break;
case 2:
newType = currentDepth + 8;
break;
case 3:
newType = currentDepth + 16;
break;
case 4:
newType = currentDepth + 24;
break;
default:
newType = currentDepth + ((newChannels - 1) << 3);
break;
}

try {
// Create new matrix with calculated dimensions
const result = new global.cv.Mat(newRows, newCols, newType);

// Copy all the data (should be same amount, just organized differently)
const srcData = this.data;
const dstData = result.data;
const copyLength = Math.min(srcData.length, dstData.length);

for (let i = 0; i < copyLength; i++) {
dstData[i] = srcData[i];
}

return result;
} catch (error) {
throw new Error(`Failed to create reshaped matrix: ${error instanceof Error ? error.message : String(error)}`);
}
};
}
}
}
4 changes: 4 additions & 0 deletions test/cv.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import "../src";
import { extendMatWithReshape } from "../src/mat-extensions";

export async function setupOpenCv() {
const _cv = await require("../dist/opencv.js");
global.cv = _cv;

// Apply our extensions after OpenCV is loaded
extendMatWithReshape();
}

export function translateException(err: any) {
Expand Down
85 changes: 85 additions & 0 deletions test/reshape.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import { setupOpenCv, translateException } from "./cv";

beforeAll(setupOpenCv);

describe("Mat.reshape", () => {
it("should fix the original issue", async () => {
try {
// Create a simple test matrix
const origImg = new cv.Mat(4, 4, cv.CV_8UC4); // 4x4 RGBA image
const img = new cv.Mat();
cv.cvtColor(origImg, img, cv.COLOR_RGBA2RGB); // Convert to RGB (3 channels)

// This should now work (not throw "img.reshape is not a function")
expect(() => {
const vectorized = img.reshape(-1, 3);
vectorized.delete();
}).not.toThrow("img.reshape is not a function");

origImg.delete();
img.delete();
} catch (err) {
throw translateException(err);
}
});

it("should implement reshape functionality", async () => {
try {
// Create a 2x3 matrix with 2 channels (12 elements total)
const mat = new cv.Mat(2, 3, cv.CV_8UC2);

// Fill with test data
for (let i = 0; i < 2; i++) {
for (let j = 0; j < 3; j++) {
mat.ucharPtr(i, j)[0] = i * 6 + j * 2; // First channel
mat.ucharPtr(i, j)[1] = i * 6 + j * 2 + 1; // Second channel
}
}

// Test reshape: convert 2x3x2 to 3x2x2 (same total elements)
const reshaped = mat.reshape(2, 3);

expect(reshaped.rows).toBe(3);
expect(reshaped.cols).toBe(2);
expect(reshaped.channels()).toBe(2);
expect(reshaped.total() * reshaped.channels()).toBe(mat.total() * mat.channels());

// Test reshape with auto-calculated channels: total=12, rows=3, so 4 elements per row
// With -1 (auto-calculate), it should default to 1 channel, so 3x4x1
const reshaped2 = mat.reshape(-1, 3);

expect(reshaped2.rows).toBe(3);
expect(reshaped2.cols).toBe(4); // 12 total elements / 3 rows / 1 channel = 4 cols
expect(reshaped2.channels()).toBe(1); // Auto-calculated as 1 channel

mat.delete();
reshaped.delete();
reshaped2.delete();
} catch (err) {
throw translateException(err);
}
});

it("should handle edge cases", async () => {
try {
// Test with 1D vector
const mat = new cv.Mat(1, 6, cv.CV_8UC1);

// Reshape to 2x3
const reshaped = mat.reshape(1, 2);
expect(reshaped.rows).toBe(2);
expect(reshaped.cols).toBe(3);
expect(reshaped.channels()).toBe(1);

// Test invalid reshape (mismatched total elements)
expect(() => {
mat.reshape(1, 5); // 1*5 = 5, but original has 6 elements
}).toThrow();

mat.delete();
reshaped.delete();
} catch (err) {
throw translateException(err);
}
});
});