In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [4]:
MODEL_PATH = "../models/distilbert_finetuned_agnews"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
model.eval()

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [5]:
label_names = ["World", "Sports", "Business", "Sci/Tech"]

In [6]:
def classify_text(text: str):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=128
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=1)
        pred = probs.argmax(dim=1).item()

    return {
        "predicted_label": label_names[pred],
        "confidence": float(probs[0, pred])
    }

In [7]:
example_text = """
Apple announced a new generation of processors designed
for artificial intelligence workloads and data centers.
"""

result = classify_text(example_text)
result

{'predicted_label': 'Sci/Tech', 'confidence': 0.9915846586227417}

In [8]:
examples = [
    "The stock market reacted positively to the new interest rate decision.",
    "The football team won the championship after a dramatic final match.",
    "Scientists discovered a new particle that could change physics.",
    "The president met with foreign leaders to discuss global security."
]

for text in examples:
    res = classify_text(text)
    print(f"TEXT: {text}")
    print(f" - PREDICTION: {res['predicted_label']} (confidence={res['confidence']:.2f})\n")

TEXT: The stock market reacted positively to the new interest rate decision.
 - PREDICTION: Business (confidence=0.99)

TEXT: The football team won the championship after a dramatic final match.
 - PREDICTION: World (confidence=0.87)

TEXT: Scientists discovered a new particle that could change physics.
 - PREDICTION: Sci/Tech (confidence=0.87)

TEXT: The president met with foreign leaders to discuss global security.
 - PREDICTION: World (confidence=1.00)

