In [1]:
import pandas as pd
from tqdm import tqdm

In [2]:
# index, session_id, song_id, unix_played_at, play_status, login_type, listening_order
train_source = pd.read_parquet("../../datagame-2023/label_train_source.parquet")
# index, session_id, song_id, unix_played_at, play_status, login_type, listening_order
train_target = pd.read_parquet("../../datagame-2023/label_train_target.parquet")
# index, session_id, song_id, unix_played_at, play_status, login_type, listening_order
test_source = pd.read_parquet("../../datagame-2023/label_test_source.parquet")
# index, song_id, artist_id, song_length, album_id, language_id, album_month
meta_song = pd.read_parquet("../../datagame-2023/meta_song.parquet")
# index, song_id, composer_id
meta_song_composer = pd.read_parquet("../../datagame-2023/meta_song_composer.parquet")
# index, song_id, genre_id
meta_song_genre = pd.read_parquet("../../datagame-2023/meta_song_genre.parquet")
# index, song_id, lyricist_id
meta_song_lyricist = pd.read_parquet("../../datagame-2023/meta_song_lyricist.parquet")
# index, song_id, producer_id
meta_song_producer = pd.read_parquet("../../datagame-2023/meta_song_producer.parquet")
# index, song_id, title_text_id
meta_song_titletext = pd.read_parquet("../../datagame-2023/meta_song_titletext.parquet")

In [3]:
# Preparse session's songs
from collections import defaultdict

session_to_songs = defaultdict(list)  # key -> session_id, value -> songs

test_source = test_source.sort_values(by=['session_id', 'listening_order'])

group_by_session = test_source.groupby('session_id')

for _, group_song in tqdm(group_by_session):
    session_id = group_song['session_id'].iloc[0]
    session_to_songs[session_id] = group_song['song_id'].tolist()

100%|██████████| 143064/143064 [00:05<00:00, 25688.14it/s]


In [4]:
session_to_time = dict()
for row in tqdm(train_target.itertuples(), total=len(train_target)):
    session_to_time[row.session_id] = row.unix_played_at
for row in tqdm(test_source.itertuples(), total=len(test_source)):
    session_to_time[row.session_id] = row.unix_played_at

100%|██████████| 2861295/2861295 [00:01<00:00, 1639204.10it/s]
100%|██████████| 2861280/2861280 [00:01<00:00, 1754781.41it/s]


In [5]:
session_to_time[8]

1664085793

In [6]:
# Preparse session's songs
from collections import defaultdict

session_to_songs = defaultdict(list)  # key -> session_id, value -> songs

test_source = test_source.sort_values(by=['session_id', 'listening_order'])

group_by_session = test_source.groupby('session_id')

for _, group_song in tqdm(group_by_session):
    session_id = group_song['session_id'].iloc[0]
    session_to_songs[session_id] = group_song['song_id'].tolist()

100%|██████████| 143064/143064 [00:05<00:00, 25582.03it/s]


In [7]:
""" For Jelinek-Mercer smoothing """
from pyserini.search.lucene import LuceneSearcher


class MyLuceneSearcher(LuceneSearcher):
    def set_jmlm(self, Lambda: float = 0.9999):
        """ Set the Jelinek-Mercer smoothing with lambda

        Reference java code:
            public void set_qld(float mu) {
              this.similarity = new LMDirichletSimilarity(mu); # SimpleSearcher

              // We need to re-initialize the searcher
              searcher = new IndexSearcher(reader); # SimpleSearcher.searcher
              searcher.setSimilarity(similarity); # SimpleSearcher.searcher
            }
        Args:
            l (float): Lamda
        """
        from jnius import autoclass

        LMDirichletSimilarity = autoclass("org.apache.lucene.search.similarities.LMJelinekMercerSimilarity")
        self.object.similarity = LMDirichletSimilarity(Lambda)

        # We need to re-initialize the searcher
        IndexSearcher = autoclass("org.apache.lucene.search.IndexSearcher")
        self.object.searcher = IndexSearcher(self.object.reader)
        self.object.searcher.setSimilarity(self.object.similarity)
        print("set to jmlm with lambda = {}".format(Lambda))

In [8]:
from pyserini.index.lucene import IndexReader
from pyserini.search.lucene import LuceneSearcher, querybuilder
from pyserini.analysis import get_lucene_analyzer

fields = ['artist', 'album', 'language', 'genre']


class Searcher():
    def __init__(self, searcher: LuceneSearcher, reader: IndexReader, is_stemming=False) -> None:
        self.searcher: LuceneSearcher = searcher
        self.searcher.set_analyzer(get_lucene_analyzer(stemming=is_stemming))
        self.total_docs = reader.stats()['documents']
        self.mu = reader.stats()['total_terms'] / reader.stats()['unique_terms']

    def song_to_contents(self, song_id):
        return self.searcher.doc(song_id).contents() if self.searcher.doc(song_id) else ""

    def songs_to_query(self, session, time_range, song_ids, cut_song_token=0):
        time_min = session_to_time[session] - time_range
        time_max = session_to_time[session] + time_range
        contents = [self.song_to_contents(song_id) for song_id in song_ids]
        contents = [content.split() for content in contents]

        query_text = []
        query_text_len = 0
        count = 10

        for content in contents:
            max_song_token = count
            is_first = True
            first_index = -1
            last_index = -1
            for i in range(len(content)):
                if max_song_token <= 0:
                    break
                if any(field in content[i] for field in fields):
                    query_text.append(content[i])
                    query_text_len += len(content[i])
                    continue
                int_x = session_to_time[int(content[i], 16)]
                if time_min < int_x < time_max:
                    if is_first:
                        is_first = False
                        first_index = i
                    last_index = i
                    query_text.append(content[i])
                    query_text_len += len(content[i])
                    max_song_token -= 1

                if int_x > time_max:
                    if is_first:
                        is_first = False
                        first_index = i - 1
                        last_index = i - 1
                    break

            for i in range(first_index, -1, -1):
                if max_song_token <= 0 or any(field in content[i] for field in fields):
                    break
                query_text.append(content[i])
                query_text_len += len(content[i])
                max_song_token -= 1

            for i in range(last_index + 1, len(content)):
                if max_song_token <= 0 or any(field in content[i] for field in fields):
                    break
                query_text.append(content[i])
                query_text_len += len(content[i])
                max_song_token -= 1

        query_text = " ".join(query_text)

        return query_text

    def set_max_clause_count(self, max_clause_count):
        print(type(self.searcher.object.searcher))
        print("Original maxCaluseCount:", self.searcher.object.searcher.maxClauseCount)
        self.searcher.object.searcher.setMaxClauseCount(max_clause_count)
        print("Updated maxCaluseCount:", self.searcher.object.searcher.maxClauseCount)

    def search(self, queries, args):
        # self.searcher.set_bm25(b=0, k1=0)
        # self.searcher.set_qld(self.mu)
        self.searcher.set_jmlm()

        results = []

        for session_id, qtext in tqdm(queries):
            hits = self.searcher.search(qtext, args.k)
            results.append([session_id, [hit.docid for hit in hits]])
        return results

In [9]:
class Arg:
    def __init__(self, k):
        self.k = k


args = Arg(k=100)
index = "indexes/collection_jsonl_sparse"
stem = False

luceneSearcher = MyLuceneSearcher(index)
reader = IndexReader(index)
searcher = Searcher(searcher=luceneSearcher, reader=reader, is_stemming=stem)

max_clause_count = 1000000
searcher.set_max_clause_count(max_clause_count)

<class 'jnius.reflect.org.apache.lucene.search.IndexSearcher'>
Original maxCaluseCount: 1024
Updated maxCaluseCount: 1000000


In [10]:
# Prepare queries = [[session_id, query], ...]
print("Preparing queries...")
last_n_song = 10  # TODO : check this
queries = []
too_long_count = 0
for session, songs in tqdm(session_to_songs.items()):
    qtext = searcher.songs_to_query(session, 0, songs[-5:])
    queries.append([session, qtext])

Preparing queries...


100%|██████████| 143064/143064 [04:08<00:00, 576.05it/s] 


In [11]:
qtext

'artist2 album87608 language3 genrece4db56f6a48426643b08038139a8a75 7905a 29da8 2f62c 5795f 5795f ac049 71dd9 7905a 47b5b 7905a artist2 album87608 language3 genreb856b6781d370a3645c6dde0c20b3597 2e321 7f3b0 1b0bf 7a00e 6c5f6 7a00d 8d853 6423c 61982 89041 artist2 album87608 language3 genrece4db56f6a48426643b08038139a8a75 92989 98367 8d853 766f9 764aa 7bda8 8c331 69540 29b2b 80f76 artist31165878 album127736324 language3 genrece4db56f6a48426643b08038139a8a75 a9f6e 21293 3c6af 3c6af 7ca24 44804 944ea 944ea 944ea 944ea artist6278 a4c9f a4c9f 4d26c 8755c a4c9d a4c9d 4b758 9f4d1 327b9 a4c9b'

In [12]:
songs

['700a3bbe2b689e2da396bee4daafa4b2',
 '7ce1913e1511f3d77da7a0b32e640604',
 '700a3bbe2b689e2da396bee4daafa4b2',
 'bc603ae5839065a50a23592003bf4233',
 'c6cbfaccb4c07120a76da6b9c14e5902',
 '14b087038ee0d59c55ae0a8e6cbbe081',
 '110aed5fac7d6f46e7a667ad1261d42e',
 '488eb3b766d16e5e13009566aeb5ab5f',
 '68b5f72cb29c3ad2fffa47209597860f',
 'a20728aa7e8122584e8b5863c7d0bc02',
 '015c340cf3a75afb53040ce0d01e6b13',
 '8a4c8f80d095a42feaedaee4cf25be84',
 '723ab216ae4e5161e397c462a1cf8954',
 'e667f16939964f81bccad710ed0adce9',
 '1353435a907399cb65b925c6b5e3960b',
 '00cef2617cceaa2299a47a79f6100ee2',
 '605b1abcba2f893bdceaf20276be07c7',
 '3df18462598942a0e906c9327fc0e738',
 '3b4f31812ec47aaf14ab56939dbe9b57',
 '1c285118397adfe939edb78504fe6259']

In [14]:
print("Searching...")
results = searcher.search(queries, args)

Searching...
set to jmlm with lambda = 0.9999


100%|██████████| 143064/143064 [26:48<00:00, 88.93it/s] 


In [15]:
import pickle

with open('jmlm_0.9999_token10.pkl', 'wb') as file:
    pickle.dump(results, file)