# Setting Environment

In [1]:
!pip install -U transformers



In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
FOLDER_PROJECT = "/content/drive/MyDrive/ScholarAI/"

In [4]:
import os
from dotenv import load_dotenv
load_dotenv(dotenv_path=os.path.join(FOLDER_PROJECT, "server/predict/.env"))

True

In [5]:
from huggingface_hub import login
login(token=os.getenv("TOKEN_HUGGINGFACE_HUB"))

In [6]:
OUTPUT_DATASET_JSON = os.path.join(FOLDER_PROJECT, "data/output_dataset.json")
MODEL_NAME = "VTKK/bert-news-category-classification"

# Using Model

In [7]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
import torch.nn.functional as F
import pandas as pd

In [8]:
pipe = pipeline("text-classification", model=MODEL_NAME)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Device set to use cpu


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cpu


In [10]:
def load_model(model_name: str):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        output_hidden_states=True
    )
    model.to(device)
    model.eval()
    return tokenizer, model

In [11]:
def get_embedding(text: str, tokenizer, model):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        out = model.bert(**inputs, output_hidden_states=True)
    return out.last_hidden_state[0,0].cpu().numpy()

In [12]:
def cosine_similarity(a, b):
    a = torch.tensor(a)
    b = torch.tensor(b)
    return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()


In [13]:
def sort_by_similarity(query: str, copus: list[str], tokenizer, model):

    query_emb = get_embedding(query, tokenizer, model)

    results = []
    for id, sent in enumerate(copus):
        emb = get_embedding(sent, tokenizer, model)
        score = cosine_similarity(query_emb, emb)
        results.append((id, score))

    results.sort(key=lambda x: x[1], reverse=True)
    return results

In [14]:
def predict_label(text: str, max_length: int = 512, top_k: int = 1):
    inputs = tokenizer(
        text,
        truncation=True,
        padding=True,
        max_length=max_length,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probs = F.softmax(logits, dim=-1).cpu().numpy()[0]


    id2label = model.config.id2label
    top_indices = probs.argsort()[::-1][:top_k]
    results = [{"label": id2label[int(i)], "score": float(probs[int(i)])} for i in top_indices]
    return results

In [15]:
tokenizer, model = load_model(MODEL_NAME)

In [16]:
df = pd.read_json(OUTPUT_DATASET_JSON, lines=True)
df.head()

Unnamed: 0,link,headline,category,short_description,authors,date,text
0,https://www.huffpost.com/entry/covid-boosters-...,Over 4 Million Americans Roll Up Sleeves For O...,WORLD NEWS,Health experts said it is too early to predict...,"Carla K. Johnson, AP",2022-09-23,Over 4 Million Americans Roll Up Sleeves For O...
1,https://www.huffpost.com/entry/american-airlin...,"American Airlines Flyer Charged, Banned For Li...",WORLD NEWS,He was subdued by passengers and crew when he ...,Mary Papenfuss,2022-09-23,"American Airlines Flyer Charged, Banned For Li..."
2,https://www.huffpost.com/entry/funniest-tweets...,23 Of The Funniest Tweets About Cats And Dogs ...,COMEDY,"""Until you have a dog you don't understand wha...",Elyse Wanshel,2022-09-23,23 Of The Funniest Tweets About Cats And Dogs ...
3,https://www.huffpost.com/entry/funniest-parent...,The Funniest Tweets From Parents This Week (Se...,PARENTING,"""Accidentally put grown-up toothpaste on my to...",Caroline Bologna,2022-09-23,The Funniest Tweets From Parents This Week (Se...
4,https://www.huffpost.com/entry/amy-cooper-lose...,Woman Who Called Cops On Black Bird-Watcher Lo...,WORLD NEWS,Amy Cooper accused investment firm Franklin Te...,Nina Golgowski,2022-09-22,Woman Who Called Cops On Black Bird-Watcher Lo...


In [17]:
def find_similarity(query, limit=10):
    for pred in predict_label(query, top_k=1):
        label = pred["label"]
        score = pred["score"]

        # print(f"Label: {label}, Score: {score}")
        copus = df[df['category'] == label]['short_description']

        results = sort_by_similarity(query, copus[:limit], tokenizer, model)

    res = [
      {
        "id": id,
        "score": score,
        "link": df.at[id, "link"],
        "headline": df.at[id, "headline"],
        "short_description": df.at[id, "short_description"],
        "date": df.at[id, "date"],
        "category": df.at[id, "category"],
        "authors": df.at[id, "authors"],
      }
      for id, score in results
    ]

    return res


# APIs

In [18]:
!pip install pyngrok
!pip install flask-ngrok



In [19]:
from pyngrok import ngrok
from flask import Flask, request, jsonify

In [20]:
ngrok.set_auth_token(os.getenv("TOKEN_NGROK"))

In [21]:
app = Flask(__name__)

@app.route("/", methods=["POST"])
def predict():
    data = request.get_json()
    query = data.get("query", "")
    limit = data.get("limit", 10)
    return jsonify(find_similarity(query, limit))

In [22]:
public_url = ngrok.connect(5000)
print(public_url)
app.run()

NgrokTunnel: "https://12cf3f6789d2.ngrok-free.app" -> "http://localhost:5000"
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m
INFO:werkzeug:127.0.0.1 - - [11/Dec/2025 15:29:38] "POST / HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [11/Dec/2025 15:30:10] "POST / HTTP/1.1" 200 -
