In [1]:
%%writefile mini_batch_train.cu
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cuda_runtime.h>

#define CHECK(call) { cudaError_t err = call; if(err!=cudaSuccess){ \
    fprintf(stderr,"CUDA error %s:%d: %s\n",__FILE__,__LINE__,cudaGetErrorString(err)); exit(1);} }

const int N=8;
const int BATCH=4;
const float ETA=0.1f;

__global__ void forwardKernel(const float* x,const float* w,const float* bias,float* y,int n,int batch){
  int i=blockIdx.x*blockDim.x+threadIdx.x;
  int b=blockIdx.y*blockDim.y+threadIdx.y;
  if(i<n && b<batch){
    int idx=b*n+i;
    float bi=bias?bias[i]:0.0f;
    y[idx]=w[i]*x[idx]+bi;
  }
}

__global__ void updateKernel(float* w,const float* x,const float* y,const float* t,int n,int batch,float eta){
  int i=blockIdx.x*blockDim.x+threadIdx.x;
  if(i<n){
    float g=0.0f;
    for(int b=0;b<batch;++b){
      int idx=b*n+i;
      float err=y[idx]-t[idx];
      g+=err*x[idx];
    }
    g/=float(batch);
    w[i]-=eta*g;
  }
}

void printArray(const char* msg,const float* A,int n){
  printf("%s",msg);
  for(int i=0;i<n;++i) printf(" %.4f",A[i]);
  printf("\n");
}

int main(){
  printf("N=%d, BATCH=%d, ETA=%0.3f\n",N,BATCH,ETA);
  const int total=N*BATCH;
  float *hx=(float*)malloc(sizeof(float)*total);
  float *hw=(float*)malloc(sizeof(float)*N);
  float *hb=(float*)malloc(sizeof(float)*N);
  float *ht=(float*)malloc(sizeof(float)*total);

  for(int i=0;i<N;++i){ hw[i]=0.5f+0.1f*i; hb[i]=0.0f; }
  for(int b=0;b<BATCH;++b)
    for(int i=0;i<N;++i){
      int idx=b*N+i;
      hx[idx]=1.0f+0.5f*i+0.1f*b;
      ht[idx]=(hw[i]*hx[idx])+0.5f;
    }

  printArray("Initial weights:",hw,N);

  float *dx,*dw,*db,*dt,*dy;
  CHECK(cudaMalloc(&dx,sizeof(float)*total));
  CHECK(cudaMalloc(&dw,sizeof(float)*N));
  CHECK(cudaMalloc(&db,sizeof(float)*N));
  CHECK(cudaMalloc(&dt,sizeof(float)*total));
  CHECK(cudaMalloc(&dy,sizeof(float)*total));

  CHECK(cudaMemcpy(dx,hx,sizeof(float)*total,cudaMemcpyHostToDevice));
  CHECK(cudaMemcpy(dw,hw,sizeof(float)*N,cudaMemcpyHostToDevice));
  CHECK(cudaMemcpy(db,hb,sizeof(float)*N,cudaMemcpyHostToDevice));
  CHECK(cudaMemcpy(dt,ht,sizeof(float)*total,cudaMemcpyHostToDevice));

  dim3 blockF(8,4);
  dim3 gridF((N+blockF.x-1)/blockF.x,(BATCH+blockF.y-1)/blockF.y);
  int tpb=128; dim3 blockU((N<tpb)?N:tpb); dim3 gridU((N+blockU.x-1)/blockU.x);

  printf("\n--- Serial kernels (default stream) ---\n");
  forwardKernel<<<gridF,blockF>>>(dx,dw,db,dy,N,BATCH);
  updateKernel<<<gridU,blockU>>>(dw,dx,dy,dt,N,BATCH,ETA);
  CHECK(cudaDeviceSynchronize());
  CHECK(cudaMemcpy(hw,dw,sizeof(float)*N,cudaMemcpyDeviceToHost));
  printArray("Weights after serial update:",hw,N);

  for(int i=0;i<N;++i) hw[i]=0.5f+0.1f*i;
  CHECK(cudaMemcpy(dw,hw,sizeof(float)*N,cudaMemcpyHostToDevice));

  printf("\n--- Concurrent streams ---\n");
  cudaStream_t s0,s1; CHECK(cudaStreamCreate(&s0)); CHECK(cudaStreamCreate(&s1));
  forwardKernel<<<gridF,blockF,0,s0>>>(dx,dw,db,dy,N,BATCH);
  updateKernel<<<gridU,blockU,0,s1>>>(dw,dx,dy,dt,N,BATCH,ETA);
  CHECK(cudaStreamSynchronize(s0)); CHECK(cudaStreamSynchronize(s1));
  CHECK(cudaMemcpy(hw,dw,sizeof(float)*N,cudaMemcpyDeviceToHost));
  printArray("Weights after concurrent streams:",hw,N);

  CHECK(cudaStreamDestroy(s0)); CHECK(cudaStreamDestroy(s1));
  cudaFree(dx); cudaFree(dw); cudaFree(db); cudaFree(dt); cudaFree(dy);
  free(hx); free(hw); free(hb); free(ht);
  return 0;
}


Writing mini_batch_train.cu


In [2]:
!nvcc mini_batch_train.cu -o mini_batch_train

In [3]:
!./mini_batch_train

N=8, BATCH=4, ETA=0.100
Initial weights: 0.5000 0.6000 0.7000 0.8000 0.9000 1.0000 1.1000 1.2000

--- Serial kernels (default stream) ---
Weights after serial update: 0.5000 0.6000 0.7000 0.8000 0.9000 1.0000 1.1000 1.2000

--- Concurrent streams ---
Weights after concurrent streams: 0.5000 0.6000 0.7000 0.8000 0.9000 1.0000 1.1000 1.2000
