# Writing data into FITS files with PySpark

As of version 0.7 and earlier, spark-fits can only read and distribute FITS files but cannot write DataFrame data into FITS files on disk. Such a writer shall come one day, but in the meantime if you are using PySpark to perform your computation, writing into FITS files can be done in only very few steps.

Note: this notebook presents _workarounds_ or _tricks_ to quickly write DataFrame data into FITS files. This is clearly not meant to be used in production mode!

## Initial data set

Let's load a test FITS file from the spark-fits repo. Note that at this stage, we could load data from any file formats supported by Spark as read and write processes are disconnected.

In [1]:
df = spark.read.format("fits").option("hdu", 1).load("../../src/test/resources/test_file.fits")

In [2]:
df.show(5)

+----------+---------+--------------------+-----+-----+
|    target|       RA|                 Dec|Index|RunId|
+----------+---------+--------------------+-----+-----+
|NGC0000000| 3.448297| -0.3387486324784641|    0|    1|
|NGC0000001| 4.493667| -1.4414990980543227|    1|    1|
|NGC0000002| 3.787274|  1.3298379564211742|    2|    1|
|NGC0000003| 3.423602|-0.29457151504987844|    3|    1|
|NGC0000004|2.6619017|  1.3957536426732444|    4|    1|
+----------+---------+--------------------+-----+-----+
only showing top 5 rows



Let's force the repartitionning of the data set in case there is only one partition.
Also add one column with the partition ID

In [3]:
from pyspark.sql.functions import spark_partition_id

numPart = df.rdd.getNumPartitions()
if numPart == 1:
    df = df.repartition(4).withColumn("partId", spark_partition_id())
else:
    df = df.withColumn("partId", spark_partition_id())
    
print("Number of partitions: ", df.rdd.getNumPartitions())
df.show()

Number of partitions:  4
+----------+----------+--------------------+-----+-----+------+
|    target|        RA|                 Dec|Index|RunId|partId|
+----------+----------+--------------------+-----+-----+------+
|NGC0001880|  5.261077| -1.2750667383256107| 1880|    1|     0|
|NGC0016891| 1.2029302|  0.8990373524969306|16891|    1|     0|
|NGC0010745| 2.8954773|-0.20057668623844616|10745|    1|     0|
|NGC0008653| 1.5082117|  0.9318336009561894| 8653|    1|     0|
|NGC0006277| 5.1847734|0.023247432847754768| 6277|    1|     0|
|NGC0006331|  2.391221| -0.2515119465585818| 6331|    1|     0|
|NGC0000829| 0.9999453|  1.0383455102037864|  829|    1|     0|
|NGC0012088| 1.6934042| -0.5576700853686516|12088|    1|     0|
|NGC0004790| 4.1513066|  1.2978590130800844| 4790|    1|     0|
|NGC0004439|  5.759273|  0.5185091392821048| 4439|    1|     0|
|NGC0014863| 2.6574259| -1.3711896358852964|14863|    1|     0|
|NGC0007658| 2.5824444| -1.1253206398660374| 7658|    1|     0|
|NGC0007935|0.3

## Writing data with minimal effort

In this first example, the user provides the names of columns, and data types.
Since FITS has its own way to express data type, we need a data type conversion prior to writing the data.
This is provided by the `toTFORM` routine.

In [4]:
from astropy.io import fits
import numpy as np
def toTFORM(value):
    """ Simple data type converter.
    
    NOTE: Due to the nature of Python,
    float will be converted to double, and int to long 
    automatically...

    Parameters
    ----------
    value: Any
        Input value from which we want to know the 
        name of the type in the FITS language (TFORM).

    Returns
    ----------
    out : str
        Corresponding TFORM.

    Examples
    ----------
    >>> toTFORM(1)
    K
    >>> toTFORM("toto")
    A4
    >>> toTFORM(np.int16(10))
    I
    >>> toTFORM(3.4)
    D
    """
    if type(value) == str:
        ft = "A" + str(len(value))
    else:
        tt = fits.column._dtype_to_recformat(type(value))[0]
        ft = fits.column._convert_record2fits(tt)
    return ft

def write_from_user(part, colnames, coltypes, fitsname):
    """ Write DataFrame data into FITS file on disk.
    By default, there is one file per partition, 
    and data is written in the HDU 1.
    
    Parameters
    ----------
    part : Iterator
        Iterator containing data and partition ID.
    colnames : List of str
        List containing the names of the columns
    coltypes : List of str
        List containing the data types (FITS language - TFORM)
    fitsname : str
        Name for the output set of files (<blah>.fits). The set of files will
        be then part<number>_<blah>.fits.
    """
    # We assume the data contains 
    # data array (0, ..., N-1) & partition ID (Nth)
    data = np.transpose([*part])
    data_fits = data[0:-1]
    partId = np.unique(data[-1])[0]
    
    # Create fake primary HDU
    hdr = fits.Header()
    primary_hdu = fits.PrimaryHDU(header=hdr)
    
    # HDU containing data
    # Loop over columns
    cols_ = []
    for d, k, v in zip(data_fits, colnames, coltypes):
        cols_.append(fits.Column(name=k, format=v, array=d))
    cols = fits.ColDefs(cols_)
    hdu1 = fits.BinTableHDU.from_columns(cols)

    hdul = fits.HDUList([primary_hdu, hdu1])
    
    fnout = "part{}_{}".format(partId, fitsname)
    hdul.writeto(fnout)
    
    # Return 0
    yield 0

In [5]:
# Name for the output
fitsname = "from_write.fits"

# Names and dtypes
colnames = df.columns
oneRow = df.take(1)[0]
coltypes = [toTFORM(i) for i in oneRow[:-1]]

# Write data on disk. The count() is just here
# to trigger the mapPartitions.
df.rdd.mapPartitions(
    lambda part: write_from_user(part, colnames, coltypes, fitsname)).count()

4

Check the process went right:

In [6]:
df2 = spark.read.format("fits").option("hdu", 1).load("part*{}".format(fitsname))
print("INPUT: ", df.count(), " elements")
df.drop("partId").orderBy("target").show(5)
df.drop("partId").printSchema()
print("OUTPUT: ", df2.count(), " elements")
df2.orderBy("target").show(5)
df2.printSchema()

INPUT:  20000  elements
+----------+---------+--------------------+-----+-----+
|    target|       RA|                 Dec|Index|RunId|
+----------+---------+--------------------+-----+-----+
|NGC0000000| 3.448297| -0.3387486324784641|    0|    1|
|NGC0000001| 4.493667| -1.4414990980543227|    1|    1|
|NGC0000002| 3.787274|  1.3298379564211742|    2|    1|
|NGC0000003| 3.423602|-0.29457151504987844|    3|    1|
|NGC0000004|2.6619017|  1.3957536426732444|    4|    1|
+----------+---------+--------------------+-----+-----+
only showing top 5 rows

root
 |-- target: string (nullable = true)
 |-- RA: float (nullable = true)
 |-- Dec: double (nullable = true)
 |-- Index: long (nullable = true)
 |-- RunId: integer (nullable = true)

OUTPUT:  20000  elements
+----------+------------------+--------------------+-----+-----+
|    target|                RA|                 Dec|Index|RunId|
+----------+------------------+--------------------+-----+-----+
|NGC0000000|3.4482970237731934| -0.3387486

### Limitations

- When using `toTFORM`, float are automatically cast to double, and int to long... TBD. Of course you can enter formats manually to avoid this problem (not using `toTFORM`).
- Data is written on local file system (not the DFS).

## Writing data from FITS HEADER

If your structure has't change, you can also directly pass the input header when writing data.
You could also define your own header corresponding to final data to be written, and pass it here.

In [7]:
from astropy.io import fits

# Grab the input header
data = fits.open("../../src/test/resources/test_file.fits")
header = data[1].header

In [8]:
import numpy as np

def write_from_header(part, header, fitsname):
    """ Write DataFrame data into FITS file on disk, using FITS header.
    By default, there is one file per partition, 
    and data is written in the HDU 1.
    
    Parameters
    ----------
    part : Iterator
        Iterator containing data and partition ID.
    header : astropy.io.fits.header.Header
        Instance of the FITS header class.
    fitsname : str
        Name for the output set of files (<blah>.fits). The set of files will
        be then part<number>_<blah>.fits.
    """
    # We assume the data contains 
    # data array (0, ..., N-1) & partition ID (Nth)
    data = np.transpose([*part])
    data_fits = data[0:-1]
    partId = np.unique(data[-1])[0]
    
    # Create fake primary header
    hdr = fits.Header()
    primary_hdu = fits.PrimaryHDU(header=hdr)
    
    # Grab column names and column data types
    # from the header
    names, types = [], []
    for k, v in zip(header.keys(), header.values()):
        if "TTY" in k:
            names.append(v)
        if "TFO" in k:
            types.append(v)

    # HDU containing data
    # Loop over columns
    cols_ = []
    for d, k, v in zip(data_fits, names, types):
        cols_.append(fits.Column(name=k, format=v, array=d))
    cols = fits.ColDefs(cols_)
    hdu1 = fits.BinTableHDU.from_columns(cols)

    hdul = fits.HDUList([primary_hdu, hdu1])
    
    fnout = "part{}_{}".format(partId, fitsname)
    hdul.writeto(fnout)
    
    yield 0

In [9]:
# Name for the output
fitsname = "from_header.fits"

# Write data on disk. The count() is just here
# to trigger the mapPartitions.
df.rdd.mapPartitions(
    lambda part: write_from_header(part, header, fitsname)).count()

4

Check the process went right:

In [10]:
df2 = spark.read.format("fits").option("hdu", 1).load("part*{}".format(fitsname))
print("INPUT: ", df.count(), " elements")
df.drop("partId").orderBy("target").show(5)
df.drop("partId").printSchema()
print("OUTPUT: ", df2.count(), " elements")
df2.orderBy("target").show(5)
df2.printSchema()

INPUT:  20000  elements
+----------+---------+--------------------+-----+-----+
|    target|       RA|                 Dec|Index|RunId|
+----------+---------+--------------------+-----+-----+
|NGC0000000| 3.448297| -0.3387486324784641|    0|    1|
|NGC0000001| 4.493667| -1.4414990980543227|    1|    1|
|NGC0000002| 3.787274|  1.3298379564211742|    2|    1|
|NGC0000003| 3.423602|-0.29457151504987844|    3|    1|
|NGC0000004|2.6619017|  1.3957536426732444|    4|    1|
+----------+---------+--------------------+-----+-----+
only showing top 5 rows

root
 |-- target: string (nullable = true)
 |-- RA: float (nullable = true)
 |-- Dec: double (nullable = true)
 |-- Index: long (nullable = true)
 |-- RunId: integer (nullable = true)

OUTPUT:  20000  elements
+----------+---------+--------------------+-----+-----+
|    target|       RA|                 Dec|Index|RunId|
+----------+---------+--------------------+-----+-----+
|NGC0000000| 3.448297| -0.3387486324784641|    0|    1|
|NGC0000001| 

### Limitations

- Data is written on local file system (not the DFS).