In [1]:
import torch
from pandas import read_csv
from torch.optim import SGD
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torch import randperm
from config import DATA_DIR, MODELS_DIR
from src.data_set import RatingsDataset
from src.model import MF
from src.runner import Runner
from src.utils import (
    create_dataset,
    mine_outliers_scipy,
    DataConverter,
)
from src.consistency import direct_calculation


"""
The Deam dataset is based on Arousal-Valence 2D emotional model.
The Valence/Arousal ratings were collected using Amazon Mechanical Turks service.
Each turk from the collected crowd were asked to mark his own emotion for the current song on a 2D plane, Arousal/Valence.
For more information please read: https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0173392
"""

DF_PATH = f"{DATA_DIR}" \
          f"/DEAM/annotations/annotations per each rater/" \
          f"song_level/static_annotations_songs_1_2000.csv"

def select_n_random(trainset: RatingsDataset):
    """
    Selects n random data points and their corresponding labels from a dataset
    """
    perm = randperm(len(trainset))
    return trainset[perm][:100]

In [2]:
"""
This block analyze raw data consistency using the direct calculation defined by:
consistency += row.rating - row.song.mean() for all rows in dataset
"""
columns = ["workerID", "SongId", "Valence"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]
consistency = direct_calculation(data_frame=original_df)
print(f"Raw data consistency according to direct calculation is: \x1b[33m{consistency}\x1b[32m")

  0%|          | 0/17464 [00:00<?, ?it/s]

Raw data consistency according to direct calculation is: [33m2.5579538487363607e-13[32m


In [3]:
"""
This block of code calculates the outliers alongside the valence axis
"""
columns = ["workerID", "SongId", "Valence"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]

data_converter = DataConverter(
        original_df=original_df, n_random_users=0, n_ratings_per_random_user=9
)

valence_model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
    include_bias=True
)
epochs = 100

criterion = MSELoss()
optimizer = SGD(valence_model.parameters(), lr=5, weight_decay=1e-3)
runner = Runner(
    model=valence_model,
    criterion=criterion,
    optimizer=optimizer,
    epochs=epochs
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=1000, shuffle=True)
users, items, ratings = select_n_random(train_set)

with SummaryWriter("runs/DEAM/valence") as writer:
    writer.add_graph(valence_model, (users, items))

    for epoch in range(epochs):
        epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)

        print(f"epoch={epoch + 1}, loss={epoch_loss}")

torch.save(valence_model.state_dict(), f"{MODELS_DIR}/DEAM/raw/valence.pt")

100%|██████████| 18/18 [00:00<00:00, 103.29batch/s, train_loss=0.0122] 


epoch=1, loss=0.17982098235755134


100%|██████████| 18/18 [00:00<00:00, 54.19batch/s, train_loss=0.00933]


epoch=2, loss=0.09419888197964636


100%|██████████| 18/18 [00:00<00:00, 67.95batch/s, train_loss=0.00823]


epoch=3, loss=0.07617476498258524


100%|██████████| 18/18 [00:00<00:00, 73.22batch/s, train_loss=0.00658]


epoch=4, loss=0.06299791464312324


100%|██████████| 18/18 [00:00<00:00, 43.06batch/s, train_loss=0.00585]


epoch=5, loss=0.05227582643360927


100%|██████████| 18/18 [00:00<00:00, 58.88batch/s, train_loss=0.00517]


epoch=6, loss=0.04426340256066158


100%|██████████| 18/18 [00:00<00:00, 51.58batch/s, train_loss=0.00421]


epoch=7, loss=0.038168219979467066


100%|██████████| 18/18 [00:00<00:00, 38.79batch/s, train_loss=0.00435]


epoch=8, loss=0.03458417874369128


100%|██████████| 18/18 [00:00<00:00, 54.35batch/s, train_loss=0.00385]


epoch=9, loss=0.03114878068299129


100%|██████████| 18/18 [00:00<00:00, 50.31batch/s, train_loss=0.00348]


epoch=10, loss=0.028427012295558534


100%|██████████| 18/18 [00:00<00:00, 46.85batch/s, train_loss=0.00326]


epoch=11, loss=0.026185961468466395


100%|██████████| 18/18 [00:00<00:00, 37.10batch/s, train_loss=0.00237]


epoch=12, loss=0.023765828763616496


100%|██████████| 18/18 [00:00<00:00, 46.42batch/s, train_loss=0.0025] 


epoch=13, loss=0.022419746304380483


100%|██████████| 18/18 [00:00<00:00, 50.28batch/s, train_loss=0.00244]


epoch=14, loss=0.021154647374975265


100%|██████████| 18/18 [00:00<00:00, 43.42batch/s, train_loss=0.00245]


epoch=15, loss=0.020126481128150013


100%|██████████| 18/18 [00:00<00:00, 67.73batch/s, train_loss=0.00276] 


epoch=16, loss=0.019487486985223047


100%|██████████| 18/18 [00:00<00:00, 64.59batch/s, train_loss=0.00164] 


epoch=17, loss=0.017810346320785327


100%|██████████| 18/18 [00:00<00:00, 53.88batch/s, train_loss=0.00204] 


epoch=18, loss=0.017517926406243754


100%|██████████| 18/18 [00:00<00:00, 80.11batch/s, train_loss=0.00187]


epoch=19, loss=0.016851023875433822


100%|██████████| 18/18 [00:00<00:00, 81.22batch/s, train_loss=0.0019] 


epoch=20, loss=0.016373711768923136


100%|██████████| 18/18 [00:00<00:00, 63.04batch/s, train_loss=0.00211]


epoch=21, loss=0.01604013973988336


100%|██████████| 18/18 [00:00<00:00, 76.34batch/s, train_loss=0.00186] 


epoch=22, loss=0.015529928134433152


100%|██████████| 18/18 [00:00<00:00, 86.10batch/s, train_loss=0.00122] 


epoch=23, loss=0.014675856945843533


100%|██████████| 18/18 [00:00<00:00, 53.05batch/s, train_loss=0.00169] 


epoch=24, loss=0.014731369165510965


100%|██████████| 18/18 [00:00<00:00, 86.42batch/s, train_loss=0.00184]


epoch=25, loss=0.01457916102943749


100%|██████████| 18/18 [00:00<00:00, 74.81batch/s, train_loss=0.00161] 


epoch=26, loss=0.014184908036527962


100%|██████████| 18/18 [00:00<00:00, 67.81batch/s, train_loss=0.00194] 


epoch=27, loss=0.014214383048230205


100%|██████████| 18/18 [00:00<00:00, 77.07batch/s, train_loss=0.00157] 


epoch=28, loss=0.013704967798857855


100%|██████████| 18/18 [00:00<00:00, 80.88batch/s, train_loss=0.00162] 


epoch=29, loss=0.013575415390318837


100%|██████████| 18/18 [00:00<00:00, 83.81batch/s, train_loss=0.00179]


epoch=30, loss=0.01356268791905765


100%|██████████| 18/18 [00:00<00:00, 52.33batch/s, train_loss=0.00182] 


epoch=31, loss=0.013427715383726977


100%|██████████| 18/18 [00:00<00:00, 76.10batch/s, train_loss=0.00185] 


epoch=32, loss=0.01330922019173359


100%|██████████| 18/18 [00:00<00:00, 75.40batch/s, train_loss=0.00143] 


epoch=33, loss=0.012862727388225753


100%|██████████| 18/18 [00:00<00:00, 45.42batch/s, train_loss=0.00163] 


epoch=34, loss=0.012861349410024182


100%|██████████| 18/18 [00:00<00:00, 62.77batch/s, train_loss=0.00189] 


epoch=35, loss=0.013021687715217984


100%|██████████| 18/18 [00:00<00:00, 75.28batch/s, train_loss=0.0015]  


epoch=36, loss=0.012569405940072286


100%|██████████| 18/18 [00:00<00:00, 36.69batch/s, train_loss=0.00129] 


epoch=37, loss=0.012303183231888145


100%|██████████| 18/18 [00:00<00:00, 69.69batch/s, train_loss=0.00154] 


epoch=38, loss=0.012460330950802767


100%|██████████| 18/18 [00:00<00:00, 77.97batch/s, train_loss=0.00152]


epoch=39, loss=0.012364707713497097


100%|██████████| 18/18 [00:00<00:00, 39.01batch/s, train_loss=0.00164] 


epoch=40, loss=0.012381994821901979


100%|██████████| 18/18 [00:00<00:00, 58.66batch/s, train_loss=0.00131] 


epoch=41, loss=0.012076462695310854


100%|██████████| 18/18 [00:00<00:00, 58.94batch/s, train_loss=0.00157] 


epoch=42, loss=0.012187839932482817


100%|██████████| 18/18 [00:00<00:00, 58.78batch/s, train_loss=0.00151] 


epoch=43, loss=0.012010826598981332


100%|██████████| 18/18 [00:00<00:00, 71.87batch/s, train_loss=0.00125]


epoch=44, loss=0.011833128597201974


100%|██████████| 18/18 [00:00<00:00, 73.78batch/s, train_loss=0.00124] 


epoch=45, loss=0.011760105359143224


100%|██████████| 18/18 [00:00<00:00, 56.11batch/s, train_loss=0.00188] 


epoch=46, loss=0.012217574965337229


100%|██████████| 18/18 [00:00<00:00, 67.63batch/s, train_loss=0.00146] 


epoch=47, loss=0.011826435475513854


100%|██████████| 18/18 [00:00<00:00, 72.05batch/s, train_loss=0.00155] 


epoch=48, loss=0.011832504799653747


100%|██████████| 18/18 [00:00<00:00, 53.75batch/s, train_loss=0.00135] 


epoch=49, loss=0.011637733918839488


100%|██████████| 18/18 [00:00<00:00, 80.89batch/s, train_loss=0.00141]


epoch=50, loss=0.011664526383424629


100%|██████████| 18/18 [00:00<00:00, 66.62batch/s, train_loss=0.00156] 


epoch=51, loss=0.0116963249157215


100%|██████████| 18/18 [00:00<00:00, 55.86batch/s, train_loss=0.00129] 


epoch=52, loss=0.011461943410593887


100%|██████████| 18/18 [00:00<00:00, 78.17batch/s, train_loss=0.00132] 


epoch=53, loss=0.011444385316865197


100%|██████████| 18/18 [00:00<00:00, 72.64batch/s, train_loss=0.00126] 


epoch=54, loss=0.01141723427279242


100%|██████████| 18/18 [00:00<00:00, 57.38batch/s, train_loss=0.00137] 


epoch=55, loss=0.011446676621149328


100%|██████████| 18/18 [00:00<00:00, 71.63batch/s, train_loss=0.00143] 


epoch=56, loss=0.011419739596802614


100%|██████████| 18/18 [00:00<00:00, 83.92batch/s, train_loss=0.00152]


epoch=57, loss=0.011527984519456995


100%|██████████| 18/18 [00:00<00:00, 73.56batch/s, train_loss=0.00155] 


epoch=58, loss=0.011480296978662752


100%|██████████| 18/18 [00:00<00:00, 60.51batch/s, train_loss=0.00137] 


epoch=59, loss=0.011299198227709736


100%|██████████| 18/18 [00:00<00:00, 80.14batch/s, train_loss=0.00156] 


epoch=60, loss=0.011436411033416617


100%|██████████| 18/18 [00:00<00:00, 71.85batch/s, train_loss=0.00134] 


epoch=61, loss=0.0112477100635397


100%|██████████| 18/18 [00:00<00:00, 58.73batch/s, train_loss=0.00156] 


epoch=62, loss=0.011411345855943089


100%|██████████| 18/18 [00:00<00:00, 72.50batch/s, train_loss=0.00141] 


epoch=63, loss=0.011299013372125298


100%|██████████| 18/18 [00:00<00:00, 73.45batch/s, train_loss=0.00146] 


epoch=64, loss=0.011232685353221564


100%|██████████| 18/18 [00:00<00:00, 57.63batch/s, train_loss=0.0012]  


epoch=65, loss=0.011038808759944193


100%|██████████| 18/18 [00:00<00:00, 81.67batch/s, train_loss=0.00135]


epoch=66, loss=0.011127979496429706


100%|██████████| 18/18 [00:00<00:00, 74.17batch/s, train_loss=0.00119] 


epoch=67, loss=0.01097597570152118


100%|██████████| 18/18 [00:00<00:00, 56.05batch/s, train_loss=0.0012]  


epoch=68, loss=0.010997138428276982


100%|██████████| 18/18 [00:00<00:00, 73.96batch/s, train_loss=0.000885]


epoch=69, loss=0.010752483389500913


100%|██████████| 18/18 [00:00<00:00, 68.80batch/s, train_loss=0.00111] 


epoch=70, loss=0.010932396418061749


100%|██████████| 18/18 [00:00<00:00, 59.46batch/s, train_loss=0.00141] 


epoch=71, loss=0.01115465666918919


100%|██████████| 18/18 [00:00<00:00, 77.32batch/s, train_loss=0.0012] 


epoch=72, loss=0.010959448395104243


100%|██████████| 18/18 [00:00<00:00, 78.26batch/s, train_loss=0.00103] 


epoch=73, loss=0.01077587304495532


100%|██████████| 18/18 [00:00<00:00, 58.25batch/s, train_loss=0.00163] 


epoch=74, loss=0.0113163848344622


100%|██████████| 18/18 [00:00<00:00, 71.07batch/s, train_loss=0.00134] 


epoch=75, loss=0.010993758824364892


100%|██████████| 18/18 [00:00<00:00, 79.67batch/s, train_loss=0.00157] 


epoch=76, loss=0.011156056487354737


100%|██████████| 18/18 [00:00<00:00, 71.17batch/s, train_loss=0.00164] 


epoch=77, loss=0.01124121210903957


100%|██████████| 18/18 [00:00<00:00, 51.90batch/s, train_loss=0.00186] 


epoch=78, loss=0.011329361295905605


100%|██████████| 18/18 [00:00<00:00, 74.15batch/s, train_loss=0.00126] 


epoch=79, loss=0.010884048969581209


100%|██████████| 18/18 [00:00<00:00, 65.34batch/s, train_loss=0.00118] 


epoch=80, loss=0.010821593013303035


100%|██████████| 18/18 [00:00<00:00, 57.54batch/s, train_loss=0.00117] 


epoch=81, loss=0.010789312576425484


100%|██████████| 18/18 [00:00<00:00, 80.72batch/s, train_loss=0.00176]


epoch=82, loss=0.011250177050458974


100%|██████████| 18/18 [00:00<00:00, 81.11batch/s, train_loss=0.00168] 


epoch=83, loss=0.01116783401883882


100%|██████████| 18/18 [00:00<00:00, 58.25batch/s, train_loss=0.00114] 


epoch=84, loss=0.010661192929950256


100%|██████████| 18/18 [00:00<00:00, 83.70batch/s, train_loss=0.00104] 


epoch=85, loss=0.010650614601784738


100%|██████████| 18/18 [00:00<00:00, 73.34batch/s, train_loss=0.00156] 


epoch=86, loss=0.011085799779357582


100%|██████████| 18/18 [00:00<00:00, 60.02batch/s, train_loss=0.00142] 


epoch=87, loss=0.010898823605529193


100%|██████████| 18/18 [00:00<00:00, 73.96batch/s, train_loss=0.00137] 


epoch=88, loss=0.010897431673674749


100%|██████████| 18/18 [00:00<00:00, 82.56batch/s, train_loss=0.00133]


epoch=89, loss=0.010845802966890665


100%|██████████| 18/18 [00:00<00:00, 56.88batch/s, train_loss=0.00144] 


epoch=90, loss=0.01090586003250089


100%|██████████| 18/18 [00:00<00:00, 67.18batch/s, train_loss=0.00136] 


epoch=91, loss=0.010843664110734546


100%|██████████| 18/18 [00:00<00:00, 64.12batch/s, train_loss=0.00139] 


epoch=92, loss=0.010878787949167449


100%|██████████| 18/18 [00:00<00:00, 49.20batch/s, train_loss=0.00109] 


epoch=93, loss=0.010641334273691835


100%|██████████| 18/18 [00:00<00:00, 60.85batch/s, train_loss=0.00121] 


epoch=94, loss=0.010715804649838085


100%|██████████| 18/18 [00:00<00:00, 63.02batch/s, train_loss=0.0012]  


epoch=95, loss=0.010748627249536843


100%|██████████| 18/18 [00:00<00:00, 49.90batch/s, train_loss=0.0013]  


epoch=96, loss=0.010771890905396693


100%|██████████| 18/18 [00:00<00:00, 64.05batch/s, train_loss=0.00122] 


epoch=97, loss=0.010692014925438783


100%|██████████| 18/18 [00:00<00:00, 62.13batch/s, train_loss=0.00142] 


epoch=98, loss=0.010878291750776356


100%|██████████| 18/18 [00:00<00:00, 51.50batch/s, train_loss=0.0013]  


epoch=99, loss=0.010777630216088789


100%|██████████| 18/18 [00:00<00:00, 63.92batch/s, train_loss=0.00119] 

epoch=100, loss=0.010644557937465868





In [4]:
"""
This block of code calculates the outliers alongside the arousal axis
"""
columns = ["workerID", "SongId", "Arousal"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]

data_converter = DataConverter(
        original_df=original_df, n_random_users=0, n_ratings_per_random_user=9
)

arousal_model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
    include_bias=True
)
epochs = 100

criterion = MSELoss()
optimizer = SGD(arousal_model.parameters(), lr=5, weight_decay=1e-3)
runner = Runner(
    model=arousal_model,
    criterion=criterion,
    optimizer=optimizer,
    epochs=epochs,
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=1000, shuffle=True)
users, items, ratings = select_n_random(train_set)

with SummaryWriter("runs/DEAM/arousal") as writer:
    writer.add_graph(arousal_model, (users, items))

    for epoch in range(epochs):
        epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
        print(f"epoch={epoch + 1}, loss={epoch_loss}")

torch.save(arousal_model.state_dict(), f"{MODELS_DIR}/DEAM/raw/arousal.pt")

100%|██████████| 18/18 [00:00<00:00, 79.73batch/s, train_loss=0.0135]


epoch=1, loss=0.17883310205360936


100%|██████████| 18/18 [00:00<00:00, 57.15batch/s, train_loss=0.01]   


epoch=2, loss=0.09467792048947565


100%|██████████| 18/18 [00:00<00:00, 83.23batch/s, train_loss=0.00665]


epoch=3, loss=0.07207826583138828


100%|██████████| 18/18 [00:00<00:00, 77.29batch/s, train_loss=0.00688]


epoch=4, loss=0.05761670312388191


100%|██████████| 18/18 [00:00<00:00, 58.24batch/s, train_loss=0.00491]


epoch=5, loss=0.047798272297300146


100%|██████████| 18/18 [00:00<00:00, 82.25batch/s, train_loss=0.0045] 


epoch=6, loss=0.04247609856210906


100%|██████████| 18/18 [00:00<00:00, 79.08batch/s, train_loss=0.00451]


epoch=7, loss=0.03850702433339481


100%|██████████| 18/18 [00:00<00:00, 63.37batch/s, train_loss=0.00326]


epoch=8, loss=0.034191212645892435


100%|██████████| 18/18 [00:00<00:00, 78.13batch/s, train_loss=0.00387]


epoch=9, loss=0.03175173814132296


100%|██████████| 18/18 [00:00<00:00, 87.27batch/s, train_loss=0.00298]


epoch=10, loss=0.02863527478431833


100%|██████████| 18/18 [00:00<00:00, 73.08batch/s, train_loss=0.00345]


epoch=11, loss=0.02711151904689854


100%|██████████| 18/18 [00:00<00:00, 61.36batch/s, train_loss=0.00271]


epoch=12, loss=0.024808207914747037


100%|██████████| 18/18 [00:00<00:00, 69.17batch/s, train_loss=0.00266]


epoch=13, loss=0.023318361105590033


100%|██████████| 18/18 [00:00<00:00, 70.82batch/s, train_loss=0.00238]


epoch=14, loss=0.021831014536578083


100%|██████████| 18/18 [00:00<00:00, 55.61batch/s, train_loss=0.00246] 


epoch=15, loss=0.020697899390911236


100%|██████████| 18/18 [00:00<00:00, 72.50batch/s, train_loss=0.00247]


epoch=16, loss=0.019738359177934713


100%|██████████| 18/18 [00:00<00:00, 70.97batch/s, train_loss=0.00215] 


epoch=17, loss=0.01861864854960606


100%|██████████| 18/18 [00:00<00:00, 60.47batch/s, train_loss=0.0017]  


epoch=18, loss=0.017473338151800223


100%|██████████| 18/18 [00:00<00:00, 82.97batch/s, train_loss=0.00151] 


epoch=19, loss=0.016655148891539413


100%|██████████| 18/18 [00:00<00:00, 83.50batch/s, train_loss=0.00171]


epoch=20, loss=0.016324336677789687


100%|██████████| 18/18 [00:00<00:00, 60.14batch/s, train_loss=0.00139] 


epoch=21, loss=0.015546332926585756


100%|██████████| 18/18 [00:00<00:00, 80.34batch/s, train_loss=0.0018]  


epoch=22, loss=0.015433423471861872


100%|██████████| 18/18 [00:00<00:00, 83.98batch/s, train_loss=0.00184]


epoch=23, loss=0.01508861992071415


100%|██████████| 18/18 [00:00<00:00, 64.51batch/s, train_loss=0.00169]


epoch=24, loss=0.014641115945988688


100%|██████████| 18/18 [00:00<00:00, 87.86batch/s, train_loss=0.0021]   


epoch=25, loss=0.014643714302572712


100%|██████████| 18/18 [00:00<00:00, 57.55batch/s, train_loss=0.00159] 


epoch=26, loss=0.013968993467503582


100%|██████████| 18/18 [00:00<00:00, 35.23batch/s, train_loss=0.00148] 


epoch=27, loss=0.013651461511850359


100%|██████████| 18/18 [00:00<00:00, 42.65batch/s, train_loss=0.00166] 


epoch=28, loss=0.013624645785011096


100%|██████████| 18/18 [00:00<00:00, 49.21batch/s, train_loss=0.00193] 


epoch=29, loss=0.013637118998272666


100%|██████████| 18/18 [00:00<00:00, 19.06batch/s, train_loss=0.002]   


epoch=30, loss=0.01356735977119413


100%|██████████| 18/18 [00:00<00:00, 26.49batch/s, train_loss=0.00196] 


epoch=31, loss=0.013325002370209529


100%|██████████| 18/18 [00:00<00:00, 48.76batch/s, train_loss=0.00164] 


epoch=32, loss=0.012898232369587337


100%|██████████| 18/18 [00:00<00:00, 35.20batch/s, train_loss=0.00158] 


epoch=33, loss=0.012757887422010816


100%|██████████| 18/18 [00:00<00:00, 48.21batch/s, train_loss=0.00138] 


epoch=34, loss=0.012506732647788937


100%|██████████| 18/18 [00:00<00:00, 59.66batch/s, train_loss=0.00157] 


epoch=35, loss=0.012569092670391345


100%|██████████| 18/18 [00:00<00:00, 67.33batch/s, train_loss=0.00134] 


epoch=36, loss=0.01223512912002103


100%|██████████| 18/18 [00:00<00:00, 48.86batch/s, train_loss=0.00149] 


epoch=37, loss=0.012233913331196224


100%|██████████| 18/18 [00:00<00:00, 69.75batch/s, train_loss=0.00136] 


epoch=38, loss=0.012119631891620571


100%|██████████| 18/18 [00:00<00:00, 38.68batch/s, train_loss=0.00113] 


epoch=39, loss=0.011862889054520376


100%|██████████| 18/18 [00:00<00:00, 31.41batch/s, train_loss=0.00156] 


epoch=40, loss=0.012126652686760341


100%|██████████| 18/18 [00:00<00:00, 23.86batch/s, train_loss=0.00158] 


epoch=41, loss=0.012104234854722843


100%|██████████| 18/18 [00:00<00:00, 28.37batch/s, train_loss=0.00141] 


epoch=42, loss=0.011856966771956148


100%|██████████| 18/18 [00:00<00:00, 45.12batch/s, train_loss=0.00129] 


epoch=43, loss=0.01170292911858394


100%|██████████| 18/18 [00:00<00:00, 64.91batch/s, train_loss=0.00157] 


epoch=44, loss=0.011902371007820656


100%|██████████| 18/18 [00:00<00:00, 65.47batch/s, train_loss=0.00169] 


epoch=45, loss=0.011946059367780029


100%|██████████| 18/18 [00:00<00:00, 39.53batch/s, train_loss=0.00161] 


epoch=46, loss=0.011815503164612013


100%|██████████| 18/18 [00:00<00:00, 67.16batch/s, train_loss=0.00141] 


epoch=47, loss=0.011619970632010495


100%|██████████| 18/18 [00:00<00:00, 71.96batch/s, train_loss=0.0016]  


epoch=48, loss=0.011703841397474551


100%|██████████| 18/18 [00:00<00:00, 53.17batch/s, train_loss=0.00126] 


epoch=49, loss=0.011387174252806036


100%|██████████| 18/18 [00:00<00:00, 60.15batch/s, train_loss=0.00165] 


epoch=50, loss=0.011689461938266097


100%|██████████| 18/18 [00:00<00:00, 29.10batch/s, train_loss=0.00112] 


epoch=51, loss=0.011212628335788332


100%|██████████| 18/18 [00:00<00:00, 55.12batch/s, train_loss=0.00136] 


epoch=52, loss=0.01137335842231224


100%|██████████| 18/18 [00:00<00:00, 48.72batch/s, train_loss=0.00145] 


epoch=53, loss=0.011443277657032012


100%|██████████| 18/18 [00:00<00:00, 74.16batch/s, train_loss=0.00127] 


epoch=54, loss=0.011191426000718412


100%|██████████| 18/18 [00:00<00:00, 76.71batch/s, train_loss=0.00155] 


epoch=55, loss=0.011447990990918259


100%|██████████| 18/18 [00:00<00:00, 59.48batch/s, train_loss=0.00138] 


epoch=56, loss=0.011256182243084086


100%|██████████| 18/18 [00:00<00:00, 75.39batch/s, train_loss=0.0013]  


epoch=57, loss=0.01117065344391198


100%|██████████| 18/18 [00:00<00:00, 77.80batch/s, train_loss=0.00163] 


epoch=58, loss=0.011412690320919301


100%|██████████| 18/18 [00:00<00:00, 58.68batch/s, train_loss=0.00121] 


epoch=59, loss=0.010999818883065518


100%|██████████| 18/18 [00:00<00:00, 78.78batch/s, train_loss=0.00123] 


epoch=60, loss=0.011033524824627515


100%|██████████| 18/18 [00:00<00:00, 76.87batch/s, train_loss=0.00111] 


epoch=61, loss=0.010931790250128713


100%|██████████| 18/18 [00:00<00:00, 60.04batch/s, train_loss=0.00149]


epoch=62, loss=0.011182150031985909


100%|██████████| 18/18 [00:00<00:00, 77.94batch/s, train_loss=0.0014]  


epoch=63, loss=0.011146392148116538


100%|██████████| 18/18 [00:00<00:00, 88.74batch/s, train_loss=0.00113] 


epoch=64, loss=0.010854975544173143


100%|██████████| 18/18 [00:00<00:00, 58.43batch/s, train_loss=0.00198] 


epoch=65, loss=0.011590972148138904


100%|██████████| 18/18 [00:00<00:00, 77.92batch/s, train_loss=0.00133] 


epoch=66, loss=0.011000273217414987


100%|██████████| 18/18 [00:00<00:00, 80.26batch/s, train_loss=0.00128]


epoch=67, loss=0.010952967426900205


100%|██████████| 18/18 [00:00<00:00, 58.84batch/s, train_loss=0.00113]


epoch=68, loss=0.010785556015269511


100%|██████████| 18/18 [00:00<00:00, 67.94batch/s, train_loss=0.00152] 


epoch=69, loss=0.011099390512910381


100%|██████████| 18/18 [00:00<00:00, 74.91batch/s, train_loss=0.00132] 


epoch=70, loss=0.01091842549011625


100%|██████████| 18/18 [00:00<00:00, 58.07batch/s, train_loss=0.00101] 


epoch=71, loss=0.010670775893433342


100%|██████████| 18/18 [00:00<00:00, 84.70batch/s, train_loss=0.00098]


epoch=72, loss=0.010630900467264242


100%|██████████| 18/18 [00:00<00:00, 88.84batch/s, train_loss=0.00144] 


epoch=73, loss=0.010983074768863876


100%|██████████| 18/18 [00:00<00:00, 57.10batch/s, train_loss=0.00141] 


epoch=74, loss=0.010986764982856551


100%|██████████| 18/18 [00:00<00:00, 78.17batch/s, train_loss=0.000931]


epoch=75, loss=0.01058172884172407


100%|██████████| 18/18 [00:00<00:00, 81.28batch/s, train_loss=0.00174]


epoch=76, loss=0.011219659782689193


100%|██████████| 18/18 [00:00<00:00, 59.25batch/s, train_loss=0.00114] 


epoch=77, loss=0.010682949155569077


100%|██████████| 18/18 [00:00<00:00, 76.33batch/s, train_loss=0.00125] 


epoch=78, loss=0.010787951480725717


100%|██████████| 18/18 [00:00<00:00, 72.75batch/s, train_loss=0.0014]  


epoch=79, loss=0.010864962835764061


100%|██████████| 18/18 [00:00<00:00, 78.52batch/s, train_loss=0.00136] 


epoch=80, loss=0.010844634256486234


100%|██████████| 18/18 [00:00<00:00, 58.73batch/s, train_loss=0.00148] 


epoch=81, loss=0.010967199950382625


100%|██████████| 18/18 [00:00<00:00, 85.91batch/s, train_loss=0.000997]


epoch=82, loss=0.010541122527471905


100%|██████████| 18/18 [00:00<00:00, 85.71batch/s, train_loss=0.00107]


epoch=83, loss=0.01057154871009547


100%|██████████| 18/18 [00:00<00:00, 59.52batch/s, train_loss=0.00114] 


epoch=84, loss=0.010643825430294563


100%|██████████| 18/18 [00:00<00:00, 78.27batch/s, train_loss=0.00161]


epoch=85, loss=0.01102953392061694


100%|██████████| 18/18 [00:00<00:00, 64.53batch/s, train_loss=0.00105]


epoch=86, loss=0.010548803306859115


100%|██████████| 18/18 [00:00<00:00, 45.93batch/s, train_loss=0.00167] 


epoch=87, loss=0.01105461206415604


100%|██████████| 18/18 [00:00<00:00, 80.80batch/s, train_loss=0.0012]  


epoch=88, loss=0.010606162642610486


100%|██████████| 18/18 [00:00<00:00, 85.46batch/s, train_loss=0.00145]


epoch=89, loss=0.010881263182081026


100%|██████████| 18/18 [00:00<00:00, 61.01batch/s, train_loss=0.0017] 


epoch=90, loss=0.011000490454764202


100%|██████████| 18/18 [00:00<00:00, 64.73batch/s, train_loss=0.00142] 


epoch=91, loss=0.010814318221190881


100%|██████████| 18/18 [00:00<00:00, 64.63batch/s, train_loss=0.00117] 


epoch=92, loss=0.01056087257738771


100%|██████████| 18/18 [00:00<00:00, 49.72batch/s, train_loss=0.00143] 


epoch=93, loss=0.010864609404884536


100%|██████████| 18/18 [00:00<00:00, 70.25batch/s, train_loss=0.00102] 


epoch=94, loss=0.010464558875766292


100%|██████████| 18/18 [00:00<00:00, 65.72batch/s, train_loss=0.0016]  


epoch=95, loss=0.010958411687407


100%|██████████| 18/18 [00:00<00:00, 47.33batch/s, train_loss=0.00143] 


epoch=96, loss=0.01080928281668959


100%|██████████| 18/18 [00:00<00:00, 61.29batch/s, train_loss=0.00121] 


epoch=97, loss=0.01061578077283399


100%|██████████| 18/18 [00:00<00:00, 62.52batch/s, train_loss=0.00117] 


epoch=98, loss=0.010581601203515612


100%|██████████| 18/18 [00:00<00:00, 59.55batch/s, train_loss=0.000914]


epoch=99, loss=0.010408482466792239


100%|██████████| 18/18 [00:00<00:00, 45.81batch/s, train_loss=0.00158] 


epoch=100, loss=0.010911776545746573


In [5]:
valence_outliers = mine_outliers_scipy(model=valence_model, data_converter=data_converter)
arousal_outliers = mine_outliers_scipy(model=arousal_model, data_converter=data_converter)

items_group_by_users = data_converter.original_df.groupby("user_id")
combined_outliers = {}
for user_id, valence_dist in valence_outliers.items():
    arousal_dist = arousal_outliers[user_id]
    combined_outliers[user_id] = valence_dist + arousal_dist

combined_outliers = dict(sorted(combined_outliers.items(), key=lambda item: item[1]))
for user_id, item_id in combined_outliers.items():
    number_of_items = len(items_group_by_users.get_group(user_id))
    print(f"user: {user_id}, dist: {item_id}, #items: {number_of_items}")

user: 2a6b63b7690efa2390c8d9fee11b1407, dist: -24.267710561123145, #items: 3
user: ad3b997c4f2382a66e49f035cacfa682, dist: -12.296149919475926, #items: 3
user: 65794ea9f5122952403585a237bc5e52, dist: 3.9826956566666603, #items: 3
user: 374a5659c02e12b01db6319436f17a7d, dist: 9.869268496983022, #items: 3
user: 615d836ba25132081e0ebd2182221a59, dist: 11.557994451083452, #items: 6
user: 623681f76a3eab5d9c86fbc0e1ca264b, dist: 12.954068824121107, #items: 12
user: fd5b08ce362d855ca9152a894348130c, dist: 15.645982217473833, #items: 222
user: da37d1548ffd0631809f7be341e4fe4d, dist: 20.323636435677617, #items: 3
user: a30d244141cb2f51e0803e79bc4bd147, dist: 23.57655772998743, #items: 985
user: 6222da90667e5b0de990ce6c26dcfa15, dist: 25.76786325272345, #items: 12
user: 46a2289decf79f747406fa91cd92fc27, dist: 29.46003879800084, #items: 333
user: 00de940f0b5cfc82cca4791199e3bfb3, dist: 30.80378344960348, #items: 751
user: 485d8e33a731a830ef0aebd71b016d08, dist: 33.635854594780135, #items: 6
user: