In [None]:
import os
from transformers import AutoModelForSequenceClassification
from article_classifier.dataset import id2label, label2id, create_prompt

# model_path = "distilbert/distilbert-base-cased" # todo, replace with hacker1337/article-classifier
# model_path = os.path.expanduser(r"~\cache\huggingface\checkpoints\distilbert-arxiv\runs\Jun21_00-41-18_amir-xp")
model_path = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "checkpoints", "distilbert-arxiv")

model = AutoModelForSequenceClassification.from_pretrained(
    model_path,
    num_labels=len(id2label),
    id2label=id2label,
    label2id=label2id,
    problem_type="multi_label_classification",
)


In [None]:
model_path

'C:\\Users\\amirf\\.cache\\huggingface\\checkpoints\\distilbert-arxiv'

In [None]:
!dir {model_path}

 Volume in drive C is OS
 Volume Serial Number is 58DB-6C24

 Directory of C:\Users\amirf\.cache\huggingface\checkpoints\distilbert-arxiv

22.06.2025  15:59    <DIR>          .
21.06.2025  00:35    <DIR>          ..
22.06.2025  15:59               810 config.json
22.06.2025  15:59       263�153�916 model.safetensors
21.06.2025  00:41    <DIR>          runs
22.06.2025  15:59               132 special_tokens_map.json
22.06.2025  15:59           668�923 tokenizer.json
22.06.2025  15:59             1�284 tokenizer_config.json
22.06.2025  15:59           213�450 vocab.txt
               6 File(s)    264�038�515 bytes
               3 Dir(s)  59�368�357�888 bytes free


In [None]:
from transformers import pipeline

classifier = pipeline(
    "text-classification",
    model=model_path,
    tokenizer=model_path,
)

Device set to use cuda:0


In [None]:
from article_classifier.dataset import load_arxiv_dataset

dataset = load_arxiv_dataset()

In [None]:
dataset

Dataset({
    features: ['titles', 'summaries', 'terms'],
    num_rows: 51774
})

In [None]:
sample_prompt_full = create_prompt(
    dataset[0]["titles"],
    dataset[0]["summaries"],
)
classes = dataset[0]["terms"]
print(sample_prompt_full)
print("Classes:", classes)

# title:
Survey on Semantic Stereo Matching / Semantic Depth Estimation
# abstract:
Stereo matching is one of the widely used techniques for inferring depth from
stereo images owing to its robustness and speed. It has become one of the major
topics of research since it finds its applications in autonomous driving,
robotic navigation, 3D reconstruction, and many other fields. Finding pixel
correspondences in non-textured, occluded and reflective areas is the major
challenge in stereo matching. Recent developments have shown that semantic cues
from image segmentation can be used to improve the results of stereo matching.
Many deep neural network architectures have been proposed to leverage the
advantages of semantic segmentation in stereo matching. This paper aims to give
a comparison among the state of art networks both in terms of accuracy and in
terms of speed which are of higher importance in real-time applications.
Classes: ['cs.CV', 'cs.LG']


In [None]:
predictions = classifier(sample_prompt_full, top_k=None)
predictions

[{'label': 'CV', 'score': 0.9944138526916504},
 {'label': 'ML', 'score': 0.10464362800121307},
 {'label': 'AI', 'score': 0.04399743676185608},
 {'label': 'NE', 'score': 0.009200998581945896},
 {'label': 'CL', 'score': 0.00635348679497838}]

In [None]:
only_title_prompt = create_prompt(
    dataset[0]["titles"],
    "")
only_abstract_prompt = create_prompt(
    "",
    dataset[0]["summaries"],
)

In [None]:
print("Only title prompt predictions:", classifier(only_title_prompt, top_k=None))
print("Only abstract prompt predictions:", classifier(only_abstract_prompt, top_k=None))

Only title prompt predictions: [{'label': 'CV', 'score': 0.8638411164283752}, {'label': 'ML', 'score': 0.37290239334106445}, {'label': 'AI', 'score': 0.1514018476009369}, {'label': 'NE', 'score': 0.06561543792486191}, {'label': 'CL', 'score': 0.046929094940423965}]
Only abstract prompt predictions: [{'label': 'CV', 'score': 0.9935577511787415}, {'label': 'ML', 'score': 0.1251823604106903}, {'label': 'AI', 'score': 0.04704591631889343}, {'label': 'NE', 'score': 0.009932062588632107}, {'label': 'CL', 'score': 0.006572176702320576}]
