Permalink
Browse files

Tiling (tt_densetile()) parallelized.

  • Loading branch information...
ShadenSmith committed Nov 25, 2017
1 parent 99fb5f8 commit b4bbad44824e1fd7893beb79c4ac99a25edff47e
Showing with 108 additions and 36 deletions.
  1. +108 −36 src/tile.c
View
@@ -5,7 +5,9 @@
#include "tile.h"
#include "sort.h"
#include "timer.h"
#include "io.h"
#include "thd_info.h"
#include "thread_partition.h"
#include "util.h"
/******************************************************************************
@@ -261,64 +263,134 @@ idx_t * tt_densetile(
sptensor_t * const tt,
idx_t const * const tile_dims)
{
timer_start(&timers[TIMER_TILE]);
idx_t const nmodes = tt->nmodes;
/*
* Count tiles and compute their dimensions.
*/
idx_t ntiles = 1;
for(idx_t m=0; m < nmodes; ++m) {
ntiles *= tile_dims[m];
}
/* the actual number of indices to place in each tile */
idx_t tsizes[MAX_NMODES];
for(idx_t m=0; m < nmodes; ++m) {
tsizes[m] = SS_MAX(tt->dims[m] / tile_dims[m], 1);
}
idx_t * tcounts = calloc(ntiles+2, sizeof(*tcounts));
/* count tile sizes (in nnz) */
idx_t coord[MAX_NMODES];
for(idx_t x=0; x < tt->nnz; ++x) {
for(idx_t m=0; m < nmodes; ++m) {
/* capping at dims-1 fixes overflow when dims don't divide evenly */
coord[m] = SS_MIN(tt->ind[m][x] / tsizes[m], tile_dims[m]-1);
}
/* offset by 1 to make prefix sum easy */
idx_t const id = get_tile_id(tile_dims, nmodes, coord);
assert(id < ntiles);
++tcounts[2+id];
}
/* We'll copy the newly tiled non-zeros into this one, then copy back */
sptensor_t * newtt = tt_alloc(tt->nnz, tt->nmodes);
/* prefix sum */
for(idx_t t=3; t <= ntiles+1; ++t) {
tcounts[t] += tcounts[t-1];
/*
* Count of non-zeros per tile. We use +1 because after a prefix sum, this
* becomes a pointer into the non-zeros for each tile (e.g., csr->row_ptr).
*/
idx_t * tcounts_global = splatt_malloc((ntiles+1) * sizeof(*tcounts_global));
for(idx_t t=0; t < ntiles+1; ++t) {
tcounts_global[t] = 0;
}
sptensor_t * newtt = tt_alloc(tt->nnz, tt->nmodes);
/* copy old tensor into new tiled one */
for(idx_t x=0; x < tt->nnz; ++x) {
for(idx_t m=0; m < nmodes; ++m) {
coord[m] = SS_MIN(tt->ind[m][x] / tsizes[m], tile_dims[m]-1);
/*
* A matrix of thread-local counters.
*/
int const nthreads = splatt_omp_get_max_threads();
idx_t * * tcounts_thread = splatt_malloc(
(nthreads+1) * sizeof(*tcounts_thread));
/* After the prefix sum, the global counter will have the sum of all nnz in
* each tile (across threads), and thus can be returned. */
tcounts_thread[nthreads] = tcounts_global;
/* partition the non-zeros */
idx_t * thread_parts = partition_simple(tt->nnz, nthreads);
#pragma omp parallel
{
int const tid = splatt_omp_get_thread_num();
idx_t const nnz_start = thread_parts[tid];
idx_t const nnz_end = thread_parts[tid+1];
/* allocate / initialize thread-local counters */
tcounts_thread[tid] = splatt_malloc(ntiles * sizeof(**tcounts_thread));
for(idx_t tile=0; tile < ntiles; ++tile) {
tcounts_thread[tid][tile] = 0;
}
#pragma omp barrier
/* offset by 1 to make prefix sum easy */
idx_t const id = get_tile_id(tile_dims, nmodes, coord);
assert(id < ntiles);
idx_t * tcounts_local = tcounts_thread[tid+1];
/* count tile sizes (in nnz) */
idx_t coord[MAX_NMODES];
for(idx_t x=nnz_start; x < nnz_end; ++x) {
for(idx_t m=0; m < nmodes; ++m) {
/* capping at dims-1 fixes overflow when dims don't divide evenly */
coord[m] = SS_MIN(tt->ind[m][x] / tsizes[m], tile_dims[m]-1);
}
idx_t const id = get_tile_id(tile_dims, nmodes, coord);
assert(id < ntiles);
++tcounts_local[id];
}
idx_t newidx = tcounts[id+1]++;
newtt->vals[newidx] = tt->vals[x];
for(idx_t m=0; m < nmodes; ++m) {
newtt->ind[m][newidx] = tt->ind[m][x];
#pragma omp barrier
#pragma omp single
{
/* prefix sum for each tile */
for(idx_t tile=0; tile < ntiles; ++tile) {
for(int thread=0; thread < nthreads; ++thread) {
tcounts_thread[thread+1][tile] += tcounts_thread[thread][tile];
}
/* carry over to next tile */
if(tile < (ntiles-1)) {
tcounts_thread[0][tile+1] += tcounts_thread[nthreads][tile];
}
}
} /* implied barrier */
/* grab my starting indices now */
tcounts_local = tcounts_thread[tid];
/*
* Rearrange old tensor into new tiled one.
*/
for(idx_t x=nnz_start; x < nnz_end; ++x) {
for(idx_t m=0; m < nmodes; ++m) {
coord[m] = SS_MIN(tt->ind[m][x] / tsizes[m], tile_dims[m]-1);
}
/* offset by 1 to make prefix sum easy */
idx_t const id = get_tile_id(tile_dims, nmodes, coord);
assert(id < ntiles);
idx_t const newidx = tcounts_local[id]++;
newtt->vals[newidx] = tt->vals[x];
for(idx_t m=0; m < nmodes; ++m) {
newtt->ind[m][newidx] = tt->ind[m][x];
}
}
}
/* copy data into old struct */
memcpy(tt->vals, newtt->vals, tt->nnz * sizeof(*tt->vals));
splatt_free(tcounts_local);
} /* end omp parallel */
/* copy tiled data into old struct */
par_memcpy(tt->vals, newtt->vals, tt->nnz * sizeof(*tt->vals));
for(idx_t m=0; m < nmodes; ++m) {
memcpy(tt->ind[m], newtt->ind[m], tt->nnz * sizeof(**tt->ind));
par_memcpy(tt->ind[m], newtt->ind[m], tt->nnz * sizeof(**tt->ind));
}
/* shift counts to the right by 1 to make proper pointer */
memmove(tcounts_global+1, tcounts_global, ntiles * sizeof(*tcounts_global));
tcounts_global[0] = 0;
assert(tcounts_global[ntiles] == tt->nnz);
tt_free(newtt);
splatt_free(tcounts_thread);
splatt_free(thread_parts);
return tcounts;
timer_stop(&timers[TIMER_TILE]);
return tcounts_global;
}

0 comments on commit b4bbad4

Please sign in to comment.