In [1]:
from pandas import read_csv
from torch.optim import SGD
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torch import randperm
from config import DATA_DIR
from src.data_set import RatingsDataset
from src.model import MF
from src.runner import Runner
from src.utils import (
    create_dataset,
    mine_outliers_scipy,
    mine_outliers_sklearn,
    mine_outliers_torch,
    DataConverter,
    DataProcessor,
    mean_centralised
)
from src.consistency import direct_calculation

"""
The Deam dataset is based on Arousal-Valence 2D emotional model.
The Valence/Arousal ratings were collected using Amazon Mechanical Turks service.
Each turk from the collected crowd were asked to mark his own emotion for the current song on a 2D plane, Arousal/Valence.
For more information please read: https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0173392
"""

DF_PATH = f"{DATA_DIR}" \
          f"/DEAM/annotations/annotations per each rater/" \
          f"song_level/static_annotations_songs_1_2000_mean_centralised.csv"

def select_n_random(trainset: RatingsDataset):
    """
    Selects n random data points and their corresponding labels from a dataset
    """
    perm = randperm(len(trainset))
    return trainset[perm][:100]

In [2]:
"""
This block analyze mean-centralized data consistency using the direct calculation defined by:
consistency += row.rating - row.song.mean() for all rows in dataset
"""
columns = ["user_id", "item_id", "rating"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]
mean_centralised(dataframe=original_df)
consistency = direct_calculation(data_frame=original_df)
print(f"Raw data consistency according to direct calculation is: \x1b[31m{consistency}\x1b[0m")

  0%|          | 0/3 [00:00<?, ?it/s]

KeyError: 'rating'

In [2]:
"""
This block of code calculates the outliers alongside the valence axis
"""
columns = ["workerID", "SongId", "Valence"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]
original_df = mean_centralised(dataframe=original_df)

data_converter = DataConverter(
        original_df=original_df, n_random_users=10, n_ratings_per_random_user=9
)

valence_model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
)
epochs = 100

criterion = MSELoss()
optimizer = SGD(valence_model.parameters(), lr=5, weight_decay=1e-3)
runner = Runner(
    model=valence_model,
    criterion=criterion,
    optimizer=optimizer,
    epochs=epochs
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=1000, shuffle=True)
users, items, ratings = select_n_random(train_set)

with SummaryWriter("runs/DEAM/valence") as writer:
    writer.add_graph(valence_model, (users, items))

    for epoch in range(epochs):
        epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
        print(f"epoch={epoch + 1}, loss={epoch_loss}")

_mean_centralised: 100%|██████████| 17464/17464 [00:03<00:00, 4730.12it/s]
100%|██████████| 18/18 [00:00<00:00, 82.71batch/s, train_loss=0.00541]


epoch=1, loss=0.06094814507797737


100%|██████████| 18/18 [00:00<00:00, 100.18batch/s, train_loss=0.00597]


epoch=2, loss=0.06078394703486335


100%|██████████| 18/18 [00:00<00:00, 87.67batch/s, train_loss=0.0059] 


epoch=3, loss=0.05954172505798753


100%|██████████| 18/18 [00:00<00:00, 108.87batch/s, train_loss=0.00504]


epoch=4, loss=0.0553037546455645


100%|██████████| 18/18 [00:00<00:00, 103.85batch/s, train_loss=0.00476]


epoch=5, loss=0.04786025161949735


100%|██████████| 18/18 [00:00<00:00, 90.32batch/s, train_loss=0.00366]


epoch=6, loss=0.0396903656970723


100%|██████████| 18/18 [00:00<00:00, 115.20batch/s, train_loss=0.00316]


epoch=7, loss=0.0336275645257764


100%|██████████| 18/18 [00:00<00:00, 91.06batch/s, train_loss=0.00282]


epoch=8, loss=0.02946694544162131


100%|██████████| 18/18 [00:00<00:00, 103.81batch/s, train_loss=0.00237]


epoch=9, loss=0.026311553996392536


100%|██████████| 18/18 [00:00<00:00, 77.13batch/s, train_loss=0.00228] 


epoch=10, loss=0.02392782592343079


100%|██████████| 18/18 [00:00<00:00, 90.71batch/s, train_loss=0.00217]


epoch=11, loss=0.021821238131729703


100%|██████████| 18/18 [00:00<00:00, 73.52batch/s, train_loss=0.002] 


epoch=12, loss=0.020054897148901806


100%|██████████| 18/18 [00:00<00:00, 67.70batch/s, train_loss=0.00176] 


epoch=13, loss=0.01850904777635306


100%|██████████| 18/18 [00:00<00:00, 119.22batch/s, train_loss=0.00177] 


epoch=14, loss=0.017358578108277994


100%|██████████| 18/18 [00:00<00:00, 89.53batch/s, train_loss=0.00165] 


epoch=15, loss=0.016252334740618078


100%|██████████| 18/18 [00:00<00:00, 108.76batch/s, train_loss=0.0016]  


epoch=16, loss=0.015397468968203782


100%|██████████| 18/18 [00:00<00:00, 75.40batch/s, train_loss=0.00147]  


epoch=17, loss=0.01456084486982022


100%|██████████| 18/18 [00:00<00:00, 117.02batch/s, train_loss=0.00146] 


epoch=18, loss=0.013933922126835432


100%|██████████| 18/18 [00:00<00:00, 114.47batch/s, train_loss=0.00124] 


epoch=19, loss=0.013208779066478304


100%|██████████| 18/18 [00:00<00:00, 88.12batch/s, train_loss=0.00146]


epoch=20, loss=0.012917098163697693


100%|██████████| 18/18 [00:00<00:00, 113.04batch/s, train_loss=0.00115] 


epoch=21, loss=0.012282135931808597


100%|██████████| 18/18 [00:00<00:00, 91.41batch/s, train_loss=0.00126] 


epoch=22, loss=0.012002097581267789


100%|██████████| 18/18 [00:00<00:00, 117.18batch/s, train_loss=0.00108] 


epoch=23, loss=0.011576259160084843


100%|██████████| 18/18 [00:00<00:00, 109.29batch/s, train_loss=0.00134] 


epoch=24, loss=0.011484834173095785


100%|██████████| 18/18 [00:00<00:00, 87.43batch/s, train_loss=0.00107]


epoch=25, loss=0.01106275428740126


100%|██████████| 18/18 [00:00<00:00, 115.97batch/s, train_loss=0.001]   


epoch=26, loss=0.010801246668672732


100%|██████████| 18/18 [00:00<00:00, 86.58batch/s, train_loss=0.00102]


epoch=27, loss=0.010650357319667451


100%|██████████| 18/18 [00:00<00:00, 108.23batch/s, train_loss=0.00101] 


epoch=28, loss=0.010456837111754538


100%|██████████| 18/18 [00:00<00:00, 87.00batch/s, train_loss=0.000958] 


epoch=29, loss=0.010271411168661357


100%|██████████| 18/18 [00:00<00:00, 107.72batch/s, train_loss=0.00106] 


epoch=30, loss=0.010238658978835767


100%|██████████| 18/18 [00:00<00:00, 117.37batch/s, train_loss=0.00109] 


epoch=31, loss=0.010140101463140564


100%|██████████| 18/18 [00:00<00:00, 89.11batch/s, train_loss=0.000948]


epoch=32, loss=0.009945880801023559


100%|██████████| 18/18 [00:00<00:00, 117.18batch/s, train_loss=0.00104] 


epoch=33, loss=0.009899177687047619


100%|██████████| 18/18 [00:00<00:00, 82.83batch/s, train_loss=0.000859]


epoch=34, loss=0.009681856770162547


100%|██████████| 18/18 [00:00<00:00, 108.41batch/s, train_loss=0.000991]


epoch=35, loss=0.009702568755575895


100%|██████████| 18/18 [00:00<00:00, 85.65batch/s, train_loss=0.00097]  


epoch=36, loss=0.009589485083999184


100%|██████████| 18/18 [00:00<00:00, 116.44batch/s, train_loss=0.00107] 


epoch=37, loss=0.009586916365563225


100%|██████████| 18/18 [00:00<00:00, 115.19batch/s, train_loss=0.00111] 


epoch=38, loss=0.009571164331173639


100%|██████████| 18/18 [00:00<00:00, 78.33batch/s, train_loss=0.000985]


epoch=39, loss=0.009411893893019817


100%|██████████| 18/18 [00:00<00:00, 119.36batch/s, train_loss=0.000928]


epoch=40, loss=0.009334851530915133


100%|██████████| 18/18 [00:00<00:00, 93.75batch/s, train_loss=0.000973]


epoch=41, loss=0.009286942095317566


100%|██████████| 18/18 [00:00<00:00, 121.89batch/s, train_loss=0.000944]


epoch=42, loss=0.00922769113987792


100%|██████████| 18/18 [00:00<00:00, 116.73batch/s, train_loss=0.000838]


epoch=43, loss=0.009120424749296063


100%|██████████| 18/18 [00:00<00:00, 92.67batch/s, train_loss=0.000924]


epoch=44, loss=0.00912650846800219


100%|██████████| 18/18 [00:00<00:00, 124.02batch/s, train_loss=0.00092] 


epoch=45, loss=0.00905084888427266


100%|██████████| 18/18 [00:00<00:00, 94.06batch/s, train_loss=0.00108]


epoch=46, loss=0.009156656246215429


100%|██████████| 18/18 [00:00<00:00, 121.71batch/s, train_loss=0.000839]


epoch=47, loss=0.008933796103573019


100%|██████████| 18/18 [00:00<00:00, 93.61batch/s, train_loss=0.000939]


epoch=48, loss=0.00898100371018644


100%|██████████| 18/18 [00:00<00:00, 124.22batch/s, train_loss=0.000911]


epoch=49, loss=0.008952625444864968


100%|██████████| 18/18 [00:00<00:00, 124.40batch/s, train_loss=0.000939]


epoch=50, loss=0.008933515845868562


100%|██████████| 18/18 [00:00<00:00, 96.38batch/s, train_loss=0.000913]


epoch=51, loss=0.008903035915070063


100%|██████████| 18/18 [00:00<00:00, 127.40batch/s, train_loss=0.00077] 


epoch=52, loss=0.008758784768275837


100%|██████████| 18/18 [00:00<00:00, 95.33batch/s, train_loss=0.000863]


epoch=53, loss=0.008806829313508869


100%|██████████| 18/18 [00:00<00:00, 125.57batch/s, train_loss=0.00085] 


epoch=54, loss=0.008772761566113911


100%|██████████| 18/18 [00:00<00:00, 92.60batch/s, train_loss=0.00104]  


epoch=55, loss=0.00887598709251046


100%|██████████| 18/18 [00:00<00:00, 117.21batch/s, train_loss=0.00104] 


epoch=56, loss=0.008872301675997918


100%|██████████| 18/18 [00:00<00:00, 127.03batch/s, train_loss=0.00103] 


epoch=57, loss=0.008846213233707614


100%|██████████| 18/18 [00:00<00:00, 101.11batch/s, train_loss=0.00122]


epoch=58, loss=0.00895187586632016


100%|██████████| 18/18 [00:00<00:00, 133.59batch/s, train_loss=0.000874]


epoch=59, loss=0.008651261609921818


100%|██████████| 18/18 [00:00<00:00, 94.09batch/s, train_loss=0.000814]


epoch=60, loss=0.008632509602966722


100%|██████████| 18/18 [00:00<00:00, 132.02batch/s, train_loss=0.000997]


epoch=61, loss=0.008768848138272977


100%|██████████| 18/18 [00:00<00:00, 126.53batch/s, train_loss=0.000845]


epoch=62, loss=0.008618426372535822


100%|██████████| 18/18 [00:00<00:00, 104.00batch/s, train_loss=0.000975]


epoch=63, loss=0.008707423531836981


100%|██████████| 18/18 [00:00<00:00, 129.47batch/s, train_loss=0.000872]


epoch=64, loss=0.008599406719100174


100%|██████████| 18/18 [00:00<00:00, 101.03batch/s, train_loss=0.000759]


epoch=65, loss=0.008529145813375603


100%|██████████| 18/18 [00:00<00:00, 124.58batch/s, train_loss=0.00102] 


epoch=66, loss=0.008694615820039482


100%|██████████| 18/18 [00:00<00:00, 92.22batch/s, train_loss=0.000851]


epoch=67, loss=0.008553055419065462


100%|██████████| 18/18 [00:00<00:00, 130.13batch/s, train_loss=0.000821]


epoch=68, loss=0.008523507755155597


100%|██████████| 18/18 [00:00<00:00, 127.33batch/s, train_loss=0.00114] 


epoch=69, loss=0.008758145368486535


100%|██████████| 18/18 [00:00<00:00, 95.99batch/s, train_loss=0.000959]


epoch=70, loss=0.008599895050809702


100%|██████████| 18/18 [00:00<00:00, 132.49batch/s, train_loss=0.00069] 


epoch=71, loss=0.008400601187337608


100%|██████████| 18/18 [00:00<00:00, 101.31batch/s, train_loss=0.000776]


epoch=72, loss=0.008468005591136023


100%|██████████| 18/18 [00:00<00:00, 134.38batch/s, train_loss=0.000805]


epoch=73, loss=0.008461287516118817


100%|██████████| 18/18 [00:00<00:00, 102.84batch/s, train_loss=0.000913]


epoch=74, loss=0.008548406488090646


100%|██████████| 18/18 [00:00<00:00, 132.74batch/s, train_loss=0.000932]


epoch=75, loss=0.008557328562766637


100%|██████████| 18/18 [00:00<00:00, 132.25batch/s, train_loss=0.001]   


epoch=76, loss=0.008586688164016401


100%|██████████| 18/18 [00:00<00:00, 101.08batch/s, train_loss=0.000798]


epoch=77, loss=0.008419108584899764


100%|██████████| 18/18 [00:00<00:00, 121.76batch/s, train_loss=0.000781]


epoch=78, loss=0.008417506901258167


100%|██████████| 18/18 [00:00<00:00, 83.48batch/s, train_loss=0.000881]


epoch=79, loss=0.008494522238143514


100%|██████████| 18/18 [00:00<00:00, 116.87batch/s, train_loss=0.000811]


epoch=80, loss=0.008415539929152395


100%|██████████| 18/18 [00:00<00:00, 126.84batch/s, train_loss=0.000778]


epoch=81, loss=0.008404972948967766


100%|██████████| 18/18 [00:00<00:00, 99.62batch/s, train_loss=0.000829]


epoch=82, loss=0.008430143194938825


100%|██████████| 18/18 [00:00<00:00, 126.12batch/s, train_loss=0.000822]


epoch=83, loss=0.00842412122901166


100%|██████████| 18/18 [00:00<00:00, 95.97batch/s, train_loss=0.000994]


epoch=84, loss=0.008520774500572292


100%|██████████| 18/18 [00:00<00:00, 127.51batch/s, train_loss=0.000764]


epoch=85, loss=0.008375604412077997


100%|██████████| 18/18 [00:00<00:00, 98.29batch/s, train_loss=0.000867]


epoch=86, loss=0.008432529157464686


100%|██████████| 18/18 [00:00<00:00, 129.71batch/s, train_loss=0.0012]  


epoch=87, loss=0.008660535102716853


100%|██████████| 18/18 [00:00<00:00, 126.92batch/s, train_loss=0.000827]


epoch=88, loss=0.008353821198109686


100%|██████████| 18/18 [00:00<00:00, 94.42batch/s, train_loss=0.000933]


epoch=89, loss=0.008458711426288214


100%|██████████| 18/18 [00:00<00:00, 126.66batch/s, train_loss=0.000931]


epoch=90, loss=0.008455822082417967


100%|██████████| 18/18 [00:00<00:00, 76.34batch/s, train_loss=0.000782]


epoch=91, loss=0.008368600310831724


100%|██████████| 18/18 [00:00<00:00, 92.80batch/s, train_loss=0.000745]


epoch=92, loss=0.00833303556685413


100%|██████████| 18/18 [00:00<00:00, 72.86batch/s, train_loss=0.000786]


epoch=93, loss=0.008348107977141542


100%|██████████| 18/18 [00:00<00:00, 87.42batch/s, train_loss=0.000936]


epoch=94, loss=0.008422330342582848


100%|██████████| 18/18 [00:00<00:00, 91.52batch/s, train_loss=0.000822]


epoch=95, loss=0.008371224005622553


100%|██████████| 18/18 [00:00<00:00, 77.54batch/s, train_loss=0.000757]


epoch=96, loss=0.008311983256133455


100%|██████████| 18/18 [00:00<00:00, 85.86batch/s, train_loss=0.00102]


epoch=97, loss=0.008487910872224436


100%|██████████| 18/18 [00:00<00:00, 71.36batch/s, train_loss=0.000954]


epoch=98, loss=0.008457367343485142


100%|██████████| 18/18 [00:00<00:00, 91.54batch/s, train_loss=0.00087] 


epoch=99, loss=0.008397215937126414


100%|██████████| 18/18 [00:00<00:00, 83.87batch/s, train_loss=0.000754]

epoch=100, loss=0.008303904206636578





In [3]:
"""
This block of code calculates the outliers alongside the arousal axis
"""
columns = ["workerID", "SongId", "Arousal"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]
original_df = mean_centralised(dataframe=original_df)

data_converter = DataConverter(
        original_df=original_df, n_random_users=10, n_ratings_per_random_user=9
)

arousal_model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
)
epochs = 100

criterion = MSELoss()
optimizer = SGD(arousal_model.parameters(), lr=5, weight_decay=1e-3)
runner = Runner(
    model=arousal_model,
    criterion=criterion,
    optimizer=optimizer,
    epochs=epochs,
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=1000, shuffle=True)
users, items, ratings = select_n_random(train_set)

with SummaryWriter("runs/DEAM/arousal") as writer:
    writer.add_graph(arousal_model, (users, items))

    for epoch in range(epochs):
        epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
        print(f"epoch={epoch + 1}, loss={epoch_loss}")

_mean_centralised: 100%|██████████| 17464/17464 [00:03<00:00, 4476.74it/s]
100%|██████████| 18/18 [00:00<00:00, 88.02batch/s, train_loss=0.00531]


epoch=1, loss=0.06177968693468115


100%|██████████| 18/18 [00:00<00:00, 99.79batch/s, train_loss=0.00548]


epoch=2, loss=0.061219564320808724


100%|██████████| 18/18 [00:00<00:00, 76.59batch/s, train_loss=0.00539]


epoch=3, loss=0.05880930537523346


100%|██████████| 18/18 [00:00<00:00, 99.93batch/s, train_loss=0.00424]


epoch=4, loss=0.05056525908050124


100%|██████████| 18/18 [00:00<00:00, 104.21batch/s, train_loss=0.00413]


epoch=5, loss=0.04294225284855288


100%|██████████| 18/18 [00:00<00:00, 126.04batch/s, train_loss=0.00367]


epoch=6, loss=0.03844450640161975


100%|██████████| 18/18 [00:00<00:00, 135.05batch/s, train_loss=0.0035] 


epoch=7, loss=0.03446296998912247


100%|██████████| 18/18 [00:00<00:00, 96.58batch/s, train_loss=0.00276]


epoch=8, loss=0.03055297167895073


100%|██████████| 18/18 [00:00<00:00, 124.81batch/s, train_loss=0.00273]


epoch=9, loss=0.02759223397501109


100%|██████████| 18/18 [00:00<00:00, 101.52batch/s, train_loss=0.00254]


epoch=10, loss=0.02492386337738175


100%|██████████| 18/18 [00:00<00:00, 134.45batch/s, train_loss=0.00204]


epoch=11, loss=0.022456975424332743


100%|██████████| 18/18 [00:00<00:00, 136.21batch/s, train_loss=0.00197]


epoch=12, loss=0.02074015529104088


100%|██████████| 18/18 [00:00<00:00, 97.95batch/s, train_loss=0.00183]


epoch=13, loss=0.019218768156822837


100%|██████████| 18/18 [00:00<00:00, 135.46batch/s, train_loss=0.00168] 


epoch=14, loss=0.017852716377711034


100%|██████████| 18/18 [00:00<00:00, 99.86batch/s, train_loss=0.00159]


epoch=15, loss=0.016716037929488434


100%|██████████| 18/18 [00:00<00:00, 126.75batch/s, train_loss=0.00145] 


epoch=16, loss=0.015654625729318132


100%|██████████| 18/18 [00:00<00:00, 97.92batch/s, train_loss=0.00175]  


epoch=17, loss=0.015065829108553242


100%|██████████| 18/18 [00:00<00:00, 126.21batch/s, train_loss=0.00155] 


epoch=18, loss=0.01419418268061717


100%|██████████| 18/18 [00:00<00:00, 124.62batch/s, train_loss=0.00138] 


epoch=19, loss=0.013487674159693808


100%|██████████| 18/18 [00:00<00:00, 100.60batch/s, train_loss=0.00124]


epoch=20, loss=0.012864570419495718


100%|██████████| 18/18 [00:00<00:00, 108.15batch/s, train_loss=0.00137] 


epoch=21, loss=0.01248558176940959


100%|██████████| 18/18 [00:00<00:00, 88.58batch/s, train_loss=0.00105] 


epoch=22, loss=0.011887845435512625


100%|██████████| 18/18 [00:00<00:00, 132.56batch/s, train_loss=0.000997]


epoch=23, loss=0.011544134702708316


100%|██████████| 18/18 [00:00<00:00, 107.79batch/s, train_loss=0.00134] 


epoch=24, loss=0.01148600603139788


100%|██████████| 18/18 [00:00<00:00, 78.90batch/s, train_loss=0.00107] 


epoch=25, loss=0.01104184166425402


100%|██████████| 18/18 [00:00<00:00, 103.66batch/s, train_loss=0.000977]


epoch=26, loss=0.010758811741099031


100%|██████████| 18/18 [00:00<00:00, 75.97batch/s, train_loss=0.00117]


epoch=27, loss=0.010670429566491814


100%|██████████| 18/18 [00:00<00:00, 110.63batch/s, train_loss=0.0011]  


epoch=28, loss=0.010483177275756634


100%|██████████| 18/18 [00:00<00:00, 84.58batch/s, train_loss=0.00107]  


epoch=29, loss=0.010304130183122647


100%|██████████| 18/18 [00:00<00:00, 113.77batch/s, train_loss=0.000947]


epoch=30, loss=0.010078343936144658


100%|██████████| 18/18 [00:00<00:00, 105.48batch/s, train_loss=0.00127]


epoch=31, loss=0.010188875414942147


100%|██████████| 18/18 [00:00<00:00, 81.54batch/s, train_loss=0.00117] 


epoch=32, loss=0.01000419374913085


100%|██████████| 18/18 [00:00<00:00, 101.16batch/s, train_loss=0.000967]


epoch=33, loss=0.00976070651272144


100%|██████████| 18/18 [00:00<00:00, 91.00batch/s, train_loss=0.00114] 


epoch=34, loss=0.009789351888081657


100%|██████████| 18/18 [00:00<00:00, 117.57batch/s, train_loss=0.000947]


epoch=35, loss=0.00957501825216875


100%|██████████| 18/18 [00:00<00:00, 120.30batch/s, train_loss=0.00106] 


epoch=36, loss=0.009574013761019448


100%|██████████| 18/18 [00:00<00:00, 82.63batch/s, train_loss=0.000976]


epoch=37, loss=0.009433566319490598


100%|██████████| 18/18 [00:00<00:00, 119.76batch/s, train_loss=0.000821]


epoch=38, loss=0.009268240552946979


100%|██████████| 18/18 [00:00<00:00, 91.32batch/s, train_loss=0.000936]


epoch=39, loss=0.009273213625922531


100%|██████████| 18/18 [00:00<00:00, 125.04batch/s, train_loss=0.000947]


epoch=40, loss=0.009234206074519278


100%|██████████| 18/18 [00:00<00:00, 93.98batch/s, train_loss=0.00096]  


epoch=41, loss=0.00922168349014723


100%|██████████| 18/18 [00:00<00:00, 121.44batch/s, train_loss=0.00108] 


epoch=42, loss=0.009228686052431698


100%|██████████| 18/18 [00:00<00:00, 117.38batch/s, train_loss=0.000786]


epoch=43, loss=0.008996064990113357


100%|██████████| 18/18 [00:00<00:00, 92.78batch/s, train_loss=0.000915]


epoch=44, loss=0.00904342381510924


100%|██████████| 18/18 [00:00<00:00, 123.35batch/s, train_loss=0.000903]


epoch=45, loss=0.009009467224782125


100%|██████████| 18/18 [00:00<00:00, 96.72batch/s, train_loss=0.000987]


epoch=46, loss=0.009036302367057181


100%|██████████| 18/18 [00:00<00:00, 125.68batch/s, train_loss=0.000973]


epoch=47, loss=0.008993057610756224


100%|██████████| 18/18 [00:00<00:00, 119.46batch/s, train_loss=0.000821]


epoch=48, loss=0.008829355460749637


100%|██████████| 18/18 [00:00<00:00, 98.94batch/s, train_loss=0.000833]


epoch=49, loss=0.008812095176226827


100%|██████████| 18/18 [00:00<00:00, 119.05batch/s, train_loss=0.000951]


epoch=50, loss=0.008880359372507364


100%|██████████| 18/18 [00:00<00:00, 73.57batch/s, train_loss=0.000966]


epoch=51, loss=0.008855069578339477


100%|██████████| 18/18 [00:00<00:00, 83.83batch/s, train_loss=0.00106]


epoch=52, loss=0.008902532685534618


100%|██████████| 18/18 [00:00<00:00, 61.69batch/s, train_loss=0.000848]


epoch=53, loss=0.008722343965450349


100%|██████████| 18/18 [00:00<00:00, 99.94batch/s, train_loss=0.00099]  


epoch=54, loss=0.008824968575355379


100%|██████████| 18/18 [00:00<00:00, 69.08batch/s, train_loss=0.000908]


epoch=55, loss=0.008755824677780646


100%|██████████| 18/18 [00:00<00:00, 65.21batch/s, train_loss=0.000901]


epoch=56, loss=0.008714007237327659


100%|██████████| 18/18 [00:00<00:00, 96.72batch/s, train_loss=0.00104] 


epoch=57, loss=0.008790068184533274


100%|██████████| 18/18 [00:00<00:00, 53.79batch/s, train_loss=0.000897]


epoch=58, loss=0.008679119224582768


100%|██████████| 18/18 [00:00<00:00, 53.58batch/s, train_loss=0.000891]


epoch=59, loss=0.008652061496077893


100%|██████████| 18/18 [00:00<00:00, 73.49batch/s, train_loss=0.000911]


epoch=60, loss=0.008669177349508886


100%|██████████| 18/18 [00:00<00:00, 63.76batch/s, train_loss=0.000792]


epoch=61, loss=0.008583724424942306


100%|██████████| 18/18 [00:00<00:00, 92.14batch/s, train_loss=0.000944]


epoch=62, loss=0.008650948699200627


100%|██████████| 18/18 [00:00<00:00, 70.59batch/s, train_loss=0.000838]


epoch=63, loss=0.00857929152207254


100%|██████████| 18/18 [00:00<00:00, 104.79batch/s, train_loss=0.001]  


epoch=64, loss=0.0086605291658145


100%|██████████| 18/18 [00:00<00:00, 72.41batch/s, train_loss=0.000878]


epoch=65, loss=0.008569704672489787


100%|██████████| 18/18 [00:00<00:00, 106.59batch/s, train_loss=0.000971]


epoch=66, loss=0.008628490057662938


100%|██████████| 18/18 [00:00<00:00, 100.85batch/s, train_loss=0.000902]


epoch=67, loss=0.00856293073221234


100%|██████████| 18/18 [00:00<00:00, 80.17batch/s, train_loss=0.000904]


epoch=68, loss=0.008561188391830087


100%|██████████| 18/18 [00:00<00:00, 112.03batch/s, train_loss=0.000989]


epoch=69, loss=0.008593135593062275


100%|██████████| 18/18 [00:00<00:00, 87.72batch/s, train_loss=0.000952]


epoch=70, loss=0.008585417657660233


100%|██████████| 18/18 [00:00<00:00, 113.90batch/s, train_loss=0.000746]


epoch=71, loss=0.008416415419496785


100%|██████████| 18/18 [00:00<00:00, 115.32batch/s, train_loss=0.00089] 


epoch=72, loss=0.008506804903599329


100%|██████████| 18/18 [00:00<00:00, 85.97batch/s, train_loss=0.000893]


epoch=73, loss=0.008522618005852406


100%|██████████| 18/18 [00:00<00:00, 106.14batch/s, train_loss=0.000915]


epoch=74, loss=0.008520061059871735


100%|██████████| 18/18 [00:00<00:00, 87.80batch/s, train_loss=0.00084]


epoch=75, loss=0.008453872442137893


100%|██████████| 18/18 [00:00<00:00, 117.25batch/s, train_loss=0.000875]


epoch=76, loss=0.008467059796682764


100%|██████████| 18/18 [00:00<00:00, 90.09batch/s, train_loss=0.000922] 


epoch=77, loss=0.008516995383944324


100%|██████████| 18/18 [00:00<00:00, 113.78batch/s, train_loss=0.00082] 


epoch=78, loss=0.008419798068944299


100%|██████████| 18/18 [00:00<00:00, 99.67batch/s, train_loss=0.000891] 


epoch=79, loss=0.008465827810957975


100%|██████████| 18/18 [00:00<00:00, 69.54batch/s, train_loss=0.00088] 


epoch=80, loss=0.008453096466374311


100%|██████████| 18/18 [00:00<00:00, 99.50batch/s, train_loss=0.000777]


epoch=81, loss=0.008362331789861087


100%|██████████| 18/18 [00:00<00:00, 85.13batch/s, train_loss=0.000889]


epoch=82, loss=0.008452140074667087


100%|██████████| 18/18 [00:00<00:00, 105.17batch/s, train_loss=0.000847]


epoch=83, loss=0.008433972404751964


100%|██████████| 18/18 [00:00<00:00, 117.07batch/s, train_loss=0.000825]


epoch=84, loss=0.008391477279929909


100%|██████████| 18/18 [00:00<00:00, 84.16batch/s, train_loss=0.00101]


epoch=85, loss=0.008537266037094033


100%|██████████| 18/18 [00:00<00:00, 109.22batch/s, train_loss=0.000907]


epoch=86, loss=0.008429921326008944


100%|██████████| 18/18 [00:00<00:00, 96.08batch/s, train_loss=0.000926]


epoch=87, loss=0.008442570243178722


100%|██████████| 18/18 [00:00<00:00, 124.29batch/s, train_loss=0.000691]


epoch=88, loss=0.008263153724937232


100%|██████████| 18/18 [00:00<00:00, 86.12batch/s, train_loss=0.000817] 


epoch=89, loss=0.00833997714702403


100%|██████████| 18/18 [00:00<00:00, 114.65batch/s, train_loss=0.000971]


epoch=90, loss=0.008460861172809498


100%|██████████| 18/18 [00:00<00:00, 82.20batch/s, train_loss=0.000957]


epoch=91, loss=0.008437148560685802


100%|██████████| 18/18 [00:00<00:00, 71.51batch/s, train_loss=0.00064] 


epoch=92, loss=0.008230023853400124


100%|██████████| 18/18 [00:00<00:00, 82.36batch/s, train_loss=0.000899]


epoch=93, loss=0.008391896040968946


100%|██████████| 18/18 [00:00<00:00, 73.23batch/s, train_loss=0.000964]


epoch=94, loss=0.008446841176774097


100%|██████████| 18/18 [00:00<00:00, 84.48batch/s, train_loss=0.000923]


epoch=95, loss=0.008400783438544843


100%|██████████| 18/18 [00:00<00:00, 87.85batch/s, train_loss=0.000849]


epoch=96, loss=0.00834931815176234


100%|██████████| 18/18 [00:00<00:00, 69.56batch/s, train_loss=0.000884]


epoch=97, loss=0.008387542540093192


100%|██████████| 18/18 [00:00<00:00, 83.60batch/s, train_loss=0.000886]


epoch=98, loss=0.008367749842172926


100%|██████████| 18/18 [00:00<00:00, 75.66batch/s, train_loss=0.000901]


epoch=99, loss=0.008399177120480727


100%|██████████| 18/18 [00:00<00:00, 94.28batch/s, train_loss=0.00084] 

epoch=100, loss=0.008326734231589934





In [4]:
valence_outliers = mine_outliers_scipy(model=valence_model, data_converter=data_converter)
arousal_outliers = mine_outliers_scipy(model=arousal_model, data_converter=data_converter)

items_group_by_users = data_converter.original_df.groupby("user_id")
combined_outliers = {}
for user_id, valence_dist in valence_outliers.items():
    arousal_dist = arousal_outliers[user_id]
    combined_outliers[user_id] = valence_dist + arousal_dist

combined_outliers = dict(sorted(combined_outliers.items(), key=lambda item: item[1]))
for user_id, item_id in combined_outliers.items():
    number_of_items = len(items_group_by_users.get_group(user_id))
    print(f"user: {user_id}, dist: {item_id}, #items: {number_of_items}")

user: random_guy_187, dist: -79.87926167514736, #items: 9
user: random_guy_196, dist: -61.504188142567614, #items: 9
user: random_guy_194, dist: -49.64625212429112, #items: 9
user: random_guy_188, dist: -38.71032156074537, #items: 9
user: 19fee46f2810f34a8b69a7768d897a59, dist: -36.5754049091067, #items: 1
user: random_guy_190, dist: -34.09597821371794, #items: 9
user: ff18a27328ffd40ef52b7ebb7a0ded94, dist: -17.543451446355363, #items: 20
user: random_guy_191, dist: -17.20264469597077, #items: 9
user: random_guy_189, dist: -4.519798138322379, #items: 9
user: random_guy_192, dist: -3.8970388278711994, #items: 9
user: fd5b08ce362d855ca9152a894348130c, dist: -2.962559945499283, #items: 222
user: e105c200f413d7b2c5850c0df4b9687e, dist: -1.7301799548594072, #items: 2
user: c3c21239b85dcdd6679fc212afd02a49, dist: 6.671131268637211, #items: 9
user: 7cecbffe1da5ae974952db6c13695afe, dist: 11.214156184287361, #items: 428
user: a0f5cedc3a2371ec13663226c4b44771, dist: 13.959845416054605, #items: