In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from scipy import sparse

from lib import *
from classes import *

from lightfm import LightFM
from tqdm.notebook import tqdm
tqdm.pandas()

In [2]:
users, orgs, reviews = get_data('data/')
train_reviews, test_reviews = train_test_split(reviews, 1166)
X_test, y_test = process_reviews(test_reviews)

In [38]:
def filter_min_pos_rating(reviews, orgs, min_pos_rating):
    cr = reviews[['org_id', 'rating']]
    cr = cr[cr['rating'] >= 4]
    cr = cr.groupby('org_id').count().reset_index()
    cr = cr[cr['rating'] >= min_pos_rating][['org_id']]
    return orgs.merge(cr, on='org_id', how='inner'), reviews.merge(cr, on='org_id', how='inner')

def get_other_city(city):
        if city == 'msk':
            return 'spb'
        else:
            return 'msk'

class BaseLightFMSolver:
    def __init__(self, users, orgs):
        self.users = users
        self.orgs = orgs
        self.cities = ['msk', 'spb']
        self.orgs_c = dict()
        self.org_ctoi, self.org_itoc = create_mappings(self.orgs['org_id'])
        self.user_ctoi, self.user_itoc = create_mappings(self.users['user_id'])
    
    def fit(self, reviews, min_pos_rating=1):
        self.orgs, reviews = filter_min_pos_rating(reviews, self.orgs, min_pos_rating)
        for city in self.cities:
            self.orgs_c[city] = orgs[orgs['city'] == city]['org_id'].map(self.org_ctoi).values
            
        cd_reviews = reviews[['user_id', 'org_id', 'rating']].groupby(['user_id', 'org_id']).mean().reset_index()
        cd_reviews['rating'] = cd_reviews['rating'].apply(lambda x: 1 if x >= 4 else -1)
        I = cd_reviews['user_id'].map(self.user_ctoi)
        J = cd_reviews['org_id'].map(self.org_ctoi)
        X = cd_reviews['rating']
        self.train = sparse.coo_matrix((X, (I, J)), shape=(max(self.user_itoc) + 1, max(self.org_itoc) + 1))
        self.model = LightFM(loss='warp', learning_rate=0.01, no_components=10, item_alpha=5e-6, user_alpha=5e-6)

    def fit_partial(self, epochs):
        self.model.fit_partial(self.train, epochs=epochs, verbose=True)

    def predict(self, X_test, path=None):
        test_users_with_locations = X_test.merge(self.users, on='user_id')
        def f(row):
            user = self.user_ctoi[row['user_id']]
            org_city = get_other_city(row['city'])
            orgs_in_city = self.orgs_c[org_city]
            res = self.model.predict(np.full(len(orgs_in_city), user), orgs_in_city)
            ind = np.argpartition(res, -N)[-N:]
            ind = np.vectorize(self.org_itoc.get)(orgs_in_city[np.flip(ind[np.argsort(res[ind])])])
            return ind
        target = test_users_with_locations.progress_apply(f, axis=1)
        predictions = X_test.copy()
        predictions['target'] = target
        if path != None:
            predictions_str = predictions.copy()
            predictions_str['target'] = predictions_str['target'].apply(lambda x: ' '.join(map(str, x)))
            predictions_str.to_csv(path, index=None)

        return predictions

In [33]:
lfm = BaseLightFMSolver(users, orgs)
lfm.fit(train_reviews, min_pos_rating=0)

In [39]:
lfm = BaseLightFMSolver(users, orgs)
lfm.fit(train_reviews, min_pos_rating=5)

In [40]:
lfm.fit_partial(60)

Epoch: 100%|██████████| 80/80 [13:28<00:00, 10.10s/it]


In [41]:
test_users = pd.read_csv('data/test_users.csv')
lfm.predict(test_users, "answers.csv")

  0%|          | 0/16967 [00:00<?, ?it/s]

Unnamed: 0,user_id,target
0,3545210947248911048,"[12046097390037935713, 6838233943148091808, 50..."
1,15271987121288045390,"[6838233943148091808, 12046097390037935713, 50..."
2,15016858616184265932,"[12046097390037935713, 14814427257061788801, 5..."
3,12457244142928722989,"[12046097390037935713, 14814427257061788801, 5..."
4,13339684649926251468,"[15250345250621165867, 1625971115460696067, 91..."
...,...,...
16962,1191875913294598364,"[12046097390037935713, 8773822990269846303, 13..."
16963,3866507700167344338,"[12046097390037935713, 6838233943148091808, 15..."
16964,11434952144484188987,"[12046097390037935713, 5002407858008059043, 68..."
16965,7010426792722803474,"[11229813210509740706, 9104453017196776235, 15..."
