Skip to content

Commit

Permalink
[SYSTEMDS-1151] Fix error handling left indexing w/ scalars
Browse files Browse the repository at this point in the history
This patch creates consistency of error messages for left indexing
(e.g., A[a:b,c:d] = x) for matrix and scalar right-hand-sides.
  • Loading branch information
mboehm7 committed Mar 22, 2024
1 parent 3f9e903 commit af2c896
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
Expand Up @@ -4166,19 +4166,10 @@ public final MatrixBlock leftIndexingOperations(MatrixBlock rhsMatrix,
}

public MatrixBlock leftIndexingOperations(MatrixBlock rhsMatrix,
int rl, int ru, int cl, int cu, MatrixBlock ret, UpdateType update) {
int rl, int ru, int cl, int cu, MatrixBlock ret, UpdateType update)
{
// Check the validity of bounds
if( rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows()
|| cl < 0 || cl >= getNumColumns() || cu < cl || cu >= getNumColumns() ) {
throw new DMLRuntimeException("Invalid values for matrix indexing: ["+(rl+1)+":"+(ru+1)+","
+ (cl+1)+":"+(cu+1)+"] " + "must be within matrix dimensions ["+getNumRows()+","+getNumColumns()+"].");
}
if( (ru-rl+1) != rhsMatrix.getNumRows() || (cu-cl+1) != rhsMatrix.getNumColumns() ) {
throw new DMLRuntimeException("Invalid values for matrix indexing: " +
"dimensions of the source matrix ["+rhsMatrix.getNumRows()+"x" + rhsMatrix.getNumColumns() + "] " +
"do not match the shape of the matrix specified by indices [" +
(rl+1) +":" + (ru+1) + ", " + (cl+1) + ":" + (cu+1) + "] (i.e., ["+(ru-rl+1)+"x"+(cu-cl+1)+"]).");
}
checkDimsForLeftIndexing(rl, ru, cl, cu, true, rhsMatrix.rlen, rhsMatrix.clen);

MatrixBlock result = ret;
boolean sp = estimateSparsityOnLeftIndexing(rlen, clen, nonZeros,
Expand Down Expand Up @@ -4260,9 +4251,12 @@ else if( !result.sparse && sp )
* @param update ?
* @return matrix block
*/
public MatrixBlock leftIndexingOperations(ScalarObject scalar, int rl, int cl, MatrixBlock ret, UpdateType update) {
public MatrixBlock leftIndexingOperations(ScalarObject scalar, int rl, int cl,
MatrixBlock ret, UpdateType update)
{
double inVal = scalar.getDoubleValue();
boolean sp = estimateSparsityOnLeftIndexing(rlen, clen, nonZeros, 1, 1, (inVal!=0)?1:0);
checkDimsForLeftIndexing(rl, rl, cl, cl, false, -1, -1);

if( !update.isInPlace() ) { //general case
if(ret==null)
Expand All @@ -4283,7 +4277,28 @@ public MatrixBlock leftIndexingOperations(ScalarObject scalar, int rl, int cl, M
ret.quickSetValue(rl, cl, inVal);
return ret;
}


private void checkDimsForLeftIndexing(int rl, int ru, int cl, int cu,
boolean checkSrc, int rhsr, int rhsc)
{
int rlen = getNumRows(), clen = getNumColumns();
if( rl < 0 || rl >= rlen || ru < rl || ru >= rlen
|| cl < 0 || cl >= clen || cu < cl || cu >= clen ) {
throw new DMLRuntimeException("Invalid values for matrix indexing: "
+ "["+(rl+1)+":"+(ru+1)+"," + (cl+1)+":"+(cu+1)+"] "
+ "must be within matrix dimensions ["+rlen+"x"+clen+"].");
}
if( checkSrc ) {
if( (ru-rl+1) != rhsr || (cu-cl+1) != rhsc ) {
throw new DMLRuntimeException("Invalid values for matrix indexing: "
+ "dimensions of the source matrix ["+rhsr+"x"+rhsc+"] "
+ "do not match the shape of the matrix specified by indices "
+ "["+(rl+1)+":"+(ru+1)+", "+(cl+1)+":"+(cu+1)+"] "
+ "(i.e., ["+(ru-rl+1)+"x"+(cu-cl+1)+"]).");
}
}
}

@Override
public final MatrixBlock slice(IndexRange ixrange, MatrixBlock ret) {
return slice(
Expand Down

0 comments on commit af2c896

Please sign in to comment.