In [None]:
#> Imports
import numpy as np
from ipywidgets import interact
import matplotlib.pyplot as plt

In [None]:
#> How to read in the augmentations.txt file.
def read_in_augmentations_file(dataset):
    
    augmented_images = []
    
    # In order to balance the morphological distribution in the final datasets, the number of augmentations needed per morph class differs.
    # Hence the need to read in the jaggered data like below.    
    with open(dataset + "/augmentations.txt", "rb") as f:
        
        while True:
            try: augmented_images.append([np.load(f, allow_pickle = True) for _ in range(3 if "_V" in dataset else 1)])
            except EOFError: break
            
    return augmented_images

In [None]:
#> Handle plot creation depending on the dataset.
def handle_combined_image_plots(axN, source_id, number_of_augs, augmented_images):
    
    fig, ax = plt.subplots(1, number_of_augs, figsize=(5 * number_of_augs, 5 * 1))
    fig.suptitle(f"Augmented Images for Source ID: {source_id} (red = FIRST, green = LoTSS, blue = NVSS)", fontsize=16)
    
    for i in range(number_of_augs):
        ax[i].imshow(augmented_images[source_id][0][i])
        ax[i].set_title(f"Augmentation: {i}")
       
        
def handle_seperate_image_plots(axN, source_id, number_of_augs, augmented_images):
    
    fig, ax = plt.subplots(3, number_of_augs, figsize=(5 * number_of_augs, 5 * 3))
    fig.suptitle(f"Augmented Images for Source ID: {source_id}", fontsize=16)
    
    for i in range(number_of_augs):
        for j, survey in enumerate(["FIRST", "LoTSS", "NVSS"]): 
            ax[j,i].imshow(augmented_images[source_id][j][i])
            ax[j,i].set_title(f"Augmentation: {i}, Survey: {survey}")


def handle_visibility_plots(ax, source_id, number_of_augs, augmented_images):
    
    fig, ax = plt.subplots(6, number_of_augs, figsize=(6 * number_of_augs, 5 * 6)) # (9 * number_of_augs, 6 * 6)
    fig.suptitle(f"Visibilities of Augmented Images for Source ID: {source_id}", fontsize=16)
    
    for i in range(number_of_augs):
        for j, survey in enumerate(["FIRST", "LoTSS", "NVSS"]):
            ax[j,i].imshow(augmented_images[source_id][0][i][:,:,j], aspect="auto")
            ax[j,i].set_xlabel("Baselines")
            ax[j,i].set_ylabel("Timesteps")
            ax[j,i].set_title(f"Augmentation: {i}, Survey: {survey}, AMPLITUDE")
            
            ax[j + 3,i].imshow(augmented_images[source_id][0][i][:,:,j + 3], aspect="auto")
            ax[j + 3,i].set_xlabel("Baselines")
            ax[j + 3,i].set_ylabel("Timesteps")
            ax[j + 3,i].set_title(f"Augmentation: {i}, Survey: {survey}, PHASE")

In [None]:
#> How to use and display the interactive plots.
def show_augmented_images(dataset, source_id):
    
    augmented_images = read_in_augmentations_file(dataset)

    print(len(augmented_images), "\t\t-> Number of source IDs (Total unique val-train sources pre-augmentations).")
    print(len(augmented_images[source_id]), "\t\t-> Either '1' for single survey-combined-in-channels image or '3' for seperate monochrome-survey images.")
    print(augmented_images[source_id][0].shape, "-> Shape of a single source (number of augmented images, pixels, pixels, colour channels).")
    
    number_of_augs = augmented_images[source_id][0].shape[0]
    
    if "RADVIS" in dataset: 
        handle_visibility_plots(None, source_id, number_of_augs, augmented_images)
        
    elif "RADCAT_V" in dataset: 
        handle_seperate_image_plots(None, source_id, number_of_augs, augmented_images)
        
    else: # RADCAT-F\c
        handle_combined_image_plots(None, source_id, number_of_augs, augmented_images)
    
    plt.show()
    
interact(show_augmented_images, 
         dataset = ["RADCAT_F", "RADCAT_Fc", "RADCAT_V", "RADCAT_Vc", "RADVIS", "RADVISc"],
         source_id = sorted(np.loadtxt("source_ids_val_train.txt", dtype=int)))