Skip to content

Commit

Permalink
Old bug which prevented support of 16-channel inputs in img_acts is n…
Browse files Browse the repository at this point in the history
…ow gone.
  • Loading branch information
Alexander Krizhevsky committed Jan 28, 2015
1 parent a113fac commit 5620d0a
Show file tree
Hide file tree
Showing 11 changed files with 1,676 additions and 1,132 deletions.
77 changes: 39 additions & 38 deletions cudaconv3/include/conv_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
void convLocalMaxUndo(NVMatrix& images, NVMatrix& maxGrads, NVMatrix& maxActs, NVMatrix& target,
int subsX, int startX, int strideX, int outputsX);
void convLocalAvgUndo(NVMatrix& avgGrads, NVMatrix& target,
int subsX, int startX, int strideX, int outputsX, int imgSize);
int subsX, int startX, int strideX, int outputsX, int imgSize, bool sum);

void convLocalAvgUndo(NVMatrix& avgGrads, NVMatrix& target,
int subsX, int startX, int strideX, int outputsX, int imgSize,
int subsX, int startX, int strideX, int outputsX, int imgSize, bool sum,
float scaleTargets, float scaleOutput);
void convLocalMaxUndo(NVMatrix& images, NVMatrix& maxGrads, NVMatrix& maxActs, NVMatrix& target,
int subsX, int startX, int strideX, int outputsX, float scaleTargets, float scaleOutput);
Expand Down Expand Up @@ -68,7 +68,7 @@ void convReflectHorizontal(NVMatrix& images, NVMatrix& targets, int imgSize);
void convCrossMapMaxPoolUndo(NVMatrix& images, NVMatrix& maxGrads, NVMatrix& maxActs, NVMatrix& target,
const int imgSize, const int startF, const int poolSize,
const int stride, const float scaleTargets, const float scaleOutputs);

template<bool sum>
class AvgPooler {
public:
__device__ inline float operator()(const float a, const float b) const {
Expand All @@ -78,7 +78,7 @@ public:
return 0;
}
__device__ inline float output(const float a, const int regionSize) const {
return a / regionSize;
return sum ? a : (a / regionSize);
}
};

Expand All @@ -88,7 +88,7 @@ public:
return fmaxf(a, b);
}
__device__ inline float getBaseValue() const {
return -2e38;
return -2e38;
}
__device__ inline float output(const float a, const int regionSize) const {
return a;
Expand All @@ -112,15 +112,15 @@ public:
* Block size B_YxB_X
* blockIdx.x determines output.x, image idx in batches of B_X*imgsPerThread
* blockIdx.y determines output.y, filter idx in batches of B_Y*filtersPerThread
*
*
* So each block does one output for some number of images/filters.
*
*
* threadIdx.x determines img idx
* threadIdx.y determines filter idx
*
*
* imgs: (numFilters, imgPixels, numImages)
* target: (numFilters, numOutputs, numImages)
*
*
* numImages must be divisible by B_X*imgsPerThread if checkCaseBounds is false
*/

Expand All @@ -138,27 +138,27 @@ __global__ void kLocalPool(float* imgs, float* target, const int imgSize, const
if (myFilterIdx >= numFilters) {
return;
}

const int outputIdx = outputIdxY * outputsX + outputIdxX;
const int numOutputs = outputsX * outputsX;
const int imgPixels = imgSize * imgSize;

const int startImgPxX = startX + outputIdxX * strideX;
const int startImgPxY = startX + outputIdxY * strideX;
const int imgIdx = blockImgIdx + threadIdx.x;

imgs += myFilterIdx * imgPixels * numImages + imgIdx;
target += (myFilterIdx * numOutputs + outputIdx) * numImages + imgIdx;

float prod[filtersPerThread][imgsPerThread];
#pragma unroll
for (int f = 0; f < filtersPerThread; f++) {
#pragma unroll
for (int i = 0; i < imgsPerThread; i++) {
prod[f][i] = agg.getBaseValue();
prod[f][i] = agg.getBaseValue();
}
}

const int loopStartY = MAX(0, startImgPxY);
const int loopStartX = MAX(0, startImgPxX);
const int loopEndY = MIN(imgSize, startImgPxY + subsX);
Expand All @@ -178,18 +178,19 @@ __global__ void kLocalPool(float* imgs, float* target, const int imgSize, const
}
}
}

#pragma unroll
for (int i = 0; i < imgsPerThread; i++) {
if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
#pragma unroll
for (int f = 0; f < filtersPerThread; f++) {
target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i], regionSize);
target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i], regionSize);
}
}
}
}


/*
* Block size B_YxB_X
* blockIdx.x determines pixel.x, image idx in batches of B_X*imgsPerThread
Expand Down Expand Up @@ -316,26 +317,26 @@ void convPoolCrossMap(NVMatrix& images, NVMatrix& target, const int startF, cons
* Block size 16xB_X
* blockIdx.x determines 4x4 pixel.x region, image idx in batches of B_X*imgsPerThread
* blockIdx.y determines 4x4 pixel.y region, filter idx in batches of filtersPerThread
*
*
* So each block does a 4x4 region for some number of images/filters.
*
*
* threadIdx.x determines img idx
* threadIdx.y determines pixel idx
*
*
* imgs: (numFilters, imgPixels, numImages)
* target: (numFilters, numOutputs, numImages)
*
*
* B_X one of 8, 16, 32
* imgsPerThread one of 1, 2, 4, 8, 16
*
*
* B_XximgsPerThread MUST be divisible by 32.
* Number of filters MUST be divisible by filtersPerThread.
*
*
* numImages must be divisible by B_X*imgsPerThread if checkCaseBounds is false
*
*
* Final write-out will not be fully coalesced unless B_X is 32. But there's a lot more
* reading than writing here, and the reading is all coalesced, so it should be OK.
*
*
* To be used when the stride is 1 and the pooling region is fairly large.
*/
template<class Agg, int B_X, int imgsPerThread, int filtersPerThread, bool checkCaseBounds>
Expand All @@ -349,26 +350,26 @@ __global__ void kLocalPool2(float* imgs, float* target, const int imgSize, const
const int blockOutputY = 4*(blockIdx.y / numFilterBlocks);
const int blockImgIdx = (blockIdx.x % numImgBlocks) * B_X * imgsPerThread;
const int blockFilterIdx = (blockIdx.y % numFilterBlocks) * filtersPerThread;

// const int blockOutputIdx = blockOutputY * outputsX + blockOutputX;
const int numOutputs = outputsX * outputsX;
const int imgPixels = imgSize * imgSize;

const int tidx = threadIdx.y * B_X + threadIdx.x;
const int loadY = tidx / 32, loadX = tidx % 32;

const int myX = threadIdx.y % 4;
const int myY = threadIdx.y / 4;

const int myOutputIdxY = blockOutputY + myY;
const int myOutputIdxX = blockOutputX + myX;
const int myOutputIdx = myOutputIdxY * outputsX + myOutputIdxX;

const int startImgPxX = startX + blockOutputX;
const int startImgPxY = startX + blockOutputY;
const int endImgPxX = startImgPxX + subsX;
const int endImgPxY = startImgPxY + subsX;

const int myStartImgPxY = startImgPxY + myY;
const int myStartImgPxX = startImgPxX + myX;
const int myEndImgPxY = endImgPxY + myY;
Expand All @@ -378,18 +379,18 @@ __global__ void kLocalPool2(float* imgs, float* target, const int imgSize, const
const int loopStartX = MAX(startImgPxX, 0);
const int loopEndY = MIN(imgSize, endImgPxY + 3);
const int loopEndX = MIN(imgSize, endImgPxX + 3);

const int imgIdx = blockImgIdx + threadIdx.x;

imgs += (blockFilterIdx + loadY) * imgPixels * numImages + blockImgIdx + loadX;
target += (blockFilterIdx * numOutputs + myOutputIdx) * numImages + imgIdx;

float prod[filtersPerThread][imgsPerThread];
#pragma unroll
for (int f = 0; f < filtersPerThread; f++) {
#pragma unroll
for (int i = 0; i < imgsPerThread; i++) {
prod[f][i] = agg.getBaseValue();
prod[f][i] = agg.getBaseValue();
}
}
int regionSize = 0;
Expand Down Expand Up @@ -434,7 +435,7 @@ __global__ void kLocalPool2(float* imgs, float* target, const int imgSize, const
if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
#pragma unroll
for (int f = 0; f < filtersPerThread; f++) {
target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i], regionSize);
target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i], regionSize);
}
}
}
Expand All @@ -453,7 +454,7 @@ void convLocalPool(NVMatrix& images, NVMatrix& target, int numFilters,
assert(images.getNumRows() == numFilters * imgPixels);
int imgSize = int(sqrt(imgPixels));
assert(imgSize * imgSize == imgPixels);

assert(!images.isTrans());
assert(!target.isTrans());
assert(images.isContiguous());
Expand All @@ -463,7 +464,7 @@ void convLocalPool(NVMatrix& images, NVMatrix& target, int numFilters,
int outputs = outputsX * outputsX;
target.resize(numFilters*outputs, numImages);

if (strideX == 1 && subsX >= 6) {
if (strideX == 1 && subsX >= 6 && outputsX > 1) {
// NOTE: this part has not been optimized for Kepler
int imgsPerThread = numImages % 128 == 0 ? 8 : 4;
int filtersPerThread = numFilters % 4 == 0 ? 4 : numFilters % 3 == 0 ? 3 : numFilters % 2 == 0 ? 2 : 1;
Expand Down
Loading

0 comments on commit 5620d0a

Please sign in to comment.