In [5]:
from pandas import read_csv
from torch.optim import SGD
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torch import randperm
from config import DATA_DIR
from src.data_set import RatingsDataset
from src.loss import MiningOutliersLoss
from src.model import MF
from src.runner import Runner
from src.utils import (
    create_dataset,
    mine_outliers_scipy,
    mine_outliers_sklearn,
    mine_outliers_torch,
    DataConverter,
    DataProcessor,
    mean_normalized
)


"""
The Deam dataset is based on Arousal-Valence 2D emotional model.
The Valence/Arousal ratings were collected using Amazon Mechanical Turks service.
Each turk from the collected crowd were asked to mark his own emotion for the current song on a 2D plane, Arousal/Valence.
For more information please read: https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0173392
"""

DF_PATH = f"{DATA_DIR}" \
          f"/DEAM/annotations/annotations per each rater/" \
          f"song_level/static_annotations_songs_1_2000.csv"

def select_n_random(trainset: RatingsDataset):
    """
    Selects n random data points and their corresponding labels from a dataset
    """
    perm = randperm(len(trainset))
    return trainset[perm][:100]

In [6]:
"""
This block of code calculates the outliers alongside the valence axis
"""
columns = ["workerID", "SongId", "Valence"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]
original_df = mean_normalized(dataframe=original_df)

data_converter = DataConverter(
        original_df=original_df, n_random_users=10, n_ratings_per_random_user=9
)

valence_model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
)
epochs = 50

criterion = MSELoss()
optimizer = SGD(valence_model.parameters(), lr=5, weight_decay=1e-3)
runner = Runner(
    model=valence_model,
    criterion=criterion,
    optimizer=optimizer,
    epochs=epochs
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=1000, shuffle=True)
users, items, ratings = select_n_random(train_set)

with SummaryWriter("runs/DEAM/valence") as writer:
    writer.add_graph(valence_model, (users, items))

    for epoch in range(epochs):
        epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
        print(f"epoch={epoch + 1}, loss={epoch_loss}")

_mean_centralised: 100%|██████████| 17464/17464 [00:16<00:00, 1063.16it/s]
100%|██████████| 18/18 [00:00<00:00, 30.62batch/s, train_loss=0.00692]


epoch=1, loss=0.0814835940810317


100%|██████████| 18/18 [00:00<00:00, 46.13batch/s, train_loss=0.00606]


epoch=2, loss=0.06798636331713157


100%|██████████| 18/18 [00:00<00:00, 33.55batch/s, train_loss=0.00481]


epoch=3, loss=0.058502753491005734


100%|██████████| 18/18 [00:00<00:00, 50.38batch/s, train_loss=0.00455]


epoch=4, loss=0.04980152355118349


100%|██████████| 18/18 [00:00<00:00, 45.85batch/s, train_loss=0.00405]


epoch=5, loss=0.041331847930212745


100%|██████████| 18/18 [00:00<00:00, 34.70batch/s, train_loss=0.00342]


epoch=6, loss=0.03483164211386808


100%|██████████| 18/18 [00:00<00:00, 46.60batch/s, train_loss=0.00283]


epoch=7, loss=0.030199822081555534


100%|██████████| 18/18 [00:00<00:00, 32.28batch/s, train_loss=0.00268]


epoch=8, loss=0.027198234273208183


100%|██████████| 18/18 [00:00<00:00, 52.22batch/s, train_loss=0.00227]


epoch=9, loss=0.02461367974728884


100%|██████████| 18/18 [00:00<00:00, 49.18batch/s, train_loss=0.00219]


epoch=10, loss=0.02250662789103787


100%|██████████| 18/18 [00:00<00:00, 31.82batch/s, train_loss=0.00214]


epoch=11, loss=0.020644006363536478


100%|██████████| 18/18 [00:00<00:00, 49.06batch/s, train_loss=0.00164] 


epoch=12, loss=0.018736409624561075


100%|██████████| 18/18 [00:00<00:00, 32.76batch/s, train_loss=0.00175] 


epoch=13, loss=0.017561593937099197


100%|██████████| 18/18 [00:00<00:00, 50.03batch/s, train_loss=0.0016]  


epoch=14, loss=0.01633946739228624


100%|██████████| 18/18 [00:00<00:00, 34.71batch/s, train_loss=0.00134] 


epoch=15, loss=0.015246117658993828


100%|██████████| 18/18 [00:00<00:00, 45.91batch/s, train_loss=0.00123] 


epoch=16, loss=0.014375878025263222


100%|██████████| 18/18 [00:00<00:00, 48.19batch/s, train_loss=0.00147] 


epoch=17, loss=0.01384298129331334


100%|██████████| 18/18 [00:00<00:00, 35.17batch/s, train_loss=0.00126] 


epoch=18, loss=0.013119633222099676


100%|██████████| 18/18 [00:00<00:00, 45.74batch/s, train_loss=0.00122] 


epoch=19, loss=0.012568821118196426


100%|██████████| 18/18 [00:00<00:00, 33.53batch/s, train_loss=0.00112] 


epoch=20, loss=0.01208951144528303


100%|██████████| 18/18 [00:00<00:00, 46.08batch/s, train_loss=0.00142] 


epoch=21, loss=0.01193111716904795


100%|██████████| 18/18 [00:00<00:00, 32.82batch/s, train_loss=0.00112] 


epoch=22, loss=0.01138399215144801


100%|██████████| 18/18 [00:00<00:00, 44.91batch/s, train_loss=0.000928]


epoch=23, loss=0.010958327877177227


100%|██████████| 18/18 [00:00<00:00, 43.47batch/s, train_loss=0.000964]


epoch=24, loss=0.010754962533712387


100%|██████████| 18/18 [00:00<00:00, 31.86batch/s, train_loss=0.0011]  


epoch=25, loss=0.010625270753883713


100%|██████████| 18/18 [00:00<00:00, 43.43batch/s, train_loss=0.000939]


epoch=26, loss=0.010338973437513254


100%|██████████| 18/18 [00:00<00:00, 31.02batch/s, train_loss=0.0011]  


epoch=27, loss=0.01027334788172684


100%|██████████| 18/18 [00:00<00:00, 42.81batch/s, train_loss=0.000955]


epoch=28, loss=0.010026671610046379


100%|██████████| 18/18 [00:00<00:00, 43.79batch/s, train_loss=0.00109] 


epoch=29, loss=0.009983128501082156


100%|██████████| 18/18 [00:00<00:00, 32.30batch/s, train_loss=0.000839]


epoch=30, loss=0.009673364326196456


100%|██████████| 18/18 [00:00<00:00, 42.19batch/s, train_loss=0.00112] 


epoch=31, loss=0.009761239019972322


100%|██████████| 18/18 [00:00<00:00, 30.62batch/s, train_loss=0.00103] 


epoch=32, loss=0.009600785347206068


100%|██████████| 18/18 [00:00<00:00, 40.87batch/s, train_loss=0.00106] 


epoch=33, loss=0.00952313748905805


100%|██████████| 18/18 [00:00<00:00, 29.21batch/s, train_loss=0.00102] 


epoch=34, loss=0.009407324361672038


100%|██████████| 18/18 [00:00<00:00, 41.78batch/s, train_loss=0.000912]


epoch=35, loss=0.009260369393584529


100%|██████████| 18/18 [00:00<00:00, 43.11batch/s, train_loss=0.000859]


epoch=36, loss=0.00911826607143836


100%|██████████| 18/18 [00:00<00:00, 32.11batch/s, train_loss=0.000804]


epoch=37, loss=0.00903819609333892


100%|██████████| 18/18 [00:00<00:00, 43.80batch/s, train_loss=0.000786]


epoch=38, loss=0.008988331359323613


100%|██████████| 18/18 [00:00<00:00, 26.90batch/s, train_loss=0.000777]


epoch=39, loss=0.00890268324439276


100%|██████████| 18/18 [00:00<00:00, 43.79batch/s, train_loss=0.000967]


epoch=40, loss=0.00899173850014752


100%|██████████| 18/18 [00:00<00:00, 25.02batch/s, train_loss=0.00102] 


epoch=41, loss=0.008974533970282826


100%|██████████| 18/18 [00:00<00:00, 31.45batch/s, train_loss=0.000903]


epoch=42, loss=0.008864385681139432


100%|██████████| 18/18 [00:00<00:00, 32.17batch/s, train_loss=0.0008]  


epoch=43, loss=0.008722887440924181


100%|██████████| 18/18 [00:00<00:00, 26.30batch/s, train_loss=0.000811]


epoch=44, loss=0.008691942779058154


100%|██████████| 18/18 [00:00<00:00, 33.40batch/s, train_loss=0.000971]


epoch=45, loss=0.008768643032234928


100%|██████████| 18/18 [00:00<00:00, 25.93batch/s, train_loss=0.000777]


epoch=46, loss=0.008587183776960477


100%|██████████| 18/18 [00:00<00:00, 34.52batch/s, train_loss=0.00083] 


epoch=47, loss=0.008595360188923158


100%|██████████| 18/18 [00:00<00:00, 34.49batch/s, train_loss=0.000944]


epoch=48, loss=0.008636535273024321


100%|██████████| 18/18 [00:00<00:00, 25.89batch/s, train_loss=0.000779]


epoch=49, loss=0.00852047487307972


100%|██████████| 18/18 [00:00<00:00, 31.07batch/s, train_loss=0.00088] 

epoch=50, loss=0.0085368808904925





In [7]:
"""
This block of code calculates the outliers alongside the arousal axis
"""
columns = ["workerID", "SongId", "Arousal"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]
original_df = mean_normalized(dataframe=original_df)

data_converter = DataConverter(
        original_df=original_df, n_random_users=10, n_ratings_per_random_user=9
)

arousal_model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
)
epochs = 50

criterion = MSELoss()
optimizer = SGD(arousal_model.parameters(), lr=5, weight_decay=1e-3)
runner = Runner(
    model=arousal_model,
    criterion=criterion,
    optimizer=optimizer,
    epochs=epochs,
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=1000, shuffle=True)
users, items, ratings = select_n_random(train_set)

epochs = 50
with SummaryWriter("runs/DEAM/arousal") as writer:
    writer.add_graph(arousal_model, (users, items))

    for epoch in range(epochs):
        epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
        print(f"epoch={epoch + 1}, loss={epoch_loss}")

_mean_centralised: 100%|██████████| 17464/17464 [00:14<00:00, 1164.34it/s]
100%|██████████| 18/18 [00:00<00:00, 30.84batch/s, train_loss=0.00685]


epoch=1, loss=0.07889618623988293


100%|██████████| 18/18 [00:00<00:00, 47.26batch/s, train_loss=0.00564]


epoch=2, loss=0.06499546795717646


100%|██████████| 18/18 [00:00<00:00, 31.14batch/s, train_loss=0.00512]


epoch=3, loss=0.05478513373192467


100%|██████████| 18/18 [00:00<00:00, 44.69batch/s, train_loss=0.00477]


epoch=4, loss=0.04448580438307477


100%|██████████| 18/18 [00:00<00:00, 45.92batch/s, train_loss=0.00356]


epoch=5, loss=0.03762155086512169


100%|██████████| 18/18 [00:00<00:00, 32.44batch/s, train_loss=0.00312]


epoch=6, loss=0.03381245666763842


100%|██████████| 18/18 [00:00<00:00, 47.05batch/s, train_loss=0.00272]


epoch=7, loss=0.030447905169497327


100%|██████████| 18/18 [00:00<00:00, 30.34batch/s, train_loss=0.00309]


epoch=8, loss=0.02797248597782011


100%|██████████| 18/18 [00:00<00:00, 45.30batch/s, train_loss=0.00245]


epoch=9, loss=0.024940464426033763


100%|██████████| 18/18 [00:00<00:00, 46.24batch/s, train_loss=0.00251]


epoch=10, loss=0.022814417329936255


100%|██████████| 18/18 [00:00<00:00, 29.70batch/s, train_loss=0.00205]


epoch=11, loss=0.020702147776469428


100%|██████████| 18/18 [00:00<00:00, 40.50batch/s, train_loss=0.002]   


epoch=12, loss=0.01913249135921148


100%|██████████| 18/18 [00:00<00:00, 31.53batch/s, train_loss=0.00168] 


epoch=13, loss=0.017612296510259164


100%|██████████| 18/18 [00:00<00:00, 44.76batch/s, train_loss=0.00166] 


epoch=14, loss=0.01643942211250966


100%|██████████| 18/18 [00:00<00:00, 32.42batch/s, train_loss=0.00142] 


epoch=15, loss=0.015332354153321537


100%|██████████| 18/18 [00:00<00:00, 44.51batch/s, train_loss=0.00142] 


epoch=16, loss=0.014498798413181992


100%|██████████| 18/18 [00:00<00:00, 47.56batch/s, train_loss=0.00127] 


epoch=17, loss=0.013664679345671447


100%|██████████| 18/18 [00:00<00:00, 30.88batch/s, train_loss=0.00109] 


epoch=18, loss=0.012933039145779522


100%|██████████| 18/18 [00:00<00:00, 43.90batch/s, train_loss=0.00125] 


epoch=19, loss=0.012485657034152683


100%|██████████| 18/18 [00:00<00:00, 31.31batch/s, train_loss=0.00137] 


epoch=20, loss=0.012112686167339985


100%|██████████| 18/18 [00:00<00:00, 40.73batch/s, train_loss=0.0013]  


epoch=21, loss=0.011670205339628002


100%|██████████| 18/18 [00:00<00:00, 45.07batch/s, train_loss=0.00122] 


epoch=22, loss=0.0112654181131387


100%|██████████| 18/18 [00:00<00:00, 31.46batch/s, train_loss=0.00108] 


epoch=23, loss=0.01087668497381658


100%|██████████| 18/18 [00:00<00:00, 48.41batch/s, train_loss=0.00108] 


epoch=24, loss=0.010597760508637134


100%|██████████| 18/18 [00:00<00:00, 31.64batch/s, train_loss=0.00101] 


epoch=25, loss=0.010339698937610598


100%|██████████| 18/18 [00:00<00:00, 46.89batch/s, train_loss=0.00103] 


epoch=26, loss=0.010135833440274538


100%|██████████| 18/18 [00:00<00:00, 31.50batch/s, train_loss=0.00123] 


epoch=27, loss=0.010100816605215899


100%|██████████| 18/18 [00:00<00:00, 45.44batch/s, train_loss=0.000931]


epoch=28, loss=0.009740626135565316


100%|██████████| 18/18 [00:00<00:00, 41.99batch/s, train_loss=0.00107] 


epoch=29, loss=0.009703272086188252


100%|██████████| 18/18 [00:00<00:00, 32.28batch/s, train_loss=0.0011]  


epoch=30, loss=0.009606495914368855


100%|██████████| 18/18 [00:00<00:00, 42.50batch/s, train_loss=0.00104] 


epoch=31, loss=0.009411596840361825


100%|██████████| 18/18 [00:00<00:00, 32.26batch/s, train_loss=0.0009]  


epoch=32, loss=0.009245627500950645


100%|██████████| 18/18 [00:00<00:00, 40.97batch/s, train_loss=0.0009]  


epoch=33, loss=0.009156464362079916


100%|██████████| 18/18 [00:00<00:00, 30.82batch/s, train_loss=0.00101] 


epoch=34, loss=0.009165615738836865


100%|██████████| 18/18 [00:00<00:00, 44.85batch/s, train_loss=0.000925]


epoch=35, loss=0.009015929130655761


100%|██████████| 18/18 [00:00<00:00, 42.88batch/s, train_loss=0.000932]


epoch=36, loss=0.008961428360710936


100%|██████████| 18/18 [00:00<00:00, 31.07batch/s, train_loss=0.0011]  


epoch=37, loss=0.00899576840493223


100%|██████████| 18/18 [00:00<00:00, 45.62batch/s, train_loss=0.000802]


epoch=38, loss=0.00873401103488805


100%|██████████| 18/18 [00:00<00:00, 31.31batch/s, train_loss=0.00103] 


epoch=39, loss=0.008851888255952498


100%|██████████| 18/18 [00:00<00:00, 46.74batch/s, train_loss=0.00109] 


epoch=40, loss=0.008819737506902606


100%|██████████| 18/18 [00:00<00:00, 32.44batch/s, train_loss=0.000975]


epoch=41, loss=0.008706249832461457


100%|██████████| 18/18 [00:00<00:00, 26.54batch/s, train_loss=0.000841]


epoch=42, loss=0.00856932501242049


100%|██████████| 18/18 [00:00<00:00, 31.66batch/s, train_loss=0.000732]


epoch=43, loss=0.008448995469494418


100%|██████████| 18/18 [00:00<00:00, 24.48batch/s, train_loss=0.00085] 


epoch=44, loss=0.008494445961842899


100%|██████████| 18/18 [00:00<00:00, 30.33batch/s, train_loss=0.000835]


epoch=45, loss=0.008457641193044747


100%|██████████| 18/18 [00:00<00:00, 24.24batch/s, train_loss=0.00107] 


epoch=46, loss=0.008588205877193906


100%|██████████| 18/18 [00:00<00:00, 30.87batch/s, train_loss=0.000733]


epoch=47, loss=0.008312527638049764


100%|██████████| 18/18 [00:00<00:00, 32.59batch/s, train_loss=0.000901]


epoch=48, loss=0.008427848756420914


100%|██████████| 18/18 [00:00<00:00, 27.30batch/s, train_loss=0.00085] 


epoch=49, loss=0.00836037932503094


100%|██████████| 18/18 [00:00<00:00, 32.79batch/s, train_loss=0.000882]

epoch=50, loss=0.008336783372430594





In [8]:
valence_outliers = mine_outliers_scipy(model=valence_model, data_converter=data_converter)
arousal_outliers = mine_outliers_scipy(model=arousal_model, data_converter=data_converter)

items_group_by_users = data_converter.original_df.groupby("user_id")
combined_outliers = {}
for user_id, valence_dist in valence_outliers.items():
    arousal_dist = arousal_outliers[user_id]
    combined_outliers[user_id] = valence_dist + arousal_dist

combined_outliers = dict(sorted(combined_outliers.items(), key=lambda item: item[1]))
for user_id, item_id in combined_outliers.items():
    number_of_items = len(items_group_by_users.get_group(user_id))
    print(f"user: {user_id}, dist: {item_id}, #items: {number_of_items}")

user: ff18a27328ffd40ef52b7ebb7a0ded94, dist: -74.27330861100704, #items: 20
user: 19fee46f2810f34a8b69a7768d897a59, dist: -53.744675938102034, #items: 1
user: random_guy_191, dist: -47.90946700519498, #items: 9
user: random_guy_190, dist: -46.617390778418574, #items: 9
user: random_guy_188, dist: -43.83338179209181, #items: 9
user: 3111e02887b600ee085c72c0a3df33e8, dist: -36.50610478373544, #items: 1
user: random_guy_189, dist: -35.82082827645257, #items: 9
user: fd5b08ce362d855ca9152a894348130c, dist: -35.6837828862546, #items: 222
user: 2f790705ae66e70e81cc0f11ce0f4b9b, dist: -32.06156850854283, #items: 2
user: b8ef6a913a63225faafd661ee2e1a7c0, dist: -23.481714259168395, #items: 10
user: e105c200f413d7b2c5850c0df4b9687e, dist: -23.32396391727844, #items: 2
user: 7cecbffe1da5ae974952db6c13695afe, dist: -19.721673771900445, #items: 428
user: 027cefa6afc040448d29558b3175cdc1, dist: -9.499037434780577, #items: 9
user: 8eb1abd1acca601d1e23e85c69b1742a, dist: -6.8273288473123115, #items: 