In [1]:
from pandas import read_csv
from torch.optim import SGD
from torch.utils.data import DataLoader

from config import DATA_DIR
from src.loss import MiningOutliersLoss
from src.model import MF
from src.runner import Runner
from src.utils import create_dataset, mine_outliers, DataConverter, DataProcessor

DF_PATH = f"{DATA_DIR}/MovieLens/ratings.csv"

In [29]:
columns = ["user_id", "item_id", "rating"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=["userId", "movieId", "rating"])
original_df.columns = ["user_id", "item_id", "rating"]

data_converter = DataConverter(
    original_df=original_df, n_random_users=1, n_ratings_per_random_user=50
)
data_processor = DataProcessor(original_df=data_converter.original_df)

In [None]:
model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
)

criterion = MiningOutliersLoss(data_converter=data_converter, data_processor=data_processor)
optimizer = SGD(model.parameters(), lr=5, weight_decay=1e-5)
runner = Runner(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=10000, shuffle=True)
epochs = 10
for epoch in range(epochs):
    epoch_loss = runner.train(train_loader=train_load)
    print(f"epoch={epoch + 1}, loss={epoch_loss}")

outliers = mine_outliers(model=model, data_converter=data_converter)

 32%|███▏      | 7882/25001 [53:07<1:55:21,  2.47batch/s, train_loss=217]  


KeyboardInterrupt: 

In [28]:
items_group_by_users = data_converter.original_df.groupby("user_id")
outliers = dict(sorted(outliers.items(), key=lambda item: item[1]))
for user_id, item_id in outliers.items():
    number_of_items = len(items_group_by_users.get_group(user_id))
    print(f"user: {user_id}, dist: {item_id}, #items: {number_of_items}")

user: 620, dist: -54.56575012207031, #items: 43
user: 363, dist: -42.14396286010742, #items: 21
user: 74, dist: -41.66022872924805, #items: 22
user: 250, dist: -40.431488037109375, #items: 22
user: 299, dist: -38.344688415527344, #items: 23
user: 205, dist: -38.11968231201172, #items: 26
user: 533, dist: -37.82539367675781, #items: 25
user: 463, dist: -37.01531219482422, #items: 21
user: 32, dist: -36.62696075439453, #items: 33
user: 602, dist: -36.610984802246094, #items: 20
user: 214, dist: -35.97991180419922, #items: 35
user: 22, dist: -35.353851318359375, #items: 22
user: 10, dist: -35.29178237915039, #items: 53
user: 328, dist: -35.174888610839844, #items: 67
user: 644, dist: -35.11091613769531, #items: 36
user: 352, dist: -34.93895721435547, #items: 33
user: 539, dist: -34.55622863769531, #items: 36
user: 341, dist: -34.1892204284668, #items: 41
user: 330, dist: -34.03994369506836, #items: 22
user: 200, dist: -33.48986053466797, #items: 20
user: 604, dist: -33.486351013183594, #i