In [1]:
import csv
import json
import pickle
import joblib
import sqlite3

from collections import defaultdict, Counter

import hazm
from parsivar import Normalizer

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from tqdm.notebook import tqdm_notebook

from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import torch
from torch import nn

from datasets import (
    Dataset,
    DatasetDict,
    load_dataset,
    load_metric,
    load_from_disk,
    concatenate_datasets,
)
from transformers import AutoTokenizer, AutoModel
from transformers import AdamW

In [2]:
class MyNormalizer:
    def __init__(self):
        self.parsivar_normalizer = Normalizer(
            statistical_space_correction=True,
            half_space_char=" ",
            pinglish_conversion_needed=True,
        )
        self.hazm_normalizer = hazm.Normalizer(
            remove_extra_spaces=True,
            persian_numbers=True,
            persian_style=True,
            punctuation_spacing=False,
            remove_diacritics=True,
            affix_spacing=False,
            token_based=True,
        )

    def normilize(self, txt):
        return self.hazm_normalizer.normalize(
            self.parsivar_normalizer.normalize(
                txt.replace("\n", " ").replace("\u200c", " ").lower().strip()
            )
        )


In [3]:
class JsonFileIterator:
    def __init__(self, path):
        self.path = path
        self.f = open(path, 'r')
        self.i = 0
        self.length = self.counter_lines()

    def __iter__(self):
        return self

    def __next__(self):
        line = self.f.readline()
        if not line:
            # End of file
            self.f.close()
            raise StopIteration
        self.i += 1
        return json.loads(line)
        
    def counter_lines(self):
        with open(self.path, 'r') as f1:
            return sum(1 for _ in f1)

    def __len__(self):
        return  self.length

In [4]:
def find_new_words(titles_list , index_high_len):
    highest_len_title_words = titles_list[index_high_len].split()
    total_new_words = []
    for i in range(1, len(titles_list)):
        if i != index_high_len:
            words = titles_list[i].split()
            new_words = [word for word in words if word not in highest_len_title_words]
            new_words = [word for word in new_words if word not in stopwords]
            total_new_words += new_words
    return list(set(total_new_words))

In [5]:
stopwords = []
normalizer = MyNormalizer()
with open('stop-words.txt', encoding='utf-8' ) as f:
    for line in f:
        stopwords.append(normalizer.normilize(line.strip()))

In [8]:
test_data = JsonFileIterator("./test-offline-data_v1.jsonl")
with open('test_dataset_v4.txt', 'w', encoding="utf-8", newline='') as csvfile:
    wrtiter = csv.writer(csvfile)
    wrtiter.writerow(["row_number" , "query",
                                  "product_id",
                                  "p_des",
                                  "category_name",
                                  "min_num_shops",
                                  "max_num_shops",
                                  "avg_num_shops",
                                  "min_price",
                                  "max_price",
                                  "avg_price",
                                  "mean_min_prices",
                                  "mean_max_prices",
                                  "mean_avg_prices",
                                  "std_min_prices",
                                  "std_max_prices",
                                  "std_avg_prices"])
    conn = sqlite3.connect('my_database.db')
    c = conn.cursor()
    data_list = []
    for idx , test_rec in enumerate(tqdm_notebook(test_data)):
        query = test_rec["raw_query"]
        results = test_rec["result_not_ranked"]
        min_prices = []
        max_prices = []
        avg_prices = []
        
        for product_id in results:
            if product_id != None:
                c.execute('SELECT * FROM products WHERE id = ?', (product_id,))
                result_product = c.fetchone()
                if result_product!= None:
                    if result_product[3] != None:
                        min_prices.append(result_product[3])
                    if result_product[4] != None:
                        max_prices.append(result_product[4])
                    if result_product[5] != None:
                        avg_prices.append(result_product[5])
            mean_min_prices = np.mean(min_prices)
            mean_max_prices = np.mean(max_prices)
            mean_avg_prices = np.mean(avg_prices)

            std_min_prices = np.std(min_prices)
            std_max_prices = np.std(max_prices)
            std_avg_prices = np.std(avg_prices)

        for product_id in results:
            if product_id != None:
                c.execute('SELECT * FROM products WHERE id = ?', (product_id,))
                result_product = c.fetchone()
                category_name = result_product[1]
                min_price = result_product[3]
                max_price = result_product[4]
                avg_price = result_product[5]
                titles_list_product = json.loads(result_product[2])
                if len(titles_list_product)>0 :
                    min_num_shops = result_product[6]
                    max_num_shops = result_product[7]
                    avg_num_shops = result_product[8]
                    highest_len_product = max(titles_list_product, key=len)
                    index_highest_len_product = titles_list_product.index(highest_len_product)
                    product_title_new_words = find_new_words(titles_list_product,index_highest_len_product)
                    p_des = " ".join([highest_len_product , " ".join(product_title_new_words)]).replace("\u200c"," ")
                    wrtiter.writerow( [idx , normalizer.normilize(query),
                                      product_id,
                                      p_des,
                                      category_name,
                                      min_num_shops,
                                      max_num_shops,
                                      avg_num_shops,
                                      min_price,
                                      max_price,
                                      avg_price,
                                      mean_min_prices,
                                      mean_max_prices,
                                      mean_avg_prices,
                                      std_min_prices,
                                      std_max_prices,
                                      std_avg_prices,])
                


  0%|          | 0/23140 [00:00<?, ?it/s]