In [None]:
import sunpy
import sunpy.map
from sunpy.net import Fido, attrs

import datetime
from datetime import datetime
from datetime import timedelta

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline

import astropy.units as u
from astropy.io import fits
from astropy.samp import SAMPIntegratedClient
from astropy.utils.console import ProgressBar
from astropy.coordinates import SkyCoord

from functools import reduce

### Config
You can change the location where the fits-file will be downloaded by changing the value of the following variable

In [None]:
downloadPath = '/Data/{source}/{instrument}/{time.start}/{file}'

### Setup
The following Cells contain used classes as well as the code responsible for connecting to a SAMP-Hub.

They need only to be executed once before JHelioviewer sends a SAMP-Message.

In [None]:
def parseDate(dateString):
    DATE_TIME_FORMAT_NO_MS = '%Y-%m-%dT%H:%M:%S'
    DATE_TIME_FORMAT = '%Y-%m-%dT%H:%M:%S.%f'
    
    try:
        timestamp = datetime.strptime(dateString, DATE_TIME_FORMAT)
    except ValueError:
        timestamp = datetime.strptime(dateString, DATE_TIME_FORMAT_NO_MS)
    
    return timestamp

In [None]:
class CutoutInfo(object):
    def __init__(self, enabled, x_center=0, y_center=0, w=2458, h=2458):
        self.enabled = enabled
        self.x0 = (x_center - w/2) * u.arcsec
        self.y0 = (y_center - h/2) * u.arcsec
        self.x1 = (x_center + w/2) * u.arcsec
        self.y1 = (y_center + h/2) * u.arcsec

In [None]:
class LayerHandler(object):
    def __init__(self, timestamp, observatory, instrument, detector, measurement, cutout):        
        self.timestamp = parseDate(timestamp)
        self.observatory = observatory
        self.instrument = instrument
        self.detector = detector
        self.measurement = measurement
        self.cutout = cutout
        self.dataDownloaded = False
        
        self.__attr = self.__createAttr()
    
    def __createAttr(self):
        if(self.instrument == 'AIA'
           or self.instrument == 'EIT'
           or self.instrument == 'SWAP'):
            return self.__createWaveAttr()
        
        elif(self.instrument == 'XRT'):
            #TODO: Filter wheel not supported yet
            return self.__createNoneAttr()
        
        elif(self.instrument == 'HMI'):
            return self.__createPhysobsAttr()
        
        elif(self.instrument == 'LASCO'):
            return self.__createDetectorAttr()
            
        elif(self.instrument == 'MDI'):
            return self.__createPhysobsAttr()
           
        elif(self.instrument == 'SECCHI' 
             and self.detector == 'EUVI'):
            return self.__createFullAttr()
        
        elif(self.instrument == 'SECCHI'):
            return self.__createSecchiAttr()
        
        elif(self.instrument == 'TRACE'):
            #TODO: trace
            return self.__createWaveAttr()
        
        elif(self.instrument == 'SXT'):
            #TODO: sxt
            return self.__createNoneAttr()
            
            
    def __createWaveAttr(self):
        wave = int(self.measurement)
        return (attrs.Instrument(self.instrument)
                & attrs.Wavelength(wave * u.AA, wave * u.AA))
    
    def __createDetectorAttr(self):
        return (attrs.Instrument(self.instrument)
               & attrs.vso.Detector(self.detector))
    
    def __createPhysobsAttr(self):
        obs = self.measurement
        #TODO: 
        if (obs == 'CONTINUUM INTENSITY' or obs == 'FD_Continuum'):
            obs = 'intensity'
        elif (obs == 'MAGNETOGRAM' 
              or obs == 'FD_Magnetogram' 
              or obs == 'FD_Magnetogram_Sum'):
            obs = 'LOS_magnetic_field'
        return (attrs.Instrument(self.instrument)
               & attrs.vso.Physobs(obs))
    
    def __createSecchiAttr(self):
        return (attrs.Instrument(self.instrument)
               & attrs.vso.Source(self.observatory)
               & attrs.vso.Detector(self.detector))
    
    def __createFullAttr(self):
        wave = int(self.measurement)
        return (attrs.Instrument(self.instrument)
               & attrs.vso.Source(self.observatory)
               & attrs.vso.Detector(self.detector)
               & attrs.Wavelength(wave * u.AA, wave * u.AA))
    
    def __createNoneAttr(self):
        return attrs.Instrument(self.instrument)
    
    def createQuery(self, start, end):
        self.dataDownloaded = False
        self.__fullAttr = attrs.Time(start, end) & self.__attr
        self.query = Fido.search(self.__fullAttr)
        return self.query
        
    def createQueryDelta(self, deltaFrom=timedelta(minutes = 2), deltaTo=timedelta(seconds=0)):        
        self.createQuery(self.timestamp - deltaFrom, self.timestamp + deltaTo)
        return self.query
    
    def __queryLength(self):
        return sum(len(elem) for elem in self.query)
    
    def createQuerySingleResult(self):          
        deltaFromCur = timedelta(minutes = 2)
        deltaFromMin = timedelta(seconds = 0)
        
        # images are usually stored all 3 minutes
        # unless something interesting happens, then it is every 3 seconds
        self.createQueryDelta(deltaFrom = deltaFromCur)
        
        if(self.__queryLength() > 1):
            self.createQueryDelta(deltaFrom = timedelta(seconds=5))
            if(len(self.query) != 1):
                deltaFromCur = timedelta(seconds=5)
        
        counter = 0
        while(self.__queryLength() != 1 and counter <= 10):
            
            if(self.__queryLength() < 1):
                deltaFromMin = deltaFromCur
                deltaFromCur = deltaFromCur * 2
            else:
                deltaFromCur = (deltaFromMin + deltaFromCur)/2
            self.createQueryDelta(deltaFrom = deltaFromCur)
            counter = counter+1
        return self.query
    
    def createQuerySample(self, start, end, cadence):
        self.dataDownloaded = False
        self.__fullAttr = attrs.Time(start, end) & self.__attr & attrs.vso.Sample(cadence)
        self.query = Fido.search(self.__fullAttr)
        return self.query
        
    
    def showQuery(self):
        print(self.query)
    
    def downloadData(self):
        self.files = Fido.fetch(self.query, path=downloadPath)
        self.dataDownloaded = True
    
    def showData(self, figsize=(15,15)):
        if not self.dataDownloaded:
            self.downloadData()
        
        self.smap = list()
        
        if(self.cutout.enabled):
            self.submap = list()
                    
        for file in self.files:
            try:
                curMap = sunpy.map.Map(file)            
            except Exception as err:
                print("File could not be read as sunpy.map.Map(file)")
                print(file)
                print(err)
            else:
                self.smap.append(curMap)

                #XRT and LASCO doesnt seem to work with cutout!
                if(self.cutout.enabled and self.instrument != 'XRT' and self.instrument != 'LASCO'):
                    try:
                        p0 = SkyCoord(self.cutout.x0, self.cutout.y0, frame=curMap.coordinate_frame)
                        p1 = SkyCoord(self.cutout.x1, self.cutout.y1, frame=curMap.coordinate_frame)
                        curMap = curMap.submap(p0, p1)
                        self.submap.append(curMap)
                    except Exception as err:
                        print("Creating cutout failed for the instrument '" + self.__instrument + "'")

                plt.figure(figsize=figsize)

                curMap.plot()
                curMap.draw_limb()

                plt.colorbar()
                plt.show()

In [None]:
class VSOHandler(object):
    def __init__(self, start, end, timestamp, cadence, layers):
        self.timestamp = parseDate(timestamp)
        self.start     = parseDate(start)
        self.end       = parseDate(end)
        self.cadence   = cadence
        self.layers    = layers
    
    def createQuery(self):
        with ProgressBar(len(self.layers), True) as bar:
            for layer in self.layers:
                layer.createQuery(self.start, self.end)
                bar.update()
        
    def createQueryDelta(self, deltaFrom=timedelta(minutes=2), deltaTo=timedelta(seconds=0)):
        with ProgressBar(len(self.layers), True) as bar:
            for layer in self.layers:
                layer.createQueryDelta(deltaFrom, deltaTo)
                bar.update()
    
    def createQuerySingleResult(self):
        with ProgressBar(len(self.layers), True) as bar:
            for layer in self.layers:
                layer.createQuerySingleResult()
                bar.update()
    
    def createQuerySample(self, cadence=None):
        if cadence is None:
            cadence = (int(self.cadence)/1000)*u.second
        
        with ProgressBar(len(self.layers), True) as bar:
            for layer in self.layers:
                layer.createQuerySample(self.start, self.end, cadence)
                bar.update()
    
    def showQuery(self):              
        attrs = reduce(lambda x, y: x|y, map(lambda z: z._LayerHandler__fullAttr, self.layers))
        return Fido.search(attrs)
    
    def downloadData(self):
        with ProgressBar(len(self.layers), True) as bar:
            for layer in self.layers:
                layer.downloadData()
                bar.update()
    
    def showData(self, figsize=(15, 15)):
        with ProgressBar(len(self.layers), True) as bar:   
            for layer in self.layers:
                layer.showData(figsize)
                bar.update()

In [None]:
class Receiver(object):
    def __init__(self, client):
        self.client = client
        self.received = False
    def receive_call(self, private_key, sender_id, msg_id, mtype, params, extra):
        self.params = params
        self.received = True
        self.client.reply(msg_id, {"samp.status": "samp.ok", "samp.result": {}})
    def receive_notification(self, private_key, sender_id, mtype, params, extra):
        self.params = params
        self.received = True
    
    def createHandler(self):
        if self.received:
            params = self.params
            self.received = False
            
            timestamp = params['timestamp']
            start = params['start']
            end = params['end']
            cadence = params['cadence']
            layers = list()
            
            cutoutEnabled = bool(params['cutout.set'])
            if cutoutEnabled:
                cutout = CutoutInfo(cutoutEnabled, 
                                    float(params['cutout.x0']),
                                    float(params['cutout.y0']),
                                    float(params['cutout.w']),
                                    float(params['cutout.h']))
            else:
                cutout = CutoutInfo(cutoutEnabled)
            
            for layerInfo in params['layers']:
                observatory = layerInfo['observatory']
                instrument = layerInfo['instrument']
                detector = layerInfo['detector']
                measurement = layerInfo['measurement']
                layerTimeStamp = layerInfo['timestamp']
                layers.append(LayerHandler(layerTimeStamp, observatory, instrument, detector, measurement, cutout))
            
            handler = VSOHandler(start, end, timestamp, cadence, layers)
            
            return handler

The next cell connects to an active Hub (see http://docs.astropy.org/en/stable/vo/samp/example_table_image.html)

In [None]:
client = SAMPIntegratedClient()
client.connect()

r = Receiver(client)
client.bind_receive_call("jhv.vso.load", r.receive_call)
client.bind_receive_notification("jhv.vso.load", r.receive_notification)

### Downloading the data
When a message is sent over SAMP, run the following to create a new Handler with the parameters given over SAMP.

A Handler object named handler will be created which offers several utility methods to create a VSOQuery and download the data.
* __handler.createQuerySingleResult()__   
  Creates a VSO-Query to only retrieve the currently show images in JHelioviewer for each active Layer
* __handler.createQuery()__   
  Creates a VSO-Query for the complete timespan visible in JHelioviewer
* __handler.createQuerySample(cadence)__
  Creates a VSO-Query for the complete timespan visible in Jhelioviewer, but only provide an image every *cadence* seconds. If *cadence* is not provided, the same cadence as in JHelioviewer is used.
  
* __handler.showQuery()__   
  Can be used to check the query result before downloading the actual data
* __handler.downloadData()__   
  Downloads the data found by the current query
* __handler.showData(figsize)__   
  Downloads and shows the data by the current query

The start, endtime and timestamp can also be read on the handler.

In [None]:
# We test every 0.1s to see if the hub has sent a message
import time
while True:
    time.sleep(0.1)
    if r.received:
        handler = r.createHandler()
        break

In [None]:
print("Start:\t\t", handler.start)
print("Timestamp:\t", handler.timestamp)
print("End:\t\t", handler.end)

First, we create a Query for fetching only the actual data for each Layer and check the results afterwards

In [None]:
#handler.createQuerySample()
handler.createQuerySingleResult()

In [None]:
handler.showQuery()

The data can then be downloaded and displayed with showData().

In [None]:
#handler.downloadData()
handler.showData(figsize=(7, 7))

### Using the data
For each Layer that was visible in JHelioviewer, a corresponding Layer is created on the handler. They can be accessed indiviualy and provides the same method as the handler. Additionaly, the actual data as well as metadata is accessible.

The metadata can easily be accessed for each layer

In [None]:
layer = handler.layers[0]
print("Timestamp:\t", layer.timestamp)
print("Observatory:\t", layer.observatory)
print("Instrument:\t", layer.instrument)
print("Detector:\t", layer.detector)
print("Measurement:\t", layer.measurement)

The data itself for each layer is available as a sunpy.Map in a list named smap. The data itself for each layer is available as a sunpy.Map in a list named smap.

In [None]:
print("Map:")
print(layer.smap[0])
print("Submap:")
print(layer.submap[0])