In [6]:
import gradio as gr
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
from sklearn.feature_extraction.text import TfidfVectorizer

# === Model Class for TF-IDF model (required to load)
class SpamClassifier(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, 2)

    def forward(self, x):
        return self.fc(x)

# === Load TF-IDF vectorizer and model
vectorizer = pickle.load(open("vectorizer.pkl", "rb"))
tfidf_model = torch.load("spam_model.pt", map_location=torch.device("cpu"), weights_only=False)
tfidf_model.eval()

# === (Optional) Load LSTM model and word2idx
try:
    word2idx = pickle.load(open("word2idx.pkl", "rb"))
    lstm_model = torch.load("lstm_model.pt", map_location=torch.device("cpu"))
    lstm_model.eval()
    has_lstm = True
except:
    has_lstm = False

# === TF-IDF Prediction
def predict_tfidf(text):
    x = vectorizer.transform([text]).toarray()
    x_tensor = torch.tensor(x, dtype=torch.float32)
    with torch.no_grad():
        out = tfidf_model(x_tensor)
        _, pred = torch.max(out, 1)
    return "SPAM ❌" if pred.item() == 1 else "HAM ✅"

# === LSTM Prediction
def predict_lstm(text):
    if not has_lstm:
        return "⚠️ LSTM model not available."

    tokens = text.lower().split()
    encoded = [word2idx.get(w, 1) for w in tokens]  # 1 = <UNK>
    padded = F.pad(torch.tensor(encoded), (0, max(0, 30 - len(encoded))))[:30]

    with torch.no_grad():
        output = lstm_model(padded.unsqueeze(0))  # Add batch dimension
        _, pred = torch.max(output, 1)

    return "SPAM ❌" if pred.item() == 1 else "HAM ✅"

# === Gradio Interface
def classify(text, model_type):
    if model_type == "TF-IDF":
        return predict_tfidf(text)
    else:
        return predict_lstm(text)

iface = gr.Interface(
    fn=classify,
    inputs=[gr.Textbox(label="Sakonka"), gr.Radio(["TF-IDF", "LSTM"], label="Zaɓi Model")],
    outputs="text",
    title="Spam Message Detector",
    description="Wannan app yana amfani da TF-IDF ko LSTM domin gano saƙon spam."
)

iface.launch()


* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


