<a href="https://colab.research.google.com/github/AliJaffery12/Construction-of-a-parallel-multilingual-corpus-with-NLP-and-SemanticWeb/blob/main/FlaskApp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import time
import dash
import dash_html_components as html
import dash_core_components as dcc
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
from transformers import BartTokenizer, BartForConditionalGeneration, MarianMTModel, MarianTokenizer, DistilBertTokenizer, DistilBertForQuestionAnswering
import torch
import nltk

nltk.download("punkt")
from nltk.tokenize import sent_tokenize

# Initialize the required variables and models
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Load Summarization Model
pretrained_summ = "sshleifer/distilbart-xsum-12-6"
model_summ = BartForConditionalGeneration.from_pretrained(pretrained_summ)
tokenizer_summ = BartTokenizer.from_pretrained(pretrained_summ)

# Switch to cuda, eval mode, and FP16 for faster inference
if device == "cuda":
    model_summ = model_summ.half()
model_summ.to(device)
model_summ.eval()

# Load Translation Model
model_trans = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-ROMANCE")
tokenizer_trans = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ROMANCE")

# Load Question Answering Model
pretrained_qa = "distilbert-base-uncased-distilled-squad"
model_qa = DistilBertForQuestionAnswering.from_pretrained(pretrained_qa)
tokenizer_qa = DistilBertTokenizer.from_pretrained(pretrained_qa)

# Create the app
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
server = app.server

# Layout for text summarization
layout_summ = dbc.Container(
    fluid=True,
    children=[
        html.H1(" Automatic Summarization (with DistilBART)"),
        html.Hr(),
        dbc.Row(
            [
                dbc.Col(
                    width=5,
                    children=[
                        dbc.Card(
                            body=True,
                            children=[
                                dbc.FormGroup(
                                    [
                                        dbc.Label("Output Length "),
                                        dcc.Slider(
                                            id="max-length",
                                            min=10,
                                            max=50,
                                            value=30,
                                            marks={i: str(i) for i in range(10, 51, 10)},
                                        ),
                                    ]
                                ),
                                dbc.FormGroup(
                                    [
                                        dbc.Label("Beam Size"),
                                        dcc.Slider(
                                            id="num-beams",
                                            min=2,
                                            max=6,
                                            value=4,
                                            marks={i: str(i) for i in [2, 4, 6]},
                                        ),
                                    ]
                                ),
                                dbc.FormGroup(
                                    [
                                        dbc.Spinner(
                                            [
                                                dbc.Button("Summarize", id="button-run"),
                                                html.Div(id="time-taken"),
                                            ]
                                        )
                                    ]
                                ),
                            ],
                            style={"height": "275px"},
                        ),
                        dbc.Card(
                            body=True,
                            children=[
                                dbc.FormGroup(
                                    [
                                        dbc.Label("Summarized Content"),
                                        dcc.Textarea(
                                            id="summarized-content",
                                            style={
                                                "width": "100%",
                                                "height": "calc(75vh - 275px)",
                                            },
                                        ),
                                    ]
                                )
                            ],
                        ),
                    ],
                ),
                dbc.Col(
                    width=7,
                    children=[
                        dbc.Card(
                            body=True,
                            children=[
                                dbc.FormGroup(
                                    [
                                        dbc.Label("Original Text (Paste here)"),
                                        dcc.Textarea(
                                            id="original-text",
                                            style={"width": "100%", "height": "75vh"},
                                        ),
                                    ]
                                )
                            ],
                        )
                    ],
                ),
            ]
        ),
    ],
)

# Layout for text translation
layout_trans = dbc.Container(
    fluid=True,
    children=[
        html.H1("Translation"),
        html.Hr(),
        dbc.Spinner(
            dbc.Row(
                [
                    dbc.Col(dbc.Button("Translate", id="button-translate"), width=2),
                    dbc.Col(
                        html.Div(id="time-output", style={"margin-top": "8px"}),
                        width=10,
                    ),
                ],
                style={"margin-bottom": "15px"},
            )
        ),
        dbc.Row(
            [
                dbc.Col(
                    [
                        dbc.InputGroup(
                            [
                                dbc.InputGroupAddon(
                                    "Source Language", addon_type="prepend"
                                ),
                                dbc.Select(
                                    id="source-language",
                                    options=[{"label": "English", "value": "en"}],
                                    value="en",
                                ),
                            ]
                        ),
                        dbc.Textarea(
                            id="source-text",
                            style={"margin-top": "15px", "height": "65vh"},
                        ),
                    ]
                ),
                dbc.Col(
                    [
                        dbc.InputGroup(
                            [
                                dbc.InputGroupAddon(
                                    "Target Language", addon_type="prepend"
                                ),
                                dbc.Select(
                                    id="target-language",
                                    options=[
                                        {
                                            "label": v.replace(">>", "").replace(
                                                "<<", ""
                                            ),
                                            "value": v,
                                        }
                                        for v in tokenizer_trans.supported_language_codes
                                    ],
                                    value=">>fr<<",
                                ),
                            ]
                        ),
                        dbc.Textarea(
                            id="target-text",
                            style={"margin-top": "15px", "height": "65vh"},
                        ),
                    ]
                ),
            ]
        ),
    ],
)

# Layout for question answering
layout_qa = dbc.Container(
    fluid=True,
    children=[
        html.H1("Question Answering Pipelines (with DistilBERT)"),
        html.Hr(),
        dbc.Spinner(
            [
                dbc.FormGroup(
                    [
                        dbc.Label("Question"),
                        dbc.Input(id="question-input", type="text"),
                    ]
                ),
                dbc.FormGroup(
                    [
                        dbc.Label("Context"),
                        dbc.Textarea(id="context-input", style={"height": "25vh"}),
                    ]
                ),
                dbc.Button("Answer", id="qa-button", color="primary", className="mt-2"),
                html.Div(id="qa-output", className="mt-2"),
            ]
        ),
    ],
)

# Create the main app layout with a callback to switch between the layouts based on the URL path
app.layout = html.Div([
    dcc.Location(id="url", refresh=False),
    html.Div(id="page-content")
])

@app.callback(Output("page-content", "children"), [Input("url", "pathname")])
def display_page(pathname):
    if pathname == "/translation":
        return layout_trans
    elif pathname == "/QuestionsAnswer":
        return layout_qa
    else:
        return layout_summ

# Your summarization callback
@app.callback(
    [Output("summarized-content", "value"), Output("time-taken", "children")],
    [
        Input("button-run", "n_clicks"),
        Input("max-length", "value"),
        Input("num-beams", "value"),
    ],
    [State("original-text", "value")],
)
def summarize(n_clicks, max_len, num_beams, original_text):
    if original_text is None or original_text == "":
        return "", "Did not run"

    t0 = time.time()

    inputs = tokenizer_summ.batch_encode_plus(
        [original_text], max_length=1024, return_tensors="pt"
    )
    inputs = inputs.to(device)

    # Generate Summary
    summary_ids = model_summ.generate(
        inputs["input_ids"],
        num_beams=num_beams,
        max_length=max_len,
        early_stopping=True,
    )
    out = [
        tokenizer_summ.decode(
            g, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        for g in summary_ids
    ]

    t1 = time.time()
    time_taken = f"Summarized on {device} in {t1 - t0:.2f}s"

    return out[0], time_taken

# Your translation callback
@app.callback(
    [Output("target-text", "value"), Output("time-output", "children")],
    [
        Input("button-translate", "n_clicks"),
        Input("source-language", "value"),
        Input("target-language", "value"),
    ],
    [State("source-text", "value")],
)
def translate(n_clicks, src_lang, tgt_lang, src_text):
    if src_text is None or src_text == "":
        return "", "Did not run."

    t0 = time.time()

    # Tokenize the source text
    input_ids = tokenizer_trans.encode(src_text, return_tensors="pt").to(device)

    # Translate the text
    tgt_lang_id = tokenizer_trans.encode(tgt_lang, return_tensors="pt").to(device)
    input_ids = torch.cat([tgt_lang_id, input_ids], dim=-1)

    translated = model_trans.generate(input_ids)
    tgt_text = tokenizer_trans.decode(translated[0], skip_special_tokens=True)

    t1 = time.time()
    time_output = f"Translated on {device} in {t1 - t0:.2f}s"

    return tgt_text, time_output


@app.callback(
    Output("qa-output", "children"),
    [Input("qa-button", "n_clicks")],
    [
        State("question-input", "value"),
        State("context-input", "value"),
    ],
)
def answer_question(n_clicks, question, context):
    if not question or not context:
        return "Please provide both a question and context."

    # Tokenize the question and context
    inputs = tokenizer_qa(question, context, return_tensors="pt")

    # Get the model's answer
    with torch.no_grad():
        outputs = model_qa(**inputs)

    # Decode the answer
    answer_start = torch.argmax(outputs.start_logits, dim=1)
    answer_end = torch.argmax(outputs.end_logits, dim=1)
    answer = tokenizer_qa.convert_tokens_to_string(
        tokenizer_qa.convert_ids_to_tokens(inputs.input_ids[0][answer_start : answer_end + 1])
    )

    return answer


if __name__ == "__main__":
    app.run_server(debug=True)
