In [1]:
import pandas as pd
import numpy as np
import os
from PIL import Image, ImageDraw
from matplotlib import pyplot as plt

from astropy.io import fits
from astropy.wcs import WCS
from photutils import detect_threshold
from astropy.stats import gaussian_fwhm_to_sigma
from photutils import detect_sources
from astropy.convolution import Gaussian2DKernel

from pyspark import SparkContext
from pyspark.sql import SparkSession

import os, fnmatch
import warnings
import shutil



In [2]:
try:
    sc = SparkContext(appName="SDDM") #, master='spark://fs.dslc.liacs.nl:7078')
except ValueError:
    warnings.warn("SparkContext already exists in this scope")
    

In [3]:
IMAGE_PATH = '/data/astronomy-big-data/bc96e9620e41b9aba98292d37b5eec24/LoTSS_DR2_mosaic/' # './images_upper'
WRITE_PATH = './mosaic_images'

In [5]:
# Take two fits files
fits_files = [IMAGE_PATH + f for f in os.listdir(IMAGE_PATH) if not os.path.isfile(os.path.join(WRITE_PATH, f.split('_')[0]) + '.png')]

In [10]:
# Put the paths in an RDD and determine number of partitions
# More partitions == more cpu and faster (can crash when partitions are too high)
file_paths = sc.parallelize(fits_files, 64) # , len(fits_files)
print(len(file_paths.collect()))

54


In [11]:
# Map the RDD with paths to the fits.open() function to get the content in the .fits
fits_content = file_paths.map(lambda file: fits.open(file)[0], 64)
fits_content.getNumPartitions()

64

In [12]:
# Save pair of fits content: threshold matrix in a RDD.
fits_thresh = fits_content.map(lambda content: (content, detect_threshold(content.data , nsigma=3.)))

In [13]:
sigma = 3.0 * gaussian_fwhm_to_sigma  # FWHM = 3.

kernel = Gaussian2DKernel(sigma, x_size=3, y_size=3)
kernel.normalize()

# sources = object with label of sources for each pixel in the fits file
# Save pair of fits content: source matrix in a RDD.
fits_sources = fits_thresh.map(lambda ft: (ft[0].header,
                                           detect_sources(ft[0].data, ft[1], npixels=16, filter_kernel=kernel),
                                           )
                                     )
fits_sources.getNumPartitions()

64

In [14]:
def getImage(header, sources):    
    cmap = sources.make_cmap(random_state=12345)
    x = header['NAXIS1']
    y = header['NAXIS2']

    fig = plt.figure(frameon=False)
    fig.set_size_inches(x/1000,y/1000)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.imshow(sources, origin='upper', cmap=cmap)
    fig.savefig('%s/%s.png'%(WRITE_PATH,header["OBJECT"]), dpi=1000)
    return (header["OBJECT"], True)

In [15]:
images = fits_sources.map(lambda x: getImage(x[0], x[1])).collect()