In [74]:
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm

from sentence_transformers import SentenceTransformer
from torch.utils.data import Dataset, DataLoader

# Для моделирования
from sklearn.linear_model import LogisticRegression
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SentenceTransformer("intfloat/multilingual-e5-large-instruct").to(device)

In [96]:
class CustomDataset(Dataset):
    def __init__(self, path, text_column_name: str = 'text', is_test: bool = False):
        self.df = pd.read_csv(path)
        self.text_column_name = text_column_name
        self.is_test = is_test

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if self.is_test:
            return self.df.loc[idx, self.text_column_name]
        return self.df.loc[idx, self.text_column_name], self.df.loc[idx, 'target']

In [62]:
data = CustomDataset('train.csv')
dataloader = DataLoader(data, batch_size=1024)

In [None]:
embs = np.array(range(1024))[np.newaxis, :]
targets = torch.tensor([])
with torch.no_grad():
    for text, target in tqdm(dataloader):
        emb = model.encode(text)
        targets = torch.concat([targets, target])
        embs = np.concatenate([embs, emb])
embs = embs[1:]

In [72]:
np.save('embs.npy', embs)
np.save('targets.npy', targets)

In [83]:
embs = np.load('embs.npy')
targets = np.load('targets.npy')

In [75]:
X_train, X_test, y_train, y_test = train_test_split(
    embs, targets, test_size=0.2, random_state=42
)

In [76]:
# Обучение линейной модели - Логистическая регрессия
logistic_model = LogisticRegression(max_iter=1000)
logistic_model.fit(X_train, y_train)
y_pred_logistic = logistic_model.predict(X_test)
accuracy_logistic = accuracy_score(y_test, y_pred_logistic)
print(f'Точность Логистической регрессии: {accuracy_logistic:.4f}')

Точность Логистической регрессии: 0.9148


In [None]:
eval_dataset = Pool(X_test,
                    np.array(y_test)
                    )

In [85]:
# Обучение CatBoost классификатора
catboost_model = CatBoostClassifier(learning_rate=0.01, iterations=10000, task_type="GPU", devices='0')
catboost_model.fit(X_train, np.array(y_train), eval_set=eval_dataset)
y_pred_catboost = catboost_model.predict(X_test)
accuracy_catboost = accuracy_score(y_test, y_pred_catboost)
print(f'Точность CatBoost классификатора: {accuracy_catboost:.4f}')



0:	learn: 2.1729764	test: 2.1727854	best: 2.1727854 (0)	total: 47.6ms	remaining: 7m 55s
1:	learn: 2.1500469	test: 2.1497707	best: 2.1497707 (1)	total: 92.6ms	remaining: 7m 42s
2:	learn: 2.1280267	test: 2.1274486	best: 2.1274486 (2)	total: 137ms	remaining: 7m 36s
3:	learn: 2.1070935	test: 2.1062693	best: 2.1062693 (3)	total: 183ms	remaining: 7m 37s
4:	learn: 2.0871976	test: 2.0862140	best: 2.0862140 (4)	total: 230ms	remaining: 7m 38s
5:	learn: 2.0678329	test: 2.0667040	best: 2.0667040 (5)	total: 276ms	remaining: 7m 39s
6:	learn: 2.0489219	test: 2.0477037	best: 2.0477037 (6)	total: 322ms	remaining: 7m 39s
7:	learn: 2.0308489	test: 2.0295246	best: 2.0295246 (7)	total: 367ms	remaining: 7m 38s
8:	learn: 2.0139409	test: 2.0124486	best: 2.0124486 (8)	total: 411ms	remaining: 7m 35s
9:	learn: 1.9975628	test: 1.9959944	best: 1.9959944 (9)	total: 454ms	remaining: 7m 33s
10:	learn: 1.9814448	test: 1.9797474	best: 1.9797474 (10)	total: 499ms	remaining: 7m 32s
11:	learn: 1.9656819	test: 1.9639229	be

In [97]:
test_data = CustomDataset('test_news.csv', 'content', True)
test_dataloader = DataLoader(test_data, batch_size=2048, shuffle=False)

In [99]:
embs = np.array(range(1024))[np.newaxis, :]
with torch.no_grad():
    for text in tqdm(test_dataloader):
        emb = model.encode(text)
        embs = np.concatenate([embs, emb])
embs = embs[1:]

100%|██████████| 13/13 [16:43<00:00, 77.21s/it]


In [101]:
embs.shape

(26275, 1024)

In [103]:
y_test_catboost = catboost_model.predict(embs)

In [106]:
pd.DataFrame({'topic': y_test_catboost[:, 0]}).reset_index().to_csv('transformer_answer.csv', index=False)

Ответ на том же уровне 0.746))) Так что походу сильно роляют данные, плохо напарсил....