In [3]:
from pandas import read_csv
from torch.optim import SGD
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torch import randperm
from config import DATA_DIR
from src.data_set import RatingsDataset
from src.loss import MiningOutliersLoss
from src.model import MF
from src.runner import Runner
from src.utils import create_dataset, mine_outliers, DataConverter, DataProcessor, mean_centralised


"""
The Deam dataset is based on Arousal-Valence 2D emotional model.
The Valence/Arousal ratings were collected using Amazon Mechanical Turks service.
Each turk from the collected crowd were asked to mark his own emotion for the current song on a 2D plane, Arousal/Valence.
For more information please read: https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0173392
"""

DF_PATH = f"{DATA_DIR}" \
          f"/DEAM/annotations/annotations per each rater/" \
          f"song_level/static_annotations_songs_1_2000.csv"

def select_n_random(trainset: RatingsDataset):
    """
    Selects n random data points and their corresponding labels from a dataset
    """
    perm = randperm(len(trainset))
    return trainset[perm][:100]

In [4]:
"""
This block of code calculates the outliers alongside the valence axis
"""
columns = ["workerID", "SongId", "Valence"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]
original_df = mean_centralised(dataframe=original_df)

data_converter = DataConverter(
        original_df=original_df, n_random_users=10, n_ratings_per_random_user=9
)

valence_model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
)
epochs = 50

criterion = MSELoss()
optimizer = SGD(valence_model.parameters(), lr=5, weight_decay=1e-3)
runner = Runner(
    model=valence_model,
    criterion=criterion,
    optimizer=optimizer,
    epochs=epochs
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=1000, shuffle=True)
users, items, ratings = select_n_random(train_set)

with SummaryWriter("runs/DEAM/valence") as writer:
    writer.add_graph(valence_model, (users, items))

    for epoch in range(epochs):
        epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
        print(f"epoch={epoch + 1}, loss={epoch_loss}")

_mean_centralised: 100%|██████████| 17464/17464 [00:07<00:00, 2212.91it/s]
100%|██████████| 18/18 [00:01<00:00, 17.59batch/s, train_loss=0.00658]


epoch=1, loss=0.08053845308403676


100%|██████████| 18/18 [00:00<00:00, 35.10batch/s, train_loss=0.00632]


epoch=2, loss=0.06756571026829605


100%|██████████| 18/18 [00:00<00:00, 46.90batch/s, train_loss=0.0052] 


epoch=3, loss=0.0590124578570631


100%|██████████| 18/18 [00:00<00:00, 36.72batch/s, train_loss=0.00471]


epoch=4, loss=0.05122729203503054


100%|██████████| 18/18 [00:00<00:00, 130.18batch/s, train_loss=0.00414]


epoch=5, loss=0.04297406689957161


100%|██████████| 18/18 [00:00<00:00, 128.22batch/s, train_loss=0.00329]


epoch=6, loss=0.03576669341691564


100%|██████████| 18/18 [00:00<00:00, 65.89batch/s, train_loss=0.00285]


epoch=7, loss=0.030837352539227758


100%|██████████| 18/18 [00:00<00:00, 135.47batch/s, train_loss=0.00254]


epoch=8, loss=0.027356661591719206


100%|██████████| 18/18 [00:00<00:00, 128.33batch/s, train_loss=0.00234]


epoch=9, loss=0.024586266924758255


100%|██████████| 18/18 [00:00<00:00, 134.13batch/s, train_loss=0.0023] 


epoch=10, loss=0.022441622821002233


100%|██████████| 18/18 [00:00<00:00, 89.79batch/s, train_loss=0.00181]


epoch=11, loss=0.02025371558528515


100%|██████████| 18/18 [00:00<00:00, 135.05batch/s, train_loss=0.00177] 


epoch=12, loss=0.01866542534836793


100%|██████████| 18/18 [00:00<00:00, 120.57batch/s, train_loss=0.00168] 


epoch=13, loss=0.017365150405611804


100%|██████████| 18/18 [00:00<00:00, 92.98batch/s, train_loss=0.00186]


epoch=14, loss=0.01642101346542689


100%|██████████| 18/18 [00:00<00:00, 132.44batch/s, train_loss=0.00147] 


epoch=15, loss=0.015228533875640979


100%|██████████| 18/18 [00:00<00:00, 133.71batch/s, train_loss=0.00151] 


epoch=16, loss=0.014484790862467317


100%|██████████| 18/18 [00:00<00:00, 82.84batch/s, train_loss=0.00131]  


epoch=17, loss=0.013666090992813934


100%|██████████| 18/18 [00:00<00:00, 123.22batch/s, train_loss=0.00127] 


epoch=18, loss=0.013062221970583989


100%|██████████| 18/18 [00:00<00:00, 130.74batch/s, train_loss=0.00145] 


epoch=19, loss=0.012714754198002042


100%|██████████| 18/18 [00:00<00:00, 126.43batch/s, train_loss=0.0013]  


epoch=20, loss=0.012162771804022877


100%|██████████| 18/18 [00:00<00:00, 89.80batch/s, train_loss=0.000968]


epoch=21, loss=0.011552533589115212


100%|██████████| 18/18 [00:00<00:00, 126.18batch/s, train_loss=0.00112] 


epoch=22, loss=0.01135573537599309


100%|██████████| 18/18 [00:00<00:00, 116.29batch/s, train_loss=0.00126] 


epoch=23, loss=0.01118507367630728


100%|██████████| 18/18 [00:00<00:00, 89.26batch/s, train_loss=0.00133]  


epoch=24, loss=0.01098383079603691


100%|██████████| 18/18 [00:00<00:00, 125.22batch/s, train_loss=0.00126] 


epoch=25, loss=0.010729852475629388


100%|██████████| 18/18 [00:00<00:00, 126.47batch/s, train_loss=0.00101] 


epoch=26, loss=0.010340175223909991


100%|██████████| 18/18 [00:00<00:00, 113.79batch/s, train_loss=0.000999]


epoch=27, loss=0.01016490375081124


100%|██████████| 18/18 [00:00<00:00, 87.19batch/s, train_loss=0.000928]


epoch=28, loss=0.009970848901069551


100%|██████████| 18/18 [00:00<00:00, 127.86batch/s, train_loss=0.00103] 


epoch=29, loss=0.009910300728646427


100%|██████████| 18/18 [00:00<00:00, 126.47batch/s, train_loss=0.000956]


epoch=30, loss=0.009746217223496214


100%|██████████| 18/18 [00:00<00:00, 89.32batch/s, train_loss=0.0011] 


epoch=31, loss=0.009733267645973591


100%|██████████| 18/18 [00:00<00:00, 126.02batch/s, train_loss=0.000883]


epoch=32, loss=0.009476388634542265


100%|██████████| 18/18 [00:00<00:00, 101.79batch/s, train_loss=0.000956]


epoch=33, loss=0.00943965810601892


100%|██████████| 18/18 [00:00<00:00, 88.99batch/s, train_loss=0.00105] 


epoch=34, loss=0.009426988211672227


100%|██████████| 18/18 [00:00<00:00, 126.92batch/s, train_loss=0.000896]


epoch=35, loss=0.009247743907818295


100%|██████████| 18/18 [00:00<00:00, 125.65batch/s, train_loss=0.000938]


epoch=36, loss=0.009195164889204802


100%|██████████| 18/18 [00:00<00:00, 126.63batch/s, train_loss=0.000965]


epoch=37, loss=0.009168972294253133


100%|██████████| 18/18 [00:00<00:00, 90.31batch/s, train_loss=0.000901]


epoch=38, loss=0.009059180966891104


100%|██████████| 18/18 [00:00<00:00, 124.31batch/s, train_loss=0.000825]


epoch=39, loss=0.008947177203123319


100%|██████████| 18/18 [00:00<00:00, 128.20batch/s, train_loss=0.000921]


epoch=40, loss=0.00895131824991333


100%|██████████| 18/18 [00:00<00:00, 70.32batch/s, train_loss=0.00107]


epoch=41, loss=0.009019364911941845


100%|██████████| 18/18 [00:00<00:00, 89.09batch/s, train_loss=0.000877]


epoch=42, loss=0.008820980685605037


100%|██████████| 18/18 [00:00<00:00, 89.92batch/s, train_loss=0.000933]


epoch=43, loss=0.008809247613491134


100%|██████████| 18/18 [00:00<00:00, 89.21batch/s, train_loss=0.000903]


epoch=44, loss=0.008742301880345018


100%|██████████| 18/18 [00:00<00:00, 62.63batch/s, train_loss=0.000854]


epoch=45, loss=0.008687706102855798


100%|██████████| 18/18 [00:00<00:00, 91.21batch/s, train_loss=0.000883]


epoch=46, loss=0.008659508456929926


100%|██████████| 18/18 [00:00<00:00, 83.39batch/s, train_loss=0.000885]


epoch=47, loss=0.00865069864814032


100%|██████████| 18/18 [00:00<00:00, 66.15batch/s, train_loss=0.000959]


epoch=48, loss=0.008663185936449238


100%|██████████| 18/18 [00:00<00:00, 93.62batch/s, train_loss=0.000829]


epoch=49, loss=0.008515949558480122


100%|██████████| 18/18 [00:00<00:00, 88.51batch/s, train_loss=0.000734]

epoch=50, loss=0.008435969171218494





In [6]:
"""
This block of code calculates the outliers alongside the arousal axis
"""
columns = ["workerID", "SongId", "Arousal"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]
original_df = mean_centralised(dataframe=original_df)

data_converter = DataConverter(
        original_df=original_df, n_random_users=10, n_ratings_per_random_user=9
)

arousal_model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
)
epochs = 50

criterion = MSELoss()
optimizer = SGD(arousal_model.parameters(), lr=5, weight_decay=1e-3)
runner = Runner(
    model=arousal_model,
    criterion=criterion,
    optimizer=optimizer,
    epochs=epochs,
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=1000, shuffle=True)
users, items, ratings = select_n_random(train_set)

epochs = 50
with SummaryWriter("runs/DEAM/arousal") as writer:
    writer.add_graph(arousal_model, (users, items))

    for epoch in range(epochs):
        epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
        print(f"epoch={epoch + 1}, loss={epoch_loss}")

_mean_centralised: 100%|██████████| 17464/17464 [00:14<00:00, 1194.54it/s]
100%|██████████| 18/18 [00:02<00:00,  7.17batch/s, train_loss=0.00701]


epoch=1, loss=0.08052921229493318


100%|██████████| 18/18 [00:00<00:00, 36.93batch/s, train_loss=0.00578]


epoch=2, loss=0.0665833346662969


100%|██████████| 18/18 [00:00<00:00, 30.34batch/s, train_loss=0.00438]


epoch=3, loss=0.055215709966012286


100%|██████████| 18/18 [00:00<00:00, 28.97batch/s, train_loss=0.00429]


epoch=4, loss=0.044405589749236404


100%|██████████| 18/18 [00:00<00:00, 36.54batch/s, train_loss=0.00322]


epoch=5, loss=0.03694505412139617


100%|██████████| 18/18 [00:00<00:00, 35.27batch/s, train_loss=0.00319]


epoch=6, loss=0.03313680183198908


100%|██████████| 18/18 [00:00<00:00, 35.02batch/s, train_loss=0.00305]


epoch=7, loss=0.030072614759744723


100%|██████████| 18/18 [00:00<00:00, 29.32batch/s, train_loss=0.00275]


epoch=8, loss=0.02718027227622077


100%|██████████| 18/18 [00:00<00:00, 39.26batch/s, train_loss=0.00238]


epoch=9, loss=0.02464074395150484


100%|██████████| 18/18 [00:00<00:00, 39.80batch/s, train_loss=0.00214]


epoch=10, loss=0.022535072281042158


100%|██████████| 18/18 [00:00<00:00, 22.88batch/s, train_loss=0.00214]


epoch=11, loss=0.02092130002476248


100%|██████████| 18/18 [00:00<00:00, 31.16batch/s, train_loss=0.00199] 


epoch=12, loss=0.019446505890211043


100%|██████████| 18/18 [00:00<00:00, 35.00batch/s, train_loss=0.00177] 


epoch=13, loss=0.01801234378819001


100%|██████████| 18/18 [00:00<00:00, 33.70batch/s, train_loss=0.00172] 


epoch=14, loss=0.01690285271793496


100%|██████████| 18/18 [00:00<00:00, 28.10batch/s, train_loss=0.0015]  


epoch=15, loss=0.015770653200493822


100%|██████████| 18/18 [00:00<00:00, 41.95batch/s, train_loss=0.00145] 


epoch=16, loss=0.014912321170099375


100%|██████████| 18/18 [00:00<00:00, 39.95batch/s, train_loss=0.00125] 


epoch=17, loss=0.014043218545534981


100%|██████████| 18/18 [00:00<00:00, 29.36batch/s, train_loss=0.00136] 


epoch=18, loss=0.013485987454114837


100%|██████████| 18/18 [00:00<00:00, 35.78batch/s, train_loss=0.00113] 


epoch=19, loss=0.012768196799049307


100%|██████████| 18/18 [00:00<00:00, 37.99batch/s, train_loss=0.00127] 


epoch=20, loss=0.012360191030837998


100%|██████████| 18/18 [00:00<00:00, 28.70batch/s, train_loss=0.00116] 


epoch=21, loss=0.011886979221436953


100%|██████████| 18/18 [00:00<00:00, 38.90batch/s, train_loss=0.00121] 


epoch=22, loss=0.011544935643672944


100%|██████████| 18/18 [00:00<00:00, 40.16batch/s, train_loss=0.00107] 


epoch=23, loss=0.011142786511660485


100%|██████████| 18/18 [00:00<00:00, 41.16batch/s, train_loss=0.0011]  


epoch=24, loss=0.01084591759133425


100%|██████████| 18/18 [00:00<00:00, 25.66batch/s, train_loss=0.00121] 


epoch=25, loss=0.010707250558081945


100%|██████████| 18/18 [00:00<00:00, 39.23batch/s, train_loss=0.00109] 


epoch=26, loss=0.010396727009047674


100%|██████████| 18/18 [00:00<00:00, 40.17batch/s, train_loss=0.00105] 


epoch=27, loss=0.010174616547698148


100%|██████████| 18/18 [00:00<00:00, 28.28batch/s, train_loss=0.000767]


epoch=28, loss=0.009807351991372847


100%|██████████| 18/18 [00:00<00:00, 37.43batch/s, train_loss=0.000916]


epoch=29, loss=0.009753259412002906


100%|██████████| 18/18 [00:00<00:00, 37.84batch/s, train_loss=0.000987]


epoch=30, loss=0.009681299725917273


100%|██████████| 18/18 [00:00<00:00, 37.20batch/s, train_loss=0.000911]


epoch=31, loss=0.009511773929376462


100%|██████████| 18/18 [00:00<00:00, 24.90batch/s, train_loss=0.00119] 


epoch=32, loss=0.009595828999788752


100%|██████████| 18/18 [00:00<00:00, 32.20batch/s, train_loss=0.00105] 


epoch=33, loss=0.00941797518342841


100%|██████████| 18/18 [00:00<00:00, 29.85batch/s, train_loss=0.000883]


epoch=34, loss=0.009215831590581026


100%|██████████| 18/18 [00:00<00:00, 18.75batch/s, train_loss=0.000914]


epoch=35, loss=0.009138573634495373


100%|██████████| 18/18 [00:01<00:00, 14.09batch/s, train_loss=0.000801]


epoch=36, loss=0.009000555558863102


100%|██████████| 18/18 [00:11<00:00,  1.53batch/s, train_loss=0.000844]


epoch=37, loss=0.00895001822204366


100%|██████████| 18/18 [00:03<00:00,  5.11batch/s, train_loss=0.000778]


epoch=38, loss=0.008846968130837277


100%|██████████| 18/18 [00:01<00:00, 12.27batch/s, train_loss=0.00104] 


epoch=39, loss=0.00899892393550718


100%|██████████| 18/18 [00:00<00:00, 21.40batch/s, train_loss=0.00103] 


epoch=40, loss=0.008920032724684328


100%|██████████| 18/18 [00:01<00:00, 17.93batch/s, train_loss=0.000939]


epoch=41, loss=0.008783408561338158


100%|██████████| 18/18 [00:01<00:00, 10.63batch/s, train_loss=0.00101] 


epoch=42, loss=0.008819652932860793


100%|██████████| 18/18 [00:00<00:00, 22.65batch/s, train_loss=0.000901]


epoch=43, loss=0.008688358242868946


100%|██████████| 18/18 [00:00<00:00, 28.30batch/s, train_loss=0.000905]


epoch=44, loss=0.008651415245627668


100%|██████████| 18/18 [00:00<00:00, 22.88batch/s, train_loss=0.000916]


epoch=45, loss=0.008624133140063888


100%|██████████| 18/18 [00:00<00:00, 30.02batch/s, train_loss=0.000836]


epoch=46, loss=0.008528756508973532


100%|██████████| 18/18 [00:00<00:00, 29.68batch/s, train_loss=0.000772]


epoch=47, loss=0.00846604097696418


100%|██████████| 18/18 [00:00<00:00, 30.18batch/s, train_loss=0.000874]


epoch=48, loss=0.008489833172585561


100%|██████████| 18/18 [00:00<00:00, 25.13batch/s, train_loss=0.00107] 


epoch=49, loss=0.008631662661525747


100%|██████████| 18/18 [00:00<00:00, 27.84batch/s, train_loss=0.000992]

epoch=50, loss=0.008527898675375466





In [7]:
valence_outliers = mine_outliers(model=valence_model, data_converter=data_converter)
arousal_outliers = mine_outliers(model=arousal_model, data_converter=data_converter)

items_group_by_users = data_converter.original_df.groupby("user_id")
combined_outliers = {}
for user_id, valence_dist in valence_outliers.items():
    arousal_dist = arousal_outliers[user_id]
    combined_outliers[user_id] = valence_dist + arousal_dist

combined_outliers = dict(sorted(combined_outliers.items(), key=lambda item: item[1]))
for user_id, item_id in combined_outliers.items():
    number_of_items = len(items_group_by_users.get_group(user_id))
    print(f"user: {user_id}, dist: {item_id}, #items: {number_of_items}")

user: ff18a27328ffd40ef52b7ebb7a0ded94, dist: -75.23977661132812, #items: 20
user: random_guy_194, dist: -65.95697021484375, #items: 9
user: random_guy_195, dist: -65.27352905273438, #items: 9
user: 19fee46f2810f34a8b69a7768d897a59, dist: -48.07280349731445, #items: 1
user: 2f790705ae66e70e81cc0f11ce0f4b9b, dist: -40.8505859375, #items: 2
user: random_guy_191, dist: -40.778480529785156, #items: 9
user: fd5b08ce362d855ca9152a894348130c, dist: -36.675235748291016, #items: 222
user: 3111e02887b600ee085c72c0a3df33e8, dist: -32.53719711303711, #items: 1
user: random_guy_190, dist: -28.196727752685547, #items: 9
user: random_guy_192, dist: -28.08419418334961, #items: 9
user: random_guy_193, dist: -26.771160125732422, #items: 9
user: e105c200f413d7b2c5850c0df4b9687e, dist: -22.180830001831055, #items: 2
user: 7cecbffe1da5ae974952db6c13695afe, dist: -22.118755340576172, #items: 428
user: b8ef6a913a63225faafd661ee2e1a7c0, dist: -15.897381782531738, #items: 10
user: random_guy_188, dist: -15.346