## Импорт библиотек

In [1]:
import argparse

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm import tqdm
from dateutil import parser

from find_brand import find_brand

## Загрузка модели

In [2]:
BATCH_SIZE = 16 * 5
NUM_BEAMS = 2
MAX_SOURCE_LENGTH = 70

with open(f'model.pkl', 'rb') as r:
    model = pickle.load(r).to('cuda')

with open(f'tokenizer.pkl', 'rb') as r:
    tokenizer = pickle.load(r)

test_dataset = pd.read_csv('supermarket_val.tsv', sep='\t')
submission = pd.DataFrame({'name': test_dataset['name'].unique()})
test_input = ['Определи название и бренд товара. ' + inp for inp in test_dataset['name'].unique().tolist()]
test_input = [[
    tokenizer(
        inp,
        padding="max_length",
        max_length=MAX_SOURCE_LENGTH,
        truncation=True,
        return_tensors="pt",
    ), inp]
    for inp in test_input
]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 142/142 [02:37<00:00,  1.11s/it]


## Предикт и сохранение результатов

In [None]:
pred_names = []
pred_brands = []

test_loader = DataLoader(test_input, batch_size=BATCH_SIZE)


for item in tqdm(test_loader):
    encoding_x, input_text = item
    input_ids, attention_mask = encoding_x['input_ids'].reshape(-1, MAX_SOURCE_LENGTH), encoding_x['attention_mask'].reshape(-1, MAX_SOURCE_LENGTH)

    with torch.no_grad():
        prediction = model.generate(input_ids=input_ids.to('cuda'), attention_mask=attention_mask.to('cuda'), max_length=MAX_SOURCE_LENGTH, num_beams=NUM_BEAMS)
        prediction = tokenizer.batch_decode(prediction, skip_special_tokens=True)
        prediction = [pred.split('[SEPARATE]') for pred in prediction]

        pred_names.extend([pred[0].strip().lower() for pred in prediction])
        pred_brands.extend([find_brand(inp_text, i[1].strip().lower()) if len(i) == 2 else np.nan for i, inp_text in zip(prediction, input_text)])

pred_names = [i if i != 'nan' else np.nan for i in pred_names]
pred_brands = [i if i != 'nan' else np.nan for i in pred_brands]

submission['good'] = pred_names
submission['brand'] =  pred_brands
submission.to_csv('item_ner.csv', index=False)