<a href="https://colab.research.google.com/github/sohamch/VKMC/blob/SymmNets/Lattice_Gas/CE_Symmetry/Numba_Cuda/Numba_GConv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import math
from numba import cuda, jit, float32, float64, int64, uint8, int16

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
FP = "/content/drive/My Drive/Colab/Numba_Cuda/GConv/"

In [5]:
NsitesMax = 1024
NChMax = 100
NngbMax = 13 # maximum for close packed structs +1 for the site itself     
NChSitesMax = NChMax*NngbMax                                               

In [101]:
# write device function for softplus
@cuda.jit(device=True)
def softPlus(x, beta, threshold):
    if x < -threshold:
        return 0.0
    elif x > threshold:
        return x
    else:
        return math.log(1.0 + math.exp(beta*x))/beta

@cuda.jit(device=True)
def gradSoftPlus(x, beta, threshold):
    if x < -threshold:
        return 0.0
    elif x > threshold:
        return 1.0
    else:
        return 1./(1.0 + math.exp(-beta*x))

In [149]:
@cuda.jit
def Gconv(InputImage, Psi, OutImage, OutF, SiteNeighbors, GnnPerms, N_array,
          sp_beta, sp_threshold):
    # first get locations
    ty = cuda.threadIdx.y
    bx, by, bz = cuda.blockIdx.x, cuda.blockIdx.y, cuda.blockIdx.z

    bSizeX, bSizeY, bSizeZ = cuda.blockDim.x, cuda.blockDim.y, cuda.blockDim.z

    Nsites, NInCh, NOutCh, N_ngb, Ng = N_array[0], N_array[1], N_array[2], N_array[3], N_array[4]

    # Get the necessary output indices
    batchInd = bx # which sample the thread is working with
    outCh = (by*bSizeY + ty)//Nsites  # thread's output channel
    siteInd = (by*bSizeY + ty)%Nsites # which site the conv is over
    gInd = bz  # which group operation the thread is handling

    # Get the neighborhood of the current site
    # in the thread's local memory
    NgbIndices = cuda.local.array(shape=(NngbMax,), dtype=int64)
    for ngb in range(N_ngb):
        NgbIndices[ngb] = SiteNeighbors[siteInd, ngb]

    # create the shared arrays to store filters and input elements
    InChannel = cuda.shared.array(shape=(NsitesMax,), dtype=float64)
    Filter = cuda.shared.array(shape=(NChSitesMax,), dtype=float64)

    # Store the Group rotations of nns into shared memory
    gnnRotShared = cuda.shared.array(shape=(NngbMax,), dtype=uint8)
    if ty < N_ngb:
        gnnRotShared[ty] = GnnPerms[gInd, ty]

    linSum = 0.
    for inCh in range(NInCh):
        # First read the input channel into shared memory
        for sweep in range(Nsites//bSizeY + 1):
            threadSiteInd = sweep*bSizeY + ty
            if threadSiteInd < Nsites:
                InChannel[threadSiteInd] = InputImage[batchInd, inCh, threadSiteInd]
        
        # Then read the filter for this input channel
        # Apply Group permutation to it as well
        for sweep in range((NOutCh*N_ngb)//bSizeY + 1):
            threadElemInd = sweep*bSizeY + ty
            if threadElemInd < NOutCh*N_ngb:
                Filter[threadElemInd] = Psi[inCh, threadElemInd]
        
        # synchronize the block
        cuda.syncthreads()

        # Reading phase is done - now convolve
        for ngb in range(N_ngb):
            ngbSite = NgbIndices[ngb]
            linSum += Filter[outCh*N_ngb + gnnRotShared[ngb]] * InChannel[ngbSite]

    OutF[batchInd, outCh, gInd, siteInd] = linSum
    nonLin = softPlus(linSum, sp_beta, sp_threshold)/Ng
    # atomically sum out the group channel
    cuda.atomic.add(OutImage, (batchInd, outCh, siteInd), nonLin)

In [134]:
# load the data
NNSites = np.load(FP + "NNsites_sitewise.npy").T
GNNperms = np.load(FP + "GroupNNpermutations.npy")
RtoSiteInd = np.load(FP + "RtoSiteInd.npy")
SiteIndtoR = np.load(FP + "SiteIndtoR.npy")
(Nsites, N_ngb) = NNSites.shape
Ng = GNNperms.shape[0]
print(N_ngb, Nsites, Ng)

9 512 48


In [135]:
# # load the pickle files
# import pickle
# with open(FP + "supercellBCC.pkl", "rb") as fl:
#     superBCC = pickle.load(fl)

# with open(FP + "GroupOpsIndices.pkl", "rb") as fl:
#     GIndices = pickle.load(fl)

# with open(FP + "jnetBCC.pkl") as fl:
#     jNetBCC = pickle.load(fl)

In [164]:
# Create a random input and output image map
Nsites = 512
Nbatch = 512
NchIn = 1
NchOut = 1
InImage = np.random.rand(Nbatch, NchIn, Nsites)
OutImage = np.zeros((Nbatch, NchOut, Nsites))
OutImageF = np.zeros((Nbatch, NchOut, Ng, Nsites))

In [165]:
# Now set up a random filter
Psi = np.random.rand(NchIn, NchOut*N_ngb)

In [166]:
ty = 512
NbatchRun = 512
bX = NbatchRun
bY = int(np.ceil((NchOut*Nsites)/ty))
bZ = Ng

In [171]:
%%time
d_input = cuda.to_device(InImage)
d_NNSites = cuda.to_device(NNSites)
d_GNNperms = cuda.to_device(GNNperms)
d_Psi = cuda.to_device(Psi)

N_array = np.array([Nsites, NchIn, NchOut, N_ngb, Ng])
d_Narray = cuda.to_device(N_array)

CPU times: user 4.1 ms, sys: 998 µs, total: 5.1 ms
Wall time: 5.19 ms


In [172]:
d_output = cuda.to_device(OutImage)
d_outF = cuda.to_device(OutImageF)

In [173]:
%%time
Gconv[(bX, bY, bZ), (1, ty, 1)](d_input, d_Psi, d_output, d_outF, 
                                d_NNSites, d_GNNperms, d_Narray,
                                1.0, 20.0)

TypeError: ignored

In [142]:
# Copy the output to the host
HostOut = d_output.copy_to_host()
HostF = d_outF.copy_to_host()

In [143]:
# Now let's do the conv explictly
# We'll test randomly chosen samples for time considerations
def softPlusCPU(x, beta, threshold):
    if x < -threshold:
        return 0.0
    elif x > threshold:
        return x
    else:
        return math.log(1.0 + math.exp(beta*x))/beta

# select a random sample
sampInd = 4 #np.random.randint(0, 512)
outSamp = OutImage[sampInd].copy()
outSampF = OutImageF[sampInd].copy()
for outCh in range(NchOut):
    for siteInd in range(Nsites):
        # Now go through the input channels
        gsum = 0.
        for gInd in range(Ng):
            linSum = 0.
            for inCh in range(NchIn):
                for ngb in range(N_ngb):
                    filt = Psi[inCh, outCh*N_ngb + GNNperms[gInd, ngb]]
                    linSum += filt * InImage[sampInd, inCh, NNSites[siteInd, ngb]]
            
            outSampF[outCh, gInd, siteInd] = linSum

            gsum += softPlusCPU(linSum, 1, 20)/Ng
        
        outSamp[outCh, siteInd] = gsum
                

In [147]:
outSampF

array([[[2.03696419, 2.42512621, 1.87881436, ..., 2.23762819,
         1.49430208, 1.5518961 ],
        [1.87122017, 2.41304652, 1.98418778, ..., 2.00494354,
         1.2239683 , 1.76007763],
        [1.87836218, 2.16484055, 1.69372213, ..., 2.40231442,
         1.15666919, 1.36851503],
        ...,
        [2.06833637, 2.49414805, 1.70043077, ..., 2.03345768,
         1.51398718, 1.53692203],
        [2.0281402 , 2.5939968 , 1.69671463, ..., 1.89067055,
         1.5956805 , 1.90005372],
        [1.94326061, 2.45165771, 1.45563583, ..., 1.83537716,
         1.43415504, 1.77729499]]])

In [145]:
HostF[sampInd].shape

(1, 48, 512)

In [146]:
HostF[sampInd]

array([[[4.07899376, 4.29092511, 3.73490141, ..., 4.14735469,
         2.88518304, 3.56469688],
        [4.07899376, 4.17969302, 3.78667675, ..., 4.14735469,
         2.88518304, 3.56469688],
        [4.07899376, 4.02279228, 3.69589603, ..., 4.14735469,
         2.88518304, 3.56469688],
        ...,
        [4.07899376, 4.29092511, 3.73490141, ..., 4.14735469,
         2.88518304, 3.56469688],
        [4.07899376, 4.02279228, 3.69589603, ..., 4.14735469,
         2.88518304, 3.56469688],
        [4.07899376, 3.77545258, 3.72511403, ..., 4.14735469,
         2.88518304, 3.56469688]]])

In [148]:
HostF[sampInd].shape

(1, 48, 512)

In [163]:
Nsh = d_Nsites.copy_to_host()
int(Nsh)

512