In [None]:
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import traceback
app = Flask(__name__)
CORS(app)


tokenizer = GPT2Tokenizer.from_pretrained('ytu-ce-cosmos/turkish-gpt2-medium-350m-instruct-v0.1')
tokenizerBase = GPT2Tokenizer.from_pretrained('ytu-ce-cosmos/turkish-gpt2')
special_tokens_dict = {
    "bos_token": "<BOS>",
    "eos_token": "<EOS>",
    "pad_token": "<PAD>",
    "additional_special_tokens": ["<Title>","<EndTitle>"]
    }
tokenizer.add_special_tokens(special_tokens_dict)
tokenizerBase.add_special_tokens(special_tokens_dict)

modelMed = GPT2LMHeadModel.from_pretrained('./Medium')
modelBase = GPT2LMHeadModel.from_pretrained('./Base')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelMed.to(device)
modelBase.to(device)

def generate_text_Medium(prompt, max_length=750):
    inputs = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    outputs = modelMed.generate(inputs, max_length=max_length, num_return_sequences=1,
                            eos_token_id=tokenizer.eos_token_id,
                            repetition_penalty=1.2,
                            early_stopping= True ,
                            num_beams = 3)
    
    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return text

def generate_text_Base(prompt, max_length=750):
    inputs = tokenizerBase.encode(prompt, return_tensors='pt').to(device)
    
    outputs = modelBase.generate(inputs, max_length=max_length, num_return_sequences=1,
                            eos_token_id=tokenizerBase.eos_token_id,
                            repetition_penalty=1.2,
                            early_stopping= True ,
                            num_beams = 3)
    
    text = tokenizerBase.decode(outputs[0], skip_special_tokens=True)
    return text

@app.route('/api', methods=['POST'])
def api():
    if request.method == 'POST':
        data = request.get_json()
        title = data.get('title')
        modelType = data.get('model')
        print(modelType)
        if modelType == "Medium":
            text = generate_text_Medium("<BOS> <Title> "+title + " <EndTitle>",750)
        else:
            text = generate_text_Base("<BOS> <Title> "+title + " <EndTitle>",750)
        text = text.replace(title,"")
        return jsonify(text)

if __name__ == '__main__':
    try:
        app.run(host='localhost', port=5000, debug=False)
    except Exception as e:
        traceback.print_exc()
