In [21]:
import torch
import os
from pandas import read_csv
from torch.optim import SGD
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from config import DATA_DIR, MODELS_DIR
from src.loss import Loss
from src.model import HistogramMF
from src.runner import Runner
from src.create_dataset import create_dataset, create_histogram_features
from src.data_processor import DataProcessor
from src.data_encoder import DataEncoder

DF_PATH = (
    f"{DATA_DIR}"
    f"/DEAM/annotations/annotations per each rater/"
    f"song_level/static_annotations_songs_1_2000_raw.csv"
)

In [22]:
columns = ["workerID", "SongId", "Valence", "Arousal"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "Valence", "Arousal"]
original_df.head()

Unnamed: 0,user_id,item_id,Valence,Arousal
0,6010bbc8e7ef4b21fa38f9c3a9754ef3,2,5,2
1,3c888e77b992ae3cd2adfe16774e23b9,2,2,3
2,2afd218c3aecb6828d2be327f8b9c46f,2,3,3
3,fd5b08ce362d855ca9152a894348130c,2,4,4
4,9c8073214a052e414811b76012df8847,2,2,2


In [23]:
valence_dataframe = original_df[["user_id", "item_id", "Valence"]].copy()
valence_dataframe.columns = ["user_id", "item_id", "rating"]
create_histogram_features(data_frame=valence_dataframe)

data_encoder = DataEncoder(original_df=valence_dataframe)
data_processor = DataProcessor(original_df=valence_dataframe)

n_users = valence_dataframe.user_id.nunique()
n_items = valence_dataframe.item_id.nunique()

min_rating = min(valence_dataframe.rating.values)
max_rating = max(valence_dataframe.rating.values)

valence_model = HistogramMF(
    n_users=n_users,
    n_items=n_items,
    data_encoder=data_encoder,
    data_processor=data_processor,
    min_rating=min_rating,
    max_rating=max_rating,
)

if os.path.exists(f"{MODELS_DIR}/DEAM/valence.pt"):
    valence_model.load_state_dict(torch.load(f"{MODELS_DIR}/DEAM/valence.pt"))
else:
    epochs = 100

    criterion = Loss()
    optimizer = SGD(valence_model.parameters(), lr=5, weight_decay=1e-7)
    runner = Runner(model=valence_model, criterion=criterion, optimizer=optimizer)

    train_set = create_dataset(data_encoder=data_encoder)
    train_load = DataLoader(train_set, batch_size=1000, shuffle=True)

    with SummaryWriter(f"runs/DEAM/dev/valence") as writer:
        for epoch in range(epochs):
            epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
            print(f"epoch={epoch + 1}, loss={epoch_loss}")

    torch.save(valence_model.state_dict(), f"{MODELS_DIR}/DEAM/valence.pt")

  0%|          | 0/17464 [00:00<?, ?it/s]

  0%|          | 0/17464 [00:00<?, ?it/s]

100%|██████████| 18/18 [00:06<00:00,  2.86batch/s, train_loss=0.325]


epoch=1, loss=6.521853092851309


100%|██████████| 18/18 [00:04<00:00,  3.98batch/s, train_loss=0.267]


epoch=2, loss=4.776729629516602


100%|██████████| 18/18 [00:04<00:00,  3.95batch/s, train_loss=0.22] 


epoch=3, loss=4.199384980431919


100%|██████████| 18/18 [00:04<00:00,  3.91batch/s, train_loss=0.186]


epoch=4, loss=3.6083344463479925


100%|██████████| 18/18 [00:04<00:00,  3.89batch/s, train_loss=0.161]


epoch=5, loss=3.0649780572036214


100%|██████████| 18/18 [00:04<00:00,  3.80batch/s, train_loss=0.159]


epoch=6, loss=2.6652764484800144


100%|██████████| 18/18 [00:04<00:00,  3.97batch/s, train_loss=0.129]


epoch=7, loss=2.324542478166778


100%|██████████| 18/18 [00:04<00:00,  3.97batch/s, train_loss=0.109]


epoch=8, loss=2.033963793721693


100%|██████████| 18/18 [00:04<00:00,  3.86batch/s, train_loss=0.0937]


epoch=9, loss=1.7261992470971468


100%|██████████| 18/18 [00:04<00:00,  4.08batch/s, train_loss=0.0883]


epoch=10, loss=1.459153352540115


100%|██████████| 18/18 [00:04<00:00,  3.91batch/s, train_loss=0.0757]


epoch=11, loss=1.21397439851432


100%|██████████| 18/18 [00:04<00:00,  3.83batch/s, train_loss=0.0579]


epoch=12, loss=1.0169709300994874


100%|██████████| 18/18 [00:04<00:00,  3.91batch/s, train_loss=0.0492]


epoch=13, loss=0.8405246078228129


100%|██████████| 18/18 [00:04<00:00,  3.86batch/s, train_loss=0.0404]


epoch=14, loss=0.698511457509008


100%|██████████| 18/18 [00:04<00:00,  4.19batch/s, train_loss=0.0345]


epoch=15, loss=0.5773671921697158


100%|██████████| 18/18 [00:04<00:00,  3.86batch/s, train_loss=0.029] 


epoch=16, loss=0.48673683436163534


100%|██████████| 18/18 [00:04<00:00,  4.14batch/s, train_loss=0.0271]


epoch=17, loss=0.4172325006024591


100%|██████████| 18/18 [00:04<00:00,  3.93batch/s, train_loss=0.0195]


epoch=18, loss=0.3666163476582231


100%|██████████| 18/18 [00:04<00:00,  3.98batch/s, train_loss=0.0136]


epoch=19, loss=0.3190625543676574


100%|██████████| 18/18 [00:04<00:00,  3.87batch/s, train_loss=0.0138]


epoch=20, loss=0.2813132001695962


100%|██████████| 18/18 [00:04<00:00,  3.95batch/s, train_loss=0.0161]


epoch=21, loss=0.2508610283177475


100%|██████████| 18/18 [00:04<00:00,  3.77batch/s, train_loss=0.0164]


epoch=22, loss=0.22307591611763525


100%|██████████| 18/18 [00:04<00:00,  4.02batch/s, train_loss=0.0105] 


epoch=23, loss=0.19503997856173022


100%|██████████| 18/18 [00:04<00:00,  3.72batch/s, train_loss=0.00844]


epoch=24, loss=0.17326280653476717


100%|██████████| 18/18 [00:04<00:00,  3.89batch/s, train_loss=0.0096] 


epoch=25, loss=0.15866537060408759


100%|██████████| 18/18 [00:04<00:00,  3.94batch/s, train_loss=0.00756]


epoch=26, loss=0.13817068136560506


100%|██████████| 18/18 [00:04<00:00,  3.98batch/s, train_loss=0.00775]


epoch=27, loss=0.1239402412546092


100%|██████████| 18/18 [00:04<00:00,  3.90batch/s, train_loss=0.00712]


epoch=28, loss=0.1110527341530241


100%|██████████| 18/18 [00:05<00:00,  3.59batch/s, train_loss=0.00554]


epoch=29, loss=0.10217976634255771


100%|██████████| 18/18 [00:05<00:00,  3.53batch/s, train_loss=0.00449]


epoch=30, loss=0.09314877112158415


100%|██████████| 18/18 [00:04<00:00,  4.13batch/s, train_loss=0.00526]


epoch=31, loss=0.08695946130670348


100%|██████████| 18/18 [00:04<00:00,  3.97batch/s, train_loss=0.00329]


epoch=32, loss=0.07624106483212834


100%|██████████| 18/18 [00:04<00:00,  4.07batch/s, train_loss=0.00353]


epoch=33, loss=0.0699313923141052


100%|██████████| 18/18 [00:04<00:00,  3.77batch/s, train_loss=0.00391]


epoch=34, loss=0.06313802824965839


100%|██████████| 18/18 [00:04<00:00,  3.87batch/s, train_loss=0.00475]


epoch=35, loss=0.058460074597391574


100%|██████████| 18/18 [00:04<00:00,  3.80batch/s, train_loss=0.00459]


epoch=36, loss=0.05205999017583913


100%|██████████| 18/18 [00:04<00:00,  3.89batch/s, train_loss=0.00226]


epoch=37, loss=0.0485165479368177


100%|██████████| 18/18 [00:04<00:00,  3.86batch/s, train_loss=0.00242]


epoch=38, loss=0.046665991388518235


100%|██████████| 18/18 [00:04<00:00,  3.98batch/s, train_loss=0.00329]


epoch=39, loss=0.044168888135203


100%|██████████| 18/18 [00:04<00:00,  3.89batch/s, train_loss=0.00311]


epoch=40, loss=0.03978913072676495


100%|██████████| 18/18 [00:04<00:00,  4.01batch/s, train_loss=0.00246]


epoch=41, loss=0.03608081063114364


100%|██████████| 18/18 [00:05<00:00,  3.58batch/s, train_loss=0.000818]


epoch=42, loss=0.03214301371111952


100%|██████████| 18/18 [00:04<00:00,  3.85batch/s, train_loss=0.00139] 


epoch=43, loss=0.03280105018615723


100%|██████████| 18/18 [00:04<00:00,  3.91batch/s, train_loss=0.0018]  


epoch=44, loss=0.03156157575188012


100%|██████████| 18/18 [00:04<00:00,  3.90batch/s, train_loss=0.000618]


epoch=45, loss=0.029922672630384048


100%|██████████| 18/18 [00:04<00:00,  3.91batch/s, train_loss=0.00369] 


epoch=46, loss=0.028370108672257123


100%|██████████| 18/18 [00:04<00:00,  3.79batch/s, train_loss=0.00114]


epoch=47, loss=0.026008306745825143


100%|██████████| 18/18 [00:04<00:00,  3.66batch/s, train_loss=0.000845]


epoch=48, loss=0.026198463672707818


100%|██████████| 18/18 [00:04<00:00,  3.87batch/s, train_loss=0.00087] 


epoch=49, loss=0.02491024161566948


100%|██████████| 18/18 [00:04<00:00,  3.87batch/s, train_loss=0.000698]


epoch=50, loss=0.023526768446482462


100%|██████████| 18/18 [00:04<00:00,  4.18batch/s, train_loss=0.00156] 


epoch=51, loss=0.02366144766478703


100%|██████████| 18/18 [00:04<00:00,  3.95batch/s, train_loss=0.00112] 


epoch=52, loss=0.022509636895409947


100%|██████████| 18/18 [00:04<00:00,  4.03batch/s, train_loss=0.00094] 


epoch=53, loss=0.022044451294787996


100%|██████████| 18/18 [00:04<00:00,  3.94batch/s, train_loss=0.00147] 


epoch=54, loss=0.02257246486894016


100%|██████████| 18/18 [00:04<00:00,  3.90batch/s, train_loss=0.000432]


epoch=55, loss=0.022417290372838232


100%|██████████| 18/18 [00:04<00:00,  4.00batch/s, train_loss=0.00215] 


epoch=56, loss=0.02208857688102229


100%|██████████| 18/18 [00:04<00:00,  4.09batch/s, train_loss=0.000246]


epoch=57, loss=0.022092115905392787


100%|██████████| 18/18 [00:04<00:00,  3.81batch/s, train_loss=0.00096] 


epoch=58, loss=0.021407837559950762


100%|██████████| 18/18 [00:04<00:00,  3.89batch/s, train_loss=0.0023]  


epoch=59, loss=0.02167872947248919


100%|██████████| 18/18 [00:04<00:00,  3.69batch/s, train_loss=0.00413] 


epoch=60, loss=0.022750064570328285


100%|██████████| 18/18 [00:04<00:00,  3.80batch/s, train_loss=0.0023]  


epoch=61, loss=0.02181090334571641


100%|██████████| 18/18 [00:04<00:00,  3.79batch/s, train_loss=0.00227] 


epoch=62, loss=0.02155881706805065


100%|██████████| 18/18 [00:04<00:00,  3.83batch/s, train_loss=0.00048] 


epoch=63, loss=0.02015174902384651


100%|██████████| 18/18 [00:04<00:00,  3.84batch/s, train_loss=7.08e-5] 


epoch=64, loss=0.01985748724632993


100%|██████████| 18/18 [00:04<00:00,  3.86batch/s, train_loss=0.0018]  


epoch=65, loss=0.021430513803301184


100%|██████████| 18/18 [00:04<00:00,  3.89batch/s, train_loss=0.00051]


epoch=66, loss=0.020634151709747726


100%|██████████| 18/18 [00:04<00:00,  4.21batch/s, train_loss=0.00163] 


epoch=67, loss=0.02066695100582879


100%|██████████| 18/18 [00:04<00:00,  3.99batch/s, train_loss=0.00163] 


epoch=68, loss=0.021364695570592226


100%|██████████| 18/18 [00:04<00:00,  4.15batch/s, train_loss=0.00205] 


epoch=69, loss=0.020954180239603437


100%|██████████| 18/18 [00:04<00:00,  3.90batch/s, train_loss=0.00133] 


epoch=70, loss=0.021492125515280094


100%|██████████| 18/18 [00:04<00:00,  4.07batch/s, train_loss=0.00138] 


epoch=71, loss=0.021624359149357366


100%|██████████| 18/18 [00:04<00:00,  4.00batch/s, train_loss=0.00157] 


epoch=72, loss=0.021646074115202342


100%|██████████| 18/18 [00:04<00:00,  4.13batch/s, train_loss=0.00141] 


epoch=73, loss=0.02094285467369803


100%|██████████| 18/18 [00:04<00:00,  3.89batch/s, train_loss=0.00162] 


epoch=74, loss=0.02199914417903999


100%|██████████| 18/18 [00:04<00:00,  3.97batch/s, train_loss=0.000528]


epoch=75, loss=0.02096446021492111


100%|██████████| 18/18 [00:04<00:00,  3.91batch/s, train_loss=0.0011]  


epoch=76, loss=0.0197589353497686


100%|██████████| 18/18 [00:04<00:00,  4.08batch/s, train_loss=0.00114] 


epoch=77, loss=0.021812060509776243


100%|██████████| 18/18 [00:04<00:00,  3.99batch/s, train_loss=0.00124] 


epoch=78, loss=0.020828154524852493


100%|██████████| 18/18 [00:04<00:00,  3.98batch/s, train_loss=0.000628]


epoch=79, loss=0.021225404536929626


100%|██████████| 18/18 [00:04<00:00,  4.06batch/s, train_loss=0.00158] 


epoch=80, loss=0.021683403436479898


100%|██████████| 18/18 [00:04<00:00,  4.04batch/s, train_loss=0.00022]


epoch=81, loss=0.020821052888857907


100%|██████████| 18/18 [00:04<00:00,  4.04batch/s, train_loss=0.00103] 


epoch=82, loss=0.021445686650687253


100%|██████████| 18/18 [00:04<00:00,  4.11batch/s, train_loss=0.000806]


epoch=83, loss=0.021395423811571352


100%|██████████| 18/18 [00:04<00:00,  3.84batch/s, train_loss=0.00247] 


epoch=84, loss=0.02205955018360039


100%|██████████| 18/18 [00:04<00:00,  3.93batch/s, train_loss=0.00204] 


epoch=85, loss=0.021189972958688078


100%|██████████| 18/18 [00:04<00:00,  3.91batch/s, train_loss=0.00124] 


epoch=86, loss=0.021080183519371624


100%|██████████| 18/18 [00:04<00:00,  4.08batch/s, train_loss=0.00279] 


epoch=87, loss=0.020512428500528996


100%|██████████| 18/18 [00:04<00:00,  3.78batch/s, train_loss=0.000439]


epoch=88, loss=0.020223835183114842


100%|██████████| 18/18 [00:04<00:00,  4.05batch/s, train_loss=0.00296] 


epoch=89, loss=0.022856281567236472


100%|██████████| 18/18 [00:04<00:00,  4.00batch/s, train_loss=0.00274] 


epoch=90, loss=0.022178628646094226


100%|██████████| 18/18 [00:04<00:00,  4.04batch/s, train_loss=0.00222] 


epoch=91, loss=0.02223058477146872


100%|██████████| 18/18 [00:04<00:00,  4.02batch/s, train_loss=0.00107] 


epoch=92, loss=0.022048661013101715


100%|██████████| 18/18 [00:04<00:00,  4.09batch/s, train_loss=0.000864]


epoch=93, loss=0.022130224667232617


100%|██████████| 18/18 [00:04<00:00,  3.97batch/s, train_loss=0.00256] 


epoch=94, loss=0.022802576079450805


100%|██████████| 18/18 [00:04<00:00,  4.15batch/s, train_loss=0.00189] 


epoch=95, loss=0.022131025454093668


100%|██████████| 18/18 [00:04<00:00,  3.95batch/s, train_loss=0.00132] 


epoch=96, loss=0.021535125089102783


100%|██████████| 18/18 [00:04<00:00,  3.96batch/s, train_loss=0.00188] 


epoch=97, loss=0.022129894928685553


100%|██████████| 18/18 [00:04<00:00,  4.01batch/s, train_loss=0.000971]


epoch=98, loss=0.021398619867090525


100%|██████████| 18/18 [00:04<00:00,  4.02batch/s, train_loss=0.00131] 


epoch=99, loss=0.02111411698213939


100%|██████████| 18/18 [00:04<00:00,  3.89batch/s, train_loss=0.00189] 

epoch=100, loss=0.021990848892721637





In [24]:
arousal_dataframe = original_df[["user_id", "item_id", "Arousal"]].copy()
arousal_dataframe.columns = ["user_id", "item_id", "rating"]
create_histogram_features(data_frame=arousal_dataframe)

data_encoder = DataEncoder(original_df=arousal_dataframe)
data_processor = DataProcessor(original_df=arousal_dataframe)

n_users = arousal_dataframe.user_id.nunique()
n_items = arousal_dataframe.item_id.nunique()

min_rating = min(arousal_dataframe.rating.values)
max_rating = max(arousal_dataframe.rating.values)

arousal_model = HistogramMF(
    n_users=n_users,
    n_items=n_items,
    data_encoder=data_encoder,
    data_processor=data_processor,
    min_rating=min_rating,
    max_rating=max_rating,
)

if os.path.exists(f"{MODELS_DIR}/DEAM/arousal.pt"):
    valence_model.load_state_dict(torch.load(f"{MODELS_DIR}/DEAM/arousal.pt"))
else:
    epochs = 100

    criterion = Loss()
    optimizer = SGD(arousal_model.parameters(), lr=5, weight_decay=1e-7)
    runner = Runner(model=arousal_model, criterion=criterion, optimizer=optimizer)

    train_set = create_dataset(data_encoder=data_encoder)
    train_load = DataLoader(train_set, batch_size=1000, shuffle=True)

    with SummaryWriter(f"runs/DEAM/dev/arousal") as writer:
        for epoch in range(epochs):
            epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
            print(f"epoch={epoch + 1}, loss={epoch_loss}")

    torch.save(arousal_model.state_dict(), f"{MODELS_DIR}/DEAM/arousal.pt")

  0%|          | 0/17464 [00:00<?, ?it/s]

  0%|          | 0/17464 [00:00<?, ?it/s]

100%|██████████| 18/18 [00:05<00:00,  3.32batch/s, train_loss=0.317]


epoch=1, loss=6.653590763486664


100%|██████████| 18/18 [00:04<00:00,  3.68batch/s, train_loss=0.259]


epoch=2, loss=4.875762284245984


100%|██████████| 18/18 [00:04<00:00,  3.99batch/s, train_loss=0.222]


epoch=3, loss=4.2069693765311404


100%|██████████| 18/18 [00:05<00:00,  3.51batch/s, train_loss=0.192]


epoch=4, loss=3.4802345990148087


100%|██████████| 18/18 [00:05<00:00,  3.04batch/s, train_loss=0.158]


epoch=5, loss=3.0146817446741565


100%|██████████| 18/18 [00:05<00:00,  3.31batch/s, train_loss=0.143]


epoch=6, loss=2.6893231701028757


100%|██████████| 18/18 [00:04<00:00,  3.87batch/s, train_loss=0.129]


epoch=7, loss=2.3916961538380592


100%|██████████| 18/18 [00:05<00:00,  3.59batch/s, train_loss=0.107]


epoch=8, loss=2.0895474605560302


100%|██████████| 18/18 [00:04<00:00,  3.89batch/s, train_loss=0.107] 


epoch=9, loss=1.8123756814496268


100%|██████████| 18/18 [00:05<00:00,  3.47batch/s, train_loss=0.0821]


epoch=10, loss=1.5375270250254662


100%|██████████| 18/18 [00:05<00:00,  3.40batch/s, train_loss=0.075] 


epoch=11, loss=1.286348326255535


100%|██████████| 18/18 [00:05<00:00,  3.11batch/s, train_loss=0.0553]


epoch=12, loss=1.0726481052924846


100%|██████████| 18/18 [00:06<00:00,  2.96batch/s, train_loss=0.0433]


epoch=13, loss=0.861028903073278


100%|██████████| 18/18 [00:05<00:00,  3.43batch/s, train_loss=0.0402]


epoch=14, loss=0.7111846896204456


100%|██████████| 18/18 [00:06<00:00,  2.93batch/s, train_loss=0.0398]


epoch=15, loss=0.5941096956318823


100%|██████████| 18/18 [00:05<00:00,  3.48batch/s, train_loss=0.025] 


epoch=16, loss=0.4859649959761521


100%|██████████| 18/18 [00:05<00:00,  3.48batch/s, train_loss=0.0259]


epoch=17, loss=0.4240261362667741


100%|██████████| 18/18 [00:05<00:00,  3.06batch/s, train_loss=0.018] 


epoch=18, loss=0.3525913274699245


100%|██████████| 18/18 [00:04<00:00,  3.65batch/s, train_loss=0.0182]


epoch=19, loss=0.30556839069826847


100%|██████████| 18/18 [00:04<00:00,  3.70batch/s, train_loss=0.013] 


epoch=20, loss=0.26426265335905147


100%|██████████| 18/18 [00:06<00:00,  2.95batch/s, train_loss=0.00965]


epoch=21, loss=0.22901904919229704


100%|██████████| 18/18 [00:05<00:00,  3.13batch/s, train_loss=0.0137]


epoch=22, loss=0.2061722755596556


100%|██████████| 18/18 [00:05<00:00,  3.12batch/s, train_loss=0.0112] 


epoch=23, loss=0.18733901139785503


100%|██████████| 18/18 [00:05<00:00,  3.22batch/s, train_loss=0.0105] 


epoch=24, loss=0.16793370092326193


100%|██████████| 18/18 [00:05<00:00,  3.04batch/s, train_loss=0.00578]


epoch=25, loss=0.14603691434038094


100%|██████████| 18/18 [00:05<00:00,  3.24batch/s, train_loss=0.00978]


epoch=26, loss=0.13494262599122936


100%|██████████| 18/18 [00:05<00:00,  3.02batch/s, train_loss=0.00626]


epoch=27, loss=0.11873746588723413


100%|██████████| 18/18 [00:05<00:00,  3.33batch/s, train_loss=0.00436]


epoch=28, loss=0.10259792983943018


100%|██████████| 18/18 [00:04<00:00,  3.97batch/s, train_loss=0.00313]


epoch=29, loss=0.09058991734940432


100%|██████████| 18/18 [00:04<00:00,  3.67batch/s, train_loss=0.00425]


epoch=30, loss=0.08207690201545585


100%|██████████| 18/18 [00:04<00:00,  3.69batch/s, train_loss=0.00462]


epoch=31, loss=0.07239723721043817


100%|██████████| 18/18 [00:04<00:00,  3.89batch/s, train_loss=0.00792]


epoch=32, loss=0.0685143592645382


100%|██████████| 18/18 [00:03<00:00,  4.77batch/s, train_loss=0.00203]


epoch=33, loss=0.056750008130895674


100%|██████████| 18/18 [00:03<00:00,  5.10batch/s, train_loss=0.00484]


epoch=34, loss=0.05419179551354769


100%|██████████| 18/18 [00:03<00:00,  4.94batch/s, train_loss=0.000668]


epoch=35, loss=0.044306838238547586


100%|██████████| 18/18 [00:03<00:00,  5.28batch/s, train_loss=0.00243]


epoch=36, loss=0.040157288284137324


100%|██████████| 18/18 [00:03<00:00,  4.82batch/s, train_loss=0.00261]


epoch=37, loss=0.036540877502540065


100%|██████████| 18/18 [00:03<00:00,  5.12batch/s, train_loss=0.00122]


epoch=38, loss=0.03448344397236561


100%|██████████| 18/18 [00:03<00:00,  4.98batch/s, train_loss=0.0013]  


epoch=39, loss=0.03230363645841335


100%|██████████| 18/18 [00:04<00:00,  4.28batch/s, train_loss=0.0023]  


epoch=40, loss=0.03130907623110146


100%|██████████| 18/18 [00:04<00:00,  4.20batch/s, train_loss=0.00239]


epoch=41, loss=0.030546304949398695


100%|██████████| 18/18 [00:04<00:00,  4.34batch/s, train_loss=0.00217]


epoch=42, loss=0.02903238309251851


100%|██████████| 18/18 [00:03<00:00,  5.08batch/s, train_loss=0.00158] 


epoch=43, loss=0.02762838407734345


100%|██████████| 18/18 [00:03<00:00,  4.74batch/s, train_loss=0.00124]


epoch=44, loss=0.025470956724265525


100%|██████████| 18/18 [00:03<00:00,  4.81batch/s, train_loss=7.45e-5] 


epoch=45, loss=0.023982965813692792


100%|██████████| 18/18 [00:03<00:00,  4.85batch/s, train_loss=0.0018]  


epoch=46, loss=0.02458366893797085


100%|██████████| 18/18 [00:03<00:00,  4.78batch/s, train_loss=0.0019] 


epoch=47, loss=0.023890199976748433


100%|██████████| 18/18 [00:03<00:00,  5.00batch/s, train_loss=0.00104] 


epoch=48, loss=0.022584085332936252


100%|██████████| 18/18 [00:03<00:00,  4.62batch/s, train_loss=0.000272]


epoch=49, loss=0.021129438413885128


100%|██████████| 18/18 [00:03<00:00,  5.42batch/s, train_loss=0.0022]  


epoch=50, loss=0.022042858941801663


100%|██████████| 18/18 [00:03<00:00,  5.38batch/s, train_loss=0.00215] 


epoch=51, loss=0.022029137906329384


100%|██████████| 18/18 [00:03<00:00,  4.91batch/s, train_loss=0.00116] 


epoch=52, loss=0.02130525616325181


100%|██████████| 18/18 [00:03<00:00,  5.54batch/s, train_loss=0.000592]


epoch=53, loss=0.020847168047366475


100%|██████████| 18/18 [00:03<00:00,  5.46batch/s, train_loss=0.00129] 


epoch=54, loss=0.02142924555108465


100%|██████████| 18/18 [00:03<00:00,  5.06batch/s, train_loss=0.00162] 


epoch=55, loss=0.021651024961266023


100%|██████████| 18/18 [00:03<00:00,  5.27batch/s, train_loss=0.00135] 


epoch=56, loss=0.020914725355033213


100%|██████████| 18/18 [00:03<00:00,  5.60batch/s, train_loss=0.00142] 


epoch=57, loss=0.02086870620168489


100%|██████████| 18/18 [00:03<00:00,  5.03batch/s, train_loss=0.00214] 


epoch=58, loss=0.020798911952766883


100%|██████████| 18/18 [00:03<00:00,  5.27batch/s, train_loss=0.000921]


epoch=59, loss=0.02039702874730373


100%|██████████| 18/18 [00:03<00:00,  5.55batch/s, train_loss=0.00246] 


epoch=60, loss=0.020446650082695073


100%|██████████| 18/18 [00:03<00:00,  5.59batch/s, train_loss=0.00193] 


epoch=61, loss=0.019532819146740026


100%|██████████| 18/18 [00:03<00:00,  4.82batch/s, train_loss=0.00112] 


epoch=62, loss=0.019824865964979958


100%|██████████| 18/18 [00:03<00:00,  5.49batch/s, train_loss=0.00105] 


epoch=63, loss=0.0185151201959314


100%|██████████| 18/18 [00:03<00:00,  5.32batch/s, train_loss=0.00199] 


epoch=64, loss=0.019460120041822564


100%|██████████| 18/18 [00:03<00:00,  5.02batch/s, train_loss=0.00135] 


epoch=65, loss=0.01924190310157579


100%|██████████| 18/18 [00:03<00:00,  5.43batch/s, train_loss=0.000108]


epoch=66, loss=0.019107435373460938


100%|██████████| 18/18 [00:03<00:00,  5.47batch/s, train_loss=0.00117] 


epoch=67, loss=0.019289774070525997


100%|██████████| 18/18 [00:03<00:00,  5.61batch/s, train_loss=0.000822]


epoch=68, loss=0.01836721531301737


100%|██████████| 18/18 [00:03<00:00,  5.60batch/s, train_loss=0.000987]


epoch=69, loss=0.01946514279472417


100%|██████████| 18/18 [00:03<00:00,  5.51batch/s, train_loss=0.00231] 


epoch=70, loss=0.019142309229949424


100%|██████████| 18/18 [00:03<00:00,  5.70batch/s, train_loss=0.00246] 


epoch=71, loss=0.018988157063208777


100%|██████████| 18/18 [00:03<00:00,  5.44batch/s, train_loss=0.00132] 


epoch=72, loss=0.019319955806280008


100%|██████████| 18/18 [00:03<00:00,  5.45batch/s, train_loss=5.69e-5] 


epoch=73, loss=0.017990240328784646


100%|██████████| 18/18 [00:03<00:00,  4.62batch/s, train_loss=0.00228] 


epoch=74, loss=0.019895005854039353


100%|██████████| 18/18 [00:03<00:00,  5.44batch/s, train_loss=0.00202] 


epoch=75, loss=0.0191118312858302


100%|██████████| 18/18 [00:03<00:00,  5.53batch/s, train_loss=0.00167] 


epoch=76, loss=0.019140082791447636


100%|██████████| 18/18 [00:03<00:00,  5.54batch/s, train_loss=0.000708]


epoch=77, loss=0.01868549384690564


100%|██████████| 18/18 [00:03<00:00,  5.56batch/s, train_loss=0.00231] 


epoch=78, loss=0.02023201712041066


100%|██████████| 18/18 [00:03<00:00,  5.46batch/s, train_loss=0.00049] 


epoch=79, loss=0.01808939960290646


100%|██████████| 18/18 [00:03<00:00,  5.38batch/s, train_loss=0.00135] 


epoch=80, loss=0.01980272831587956


100%|██████████| 18/18 [00:03<00:00,  5.03batch/s, train_loss=0.00181] 


epoch=81, loss=0.019272557081847358


100%|██████████| 18/18 [00:03<00:00,  5.29batch/s, train_loss=0.000776]


epoch=82, loss=0.019927310186213455


100%|██████████| 18/18 [00:03<00:00,  5.64batch/s, train_loss=0.000409]


epoch=83, loss=0.01923749831344547


100%|██████████| 18/18 [00:03<00:00,  5.40batch/s, train_loss=0.00239] 


epoch=84, loss=0.019740279500854427


100%|██████████| 18/18 [00:03<00:00,  5.11batch/s, train_loss=0.00143] 


epoch=85, loss=0.019442916704919835


100%|██████████| 18/18 [00:03<00:00,  4.80batch/s, train_loss=7.34e-5] 


epoch=86, loss=0.019137741559024514


100%|██████████| 18/18 [00:03<00:00,  4.91batch/s, train_loss=0.00422] 


epoch=87, loss=0.020768778474423395


100%|██████████| 18/18 [00:03<00:00,  5.55batch/s, train_loss=0.00109] 


epoch=88, loss=0.018316379101625806


100%|██████████| 18/18 [00:03<00:00,  5.61batch/s, train_loss=0.00023] 


epoch=89, loss=0.0186156763200616


100%|██████████| 18/18 [00:03<00:00,  4.90batch/s, train_loss=0.000929]


epoch=90, loss=0.018793911740697666


100%|██████████| 18/18 [00:03<00:00,  5.52batch/s, train_loss=0.00245] 


epoch=91, loss=0.01958235697499637


100%|██████████| 18/18 [00:03<00:00,  5.55batch/s, train_loss=0.000522]


epoch=92, loss=0.018560167475764094


100%|██████████| 18/18 [00:03<00:00,  5.15batch/s, train_loss=0.000525]


epoch=93, loss=0.018281301565211395


100%|██████████| 18/18 [00:03<00:00,  5.28batch/s, train_loss=0.00193] 


epoch=94, loss=0.018411687458897458


100%|██████████| 18/18 [00:03<00:00,  5.30batch/s, train_loss=0.00252] 


epoch=95, loss=0.018444745976349403


100%|██████████| 18/18 [00:03<00:00,  5.26batch/s, train_loss=0.00162] 


epoch=96, loss=0.01972683303006764


100%|██████████| 18/18 [00:03<00:00,  5.11batch/s, train_loss=0.00199] 


epoch=97, loss=0.019145616908011764


100%|██████████| 18/18 [00:03<00:00,  5.19batch/s, train_loss=0.00179] 


epoch=98, loss=0.02004924507696053


100%|██████████| 18/18 [00:03<00:00,  5.51batch/s, train_loss=0.00189] 


epoch=99, loss=0.018742961042913894


100%|██████████| 18/18 [00:03<00:00,  5.39batch/s, train_loss=4.42e-5] 

epoch=100, loss=0.017887878675848762





In [25]:
from pandas import DataFrame
from collections import namedtuple

Row = namedtuple("Row", "user_id item_id Valence Arousal")

data_encoder = DataEncoder(original_df=original_df)
dataframe_after_mf = []

with torch.no_grad():
    for (index, user_id, item_id, valence, arousal) in original_df.itertuples():
        encoded_user_id = data_encoder.get_encoded_user_id(original_id=user_id)
        encoded_item_id = data_encoder.get_encoded_item_id(original_id=item_id)

        user_id_as_tensor = torch.LongTensor([encoded_user_id])
        item_id_as_tensor = torch.LongTensor([encoded_item_id])
        valence_output = valence_model(users=user_id_as_tensor, items=item_id_as_tensor,).squeeze()[0]
        arousal_output = arousal_model(users=user_id_as_tensor, items=item_id_as_tensor,).squeeze()[0]
        predicted_valence = torch.round(valence_output).item()
        predicted_arousal = torch.round(arousal_output).item()

        dataframe_after_mf.append(Row(user_id=user_id, item_id=item_id, Valence=predicted_valence, Arousal=predicted_arousal))

df_after_mf = DataFrame(dataframe_after_mf, columns=["user_id", "item_id", "Valence", "Arousal"])
df_after_mf.head()

Unnamed: 0,user_id,item_id,Valence,Arousal
0,6010bbc8e7ef4b21fa38f9c3a9754ef3,2,5.0,2.0
1,3c888e77b992ae3cd2adfe16774e23b9,2,2.0,3.0
2,2afd218c3aecb6828d2be327f8b9c46f,2,3.0,3.0
3,fd5b08ce362d855ca9152a894348130c,2,4.0,4.0
4,9c8073214a052e414811b76012df8847,2,2.0,2.0


In [26]:
mask = ((original_df["Valence"] == df_after_mf["Valence"]) & (original_df["Arousal"] == df_after_mf["Arousal"]))
changes = original_df[mask].copy()
changes["New Valence"] = df_after_mf.Valence
changes["New Arousal"] = df_after_mf.Arousal
print(f"Number of hits: {len(changes)} / {len(original_df)}")
print(f"Hit ratio: {(len(changes) / len(original_df)) * 100}")
changes.head(len(changes))

Number of hits: 17345 / 17464
Hit ratio: 99.31859825927623


Unnamed: 0,user_id,item_id,Valence,Arousal,New Valence,New Arousal
0,6010bbc8e7ef4b21fa38f9c3a9754ef3,2,5,2,5.0,2.0
1,3c888e77b992ae3cd2adfe16774e23b9,2,2,3,2.0,3.0
2,2afd218c3aecb6828d2be327f8b9c46f,2,3,3,3.0,3.0
3,fd5b08ce362d855ca9152a894348130c,2,4,4,4.0,4.0
4,9c8073214a052e414811b76012df8847,2,2,2,2.0,2.0
...,...,...,...,...,...,...
17459,607f6e34a0b5923333f6b16d3a59cc98,2000,6,5,6.0,5.0
17460,78b5e9744073532cc376976b5fc6b2fc,2000,7,7,7.0,7.0
17461,7cecbffe1da5ae974952db6c13695afe,2000,4,5,4.0,5.0
17462,ed7ed76453bd846859f5e6b9149df276,2000,6,7,6.0,7.0
