diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java index 719ad3a9cd2..ac4e8955d39 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java @@ -34,7 +34,7 @@ public class DenseBlockFP64 extends DenseBlockDRB public DenseBlockFP64(int[] dims) { super(dims); - reset(_rlen, _odims, 0); + resetNoFill(_rlen, _odims); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java index 1e1d8a3618e..30ea45c80dc 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java @@ -835,9 +835,18 @@ private int internPosFIndexLTE(int r, int c) { } @Override - public int posFIndexGTE(int r, int c) { - int index = internPosFIndexGTE(r, c); - return (index>=0) ? index-pos(r) : index; + public final int posFIndexGTE(int r, int c) { + final int pos = pos(r); + final int len = size(r); + final int end = pos + len; + + // search for existing col index + int index = Arrays.binarySearch(_indexes, pos, end, c); + if(index < 0) + // search gt col index (see binary search) + index = Math.abs(index + 1); + + return (index < end) ? index - pos : -1; } private int internPosFIndexGTE(int r, int c) { diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java index 1886d5eb880..e889d58b68f 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java @@ -49,7 +49,7 @@ public SparseBlockMCSR(SparseBlock sblock) _rows = new SparseRow[orows.length]; for( int i=0; i<_rows.length; i++ ) if( orows[i] != null ) - _rows[i] = new SparseRowVector(orows[i]); + _rows[i] = orows[i].copy(true); } //general case SparseBlock else { @@ -58,10 +58,17 @@ public SparseBlockMCSR(SparseBlock sblock) if( !sblock.isEmpty(i) ) { int apos = sblock.pos(i); int alen = sblock.size(i); - _rows[i] = new SparseRowVector(alen); - ((SparseRowVector)_rows[i]).setSize(alen); - System.arraycopy(sblock.indexes(i), apos, _rows[i].indexes(), 0, alen); - System.arraycopy(sblock.values(i), apos, _rows[i].values(), 0, alen); + if(alen == 0){ + // do nothing + } + else if(alen == 1) + _rows[i] = new SparseRowScalar(sblock.indexes(i)[apos], sblock.values(i)[apos]); + else{ + _rows[i] = new SparseRowVector(alen); + ((SparseRowVector)_rows[i]).setSize(alen); + System.arraycopy(sblock.indexes(i), apos, _rows[i].indexes(), 0, alen); + System.arraycopy(sblock.values(i), apos, _rows[i].values(), 0, alen); + } } } } @@ -183,7 +190,7 @@ public boolean isContiguous() { } @Override - public boolean isAllocated(int r) { + public final boolean isAllocated(int r) { return _rows[r] != null; } @@ -283,8 +290,8 @@ public long size(int rl, int ru, int cl, int cu) { } @Override - public boolean isEmpty(int r) { - return (!isAllocated(r) || _rows[r].isEmpty()); + public final boolean isEmpty(int r) { + return !isAllocated(r) || _rows[r].isEmpty(); } @Override @@ -426,6 +433,18 @@ public int posFIndexGT(int r, int c) { _rows[r] = new SparseRowVector(_rows[r]); return ((SparseRowVector)_rows[r]).searchIndexesFirstGT(c); } + + public void setNnzEstimatePerRow(int nnzPerCol, int nCol){ + for(SparseRow s : _rows){ + if(s instanceof SparseRowVector){ + SparseRowVector sv = (SparseRowVector)s; + sv.setEstimatedNzs(nnzPerCol); + } + else if(s == null){ + s = new SparseRowVector(nnzPerCol, nCol); + } + } + } @Override public String toString() { diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseRow.java b/src/main/java/org/apache/sysds/runtime/data/SparseRow.java index bfae2c21c45..e0b47a895ae 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseRow.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseRow.java @@ -131,6 +131,14 @@ public abstract class SparseRow implements Serializable * @param eps epsilon value */ public abstract void compact(double eps); + + /** + * Make a copy of this row. + * + * @param deep if the copy should be deep + * @return A copy + */ + public abstract SparseRow copy(boolean deep); @Override public String toString() { diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseRowScalar.java b/src/main/java/org/apache/sysds/runtime/data/SparseRowScalar.java index 0b1ca982cd0..b4b07c8bae9 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseRowScalar.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseRowScalar.java @@ -41,8 +41,8 @@ public int size() { } @Override - public boolean isEmpty() { - return (index < 0); + public final boolean isEmpty() { + return index < 0; } @Override @@ -115,4 +115,9 @@ public int getIndex(){ public double getValue(){ return value; } + + @Override + public SparseRow copy(boolean deep){ + return new SparseRowScalar(index, value); + } } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java b/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java index 7419f126d1e..3e433f15fcc 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java @@ -24,18 +24,30 @@ import org.apache.sysds.runtime.util.SortUtils; import org.apache.sysds.runtime.util.UtilFunctions; -public final class SparseRowVector extends SparseRow{ +/** + * A sparse row vector that is able to grow dynamically as values are appended to it. + */ +public final class SparseRowVector extends SparseRow { private static final long serialVersionUID = 2971077474424464992L; - //initial capacity of any created sparse row - //WARNING: be aware that this affects the core memory estimates (incl. implicit assumptions)! + /** + *

Initial capacity of any created sparse row

+ * WARNING: be aware that this affects the core memory estimates (incl. implicit assumptions)! + */ public static final int initialCapacity = 4; + /** + * An estimate of the number of non zero values in this row. + * The estimate is used to set a threshold on how much the array should grow at certain + * lengths to not double the size at all times. + */ private int estimatedNzs = initialCapacity; - private int maxNzs = Integer.MAX_VALUE; - private int size = 0; - private double[] values = null; - private int[] indexes = null; + /** The current size of the row vector */ + private int size; + /** The values contained in the vector, can be allocated larger than needed */ + private double[] values; + /** The column indexes of the values contained, can be allocated larger than needed */ + private int[] indexes; public SparseRowVector() { this(initialCapacity); @@ -45,6 +57,7 @@ public SparseRowVector(int capacity) { estimatedNzs = capacity; values = new double[capacity]; indexes = new int[capacity]; + size = 0; } public SparseRowVector(int nnz, double[] v, int vlen) { @@ -59,6 +72,12 @@ public SparseRowVector(int nnz, double[] v, int vlen) { size = nnz; } + public SparseRowVector(double[] v, int[] i){ + values = v; + indexes = i; + size = v.length; + } + /** * Sparse row vector constructor that take a dense array, and allocate sparsely by ignoring zero values * @param v The dense row @@ -83,11 +102,10 @@ public SparseRowVector(double[] v){ public SparseRowVector(int estnnz, int maxnnz) { if( estnnz > initialCapacity ) estimatedNzs = estnnz; - maxNzs = maxnnz; - int capacity = ((estnnz0) ? - estnnz : initialCapacity); + int capacity = initialCapacity; values = new double[capacity]; indexes = new int[capacity]; + size = 0; } public SparseRowVector(SparseRow that) { @@ -109,8 +127,8 @@ public void setSize(int newsize) { } @Override - public boolean isEmpty() { - return (size == 0); + public final boolean isEmpty() { + return size == 0; } @Override @@ -157,10 +175,14 @@ public void copy(SparseRow that) @Override public void reset(int estnns, int maxnns) { estimatedNzs = estnns; - maxNzs = maxnns; + // maxNzs = maxnns; size = 0; } + public void setEstimatedNzs(int estnnz){ + estimatedNzs = estnnz; + } + private void recap(int newCap) { if( newCap<=values.length ) return; @@ -179,11 +201,13 @@ private void recap(int newCap) { */ private int newCapacity() { final double currLen = values.length; + final boolean lessThanEstimate = currLen < estimatedNzs; + final double factor = lessThanEstimate ? + SparseBlock.RESIZE_FACTOR1 : SparseBlock.RESIZE_FACTOR2; //scale length exponentially based on estimated number of non-zeros - final int nextLen = (int)Math.ceil(currLen * ((currLen < estimatedNzs) ? - SparseBlock.RESIZE_FACTOR1 : SparseBlock.RESIZE_FACTOR2)); + final int nextLen = (int)Math.ceil(currLen * factor); //cap at max number of non-zeros with robustness of initial zero - return Math.max(2, Math.min(maxNzs, nextLen)); + return Math.max(2, nextLen); } @Override @@ -391,26 +415,27 @@ public void setIndexRange(int cl, int cu, double[] v, int[] vix, int vpos, int v } } - private void resizeAndInsert(int index, int col, double v) { - //allocate new arrays - int newCap = newCapacity(); - double[] oldvalues = values; - int[] oldindexes = indexes; + private final void resizeAndInsert(int index, int col, double v) { + final int newCap = newCapacity(); + resizeVals(newCap, index, v); + resizeIndex(newCap, index, col); + size++; + } + + private final void resizeVals(int newCap, int index, double v){ + double[] old = values; values = new double[newCap]; + System.arraycopy(old, 0, values, 0, index); + values[index] = v; + System.arraycopy(old, index, values, index+1, size-index); + } + + private final void resizeIndex(int newCap, int index, int col){ + int[] old = indexes; indexes = new int[newCap]; - - //copy lhs values to new array - System.arraycopy(oldvalues, 0, values, 0, index); - System.arraycopy(oldindexes, 0, indexes, 0, index); - - //insert new value + System.arraycopy(old, 0, indexes, 0, index); indexes[index] = col; - values[index] = v; - - //copy rhs values to new array - System.arraycopy(oldvalues, index, values, index+1, size-index); - System.arraycopy(oldindexes, index, indexes, index+1, size-index); - size++; + System.arraycopy(old, index, indexes, index+1, size-index); } private void shiftRightAndInsert(int index, int col, double v) { @@ -467,4 +492,9 @@ public void compact(double eps) { } size = nnz; //adjust row size } + + @Override + public SparseRow copy(boolean deep){ + return new SparseRowVector(this); + } } diff --git a/src/test/java/org/apache/sysds/test/component/matrix/SparseCSRTest.java b/src/test/java/org/apache/sysds/test/component/matrix/SparseCSRTest.java new file mode 100644 index 00000000000..b7070c55b86 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/matrix/SparseCSRTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.matrix; + +import static org.junit.Assert.assertEquals; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.junit.Test; + +public class SparseCSRTest { + protected static final Log LOG = LogFactory.getLog(CompressedMatrixBlock.class.getName()); + + @Test + public void testGTE() { + int[] rs = new int[] {0, 9}; + int[] colInd = new int[] {10, 20, 30, 40, 50, 60, 80, 90, 100}; + double[] val = new double[] {1, 1, 1, 1, 1, 1, 1, 1, 1}; + SparseBlockCSR b = new SparseBlockCSR(rs, colInd, val, val.length); + + assertEquals(0, b.posFIndexGTE(0, 0)); + assertEquals(0, b.posFIndexGTE(0, 10)); + assertEquals(1, b.posFIndexGTE(0, 11)); + assertEquals(7, b.posFIndexGTE(0, 90)); + assertEquals(8, b.posFIndexGTE(0, 91)); + assertEquals(-1, b.posFIndexGTE(0, 101)); + assertEquals(-1, b.posFIndexGTE(0, 10100)); + + } + + @Test + public void testGTE2Rows() { + int[] rs = new int[] {0, 0, 9}; + int[] colInd = new int[] {10, 20, 30, 40, 50, 60, 80, 90, 100}; + double[] val = new double[] {1, 1, 1, 1, 1, 1, 1, 1, 1}; + SparseBlockCSR b = new SparseBlockCSR(rs, colInd, val, val.length); + LOG.error(b); + + assertEquals(0, b.posFIndexGTE(1, 0)); + assertEquals(0, b.posFIndexGTE(1, 10)); + assertEquals(1, b.posFIndexGTE(1, 11)); + assertEquals(7, b.posFIndexGTE(1, 90)); + assertEquals(8, b.posFIndexGTE(1, 91)); + assertEquals(-1, b.posFIndexGTE(1, 101)); + assertEquals(-1, b.posFIndexGTE(1, 10100)); + + } + + @Test + public void testGTE2RowsNN() { + int[] rs = new int[] {0, 1, 10}; + int[] colInd = new int[] {100, 10, 20, 30, 40, 50, 60, 80, 90, 100}; + double[] val = new double[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + SparseBlockCSR b = new SparseBlockCSR(rs, colInd, val, val.length); + LOG.error(b); + + assertEquals(0, b.posFIndexGTE(1, 0)); + assertEquals(0, b.posFIndexGTE(1, 10)); + assertEquals(1, b.posFIndexGTE(1, 11)); + assertEquals(7, b.posFIndexGTE(1, 90)); + assertEquals(8, b.posFIndexGTE(1, 91)); + assertEquals(-1, b.posFIndexGTE(1, 101)); + assertEquals(-1, b.posFIndexGTE(1, 10100)); + + } +}