In [2]:
from pandas import read_csv
from torch.optim import SGD
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torch import randperm
from config import DATA_DIR
from src.data_set import RatingsDataset
from src.loss import MiningOutliersLoss
from src.model import MF
from src.runner import Runner
from src.utils import create_dataset, mine_outliers, DataConverter, DataProcessor

"""
The Deam dataset is based on Arousal-Valence 2D emotional model.
The Valence/Arousal ratings were collected using Amazon Mechanical Turks service.
Each turk from the collected crowd were asked to mark his own emotion for the current song on a 2D plane, Arousal/Valence.
For more information please read: https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0173392
"""

DF_PATH = f"{DATA_DIR}" \
          f"/DEAM/annotations/annotations per each rater/" \
          f"song_level/static_annotations_songs_1_2000.csv"

def select_n_random(trainset: RatingsDataset):
    """
    Selects n random data points and their corresponding labels from a dataset
    """
    perm = randperm(len(trainset))
    return trainset[perm][:100]

In [None]:
"""
This block of code calculates the outliers alongside the valence axis
"""
columns = ["workerID", "SongId", "Valence"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]

data_converter = DataConverter(
        original_df=original_df, n_random_users=10, n_ratings_per_random_user=200
    )
data_processor = DataProcessor(original_df=data_converter.original_df)

valence_model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
)

criterion = MiningOutliersLoss(data_converter=data_converter, data_processor=data_processor)
optimizer = SGD(valence_model.parameters(), lr=5, weight_decay=1e-5)
runner = Runner(
    model=valence_model,
    criterion=criterion,
    optimizer=optimizer,
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=1000, shuffle=True)
users, items, ratings = select_n_random(train_set)

epochs = 50
with SummaryWriter("runs/DEAM_experiment_1") as writer:
    writer.add_graph(valence_model, (users, items))

    for epoch in range(epochs):
        epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
        print(f"epoch={epoch + 1}, loss={epoch_loss}")

100%|██████████| 20/20 [00:31<00:00,  1.60s/batch, train_loss=0.467]


epoch=1, loss=11.493567992374818


100%|██████████| 20/20 [00:24<00:00,  1.20s/batch, train_loss=0.364]


epoch=2, loss=7.903910919715619


100%|██████████| 20/20 [00:27<00:00,  1.37s/batch, train_loss=0.306]


epoch=3, loss=6.59335493127231


100%|██████████| 20/20 [00:23<00:00,  1.19s/batch, train_loss=0.28] 


epoch=4, loss=5.686996783552498


100%|██████████| 20/20 [00:20<00:00,  1.04s/batch, train_loss=0.275]


epoch=5, loss=5.0058617873356255


100%|██████████| 20/20 [00:21<00:00,  1.07s/batch, train_loss=0.21] 


epoch=6, loss=4.431820030870108


100%|██████████| 20/20 [00:22<00:00,  1.14s/batch, train_loss=0.203]


epoch=7, loss=4.043893142831736


100%|██████████| 20/20 [00:20<00:00,  1.03s/batch, train_loss=0.185]


epoch=8, loss=3.6776609365529036


100%|██████████| 20/20 [00:23<00:00,  1.17s/batch, train_loss=0.16] 


epoch=9, loss=3.340875648103912


100%|██████████| 20/20 [00:21<00:00,  1.08s/batch, train_loss=0.171]


epoch=10, loss=3.0644424573964084


100%|██████████| 20/20 [00:17<00:00,  1.14batch/s, train_loss=0.124]


epoch=11, loss=2.766368329146812


100%|██████████| 20/20 [00:18<00:00,  1.07batch/s, train_loss=0.134]


epoch=12, loss=2.4753091285310944


100%|██████████| 20/20 [00:20<00:00,  1.01s/batch, train_loss=0.103] 


epoch=13, loss=2.190321176792013


100%|██████████| 20/20 [00:19<00:00,  1.01batch/s, train_loss=0.0969]


epoch=14, loss=1.9174654922485348


100%|██████████| 20/20 [00:21<00:00,  1.09s/batch, train_loss=0.0869]


epoch=15, loss=1.7007017655865901


100%|██████████| 20/20 [00:22<00:00,  1.11s/batch, train_loss=0.0684]


epoch=16, loss=1.5250331412677107


100%|██████████| 20/20 [00:23<00:00,  1.17s/batch, train_loss=0.0668]


epoch=17, loss=1.3763928622212904


100%|██████████| 20/20 [00:19<00:00,  1.03batch/s, train_loss=0.0721]


epoch=18, loss=1.2505457648573248


100%|██████████| 20/20 [00:26<00:00,  1.30s/batch, train_loss=0.0607]


epoch=19, loss=1.1387970248255237


100%|██████████| 20/20 [00:25<00:00,  1.30s/batch, train_loss=0.0395]


epoch=20, loss=1.0134161180956611


100%|██████████| 20/20 [00:19<00:00,  1.01batch/s, train_loss=0.0368]


epoch=21, loss=0.9291339572709183


 10%|█         | 2/20 [00:03<00:26,  1.48s/batch, train_loss=0.0372]

In [3]:
"""
This block of code calculates the outliers alongside the arousal axis
"""
columns = ["workerID", "SongId", "Arousal"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]

data_converter = DataConverter(
        original_df=original_df, n_random_users=10, n_ratings_per_random_user=200
    )
data_processor = DataProcessor(original_df=data_converter.original_df)

arousal_model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
)

criterion = MiningOutliersLoss(data_converter=data_converter, data_processor=data_processor)
optimizer = SGD(arousal_model.parameters(), lr=5, weight_decay=1e-5)
runner = Runner(
    model=valence_model,
    criterion=criterion,
    optimizer=optimizer,
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=1000, shuffle=True)
users, items, ratings = select_n_random(train_set)

epochs = 50
with SummaryWriter("runs/DEAM_experiment_1") as writer:
    writer.add_graph(arousal_model, (users, items))

    for epoch in range(epochs):
        epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
        print(f"epoch={epoch + 1}, loss={epoch_loss}")

100%|██████████| 18/18 [00:07<00:00,  2.48batch/s, train_loss=0.215]


epoch=1, loss=4.8530552320072164


100%|██████████| 18/18 [00:07<00:00,  2.42batch/s, train_loss=0.184]


epoch=2, loss=3.582961639641788


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.156]


epoch=3, loss=2.9279809759711473


100%|██████████| 18/18 [00:07<00:00,  2.41batch/s, train_loss=0.131]


epoch=4, loss=2.4210107844608766


100%|██████████| 18/18 [00:07<00:00,  2.38batch/s, train_loss=0.115]


epoch=5, loss=2.0669470141518445


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.0932]


epoch=6, loss=1.7427693262359971


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.0817]


epoch=7, loss=1.4450892597005522


100%|██████████| 18/18 [00:07<00:00,  2.42batch/s, train_loss=0.0657]


epoch=8, loss=1.186641391902582


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.0539]


epoch=9, loss=0.9511631658234948


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.0478]


epoch=10, loss=0.755659080490528


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.0336]


epoch=11, loss=0.5950050783713968


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.0224]


epoch=12, loss=0.4832098444482231


100%|██████████| 18/18 [00:07<00:00,  2.36batch/s, train_loss=0.024] 


epoch=13, loss=0.4126446677330403


100%|██████████| 18/18 [00:07<00:00,  2.41batch/s, train_loss=0.018] 


epoch=14, loss=0.3607474857805304


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.0216]


epoch=15, loss=0.32677929757262947


100%|██████████| 18/18 [00:07<00:00,  2.41batch/s, train_loss=0.0152]


epoch=16, loss=0.29571252019859934


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.0201]


epoch=17, loss=0.27952176673310275


100%|██████████| 18/18 [00:07<00:00,  2.41batch/s, train_loss=0.0161]


epoch=18, loss=0.25950544421125477


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.0152]


epoch=19, loss=0.24703320566689455


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.012] 


epoch=20, loss=0.2347337348693076


100%|██████████| 18/18 [00:07<00:00,  2.38batch/s, train_loss=0.0153] 


epoch=21, loss=0.22741118147215492


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.0141]


epoch=22, loss=0.21884117284915788


100%|██████████| 18/18 [00:07<00:00,  2.38batch/s, train_loss=0.0159] 


epoch=23, loss=0.21336611238249548


100%|██████████| 18/18 [00:07<00:00,  2.42batch/s, train_loss=0.00796]


epoch=24, loss=0.20264165859370845


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.00899]


epoch=25, loss=0.19722299874895738


100%|██████████| 18/18 [00:07<00:00,  2.41batch/s, train_loss=0.0112] 


epoch=26, loss=0.19410287695265005


100%|██████████| 18/18 [00:07<00:00,  2.38batch/s, train_loss=0.00787]


epoch=27, loss=0.18797647948209414


100%|██████████| 18/18 [00:07<00:00,  2.41batch/s, train_loss=0.0126] 


epoch=28, loss=0.18716849370503702


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.0117] 


epoch=29, loss=0.18413281504560539


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.0108] 


epoch=30, loss=0.1816701666334724


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.0114] 


epoch=31, loss=0.17905734159232115


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.0104] 


epoch=32, loss=0.17499821539704438


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.00965]


epoch=33, loss=0.17216927585713132


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.0129] 


epoch=34, loss=0.1733520513582786


100%|██████████| 18/18 [00:07<00:00,  2.38batch/s, train_loss=0.0132] 


epoch=35, loss=0.17185413415329928


100%|██████████| 18/18 [00:07<00:00,  2.42batch/s, train_loss=0.00999]


epoch=36, loss=0.167853492426965


100%|██████████| 18/18 [00:07<00:00,  2.38batch/s, train_loss=0.0124] 


epoch=37, loss=0.16872088284065753


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.0114] 


epoch=38, loss=0.16630374486529872


100%|██████████| 18/18 [00:07<00:00,  2.38batch/s, train_loss=0.00865]


epoch=39, loss=0.1647232688966892


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.00924]


epoch=40, loss=0.1645201262388712


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.00994]


epoch=41, loss=0.16426420393064328


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.0109] 


epoch=42, loss=0.1635742612775662


100%|██████████| 18/18 [00:07<00:00,  2.35batch/s, train_loss=0.0082] 


epoch=43, loss=0.162080067417501


100%|██████████| 18/18 [00:07<00:00,  2.41batch/s, train_loss=0.0114] 


epoch=44, loss=0.16334514247582582


100%|██████████| 18/18 [00:07<00:00,  2.39batch/s, train_loss=0.0108] 


epoch=45, loss=0.1621288004385358


100%|██████████| 18/18 [00:07<00:00,  2.41batch/s, train_loss=0.0108] 


epoch=46, loss=0.16170301209535115


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.0122] 


epoch=47, loss=0.16275515672854413


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.00798]


epoch=48, loss=0.15975271047042963


100%|██████████| 18/18 [00:07<00:00,  2.40batch/s, train_loss=0.0101] 


epoch=49, loss=0.16075622285108157


100%|██████████| 18/18 [00:07<00:00,  2.37batch/s, train_loss=0.00942]

epoch=50, loss=0.16043610548509238





In [4]:
valence_outliers = mine_outliers(model=valence_model, data_converter=data_converter)
arousal_outliers = mine_outliers(model=arousal_model, data_converter=data_converter)

items_group_by_users = data_converter.original_df.groupby("user_id")
combined_outliers = {}
for user_id, valence_dist in valence_outliers.items():
    arousal_dist = arousal_outliers[user_id]
    combined_outliers[user_id] = valence_dist + arousal_dist

combined_outliers = dict(sorted(combined_outliers.items(), key=lambda item: item[1]))
for user_id, item_id in combined_outliers.items():
    number_of_items = len(items_group_by_users.get_group(user_id))
    print(f"user: {user_id}, dist: {item_id}, #items: {number_of_items}")

user: d88c800327bffffea5562e23c276ede3, dist: -8.16646671295166, #items: 2
user: b37092cf23b42f8b8497d8ba89be157a, dist: -7.9408345222473145, #items: 2
user: 65794ea9f5122952403585a237bc5e52, dist: -6.504777908325195, #items: 3
user: 49be5653eaba26f1eb60fd9d63f23502, dist: -3.5293586254119873, #items: 3
user: 651938620e6e6c78bfa7854784fe62c2, dist: -2.5150487422943115, #items: 3
user: 15ec33e862185406170ff931583b014f, dist: -2.2195191383361816, #items: 4
user: d0c51e42ea093dc9a9a98ef888637c8e, dist: -2.0096826553344727, #items: 2
user: da37d1548ffd0631809f7be341e4fe4d, dist: -1.3743306398391724, #items: 3
user: 38531641e6c0628757776b0088bcc854, dist: -0.7693064212799072, #items: 7
user: 807f0025a626896f04566aa37cfbce0d, dist: -0.43722257018089294, #items: 3
user: a186cdd58a92051b7c73adc9bd6e65ca, dist: -0.29314088821411133, #items: 7
user: 27f51a4a7fe8565d26cadb88584441e5, dist: -0.279450386762619, #items: 2
user: 5b044cf509da1d8444b6f60c465240ef, dist: -0.1574951410293579, #items: 3
u