In [None]:
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

%matplotlib inline

In [None]:
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.wcs import WCS
from astropy.utils.data import download_file
from astropy.visualization import MinMaxInterval, AsinhStretch, SqrtStretch, ImageNormalize
import astropy.units as u

In [None]:
import requests
import io

## An optional helper class to manage interaction with the API

In [None]:
BASE_URL = 'http://alws.arizona.edu/api'

# Helper class for accessing the ALWS API
class AlwsApi(object):

    def __init__(self, base_url=BASE_URL):
        self.base_url = base_url


    def compose(self, url):
        return f"{self.base_url}/{url}"


    def get(self, url, params=None, **kwargs):
        if (not kwargs):
            kwargs = dict()
        if (not kwargs.get('headers')):
            kwargs['headers'] = dict()
        kwargs['headers']['Authorization'] = self.auth_token
        return requests.get(self.compose(url), params=params, **kwargs)


    def post(self, url, data=None, json=None, **kwargs):
        kwargs['headers']['Authorization'] = self.auth_token
        return requests.post(self.compose(url), data=data, json=json, **kwargs)


    def authorize(self, email, password):
        resp = requests.post(self.compose('auth/token/create/'), data={'email': email, 'password': password})
        if (resp.status_code == 200):
            self.token_pair = resp.json()
            self.auth_token = f"Bearer {self.token_pair.get('access')}"
            self.auth_header = {'Authorization': self.auth_token}
        return resp

#### Create an instance of the API helper class (uses the default server base url)

In [None]:
api = AlwsApi()

#### Authorize a user, store the resulting JWT access token in the instance of the class

In [None]:
# resp = api.authorize('jfake@nowhere.com', 'testydjango')
# print(resp.json())
resp = api.authorize('jfake@nowhere.com', 'testydjango2')

## Load and plot an entire image from the Astrolabe image server

#### First, let's see what images are available from the image server:

In [None]:
resp = api.get('imgmd/')
[file.get('file_name') for file in resp.json()]

#### Since memory is limited, let's select a **small** image to display. (The following assumes that the `HorseHead.fits` file was listed above. If not, subsitute another filename below).

In [None]:
imgURL = api.compose('img/fetch_by_filename/HorseHead.fits')
image_file = download_file(imgURL, http_headers=api.auth_header)

In [None]:
## Alternatively, one can download the bytes into memory and then read them directly:
# resp = api.get('img/fetch_by_filename/HorseHead.fits')
# if (resp.status_code == 200):
#     image_file = io.BytesIO(resp.content)
# else:
#     print(resp.json())  # print error message    

In [None]:
image_hdus = fits.open(image_file)

#### After downloading and opening the image, we can inspect its headers and properties.

In [None]:
image_hdus.info()

In [None]:
image = image_hdus[0].data
image.shape

#### Finally, we can plot the image.

In [None]:
plt.imshow(image, cmap='magma')
plt.colorbar()

## Query for, load, and plot a cutout from the Astrolabe image server

#### First, get a cutout of 10 arc seconds from an image, with the specified filter, from the specified collection, which contains the specified point.

In [None]:
params={'ra': '53.1617', 'dec': '-27.78', 'sizeArcSec': '10', 'filter': 'F356W', 'collection': 'DC20'}
resp = api.get('cuts/fetch_cutout/', params)
if (resp.status_code == 200):
    co_hdus = fits.open(io.BytesIO(resp.content))
else:
    print(resp.json())  # print error message

#### After downloading and opening the image, we can inspect its headers and properties.

In [None]:
co_hdus.info()

In [None]:
co_hdu = co_hdus[0]
wcs = WCS(co_hdu.header)
cutout = co_hdu.data
cutout.shape

#### Finally, we can plot the image data.

In [None]:
# norm_cutout = ImageNormalize(cutout, interval=MinMaxInterval(), stretch=AsinhStretch())
norm_cutout = ImageNormalize(cutout, interval=MinMaxInterval(), stretch=SqrtStretch())

In [None]:
fig = plt.figure(figsize=(12,12))
fig.add_subplot(111, projection=wcs)
plt.imshow(cutout, origin='lower', cmap=plt.cm.gray, norm=norm_cutout)
plt.colorbar()
plt.xlabel('RA')
plt.ylabel('Dec')