In [3]:
from flask import Flask, request, jsonify
from pyngrok import ngrok
import torch
import torch.nn as nn
import torch.nn.functional as F
import kss
import pickle

# ===== HAN 모델 클래스 정의 =====
class WordAttention(nn.Module):
    def __init__(self, embed_size, hidden_size):
        super().__init__()
        self.gru = nn.GRU(embed_size, hidden_size, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(hidden_size * 2, hidden_size * 2)
        self.context = nn.Parameter(torch.randn(hidden_size * 2))

    def forward(self, x):
        out, _ = self.gru(x)
        u = torch.tanh(self.fc(out))
        attn = torch.matmul(u, self.context)
        attn = F.softmax(attn, dim=1).unsqueeze(-1)
        s = torch.sum(out * attn, dim=1)
        return s

class SentenceAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.gru = nn.GRU(hidden_size * 2, hidden_size, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(hidden_size * 2, hidden_size * 2)
        self.context = nn.Parameter(torch.randn(hidden_size * 2))

    def forward(self, x):
        out, _ = self.gru(x)
        u = torch.tanh(self.fc(out))
        attn = torch.matmul(u, self.context)
        attn = F.softmax(attn, dim=1).unsqueeze(-1)
        v = torch.sum(out * attn, dim=1)
        return v

class HAN(nn.Module):
    def __init__(self, vocab_size, embed_size=128, hidden_size=64, num_classes=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.word_attn = WordAttention(embed_size, hidden_size)
        self.sen_attn = SentenceAttention(hidden_size)
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):  # x: (B, S, W)
        B, S, W = x.shape
        sents = []
        for s in range(S):
            e = self.embedding(x[:, s, :])  # (B, W, E)
            s_vec = self.word_attn(e)       # (B, H*2)
            sents.append(s_vec)
        s_mat = torch.stack(sents, dim=1)   # (B, S, H*2)
        doc_vec = self.sen_attn(s_mat)      # (B, H*2)
        out = self.fc(doc_vec)              # (B, C)
        return out

# ===== Flask 앱 초기화 =====
app = Flask(__name__)

# ===== 모델 및 vocab 로드 =====
device = torch.device("cpu")
model_path = "han_model_cpu.pkl"

# pickle로 저장된 파일 불러오기
with open(model_path, 'rb') as f:
    saved = pickle.load(f)

vocab = saved['vocab']
config = saved['config']

model = HAN(
    vocab_size=len(vocab),
    embed_size=config['embed_size'],
    hidden_size=config['hidden_size'],
    num_classes=config['num_classes']
)
model.load_state_dict(saved['model_state'])
model.to(device)
model.eval()

# ===== 전처리 함수 =====
def encode_korean(text):
    SENT_MAXLEN = config['SENT_MAXLEN']
    WORD_MAXLEN = config['WORD_MAXLEN']

    sents = kss.split_sentences(str(text))[:SENT_MAXLEN]
    doc_idx = []
    for sent in sents:
        word_idx = [vocab.get(w, 1) for w in sent.split()[:WORD_MAXLEN]]
        word_idx += [0] * (WORD_MAXLEN - len(word_idx))
        doc_idx.append(word_idx)
    while len(doc_idx) < SENT_MAXLEN:
        doc_idx.append([0]*WORD_MAXLEN)
    return torch.tensor(doc_idx, dtype=torch.long)

# ===== 예측 API =====
@app.route('/predict', methods=['POST'])
def predict():
    data = request.json
    title = data.get("title", "")
    content = data.get("content", "")

    if not title and not content:
        return jsonify({'error': '입력 텍스트가 없습니다.'}), 400

    # 제목과 내용을 공백으로 결합
    full_text = f"{title} {content}".strip()

    input_tensor = encode_korean(full_text).unsqueeze(0).to(device)  # (1, S, W)
    with torch.no_grad():
        output = model(input_tensor)
        pred = output.argmax(dim=1).item()

    return jsonify({'prediction': pred})

# ===== 루트 안내 메시지 =====
@app.route('/')
def index():
    return "HAN 모델 API입니다. POST /predict로 'title'과 'content'를 보내주세요.", 200

# ===== 서버 실행 및 ngrok 연결 =====
if __name__ == '__main__':
    ngrok.set_auth_token("2x5IsONJiTj67Ld0Yqmk4v60JCa_76yj4k9eBbT2wB8gxZweB")
    public_url = ngrok.connect(5000)
    print("ngrok tunnel URL:", public_url)
    app.run(port=5000)

ngrok tunnel URL: NgrokTunnel: "https://c8bc-166-104-245-67.ngrok-free.app" -> "http://localhost:5000"
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
[Kss]: [33mPress CTRL+C to quit[0m
  from_pos_data.costs[idx]
[Kss]: 127.0.0.1 - - [15/May/2025 19:33:47] "POST /predict HTTP/1.1" 200 -
