In [3]:
!pip install datasets
!pip install transformers
!pip install accelerate
!pip install fuzzywuzzy
!pip install python-Levenshtein
!pip install sentence_transformers
!pip install einops



In [4]:
from typing import Dict, List, Union

In [5]:
from datasets import load_dataset
from fuzzywuzzy import fuzz, process
from sklearn.metrics import classification_report
from tqdm.notebook import tqdm
from transformers import pipeline

from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import cosine

In [6]:
def prepare_message_for_llm(text: Union[str, List[str]], categories: List[str]) -> Dict[str, Union[List[Dict[str, str]], List[List[Dict[str, str]]]]]:
    if len(categories) < 2:
        raise RuntimeError(f'The category list is too small! Expected 2 or more categories, got {len(categories)} ones.')
    categories_as_string = ', '.join(categories[:-1]) + ' и ' + categories[-1]
    if isinstance(text, str):
        prompt = (f"""Прочитайте текст ниже и определите одну основную категорию из предложенного списка.
            Выберите ту категорию, которая наилучшим образом отражает содержание всего текста.
            Не добавляйте пояснений и не выходите за рамки предложенных тем — укажите только название одной темы.
            Список тем: {categories_as_string}
            Убедитесь, что выбираете категорию только из списка.
            
            **Обязательно**:
            1. Будьте конкретны и ясны.
            2. Поймите текст, после чего выберите одну из категрий, указанных выше.
            
            **Не допускается**:
            1. Не делайте предположений и не придумывайте факты.
            2. Не создавайте новые категории. Используйте только указанные выше.
            
            Текст: {" ".join(text.split())}
            Ваш ответ:
            """
        )

        messages = [
            {
                'role': 'system',
                'content': (
                    'Вы — эксперт по анализу текста. Ваша задача — с предельной точностью определить основную тему, '
                    'которая наилучшим образом охватывает содержание всего текста. Выберите одну тему из списка, '
                    'опираясь на глубокое понимание и точный анализ текста. Дайте исключительно название темы, без дополнительных слов.'
                )
            },
            {
                'role': 'user',
                'content': prompt
            }
        ]

    else:
        messages = []
        for it in text:
            prompt = (f"""Прочитайте текст ниже и определите одну основную категорию из предложенного списка.
                Выберите ту категорию, которая наилучшим образом отражает содержание всего текста.
                Не добавляйте пояснений и не выходите за рамки предложенных тем — укажите только название одной темы.
                Список тем: {categories_as_string}
                Убедитесь, что выбираете категорию только из списка.

                **Обязательно**:
                1. Будьте конкретны и ясны.
                2. Поймите текст, после чего выберите одну из категрий, указанных выше.

                **Не допускается**:
                1. Не делайте предположений и не придумывайте факты.
                2. Не создавайте новые категории. Используйте только указанные выше.

                Текст: {" ".join(text.split())}
                Ваш ответ:
                """
            )

            messages.append([
                {
                    'role': 'system',
                    'content': (
                        'Вы — эксперт по анализу текста. Ваша задача — с предельной точностью определить основную тему, '
                        'которая наилучшим образом охватывает содержание всего текста. Выберите одну тему из списка, '
                        'опираясь на глубокое понимание и точный анализ текста. Дайте исключительно название темы, без дополнительных слов.'
                    )
                },
                {
                    'role': 'user',
                    'content': prompt
                }
            ])

    return {'message_for_llm': messages}

In [7]:
llm_pipeline = pipeline(model='Qwen/Qwen2-7B-Instruct', device_map='auto', torch_dtype='auto')

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/27.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

In [8]:
DATASET_NAME = 'Davlan/sib200'
DATASET_LANGUAGE = 'rus_Cyrl'
train_set = load_dataset(DATASET_NAME, DATASET_LANGUAGE, split='train')
validation_set = load_dataset(DATASET_NAME, DATASET_LANGUAGE, split='validation')
test_set = load_dataset(DATASET_NAME, DATASET_LANGUAGE, split='test')

README.md:   0%|          | 0.00/47.9k [00:00<?, ?B/s]

data/rus_Cyrl/train.tsv:   0%|          | 0.00/195k [00:00<?, ?B/s]

data/rus_Cyrl/dev.tsv:   0%|          | 0.00/25.3k [00:00<?, ?B/s]

data/rus_Cyrl/test.tsv:   0%|          | 0.00/57.4k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [9]:
list_of_categories = sorted(list(
    set(train_set['category']) | set(validation_set['category']) | set(test_set['category'])
))
print(f'Categories for classification are: {list_of_categories}')

Categories for classification are: ['entertainment', 'geography', 'health', 'politics', 'science/technology', 'sports', 'travel']


In [10]:
validation_set_for_llm = validation_set.map(lambda it: prepare_message_for_llm(it['text'], list_of_categories))
test_set_for_llm = test_set.map(lambda it: prepare_message_for_llm(it['text'], list_of_categories))

Map:   0%|          | 0/99 [00:00<?, ? examples/s]

Map:   0%|          | 0/204 [00:00<?, ? examples/s]

In [11]:
class PostProcessor:
    def __init__(self, model_name='cde-small'):
        self.model_name = model_name
        if self.model_name == 'cde-small':
            self.model = SentenceTransformer("jxm/cde-small-v1", trust_remote_code=True)

    def process(self, true_cats, pred_cats):
        embs_dict = {cat : self.model.encode(cat) for cat in list_of_categories}
        preds = [x[-1]['content'] for x in y_pred]
        
        pred_processed = []
        for pred in pred_cats:
            res_cat = pred
            if pred not in embs_dict.keys():
                emb = self.model.encode(pred)
                max_dist = 0
                for cat in list_of_categories:
                    if cosine(emb, embs_dict[cat]) > max_dist:
                        max_dist = cosine(emb, embs_dict[cat])
                        res_cat = cat
            pred_processed.append(res_cat)

        return pred_processed

In [11]:
y_pred = list(map(
    lambda x: llm_pipeline(x, max_new_tokens=10)[0]['generated_text'],
    tqdm(validation_set_for_llm['message_for_llm'])
))
y_true = validation_set['category']

  0%|          | 0/99 [00:00<?, ?it/s]

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


In [12]:
print(classification_report(y_true=y_true, y_pred=[x[-1]['content'] for x in y_pred]))

                    precision    recall  f1-score   support

           culture       0.00      0.00      0.00         0
     entertainment       0.67      0.44      0.53         9
         geography       0.78      0.88      0.82         8
            health       0.88      0.64      0.74        11
           history       0.00      0.00      0.00         0
  media-production       0.00      0.00      0.00         0
             music       0.00      0.00      0.00         0
          politics       0.87      0.93      0.90        14
science/technology       0.85      0.92      0.88        25
            sports       0.92      1.00      0.96        12
         transport       0.00      0.00      0.00         0
            travel       0.80      0.60      0.69        20

          accuracy                           0.79        99
         macro avg       0.48      0.45      0.46        99
      weighted avg       0.83      0.79      0.80        99



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [13]:
y_pred_with_normalization = list(map(
    lambda it: process.extractOne(it[-1]['content'], list_of_categories, scorer=fuzz.token_sort_ratio)[0],
    y_pred
))

In [14]:
print(classification_report(y_true=y_true, y_pred=y_pred_with_normalization))

                    precision    recall  f1-score   support

     entertainment       0.57      0.44      0.50         9
         geography       0.78      0.88      0.82         8
            health       0.88      0.64      0.74        11
          politics       0.81      0.93      0.87        14
science/technology       0.85      0.92      0.88        25
            sports       0.75      1.00      0.86        12
            travel       0.75      0.60      0.67        20

          accuracy                           0.79        99
         macro avg       0.77      0.77      0.76        99
      weighted avg       0.78      0.79      0.78        99



In [15]:
pp = PostProcessor(model_name='cde-small')
y_pred_proc = pp.process(list_of_categories, [x[-1]['content'] for x in y_pred])
print(classification_report(y_true=y_true, y_pred=y_pred_proc))

modules.json:   0%|          | 0.00/149 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/276k [00:00<?, ?B/s]

sentence_transformers_impl.py:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jxm/cde-small-v1:
- sentence_transformers_impl.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


sentence_bert_config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/916 [00:00<?, ?B/s]

misc.py:   0%|          | 0.00/720 [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jxm/cde-small-v1:
- misc.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.py:   0%|          | 0.00/40.7k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jxm/cde-small-v1:
- model.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

configuration_hf_nomic_bert.py:   0%|          | 0.00/1.96k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/nomic-ai/nomic-bert-2048:
- configuration_hf_nomic_bert.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_hf_nomic_bert.py:   0%|          | 0.00/85.7k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/nomic-ai/nomic-bert-2048:
- modeling_hf_nomic_bert.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


pytorch_model.bin:   0%|          | 0.00/549M [00:00<?, ?B/s]

  state_dict = loader(resolved_archive_file)


tokenizer_config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Disabled 37 dropout modules from model type <class 'transformers_modules.jxm.cde-small-v1.9e2ed1d8d569d34458913d2d246935c1b2324d11.model.BiEncoder'>
modified 12 rotary modules – set rotary_start_pos to 512
Disabled 74 dropout modules from model type <class 'transformers_modules.jxm.cde-small-v1.9e2ed1d8d569d34458913d2d246935c1b2324d11.model.DatasetTransformer'>


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

                    precision    recall  f1-score   support

     entertainment       0.67      0.44      0.53         9
         geography       0.78      0.88      0.82         8
            health       0.78      0.64      0.70        11
          politics       0.87      0.93      0.90        14
science/technology       0.82      0.92      0.87        25
            sports       0.92      1.00      0.96        12
            travel       0.74      0.70      0.72        20

          accuracy                           0.81        99
         macro avg       0.80      0.79      0.79        99
      weighted avg       0.80      0.81      0.80        99



In [12]:
y_pred = list(map(
    lambda x: llm_pipeline(x, max_new_tokens=10)[0]['generated_text'],
    tqdm(test_set_for_llm['message_for_llm'])
))
y_true = test_set['category']

  0%|          | 0/204 [00:00<?, ?it/s]

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


In [13]:
print(classification_report(y_true=y_true, y_pred=[x[-1]['content'] for x in y_pred]))

                    precision    recall  f1-score   support

           animals       0.00      0.00      0.00         0
         astronomy       0.00      0.00      0.00         0
    communications       0.00      0.00      0.00         0
           culture       0.00      0.00      0.00         0
     entertainment       0.92      0.58      0.71        19
         geography       0.88      0.82      0.85        17
            health       0.90      0.86      0.88        22
        literature       0.00      0.00      0.00         0
             music       0.00      0.00      0.00         0
outdoor-activities       0.00      0.00      0.00         0
          politics       0.74      0.97      0.84        30
science/technology       0.89      0.94      0.91        51
            sports       0.91      0.84      0.87        25
         transport       0.00      0.00      0.00         0
            travel       1.00      0.70      0.82        40
           weather       0.00      0.00

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [14]:
y_pred_with_normalization = list(map(
    lambda it: process.extractOne(it[-1]['content'], list_of_categories, scorer=fuzz.token_sort_ratio)[0],
    y_pred
))

In [15]:
print(classification_report(y_true=y_true, y_pred=y_pred_with_normalization))

                    precision    recall  f1-score   support

     entertainment       0.92      0.58      0.71        19
         geography       0.88      0.82      0.85        17
            health       0.83      0.86      0.84        22
          politics       0.71      0.97      0.82        30
science/technology       0.86      0.94      0.90        51
            sports       0.81      0.84      0.82        25
            travel       0.93      0.70      0.80        40

          accuracy                           0.83       204
         macro avg       0.85      0.82      0.82       204
      weighted avg       0.85      0.83      0.83       204



In [16]:
pp = PostProcessor(model_name='cde-small')
y_pred_proc = pp.process(list_of_categories, [x[-1]['content'] for x in y_pred])
print(classification_report(y_true=y_true, y_pred=y_pred_proc))

modules.json:   0%|          | 0.00/149 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/276k [00:00<?, ?B/s]

sentence_transformers_impl.py:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jxm/cde-small-v1:
- sentence_transformers_impl.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


sentence_bert_config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/916 [00:00<?, ?B/s]

misc.py:   0%|          | 0.00/720 [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jxm/cde-small-v1:
- misc.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.py:   0%|          | 0.00/40.7k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jxm/cde-small-v1:
- model.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

configuration_hf_nomic_bert.py:   0%|          | 0.00/1.96k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/nomic-ai/nomic-bert-2048:
- configuration_hf_nomic_bert.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_hf_nomic_bert.py:   0%|          | 0.00/85.7k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/nomic-ai/nomic-bert-2048:
- modeling_hf_nomic_bert.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


pytorch_model.bin:   0%|          | 0.00/549M [00:00<?, ?B/s]

  state_dict = loader(resolved_archive_file)


tokenizer_config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Disabled 37 dropout modules from model type <class 'transformers_modules.jxm.cde-small-v1.9e2ed1d8d569d34458913d2d246935c1b2324d11.model.BiEncoder'>
modified 12 rotary modules – set rotary_start_pos to 512
Disabled 74 dropout modules from model type <class 'transformers_modules.jxm.cde-small-v1.9e2ed1d8d569d34458913d2d246935c1b2324d11.model.DatasetTransformer'>


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

                    precision    recall  f1-score   support

     entertainment       0.92      0.58      0.71        19
         geography       0.88      0.82      0.85        17
            health       0.86      0.86      0.86        22
          politics       0.74      0.97      0.84        30
science/technology       0.86      0.94      0.90        51
            sports       0.88      0.84      0.86        25
            travel       0.94      0.82      0.88        40

          accuracy                           0.86       204
         macro avg       0.87      0.83      0.84       204
      weighted avg       0.87      0.86      0.86       204

