# Setup

In [1]:
from colorama import Fore, Style
import json

# Dataset download

In [None]:
from xami_dataset import XAMIDataset

# Download the dataset
xami_dataset = XAMIDataset(
    repo_id="iulia-elisa/XAMI-dataset", 
    dataset_name="xami_dataset", 
    data_path='./dest_dir')

# Dataset info

In [None]:
splits = ['train', 'valid']

for annotations_file in [xami_dataset.train_annotations_path, xami_dataset.valid_annotations_path]:
    data = xami_dataset.get_data_from_json(annotations_file)
    print(len(data['annotations']))
    _, cat_names, _ = xami_dataset.get_categories(data)
    print(annotations_file)
    print(f"{Fore.BLUE}Class categories in: {cat_names}{Style.RESET_ALL}")
    print(f"{Fore.BLUE}Number of annotations: {len(data['annotations'])}{Style.RESET_ALL}")
    print(f"{Fore.BLUE}Number of images in: {len(data['images'])}{Style.RESET_ALL}")

In [None]:
categories_df = xami_dataset.get_category_table()
display(categories_df) 

We can also see how many images per filter we have.

In [None]:
filters_count = {'U': 0, 'V': 0, 'B': 0, 'W': 0, 'S': 0, 'M': 0, 'L': 0}
filters_df = xami_dataset.get_filters_table(filters_count)
display(filters_df)

## Mask heatmap

In [None]:
xami_dataset.generate_heatmap(output_path='./plots/artefact_distributions.png')

## Galactic coordinates plotting

In [None]:
xami_dataset.generate_galactic_distribution_plot(
    splits=['train', 'valid'], 
    obs_coords_file='./xami_utils/obs_info_1024_all.json', 
    output_path='dataset_galactic_distribution.png')

# (Optional) Split the dataset using the multi-Stratified K-fold CSV files

The CSV files have been previously generated using the multilabel Stratified K-Fold (**mskf**) technique in order to balance classes distributions across dataset splits. If you wish to generate different versions of dataset splits (e.g. with different *k*, different algorithm, etc...), you can do that in the `stratified_kfold.ipynb` notebook. We have provided these splits to make them a baseline as they are used for our metrics. 

Each CSV file has the following columns:

- `IMADE_ID` - an integer representing the image ID from the original json annottaions file
- `IMAGE_PATH` - the corresponding image path
- `SPLIT` - either **train** or **valid**

In [None]:
use_skf_splits = False

In [None]:
if use_skf_splits:
    import pandas as pd
    from xami_utils import utils
    
    if use_skf_splits: # Whether to use the 4 split folds.
        csv_files = ['mskf_0.csv', 'mskf_1.csv', 'mskf_2.csv', 'mskf_3.csv'] 
        
        for idx, csv_file in enumerate(csv_files):
            mskf = pd.read_csv(csv_file)
            utils.create_directories_and_copy_files('path/to/dest', data_in, mskf, idx)