# Exercise sheet 5
### Due 02/12/2022, 14:00 
- Name 1
- Name 2
- Name 3

## Fun with files: bright astronomical transients
The *Bright Transient Survey* (BTS) is an astronomical survey carried out by the Zwicky Transient Facility (ZTF), a multi-telescope research infrastructure targeting transient astronomical phenomena. The BTS website provides a catalogue of very bright transient astronomy sources, such as supernovae, tidal disruption events and others.

### 1. Download the current version of the catalogue
- Open the [BTS explorer](https://sites.astro.caltech.edu/ztf/bts/explorer.php) webpage. Keep all the default options and select "Display as: CSV".
- Download the webpage as a CSV file and place it at `"./data/bts.csv"` where `.` is the "current working directory" of the notebook.

In [None]:
import os
from pathlib import Path

"""
Define the path to the catalogue in a way that is relative to the current working directory of the notebook.
Remember that your code will be executed in another computer and all "absolute" paths (for example containing your username or system) will not work! 
"""
base_dir = Path(os.getcwd())
data_dir = base_dir / "data"
data_dir.mkdir(exist_ok=True)
filename = "bts.csv"
# Let's put bts.csv in the designated directory.
catalogue_path = data_dir / "bts.csv"

catalogue_path.is_file()

### 2. Read the file
- Define a function to read the file making use of the [CSV library](https://docs.python.org/3/library/csv.html).
- The function should return a dictionary mapping a field name to a list, each list should contain the corresponding field values for all the entries in the file. In other words, `dataset["redshift"][i]` should give the "redshift" value of the i-th entry in the CSV file.
- Notice that the first line of the CSV file contains the field names. These are what we call *metadata*, information that *describe the data*. Use these names as keys of the dictionary.
- Make sure that numerical values are actually stored at numbers and not as strings. You can keep "RA" and "Dec" as strings since we do not know how to deal with date/times/coordinates (yet).

In [None]:
import csv

def read_bts(file_path : str):
    with open(file_path) as f:
        for row in f:
            if row.isspace():
                # this allows us to skip empty lines, if any
                continue
            else:
                # we derive the list of fields, and sanitize them with "strip" to remove \n where present
                fieldnames = [field.strip() for field in row.split(',')]
                break
        print(f"Parsed CSV with fields:  {fieldnames}")

        data = { field : list() for field in fieldnames}

        reader = csv.DictReader(f, fieldnames=fieldnames)
        for r in reader:
            for field in fieldnames:
                data[field].append(r[field])

        """
        Let's now convert the numeric fields to float.
        """

        numeric_fields = ['peakt', 'peakmag', 'peakabs', 'duration', 'rise', 'fade', 'redshift','b','A_V']

        """
        We may have different cases where the value is not convertible, we can try to address these case by case until our code does not throw any more errors:
        - missing values are marked with '-' and converted to None (in the future, with numpy, this could be a NaN);
        - lower or upper limits '>', '<' can be treated in different ways, either converted to `None` or taking the limit as the numeric value itself. We can argue that upper/lower limits are not really "homogeneous" to actual measurements, so discarding them is the most robust choice.
        """
        for field in numeric_fields:
            column = data[field]
            numeric_column = list()
            for val in column:
                if val in ["-", ""] or val.startswith(('>','<')):
                    num_val = None
                else:
                    num_val = float(val)
                numeric_column.append(num_val)
            data[field] = numeric_column
        
    return data

dataset = read_bts(catalogue_path)

### 3. Catalogue as a class
Define a class `Catalogue` that is meant **store the catalogue** and provide a set of utility and analysis methods according to the following:
- the  `__init__()` method of the class should take as an argument the dictionary as defined at point `#2`;
- provide the class with a `from_csv()` class method that takes as an argument a CSV file path and returns an instance of `Catalogue` representing the file content. You may recycle the code of your `read_bts()` function but make sure all the necessary code is contained in the class (in other words, do not call `read_bts()` from inside the class);
- provide the `Catalogue` class with a non-modifiable `origin` attribute (to be taken care of in `__init__()`). `origin` should tell whether the object has been created from a dictionary, another catalogue (see later `filter_by_types()`), or a file. In the case of a file, `origin` should keep track of the file path in the form of a string;
- provide the class with a `__repr__()` method. Remember: this method is meant to return a string representation of the object but in general not of its full content. Use `__repr__()` method to return a string containing at least the size of the catalogue and the `origin` information;
- provide the class with a `__str__()` format. Remember: this method is meant to return a more descriptive representation than `__repr__()` that is used when the instance is passed to the `print()` function. The `__str__()` method should return a string with a table-like representation of the catalogue content, limited to the following fields: `IAUID`, `RA`, `Dec`, `peakabs`, `type`, `redshift`. Since it may not be always desiderable to print a long catalogue, make `__str()__` depend on an instance attribute `print_max_nsources` that can be set for the `Catalogue` object;
- provide a method `filter_by_types()` that takes as an argument a list of strings and returns a new instance of `Catalogue` containing only the sources for which the value of the `type` field is comprised in the list; the syntax of the operation should be like `new_cat = full_cat.filter_by_types(["SN Ia"])`;
- provide methods `get_brightest_source()` and `get_closest_source()` where the brightest source is the one with the lower value of `peakmag` and the closest source is the one with the lower value of `redshift`; the return value should contain the data of the source in the form of a dictionary (ideally, we would like to define a dedicated `Source` class in such circumstance but this is not a requirement for this exercise);
- provide a method `__eq__(self, other)` that takes as an argument another instance of `Catalogue` and returns `True` if the **content** of the two catalogues is the same, regardless of the `origin`. Tip: remember you can compare list contents with `==` and that you can make use of `all()` and `any()` special functions for booleans.


In [None]:
import csv
import json

class Catalogue:
    # Set here a class-wide default value, then provide a method to configure it.
    default_print_max_nsources = 10

    # We may need to use a reference field name in the methods, so we set it here.
    id_field = "IAUID" 

    def __init__(self, input_dict : dict, origin: str = "runtime dictionary" ):
        self.data = input_dict
        self.origin = origin
        self.print_max_nsources = Catalogue.default_print_max_nsources

    def set_print_max_nsources(self, n):
        self.print_max_nsources = n

    def __eq__(self, other):
        """ 
        If `other` is a `Catalogue` (we could incur in errors otherwise) we compare the data.
        Otherwise the comparison is just false.
        """
        if isinstance(other, Catalogue):
            return self.data == other.data
        else:
            return False

    def __len__(self):
        return len(self.data[self.id_field])

    def __repr__(self):
        repr = f"Catalogue with {len(self.data.keys())} fields holding {len(self.data[self.id_field])} objects."
        repr += f"\nOrigin: {self.origin}"
        return repr

    def __str__(self):
        """
        This has to build a string that looks like a table.
        We build a line out of the field names (header) plus a given number of lines of actual data.
        For simplicity, we choose to convert every value to string and let it take 8 spaces (using a ":8s" format string). It would be more elegant to define a different format string each field and better control the displaying of float values!
        """
        header_str, table_str = "", ""
        for fieldname in self.data.keys():
            header_str += f"| {fieldname:8s} "
        header_str += "| \n"

        for r, row in enumerate(self.data[self.id_field]):
            if r >= self.print_max_nsources:
                break
            for fieldname in self.data.keys():
                table_str += f"| {str(self.data[fieldname][r]):8s} "
            table_str += "| \n"

        """
        Let's be nice and informative!
        """
        hidden_entries = len(self) - r
        if hidden_entries > 0:
            table_str += f"{hidden_entries} additional entries not shown."

        return header_str + table_str

    def filter_by_types(self, typelist):
        _data = self.data # just a shorthand notation
        filtered_dict = { key : list() for key in _data.keys() }
        for r, row in enumerate(_data[self.id_field]):
            if _data['type'][r] in typelist:
                for key in _data.keys():
                    filtered_dict[key].append(_data[key][r])
        origin = f"Filter by types {typelist} on Catalogue with origin:\n - {self.origin}"  
        return Catalogue(filtered_dict, origin=origin)

    def get_closest_source(self):
        z = self.data['redshift']
        i_min = z.index(min(z))
        src = dict()
        for key in self.data.keys():
            src[key] = self.data[key][i_min]
        return src

    def get_brightest_source(self):
        m = self.data['peakmag']
        i_min = m.index(min(m))
        src = dict()
        for key in self.data.keys():
            src[key] = self.data[key][i_min]
        return src

    """
    We can actually generalise this logic.
    """
    def select_source(self, fieldname, selection):
        d = self.data[fieldname]
        i = d.index(selection(d))
        src = dict()
        for key in self.data.keys():
            src[key] = self.data[key][i]
        return src


    def to_json(self, json_path):
        """ 
        For simplicity, we do not preserve the `origin` information when dumping to JSON.
        Otherwise we would need to build some kind of nested `origin` structure when we read back from the JSON file.
        """
        with open(json_path, "w") as f:
            json.dump(self.data, f)

    @classmethod
    def from_json(cls, json_path):
        with open(json_path, "r") as f:
            json_data = json.load(f)
        json_origin = f"JSON file : {json_path}"
        return Catalogue(input_dict=json_data, origin=json_origin)

    @classmethod
    def from_csv(cls, file_path):
        """ This code is the same as `read_bts` """
        with open(file_path) as f:
            for row in f:
                if row.isspace():
                    continue
                else:
                    fieldnames = [field.strip() for field in row.split(',')]
                    break

            data = { field : list() for field in fieldnames}

            reader = csv.DictReader(f, fieldnames=fieldnames)
            for r in reader:
                for field in fieldnames:
                    data[field].append(r[field])
                    
        numeric_fields = ['peakt', 'peakmag', 'peakabs', 'duration', 'rise', 'fade', 'redshift','b','A_V']

        for field in numeric_fields:
            column = data[field]
            numeric_column = list()
            for val in column:
                if val in ["-", ""] or val.startswith(('>','<')):
                    num_val = None
                else:
                    num_val = float(val)
                numeric_column.append(num_val)
            data[field] = numeric_column
        
        origin = f"file = {file_path}"

        return Catalogue(input_dict=data, origin=origin)


In [None]:
"""
Test the two ways of building a Catalogue object!
"""
cat_a = Catalogue(input_dict=dataset)

cat_b = Catalogue.from_csv(catalogue_path)

print(cat_a.origin)

print(cat_b.origin)

print("cat_a == cat_b : ", cat_a == cat_b)

In [None]:
cat_c = Catalogue.from_csv(catalogue_path)
cat_c.data["ZTFID"][0] = "INVALID"

print("cat_a == cat_c : ", cat_a == cat_c)

In [None]:
# Test __repr__()
cat_a

In [None]:
# Test __str__()
print(cat_a)

In [None]:
cat_tde = cat_a.filter_by_types(["TDE"])

cat_tde

In [None]:
print(cat_tde)

In [None]:
print(cat_tde.get_brightest_source())

In [None]:
print(cat_tde.get_closest_source())

In [None]:
# Check this alternative and more general implementation!
print(cat_tde.select_source('redshift', min))

### 3. CSV to JSON and viceversa
- Provide the class with a method `to_json()` that takes as an argument a file path and writes a JSON representation of the object content to a file. Store the `origin` information in the output file. Remember: while a generic object cannot be directly *dumped* to JSON, a dictionary wrapping the relevant data can!
- Provide the class with a class method `from_json()` that takes as an argument a file path and returns a `Catalogue` instance based on the file content.
- Test that the read/write functionality works by verifying the equivalence (you have defined an `__eq__` method for this)!

In [None]:
""" TEST EXAMPLE """
json_path = data_dir / "bts.json" # define your path to a file here

csv_cat = Catalogue.from_csv(catalogue_path)
csv_cat.to_json(json_path)
json_cat = Catalogue.from_json(json_path)

print(json_cat == csv_cat) # here python should invoke your `__eq__` method.

In [None]:
json_cat