-
Notifications
You must be signed in to change notification settings - Fork 7
/
FitsHduBintable.scala
323 lines (289 loc) · 10.3 KB
/
FitsHduBintable.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
/*
* Copyright 2018 AstroLab Software
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.astrolabsoftware.sparkfits
import scala.util.Try
import org.apache.spark.sql.types.StructField
import com.astrolabsoftware.sparkfits.FitsHdu._
import com.astrolabsoftware.sparkfits.FitsSchema.ReadMyType
/**
* Contain class and methods to manipulate Bintable HDU.
*/
object FitsHduBintable {
/**
* Main class for Bintable HDU
*/
case class BintableHDU(header : Array[String],
selectedColumns: List[String] = null) extends HDU {
val keyValues = FitsLib.parseHeader(header)
// Check if the user specifies columns to select
val colNames = keyValues.
filter(x => x._1.contains("TTYPE")).
map(x => (x._1, x._2.split("'")(1).trim()))
val selectedColNames = if (selectedColumns != null) {
selectedColumns
} else {
colNames.values.toList.asInstanceOf[List[String]]
}
val colPositions = selectedColNames.map(
x => getColumnPos(keyValues, x)).toList.sorted
val rowTypes = getColTypes(keyValues)
val ncols = rowTypes.size
// splitLocations is an array containing the location of elements
// (byte index) in a row. Example if we have a row with [20A, E, E], one
// will have splitLocations = [0, 20, 24, 28] that is a string on 20 Bytes,
// followed by 2 floats on 4 bytes each.
val splitLocations = (0 :: rowSplitLocations(rowTypes, 0)).scan(0)(_ +_).tail
// // Declare useful vars for later
// var rowTypes: List[String] = List()
// var colNames: Map[String, String] = Map()
// var selectedColNames: List[String] = List()
// var colPositions: List[Int] = List()
// var splitLocations: List[Int] = List()
/** Bintables are implemented */
override def implemented: Boolean = {true}
/**
* Get the number of row of a HDU.
* We rely on what's written in the header, meaning
* here we do not access the data directly.
*
* @param keyValues : (Map[String, String])
* keyValues from the header of the HDU (see parseHeader).
* @return (Long), the number of rows as written in KEYWORD=NAXIS2.
*
*/
override def getNRows(keyValues : Map[String, String]) : Long = {
keyValues("NAXIS2").toLong
}
/**
* Get the size (bytes) of each row of a HDU.
* We rely on what's written in the header, meaning
* here we do not access the data directly.
*
* @param keyValues : (Map[String, String])
* keyValues from the header of the HDU (see parseHeader).
* @return (Int), the size (bytes) of one row as written in KEYWORD=NAXIS1.
*
*/
override def getSizeRowBytes(keyValues: Map[String, String]) : Int = {
keyValues("NAXIS1").toInt
}
/**
* Get the number of column of a HDU.
* We rely on what's written in the header, meaning
* here we do not access the data directly.
*
* @param keyValues : (Map[String, String])
* keyValues from the header of the HDU (see parseHeader).
* @return (Long), the number of rows as written in KEYWORD=TFIELDS.
*
*/
override def getNCols(keyValues : Map[String, String]) : Long = {
if (keyValues.contains("TFIELDS")) {
keyValues("TFIELDS").toLong
} else 0L
}
/**
* Return the types of elements for each column as a list.
*
* @param keyValues : (Map[String, String])
* keyValues from the header of the HDU (see parseHeader).
* @return (List[String]), list with the types of elements for each column
* as given by the header.
*
*/
override def getColTypes(keyValues: Map[String, String]): List[String] = {
val colTypes = List.newBuilder[String]
val ncols = getNCols(keyValues).toInt
for (col <- 0 to ncols-1) {
colTypes += getColType(keyValues, col)
}
colTypes.result
}
/**
* Get the type of the elements of a column with index `colIndex` of a HDU.
*
* @param keyValues : (Map[String, String])
* The header of the HDU.
* @param colIndex : (Int)
* Index (zero-based) of a column.
* @return (String), the type (FITS convention) of the elements of the column.
*
*/
def getColType(keyValues : Map[String, String], colIndex : Int) : String = {
// Zero-based index
FitsLib.shortStringValue(keyValues("TFORM" + (colIndex + 1).toString))
}
/**
*
* Build a list of StructField from header information.
* The list of StructField is then used to build the DataFrame schema.
*
* @return (List[StructField]) List of StructField with column name,
* column type, and whether the column is nullable.
*
*/
override def listOfStruct : List[StructField] = {
// Initialise the list of StructField.
val lStruct = List.newBuilder[StructField]
// Loop over necessary columns specified by the user.
for (colIndex <- colPositions) {
// Column name
val colName = FitsLib.shortStringValue(
colNames("TTYPE" + (colIndex + 1).toString))
// Full structure
lStruct += ReadMyType(colName, rowTypes(colIndex))
}
// Return the result
lStruct.result
}
/**
* Convert a bintable row from binary to primitives.
*
* @param buf : (Array[Byte])
* Array of bytes.
* @return (List[Any]) : Decoded row as a list of primitives.
*
*/
override def getRow(buf: Array[Byte]): List[Any] = {
var row = List.newBuilder[Any]
for (col <- colPositions) {
row += getElementFromBuffer(
buf.slice(splitLocations(col), splitLocations(col+1)), rowTypes(col))
}
row.result
}
/**
* Description of a row in terms of bytes indices.
* rowSplitLocations returns an array containing the position of elements
* (byte index) in a row. Example if we have a row with [20A, E, E], one
* will have rowSplitLocations -> [0, 20, 24, 28] that is a string
* on 20 Bytes, followed by 2 floats on 4 bytes each.
*
* @param col : (Int)
* Column position used for the recursion. Should be left at 0.
* @return (List[Int]), the position of elements (byte index) in a row.
*
*/
def rowSplitLocations(rowTypes: List[String], col : Int = 0) : List[Int] = {
val ncols = rowTypes.size
if (col == ncols) {
Nil
} else {
getSplitLocation(rowTypes(col)) :: rowSplitLocations(rowTypes, col + 1)
}
}
/**
* Companion routine to rowSplitLocations. Returns the size of a primitive
* according to its type from the FITS header. More information about handled types can be
* found Table 18 of https://fits.gsfc.nasa.gov/standard40/fits_standard40aa.pdf
*
* @param fitstype : (String)
* Element type according to FITS standards (I, J, K, E, D, L, A, etc)
* @return (Int), the size (bytes) of the element.
*
*/
def getSplitLocation(fitstype : String) : Int = {
val shortType = FitsLib.shortStringValue(fitstype)
shortType match {
case x if shortType.contains("I") => {
if (x == "I") {
2
} else {
// Multivalued columns. nI means array of n shorts
x.slice(0, x.length - 1).toInt * 2
}
}
case x if shortType.contains("J") => {
if (x == "J") {
4
} else {
// Multivalued columns. nJ means array of n ints
x.slice(0, x.length - 1).toInt * 4
}
}
case x if shortType.contains("K") => {
if (x == "K") {
8
} else {
// Multivalued columns. nK means array of n longs
x.slice(0, x.length - 1).toInt * 8
}
}
case x if shortType.contains("E") => {
if (x == "E") {
4
} else {
// Multivalued columns. nE means array of n floats
x.slice(0, x.length - 1).toInt * 4
}
}
case x if shortType.contains("D") => {
if (x == "D") {
8
} else {
// Multivalued columns. nD means array of n doubles
x.slice(0, x.length - 1).toInt * 8
}
}
case x if shortType.contains("L") => 1
case x if shortType.contains("B") => 1
case x if shortType.endsWith("X") => {
// Example 16X means 2 bytes
x.slice(0, x.length - 1).toInt / BYTE_SIZE
}
case x if shortType.endsWith("A") => {
// Example 20A means string on 20 bytes
x.slice(0, x.length - 1).toInt
}
case _ => {
println(s"""
FitsHduBintable.getSplitLocation> Cannot infer size of type $shortType
from the header! See com.astrolabsoftware.sparkfits.FitsHduBintable.getSplitLocation
""")
0
}
}
}
/**
* Get the position (zero based) of a column with name `colName` of a HDU.
*
* @param header : (Array[String])
* The header of the HDU.
* @param colName : (String)
* The name of the column
* @return (Int), position (zero-based) of the column.
*
*/
def getColumnPos(keyValues : Map[String, String], colName : String) : Int = {
// Get the position of the column. Header names are TTYPE#
val pos = Try {
keyValues.filter(x => x._1.contains("TTYPE"))
.map(x => (x._1, x._2.split("'")(1).trim()))
.filter(x => x._2.toLowerCase == colName.toLowerCase)
.keys.head.substring(5).toInt
}.getOrElse(-1)
val isCol = pos >= 0
isCol match {
case true => isCol
case false => throw new AssertionError(s"""
$colName is not a valid column name!
""")
}
// Zero based
pos - 1
}
}
}