In [None]:
!pip install flask pyngrok torch kss

from flask import Flask, request, jsonify
from pyngrok import ngrok
import torch
import torch.nn as nn
import torch.nn.functional as F
import kss
import pickle

# ===== HAN 모델 클래스 정의 =====
class WordAttention(nn.Module):
    def __init__(self, embed_size, hidden_size):
        super().__init__()
        self.gru = nn.GRU(embed_size, hidden_size, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(hidden_size * 2, hidden_size * 2)
        self.context = nn.Parameter(torch.randn(hidden_size * 2))

    def forward(self, x):
        out, _ = self.gru(x)
        u = torch.tanh(self.fc(out))
        attn = torch.matmul(u, self.context)
        attn = F.softmax(attn, dim=1).unsqueeze(-1)
        s = torch.sum(out * attn, dim=1)
        return s

class SentenceAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.gru = nn.GRU(hidden_size * 2, hidden_size, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(hidden_size * 2, hidden_size * 2)
        self.context = nn.Parameter(torch.randn(hidden_size * 2))

    def forward(self, x):
        out, _ = self.gru(x)
        u = torch.tanh(self.fc(out))
        attn = torch.matmul(u, self.context)
        attn = F.softmax(attn, dim=1).unsqueeze(-1)
        v = torch.sum(out * attn, dim=1)
        return v

class HAN(nn.Module):
    def __init__(self, vocab_size, embed_size=128, hidden_size=64, num_classes=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.word_attn = WordAttention(embed_size, hidden_size)
        self.sen_attn = SentenceAttention(hidden_size)
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):  # x: (B, S, W)
        B, S, W = x.shape
        sents = []
        for s in range(S):
            e = self.embedding(x[:, s, :])  # (B, W, E)
            s_vec = self.word_attn(e)       # (B, H*2)
            sents.append(s_vec)
        s_mat = torch.stack(sents, dim=1)   # (B, S, H*2)
        doc_vec = self.sen_attn(s_mat)      # (B, H*2)
        out = self.fc(doc_vec)              # (B, C)
        return out

# ===== Flask 앱 초기화 =====
app = Flask(__name__)

# ===== 모델 및 vocab 로드 =====
device = torch.device("cpu")
model_path = "han_model_cpu.pkl"

# pickle로 저장된 파일 불러오기
with open(model_path, 'rb') as f:
    saved = pickle.load(f)

vocab = saved['vocab']
config = saved['config']

model = HAN(
    vocab_size=len(vocab),
    embed_size=config['embed_size'],
    hidden_size=config['hidden_size'],
    num_classes=config['num_classes']
)
model.load_state_dict(saved['model_state'])
model.to(device)
model.eval()

# ===== 전처리 함수 =====
def encode_korean(text):
    SENT_MAXLEN = config['SENT_MAXLEN']
    WORD_MAXLEN = config['WORD_MAXLEN']

    sents = kss.split_sentences(str(text))[:SENT_MAXLEN]
    doc_idx = []
    for sent in sents:
        word_idx = [vocab.get(w, 1) for w in sent.split()[:WORD_MAXLEN]]
        word_idx += [0] * (WORD_MAXLEN - len(word_idx))
        doc_idx.append(word_idx)
    while len(doc_idx) < SENT_MAXLEN:
        doc_idx.append([0]*WORD_MAXLEN)
    return torch.tensor(doc_idx, dtype=torch.long)

# ===== 예측 API =====
@app.route('/predict', methods=['POST'])
def predict():
    data = request.json
    title = data.get("title", "")
    content = data.get("content", "")

    if not title and not content:
        return jsonify({'error': '입력 텍스트가 없습니다.'}), 400

    full_text = f"{title} {content}".strip()
    input_tensor = encode_korean(full_text).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)
        probs = torch.softmax(output, dim=1)[0].tolist()  # 확률 계산
        pred = output.argmax().item()

    return jsonify({
        'prediction': pred,
        'probabilities': {
            'class_0': round(probs[0]*100, 2),
            'class_1': round(probs[1]*100, 2)
        }
    })
    
# ===== 루트 안내 메시지 =====
@app.route('/')
def index():
    return "HAN 모델 API입니다. POST /predict로 'title'과 'content'를 보내주세요.", 200

# ===== 서버 실행 및 ngrok 연결 =====
if __name__ == '__main__':
    AUTH_TOKEN = "2tZphsq6YON5WdcRN6sWph8myF0_7Mk72ZnK7h8x42QkSS85P"
    ngrok.set_auth_token(AUTH_TOKEN)
    public_url = ngrok.connect(5000)
    print("ngrok tunnel URL:", public_url)
    app.run(port=5000)

Collecting flask
  Using cached flask-3.1.1-py3-none-any.whl.metadata (3.0 kB)
Collecting kss
  Using cached kss-6.0.4-py3-none-any.whl
Collecting pecab (from kss)
  Using cached pecab-1.0.8-py3-none-any.whl
Collecting cmudict (from kss)
  Using cached cmudict-1.0.32-py3-none-any.whl.metadata (3.6 kB)
Collecting bs4 (from kss)
  Using cached bs4-0.0.2-py2.py3-none-any.whl.metadata (411 bytes)
Using cached flask-3.1.1-py3-none-any.whl (103 kB)
Using cached bs4-0.0.2-py2.py3-none-any.whl (1.2 kB)
Using cached cmudict-1.0.32-py3-none-any.whl (939 kB)
Installing collected packages: pecab, flask, cmudict, bs4, kss
Successfully installed bs4-0.0.2 cmudict-1.0.32 flask-3.1.1 kss-6.0.4 pecab-1.0.8



[notice] A new release of pip is available: 24.0 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


ngrok tunnel URL: NgrokTunnel: "https://1989-218-237-232-179.ngrok-free.app" -> "http://localhost:5000"
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
[Kss]: [33mPress CTRL+C to quit[0m
[Kss]: Because there's no supported C++ morpheme analyzer, Kss will take pecab as a backend. :D
For your information, Kss also supports mecab backend.
We recommend you to install mecab or konlpy.tag.Mecab for faster execution of Kss.
Please refer to following web sites for details:
- mecab: https://cleancode-ws.tistory.com/97
- konlpy.tag.Mecab: https://uwgdqo.tistory.com/363

[Kss]: 127.0.0.1 - - [18/May/2025 01:36:02] "POST /predict HTTP/1.1" 200 -
[Kss]: 127.0.0.1 - - [18/May/2025 01:36:03] "POST /predict HTTP/1.1" 200 -
  from_pos_data.costs[idx]
[Kss]: Exception on /predict [POST]
Traceback (most recent call last):
  File "E:\vegetable-dragon\model\han_env\Lib\site-packages\flask\app.py", line 1511, in wsgi_app
    response = self.full_dispatch_request()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\vegetable-dragon\model\han_env\Lib\site-packages\flask\app.py", line 919, in full_dispatch_request
   