In [None]:
%%capture
!pip install transformers indobenchmark-toolkit bitsandbytes pyngrok flask flask-cors
!pip install unsloth "xformers==0.0.28.post2"
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth @ git+https://github.com/unslothai/unsloth.git"
!pip uninstall -y torch torchvision torchaudio
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [None]:
from unsloth import FastLanguageModel

In [None]:
from typing import Optional
import re
import os
import torch
import random
import numpy as np
import bitsandbytes
from flask import Flask, request, jsonify
from flask_cors import CORS
from pyngrok import ngrok
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import MBartForConditionalGeneration, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
from indobenchmark import IndoNLGTokenizer

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def format_text(text):
    text = text.strip()
    sentences = re.split('([.!?])', text)
    sentences = [s.strip().capitalize() for s in sentences if s.strip()]
    formatted_text = ''.join([f'{s} ' if s in '.!?' else s for s in sentences]).strip()
    return formatted_text

In [None]:
ngrok.set_auth_token("2pKBQen0CRf2LTplCUtNrNWacfi_4zEqemo7u5MtkMZef3PTY")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the pre-trained model and tokenizer
model_name_bart = "Danda245/BART-medical-qa-id"
model_name_gpt2 = "lafarizo/indo_medical_gpt2_v4"
model_name_t5 = "Yaziddd/indoT5-base-medical-qa-3"
model_name_phi = "Rizald95/phi.35-instruct-medical-qa-id"

model_bart = MBartForConditionalGeneration.from_pretrained(model_name_bart).to(device)
tokenizer_bart = IndoNLGTokenizer.from_pretrained(model_name_bart)

model_gpt2 = AutoModelForCausalLM.from_pretrained(model_name_gpt2).to(device)
tokenizer_gpt2 = AutoTokenizer.from_pretrained(model_name_gpt2)
if tokenizer_gpt2.pad_token is None:
    tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token

model_t5 = AutoModelForSeq2SeqLM.from_pretrained(model_name_t5).to(device)
tokenizer_t5 = AutoTokenizer.from_pretrained(model_name_t5, use_fast=True)

model_instance_phi, tokenizer_phi = FastLanguageModel.from_pretrained(model_name_phi)
model_phi = FastLanguageModel.for_inference(model_instance_phi)

In [None]:
# Enable CORS for Flask app
app = Flask(__name__)
CORS(app, resources={r"/generate": {"origins": "*"}})

@app.route('/generate', methods=['POST'])
def generate():
    user_input = request.json.get('input', '')
    formatted_input = (
        f"Input:\n {user_input}\n"
        "Response:"
    )

    ########
    # BART #
    ########
    def generate_bart():
        set_seed(101)
        inputs_bart = tokenizer_bart.prepare_input_for_generation(
            user_input.lower(),
            return_tensors='pt',
            lang_token='[indonesian]',
            decoder_lang_token='[indonesian]',
        ).to(device)
        
        input_ids_bart = inputs_bart['input_ids'].to(device)
        attention_mask_bart = inputs_bart['attention_mask'].to(device)

        if input_ids_bart.dim() == 1:
            input_ids_bart = input_ids_bart.unsqueeze(0)
            attention_mask_bart = attention_mask_bart.unsqueeze(0)

        with torch.no_grad():
            outputs_bart = model_bart.generate(
                input_ids=input_ids_bart,
                attention_mask=attention_mask_bart,
                max_length=128,
                num_beams=5,
                top_k=20,
                top_p=0.7,
                no_repeat_ngram_size=2,
                do_sample=True,
                temperature=0.7,
            )

        generated_ans_bart = tokenizer_bart.decode(
            outputs_bart[0],
            skip_special_tokens=True
        )

        if generated_ans_bart.startswith(user_input.lower()):
            generated_ans_bart = generated_ans_bart[len(user_input):].strip()
            
        return 'BART', format_text(generated_ans_bart)

    ########
    # GPT2 #
    ########
    def generate_gpt2():
        set_seed(11)
        inputs_gpt2 = tokenizer_gpt2(
            user_input, 
            return_tensors="pt", 
            truncation=True, 
            padding=True,
        ).to(device)
        
        input_ids_gpt2 = inputs_gpt2['input_ids'].to(device)
        attention_mask_gpt2 = inputs_gpt2['attention_mask'].to(device)

        with torch.no_grad():
            outputs_gpt2 = model_gpt2.generate(
                input_ids=input_ids_gpt2,
                attention_mask=attention_mask_gpt2,
                max_length=128,
                num_beams=5,
                top_k=20,
                top_p=0.7,
                no_repeat_ngram_size=2,
                do_sample=True,
                temperature=0.7,
                eos_token_id=tokenizer_gpt2.eos_token_id,
                pad_token_id=tokenizer_gpt2.pad_token_id
            )

        generated_ans_gpt2 = tokenizer_gpt2.decode(
            outputs_gpt2[0], 
            skip_special_tokens=True
        )

        if generated_ans_gpt2.lower().startswith(user_input.lower()):
            generated_ans_gpt2 = generated_ans_gpt2[len(user_input):].strip()

        return 'GPT2', generated_ans_gpt2

    ######
    # T5 #
    ######
    def generate_t5():
        set_seed(41)
        inputs_t5 = tokenizer_t5(
            user_input,
            return_tensors="pt",
            truncation=True, 
            padding=True,
        ).to(device)
        
        input_ids_t5 = inputs_t5['input_ids'].to(device)
        attention_mask_t5 = inputs_t5['attention_mask'].to(device)

        if input_ids_t5.dim() == 1:
            input_ids_t5 = input_ids_t5.unsqueeze(0)
            attention_mask_t5 = attention_mask_t5.unsqueeze(0)
            
        with torch.no_grad():
            outputs_t5 = model_t5.generate(
                input_ids=input_ids_t5,
                attention_mask=attention_mask_t5,
                max_length=128,
                num_beams=5,
                top_k=20,
                top_p=0.7,
                no_repeat_ngram_size=2,
                do_sample=True,
                temperature=0.7,
            )

        generated_ans_t5 = tokenizer_t5.decode(
            outputs_t5[0], 
            skip_special_tokens=True
        )

        if generated_ans_t5.lower().startswith(user_input.lower()):
            generated_ans_t5 = generated_ans_t5[len(user_input):].strip()
            
        return 'T5', generated_ans_t5

    def generate_phi():
        set_seed(3407)
        inputs_phi = tokenizer_phi(
            formatted_input, 
            return_tensors="pt",
            truncation=True, 
            padding=True,
        ).to(device)

        input_ids_phi = inputs_phi['input_ids'].to('cuda')
        attention_mask_phi = inputs_phi['attention_mask'].to('cuda')

        with torch.no_grad():
            outputs_phi = model_phi.generate(
                input_ids=input_ids_phi, 
                attention_mask=attention_mask_phi,
                max_new_tokens=128,
                num_beams=1,
                top_k=20,
                top_p=0.7,
                no_repeat_ngram_size=2,
                do_sample=True,
                temperature=0.7,
            )

        generated_ans_phi = tokenizer_phi.decode(
            outputs_phi[0], 
            skip_special_tokens=True
        )

        if generated_ans_phi.lower().startswith(formatted_input.lower()):
            generated_ans_phi = generated_ans_phi[len(formatted_input):].strip()

        return 'PHI3.5', generated_ans_phi

    #####################################
    # Run model generations in parallel #
    #####################################
    with ThreadPoolExecutor() as executor:
        futures = {
            executor.submit(generate_bart): 'BART',
            executor.submit(generate_gpt2): 'GPT2',
            executor.submit(generate_t5): 'T5',
            executor.submit(generate_phi): 'PHI3.5',
        }

        results = {}
        for future in as_completed(futures):
            model_name, output = future.result()
            results[model_name] = output

    return jsonify(results)

In [None]:
if __name__ == '__main__':
    public_url = ngrok.connect(5000)
    print(f"Public URL: {public_url}")
    app.run(host='0.0.0.0', port=5000)