### To do list...

- [x] Implement Direct Imager utilising cupy for single-precision imaging
- [ ] Modify the implementation to support double or single toggling
- [ ] Modify the implementation to support shared memory (for efficiency)
- [ ] Modify the implementation to support batched imaging (i.e., produce multiple images, where each image is some portion of the full observation (timesteps))

In [None]:
%matplotlib inline

import numpy as np
import cupy as cp
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = [10, 10]

#==============================================================#

def show_image(image, title, flip_x_axis=False):
    if flip_x_axis:
        image = np.fliplr(image)
    plt.imshow(image, cmap=plt.get_cmap("gray"))
    plt.title(title)
    plt.colorbar()
    plt.show()
    
#==============================================================#

def normalise(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

#==============================================================#

direct_imaging_kernel = cp.RawKernel(r'''
extern "C" __global__
void direct_imaging_with_w_correction(
    float *image, 
    const float *vis_real,
    const float *vis_imag,
    const float *u_coord,
    const float *v_coord,
    const float *w_coord,
    const unsigned int num_uvw_coords, 
    const unsigned int num_channels,
    const unsigned int image_size,
    const float cell_size_rads, 
    const float frequency_hz_start, 
    const float bandwidth_increment,
    const float pi,
    const float speed_of_light
)
{
    unsigned int pixel_index = blockIdx.x * blockDim.x + threadIdx.x;
    unsigned int total_vis = num_uvw_coords * num_channels;
    if(pixel_index >= image_size*image_size)
        return;

    float x = ((int)(pixel_index % image_size) - (int)image_size/2) * cell_size_rads;
    float y = ((int)(pixel_index / image_size) - (int)image_size/2) * cell_size_rads;

    float image_correction = sqrt(1.0f - (x * x) - (y * y));
    float w_correction = image_correction - 1.0f;

    float sum = 0.0f;

    for(unsigned int v = 0; v < total_vis; v++)
    {	
        float2 current_vis = make_float2(vis_real[v], vis_imag[v]);
        unsigned int current_baseline = v / num_channels;
        unsigned int current_channel = v % num_channels;
        float metres_to_wavelength = (frequency_hz_start + (bandwidth_increment * current_channel)) / speed_of_light;
        float3 uvw_coord = make_float3(u_coord[current_baseline], v_coord[current_baseline], w_coord[current_baseline]);
        uvw_coord.x *= metres_to_wavelength;
        uvw_coord.y *= metres_to_wavelength;
        uvw_coord.z *= metres_to_wavelength;
        
        float2 theta_complex = make_float2(0.0, 0.0);
        float theta = 2.0f * pi * (x * uvw_coord.x + y * uvw_coord.y + w_correction * uvw_coord.z);
        sincos(theta, &(theta_complex.y), &(theta_complex.x));
        // accumulate real complex product of vis and theta to running pixel sum
        sum += current_vis.x * theta_complex.x - current_vis.y * theta_complex.y;
    }

    image[pixel_index] += (sum * image_correction);
}''', "direct_imaging_with_w_correction")

#==============================================================#

In [None]:
# Configurable params
dataset_folder = "../datasets/gleam_small/"
vis_intensity_file = dataset_folder + "gleam_small_ts_bl_ch.vis"
vis_uvw_file =  dataset_folder + "gleam_small_ts_bl_ch.uvw"
image_size = 100
fov_degrees = 1.0
num_timesteps = 30
num_chans = 1
num_recv = 512
freq_start_hz = 140000000
freq_bandwidth = 0
right_ascension = True # flip u and w coords
timesteps_per_image = 5 # number of timesteps to use per generated image

# Calculated params
num_baselines = num_recv * (num_recv - 1) // 2
num_uvw_coords = num_baselines * num_timesteps
num_visibilities = num_uvw_coords * num_chans
image_cell_size_radians = np.arcsin(2.0 * np.sin(0.5 * fov_degrees * np.pi / 180.0) / image_size)
num_images_to_generate = num_timesteps // timesteps_per_image

# GPU work distribution
max_threads_per_block = np.minimum(1024, image_size**2)
num_blocks = np.int32(np.ceil(image_size**2 / max_threads_per_block))
kernel_blocks = (num_blocks, 1, 1)
kernel_threads = (max_threads_per_block, 1, 1)

In [None]:
# Data preparation
vis_count_from_file = np.fromfile(vis_intensity_file, dtype=np.int32, count=1)[0]
# print(f"Visibility count according to file: {vis_count_from_file}")
uvw_count_from_file = np.fromfile(vis_uvw_file, dtype=np.int32, count=1)[0]
# print(f"UVW coord count according to file: {uvw_count_from_file}")

visibilities_host = np.fromfile(vis_intensity_file, dtype=np.float32, count=vis_count_from_file*2, offset=np.dtype(np.int32).itemsize)
visibilities_host = visibilities_host.reshape(vis_count_from_file, 2)
# print(visibilities_host.shape)
uvw_coords_host = np.fromfile(vis_uvw_file, dtype=np.float32, count=uvw_count_from_file*3, offset=np.dtype(np.int32).itemsize)
uvw_coords_host = uvw_coords_host.reshape(uvw_count_from_file, 3)
# print(uvw_coords_host[0])

In [None]:
vis_per_timestep = num_baselines * num_chans
uvw_per_timestep = vis_per_timestep # only valid for single channel, needs refactoring later

print(f"Vis per timestep: {vis_per_timestep}")

for batch in np.arange(num_images_to_generate):
    
    timestep_range_start = batch * timesteps_per_image
    timestep_range_end = timestep_range_start + timesteps_per_image
    
    vis_batch_start = timestep_range_start * vis_per_timestep
    vis_batch_end = vis_batch_start + timesteps_per_image * vis_per_timestep
    uvw_batch_start = timestep_range_start * vis_per_timestep
    uvw_batch_end = uvw_batch_start + timesteps_per_image * vis_per_timestep

    # Device preparation
    vis_real_gpu = cp.asarray(visibilities_host[vis_batch_start:vis_batch_end, 0])
    vis_imag_gpu = cp.asarray(visibilities_host[vis_batch_start:vis_batch_end, 1])
    
    # print(f"Sample vis from current batch: vis real imag => {vis_real_gpu[0]} {vis_imag_gpu[0]}")

    u_coords_gpu = cp.asarray(uvw_coords_host[uvw_batch_start:uvw_batch_end, 0])
    v_coords_gpu = cp.asarray(uvw_coords_host[uvw_batch_start:uvw_batch_end, 1])
    w_coords_gpu = cp.asarray(uvw_coords_host[uvw_batch_start:uvw_batch_end, 2])

    # print(f"Sample uvw from current batch: uvw => {u_coords_gpu[0]} {v_coords_gpu[0]} {w_coords_gpu[0]}")
    
    if right_ascension:
        u_coords_gpu *= -1.0
        w_coords_gpu *= -1.0

    image_gpu = cp.zeros((image_size, image_size), dtype=np.float32)

    # Kernel execution
    direct_imaging_kernel(kernel_blocks, kernel_threads, (
        image_gpu,
        vis_real_gpu,
        vis_imag_gpu,
        u_coords_gpu,
        v_coords_gpu,
        w_coords_gpu,
        uvw_per_timestep * timesteps_per_image,
        num_chans,
        image_size,
        np.float32(image_cell_size_radians),
        np.float32(freq_start_hz),
        np.float32(freq_bandwidth),
        np.float32(np.pi),
        np.float32(299792458.0)
    ))

    # Obtain generated image
    image_host = cp.asnumpy(image_gpu)
    show_image(image_host, f"Direct Image (timesteps {timestep_range_start} to {timestep_range_end - 1}")

In [None]:
# Comparing against all timesteps IDFT from C implementation
reference = np.fromfile("../data/direct_image_ts_0_29.bin", dtype=np.float32)
reference = reference.reshape(100, 100)
show_image(reference, "Reference from C")

show_image(np.absolute(reference - image_host), "Abs diff")

In [None]:
# Cupy clean up (memory wont fully dealloc until python instance is dead)
mempool = cp.get_default_memory_pool()
mempool.free_all_blocks()