In [4]:
import torch
import os
from pandas import read_csv
from torch.optim import SGD
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from config import DATA_DIR, MODELS_DIR
from src.loss import Loss
from src.model import HistogramMF
from src.runner import Runner
from src.create_dataset import create_dataset, create_histogram_features
from src.data_processor import DataProcessor
from src.data_encoder import DataEncoder

DF_PATH = (
    f"{DATA_DIR}"
    f"/DEAM/annotations/annotations per each rater/"
    f"song_level/static_annotations_songs_1_2000_raw.csv"
)

In [5]:
columns = ["workerID", "SongId", "Valence", "Arousal"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "Valence", "Arousal"]
original_df.head()

Unnamed: 0,user_id,item_id,Valence,Arousal
0,6010bbc8e7ef4b21fa38f9c3a9754ef3,2,5,2
1,3c888e77b992ae3cd2adfe16774e23b9,2,2,3
2,2afd218c3aecb6828d2be327f8b9c46f,2,3,3
3,fd5b08ce362d855ca9152a894348130c,2,4,4
4,9c8073214a052e414811b76012df8847,2,2,2


In [6]:
valence_dataframe = original_df[["user_id", "item_id", "Valence"]].copy()
valence_dataframe.columns = ["user_id", "item_id", "rating"]
create_histogram_features(data_frame=valence_dataframe)

data_encoder = DataEncoder(original_df=valence_dataframe)
data_processor = DataProcessor(original_df=valence_dataframe)

n_users = valence_dataframe.user_id.nunique()
n_items = valence_dataframe.item_id.nunique()

min_rating = min(valence_dataframe.rating.values)
max_rating = max(valence_dataframe.rating.values)

valence_model = HistogramMF(
    n_users=n_users,
    n_items=n_items,
    data_encoder=data_encoder,
    data_processor=data_processor,
    min_rating=min_rating,
    max_rating=max_rating,
)

if os.path.exists(f"{MODELS_DIR}/DEAM/valence.pt"):
    valence_model.load_state_dict(torch.load(f"{MODELS_DIR}/DEAM/valence.pt"))
else:
    epochs = 100

    criterion = Loss()
    optimizer = SGD(valence_model.parameters(), lr=5, weight_decay=1e-7)
    runner = Runner(model=valence_model, criterion=criterion, optimizer=optimizer)

    train_set = create_dataset(data_encoder=data_encoder)
    train_load = DataLoader(train_set, batch_size=1000, shuffle=True)

    with SummaryWriter(f"runs/DEAM/dev/valence") as writer:
        for epoch in range(epochs):
            epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
            print(f"epoch={epoch + 1}, loss={epoch_loss}")

    torch.save(valence_model.state_dict(), f"{MODELS_DIR}/DEAM/valence.pt")

  0%|          | 0/17464 [00:00<?, ?it/s]

  0%|          | 0/17464 [00:00<?, ?it/s]

100%|██████████| 18/18 [00:04<00:00,  4.42batch/s, train_loss=0.311]


epoch=1, loss=6.383931920544853


100%|██████████| 18/18 [00:03<00:00,  5.61batch/s, train_loss=0.263]


epoch=2, loss=4.732503604888917


100%|██████████| 18/18 [00:03<00:00,  4.79batch/s, train_loss=0.223]


epoch=3, loss=4.150253273536419


100%|██████████| 18/18 [00:03<00:00,  5.38batch/s, train_loss=0.186]


epoch=4, loss=3.565020938084043


100%|██████████| 18/18 [00:02<00:00,  6.11batch/s, train_loss=0.163]


epoch=5, loss=3.0460250245456044


100%|██████████| 18/18 [00:02<00:00,  6.17batch/s, train_loss=0.158]


epoch=6, loss=2.6578228046811865


100%|██████████| 18/18 [00:02<00:00,  6.25batch/s, train_loss=0.139]


epoch=7, loss=2.3213844880728884


100%|██████████| 18/18 [00:02<00:00,  6.17batch/s, train_loss=0.108]


epoch=8, loss=1.994448245476032


100%|██████████| 18/18 [00:03<00:00,  5.88batch/s, train_loss=0.103] 


epoch=9, loss=1.7074475283129462


100%|██████████| 18/18 [00:03<00:00,  5.38batch/s, train_loss=0.0827]


epoch=10, loss=1.4205988262439595


100%|██████████| 18/18 [00:03<00:00,  5.95batch/s, train_loss=0.0599]


epoch=11, loss=1.185410393846446


100%|██████████| 18/18 [00:02<00:00,  6.02batch/s, train_loss=0.0511]


epoch=12, loss=0.9926497220335335


100%|██████████| 18/18 [00:02<00:00,  6.14batch/s, train_loss=0.0464]


epoch=13, loss=0.8260459106379542


100%|██████████| 18/18 [00:03<00:00,  5.88batch/s, train_loss=0.0371]


epoch=14, loss=0.6844703639457965


100%|██████████| 18/18 [00:03<00:00,  5.98batch/s, train_loss=0.0285]


epoch=15, loss=0.5624886925795982


100%|██████████| 18/18 [00:03<00:00,  5.35batch/s, train_loss=0.0234]


epoch=16, loss=0.4780568293703013


100%|██████████| 18/18 [00:03<00:00,  5.62batch/s, train_loss=0.024] 


epoch=17, loss=0.4077889876365661


100%|██████████| 18/18 [00:03<00:00,  5.90batch/s, train_loss=0.024] 


epoch=18, loss=0.35359360375897636


100%|██████████| 18/18 [00:03<00:00,  5.92batch/s, train_loss=0.0179]


epoch=19, loss=0.3069056426410018


100%|██████████| 18/18 [00:03<00:00,  5.97batch/s, train_loss=0.0145]


epoch=20, loss=0.2688478798372993


100%|██████████| 18/18 [00:02<00:00,  6.06batch/s, train_loss=0.0131]


epoch=21, loss=0.2336178421645329


100%|██████████| 18/18 [00:03<00:00,  5.67batch/s, train_loss=0.0133] 


epoch=22, loss=0.20777721770878496


100%|██████████| 18/18 [00:03<00:00,  4.97batch/s, train_loss=0.0135] 


epoch=23, loss=0.18944488273817917


100%|██████████| 18/18 [00:03<00:00,  5.91batch/s, train_loss=0.0102] 


epoch=24, loss=0.17009191212982966


100%|██████████| 18/18 [00:02<00:00,  6.28batch/s, train_loss=0.0137] 


epoch=25, loss=0.15622532178615703


100%|██████████| 18/18 [00:02<00:00,  6.02batch/s, train_loss=0.00923]


epoch=26, loss=0.14089556685809432


100%|██████████| 18/18 [00:03<00:00,  5.71batch/s, train_loss=0.00637]


epoch=27, loss=0.12749240341679802


100%|██████████| 18/18 [00:03<00:00,  5.56batch/s, train_loss=0.00529]


epoch=28, loss=0.1117434310090953


100%|██████████| 18/18 [00:02<00:00,  6.20batch/s, train_loss=0.00286]


epoch=29, loss=0.09681476047121246


100%|██████████| 18/18 [00:03<00:00,  5.97batch/s, train_loss=0.00455]


epoch=30, loss=0.08964730649158872


100%|██████████| 18/18 [00:03<00:00,  5.99batch/s, train_loss=0.00298]


epoch=31, loss=0.08023167870784627


100%|██████████| 18/18 [00:03<00:00,  6.00batch/s, train_loss=0.00287]


epoch=32, loss=0.0744772935139722


100%|██████████| 18/18 [00:03<00:00,  5.23batch/s, train_loss=0.00477]


epoch=33, loss=0.0693534399887611


100%|██████████| 18/18 [00:03<00:00,  5.89batch/s, train_loss=0.004]  


epoch=34, loss=0.060735426939766975


100%|██████████| 18/18 [00:02<00:00,  6.04batch/s, train_loss=0.00428]


epoch=35, loss=0.05600862792853652


100%|██████████| 18/18 [00:03<00:00,  5.97batch/s, train_loss=0.00291]


epoch=36, loss=0.052765076799639335


100%|██████████| 18/18 [00:02<00:00,  6.03batch/s, train_loss=0.00339]


epoch=37, loss=0.048469042778015135


100%|██████████| 18/18 [00:03<00:00,  5.82batch/s, train_loss=0.00333]


epoch=38, loss=0.04435193820862934


100%|██████████| 18/18 [00:02<00:00,  6.11batch/s, train_loss=0.00391]


epoch=39, loss=0.04230885281233952


100%|██████████| 18/18 [00:03<00:00,  5.14batch/s, train_loss=0.000939]


epoch=40, loss=0.03829761135269855


100%|██████████| 18/18 [00:02<00:00,  6.02batch/s, train_loss=0.000904]


epoch=41, loss=0.037157582969501106


100%|██████████| 18/18 [00:03<00:00,  5.97batch/s, train_loss=0.00375]


epoch=42, loss=0.03693986038092909


100%|██████████| 18/18 [00:03<00:00,  5.29batch/s, train_loss=0.00349]


epoch=43, loss=0.034717151378763134


100%|██████████| 18/18 [00:03<00:00,  5.89batch/s, train_loss=0.00287]


epoch=44, loss=0.03115682227241582


100%|██████████| 18/18 [00:02<00:00,  6.06batch/s, train_loss=0.00198] 


epoch=45, loss=0.02941639056390729


100%|██████████| 18/18 [00:03<00:00,  5.19batch/s, train_loss=0.00138]


epoch=46, loss=0.02962011884307039


100%|██████████| 18/18 [00:03<00:00,  5.48batch/s, train_loss=0.000993]


epoch=47, loss=0.02887800595575365


100%|██████████| 18/18 [00:03<00:00,  5.47batch/s, train_loss=0.00249] 


epoch=48, loss=0.027331079605324507


100%|██████████| 18/18 [00:03<00:00,  5.69batch/s, train_loss=0.00117]


epoch=49, loss=0.02744373628394357


100%|██████████| 18/18 [00:03<00:00,  5.75batch/s, train_loss=0.00113] 


epoch=50, loss=0.026555833144434566


100%|██████████| 18/18 [00:02<00:00,  6.03batch/s, train_loss=0.00249] 


epoch=51, loss=0.025805867178686735


100%|██████████| 18/18 [00:03<00:00,  5.71batch/s, train_loss=0.000524]


epoch=52, loss=0.02388806276639988


100%|██████████| 18/18 [00:02<00:00,  6.08batch/s, train_loss=0.00201] 


epoch=53, loss=0.026303588462286975


100%|██████████| 18/18 [00:03<00:00,  5.13batch/s, train_loss=0.00335] 


epoch=54, loss=0.02524398691695312


100%|██████████| 18/18 [00:03<00:00,  5.78batch/s, train_loss=0.000933]


epoch=55, loss=0.024058763433119346


100%|██████████| 18/18 [00:02<00:00,  6.04batch/s, train_loss=0.00147] 


epoch=56, loss=0.021917306515677224


100%|██████████| 18/18 [00:03<00:00,  5.90batch/s, train_loss=0.00104]


epoch=57, loss=0.022069511391479394


100%|██████████| 18/18 [00:03<00:00,  5.87batch/s, train_loss=0.00211] 


epoch=58, loss=0.022769924882157094


100%|██████████| 18/18 [00:03<00:00,  5.92batch/s, train_loss=0.000485]


epoch=59, loss=0.02258325687333428


100%|██████████| 18/18 [00:03<00:00,  5.17batch/s, train_loss=0.00135] 


epoch=60, loss=0.02262802719761586


100%|██████████| 18/18 [00:02<00:00,  6.17batch/s, train_loss=0.00256] 


epoch=61, loss=0.02377507988440579


100%|██████████| 18/18 [00:03<00:00,  5.58batch/s, train_loss=0.00151] 


epoch=62, loss=0.022125290347584363


100%|██████████| 18/18 [00:03<00:00,  5.40batch/s, train_loss=0.00221] 


epoch=63, loss=0.022638217911638062


100%|██████████| 18/18 [00:03<00:00,  5.80batch/s, train_loss=0.00192] 


epoch=64, loss=0.022183081940330307


100%|██████████| 18/18 [00:02<00:00,  6.12batch/s, train_loss=0.000523]


epoch=65, loss=0.021040986193408225


100%|██████████| 18/18 [00:03<00:00,  5.81batch/s, train_loss=0.000984]


epoch=66, loss=0.021120962287845286


100%|██████████| 18/18 [00:02<00:00,  6.02batch/s, train_loss=0.00231] 


epoch=67, loss=0.02256446745786174


100%|██████████| 18/18 [00:03<00:00,  5.05batch/s, train_loss=0.00194] 


epoch=68, loss=0.020577304285661926


100%|██████████| 18/18 [00:03<00:00,  5.84batch/s, train_loss=0.00218] 


epoch=69, loss=0.021670266309688827


100%|██████████| 18/18 [00:03<00:00,  5.93batch/s, train_loss=0.000229]


epoch=70, loss=0.02188924962885935


100%|██████████| 18/18 [00:03<00:00,  5.90batch/s, train_loss=0.00227] 


epoch=71, loss=0.022223877158658253


100%|██████████| 18/18 [00:03<00:00,  5.77batch/s, train_loss=0.00234] 


epoch=72, loss=0.021639764163000823


100%|██████████| 18/18 [00:03<00:00,  5.45batch/s, train_loss=0.00103] 


epoch=73, loss=0.02131764173610457


100%|██████████| 18/18 [00:03<00:00,  5.51batch/s, train_loss=0.00181] 


epoch=74, loss=0.02136322732321147


100%|██████████| 18/18 [00:02<00:00,  6.17batch/s, train_loss=0.000648]


epoch=75, loss=0.020542784378446383


100%|██████████| 18/18 [00:03<00:00,  5.90batch/s, train_loss=0.001]   


epoch=76, loss=0.021324166302023266


100%|██████████| 18/18 [00:02<00:00,  6.01batch/s, train_loss=0.00112] 


epoch=77, loss=0.02168897149788922


100%|██████████| 18/18 [00:03<00:00,  5.23batch/s, train_loss=0.00133] 


epoch=78, loss=0.02134483542832835


100%|██████████| 18/18 [00:03<00:00,  5.46batch/s, train_loss=0.00132] 


epoch=79, loss=0.02187742165758692


100%|██████████| 18/18 [00:02<00:00,  6.13batch/s, train_loss=0.00111] 


epoch=80, loss=0.022332172088581938


100%|██████████| 18/18 [00:03<00:00,  5.11batch/s, train_loss=0.00252] 


epoch=81, loss=0.021598286616391148


100%|██████████| 18/18 [00:03<00:00,  5.94batch/s, train_loss=0.000914]


epoch=82, loss=0.020500023462135215


100%|██████████| 18/18 [00:02<00:00,  6.00batch/s, train_loss=0.00104] 


epoch=83, loss=0.02167596210539341


100%|██████████| 18/18 [00:03<00:00,  5.04batch/s, train_loss=0.000988]


epoch=84, loss=0.021214436924149253


100%|██████████| 18/18 [00:02<00:00,  6.18batch/s, train_loss=0.00212] 


epoch=85, loss=0.022050384291287122


100%|██████████| 18/18 [00:03<00:00,  5.74batch/s, train_loss=0.00185] 


epoch=86, loss=0.022011519738312425


100%|██████████| 18/18 [00:03<00:00,  5.94batch/s, train_loss=0.00183]


epoch=87, loss=0.021381726839419066


100%|██████████| 18/18 [00:03<00:00,  5.76batch/s, train_loss=0.0029]  


epoch=88, loss=0.02181437947626772


100%|██████████| 18/18 [00:03<00:00,  5.65batch/s, train_loss=0.00168] 


epoch=89, loss=0.02206848480372594


100%|██████████| 18/18 [00:03<00:00,  5.40batch/s, train_loss=0.00218] 


epoch=90, loss=0.021316534741171476


100%|██████████| 18/18 [00:03<00:00,  5.78batch/s, train_loss=0.00137] 


epoch=91, loss=0.02344620013545299


100%|██████████| 18/18 [00:03<00:00,  5.90batch/s, train_loss=0.000891]


epoch=92, loss=0.021899200996962086


100%|██████████| 18/18 [00:03<00:00,  6.00batch/s, train_loss=0.00196] 


epoch=93, loss=0.02258331509602481


100%|██████████| 18/18 [00:03<00:00,  5.77batch/s, train_loss=0.00227] 


epoch=94, loss=0.02221345600177502


100%|██████████| 18/18 [00:03<00:00,  5.08batch/s, train_loss=0.00344] 


epoch=95, loss=0.022293677373179072


100%|██████████| 18/18 [00:03<00:00,  5.81batch/s, train_loss=0.000232]


epoch=96, loss=0.022182538376799944


100%|██████████| 18/18 [00:03<00:00,  5.32batch/s, train_loss=0.00119] 


epoch=97, loss=0.021998145067486274


100%|██████████| 18/18 [00:03<00:00,  5.71batch/s, train_loss=0.00131] 


epoch=98, loss=0.022202859875457046


100%|██████████| 18/18 [00:03<00:00,  5.74batch/s, train_loss=0.00162] 


epoch=99, loss=0.0223618272193547


100%|██████████| 18/18 [00:03<00:00,  5.10batch/s, train_loss=0.0031]  

epoch=100, loss=0.022680542747522223





In [7]:
arousal_dataframe = original_df[["user_id", "item_id", "Arousal"]].copy()
arousal_dataframe.columns = ["user_id", "item_id", "rating"]
create_histogram_features(data_frame=arousal_dataframe)

data_encoder = DataEncoder(original_df=arousal_dataframe)
data_processor = DataProcessor(original_df=arousal_dataframe)

n_users = arousal_dataframe.user_id.nunique()
n_items = arousal_dataframe.item_id.nunique()

min_rating = min(arousal_dataframe.rating.values)
max_rating = max(arousal_dataframe.rating.values)

arousal_model = HistogramMF(
    n_users=n_users,
    n_items=n_items,
    data_encoder=data_encoder,
    data_processor=data_processor,
    min_rating=min_rating,
    max_rating=max_rating,
)

if os.path.exists(f"{MODELS_DIR}/DEAM/arousal.pt"):
    valence_model.load_state_dict(torch.load(f"{MODELS_DIR}/DEAM/arousal.pt"))
else:
    epochs = 100

    criterion = Loss()
    optimizer = SGD(arousal_model.parameters(), lr=5, weight_decay=1e-7)
    runner = Runner(model=arousal_model, criterion=criterion, optimizer=optimizer)

    train_set = create_dataset(data_encoder=data_encoder)
    train_load = DataLoader(train_set, batch_size=1000, shuffle=True)

    with SummaryWriter(f"runs/DEAM/dev/arousal") as writer:
        for epoch in range(epochs):
            epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
            print(f"epoch={epoch + 1}, loss={epoch_loss}")

    torch.save(arousal_model.state_dict(), f"{MODELS_DIR}/DEAM/arousal.pt")

  0%|          | 0/17464 [00:00<?, ?it/s]

  0%|          | 0/17464 [00:00<?, ?it/s]

100%|██████████| 18/18 [00:04<00:00,  4.29batch/s, train_loss=0.291]


epoch=1, loss=6.560973037982809


100%|██████████| 18/18 [00:03<00:00,  5.49batch/s, train_loss=0.255]


epoch=2, loss=4.800009394810116


100%|██████████| 18/18 [00:02<00:00,  6.23batch/s, train_loss=0.228]


epoch=3, loss=4.121323705081282


100%|██████████| 18/18 [00:02<00:00,  6.01batch/s, train_loss=0.172]


epoch=4, loss=3.4163448737571986


100%|██████████| 18/18 [00:02<00:00,  6.32batch/s, train_loss=0.152]


epoch=5, loss=3.011182019726983


100%|██████████| 18/18 [00:03<00:00,  5.39batch/s, train_loss=0.149]


epoch=6, loss=2.676035742134884


100%|██████████| 18/18 [00:03<00:00,  5.56batch/s, train_loss=0.145]


epoch=7, loss=2.382401096870159


100%|██████████| 18/18 [00:03<00:00,  5.59batch/s, train_loss=0.122]


epoch=8, loss=2.055350344559242


100%|██████████| 18/18 [00:02<00:00,  6.00batch/s, train_loss=0.0881]


epoch=9, loss=1.7405594166722789


100%|██████████| 18/18 [00:03<00:00,  5.80batch/s, train_loss=0.0826]


epoch=10, loss=1.4674977299262741


100%|██████████| 18/18 [00:03<00:00,  5.87batch/s, train_loss=0.0693]


epoch=11, loss=1.2247930056473306


100%|██████████| 18/18 [00:03<00:00,  5.66batch/s, train_loss=0.0628]


epoch=12, loss=1.005058127370374


100%|██████████| 18/18 [00:03<00:00,  5.88batch/s, train_loss=0.0504]


epoch=13, loss=0.8224670210213497


100%|██████████| 18/18 [00:03<00:00,  5.68batch/s, train_loss=0.0316]


epoch=14, loss=0.670354204687579


100%|██████████| 18/18 [00:03<00:00,  5.97batch/s, train_loss=0.0283]


epoch=15, loss=0.5609392130785976


100%|██████████| 18/18 [00:03<00:00,  5.60batch/s, train_loss=0.0252]


epoch=16, loss=0.4732061998761932


100%|██████████| 18/18 [00:03<00:00,  5.90batch/s, train_loss=0.0233]


epoch=17, loss=0.4092210789877793


100%|██████████| 18/18 [00:03<00:00,  5.62batch/s, train_loss=0.0192]


epoch=18, loss=0.3528593425750733


100%|██████████| 18/18 [00:03<00:00,  5.96batch/s, train_loss=0.0183]


epoch=19, loss=0.30300438369553667


100%|██████████| 18/18 [00:03<00:00,  5.55batch/s, train_loss=0.0117]


epoch=20, loss=0.264198019874507


100%|██████████| 18/18 [00:03<00:00,  5.52batch/s, train_loss=0.0136]


epoch=21, loss=0.23390771002605043


100%|██████████| 18/18 [00:03<00:00,  4.69batch/s, train_loss=0.0171]


epoch=22, loss=0.21287885727553524


100%|██████████| 18/18 [00:03<00:00,  6.00batch/s, train_loss=0.0131] 


epoch=23, loss=0.18591664292072427


100%|██████████| 18/18 [00:02<00:00,  6.37batch/s, train_loss=0.00641]


epoch=24, loss=0.1635105659509527


100%|██████████| 18/18 [00:02<00:00,  6.63batch/s, train_loss=0.00715]


epoch=25, loss=0.14628627765589747


100%|██████████| 18/18 [00:02<00:00,  6.30batch/s, train_loss=0.00901]


epoch=26, loss=0.1308389554928089


100%|██████████| 18/18 [00:02<00:00,  6.63batch/s, train_loss=0.0072] 


epoch=27, loss=0.11623109927670709


100%|██████████| 18/18 [00:02<00:00,  6.34batch/s, train_loss=0.00418]


epoch=28, loss=0.10149495212990663


100%|██████████| 18/18 [00:02<00:00,  6.71batch/s, train_loss=0.00356]


epoch=29, loss=0.08946685215522504


100%|██████████| 18/18 [00:02<00:00,  6.32batch/s, train_loss=0.00429]


epoch=30, loss=0.08193518233710324


100%|██████████| 18/18 [00:02<00:00,  6.63batch/s, train_loss=0.0101] 


epoch=31, loss=0.07358506238049474


100%|██████████| 18/18 [00:02<00:00,  6.43batch/s, train_loss=0.00151]


epoch=32, loss=0.058409005998537454


100%|██████████| 18/18 [00:02<00:00,  6.46batch/s, train_loss=0.00475]


epoch=33, loss=0.057249595300904635


100%|██████████| 18/18 [00:02<00:00,  6.37batch/s, train_loss=0.00305]


epoch=34, loss=0.04874607763208193


100%|██████████| 18/18 [00:02<00:00,  6.70batch/s, train_loss=0.00428]


epoch=35, loss=0.045439261656382984


100%|██████████| 18/18 [00:02<00:00,  6.28batch/s, train_loss=0.00357]


epoch=36, loss=0.042333711396003594


100%|██████████| 18/18 [00:02<00:00,  6.68batch/s, train_loss=0.00141]


epoch=37, loss=0.03911272967375557


100%|██████████| 18/18 [00:02<00:00,  6.37batch/s, train_loss=0.00239]


epoch=38, loss=0.036254613013103086


100%|██████████| 18/18 [00:02<00:00,  6.70batch/s, train_loss=0.00261]


epoch=39, loss=0.035198182823329136


100%|██████████| 18/18 [00:02<00:00,  6.43batch/s, train_loss=0.00178]


epoch=40, loss=0.032995167044730024


100%|██████████| 18/18 [00:02<00:00,  6.66batch/s, train_loss=0.00141]


epoch=41, loss=0.02893742339878247


100%|██████████| 18/18 [00:02<00:00,  6.34batch/s, train_loss=0.00086]


epoch=42, loss=0.027550265181167374


100%|██████████| 18/18 [00:02<00:00,  6.74batch/s, train_loss=0.0024]  


epoch=43, loss=0.02713426171294574


100%|██████████| 18/18 [00:02<00:00,  6.37batch/s, train_loss=0.00306] 


epoch=44, loss=0.02690864181929621


100%|██████████| 18/18 [00:02<00:00,  6.01batch/s, train_loss=0.000834]


epoch=45, loss=0.024179839513425172


100%|██████████| 18/18 [00:03<00:00,  5.64batch/s, train_loss=0.00203] 


epoch=46, loss=0.025007542340919887


100%|██████████| 18/18 [00:03<00:00,  5.86batch/s, train_loss=0.00055] 


epoch=47, loss=0.023086379990495485


100%|██████████| 18/18 [00:02<00:00,  6.41batch/s, train_loss=0.00087] 


epoch=48, loss=0.02198525396194951


100%|██████████| 18/18 [00:02<00:00,  6.64batch/s, train_loss=0.000332]


epoch=49, loss=0.021445421184959082


100%|██████████| 18/18 [00:02<00:00,  6.27batch/s, train_loss=0.00106] 


epoch=50, loss=0.021233560326798212


100%|██████████| 18/18 [00:02<00:00,  6.80batch/s, train_loss=0.00135] 


epoch=51, loss=0.021185576904436634


100%|██████████| 18/18 [00:02<00:00,  6.33batch/s, train_loss=0.00108] 


epoch=52, loss=0.021081938380825112


100%|██████████| 18/18 [00:02<00:00,  6.69batch/s, train_loss=0.00188] 


epoch=53, loss=0.02029948357672527


100%|██████████| 18/18 [00:02<00:00,  6.15batch/s, train_loss=0.0016]  


epoch=54, loss=0.020090572945002855


100%|██████████| 18/18 [00:02<00:00,  6.69batch/s, train_loss=0.00279] 


epoch=55, loss=0.020339019324245124


100%|██████████| 18/18 [00:02<00:00,  6.35batch/s, train_loss=0.00215] 


epoch=56, loss=0.020892584268389077


100%|██████████| 18/18 [00:02<00:00,  6.65batch/s, train_loss=0.000573]


epoch=57, loss=0.019379515489113742


100%|██████████| 18/18 [00:02<00:00,  6.34batch/s, train_loss=0.000497]


epoch=58, loss=0.018053792283195876


100%|██████████| 18/18 [00:02<00:00,  6.57batch/s, train_loss=0.00024] 


epoch=59, loss=0.01903321313870878


100%|██████████| 18/18 [00:02<00:00,  6.36batch/s, train_loss=0.000197]


epoch=60, loss=0.018474427978154913


100%|██████████| 18/18 [00:02<00:00,  6.66batch/s, train_loss=0.000762]


epoch=61, loss=0.018629667972696235


100%|██████████| 18/18 [00:02<00:00,  6.13batch/s, train_loss=0.00117] 


epoch=62, loss=0.019082345618256204


100%|██████████| 18/18 [00:03<00:00,  5.96batch/s, train_loss=0.000786]


epoch=63, loss=0.01794840656580596


100%|██████████| 18/18 [00:03<00:00,  5.96batch/s, train_loss=0.000308]


epoch=64, loss=0.019204470225449265


100%|██████████| 18/18 [00:03<00:00,  4.96batch/s, train_loss=0.00106] 


epoch=65, loss=0.018938778393227475


100%|██████████| 18/18 [00:03<00:00,  5.58batch/s, train_loss=0.00304] 


epoch=66, loss=0.019969119825753674


100%|██████████| 18/18 [00:02<00:00,  6.14batch/s, train_loss=0.000721]


epoch=67, loss=0.018221944941015075


100%|██████████| 18/18 [00:03<00:00,  5.86batch/s, train_loss=0.000576]


epoch=68, loss=0.018032457382000724


100%|██████████| 18/18 [00:02<00:00,  6.15batch/s, train_loss=0.000164]


epoch=69, loss=0.017713539210758335


100%|██████████| 18/18 [00:03<00:00,  5.60batch/s, train_loss=0.000569]


epoch=70, loss=0.017902080375058898


100%|██████████| 18/18 [00:03<00:00,  5.43batch/s, train_loss=0.000941]


epoch=71, loss=0.018031290833806165


100%|██████████| 18/18 [00:03<00:00,  5.84batch/s, train_loss=0.00123] 


epoch=72, loss=0.01834413431122385


100%|██████████| 18/18 [00:02<00:00,  6.25batch/s, train_loss=0.00119] 


epoch=73, loss=0.017259263882349277


100%|██████████| 18/18 [00:03<00:00,  5.66batch/s, train_loss=0.00014] 


epoch=74, loss=0.01786349843313982


100%|██████████| 18/18 [00:03<00:00,  5.85batch/s, train_loss=0.00124] 


epoch=75, loss=0.018812811311976662


100%|██████████| 18/18 [00:03<00:00,  4.69batch/s, train_loss=0.00298] 


epoch=76, loss=0.019064778124977803


100%|██████████| 18/18 [00:03<00:00,  5.72batch/s, train_loss=0.00126] 


epoch=77, loss=0.017550762551611872


100%|██████████| 18/18 [00:03<00:00,  5.85batch/s, train_loss=0.00104] 


epoch=78, loss=0.01943461649017087


100%|██████████| 18/18 [00:03<00:00,  5.34batch/s, train_loss=0.00221] 


epoch=79, loss=0.01892996363393192


100%|██████████| 18/18 [00:03<00:00,  5.58batch/s, train_loss=0.000967]


epoch=80, loss=0.018128993886298142


100%|██████████| 18/18 [00:02<00:00,  6.08batch/s, train_loss=0.0031]  


epoch=81, loss=0.019275142720033382


100%|██████████| 18/18 [00:03<00:00,  5.87batch/s, train_loss=0.00207] 


epoch=82, loss=0.01847763595087775


100%|██████████| 18/18 [00:02<00:00,  6.14batch/s, train_loss=0.00128] 


epoch=83, loss=0.0183011861248263


100%|██████████| 18/18 [00:03<00:00,  5.07batch/s, train_loss=0.00181] 


epoch=84, loss=0.018271280254783303


100%|██████████| 18/18 [00:03<00:00,  5.87batch/s, train_loss=0.00133] 


epoch=85, loss=0.019077914934219987


100%|██████████| 18/18 [00:03<00:00,  5.83batch/s, train_loss=0.000294]


epoch=86, loss=0.01707346561936469


100%|██████████| 18/18 [00:02<00:00,  6.10batch/s, train_loss=0.00214] 


epoch=87, loss=0.01914774719188953


100%|██████████| 18/18 [00:03<00:00,  5.87batch/s, train_loss=0.00035] 


epoch=88, loss=0.018142500984771496


100%|██████████| 18/18 [00:03<00:00,  4.99batch/s, train_loss=0.000147]


epoch=89, loss=0.017427950210098568


100%|██████████| 18/18 [00:03<00:00,  5.84batch/s, train_loss=0.00226] 


epoch=90, loss=0.018679503691607507


100%|██████████| 18/18 [00:02<00:00,  6.27batch/s, train_loss=0.000461]


epoch=91, loss=0.01878380195308348


100%|██████████| 18/18 [00:03<00:00,  5.92batch/s, train_loss=0.000988]


epoch=92, loss=0.018246654516664046


100%|██████████| 18/18 [00:02<00:00,  6.06batch/s, train_loss=0.00175] 


epoch=93, loss=0.01956295191008469


100%|██████████| 18/18 [00:03<00:00,  5.91batch/s, train_loss=0.000739]


epoch=94, loss=0.018711874898651552


100%|██████████| 18/18 [00:03<00:00,  5.32batch/s, train_loss=0.000934]


epoch=95, loss=0.018399095772669235


100%|██████████| 18/18 [00:03<00:00,  5.77batch/s, train_loss=0.000483]


epoch=96, loss=0.01788058820906384


100%|██████████| 18/18 [00:02<00:00,  6.25batch/s, train_loss=0.00202] 


epoch=97, loss=0.018817319060194078


100%|██████████| 18/18 [00:03<00:00,  5.89batch/s, train_loss=0.00259] 


epoch=98, loss=0.01892592593051236


100%|██████████| 18/18 [00:02<00:00,  6.23batch/s, train_loss=0.00275] 


epoch=99, loss=0.02067312302877163


100%|██████████| 18/18 [00:03<00:00,  5.47batch/s, train_loss=0.000902]

epoch=100, loss=0.017178350812916098





In [8]:
from pandas import DataFrame
from collections import namedtuple

Row = namedtuple("Row", "user_id item_id Valence Arousal")

data_encoder = DataEncoder(original_df=original_df)
dataframe_after_mf = []

with torch.no_grad():
    for (index, user_id, item_id, valence, arousal) in original_df.itertuples():
        encoded_user_id = data_encoder.get_encoded_user_id(original_id=user_id)
        encoded_item_id = data_encoder.get_encoded_item_id(original_id=item_id)

        user_id_as_tensor = torch.LongTensor([encoded_user_id])
        item_id_as_tensor = torch.LongTensor([encoded_item_id])
        valence_output = valence_model(users=user_id_as_tensor, items=item_id_as_tensor,).squeeze()[0]
        arousal_output = arousal_model(users=user_id_as_tensor, items=item_id_as_tensor,).squeeze()[0]
        predicted_valence = torch.round(valence_output).item()
        predicted_arousal = torch.round(arousal_output).item()

        dataframe_after_mf.append(Row(user_id=user_id, item_id=item_id, Valence=predicted_valence, Arousal=predicted_arousal))

df_after_mf = DataFrame(dataframe_after_mf, columns=["user_id", "item_id", "Valence", "Arousal"])
df_after_mf.head()

Unnamed: 0,user_id,item_id,Valence,Arousal
0,6010bbc8e7ef4b21fa38f9c3a9754ef3,2,5.0,2.0
1,3c888e77b992ae3cd2adfe16774e23b9,2,2.0,3.0
2,2afd218c3aecb6828d2be327f8b9c46f,2,3.0,3.0
3,fd5b08ce362d855ca9152a894348130c,2,4.0,4.0
4,9c8073214a052e414811b76012df8847,2,2.0,2.0


In [9]:
mask = ((original_df["Valence"] == df_after_mf["Valence"]) & (original_df["Arousal"] == df_after_mf["Arousal"]))
changes = original_df[mask].copy()
changes["New Valence"] = df_after_mf.Valence
changes["New Arousal"] = df_after_mf.Arousal
print(f"Number of hits: {len(changes)} / {len(original_df)}")
changes.head(len(changes))

Number of hits: 17345 / 17464


Unnamed: 0,user_id,item_id,Valence,Arousal,New Valence,New Arousal
0,6010bbc8e7ef4b21fa38f9c3a9754ef3,2,5,2,5.0,2.0
1,3c888e77b992ae3cd2adfe16774e23b9,2,2,3,2.0,3.0
2,2afd218c3aecb6828d2be327f8b9c46f,2,3,3,3.0,3.0
3,fd5b08ce362d855ca9152a894348130c,2,4,4,4.0,4.0
4,9c8073214a052e414811b76012df8847,2,2,2,2.0,2.0
...,...,...,...,...,...,...
17459,607f6e34a0b5923333f6b16d3a59cc98,2000,6,5,6.0,5.0
17460,78b5e9744073532cc376976b5fc6b2fc,2000,7,7,7.0,7.0
17461,7cecbffe1da5ae974952db6c13695afe,2000,4,5,4.0,5.0
17462,ed7ed76453bd846859f5e6b9149df276,2000,6,7,6.0,7.0
