# WALLABY update data products

This notebook is intended to be used to update product files in the WALLABY database for accepted sources.

In [None]:
import os
import getpass
import requests
import getpass
import pyvo as vo
from pyvo.auth import authsession, securitymethods
from astropy.io.votable import from_table, parse_single_table
from astropy.table import vstack

In [None]:
# Enter WALLABY user username and password

username = 'wallaby_user'
password = getpass.getpass('Enter your password')

In [None]:
# Connect with TAP service

URL = "https://wallaby.aussrc.org/tap"
auth = vo.auth.AuthSession()
auth.add_security_method_for_url(URL, vo.auth.securitymethods.BASIC)
auth.credentials.set_password(username, password)
tap = vo.dal.TAPService(URL, session=auth)

## 1. Select run

In [None]:
# Get all tags

query = "SELECT * FROM wallaby.run"
votable = tap.search(query)
run_table = votable.to_table()

<span style="font-weight: bold; color: #FF0000;">⚠ Update the cell below with the Run that you would like to update products for</span>

In [None]:
# Select run

run_name = 'SER_223-22'
assert run_name in run_table['name'], 'Run does not exist'

## 2. Get detections and products

In [None]:
# Retrieve catalog as Astropy table

default_query = """SELECT * FROM wallaby.detection d 
        LEFT JOIN wallaby.run r ON d.run_id = r.id 
        WHERE d.source_name is not null AND r.name = '$RUN_NAME'"""
query = default_query.replace('$RUN_NAME', run_name)

In [None]:
# Run TAP query

result = tap.search(query)
table = result.to_table()
table

In [None]:
# useful function for downloading table products (requires authentication)

def download_products(row, products_filename, chunk_size=8192):
    """Download products for a row of the table (a detection entry)
    
    """
    name = row['source_name']
    access_url = row['access_url']
    votable = parse_single_table(access_url)
    product_table = votable.to_table()
    url = product_table[product_table['description'] == 'SoFiA-2 Detection Products'][0]['access_url']
    with requests.get(url, auth=(username, password), stream=True) as r:
        r.raise_for_status()
        with open(products_filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=chunk_size):
                f.write(chunk)
    print(f'Downloaded completed for {name}')
    return

def download_table_products(table, directory, chunk_size=8192):
    """Download WALLABY products from ADQL queried table

    """
    if not os.path.exists(directory):
        os.mkdir(directory)
    print(f'Saving products to {directory}')
    for row in table:
        try:
            name = row['source_name']
            products_filename = os.path.join(directory, f'{name}.tar')
            download_products(row, products_filename, chunk_size)
        except Exception as e:
            print(f'Error downloading {products_filename}')
            continue
    print('Downloads complete')
    return

In [None]:
# Download product files

download_table_products(table[0:2], run_name)

## 3. Update products

In [None]:
import tarfile
import glob
from astropy.io import fits

<span style="font-weight: bold; color: #FF0000;">⚠ Update the cell below with how you would like to modify the fits files</span>

In [None]:
# Update this function with how you would like to modify the fits files

def update_fits(hdul):
    header = hdul[0].header
    header['SBID'] = '50095 50583'
    return

In [None]:
# Get all product files

product_tarfiles = glob.glob(os.path.join(run_name, '*.tar'))
product_files = [f.replace('.tar', '') for f in product_tarfiles]

In [None]:
# Update all product files

# Extract
for f in product_tarfiles:
    filename = f.replace('.tar', '')
    with tarfile.open(f) as tf:
        tf.extractall(path=filename)
    # os.remove(f)

# Update fits files
print('Updating fits files')
for idx_pf, pf in enumerate(product_files):
    print(f'Folder {pf} [{idx_pf + 1}/{len(product_files)}]')
    fits_files = glob.glob(os.path.join(pf, '*.fits'))
    for idx_ff, ff in enumerate(fits_files):
        print(f'[{idx_ff + 1}/{len(fits_files)}] {ff}')
        source_name = ff.split('/')[1]
        with fits.open(ff, mode='update') as hdul:
            update_fits(hdul)
            hdul.flush()

## 4. Re-upload to database

**NOTE:** This makes some important assumptions about the run name and source names. The structure that is expected is: `$CWD/<run_name>/<source_name>/<product_file>`. It will parse the filename to get the run name and source name to update.

You will need to update the code below with the destination of the database connection credentials environment file.

In [None]:
import asyncio
import asyncpg
from dotenv import load_dotenv

<span style="font-weight: bold; color: #FF0000;">⚠ Update `database_env` in the cell below with the path to the database credentials environment file.</span>

In [None]:
# Database connection

database_env = '/path/to/file'
assert os.path.exists(database_env), 'Database credentials environment variable not provided'

In [None]:
# Load environment variable and establish database connection

load_dotenv(database_env)
creds = {
    'host': os.getenv('DATABASE_HOST'),
    'database': os.getenv('DATABASE_NAME'),
    'user': os.getenv('DATABASE_USER'),
    'password': os.getenv('DATABASE_PASSWORD'),
    'port': os.getenv('DATABASE_PORT', 5432)
}

In [None]:
# Establish and test connection

pool = await asyncpg.create_pool(None, **creds)
async with pool.acquire() as conn:
    res = await conn.fetch('SELECT * FROM wallaby.run')
assert res is not None, 'Connection did not work...'

In [None]:
# Fits file to bytes function

import os
import aiofiles

async def _get_file_bytes(path: str, mode: str = 'rb'):
    buffer = []
    if not os.path.isfile(path):
        return b''
    async with aiofiles.open(path, mode) as f:
        while True:
            buff = await f.read()
            if not buff:
                break
            buffer.append(buff)
        if 'b' in mode:
            return b''.join(buffer)
        else:
            return ''.join(buffer)

In [None]:
# Update fits files

update_query = "UPDATE wallaby.product SET $COLUMN = ($1) WHERE id=$2"

print(f'Re-uploading products to database for run {run_name}')
async with pool.acquire() as conn:
    async with conn.transaction():
        for idx_pf, pf in enumerate(product_files):
            source_name = pf.split('/')[1]
            print(f'Source {source_name} [{idx_pf + 1}/{len(product_files)}]')
            fits_files = glob.glob(os.path.join(pf, '*.fits'))
        
            # get product id
            get_product_id = """SELECT pr.id FROM wallaby.product pr
                                LEFT JOIN wallaby.detection d ON d.id = pr.detection_id
                                LEFT JOIN wallaby.run r ON d.run_id = r.id
                                WHERE (r.name = $1 AND d.source_name = $2)"""
            res = await conn.fetchrow(get_product_id, run_name, source_name)
            product_id = int(res['id'])
            
            for idx_ff, ff in enumerate(fits_files):
                print(f'[{idx_ff + 1}/{len(fits_files)}] Re-uploaded {ff}')
                
                # TODO: database update
                suffix = ff.rsplit('_', 1)[1].replace('.fits', '')
                p_bytes = await _get_file_bytes(ff)
                
                res = await conn.execute(update_query.replace('$COLUMN', suffix), p_bytes, product_id)

In [None]:
# Close database connection

pool.close()