In [1]:
import gradio as gr
import torch
from transformers import T5Tokenizer
import warnings
warnings.filterwarnings('ignore')

In [4]:
%%capture
# need this because we used them to train the model
TOKENIZER = T5Tokenizer.from_pretrained('t5-base')
DEV = 'cuda'

# the model
MODEL = torch.load('/home/arrykrishna/Documents/Oxford/Projects/Galaxy-Zoo/models/t5-model-finetuned-50-epochs.bin')
MODEL.eval()

In [5]:
def prompt_generate(keywords):
    input_ids = TOKENIZER.encode(keywords + "</s>", max_length=512, truncation=True,
                                 return_tensors="pt")
    input_ids = input_ids.to(DEV)
    outputs = MODEL.generate(input_ids, do_sample=True, max_length=1024)
    output_text = TOKENIZER.decode(outputs[0])
    return output_text[6:-4]

In [6]:
%%capture

TASK_MAPPING = {

    'task_1': {'Smooth': 'smooth', 'Featured or Disk': 'has features or disk', 'Artifact': 'artifact'},
    'task_2': {'Round': 'round', 'In Between': 'elliptical', 'Cigar Shaped': 'cigar-shaped'},
    'task_3': {'Edge On Disk (Yes)': 'has an edge-on disk', 'Edge On Disk (No)': 'does not have an edge-on disk'},
    'task_4': {'Merging (Merger)': 'merging', 'Merging (Major Disturbance)': 'merging with major disturbance', 'Merging (Minor Disturbance)': 'merging with minor disturbance', 'Merging (None)': 'not merging'},
    'task_5': {'Bulge (Rounded)': 'rounded central bulge', 'Bulge (Boxy)': 'boxy central bulge', 'Bulge (None)': 'no central bulge'},
    'task_6': {'No Bar': 'no bar', 'Weak Bar': 'weak bar', 'Strong Bar': 'strong bar'},
    'task_7': {'Spiral Arms (Yes)': 'has spiral arms', 'Spiral Arms (No)': 'does not have spiral arms'},
    'task_8': {'Spiral Winding (Tight)': 'tight spiral winding', 'Spiral Winding (Medium)': 'medium spiral winding', 'Spiral Winding (Loose)': 'loose spiral winding'},
    'task_9': {'Spiral Arms (1)': 'one spiral arm', 'Spiral Arms (2)': 'two spiral arms', 'Spiral Arms (3)': 'three spiral arms', 'Spiral Arms (4)': 'four spiral arms',
               'Spiral Arms (More Than 4)': 'more than four spiral arms', 'Spiral Arms (cannot tell)': 'no spiral arms'},
    'task_10': {'Central Bulge (None)': 'no central bulge', 'Central Bulge (Small)': 'small central bulge', 'Central Bulge (Moderate)': 'moderate central bulge',
                'Central Bulge (Large)': 'large central bulge', 'Central Bulge (Dominant)': 'dominant central bulge'}

}

LABELS = {

    'task_1': ['Smooth', 'Featured or Disk', 'Artifact'],
    'task_2': ['Round', 'In Between', 'Cigar Shaped'],
    'task_3': ['Edge On Disk (Yes)', 'Edge On Disk (No)'],
    'task_4': ['Merging (Merger)', 'Merging (Major Disturbance)', 'Merging (Minor Disturbance)', 'Merging (None)'],
    'task_5': ['Bulge (Rounded)', 'Bulge (Boxy)', 'Bulge (None)'],
    'task_6': ['No Bar', 'Weak Bar', 'Strong Bar'],
    'task_7': ['Spiral Arms (Yes)', 'Spiral Arms (No)'],
    'task_8': ['Spiral Winding (Tight)', 'Spiral Winding (Medium)', 'Spiral Winding (Loose)'],
    'task_9': ['Spiral Arms (1)', 'Spiral Arms (2)', 'Spiral Arms (3)', 'Spiral Arms (4)',
               'Spiral Arms (More Than 4)', 'Spiral Arms (cannot tell)'],
    'task_10': ['Central Bulge (None)', 'Central Bulge (Small)', 'Central Bulge (Moderate)',
                'Central Bulge (Large)', 'Central Bulge (Dominant)']

}

# app 2
def tree_1(label_1, label_2, label_3):
    newlabel_1 = TASK_MAPPING['task_1'][label_1]
    newlabel_2 = TASK_MAPPING['task_2'][label_2]
    newlabel_3 = TASK_MAPPING['task_4'][label_3]
    keywords = f"{newlabel_1} | {newlabel_2} | {newlabel_3}"
    return prompt_generate(keywords) 

def tree_2(label_1, label_2, label_3, label_4):
    newlabel_1 = TASK_MAPPING['task_1'][label_1]
    newlabel_2 = TASK_MAPPING['task_3'][label_2]
    newlabel_3 = TASK_MAPPING['task_5'][label_3]
    newlabel_4 = TASK_MAPPING['task_4'][label_4]
    keywords = f"{newlabel_1} | {newlabel_2} | {newlabel_3} | {newlabel_4}"
    return prompt_generate(keywords) 

def tree_3(label_1, label_2, label_3, label_4, label_5, label_6):
    newlabel_1 = TASK_MAPPING['task_1'][label_1]
    newlabel_2 = TASK_MAPPING['task_3'][label_2]
    newlabel_3 = TASK_MAPPING['task_6'][label_3]
    newlabel_4 = TASK_MAPPING['task_7'][label_4]
    newlabel_5 = TASK_MAPPING['task_10'][label_5]
    newlabel_6 = TASK_MAPPING['task_4'][label_6]
    keywords = f"{newlabel_1} | {newlabel_2} | {newlabel_3} | {newlabel_4} | {newlabel_5} | {newlabel_6}"
    return prompt_generate(keywords) 

def tree_4(label_1, label_2, label_3, label_4, label_5, label_6, label_7, label_8):
    newlabel_1 = TASK_MAPPING['task_1'][label_1]
    newlabel_2 = TASK_MAPPING['task_3'][label_2]
    newlabel_3 = TASK_MAPPING['task_6'][label_3]
    newlabel_4 = TASK_MAPPING['task_7'][label_4]
    newlabel_5 = TASK_MAPPING['task_8'][label_5]
    newlabel_6 = TASK_MAPPING['task_9'][label_6]
    newlabel_7 = TASK_MAPPING['task_10'][label_7]
    newlabel_8 = TASK_MAPPING['task_4'][label_8]
    keywords = f"{newlabel_1} | {newlabel_2} | {newlabel_3} | {newlabel_4} | {newlabel_5} | {newlabel_6} | {newlabel_7} | {newlabel_8}"
    return prompt_generate(keywords) 

def tree_5(label_1):
    newlabel_1 = TASK_MAPPING['task_1'][label_1]
    keywords = f"{newlabel_1}"
    return prompt_generate(keywords) 


questions = {
    'q1': 'Is the galaxy simply smooth and rounded, with no sign of a disk?',
    'q2': 'How rounded is it?',
    'q3': 'Could this be a disk viewed edge-on?',
    'q4': 'Is the galaxy merging or disturbed?',
    'q5': 'Does the galaxy have a bulge at its centre? If so, what shape?',
    'q6': 'Is there a bar feature through the centre of the galaxy?', 
    'q7': 'Is there any sign of a spiral arm pattern?',
    'q8': 'How tightly wound do the spiral arms appear?',
    'q9': 'How many spiral arms are there?',
    'q10': 'Is there a central bulge? Is so, how large is it compared with the galaxy?'
}


with gr.Blocks(analytics_enabled=False) as app1:
    gr.Markdown("""
    # Generate your Galaxy
    
    Here we will be 
    
    $a = b + c$
    """
    )

#interface 2
app2 =  gr.Interface(fn = tree_1, 
                     inputs=[gr.Radio(["Smooth"], label=questions['q1']),
                            gr.Radio(["Round", "In Between", "Cigar Shaped"], label=questions['q2']),
                            gr.Radio(['Merging (Merger)', 'Merging (Major Disturbance)', 'Merging (Minor Disturbance)', 'Merging (None)'], label=questions['q4'])],
                     outputs="text")

app3 = gr.Interface(fn = tree_2, 
                     inputs=[gr.Radio(["Featured or Disk"], label=questions['q1']),
                            gr.Radio(['Edge On Disk (Yes)', 'Edge On Disk (No)'], label=questions['q3']),
                            gr.Radio(['Bulge (Rounded)', 'Bulge (Boxy)', 'Bulge (None)'], label=questions['q5']),
                            gr.Radio(['Merging (Merger)', 'Merging (Major Disturbance)', 'Merging (Minor Disturbance)', 'Merging (None)'], label=questions['q4'])],
                     outputs="text")

app4 = gr.Interface(fn = tree_3, 
                     inputs=[gr.Radio(["Featured or Disk"], label=questions['q1']),
                            gr.Radio(['Edge On Disk (Yes)', 'Edge On Disk (No)'], label=questions['q3']),
                            gr.Radio(['No Bar', 'Weak Bar', 'Strong Bar'], label=questions['q6']),
                            gr.Radio(['Spiral Arms (No)'], label=questions['q7']),
                            gr.Radio(['Central Bulge (None)', 'Central Bulge (Small)', 'Central Bulge (Moderate)',
                                       'Central Bulge (Large)', 'Central Bulge (Dominant)'], label=questions['q10']),
                            gr.Radio(['Merging (Merger)', 'Merging (Major Disturbance)', 'Merging (Minor Disturbance)', 'Merging (None)'], label=questions['q4'])],
                     outputs="text")

app5 = gr.Interface(fn = tree_4, 
                     inputs=[gr.Radio(["Featured or Disk"], label=questions['q1']),
                            gr.Radio(['Edge On Disk (Yes)', 'Edge On Disk (No)'], label=questions['q3']),
                            gr.Radio(['No Bar', 'Weak Bar', 'Strong Bar'], label=questions['q6']),
                            gr.Radio(['Spiral Arms (Yes)'], label=questions['q7']),
                            gr.Radio(['Spiral Winding (Tight)', 'Spiral Winding (Medium)', 'Spiral Winding (Loose)'], label=questions['q8']),
                            gr.Radio(['Spiral Arms (1)', 'Spiral Arms (2)', 'Spiral Arms (3)', 'Spiral Arms (4)',
                                      'Spiral Arms (More Than 4)', 'Spiral Arms (cannot tell)'], label=questions['q9']),
                            gr.Radio(['Central Bulge (None)', 'Central Bulge (Small)', 'Central Bulge (Moderate)',
                                       'Central Bulge (Large)', 'Central Bulge (Dominant)'], label=questions['q10']),
                            gr.Radio(['Merging (Merger)', 'Merging (Major Disturbance)', 'Merging (Minor Disturbance)', 'Merging (None)'], label=questions['q4'])],
                     outputs="text")

app6 = gr.Interface(fn = tree_5, 
                     inputs=[gr.Radio(["Artifact"], label=questions['q1'])
                            ],
                     outputs="text")

demo = gr.TabbedInterface([app1, app2, app3, app4, app5, app6], ["Welcome", "Tree 1", "Tree 2", "Tree 3", "Tree 4", "Tree 5"])

In [7]:
demo.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


