In [1]:
import pandas as pd 
import numpy as np 
import ast
import shutil
import gradio as gr
import os  
import torch
import random
from transformers import T5Tokenizer
import warnings
warnings.filterwarnings('ignore')

In [2]:
data_ml = pd.read_csv('data-ml/ML_predictions_class.csv')
test_set = pd.read_csv('data-ml/test.csv')
data_mtl = pd.read_csv('data-ml/tree_pred-2022-6-14.csv')

# Process Results

In [3]:
columns = list(test_set.columns)[2:]
data_ml.columns = columns
png_locs = test_set[['iauname', 'png_loc']]
iaunames = png_locs['iauname'].values
nobjects = test_set.shape[0]
newnames = [iaunames[i].replace('+', '') for i in range(nobjects)]
png_locs['newnames'] = newnames

data_ml_new = pd.concat([png_locs, data_ml], axis = 1)
data_mtl_new = pd.concat([png_locs, data_mtl], axis = 1)
nobjects = data_mtl_new.shape[0]

In [23]:
png_locs.head()

Unnamed: 0,iauname,png_loc,newnames
0,J140750.54+151031.7,J140/J140750.54+151031.7.png,J140750.54151031.7
1,J135718.46+250352.5,J135/J135718.46+250352.5.png,J135718.46250352.5
2,J091543.01+300914.0,J091/J091543.01+300914.0.png,J091543.01300914.0
3,J081848.46+054220.7,J081/J081848.46+054220.7.png,J081848.46054220.7
4,J131316.58+093030.6,J131/J131316.58+093030.6.png,J131316.58093030.6


In [4]:
def process_column(dataframe):
    nobjects = dataframe.shape[0]
    for taskname in list(dataframe.columns)[3:]:
        record = [] 
        for i in range(nobjects):
            cell_value = dataframe[taskname].values[i]
            if not pd.isnull(cell_value):
                answer = ast.literal_eval(cell_value)[0]
            else:
                answer = np.nan
            record.append(answer)
        dataframe[taskname] = record
    return dataframe

In [5]:
data_mtl_new = process_column(data_mtl_new)

In [6]:
# data_ml_new.to_csv('data-ml/processed_ml.csv')
# data_mtl_new.to_csv('data-ml/processed_mtl.csv')

# Gather Images

In [7]:
src_main = '/home/arrykrishna/Desktop/GalaxyTest/'
dst = 'images/'

In [8]:
# for path in png_locs['png_loc'].values:
#     src = src_main + path 
#     shutil.copy(src, dst)

# App Components

In [9]:
TASK_MAPPING = {

    'task_1': {'Smooth': 'smooth', 'Featured or Disk': 'has features or disk', 'Artifact': 'artifact'},
    'task_2': {'Round': 'round', 'In Between': 'elliptical', 'Cigar Shaped': 'cigar-shaped'},
    'task_3': {'Edge On Disk (Yes)': 'has an edge-on disk', 'Edge On Disk (No)': 'does not have an edge-on disk'},
    'task_4': {'Merging (Merger)': 'merging', 'Merging (Major Disturbance)': 'merging with major disturbance', 'Merging (Minor Disturbance)': 'merging with minor disturbance', 'Merging (None)': 'not merging'},
    'task_5': {'Bulge (Rounded)': 'rounded central bulge', 'Bulge (Boxy)': 'boxy central bulge', 'Bulge (None)': 'no central bulge'},
    'task_6': {'No Bar': 'no bar', 'Weak Bar': 'weak bar', 'Strong Bar': 'strong bar'},
    'task_7': {'Spiral Arms (Yes)': 'has spiral arms', 'Spiral Arms (No)': 'does not have spiral arms'},
    'task_8': {'Spiral Winding (Tight)': 'tight spiral winding', 'Spiral Winding (Medium)': 'medium spiral winding', 'Spiral Winding (Loose)': 'loose spiral winding'},
    'task_9': {'Spiral Arms (1)': 'one spiral arm', 'Spiral Arms (2)': 'two spiral arms', 'Spiral Arms (3)': 'three spiral arms', 'Spiral Arms (4)': 'four spiral arms',
               'Spiral Arms (More Than 4)': 'more than four spiral arms', 'Spiral Arms (cannot tell)': 'no spiral arms'},
    'task_10': {'Central Bulge (None)': 'no central bulge', 'Central Bulge (Small)': 'small central bulge', 'Central Bulge (Moderate)': 'moderate central bulge',
                'Central Bulge (Large)': 'large central bulge', 'Central Bulge (Dominant)': 'dominant central bulge'}

}

QUESTIONS = {
    'task_1': 'Is the galaxy simply smooth and rounded, with no sign of a disk?',
    'task_2': 'How rounded is it?',
    'task_3': 'Could this be a disk viewed edge-on?',
    'task_4': 'Is the galaxy merging or disturbed?',
    'task_5': 'Does the galaxy have a bulge at its centre? If so, what shape?',
    'task_6': 'Is there a bar feature through the centre of the galaxy?', 
    'task_7': 'Is there any sign of a spiral arm pattern?',
    'task_8': 'How tightly wound do the spiral arms appear?',
    'task_9': 'How many spiral arms are there?',
    'task_10': 'Is there a central bulge? Is so, how large is it compared with the galaxy?'
}


In [10]:
%%capture
# need this because we used them to train the model
TOKENIZER = T5Tokenizer.from_pretrained('t5-base')
DEV = 'cuda'

# the model
MODEL = torch.load('/home/arrykrishna/Documents/Oxford/Projects/Galaxy-Zoo/models/t5-model-finetuned-50-epochs.bin')
MODEL.eval()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=True`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [11]:
def prompt_generate(keywords):
    input_ids = TOKENIZER.encode(keywords + "</s>", max_length=512, truncation=True,
                                 return_tensors="pt")
    input_ids = input_ids.to(DEV)
    outputs = MODEL.generate(input_ids, do_sample=True, max_length=1024)
    output_text = TOKENIZER.decode(outputs[0])
    return output_text[6:-4]

In [12]:
def order_labels(mtl_labels):

    t_1 = ['task_1', 'task_2', 'task_4']
    t_2 = ['task_1', 'task_3', 'task_5', 'task_4']
    t_3 = ['task_1', 'task_3', 'task_6', 'task_7', 'task_10', 'task_4']
    t_4 = ['task_1', 'task_3', 'task_6', 'task_7', 'task_8', 'task_9', 'task_10', 'task_4']
    t_5 = ['task_1']
    trees = [t_1, t_2, t_3, t_4, t_5]

    tasks = list(mtl_labels.index[~pd.isnull(mtl_labels)]) 
    for i, tree in enumerate(trees):
        criterion = all([t in tree for t in tasks])
        if criterion:
            return f'tree_{i+1}', mtl_labels[tree]

In [13]:


def tree_1(label_1, label_2, label_3):
    newlabel_1 = TASK_MAPPING['task_1'][label_1]
    newlabel_2 = TASK_MAPPING['task_2'][label_2]
    newlabel_3 = TASK_MAPPING['task_4'][label_3]
    keywords = f"{newlabel_1} | {newlabel_2} | {newlabel_3}"
    return prompt_generate(keywords) 

def tree_2(label_1, label_2, label_3, label_4):
    newlabel_1 = TASK_MAPPING['task_1'][label_1]
    newlabel_2 = TASK_MAPPING['task_3'][label_2]
    newlabel_3 = TASK_MAPPING['task_5'][label_3]
    newlabel_4 = TASK_MAPPING['task_4'][label_4]
    keywords = f"{newlabel_1} | {newlabel_2} | {newlabel_3} | {newlabel_4}"
    return prompt_generate(keywords) 

def tree_3(label_1, label_2, label_3, label_4, label_5, label_6):
    newlabel_1 = TASK_MAPPING['task_1'][label_1]
    newlabel_2 = TASK_MAPPING['task_3'][label_2]
    newlabel_3 = TASK_MAPPING['task_6'][label_3]
    newlabel_4 = TASK_MAPPING['task_7'][label_4]
    newlabel_5 = TASK_MAPPING['task_10'][label_5]
    newlabel_6 = TASK_MAPPING['task_4'][label_6]
    keywords = f"{newlabel_1} | {newlabel_2} | {newlabel_3} | {newlabel_4} | {newlabel_5} | {newlabel_6}"
    return prompt_generate(keywords) 


def tree_4(label_1, label_2, label_3, label_4, label_5, label_6, label_7, label_8):
    newlabel_1 = TASK_MAPPING['task_1'][label_1]
    newlabel_2 = TASK_MAPPING['task_3'][label_2]
    newlabel_3 = TASK_MAPPING['task_6'][label_3]
    newlabel_4 = TASK_MAPPING['task_7'][label_4]
    newlabel_5 = TASK_MAPPING['task_8'][label_5]
    newlabel_6 = TASK_MAPPING['task_9'][label_6]
    newlabel_7 = TASK_MAPPING['task_10'][label_7]
    newlabel_8 = TASK_MAPPING['task_4'][label_8]
    keywords = f"{newlabel_1} | {newlabel_2} | {newlabel_3} | {newlabel_4} | {newlabel_5} | {newlabel_6} | {newlabel_7} | {newlabel_8}"
    return prompt_generate(keywords) 

def tree_5(label_1):
    newlabel_1 = TASK_MAPPING['task_1'][label_1]
    keywords = f"{newlabel_1}"
    return prompt_generate(keywords) 

In [14]:
def generate_sentence(treename, labels):
    
    if treename == 'tree_1':
        sentence = tree_1(*labels.values)
    
    if treename == 'tree_2':
        sentence = tree_2(*labels.values)
        
    if treename == 'tree_3':
        sentence = tree_3(*labels.values)
        
    if treename == 'tree_4':
        sentence = tree_4(*labels.values)
        
    if treename == 'tree_5':
        sentence = tree_5(*labels.values)
    
    return sentence 

def generate_question_answer(mtl_labels):
    
    tasks = list(mtl_labels.index)
    record = {}
    for i, t in enumerate(tasks):
        record[QUESTIONS[t]] = mtl_labels.values[i]
    return record

def generate_tags(ml_labels):
    return list(ml_labels.index[ml_labels == 1])

# Components

In [15]:
index = 3
mtl_test = data_mtl_new.iloc[index][3:]
ml_test = data_ml_new.iloc[index][3:]

In [16]:
generate_tags(ml_test)

['Featured or Disk',
 'Cigar Shaped',
 'Edge On Disk (No)',
 'Merging (None)',
 'Bulge (None)',
 'Spiral Arms (Yes)',
 'Central Bulge (Small)']

In [17]:
new_tree, new_test = order_labels(mtl_test)

In [18]:
generate_question_answer(new_test)

{'Is the galaxy simply smooth and rounded, with no sign of a disk?': 'Featured or Disk',
 'Could this be a disk viewed edge-on?': 'Edge On Disk (No)',
 'Is there a bar feature through the centre of the galaxy?': 'No Bar',
 'Is there any sign of a spiral arm pattern?': 'Spiral Arms (Yes)',
 'How tightly wound do the spiral arms appear?': 'Spiral Winding (Tight)',
 'How many spiral arms are there?': 'Spiral Arms (cannot tell)',
 'Is there a central bulge? Is so, how large is it compared with the galaxy?': 'Central Bulge (Small)',
 'Is the galaxy merging or disturbed?': 'Merging (None)'}

In [19]:
generate_sentence(new_tree, new_test)

'This galaxy image does not have an edge-on disk, a bar, or spiral arms, but it does have a small central bulge and tight spiral winding.'

# Application

In [21]:
def get_images():
    nimages = 4
    img_path = 'images/'
    img_paths = os.listdir(img_path)
    idx = random.sample(range(0, nobjects), nimages)
    chosen_images = np.asarray(img_paths)[idx]
    full_paths = [img_path + c for c in chosen_images]
    images = [(full_paths[i], f"Image {i}") for i in range(nimages)]
    return images

def get_select_index(evt: gr.SelectData):
    return evt.index

def machine_learning(number, gallery):
    obj_name = gallery[number][0].split(os.sep)[-1][:-4]
    
    # multilabel case (taggings)
    test_obj = data_ml_new[data_ml_new['newnames'] == obj_name]
    tags = generate_tags(test_obj.iloc[0])
    tags_out = ''.join([f'{t}\n' for t in tags])
    
    # multitask case 
    test_obj_mtl = data_mtl_new[data_mtl_new['newnames'] == obj_name]
    new_tree, new_test = order_labels(mtl_test)
    mtl_out = generate_question_answer(new_test)
    mtl_ques_ans = ''
    for i, (k, v) in enumerate(mtl_out.items()):
        mtl_ques_ans += f'{i+1}) ' + k + '\n' + v + '\n\n'
        
    # generate sentence 
    sentence = generate_sentence(new_tree, new_test)
    
    return tags_out, mtl_ques_ans, sentence

In [24]:
# <img src="https://raw.githubusercontent.com/Harry45/azml/main/projects/GalaxyGenius/paper-images/tree.png" width="400"/>

In [42]:
with gr.Blocks() as demo:
    gr.HTML("""
    
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>GalaxyGenius</title>
    <style>
        .container {
            display: flex;
            justify-content: space-between;
            align-items: flex-start;
            padding: 20px;
        }
        .column {
            flex: 1;
            padding: 10px;
        }
        .column img {
            max-width: 100%;
            height: auto;
        }
        p {
            text-align: justify;
        }
    </style>
</head>
<body>

<div class="container">
    <div class="column">
        <h2>GalaxyGenius</h2>
        
        <h3>Multi-label Learning for Galaxy Image Tagging</h3>
        <p>Our innovative application utilizes multi-label learning techniques to assign relevant tags to galaxy images. By analyzing various features and characteristics within each image, the system can accurately identify and label multiple attributes such as galaxy type, shape, and size. This approach enhances the efficiency of cataloging vast collections of astronomical data and provides valuable insights into the diverse properties of galaxies across the universe.</p>

        <h3>Multi-Task Learning for Hierarchical Classification of Galaxies</h3>
        <p>Our cutting-edge system employs multi-task learning methodologies to navigate through hierarchical taxonomies and classify galaxies. By simultaneously training on multiple related tasks, the model learns to identify the optimal path down the hierarchical tree structure, enabling precise categorization of galaxies based on their intrinsic properties. This approach streamlines the classification process and enhances the accuracy and granularity of galaxy classification, facilitating comprehensive studies of cosmic structures and evolution.</p>

        <h3>NLP-Powered Object Description for Astronomical Images</h3>
        <p>This tool leverages advanced Natural Language Processing (NLP) techniques and the
        Multi-Task Learning (MTL) method developed above, to analyze word descriptions across various tasks simultaneously, 
        generating prompts tailored to individual needs. Whether you're a writer seeking inspiration, 
        a student preparing for presentations, or a professional brainstorming ideas, this application 
        adapts seamlessly to provide relevant and engaging prompts. With its user-friendly interface and 
        sophisticated algorithms, this application empowers users to unlock their full creative potential 
        effortlessly. Experience the future of prompt generation and redefine the way you approach your 
        creative endeavors with this powerful tool.</p>


    </div>
    <div class="column">
        <img src="https://raw.githubusercontent.com/Harry45/azml/main/projects/GalaxyGenius/paper-images/tree.png" width="420" alt="Application Image">
    </div>
</div>

</body>
</html>
    """) 

    gallery = gr.Gallery(label="Your Images", columns=[2], rows=[2], min_width=200, interactive=False, height = 700)
    btn = gr.Button("Get Your Images", scale=0)
    btn.click(get_images, None, gallery)
    selected = gr.Number(label = 'Image Selected', show_label=True)
    index_selected = gallery.select(get_select_index, None, selected)
    gr.Interface(machine_learning, [selected, gallery], [gr.Textbox(label = 'Tags'), 
                                                         gr.Textbox(label = 'Classifier'),
                                                         gr.Textbox(label = 'NLP Model')])
    
if __name__ == "__main__":
    demo.launch()

Running on local URL:  http://127.0.0.1:7878

To create a public link, set `share=True` in `launch()`.
