In [None]:
import sunpy
import sunpy.map
from sunpy.net import vso

import datetime
from datetime import datetime
from datetime import timedelta

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline

import astropy.units as u
from astropy.io import fits
from astropy.vo.samp import SAMPIntegratedClient

from functools import reduce

In [None]:
class CutoutInfo(object):
    def __init__(self, enabled, x0=0, y0=0, w=2458, h=2458):
        self.enabled = enabled
        self.x0 = x0 * u.arcsec
        self.y0 = y0 * u.arcsec
        self.h  = h  * u.arcsec
        self.w  = w  * u.arcsec
        
        self.a = u.Quantity([self.x0-self.w/2, self.x0+self.w/2])
        self.b = u.Quantity([self.y0-self.h/2, self.y0+self.h/2])                             

In [None]:
class LayerHandler(object):
    def __init__(self, client, timestamp, observatory, instrument, detector, measurement, cutout):        
        self.timestamp = datetime.strptime(timestamp, '%Y/%m/%dT%H:%M:%S.%f')
        self.observatory = observatory
        self.instrument = instrument
        self.detector = detector
        self.measurement = measurement
        self.cutout = cutout
        
        self.__attr = self.__createAttr()
        self.__client = client
    
    def __createAttr(self):
        if(self.instrument == 'AIA'
           or self.instrument == 'EIT' 
           or self.instrument == 'SWAP'):
            return self.__createWaveAttr()
        
        elif(self.instrument == 'XRT'):
            #TODO: Filter wheel not supported yet
            return self.__createNoneAttr()
        
        elif(self.instrument == 'HMI'):
            #TODO: how to query HMI
            return self.__createPhysobsAttr()
        
        elif(self.instrument == 'LASCO'):
            return self.__createDetectorAttr()
            
        elif(self.instrument == 'MDI'):
            return self.__createPhysobsAttr()
           
        elif(self.instrument == 'SECCHI' 
             and self.detector == 'EUVI'):
            return self.__createFullAttr()
        
        elif(self.instrument == 'SECCHI'):
            return self.__createSecchiAttr()
        
        elif(self.instrument == 'TRACE'):
            #TODO: trace
            return self.__createWaveAttr()
        
        elif(self.instrument == 'SXT'):
            #TODO: sxt
            return self.__createNoneAttr()
            
            
    def __createWaveAttr(self):
        wave = int(self.measurement)
        return (vso.attrs.Instrument(self.instrument) 
                & vso.attrs.Wave(wave * u.AA, wave * u.AA))
    
    def __createDetectorAttr(self):
        return (vso.attrs.Instrument(self.instrument)
               & vso.attrs.Detector(self.detector))
    
    def __createPhysobsAttr(self):
        obs = self.measurement
        #TODO: 
        if (obs == 'CONTINUUM INTENSITY' or obs == 'FD_Continuum'):
            obs = 'intensity'
        elif (obs == 'MAGNETOGRAM' or obs == 'FD_Magnetogram'):
            obs = 'LOS_magnetic_field'
        return (vso.attrs.Instrument(self.instrument)
               & vso.attrs.Physobs(obs))
    
    def __createSecchiAttr(self):
        return (vso.attrs.Instrument(self.instrument)
               & vso.attrs.Source(self.observatory)
               & vso.attrs.Detector(self.detector))
    
    def __createFullAttr(self):
        wave = int(self.measurement)
        return (vso.attrs.Instrument(self.instrument)
               & vso.attrs.Source(self.observatory)
               & vso.attrs.Detector(self.detector)
               & vso.attrs.Wave(wave * u.AA, wave * u.AA))
    
    def __createNoneAttr(self):
        return vso.attrs.Instrument(self.instrument)
    
    def createQuery(self, start, end):
        self.__fullAttr = vso.attrs.Time(start, end) & self.__attr
        self.query = self.__client.query(self.__fullAttr)
        return self.query
        
    def createQueryDelta(self, deltaFrom=timedelta(minutes = 2), deltaTo=timedelta(seconds=0)):        
        self.createQuery(self.timestamp - deltaFrom, self.timestamp + deltaTo)
        return self.query
    
    def createQuerySingleResult(self):          
        deltaFromCur = timedelta(minutes = 2)
        deltaFromMin = timedelta(seconds = 0)
        
        # images are usually stored all 3 minutes
        # unless something interesting happens, then it is every 3 seconds
        self.createQueryDelta(deltaFrom = deltaFromCur)
        
        if(len(self.query) > 1):
            self.createQueryDelta(deltaFrom = timedelta(seconds=5))
            if(len(self.query) != 1):
                deltaFromCur = timedelta(seconds=5)
        
        counter = 0
        while(len(self.query) != 1 and counter <= 10):
            
            if(len(self.query) < 1):
                deltaFromMin = deltaFromCur
                deltaFromCur = deltaFromCur * 2
            else:
                deltaFromCur = (deltaFromMin + deltaFromCur)/2
            self.createQueryDelta(deltaFrom = deltaFromCur)
            counter = counter+1
        return self.query
    
    def showQuery(self):
        print(self.query)
    
    def showData(self):
        self.files = self.__client.get(self.query, path='/Data/{instrument}/{file}.fits').wait()
        self.smap = list()
        
        if(self.cutout.enabled):
            self.submap = list()
                    
        for file in self.files:
            curMap = sunpy.map.Map(file)
            self.smap.append(curMap)
            
            if(self.cutout.enabled):
                curMap = curMap.submap(self.cutout.a, self.cutout.b)
                self.submap.append(curMap)
            
            plt.figure(figsize=(15,15))

            curMap.plot()
            curMap.draw_limb()

            plt.colorbar()
            plt.show()

In [None]:
class VSOHandler(object):
    def __init__(self, client, start, end, timestamp, cadence, layers):
        DATE_FORMAT = '%Y/%m/%dT%H:%M:%S.%f'
        
        self.client = client
        self.timestamp = datetime.strptime(timestamp, DATE_FORMAT)
        self.start     = datetime.strptime(start, DATE_FORMAT)
        self.end       = datetime.strptime(end, DATE_FORMAT)
        self.cadence   = cadence
        self.layers    = layers
    
    def createQuery(self):
        for layer in self.layers:
            layer.createQuery(self.start, self.end)
        
    def createQueryDelta(self, deltaFrom=timedelta(minutes = 2), deltaTo=timedelta(seconds=0)):        
        for layer in self.layers:
            layer.createQueryDelta(deltaFrom, deltaTo)
    
    def createQuerySingleResult(self):
        for layer in self.layers:
            layer.createQuerySingleResult()
    
    def showQuery(self):              
        attrs = reduce(lambda x,y: x|y, map(lambda z: z._LayerHandler__fullAttr, self.layers))
        return self.client.query(attrs)
    
    def showData(self):                    
        for layer in self.layers:
            layer.showData()

In [None]:
class Receiver(object):
    def __init__(self, client):
        self.client = client
        self.received = False
    def receive_call(self, private_key, sender_id, msg_id, mtype, params, extra):
        self.params = params
        self.received = True
        self.client.reply(msg_id, {"samp.status": "samp.ok", "samp.result": {}})
    def receive_notification(self, private_key, sender_id, mtype, params, extra):
        self.params = params
        self.received = True
    
    def createHandler(self):
        if self.received:
            client = vso.VSOClient()
            timestamp = self.params['timestamp']
            start = self.params['start']
            end = self.params['end']
            cadence = self.params['cadence']
            layers = list()
            
            cutoutEnabled = bool(self.params['cutout.set'])
            if(cutoutEnabled):
                cutout = CutoutInfo(cutoutEnabled, 
                                    float(self.params['cutout.x0']),
                                    float(self.params['cutout.y0']),
                                    float(self.params['cutout.w']),
                                    float(self.params['cutout.h']))
            else:
                cutout = CutoutInfo(cutoutEnabled)
            
            for layerInfo in self.params['layers']:
                observatory = layerInfo['observatory']
                instrument = layerInfo['instrument']
                detector = layerInfo['detector']
                measurement = layerInfo['measurement']
                layerTimeStamp = layerInfo['timestamp']
                layers.append(LayerHandler(client, layerTimeStamp, observatory, instrument, detector, measurement, cutout))
            
            self.received = False
            
            handler = VSOHandler(client, start, end, timestamp, cadence, layers)
            
            return handler

In [None]:
client = SAMPIntegratedClient()
client.connect()

r = Receiver(client)
client.bind_receive_call("jhv.vso.load", r.receive_call)
client.bind_receive_notification("jhv.vso.load", r.receive_notification)

In [None]:
# We test every 0.1s to see if the hub has sent a message
import time
while True:
    time.sleep(0.1)
    if r.received:
        handler = r.createHandler()
        break

In [None]:
handler.createQuerySingleResult()

In [None]:
handler.showQuery()

In [None]:
handler.showData()

In [None]:
curMap = handler.layers[2].smap[0]

plt.figure(figsize=(15,15))

curMap.plot()
curMap.draw_limb()

plt.colorbar()
plt.show()